diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..333595863b98eb3393d74a1ac3144751a86329c6 --- /dev/null +++ b/app.py @@ -0,0 +1,176 @@ +import os +import gradio as gr +import subprocess +import tempfile +import shutil +from pathlib import Path +import sys +import importlib.util + +# Ensure models directory exists +MODELS_DIR = Path("models") +os.makedirs(MODELS_DIR, exist_ok=True) + +def ensure_dependencies(): + """Ensure all required dependencies are installed.""" + required_packages = [ + "ultralytics", + "boxmot", + "supervision" + ] + + for package in required_packages: + try: + importlib.import_module(package) + print(f"✅ {package} is installed") + except ImportError: + print(f"⚠️ {package} is not installed, attempting to install...") + subprocess.run([sys.executable, "-m", "pip", "install", package], check=True) + +# Apply tracker patches if tracker_patch.py exists +def apply_patches(): + patch_path = Path("tracker_patch.py") + if patch_path.exists(): + spec = importlib.util.spec_from_file_location("tracker_patch", patch_path) + if spec: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "patch_trackers"): + module.patch_trackers() + print("✅ Applied tracker patches") + else: + print("⚠️ tracker_patch.py exists but has no patch_trackers function") + else: + print("⚠️ tracker_patch.py not found, skipping patches") + +def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_threshold): + """Run object tracking on the uploaded video.""" + try: + # Create temporary workspace + with tempfile.TemporaryDirectory() as temp_dir: + # Prepare input + input_path = os.path.join(temp_dir, "input_video.mp4") + shutil.copy(video_file, input_path) + + # Prepare output directory + output_dir = os.path.join(temp_dir, "output") + os.makedirs(output_dir, exist_ok=True) + + # Build command + cmd = [ + "python", "tracking/track.py", + "--yolo-model", str(MODELS_DIR / yolo_model), + "--reid-model", str(MODELS_DIR / reid_model), + "--tracking-method", tracking_method, + "--source", input_path, + "--conf", str(conf_threshold), + "--save", + "--project", output_dir, + "--name", "track", + "--exist-ok" + ] + + # Special handling for OcSort + if tracking_method == "ocsort": + cmd.append("--per-class") + + # Execute tracking with error handling + process = subprocess.run( + cmd, + capture_output=True, + text=True + ) + + # Check for errors in output + if process.returncode != 0: + error_message = process.stderr or process.stdout + return None, f"Error in tracking process: {error_message}" + + # Find output video + output_files = [] + for root, _, files in os.walk(output_dir): + for file in files: + if file.lower().endswith((".mp4", ".avi", ".mov")): + output_files.append(os.path.join(root, file)) + + if not output_files: + return None, "No output video was generated. Check if tracking was successful." + + return output_files[0], "Processing completed successfully!" + + except Exception as e: + return None, f"Error: {str(e)}" + +# Define the Gradio interface +def process_video(video_path, yolo_model, reid_model, tracking_method, conf_threshold): + # Validate inputs + if not video_path: + return None, "Please upload a video file" + + output_path, status = run_tracking( + video_path, + yolo_model, + reid_model, + tracking_method, + conf_threshold + ) + + return output_path, status + +# Available models and tracking methods +yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"] +reid_models = ["osnet_x0_25_msmt17.pt"] +tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"] + +# Ensure dependencies and apply patches at startup +ensure_dependencies() +apply_patches() + +# Create the Gradio interface +with gr.Blocks(title="YOLO Object Tracking") as app: + gr.Markdown("# 🚀 YOLO Object Tracking") + gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.") + + with gr.Row(): + with gr.Column(): + input_video = gr.Video(label="Input Video", sources=["upload"]) + + with gr.Group(): + yolo_model = gr.Dropdown( + choices=yolo_models, + value="yolov8n.pt", + label="YOLO Model" + ) + reid_model = gr.Dropdown( + choices=reid_models, + value="osnet_x0_25_msmt17.pt", + label="ReID Model" + ) + tracking_method = gr.Dropdown( + choices=tracking_methods, + value="bytetrack", + label="Tracking Method" + ) + conf_threshold = gr.Slider( + minimum=0.1, + maximum=0.9, + value=0.3, + step=0.05, + label="Confidence Threshold" + ) + + process_btn = gr.Button("Process Video", variant="primary") + + with gr.Column(): + output_video = gr.Video(label="Output Video with Tracking", autoplay=True) + status_text = gr.Textbox(label="Status", value="Ready to process video") + + process_btn.click( + fn=process_video, + inputs=[input_video, yolo_model, reid_model, tracking_method, conf_threshold], + outputs=[output_video, status_text] + ) + +# Launch the app +if __name__ == "__main__": + app.launch(debug=True, share=True) \ No newline at end of file diff --git a/boxmot/__init__.py b/boxmot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44d92aeb2700717188449808878236f4c35c1509 --- /dev/null +++ b/boxmot/__init__.py @@ -0,0 +1,21 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +__version__ = '12.0.7' + +from boxmot.postprocessing.gsi import gsi +from boxmot.tracker_zoo import create_tracker, get_tracker_config +from boxmot.trackers.botsort.botsort import BotSort +from boxmot.trackers.bytetrack.bytetrack import ByteTrack +from boxmot.trackers.deepocsort.deepocsort import DeepOcSort +from boxmot.trackers.hybridsort.hybridsort import HybridSort +from boxmot.trackers.ocsort.ocsort import OcSort +from boxmot.trackers.strongsort.strongsort import StrongSort +from boxmot.trackers.imprassoc.imprassoctrack import ImprAssocTrack +from boxmot.trackers.boosttrack.boosttrack import BoostTrack + + +TRACKERS = ['bytetrack', 'botsort', 'strongsort', 'ocsort', 'deepocsort', 'hybridsort', 'imprassoc', 'boosttrack'] + +__all__ = ("__version__", + "StrongSort", "OcSort", "ByteTrack", "BotSort", "DeepOcSort", "HybridSort", "ImprAssocTrack", "BoostTrack", + "create_tracker", "get_tracker_config", "gsi") diff --git a/boxmot/appearance/__init__.py b/boxmot/appearance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boxmot/appearance/backbones/__init__.py b/boxmot/appearance/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d3b5e380ba49ba2191f5afcb885352fc6a7c29 --- /dev/null +++ b/boxmot/appearance/backbones/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license \ No newline at end of file diff --git a/boxmot/appearance/backbones/clip/__init__.py b/boxmot/appearance/backbones/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/appearance/backbones/clip/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/appearance/backbones/clip/clip/__init__.py b/boxmot/appearance/backbones/clip/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/appearance/backbones/clip/clip/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz b/boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/boxmot/appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/boxmot/appearance/backbones/clip/clip/clip.py b/boxmot/appearance/backbones/clip/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..728a82a89fc9587596d7834a4c7b5dfa7775da8c --- /dev/null +++ b/boxmot/appearance/backbones/clip/clip/clip.py @@ -0,0 +1,222 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import hashlib +import os +import urllib +import warnings +from typing import List, Union + +import torch +from PIL import Image +from torchvision.transforms import (CenterCrop, Compose, Normalize, Resize, + ToTensor) +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa: E501 + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa: E501 + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa: E501 + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa: E501 + "ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa: E501 + "ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa: E501 +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + # import pdb + # pdb.set_trace() + if isinstance(texts, str): + texts = [texts] # ['a photo of a face.'] + + sot_token = _tokenizer.encoder["<|startoftext|>"] # 49406 + eot_token = _tokenizer.encoder["<|endoftext|>"] # 49407 + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) # 1,77 + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: # context_length 77 + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/boxmot/appearance/backbones/clip/clip/model.py b/boxmot/appearance/backbones/clip/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..47dd4af3e979b1066a245aebaa769b148f91c063 --- /dev/null +++ b/boxmot/appearance/backbones/clip/clip/model.py @@ -0,0 +1,504 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + # NCHW -> (HW)NC #32,2048,7,7 ->49, 32, 2048 + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 50,32,2048 + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=1) + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x3 = self.layer3(x) + x4 = self.layer4(x3) + xproj = self.attnpool(x4) + + return x3, x4, xproj + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + for param in self.parameters(): + if param.dtype == torch.float16: + param.data = param.data.to(torch.float32) + ret = super().forward(x.to(torch.float32)) + return ret.to(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__( + self, + h_resolution: int, + w_resolution: int, + patch_size: int, + stride_size: int, + width: int, + layers: int, + heads: int, + output_dim: int + ): + super().__init__() + self.h_resolution = h_resolution + self.w_resolution = w_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=stride_size, + bias=False + ) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(h_resolution * w_resolution + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, cv_emb=None): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + + # shape = [*, grid ** 2 + 1, width] + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) + if cv_emb is not None: + x[:, 0] = x[:, 0] + cv_emb + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + x11 = self.transformer.resblocks[:11](x) + x12 = self.transformer.resblocks[11](x11) + x11 = x11.permute(1, 0, 2) # LND -> NLD + x12 = x12.permute(1, 0, 2) # LND -> NLD + + x12 = self.ln_post(x12) + + if self.proj is not None: + xproj = x12 @ self.proj + + return x11, x12, xproj + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + vision_stride_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + h_resolution: int, + w_resolution: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=h_resolution * w_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + h_resolution=h_resolution, + w_resolution=w_resolution, + patch_size=vision_patch_size, + stride_size=vision_stride_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) + x = self.transformer(x) + x = x.permute(1, 0, 2) + x = self.ln_final(x).type(self.dtype) + + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.float() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.float() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict, h_resolution: int, w_resolution: int, vision_stride_size: int): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and + k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: # RN50 + counts: list = [len(set(k.split(".")[2] for k in state_dict if + k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] # 77 (77,512) + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, vision_stride_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, + h_resolution, w_resolution + ) + if vit: + state_dict["visual.positional_embedding"] = resize_pos_embed( + state_dict["visual.positional_embedding"], + model.visual.positional_embedding, + h_resolution, + w_resolution + ) + else: # RN50 + state_dict["visual.attnpool.positional_embedding"] = resize_pos_embed( + state_dict["visual.attnpool.positional_embedding"], + model.visual.attnpool.positional_embedding, + h_resolution, + w_resolution + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + + model.load_state_dict(state_dict) + return model.eval() + + +import math + + +def resize_pos_embed(posemb, posemb_new, hight, width): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + + ntok_new = posemb_new.shape[0] # 129,2048 + + posemb_token, posemb_grid = posemb[:1], posemb[1:] + ntok_new -= 1 + + gs_old = int(math.sqrt(len(posemb_grid))) # 14 + print('Position embedding resize to height:{} width: {}'.format(hight, width)) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) + posemb = torch.cat([posemb_token, posemb_grid.squeeze()], dim=0) + return posemb diff --git a/boxmot/appearance/backbones/clip/clip/simple_tokenizer.py b/boxmot/appearance/backbones/clip/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..34a46b13195b4a52fd6b9e571d0679256511ca5c --- /dev/null +++ b/boxmot/appearance/backbones/clip/clip/simple_tokenizer.py @@ -0,0 +1,136 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import gzip +import html +from functools import lru_cache + +import ftfy +import regex as re + +from boxmot.utils import BOXMOT + + +@lru_cache() +def default_bpe(): + return BOXMOT / "appearance/backbones/clip/clip/bpe_simple_vocab_16e6.txt.gz" + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) # noqa: E501 + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/boxmot/appearance/backbones/clip/config/__init__.py b/boxmot/appearance/backbones/clip/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/appearance/backbones/clip/config/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/appearance/backbones/clip/config/defaults.py b/boxmot/appearance/backbones/clip/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..15245b105782738b9275851014009dc6a097cc8b --- /dev/null +++ b/boxmot/appearance/backbones/clip/config/defaults.py @@ -0,0 +1,239 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# Using cuda or cpu for training +_C.MODEL.DEVICE = "cuda" +# ID number of GPU +_C.MODEL.DEVICE_ID = '0' +# Name of backbone +_C.MODEL.NAME = 'ViT-B-16' +# Last stride of backbone +_C.MODEL.LAST_STRIDE = 1 +# Path to pretrained model of backbone +_C.MODEL.PRETRAIN_PATH = '/home/mikel.brostrom/yolo_tracking/clip_market1501.pt' + +# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model +# Options: 'imagenet' , 'self' , 'finetune' +_C.MODEL.PRETRAIN_CHOICE = 'imagenet' + +# If train with BNNeck, options: 'bnneck' or 'no' +_C.MODEL.NECK = 'bnneck' +# If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration +_C.MODEL.IF_WITH_CENTER = 'no' + +_C.MODEL.ID_LOSS_TYPE = 'softmax' +_C.MODEL.ID_LOSS_WEIGHT = 1.0 +_C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 +_C.MODEL.I2T_LOSS_WEIGHT = 1.0 + +_C.MODEL.METRIC_LOSS_TYPE = 'triplet' +# If train with multi-gpu ddp mode, options: 'True', 'False' +_C.MODEL.DIST_TRAIN = False +# If train with soft triplet loss, options: 'True', 'False' +_C.MODEL.NO_MARGIN = False +# If train with label smooth, options: 'on', 'off' +_C.MODEL.IF_LABELSMOOTH = 'on' +# If train with arcface loss, options: 'True', 'False' +_C.MODEL.COS_LAYER = False + +# Transformer setting +_C.MODEL.DROP_PATH = 0.1 +_C.MODEL.DROP_OUT = 0.0 +_C.MODEL.ATT_DROP_RATE = 0.0 +_C.MODEL.TRANSFORMER_TYPE = 'None' +_C.MODEL.STRIDE_SIZE = [16, 16] + +# SIE Parameter +_C.MODEL.SIE_COE = 3.0 +_C.MODEL.SIE_CAMERA = False +_C.MODEL.SIE_VIEW = False + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# Size of the image during training +_C.INPUT.SIZE_TRAIN = [256, 128] +# Size of the image during test +_C.INPUT.SIZE_TEST = [256, 128] +# Random probability for image horizontal flip +_C.INPUT.PROB = 0.5 +# Random probability for random erasing +_C.INPUT.RE_PROB = 0.5 +# Values to be used for image normalization +_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] +# Values to be used for image normalization +_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] +# Value of padding size +_C.INPUT.PADDING = 10 + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training, as present in paths_catalog.py +_C.DATASETS.NAMES = ('market1501') +# Root directory where datasets should be used (and downloaded if not found) +_C.DATASETS.ROOT_DIR = ('../data') + + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 8 +# Sampler for data loading +_C.DATALOADER.SAMPLER = 'softmax' +# Number of instance for one batch +_C.DATALOADER.NUM_INSTANCE = 16 + +# ---------------------------------------------------------------------------- # +# Solver +_C.SOLVER = CN() +_C.SOLVER.SEED = 1234 +_C.SOLVER.MARGIN = 0.3 + +# stage1 +# ---------------------------------------------------------------------------- # +# Name of optimizer +_C.SOLVER.STAGE1 = CN() + +_C.SOLVER.STAGE1.IMS_PER_BATCH = 64 + +_C.SOLVER.STAGE1.OPTIMIZER_NAME = "Adam" +# Number of max epoches +_C.SOLVER.STAGE1.MAX_EPOCHS = 100 +# Base learning rate +_C.SOLVER.STAGE1.BASE_LR = 3e-4 +# Momentum +_C.SOLVER.STAGE1.MOMENTUM = 0.9 + +# Settings of weight decay +_C.SOLVER.STAGE1.WEIGHT_DECAY = 0.0005 +_C.SOLVER.STAGE1.WEIGHT_DECAY_BIAS = 0.0005 + +# warm up factor +_C.SOLVER.STAGE1.WARMUP_FACTOR = 0.01 +# warm up epochs +_C.SOLVER.STAGE1.WARMUP_EPOCHS = 5 +_C.SOLVER.STAGE1.WARMUP_LR_INIT = 0.01 +_C.SOLVER.STAGE1.LR_MIN = 0.000016 + +_C.SOLVER.STAGE1.WARMUP_ITERS = 500 +# method of warm up, option: 'constant','linear' +_C.SOLVER.STAGE1.WARMUP_METHOD = "linear" + +_C.SOLVER.STAGE1.COSINE_MARGIN = 0.5 +_C.SOLVER.STAGE1.COSINE_SCALE = 30 + +# epoch number of saving checkpoints +_C.SOLVER.STAGE1.CHECKPOINT_PERIOD = 10 +# iteration of display training log +_C.SOLVER.STAGE1.LOG_PERIOD = 100 +# epoch number of validation +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will +# contain 16 images per batch +# _C.SOLVER.STAGE1.IMS_PER_BATCH = 64 +_C.SOLVER.STAGE1.EVAL_PERIOD = 10 + +# ---------------------------------------------------------------------------- # +# Solver +# stage1 +# ---------------------------------------------------------------------------- # +_C.SOLVER.STAGE2 = CN() + +_C.SOLVER.STAGE2.IMS_PER_BATCH = 64 +# Name of optimizer +_C.SOLVER.STAGE2.OPTIMIZER_NAME = "Adam" +# Number of max epoches +_C.SOLVER.STAGE2.MAX_EPOCHS = 100 +# Base learning rate +_C.SOLVER.STAGE2.BASE_LR = 3e-4 +# Whether using larger learning rate for fc layer +_C.SOLVER.STAGE2.LARGE_FC_LR = False +# Factor of learning bias +_C.SOLVER.STAGE2.BIAS_LR_FACTOR = 1 +# Momentum +_C.SOLVER.STAGE2.MOMENTUM = 0.9 +# Margin of triplet loss +# Learning rate of SGD to learn the centers of center loss +_C.SOLVER.STAGE2.CENTER_LR = 0.5 +# Balanced weight of center loss +_C.SOLVER.STAGE2.CENTER_LOSS_WEIGHT = 0.0005 + +# Settings of weight decay +_C.SOLVER.STAGE2.WEIGHT_DECAY = 0.0005 +_C.SOLVER.STAGE2.WEIGHT_DECAY_BIAS = 0.0005 + +# decay rate of learning rate +_C.SOLVER.STAGE2.GAMMA = 0.1 +# decay step of learning rate +_C.SOLVER.STAGE2.STEPS = (40, 70) +# warm up factor +_C.SOLVER.STAGE2.WARMUP_FACTOR = 0.01 +# warm up epochs +_C.SOLVER.STAGE2.WARMUP_EPOCHS = 5 +_C.SOLVER.STAGE2.WARMUP_LR_INIT = 0.01 +_C.SOLVER.STAGE2.LR_MIN = 0.000016 + + +_C.SOLVER.STAGE2.WARMUP_ITERS = 500 +# method of warm up, option: 'constant','linear' +_C.SOLVER.STAGE2.WARMUP_METHOD = "linear" + +_C.SOLVER.STAGE2.COSINE_MARGIN = 0.5 +_C.SOLVER.STAGE2.COSINE_SCALE = 30 + +# epoch number of saving checkpoints +_C.SOLVER.STAGE2.CHECKPOINT_PERIOD = 10 +# iteration of display training log +_C.SOLVER.STAGE2.LOG_PERIOD = 100 +# epoch number of validation +_C.SOLVER.STAGE2.EVAL_PERIOD = 10 +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will +# contain 16 images per batch + +# ---------------------------------------------------------------------------- # +# TEST +# ---------------------------------------------------------------------------- # + +_C.TEST = CN() +# Number of images per batch during test +_C.TEST.IMS_PER_BATCH = 128 +# If test with re-ranking, options: 'True','False' +_C.TEST.RE_RANKING = False +# Path to trained model +_C.TEST.WEIGHT = "" +# Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' +_C.TEST.NECK_FEAT = 'after' +# Whether feature is nomalized before test, if yes, it is equivalent to cosine distance +_C.TEST.FEAT_NORM = 'yes' + +# Name for saving the distmat after testing. +_C.TEST.DIST_MAT = "dist_mat.npy" +# Whether calculate the eval score option: 'True', 'False' +_C.TEST.EVAL = False +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +# Path to checkpoint and saved log of trained model +_C.OUTPUT_DIR = "" diff --git a/boxmot/appearance/backbones/clip/config/defaults_base.py b/boxmot/appearance/backbones/clip/config/defaults_base.py new file mode 100644 index 0000000000000000000000000000000000000000..81c898f81f8f7583ef663c667e75a2b01c5f1b21 --- /dev/null +++ b/boxmot/appearance/backbones/clip/config/defaults_base.py @@ -0,0 +1,190 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() +# ----------------------------------------------------------------------------- +# MODEL +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# Using cuda or cpu for training +_C.MODEL.DEVICE = "cuda" +# ID number of GPU +_C.MODEL.DEVICE_ID = '0' +# Name of backbone +_C.MODEL.NAME = 'resnet50' +# Last stride of backbone +_C.MODEL.LAST_STRIDE = 1 +# Path to pretrained model of backbone +_C.MODEL.PRETRAIN_PATH = '' + +# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model +# Options: 'imagenet' , 'self' , 'finetune' +_C.MODEL.PRETRAIN_CHOICE = 'imagenet' + +# If train with BNNeck, options: 'bnneck' or 'no' +_C.MODEL.NECK = 'bnneck' +# If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration +_C.MODEL.IF_WITH_CENTER = 'no' + +_C.MODEL.ID_LOSS_TYPE = 'softmax' +_C.MODEL.ID_LOSS_WEIGHT = 1.0 +_C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 +_C.MODEL.I2T_LOSS_WEIGHT = 1.0 + +_C.MODEL.METRIC_LOSS_TYPE = 'triplet' +# If train with multi-gpu ddp mode, options: 'True', 'False' +_C.MODEL.DIST_TRAIN = False +# If train with soft triplet loss, options: 'True', 'False' +_C.MODEL.NO_MARGIN = False +# If train with label smooth, options: 'on', 'off' +_C.MODEL.IF_LABELSMOOTH = 'on' +# If train with arcface loss, options: 'True', 'False' +_C.MODEL.COS_LAYER = False + +# Transformer setting +_C.MODEL.DROP_PATH = 0.1 +_C.MODEL.DROP_OUT = 0.0 +_C.MODEL.ATT_DROP_RATE = 0.0 +_C.MODEL.TRANSFORMER_TYPE = 'None' +_C.MODEL.STRIDE_SIZE = [16, 16] + +# SIE Parameter +_C.MODEL.SIE_COE = 3.0 +_C.MODEL.SIE_CAMERA = False +_C.MODEL.SIE_VIEW = False + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# Size of the image during training +_C.INPUT.SIZE_TRAIN = [384, 128] +# Size of the image during test +_C.INPUT.SIZE_TEST = [384, 128] +# Random probability for image horizontal flip +_C.INPUT.PROB = 0.5 +# Random probability for random erasing +_C.INPUT.RE_PROB = 0.5 +# Values to be used for image normalization +_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] +# Values to be used for image normalization +_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] +# Value of padding size +_C.INPUT.PADDING = 10 + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training, as present in paths_catalog.py +_C.DATASETS.NAMES = ('market1501') +# Root directory where datasets should be used (and downloaded if not found) +_C.DATASETS.ROOT_DIR = ('../data') + + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 8 +# Sampler for data loading +_C.DATALOADER.SAMPLER = 'softmax' +# Number of instance for one batch +_C.DATALOADER.NUM_INSTANCE = 16 + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() +# Name of optimizer +_C.SOLVER.OPTIMIZER_NAME = "Adam" +# Number of max epoches +_C.SOLVER.MAX_EPOCHS = 100 +# Base learning rate +_C.SOLVER.BASE_LR = 3e-4 +# Whether using larger learning rate for fc layer +_C.SOLVER.LARGE_FC_LR = False +# Factor of learning bias +_C.SOLVER.BIAS_LR_FACTOR = 1 +# Factor of learning bias +_C.SOLVER.SEED = 1234 +# Momentum +_C.SOLVER.MOMENTUM = 0.9 +# Margin of triplet loss +_C.SOLVER.MARGIN = 0.3 +# Learning rate of SGD to learn the centers of center loss +_C.SOLVER.CENTER_LR = 0.5 +# Balanced weight of center loss +_C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 + +# Settings of weight decay +_C.SOLVER.WEIGHT_DECAY = 0.0005 +_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 + +# decay rate of learning rate +_C.SOLVER.GAMMA = 0.1 +# decay step of learning rate +_C.SOLVER.STEPS = (40, 70) +# warm up factor +_C.SOLVER.WARMUP_FACTOR = 0.01 +# warm up epochs +_C.SOLVER.WARMUP_EPOCHS = 5 +_C.SOLVER.WARMUP_LR_INIT = 0.01 +_C.SOLVER.LR_MIN = 0.000016 + + +_C.SOLVER.WARMUP_ITERS = 500 +# method of warm up, option: 'constant','linear' +_C.SOLVER.WARMUP_METHOD = "linear" + +_C.SOLVER.COSINE_MARGIN = 0.5 +_C.SOLVER.COSINE_SCALE = 30 + +# epoch number of saving checkpoints +_C.SOLVER.CHECKPOINT_PERIOD = 10 +# iteration of display training log +_C.SOLVER.LOG_PERIOD = 100 +# epoch number of validation +_C.SOLVER.EVAL_PERIOD = 10 +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will +# contain 16 images per batch +_C.SOLVER.IMS_PER_BATCH = 64 + +# ---------------------------------------------------------------------------- # +# TEST +# ---------------------------------------------------------------------------- # + +_C.TEST = CN() +# Number of images per batch during test +_C.TEST.IMS_PER_BATCH = 128 +# If test with re-ranking, options: 'True','False' +_C.TEST.RE_RANKING = False +# Path to trained model +_C.TEST.WEIGHT = "" +# Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' +_C.TEST.NECK_FEAT = 'after' +# Whether feature is nomalized before test, if yes, it is equivalent to cosine distance +_C.TEST.FEAT_NORM = 'yes' + +# Name for saving the distmat after testing. +_C.TEST.DIST_MAT = "dist_mat.npy" +# Whether calculate the eval score option: 'True', 'False' +_C.TEST.EVAL = False +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +# Path to checkpoint and saved log of trained model +_C.OUTPUT_DIR = "" diff --git a/boxmot/appearance/backbones/clip/make_model.py b/boxmot/appearance/backbones/clip/make_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9cfc609a9e7cc53453e8f7cdc0388694ae90d9 --- /dev/null +++ b/boxmot/appearance/backbones/clip/make_model.py @@ -0,0 +1,161 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import torch +import torch.nn as nn + +from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer + +_tokenizer = _Tokenizer() + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') + nn.init.constant_(m.bias, 0.0) + + elif classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find('BatchNorm') != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + +def weights_init_classifier(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.normal_(m.weight, std=0.001) + if m.bias: + nn.init.constant_(m.bias, 0.0) + + +class build_transformer(nn.Module): + def __init__(self, num_classes, camera_num, view_num, cfg): + super(build_transformer, self).__init__() + self.model_name = cfg.MODEL.NAME + self.cos_layer = cfg.MODEL.COS_LAYER + self.neck = cfg.MODEL.NECK + self.neck_feat = cfg.TEST.NECK_FEAT + if self.model_name == 'ViT-B-16': + self.in_planes = 768 + self.in_planes_proj = 512 + elif self.model_name == 'RN50': + self.in_planes = 2048 + self.in_planes_proj = 1024 + self.num_classes = num_classes + self.camera_num = camera_num + self.view_num = view_num + self.sie_coe = cfg.MODEL.SIE_COE + + self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) + self.classifier.apply(weights_init_classifier) + self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False) + self.classifier_proj.apply(weights_init_classifier) + + self.bottleneck = nn.BatchNorm1d(self.in_planes) + self.bottleneck.bias.requires_grad_(False) + self.bottleneck.apply(weights_init_kaiming) + self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj) + self.bottleneck_proj.bias.requires_grad_(False) + self.bottleneck_proj.apply(weights_init_kaiming) + + self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0] - 16) // cfg.MODEL.STRIDE_SIZE[0] + 1) + self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1] - 16) // cfg.MODEL.STRIDE_SIZE[1] + 1) + self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0] + clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size) + + self.image_encoder = clip_model.visual + + # if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW: + # self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes)) + # trunc_normal_(self.cv_embed, std=.02) + # print('camera number is : {}'.format(camera_num)) + # elif cfg.MODEL.SIE_CAMERA: + # self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes)) + # trunc_normal_(self.cv_embed, std=.02) + # print('camera number is : {}'.format(camera_num)) + # elif cfg.MODEL.SIE_VIEW: + # self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes)) + # trunc_normal_(self.cv_embed, std=.02) + # print('camera number is : {}'.format(view_num)) + + def forward(self, x, label=None, cam_label=None, view_label=None): + if self.model_name == 'RN50': + image_features_last, image_features, image_features_proj = self.image_encoder(x) # B,512 B,128,512 + img_feature_last = nn.functional.avg_pool2d( + image_features_last, + image_features_last.shape[2:4]).view(x.shape[0], -1) + img_feature = nn.functional.avg_pool2d( + image_features, + image_features.shape[2:4]).view(x.shape[0], -1) + img_feature_proj = image_features_proj[0] + + elif self.model_name == 'ViT-B-16': + if cam_label is not None and view_label is not None: + cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label] + elif cam_label is not None: + cv_embed = self.sie_coe * self.cv_embed[cam_label] + elif view_label is not None: + cv_embed = self.sie_coe * self.cv_embed[view_label] + else: + cv_embed = None + # B,512 B,128,512 + image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed) + img_feature_last = image_features_last[:, 0] + img_feature = image_features[:, 0] + img_feature_proj = image_features_proj[:, 0] + + feat = self.bottleneck(img_feature) + feat_proj = self.bottleneck_proj(img_feature_proj) + + if self.training: + cls_score = self.classifier(feat) + cls_score_proj = self.classifier_proj(feat_proj) + return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj] + + else: + if self.neck_feat == 'after': + # print("Test with feature after BN") + return torch.cat([feat, feat_proj], dim=1) + else: + return torch.cat([img_feature, img_feature_proj], dim=1) + + def load_param(self, trained_path): + param_dict = torch.load(trained_path, map_location=torch.device("cpu")) + for i in self.state_dict(): + self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) + # print('Loading pretrained model from {}'.format('/home/mikel.brostrom/yolo_tracking/clip_market1501.pt')) + + def load_param_finetune(self, model_path): + param_dict = torch.load(model_path) + for i in param_dict: + self.state_dict()[i].copy_(param_dict[i]) + # print('Loading pretrained model for finetuning from {}'.format(model_path)) + + +def make_model(cfg, num_class, camera_num, view_num): + model = build_transformer(num_class, camera_num, view_num, cfg) + return model + + +from .clip import clip + + +def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size): + url = clip._MODELS[backbone_name] + model_path = clip._download(url) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + + except RuntimeError: + state_dict = torch.load(model_path, map_location="cpu") + + model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size) + + return model diff --git a/boxmot/appearance/backbones/clip/make_model_clipreid.py b/boxmot/appearance/backbones/clip/make_model_clipreid.py new file mode 100644 index 0000000000000000000000000000000000000000..51ea9897b21b91274a1f76209ce5c2211ce4bb7a --- /dev/null +++ b/boxmot/appearance/backbones/clip/make_model_clipreid.py @@ -0,0 +1,247 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import torch +import torch.nn as nn + +from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer + +_tokenizer = _Tokenizer() + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') + nn.init.constant_(m.bias, 0.0) + + elif classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find('BatchNorm') != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + +def weights_init_classifier(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.normal_(m.weight, std=0.001) + if m.bias: + nn.init.constant_(m.bias, 0.0) + + +class TextEncoder(nn.Module): + def __init__(self, clip_model): + super().__init__() + self.transformer = clip_model.transformer + self.positional_embedding = clip_model.positional_embedding + self.ln_final = clip_model.ln_final + self.text_projection = clip_model.text_projection + self.dtype = clip_model.dtype + + def forward(self, prompts, tokenized_prompts): + x = prompts + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection + return x + + +class build_transformer(nn.Module): + def __init__(self, num_classes, camera_num, view_num, cfg): + super(build_transformer, self).__init__() + self.model_name = cfg.MODEL.NAME + self.cos_layer = cfg.MODEL.COS_LAYER + self.neck = cfg.MODEL.NECK + self.neck_feat = cfg.TEST.NECK_FEAT + if self.model_name == 'ViT-B-16': + self.in_planes = 768 + self.in_planes_proj = 512 + elif self.model_name == 'RN50': + self.in_planes = 2048 + self.in_planes_proj = 1024 + self.num_classes = num_classes + self.camera_num = camera_num + self.view_num = view_num + self.sie_coe = cfg.MODEL.SIE_COE + + self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) + self.classifier.apply(weights_init_classifier) + self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False) + self.classifier_proj.apply(weights_init_classifier) + + self.bottleneck = nn.BatchNorm1d(self.in_planes) + self.bottleneck.bias.requires_grad_(False) + self.bottleneck.apply(weights_init_kaiming) + self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj) + self.bottleneck_proj.bias.requires_grad_(False) + self.bottleneck_proj.apply(weights_init_kaiming) + + self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0] - 16) // cfg.MODEL.STRIDE_SIZE[0] + 1) + self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1] - 16) // cfg.MODEL.STRIDE_SIZE[1] + 1) + self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0] + clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size) + + self.image_encoder = clip_model.visual + + # if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW: + # self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes)) + # trunc_normal_(self.cv_embed, std=.02) + # print('camera number is : {}'.format(camera_num)) + # elif cfg.MODEL.SIE_CAMERA: + # self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes)) + # trunc_normal_(self.cv_embed, std=.02) + # print('camera number is : {}'.format(camera_num)) + # elif cfg.MODEL.SIE_VIEW: + # self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes)) + # trunc_normal_(self.cv_embed, std=.02) + # print('camera number is : {}'.format(view_num)) + + dataset_name = cfg.DATASETS.NAMES + self.prompt_learner = PromptLearner(num_classes, dataset_name, clip_model.dtype, clip_model.token_embedding) + self.text_encoder = TextEncoder(clip_model) + + def forward(self, x=None, label=None, get_image=False, get_text=False, cam_label=None, view_label=None): + if get_text is True: + prompts = self.prompt_learner(label) + text_features = self.text_encoder(prompts, self.prompt_learner.tokenized_prompts) + return text_features + + if get_image is True: + image_features_last, image_features, image_features_proj = self.image_encoder(x) + if self.model_name == 'RN50': + return image_features_proj[0] + elif self.model_name == 'ViT-B-16': + return image_features_proj[:, 0] + + if self.model_name == 'RN50': + image_features_last, image_features, image_features_proj = self.image_encoder(x) + img_feature_last = nn.functional.avg_pool2d( + image_features_last, + image_features_last.shape[2:4]).view(x.shape[0], -1) + img_feature = nn.functional.avg_pool2d( + image_features, + image_features.shape[2:4]).view(x.shape[0], -1) + img_feature_proj = image_features_proj[0] + + elif self.model_name == 'ViT-B-16': + if cam_label is not None and view_label is not None: + cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label] + elif cam_label is not None: + cv_embed = self.sie_coe * self.cv_embed[cam_label] + elif view_label is not None: + cv_embed = self.sie_coe * self.cv_embed[view_label] + else: + cv_embed = None + image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed) + img_feature_last = image_features_last[:, 0] + img_feature = image_features[:, 0] + img_feature_proj = image_features_proj[:, 0] + + feat = self.bottleneck(img_feature) + feat_proj = self.bottleneck_proj(img_feature_proj) + + if self.training: + cls_score = self.classifier(feat) + cls_score_proj = self.classifier_proj(feat_proj) + return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj], img_feature_proj + + else: + if self.neck_feat == 'after': + # print("Test with feature after BN") + return torch.cat([feat, feat_proj], dim=1) + else: + return torch.cat([img_feature, img_feature_proj], dim=1) + + def load_param(self, trained_path): + param_dict = torch.load(trained_path) + for i in param_dict: + self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) + print('Loaded pretrained model from {}'.format(trained_path)) + + def load_param_finetune(self, model_path): + param_dict = torch.load(model_path) + for i in param_dict: + self.state_dict()[i].copy_(param_dict[i]) + print('Loading pretrained model for finetuning from {}'.format(model_path)) + + +def make_model(cfg, num_class, camera_num, view_num): + model = build_transformer(num_class, camera_num, view_num, cfg) + return model + + +from .clip import clip + + +def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size): + url = clip._MODELS[backbone_name] + model_path = clip._download(url) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + + except RuntimeError: + state_dict = torch.load(model_path, map_location="cpu") + + model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size) + + return model + + +class PromptLearner(nn.Module): + def __init__(self, num_class, dataset_name, dtype, token_embedding): + super().__init__() + if dataset_name == "VehicleID" or dataset_name == "veri": + ctx_init = "A photo of a X X X X vehicle." + else: + ctx_init = "A photo of a X X X X person." + + ctx_dim = 512 + # use given words to initialize context vectors + ctx_init = ctx_init.replace("_", " ") + n_ctx = 4 + + tokenized_prompts = clip.tokenize(ctx_init).cuda() + with torch.no_grad(): + embedding = token_embedding(tokenized_prompts).type(dtype) + self.tokenized_prompts = tokenized_prompts # torch.Tensor + + n_cls_ctx = 4 + cls_vectors = torch.empty(num_class, n_cls_ctx, ctx_dim, dtype=dtype) + nn.init.normal_(cls_vectors, std=0.02) + self.cls_ctx = nn.Parameter(cls_vectors) + + # These token vectors will be saved when in save_model(), + # but they should be ignored in load_model() as we want to use + # those computed using the current class names + self.register_buffer("token_prefix", embedding[:, :n_ctx + 1, :]) + self.register_buffer("token_suffix", embedding[:, n_ctx + 1 + n_cls_ctx:, :]) + self.num_class = num_class + self.n_cls_ctx = n_cls_ctx + + def forward(self, label): + cls_ctx = self.cls_ctx[label] + b = label.shape[0] + prefix = self.token_prefix.expand(b, -1, -1) + suffix = self.token_suffix.expand(b, -1, -1) + + prompts = torch.cat( + [ + prefix, # (n_cls, 1, dim) + cls_ctx, # (n_cls, n_ctx, dim) + suffix, # (n_cls, *, dim) + ], + dim=1, + ) + + return prompts diff --git a/boxmot/appearance/backbones/hacnn.py b/boxmot/appearance/backbones/hacnn.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3dea039fe347e2f08162b9a8befa5737c2759c --- /dev/null +++ b/boxmot/appearance/backbones/hacnn.py @@ -0,0 +1,406 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import, division + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ["HACNN"] + + +class ConvBlock(nn.Module): + """Basic convolutional block. + + convolution + batch normalization + relu. + + Args: + in_c (int): number of input channels. + out_c (int): number of output channels. + k (int or tuple): kernel size. + s (int or tuple): stride. + p (int or tuple): padding. + """ + + def __init__(self, in_c, out_c, k, s=1, p=0): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) + self.bn = nn.BatchNorm2d(out_c) + + def forward(self, x): + return F.relu(self.bn(self.conv(x))) + + +class InceptionA(nn.Module): + def __init__(self, in_channels, out_channels): + super(InceptionA, self).__init__() + mid_channels = out_channels // 4 + + self.stream1 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream2 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream3 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ) + self.stream4 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1), + ConvBlock(in_channels, mid_channels, 1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + s4 = self.stream4(x) + y = torch.cat([s1, s2, s3, s4], dim=1) + return y + + +class InceptionB(nn.Module): + def __init__(self, in_channels, out_channels): + super(InceptionB, self).__init__() + mid_channels = out_channels // 4 + + self.stream1 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), + ) + self.stream2 = nn.Sequential( + ConvBlock(in_channels, mid_channels, 1), + ConvBlock(mid_channels, mid_channels, 3, p=1), + ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), + ) + self.stream3 = nn.Sequential( + nn.MaxPool2d(3, stride=2, padding=1), + ConvBlock(in_channels, mid_channels * 2, 1), + ) + + def forward(self, x): + s1 = self.stream1(x) + s2 = self.stream2(x) + s3 = self.stream3(x) + y = torch.cat([s1, s2, s3], dim=1) + return y + + +class SpatialAttn(nn.Module): + """Spatial Attention (Sec. 3.1.I.1)""" + + def __init__(self): + super(SpatialAttn, self).__init__() + self.conv1 = ConvBlock(1, 1, 3, s=2, p=1) + self.conv2 = ConvBlock(1, 1, 1) + + def forward(self, x): + # global cross-channel averaging + x = x.mean(1, keepdim=True) + # 3-by-3 conv + x = self.conv1(x) + # bilinear resizing + x = F.upsample( + x, (x.size(2) * 2, x.size(3) * 2), mode="bilinear", align_corners=True + ) + # scaling conv + x = self.conv2(x) + return x + + +class ChannelAttn(nn.Module): + """Channel Attention (Sec. 3.1.I.2)""" + + def __init__(self, in_channels, reduction_rate=16): + super(ChannelAttn, self).__init__() + assert in_channels % reduction_rate == 0 + self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1) + self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1) + + def forward(self, x): + # squeeze operation (global average pooling) + x = F.avg_pool2d(x, x.size()[2:]) + # excitation operation (2 conv layers) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SoftAttn(nn.Module): + """Soft Attention (Sec. 3.1.I) + + Aim: Spatial Attention + Channel Attention + + Output: attention maps with shape identical to input. + """ + + def __init__(self, in_channels): + super(SoftAttn, self).__init__() + self.spatial_attn = SpatialAttn() + self.channel_attn = ChannelAttn(in_channels) + self.conv = ConvBlock(in_channels, in_channels, 1) + + def forward(self, x): + y_spatial = self.spatial_attn(x) + y_channel = self.channel_attn(x) + y = y_spatial * y_channel + y = torch.sigmoid(self.conv(y)) + return y + + +class HardAttn(nn.Module): + """Hard Attention (Sec. 3.1.II)""" + + def __init__(self, in_channels): + super(HardAttn, self).__init__() + self.fc = nn.Linear(in_channels, 4 * 2) + self.init_params() + + def init_params(self): + self.fc.weight.data.zero_() + self.fc.bias.data.copy_( + torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float) + ) + + def forward(self, x): + # squeeze operation (global average pooling) + x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1)) + # predict transformation parameters + theta = torch.tanh(self.fc(x)) + theta = theta.view(-1, 4, 2) + return theta + + +class HarmAttn(nn.Module): + """Harmonious Attention (Sec. 3.1)""" + + def __init__(self, in_channels): + super(HarmAttn, self).__init__() + self.soft_attn = SoftAttn(in_channels) + self.hard_attn = HardAttn(in_channels) + + def forward(self, x): + y_soft_attn = self.soft_attn(x) + theta = self.hard_attn(x) + return y_soft_attn, theta + + +class HACNN(nn.Module): + """Harmonious Attention Convolutional Neural Network. + + Reference: + Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. + + Public keys: + - ``hacnn``: HACNN. + """ + + # Args: + # num_classes (int): number of classes to predict + # nchannels (list): number of channels AFTER concatenation + # feat_dim (int): feature dimension for a single stream + # learn_region (bool): whether to learn region features (i.e. local branch) + + def __init__( + self, + num_classes, + loss="softmax", + nchannels=[128, 256, 384], + feat_dim=512, + learn_region=True, + use_gpu=True, + **kwargs + ): + super(HACNN, self).__init__() + self.loss = loss + self.learn_region = learn_region + self.use_gpu = use_gpu + + self.conv = ConvBlock(3, 32, 3, s=2, p=1) + + # Construct Inception + HarmAttn blocks + # ============== Block 1 ============== + self.inception1 = nn.Sequential( + InceptionA(32, nchannels[0]), + InceptionB(nchannels[0], nchannels[0]), + ) + self.ha1 = HarmAttn(nchannels[0]) + + # ============== Block 2 ============== + self.inception2 = nn.Sequential( + InceptionA(nchannels[0], nchannels[1]), + InceptionB(nchannels[1], nchannels[1]), + ) + self.ha2 = HarmAttn(nchannels[1]) + + # ============== Block 3 ============== + self.inception3 = nn.Sequential( + InceptionA(nchannels[1], nchannels[2]), + InceptionB(nchannels[2], nchannels[2]), + ) + self.ha3 = HarmAttn(nchannels[2]) + + self.fc_global = nn.Sequential( + nn.Linear(nchannels[2], feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + ) + self.classifier_global = nn.Linear(feat_dim, num_classes) + + if self.learn_region: + self.init_scale_factors() + self.local_conv1 = InceptionB(32, nchannels[0]) + self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) + self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) + self.fc_local = nn.Sequential( + nn.Linear(nchannels[2] * 4, feat_dim), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + ) + self.classifier_local = nn.Linear(feat_dim, num_classes) + self.feat_dim = feat_dim * 2 + else: + self.feat_dim = feat_dim + + def init_scale_factors(self): + # initialize scale factors (s_w, s_h) for four regions + self.scale_factors = [] + self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) + self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) + self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) + self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) + + def stn(self, x, theta): + """Performs spatial transform + + x: (batch, channel, height, width) + theta: (batch, 2, 3) + """ + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid) + return x + + def transform_theta(self, theta_i, region_idx): + """Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)""" + scale_factors = self.scale_factors[region_idx] + theta = torch.zeros(theta_i.size(0), 2, 3) + theta[:, :, :2] = scale_factors + theta[:, :, -1] = theta_i + if self.use_gpu: + theta = theta.cuda() + return theta + + def forward(self, x): + assert ( + x.size(2) == 160 and x.size(3) == 64 + ), "Input size does not match, expected (160, 64) but got ({}, {})".format( + x.size(2), x.size(3) + ) + x = self.conv(x) + + # ============== Block 1 ============== + # global branch + x1 = self.inception1(x) + x1_attn, x1_theta = self.ha1(x1) + x1_out = x1 * x1_attn + # local branch + if self.learn_region: + x1_local_list = [] + for region_idx in range(4): + x1_theta_i = x1_theta[:, region_idx, :] + x1_theta_i = self.transform_theta(x1_theta_i, region_idx) + x1_trans_i = self.stn(x, x1_theta_i) + x1_trans_i = F.upsample( + x1_trans_i, (24, 28), mode="bilinear", align_corners=True + ) + x1_local_i = self.local_conv1(x1_trans_i) + x1_local_list.append(x1_local_i) + + # ============== Block 2 ============== + # Block 2 + # global branch + x2 = self.inception2(x1_out) + x2_attn, x2_theta = self.ha2(x2) + x2_out = x2 * x2_attn + # local branch + if self.learn_region: + x2_local_list = [] + for region_idx in range(4): + x2_theta_i = x2_theta[:, region_idx, :] + x2_theta_i = self.transform_theta(x2_theta_i, region_idx) + x2_trans_i = self.stn(x1_out, x2_theta_i) + x2_trans_i = F.upsample( + x2_trans_i, (12, 14), mode="bilinear", align_corners=True + ) + x2_local_i = x2_trans_i + x1_local_list[region_idx] + x2_local_i = self.local_conv2(x2_local_i) + x2_local_list.append(x2_local_i) + + # ============== Block 3 ============== + # Block 3 + # global branch + x3 = self.inception3(x2_out) + x3_attn, x3_theta = self.ha3(x3) + x3_out = x3 * x3_attn + # local branch + if self.learn_region: + x3_local_list = [] + for region_idx in range(4): + x3_theta_i = x3_theta[:, region_idx, :] + x3_theta_i = self.transform_theta(x3_theta_i, region_idx) + x3_trans_i = self.stn(x2_out, x3_theta_i) + x3_trans_i = F.upsample( + x3_trans_i, (6, 7), mode="bilinear", align_corners=True + ) + x3_local_i = x3_trans_i + x2_local_list[region_idx] + x3_local_i = self.local_conv3(x3_local_i) + x3_local_list.append(x3_local_i) + + # ============== Feature generation ============== + # global branch + x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view( + x3_out.size(0), x3_out.size(1) + ) + x_global = self.fc_global(x_global) + # local branch + if self.learn_region: + x_local_list = [] + for region_idx in range(4): + x_local_i = x3_local_list[region_idx] + x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view( + x_local_i.size(0), -1 + ) + x_local_list.append(x_local_i) + x_local = torch.cat(x_local_list, 1) + x_local = self.fc_local(x_local) + + if not self.training: + # l2 normalization before concatenation + if self.learn_region: + x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) + x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) + return torch.cat([x_global, x_local], 1) + else: + return x_global + + prelogits_global = self.classifier_global(x_global) + if self.learn_region: + prelogits_local = self.classifier_local(x_local) + + if self.loss == "softmax": + if self.learn_region: + return (prelogits_global, prelogits_local) + else: + return prelogits_global + + elif self.loss == "triplet": + if self.learn_region: + return (prelogits_global, prelogits_local), (x_global, x_local) + else: + return prelogits_global, x_global + + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) diff --git a/boxmot/appearance/backbones/lmbn/__init__.py b/boxmot/appearance/backbones/lmbn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/appearance/backbones/lmbn/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/appearance/backbones/lmbn/attention.py b/boxmot/appearance/backbones/lmbn/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7125c6d90e5a25ee29922ac97ecb25a68cf98a5d --- /dev/null +++ b/boxmot/appearance/backbones/lmbn/attention.py @@ -0,0 +1,281 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import math +import random + +import torch +from torch import nn +from torch.nn import Conv2d, Module, Parameter, ReLU, Sigmoid, Softmax +from torch.nn import functional as F + +torch_ver = torch.__version__[:3] + +__all__ = [ + "BatchDrop", + "BatchFeatureErase_Top", + "BatchRandomErasing", + "PAM_Module", + "CAM_Module", + "Dual_Module", + "SE_Module", +] + + +class BatchRandomErasing(nn.Module): + def __init__( + self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465] + ): + super(BatchRandomErasing, self).__init__() + + self.probability = probability + self.mean = mean + self.sl = sl + self.sh = sh + self.r1 = r1 + + def forward(self, img): + if self.training: + if random.uniform(0, 1) > self.probability: + return img + + for attempt in range(100): + area = img.size()[2] * img.size()[3] + + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.r1, 1 / self.r1) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < img.size()[3] and h < img.size()[2]: + x1 = random.randint(0, img.size()[2] - h) + y1 = random.randint(0, img.size()[3] - w) + if img.size()[1] == 3: + img[:, 0, x1: x1 + h, y1: y1 + w] = self.mean[0] + img[:, 1, x1: x1 + h, y1: y1 + w] = self.mean[1] + img[:, 2, x1: x1 + h, y1: y1 + w] = self.mean[2] + else: + img[:, 0, x1: x1 + h, y1: y1 + w] = self.mean[0] + return img + + return img + + +class BatchDrop(nn.Module): + """ + Ref: Batch DropBlock Network for Person Re-identification and Beyond + https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py + Created by: daizuozhuo + """ + + def __init__(self, h_ratio, w_ratio): + super(BatchDrop, self).__init__() + self.h_ratio = h_ratio + self.w_ratio = w_ratio + + def forward(self, x): + if self.training: + h, w = x.size()[-2:] + rh = round(self.h_ratio * h) + rw = round(self.w_ratio * w) + sx = random.randint(0, h - rh) + sy = random.randint(0, w - rw) + mask = x.new_ones(x.size()) + mask[:, :, sx: sx + rh, sy: sy + rw] = 0 + x = x * mask + return x + + +class BatchDropTop(nn.Module): + """ + Ref: Top-DB-Net: Top DropBlock for Activation Enhancement in Person Re-Identification + https://github.com/RQuispeC/top-dropblock/blob/master/torchreid/models/bdnet.py + Created by: RQuispeC + + """ + + def __init__(self, h_ratio): + super(BatchDropTop, self).__init__() + self.h_ratio = h_ratio + + def forward(self, x, visdrop=False): + if self.training or visdrop: + b, c, h, w = x.size() + rh = round(self.h_ratio * h) + act = (x**2).sum(1) + act = act.view(b, h * w) + act = F.normalize(act, p=2, dim=1) + act = act.view(b, h, w) + max_act, _ = act.max(2) + ind = torch.argsort(max_act, 1) + ind = ind[:, -rh:] + mask = [] + for i in range(b): + rmask = torch.ones(h) + rmask[ind[i]] = 0 + mask.append(rmask.unsqueeze(0)) + mask = torch.cat(mask) + mask = torch.repeat_interleave(mask, w, 1).view(b, h, w) + mask = torch.repeat_interleave(mask, c, 0).view(b, c, h, w) + if x.is_cuda: + mask = mask.cuda() + if visdrop: + return mask + x = x * mask + return x + + +class BatchFeatureErase_Top(nn.Module): + """ + Ref: Top-DB-Net: Top DropBlock for Activation Enhancement in Person Re-Identification + https://github.com/RQuispeC/top-dropblock/blob/master/torchreid/models/bdnet.py + Created by: RQuispeC + + """ + + def __init__( + self, + channels, + bottleneck_type, + h_ratio=0.33, + w_ratio=1.0, + double_bottleneck=False, + ): + super(BatchFeatureErase_Top, self).__init__() + + self.drop_batch_bottleneck = bottleneck_type(channels, 512) + + self.drop_batch_drop_basic = BatchDrop(h_ratio, w_ratio) + self.drop_batch_drop_top = BatchDropTop(h_ratio) + + def forward(self, x, drop_top=True, bottleneck_features=True, visdrop=False): + features = self.drop_batch_bottleneck(x) + + if drop_top: + x = self.drop_batch_drop_top(features, visdrop=visdrop) + else: + x = self.drop_batch_drop_basic(features, visdrop=visdrop) + if visdrop: + return x # x is dropmask + if bottleneck_features: + return x, features + else: + return x + + +class SE_Module(Module): + def __init__(self, channels, reduction=4): + super(SE_Module, self).__init__() + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class PAM_Module(Module): + """Position attention module""" + + # Ref from SAGAN + + def __init__(self, in_dim): + super(PAM_Module, self).__init__() + self.chanel_in = in_dim + + self.query_conv = Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1 + ) + self.key_conv = Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1 + ) + self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.gamma = Parameter(torch.zeros(1)) + + self.softmax = Softmax(dim=-1) + + def forward(self, x): + """ + inputs : + x : input feature maps( B X C X H X W) + returns : + out : attention value + input feature + attention: B X (HxW) X (HxW) + """ + m_batchsize, C, height, width = x.size() + proj_query = ( + self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) + ) + proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) + energy = torch.bmm(proj_query, proj_key) + attention = self.softmax(energy) + proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) + + out = torch.bmm(proj_value, attention.permute(0, 2, 1)) + out = out.view(m_batchsize, C, height, width) + + out = self.gamma * out + x + return out + + +class CAM_Module(Module): + """Channel attention module""" + + def __init__(self, in_dim): + super(CAM_Module, self).__init__() + self.chanel_in = in_dim + + self.gamma = Parameter(torch.zeros(1)) + self.softmax = Softmax(dim=-1) + + def forward(self, x): + """ + inputs : + x : input feature maps( B X C X H X W) + returns : + out : attention value + input feature + attention: B X C X C + """ + m_batchsize, C, height, width = x.size() + proj_query = x.view(m_batchsize, C, -1) + proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) + # proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1).contiguous() + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = self.softmax(energy_new) + proj_value = x.view(m_batchsize, C, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(m_batchsize, C, height, width) + + out = self.gamma * out + x + return out + + +class Dual_Module(Module): + """ + # Created by: CASIA IVA + # Email: jliu@nlpr.ia.ac.cn + # Copyright (c) 2018 + + # Reference: Dual Attention Network for Scene Segmentation + # https://arxiv.org/pdf/1809.02983.pdf + # https://github.com/junfu1115/DANet/blob/master/encoding/nn/attention.py + """ + + def __init__(self, in_dim): + super(Dual_Module).__init__() + self.indim = in_dim + self.pam = PAM_Module(in_dim) + self.cam = CAM_Module(in_dim) + + def forward(self, x): + out1 = self.pam(x) + out2 = self.cam(x) + return out1 + out2 diff --git a/boxmot/appearance/backbones/lmbn/bnneck.py b/boxmot/appearance/backbones/lmbn/bnneck.py new file mode 100644 index 0000000000000000000000000000000000000000..e193a8fadbd96f6b7e6c6c935aff2eb3db579f8e --- /dev/null +++ b/boxmot/appearance/backbones/lmbn/bnneck.py @@ -0,0 +1,166 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from torch import nn + + +class BNNeck(nn.Module): + def __init__(self, input_dim, class_num, return_f=False): + super(BNNeck, self).__init__() + self.return_f = return_f + self.bn = nn.BatchNorm1d(input_dim) + self.bn.bias.requires_grad_(False) + self.classifier = nn.Linear(input_dim, class_num, bias=False) + self.bn.apply(self.weights_init_kaiming) + self.classifier.apply(self.weights_init_classifier) + + def forward(self, x): + before_neck = x.view(x.size(0), x.size(1)) + after_neck = self.bn(before_neck) + + if self.return_f: + score = self.classifier(after_neck) + return after_neck, score, before_neck + else: + x = self.classifier(x) + return x + + def weights_init_kaiming(self, m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") + nn.init.constant_(m.bias, 0.0) + elif classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in") + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find("BatchNorm") != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + def weights_init_classifier(self, m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + nn.init.normal_(m.weight, std=0.001) + if m.bias: + nn.init.constant_(m.bias, 0.0) + + +class BNNeck3(nn.Module): + def __init__(self, input_dim, class_num, feat_dim, return_f=False): + super(BNNeck3, self).__init__() + self.return_f = return_f + # self.reduction = nn.Linear(input_dim, feat_dim) + # self.bn = nn.BatchNorm1d(feat_dim) + + self.reduction = nn.Conv2d(input_dim, feat_dim, 1, bias=False) + self.bn = nn.BatchNorm1d(feat_dim) + + self.bn.bias.requires_grad_(False) + self.classifier = nn.Linear(feat_dim, class_num, bias=False) + self.bn.apply(self.weights_init_kaiming) + self.classifier.apply(self.weights_init_classifier) + + def forward(self, x): + x = self.reduction(x) + # before_neck = x.squeeze(dim=3).squeeze(dim=2) + # after_neck = self.bn(x).squeeze(dim=3).squeeze(dim=2) + before_neck = x.view(x.size(0), x.size(1)) + after_neck = self.bn(before_neck) + if self.return_f: + score = self.classifier(after_neck) + return after_neck, score, before_neck + else: + x = self.classifier(x) + return x + + def weights_init_kaiming(self, m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") + nn.init.constant_(m.bias, 0.0) + elif classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in") + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find("BatchNorm") != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + def weights_init_classifier(self, m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + nn.init.normal_(m.weight, std=0.001) + if m.bias: + nn.init.constant_(m.bias, 0.0) + + +# Defines the new fc layer and classification layer +# |--Linear--|--bn--|--relu--|--Linear--| + + +class ClassBlock(nn.Module): + def __init__( + self, + input_dim, + class_num, + droprate=0, + relu=False, + bnorm=True, + num_bottleneck=512, + linear=True, + return_f=False, + ): + super(ClassBlock, self).__init__() + self.return_f = return_f + add_block = [] + if linear: + add_block += [nn.Linear(input_dim, num_bottleneck)] + else: + num_bottleneck = input_dim + if bnorm: + add_block += [nn.BatchNorm1d(num_bottleneck)] + if relu: + add_block += [nn.LeakyReLU(0.1)] + if droprate > 0: + add_block += [nn.Dropout(p=droprate)] + add_block = nn.Sequential(*add_block) + add_block.apply(self.weights_init_kaiming) + + classifier = [] + classifier += [nn.Linear(num_bottleneck, class_num)] + classifier = nn.Sequential(*classifier) + classifier.apply(self.weights_init_classifier) + + self.add_block = add_block + self.classifier = classifier + + def forward(self, x): + x = self.add_block(x.squeeze(3).squeeze(2)) + if self.return_f: + f = x + x = self.classifier(x) + return f, x, f + else: + x = self.classifier(x) + return x + + def weights_init_kaiming(self, m): + classname = m.__class__.__name__ + # print(classname) + if classname.find("Conv") != -1: + # For old pytorch, you may use kaiming_normal. + nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") + elif classname.find("Linear") != -1: + nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out") + nn.init.constant_(m.bias.data, 0.0) + elif classname.find("BatchNorm1d") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0.0) + + def weights_init_classifier(self, m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + nn.init.normal_(m.weight.data, std=0.001) + nn.init.constant_(m.bias.data, 0.0) diff --git a/boxmot/appearance/backbones/lmbn/lmbn_n.py b/boxmot/appearance/backbones/lmbn/lmbn_n.py new file mode 100644 index 0000000000000000000000000000000000000000..9beb3aaa2e4ff6e1e14b025975577b0ecd093832 --- /dev/null +++ b/boxmot/appearance/backbones/lmbn/lmbn_n.py @@ -0,0 +1,185 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import copy + +import torch +from torch import nn + +from boxmot.appearance.backbones.lmbn.attention import BatchFeatureErase_Top +from boxmot.appearance.backbones.lmbn.bnneck import BNNeck, BNNeck3 +from boxmot.appearance.backbones.osnet import OSBlock, osnet_x1_0 + + +class LMBN_n(nn.Module): + def __init__(self, num_classes, loss, pretrained, use_gpu): + super(LMBN_n, self).__init__() + + self.n_ch = 2 + self.chs = 512 // self.n_ch + self.training = False + + osnet = osnet_x1_0(pretrained=True) + + self.backone = nn.Sequential( + osnet.conv1, osnet.maxpool, osnet.conv2, osnet.conv3[0] + ) + + conv3 = osnet.conv3[1:] + + self.global_branch = nn.Sequential( + copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5) + ) + + self.partial_branch = nn.Sequential( + copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5) + ) + + self.channel_branch = nn.Sequential( + copy.deepcopy(conv3), copy.deepcopy(osnet.conv4), copy.deepcopy(osnet.conv5) + ) + + self.global_pooling = nn.AdaptiveMaxPool2d((1, 1)) + self.partial_pooling = nn.AdaptiveAvgPool2d((2, 1)) + self.channel_pooling = nn.AdaptiveAvgPool2d((1, 1)) + + reduction = BNNeck3(512, num_classes, 512, return_f=True) + + self.reduction_0 = copy.deepcopy(reduction) + self.reduction_1 = copy.deepcopy(reduction) + self.reduction_2 = copy.deepcopy(reduction) + self.reduction_3 = copy.deepcopy(reduction) + self.reduction_4 = copy.deepcopy(reduction) + + self.shared = nn.Sequential( + nn.Conv2d(self.chs, 512, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True) + ) + self.weights_init_kaiming(self.shared) + + self.reduction_ch_0 = BNNeck(512, num_classes, return_f=True) + self.reduction_ch_1 = BNNeck(512, num_classes, return_f=True) + + # if args.drop_block: + # print('Using batch random erasing block.') + # self.batch_drop_block = BatchRandomErasing() + # print('Using batch drop block.') + # self.batch_drop_block = BatchDrop( + # h_ratio=args.h_ratio, w_ratio=args.w_ratio) + self.batch_drop_block = BatchFeatureErase_Top(512, OSBlock) + + self.activation_map = False + + def forward(self, x): + # if self.batch_drop_block is not None: + # x = self.batch_drop_block(x) + + x = self.backone(x) + + glo = self.global_branch(x) + par = self.partial_branch(x) + cha = self.channel_branch(x) + + if self.activation_map: + glo_ = glo + + if self.batch_drop_block is not None: + glo_drop, glo = self.batch_drop_block(glo) + + if self.activation_map: + _, _, h_par, _ = par.size() + + fmap_p0 = par[:, :, :h_par // 2, :] + fmap_p1 = par[:, :, h_par // 2:, :] + fmap_c0 = cha[:, : self.chs, :, :] + fmap_c1 = cha[:, self.chs:, :, :] + print("Generating activation maps...") + + return glo, glo_, fmap_c0, fmap_c1, fmap_p0, fmap_p1 + + glo_drop = self.global_pooling(glo_drop) + glo = self.channel_pooling(glo) # shape:(batchsize, 512,1,1) + g_par = self.global_pooling(par) # shape:(batchsize, 512,1,1) + p_par = self.partial_pooling(par) # shape:(batchsize, 512,2,1) + cha = self.channel_pooling(cha) # shape:(batchsize, 256,1,1) + + p0 = p_par[:, :, 0:1, :] + p1 = p_par[:, :, 1:2, :] + + f_glo = self.reduction_0(glo) + f_p0 = self.reduction_1(g_par) + f_p1 = self.reduction_2(p0) + f_p2 = self.reduction_3(p1) + f_glo_drop = self.reduction_4(glo_drop) + + ################ + + c0 = cha[:, : self.chs, :, :] + c1 = cha[:, self.chs:, :, :] + c0 = self.shared(c0) + c1 = self.shared(c1) + f_c0 = self.reduction_ch_0(c0) + f_c1 = self.reduction_ch_1(c1) + + ################ + + fea = [f_glo[-1], f_glo_drop[-1], f_p0[-1]] + + if not self.training: + features = torch.stack( + [f_glo[0], f_glo_drop[0], f_p0[0], f_p1[0], f_p2[0], f_c0[0], f_c1[0]], + dim=2, + ) + features = features.flatten(1, 2) + return features + + return [ + f_glo[1], + f_glo_drop[1], + f_p0[1], + f_p1[1], + f_p2[1], + f_c0[1], + f_c1[1], + ], fea + + def weights_init_kaiming(self, m): + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") + nn.init.constant_(m.bias, 0.0) + elif classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in") + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find("BatchNorm") != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + +if __name__ == "__main__": + # Here I left a simple forward function. + # Test the model, before you train it. + import argparse + + parser = argparse.ArgumentParser(description="MGN") + parser.add_argument("--num_classes", type=int, default=751, help="") + parser.add_argument("--bnneck", type=bool, default=True) + parser.add_argument("--pool", type=str, default="max") + parser.add_argument("--feats", type=int, default=512) + parser.add_argument("--drop_block", type=bool, default=True) + parser.add_argument("--w_ratio", type=float, default=1.0, help="") + + args = parser.parse_args() + # net = MCMP_n(args) + # net.classifier = nn.Sequential() + # print([p for p in net.parameters()]) + # a=filter(lambda p: p.requires_grad, net.parameters()) + # print(a) + + # print(net) + # input = Variable(torch.FloatTensor(8, 3, 384, 128)) + # net.eval() + # output = net(input) + # print(output.shape) + print("net output size:") + # print(len(output)) diff --git a/boxmot/appearance/backbones/mlfn.py b/boxmot/appearance/backbones/mlfn.py new file mode 100644 index 0000000000000000000000000000000000000000..442dd1fd37a1451c2e7aa608c6e8679e8632e182 --- /dev/null +++ b/boxmot/appearance/backbones/mlfn.py @@ -0,0 +1,240 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import, division + +import torch +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ["mlfn"] + +model_urls = { + # training epoch = 5, top1 = 51.6 + "imagenet": "https://mega.nz/#!YHxAhaxC!yu9E6zWl0x5zscSouTdbZu8gdFFytDdl-RAdD2DEfpk", +} + + +class MLFNBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, fsm_channels, groups=32): + super(MLFNBlock, self).__init__() + self.groups = groups + mid_channels = out_channels // 2 + + # Factor Modules + self.fm_conv1 = nn.Conv2d(in_channels, mid_channels, 1, bias=False) + self.fm_bn1 = nn.BatchNorm2d(mid_channels) + self.fm_conv2 = nn.Conv2d( + mid_channels, + mid_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=self.groups, + ) + self.fm_bn2 = nn.BatchNorm2d(mid_channels) + self.fm_conv3 = nn.Conv2d(mid_channels, out_channels, 1, bias=False) + self.fm_bn3 = nn.BatchNorm2d(out_channels) + + # Factor Selection Module + self.fsm = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, fsm_channels[0], 1), + nn.BatchNorm2d(fsm_channels[0]), + nn.ReLU(inplace=True), + nn.Conv2d(fsm_channels[0], fsm_channels[1], 1), + nn.BatchNorm2d(fsm_channels[1]), + nn.ReLU(inplace=True), + nn.Conv2d(fsm_channels[1], self.groups, 1), + nn.BatchNorm2d(self.groups), + nn.Sigmoid(), + ) + + self.downsample = None + if in_channels != out_channels or stride > 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + residual = x + s = self.fsm(x) + + # reduce dimension + x = self.fm_conv1(x) + x = self.fm_bn1(x) + x = F.relu(x, inplace=True) + + # group convolution + x = self.fm_conv2(x) + x = self.fm_bn2(x) + x = F.relu(x, inplace=True) + + # factor selection + b, c = x.size(0), x.size(1) + n = c // self.groups + ss = s.repeat(1, n, 1, 1) # from (b, g, 1, 1) to (b, g*n=c, 1, 1) + ss = ss.view(b, n, self.groups, 1, 1) + ss = ss.permute(0, 2, 1, 3, 4).contiguous() + ss = ss.view(b, c, 1, 1) + x = ss * x + + # recover dimension + x = self.fm_conv3(x) + x = self.fm_bn3(x) + x = F.relu(x, inplace=True) + + if self.downsample is not None: + residual = self.downsample(residual) + + return F.relu(residual + x, inplace=True), s + + +class MLFN(nn.Module): + """Multi-Level Factorisation Net. + + Reference: + Chang et al. Multi-Level Factorisation Net for + Person Re-Identification. CVPR 2018. + + Public keys: + - ``mlfn``: MLFN (Multi-Level Factorisation Net). + """ + + def __init__( + self, + num_classes, + loss="softmax", + groups=32, + channels=[64, 256, 512, 1024, 2048], + embed_dim=1024, + **kwargs + ): + super(MLFN, self).__init__() + self.loss = loss + self.groups = groups + + # first convolutional layer + self.conv1 = nn.Conv2d(3, channels[0], 7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(channels[0]) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + + # main body + self.feature = nn.ModuleList( + [ + # layer 1-3 + MLFNBlock(channels[0], channels[1], 1, [128, 64], self.groups), + MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), + MLFNBlock(channels[1], channels[1], 1, [128, 64], self.groups), + # layer 4-7 + MLFNBlock(channels[1], channels[2], 2, [256, 128], self.groups), + MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups), + MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups), + MLFNBlock(channels[2], channels[2], 1, [256, 128], self.groups), + # layer 8-13 + MLFNBlock(channels[2], channels[3], 2, [512, 128], self.groups), + MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), + MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), + MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), + MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), + MLFNBlock(channels[3], channels[3], 1, [512, 128], self.groups), + # layer 14-16 + MLFNBlock(channels[3], channels[4], 2, [512, 128], self.groups), + MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups), + MLFNBlock(channels[4], channels[4], 1, [512, 128], self.groups), + ] + ) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + + # projection functions + self.fc_x = nn.Sequential( + nn.Conv2d(channels[4], embed_dim, 1, bias=False), + nn.BatchNorm2d(embed_dim), + nn.ReLU(inplace=True), + ) + self.fc_s = nn.Sequential( + nn.Conv2d(self.groups * 16, embed_dim, 1, bias=False), + nn.BatchNorm2d(embed_dim), + nn.ReLU(inplace=True), + ) + + self.classifier = nn.Linear(embed_dim, num_classes) + + self.init_params() + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x, inplace=True) + x = self.maxpool(x) + + s_hat = [] + for block in self.feature: + x, s = block(x) + s_hat.append(s) + s_hat = torch.cat(s_hat, 1) + + x = self.global_avgpool(x) + x = self.fc_x(x) + s_hat = self.fc_s(s_hat) + + v = (x + s_hat) * 0.5 + v = v.view(v.size(0), -1) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def mlfn(num_classes, loss="softmax", pretrained=True, **kwargs): + model = MLFN(num_classes, loss, **kwargs) + if pretrained: + # init_pretrained_weights(model, model_urls['imagenet']) + import warnings + + warnings.warn( + "The imagenet pretrained weights need to be manually downloaded from {}".format( + model_urls["imagenet"] + ) + ) + return model diff --git a/boxmot/appearance/backbones/mobilenetv2.py b/boxmot/appearance/backbones/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4071ac84d0a1600ae30a6e76254308adc72ad7 --- /dev/null +++ b/boxmot/appearance/backbones/mobilenetv2.py @@ -0,0 +1,246 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import, division + +import torch.utils.model_zoo as model_zoo +from torch import nn +from torch.nn import functional as F + +__all__ = ["mobilenetv2_x1_0", "mobilenetv2_x1_4"] + +model_urls = { + # 1.0: top-1 71.3 + "mobilenetv2_x1_0": "https://mega.nz/#!NKp2wAIA!1NH1pbNzY_M2hVk_hdsxNM1NUOWvvGPHhaNr-fASF6c", + # 1.4: top-1 73.9 + "mobilenetv2_x1_4": "https://mega.nz/#!RGhgEIwS!xN2s2ZdyqI6vQ3EwgmRXLEW3khr9tpXg96G9SUJugGk", +} + + +class ConvBlock(nn.Module): + """Basic convolutional block. + + convolution (bias discarded) + batch normalization + relu6. + + Args: + in_c (int): number of input channels. + out_c (int): number of output channels. + k (int or tuple): kernel size. + s (int or tuple): stride. + p (int or tuple): padding. + g (int): number of blocked connections from input channels + to output channels (default: 1). + """ + + def __init__(self, in_c, out_c, k, s=1, p=0, g=1): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g) + self.bn = nn.BatchNorm2d(out_c) + + def forward(self, x): + return F.relu6(self.bn(self.conv(x))) + + +class Bottleneck(nn.Module): + def __init__(self, in_channels, out_channels, expansion_factor, stride=1): + super(Bottleneck, self).__init__() + mid_channels = in_channels * expansion_factor + self.use_residual = stride == 1 and in_channels == out_channels + self.conv1 = ConvBlock(in_channels, mid_channels, 1) + self.dwconv2 = ConvBlock( + mid_channels, mid_channels, 3, stride, 1, g=mid_channels + ) + self.conv3 = nn.Sequential( + nn.Conv2d(mid_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + m = self.conv1(x) + m = self.dwconv2(m) + m = self.conv3(m) + if self.use_residual: + return x + m + else: + return m + + +class MobileNetV2(nn.Module): + """MobileNetV2. + + Reference: + Sandler et al. MobileNetV2: Inverted Residuals and + Linear Bottlenecks. CVPR 2018. + + Public keys: + - ``mobilenetv2_x1_0``: MobileNetV2 x1.0. + - ``mobilenetv2_x1_4``: MobileNetV2 x1.4. + """ + + def __init__( + self, + num_classes, + width_mult=1, + loss="softmax", + fc_dims=None, + dropout_p=None, + **kwargs + ): + super(MobileNetV2, self).__init__() + self.loss = loss + self.in_channels = int(32 * width_mult) + self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280 + + # construct layers + self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1) + self.conv2 = self._make_layer(Bottleneck, 1, int(16 * width_mult), 1, 1) + self.conv3 = self._make_layer(Bottleneck, 6, int(24 * width_mult), 2, 2) + self.conv4 = self._make_layer(Bottleneck, 6, int(32 * width_mult), 3, 2) + self.conv5 = self._make_layer(Bottleneck, 6, int(64 * width_mult), 4, 2) + self.conv6 = self._make_layer(Bottleneck, 6, int(96 * width_mult), 3, 1) + self.conv7 = self._make_layer(Bottleneck, 6, int(160 * width_mult), 3, 2) + self.conv8 = self._make_layer(Bottleneck, 6, int(320 * width_mult), 1, 1) + self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1) + + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = self._construct_fc_layer(fc_dims, self.feature_dim, dropout_p) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer(self, block, t, c, n, s): + # t: expansion factor + # c: output channels + # n: number of blocks + # s: stride for first layer + layers = [] + layers.append(block(self.in_channels, c, t, s)) + self.in_channels = c + for i in range(1, n): + layers.append(block(self.in_channels, c, t)) + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer. + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims)) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + x = self.conv6(x) + x = self.conv7(x) + x = self.conv8(x) + x = self.conv9(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs): + model = MobileNetV2( + num_classes, loss=loss, width_mult=1, fc_dims=None, dropout_p=None, **kwargs + ) + if pretrained: + # init_pretrained_weights(model, model_urls['mobilenetv2_x1_0']) + import warnings + + warnings.warn( + "The imagenet pretrained weights need to be manually downloaded from {}".format( + model_urls["mobilenetv2_x1_0"] + ) + ) + return model + + +def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs): + model = MobileNetV2( + num_classes, loss=loss, width_mult=1.4, fc_dims=None, dropout_p=None, **kwargs + ) + if pretrained: + # init_pretrained_weights(model, model_urls['mobilenetv2_x1_4']) + import warnings + + warnings.warn( + "The imagenet pretrained weights need to be manually downloaded from {}".format( + model_urls["mobilenetv2_x1_4"] + ) + ) + return model diff --git a/boxmot/appearance/backbones/osnet.py b/boxmot/appearance/backbones/osnet.py new file mode 100644 index 0000000000000000000000000000000000000000..709c7b5f21da0d9f842ffab8ddf82fe7c1aef0cc --- /dev/null +++ b/boxmot/appearance/backbones/osnet.py @@ -0,0 +1,560 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import, division + +import warnings + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ["osnet_x1_0", "osnet_x0_75", "osnet_x0_5", "osnet_x0_25", "osnet_ibn_x1_0"] + +pretrained_urls = { + "osnet_x1_0": "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY", + "osnet_x0_75": "https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq", + "osnet_x0_5": "https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i", + "osnet_x0_25": "https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs", + "osnet_ibn_x1_0": "https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l", +} + + +########## +# Basic layers +########## +class ConvLayer(nn.Module): + """Convolution layer (conv + bn + relu).""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + IN=False, + ): + super(ConvLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + groups=groups, + ) + if IN: + self.bn = nn.InstanceNorm2d(out_channels, affine=True) + else: + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Conv1x1(nn.Module): + """1x1 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv1x1, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=stride, + padding=0, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Conv1x1Linear(nn.Module): + """1x1 convolution + bn (w/o non-linearity).""" + + def __init__(self, in_channels, out_channels, stride=1): + super(Conv1x1Linear, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=stride, padding=0, bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class Conv3x3(nn.Module): + """3x3 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv3x3, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class LightConv3x3(nn.Module): + """Lightweight 3x3 convolution. + + 1x1 (linear) + dw 3x3 (nonlinear). + """ + + def __init__(self, in_channels, out_channels): + super(LightConv3x3, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + groups=out_channels, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn(x) + x = self.relu(x) + return x + + +########## +# Building blocks for omni-scale feature learning +########## +class ChannelGate(nn.Module): + """A mini-network that generates channel-wise gates conditioned on input tensor.""" + + def __init__( + self, + in_channels, + num_gates=None, + return_gates=False, + gate_activation="sigmoid", + reduction=16, + layer_norm=False, + ): + super(ChannelGate, self).__init__() + if num_gates is None: + num_gates = in_channels + self.return_gates = return_gates + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0 + ) + self.norm1 = None + if layer_norm: + self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d( + in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0 + ) + if gate_activation == "sigmoid": + self.gate_activation = nn.Sigmoid() + elif gate_activation == "relu": + self.gate_activation = nn.ReLU(inplace=True) + elif gate_activation == "linear": + self.gate_activation = None + else: + raise RuntimeError("Unknown gate activation: {}".format(gate_activation)) + + def forward(self, x): + input = x + x = self.global_avgpool(x) + x = self.fc1(x) + if self.norm1 is not None: + x = self.norm1(x) + x = self.relu(x) + x = self.fc2(x) + if self.gate_activation is not None: + x = self.gate_activation(x) + if self.return_gates: + return x + return input * x + + +class OSBlock(nn.Module): + """Omni-scale feature learning block.""" + + def __init__( + self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs + ): + super(OSBlock, self).__init__() + mid_channels = out_channels // bottleneck_reduction + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2a = LightConv3x3(mid_channels, mid_channels) + self.conv2b = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.conv2c = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.conv2d = nn.Sequential( + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + LightConv3x3(mid_channels, mid_channels), + ) + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + self.IN = None + if IN: + self.IN = nn.InstanceNorm2d(out_channels, affine=True) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2a = self.conv2a(x1) + x2b = self.conv2b(x1) + x2c = self.conv2c(x1) + x2d = self.conv2d(x1) + x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d) + x3 = self.conv3(x2) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + if self.IN is not None: + out = self.IN(out) + return F.relu(out) + + +########## +# Network architecture +########## +class OSNet(nn.Module): + """Omni-Scale Network. + + Reference: + - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. + - Zhou et al. Learning Generalisable Omni-Scale Representations + for Person Re-Identification. TPAMI, 2021. + """ + + def __init__( + self, + num_classes, + blocks, + layers, + channels, + feature_dim=512, + loss="softmax", + IN=False, + **kwargs + ): + super(OSNet, self).__init__() + num_blocks = len(blocks) + assert num_blocks == len(layers) + assert num_blocks == len(channels) - 1 + self.loss = loss + self.feature_dim = feature_dim + + # convolutional backbone + self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.conv2 = self._make_layer( + blocks[0], + layers[0], + channels[0], + channels[1], + reduce_spatial_size=True, + IN=IN, + ) + self.conv3 = self._make_layer( + blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True + ) + self.conv4 = self._make_layer( + blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False + ) + self.conv5 = Conv1x1(channels[3], channels[3]) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + # fully connected layer + self.fc = self._construct_fc_layer( + self.feature_dim, channels[3], dropout_p=None + ) + # identity classification layer + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer( + self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False + ): + layers = [] + + layers.append(block(in_channels, out_channels, IN=IN)) + for i in range(1, layer): + layers.append(block(out_channels, out_channels, IN=IN)) + + if reduce_spatial_size: + layers.append( + nn.Sequential( + Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2) + ) + ) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None or fc_dims < 0: + self.feature_dim = input_dim + return None + + if isinstance(fc_dims, int): + fc_dims = [fc_dims] + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + return x + + def forward(self, x, return_featuremaps=False): + x = self.featuremaps(x) + if return_featuremaps: + return x + v = self.global_avgpool(x) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, key=""): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + import errno + import os + from collections import OrderedDict + + import gdown + + def _get_torch_home(): + ENV_TORCH_HOME = "TORCH_HOME" + ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" + DEFAULT_CACHE_DIR = "~/.cache" + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), + ) + ) + return torch_home + + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, "checkpoints") + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + filename = key + "_imagenet.pth" + cached_file = os.path.join(model_dir, filename) + + if not os.path.exists(cached_file): + gdown.download(pretrained_urls[key], cached_file, quiet=False) + + state_dict = torch.load(cached_file) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith("module."): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + warnings.warn( + 'The pretrained weights from "{}" cannot be loaded, ' + "please check the key names manually " + "(** ignored and continue **)".format(cached_file) + ) + else: + print( + 'Successfully loaded imagenet pretrained weights from "{}"'.format( + cached_file + ) + ) + if len(discarded_layers) > 0: + print( + "** The following layers are discarded " + "due to unmatched keys or layer size: {}".format(discarded_layers) + ) + + +########## +# Instantiation +########## +def osnet_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + # standard size (width x1.0) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x1_0") + return model + + +def osnet_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + # medium size (width x0.75) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[48, 192, 288, 384], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x0_75") + return model + + +def osnet_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + # tiny size (width x0.5) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[32, 128, 192, 256], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x0_5") + return model + + +def osnet_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + # very tiny size (width x0.25) + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[16, 64, 96, 128], + loss=loss, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_x0_25") + return model + + +def osnet_ibn_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + # standard size (width x1.0) + IBN layer + # Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018. + model = OSNet( + num_classes, + blocks=[OSBlock, OSBlock, OSBlock], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ibn_x1_0") + return model diff --git a/boxmot/appearance/backbones/osnet_ain.py b/boxmot/appearance/backbones/osnet_ain.py new file mode 100644 index 0000000000000000000000000000000000000000..09d5ef9c4423d83e1223b396d515ddf67f2bad35 --- /dev/null +++ b/boxmot/appearance/backbones/osnet_ain.py @@ -0,0 +1,582 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import, division + +import warnings + +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ["osnet_ain_x1_0", "osnet_ain_x0_75", "osnet_ain_x0_5", "osnet_ain_x0_25"] + +pretrained_urls = { + "osnet_ain_x1_0": "https://drive.google.com/uc?id=1-CaioD9NaqbHK_kzSMW8VE4_3KcsRjEo", + "osnet_ain_x0_75": "https://drive.google.com/uc?id=1apy0hpsMypqstfencdH-jKIUEFOW4xoM", + "osnet_ain_x0_5": "https://drive.google.com/uc?id=1KusKvEYyKGDTUBVRxRiz55G31wkihB6l", + "osnet_ain_x0_25": "https://drive.google.com/uc?id=1SxQt2AvmEcgWNhaRb2xC4rP6ZwVDP0Wt", +} + + +########## +# Basic layers +########## +class ConvLayer(nn.Module): + """Convolution layer (conv + bn + relu).""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + groups=1, + IN=False, + ): + super(ConvLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + groups=groups, + ) + if IN: + self.bn = nn.InstanceNorm2d(out_channels, affine=True) + else: + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class Conv1x1(nn.Module): + """1x1 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv1x1, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 1, + stride=stride, + padding=0, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class Conv1x1Linear(nn.Module): + """1x1 convolution + bn (w/o non-linearity).""" + + def __init__(self, in_channels, out_channels, stride=1, bn=True): + super(Conv1x1Linear, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=stride, padding=0, bias=False + ) + self.bn = None + if bn: + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class Conv3x3(nn.Module): + """3x3 convolution + bn + relu.""" + + def __init__(self, in_channels, out_channels, stride=1, groups=1): + super(Conv3x3, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + 3, + stride=stride, + padding=1, + bias=False, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.relu(x) + + +class LightConv3x3(nn.Module): + """Lightweight 3x3 convolution. + + 1x1 (linear) + dw 3x3 (nonlinear). + """ + + def __init__(self, in_channels, out_channels): + super(LightConv3x3, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, 1, stride=1, padding=0, bias=False + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + 3, + stride=1, + padding=1, + bias=False, + groups=out_channels, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.bn(x) + return self.relu(x) + + +class LightConvStream(nn.Module): + """Lightweight convolution stream.""" + + def __init__(self, in_channels, out_channels, depth): + super(LightConvStream, self).__init__() + assert depth >= 1, "depth must be equal to or larger than 1, but got {}".format( + depth + ) + layers = [] + layers += [LightConv3x3(in_channels, out_channels)] + for i in range(depth - 1): + layers += [LightConv3x3(out_channels, out_channels)] + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +########## +# Building blocks for omni-scale feature learning +########## +class ChannelGate(nn.Module): + """A mini-network that generates channel-wise gates conditioned on input tensor.""" + + def __init__( + self, + in_channels, + num_gates=None, + return_gates=False, + gate_activation="sigmoid", + reduction=16, + layer_norm=False, + ): + super(ChannelGate, self).__init__() + if num_gates is None: + num_gates = in_channels + self.return_gates = return_gates + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0 + ) + self.norm1 = None + if layer_norm: + self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2d( + in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0 + ) + if gate_activation == "sigmoid": + self.gate_activation = nn.Sigmoid() + elif gate_activation == "relu": + self.gate_activation = nn.ReLU() + elif gate_activation == "linear": + self.gate_activation = None + else: + raise RuntimeError("Unknown gate activation: {}".format(gate_activation)) + + def forward(self, x): + input = x + x = self.global_avgpool(x) + x = self.fc1(x) + if self.norm1 is not None: + x = self.norm1(x) + x = self.relu(x) + x = self.fc2(x) + if self.gate_activation is not None: + x = self.gate_activation(x) + if self.return_gates: + return x + return input * x + + +class OSBlock(nn.Module): + """Omni-scale feature learning block.""" + + def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): + super(OSBlock, self).__init__() + assert T >= 1 + assert out_channels >= reduction and out_channels % reduction == 0 + mid_channels = out_channels // reduction + + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2 = nn.ModuleList() + for t in range(1, T + 1): + self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2 = 0 + for conv2_t in self.conv2: + x2_t = conv2_t(x1) + x2 = x2 + self.gate(x2_t) + x3 = self.conv3(x2) + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + return F.relu(out) + + +class OSBlockINin(nn.Module): + """Omni-scale feature learning block with instance normalization.""" + + def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs): + super(OSBlockINin, self).__init__() + assert T >= 1 + assert out_channels >= reduction and out_channels % reduction == 0 + mid_channels = out_channels // reduction + + self.conv1 = Conv1x1(in_channels, mid_channels) + self.conv2 = nn.ModuleList() + for t in range(1, T + 1): + self.conv2 += [LightConvStream(mid_channels, mid_channels, t)] + self.gate = ChannelGate(mid_channels) + self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False) + self.downsample = None + if in_channels != out_channels: + self.downsample = Conv1x1Linear(in_channels, out_channels) + self.IN = nn.InstanceNorm2d(out_channels, affine=True) + + def forward(self, x): + identity = x + x1 = self.conv1(x) + x2 = 0 + for conv2_t in self.conv2: + x2_t = conv2_t(x1) + x2 = x2 + self.gate(x2_t) + x3 = self.conv3(x2) + x3 = self.IN(x3) # IN inside residual + if self.downsample is not None: + identity = self.downsample(identity) + out = x3 + identity + return F.relu(out) + + +########## +# Network architecture +########## +class OSNet(nn.Module): + """Omni-Scale Network. + + Reference: + - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019. + - Zhou et al. Learning Generalisable Omni-Scale Representations + for Person Re-Identification. TPAMI, 2021. + """ + + def __init__( + self, + num_classes, + blocks, + layers, + channels, + feature_dim=512, + loss="softmax", + conv1_IN=False, + **kwargs + ): + super(OSNet, self).__init__() + num_blocks = len(blocks) + assert num_blocks == len(layers) + assert num_blocks == len(channels) - 1 + self.loss = loss + self.feature_dim = feature_dim + + # convolutional backbone + self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=conv1_IN) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1]) + self.pool2 = nn.Sequential( + Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2) + ) + self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2]) + self.pool3 = nn.Sequential( + Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2) + ) + self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3]) + self.conv5 = Conv1x1(channels[3], channels[3]) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + # fully connected layer + self.fc = self._construct_fc_layer( + self.feature_dim, channels[3], dropout_p=None + ) + # identity classification layer + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + def _make_layer(self, blocks, layer, in_channels, out_channels): + layers = [] + layers += [blocks[0](in_channels, out_channels)] + for i in range(1, len(blocks)): + layers += [blocks[i](out_channels, out_channels)] + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + if fc_dims is None or fc_dims < 0: + self.feature_dim = input_dim + return None + + if isinstance(fc_dims, int): + fc_dims = [fc_dims] + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU()) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.InstanceNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.conv3(x) + x = self.pool3(x) + x = self.conv4(x) + x = self.conv5(x) + return x + + def forward(self, x, return_featuremaps=False): + x = self.featuremaps(x) + if return_featuremaps: + return x + v = self.global_avgpool(x) + v = v.view(v.size(0), -1) + if self.fc is not None: + v = self.fc(v) + if not self.training: + return v + y = self.classifier(v) + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, key=""): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + import errno + import os + from collections import OrderedDict + + import gdown + + def _get_torch_home(): + ENV_TORCH_HOME = "TORCH_HOME" + ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" + DEFAULT_CACHE_DIR = "~/.cache" + torch_home = os.path.expanduser( + os.getenv( + ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), + ) + ) + return torch_home + + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, "checkpoints") + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + filename = key + "_imagenet.pth" + cached_file = os.path.join(model_dir, filename) + + if not os.path.exists(cached_file): + gdown.download(pretrained_urls[key], cached_file, quiet=False) + + state_dict = torch.load(cached_file) + model_dict = model.state_dict() + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + + for k, v in state_dict.items(): + if k.startswith("module."): + k = k[7:] # discard module. + + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if len(matched_layers) == 0: + warnings.warn( + 'The pretrained weights from "{}" cannot be loaded, ' + "please check the key names manually " + "(** ignored and continue **)".format(cached_file) + ) + else: + print( + 'Successfully loaded imagenet pretrained weights from "{}"'.format( + cached_file + ) + ) + if len(discarded_layers) > 0: + print( + "** The following layers are discarded " + "due to unmatched keys or layer size: {}".format(discarded_layers) + ) + + +########## +# Instantiation +########## +def osnet_ain_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[64, 256, 384, 512], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x1_0") + return model + + +def osnet_ain_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[48, 192, 288, 384], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x0_75") + return model + + +def osnet_ain_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[32, 128, 192, 256], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x0_5") + return model + + +def osnet_ain_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs): + model = OSNet( + num_classes, + blocks=[ + [OSBlockINin, OSBlockINin], + [OSBlock, OSBlockINin], + [OSBlockINin, OSBlock], + ], + layers=[2, 2, 2], + channels=[16, 64, 96, 128], + loss=loss, + conv1_IN=True, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, key="osnet_ain_x0_25") + return model diff --git a/boxmot/appearance/backbones/resnet.py b/boxmot/appearance/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e818da23d5f1c2d8af1608cac0c2b7236d5d7abe --- /dev/null +++ b/boxmot/appearance/backbones/resnet.py @@ -0,0 +1,517 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +""" +Code source: https://github.com/pytorch/vision +""" +from __future__ import absolute_import, division + +import torch.utils.model_zoo as model_zoo +from torch import nn + +__all__ = [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "resnet50_fc512", +] + +model_urls = { + "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + """Residual network. + + Reference: + - He et al. Deep Residual Learning for Image Recognition. CVPR 2016. + - Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017. + + Public keys: + - ``resnet18``: ResNet18. + - ``resnet34``: ResNet34. + - ``resnet50``: ResNet50. + - ``resnet101``: ResNet101. + - ``resnet152``: ResNet152. + - ``resnext50_32x4d``: ResNeXt50. + - ``resnext101_32x8d``: ResNeXt101. + - ``resnet50_fc512``: ResNet50 + FC. + """ + + def __init__( + self, + num_classes, + loss, + block, + layers, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=None, + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.loss = loss + self.feature_dim = 512 * block.expansion + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=last_stride, + dilate=replace_stride_with_dilation[2], + ) + self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p) + self.classifier = nn.Linear(self.feature_dim, num_classes) + + self._init_params() + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): + """Constructs fully connected layer + + Args: + fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed + input_dim (int): input dimension + dropout_p (float): dropout probability, if None, dropout is unused + """ + if fc_dims is None: + self.feature_dim = input_dim + return None + + assert isinstance( + fc_dims, (list, tuple) + ), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims)) + + layers = [] + for dim in fc_dims: + layers.append(nn.Linear(input_dim, dim)) + layers.append(nn.BatchNorm1d(dim)) + layers.append(nn.ReLU(inplace=True)) + if dropout_p is not None: + layers.append(nn.Dropout(p=dropout_p)) + input_dim = dim + + self.feature_dim = fc_dims[-1] + + return nn.Sequential(*layers) + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def featuremaps(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + f = self.featuremaps(x) + v = self.global_avgpool(f) + v = v.view(v.size(0), -1) + + if self.fc is not None: + v = self.fc(v) + + if not self.training: + return v + + y = self.classifier(v) + + if self.loss == "softmax": + return y + elif self.loss == "triplet": + return y, v + else: + raise KeyError("Unsupported loss: {}".format(self.loss)) + + +def init_pretrained_weights(model, model_url): + """Initializes model with pretrained weights. + + Layers that don't match with pretrained layers in name or size are kept unchanged. + """ + pretrain_dict = model_zoo.load_url(model_url) + model_dict = model.state_dict() + pretrain_dict = { + k: v + for k, v in pretrain_dict.items() + if k in model_dict and model_dict[k].size() == v.size() + } + model_dict.update(pretrain_dict) + model.load_state_dict(model_dict) + + +"""ResNet""" + + +def resnet18(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=BasicBlock, + layers=[2, 2, 2, 2], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet18"]) + return model + + +def resnet34(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=BasicBlock, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet34"]) + return model + + +def resnet50(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet50"]) + return model + + +def resnet101(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 23, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet101"]) + return model + + +def resnet152(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 8, 36, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet152"]) + return model + + +"""ResNeXt""" + + +def resnext50_32x4d(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + groups=32, + width_per_group=4, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnext50_32x4d"]) + return model + + +def resnext101_32x8d(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 23, 3], + last_stride=2, + fc_dims=None, + dropout_p=None, + groups=32, + width_per_group=8, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnext101_32x8d"]) + return model + + +""" +ResNet + FC +""" + + +def resnet50_fc512(num_classes, loss="softmax", pretrained=True, **kwargs): + model = ResNet( + num_classes=num_classes, + loss=loss, + block=Bottleneck, + layers=[3, 4, 6, 3], + last_stride=1, + fc_dims=[512], + dropout_p=None, + **kwargs + ) + if pretrained: + init_pretrained_weights(model, model_urls["resnet50"]) + return model diff --git a/boxmot/appearance/backends/base_backend.py b/boxmot/appearance/backends/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..a297080b7bd489fcb4ee3d6ce0fd3a60ff7e3f71 --- /dev/null +++ b/boxmot/appearance/backends/base_backend.py @@ -0,0 +1,135 @@ +import cv2 +import torch +import gdown +import numpy as np +from abc import ABC, abstractmethod +from boxmot.utils import logger as LOGGER +from boxmot.appearance.reid.registry import ReIDModelRegistry +from boxmot.utils.checks import RequirementsChecker + + +class BaseModelBackend: + def __init__(self, weights, device, half): + self.weights = weights[0] if isinstance(weights, list) else weights + self.device = device + self.half = half + self.model = None + self.cuda = torch.cuda.is_available() and self.device.type != "cpu" + + self.download_model(self.weights) + self.model_name = ReIDModelRegistry.get_model_name(self.weights) + + self.model = ReIDModelRegistry.build_model( + self.model_name, + num_classes=ReIDModelRegistry.get_nr_classes(self.weights), + pretrained=not (self.weights and self.weights.is_file()), + use_gpu=device, + ) + self.checker = RequirementsChecker() + self.load_model(self.weights) + + + def get_crops(self, xyxys, img): + h, w = img.shape[:2] + resize_dims = (128, 256) + interpolation_method = cv2.INTER_LINEAR + mean_array = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) + std_array = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) + + # Preallocate tensor for crops + num_crops = len(xyxys) + crops = torch.empty((num_crops, 3, resize_dims[1], resize_dims[0]), + dtype=torch.half if self.half else torch.float, device=self.device) + + for i, box in enumerate(xyxys): + x1, y1, x2, y2 = box.round().astype('int') + x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2) + crop = img[y1:y2, x1:x2] + + # Resize and convert color in one step + crop = cv2.resize(crop, resize_dims, interpolation=interpolation_method) + crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) + + # Convert to tensor and normalize (convert to [0, 1] by dividing by 255 in batch later) + crop = torch.from_numpy(crop).to(self.device, dtype=torch.half if self.half else torch.float) + crops[i] = torch.permute(crop, (2, 0, 1)) # Change to (C, H, W) + + # Normalize the entire batch in one go + crops = crops / 255.0 + + # Standardize the batch + crops = (crops - mean_array) / std_array + + return crops + + + @torch.no_grad() + def get_features(self, xyxys, img): + if xyxys.size != 0: + crops = self.get_crops(xyxys, img) + crops = self.inference_preprocess(crops) + features = self.forward(crops) + features = self.inference_postprocess(features) + else: + features = np.array([]) + features = features / np.linalg.norm(features, axis=-1, keepdims=True) + return features + + def warmup(self, imgsz=[(256, 128, 3)]): + # warmup model by running inference once + if self.device.type != "cpu": + im = np.random.randint(0, 255, *imgsz, dtype=np.uint8) + crops = self.get_crops(xyxys=np.array( + [[0, 0, 64, 64], [0, 0, 128, 128]]), + img=im + ) + crops = self.inference_preprocess(crops) + self.forward(crops) # warmup + + def to_numpy(self, x): + return x.cpu().numpy() if isinstance(x, torch.Tensor) else x + + def inference_preprocess(self, x): + if self.half: + if isinstance(x, torch.Tensor): + if x.dtype != torch.float16: + x = x.half() + elif isinstance(x, np.ndarray): + if x.dtype != np.float16: + x = x.astype(np.float16) + + if self.nhwc: + if isinstance(x, torch.Tensor): + x = x.permute(0, 2, 3, 1) # Convert from NCHW to NHWC + elif isinstance(x, np.ndarray): + x = np.transpose(x, (0, 2, 3, 1)) # Convert from NCHW to NHWC + return x + + def inference_postprocess(self, features): + if isinstance(features, (list, tuple)): + return ( + self.to_numpy(features[0]) if len(features) == 1 else [self.to_numpy(x) for x in features] + ) + else: + return self.to_numpy(features) + + @abstractmethod + def forward(self, im_batch): + raise NotImplementedError("This method should be implemented by subclasses.") + + @abstractmethod + def load_model(self, w): + raise NotImplementedError("This method should be implemented by subclasses.") + + + def download_model(self, w): + if w.suffix == ".pt": + model_url = ReIDModelRegistry.get_model_url(w) + if not w.exists() and model_url is not None: + gdown.download(model_url, str(w), quiet=False) + elif not w.exists(): + LOGGER.error( + f"No URL associated with the chosen StrongSORT weights ({w}). Choose between:" + ) + ReIDModelRegistry.show_downloadable_models() + exit() \ No newline at end of file diff --git a/boxmot/appearance/backends/onnx_backend.py b/boxmot/appearance/backends/onnx_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..0e47acd5477c016985ef05bb12a357a34239f9c2 --- /dev/null +++ b/boxmot/appearance/backends/onnx_backend.py @@ -0,0 +1,42 @@ +import numpy as np +from pathlib import Path + +from boxmot.appearance.backends.base_backend import BaseModelBackend + + +class ONNXBackend(BaseModelBackend): + + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + + # ONNXRuntime will attempt to use the first provider, and if it fails or is not + # available for some reason, it will fall back to the next provider in the list + if self.device == "mps": + self.checker.check_packages(("onnxruntime-silicon==1.17.0",)) + providers = ["MPSExecutionProvider", "CPUExecutionProvider"] + elif self.device == "cuda": + self.checker.check_packages(("onnxruntime-gpu==1.17.0",)) + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + else: + self.checker.check_packages(("onnxruntime==1.17.0",)) + providers = ["CPUExecutionProvider"] + + # Load the ONNX model using onnxruntime + import onnxruntime + self.session = onnxruntime.InferenceSession(str(w), providers=providers) + + def forward(self, im_batch): + # Convert torch tensor to numpy (onnxruntime expects numpy arrays) + im_batch = im_batch.cpu().numpy() + + # Run inference using ONNX session + features = self.session.run( + [self.session.get_outputs()[0].name], + {self.session.get_inputs()[0].name: im_batch}, + )[0] + + return features \ No newline at end of file diff --git a/boxmot/appearance/backends/openvino_backend.py b/boxmot/appearance/backends/openvino_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..582f24232bb460bcf22accd26369b5ceb16d6337 --- /dev/null +++ b/boxmot/appearance/backends/openvino_backend.py @@ -0,0 +1,44 @@ +import numpy as np +from pathlib import Path +from boxmot.utils import logger as LOGGER + +from boxmot.appearance.backends.base_backend import BaseModelBackend + + +class OpenVinoBackend(BaseModelBackend): + + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + self.checker.check_packages(("openvino-dev>=2022.3",)) + + LOGGER.info(f"Loading {w} for OpenVINO inference...") + try: + # requires openvino-dev: https://pypi.org/project/openvino-dev/ + from openvino.runtime import Core, Layout + except ImportError: + LOGGER.error( + f"Running {self.__class__} with the specified OpenVINO weights\n{w.name}\n" + "requires openvino pip package to be installed!\n" + "$ pip install openvino-dev>=2022.3\n" + ) + ie = Core() + if not Path(w).is_file(): # if not *.xml + w = next( + Path(w).glob("*.xml") + ) # get *.xml file from *_openvino_model dir + network = ie.read_model(model=w, weights=Path(w).with_suffix(".bin")) + if network.get_parameters()[0].get_layout().empty: + network.get_parameters()[0].set_layout(Layout("NCWH")) + self.executable_network = ie.compile_model( + network, device_name="CPU" + ) # device_name="MYRIAD" for Intel NCS2 + self.output_layer = next(iter(self.executable_network.outputs)) + + def forward(self, im_batch): + im_batch = im_batch.cpu().numpy() # FP32 + features = self.executable_network([im_batch])[self.output_layer] + return features diff --git a/boxmot/appearance/backends/pytorch_backend.py b/boxmot/appearance/backends/pytorch_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3360578c1516faebd222f66c3e7f4eafff675cd0 --- /dev/null +++ b/boxmot/appearance/backends/pytorch_backend.py @@ -0,0 +1,24 @@ +import numpy as np +from pathlib import Path + +from boxmot.appearance.backends.base_backend import BaseModelBackend +from boxmot.appearance.reid.registry import ReIDModelRegistry + + +class PyTorchBackend(BaseModelBackend): + + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + # Load a PyTorch model + if w and w.is_file(): + ReIDModelRegistry.load_pretrained_weights(self.model, w) + self.model.to(self.device).eval() + self.model.half() if self.half else self.model.float() + + def forward(self, im_batch): + features = self.model(im_batch) + return features diff --git a/boxmot/appearance/backends/tensorrt_backend.py b/boxmot/appearance/backends/tensorrt_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c5ce6e4a5eaf6316b4955874d8070c6850d836 --- /dev/null +++ b/boxmot/appearance/backends/tensorrt_backend.py @@ -0,0 +1,126 @@ +import torch +import numpy as np +from pathlib import Path +from collections import OrderedDict, namedtuple +from boxmot.utils import logger as LOGGER +from boxmot.appearance.backends.base_backend import BaseModelBackend + +class TensorRTBackend(BaseModelBackend): + def __init__(self, weights, device, half): + self.is_trt10 = False + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + self.device = device + self.weights = weights + self.fp16 = False # Will be updated in load_model + self.load_model(self.weights) + + def load_model(self, w): + LOGGER.info(f"Loading {w} for TensorRT inference...") + self.checker.check_packages(("nvidia-tensorrt",)) + try: + import tensorrt as trt # TensorRT library + except ImportError: + raise ImportError("Please install tensorrt to use this backend.") + + if self.device.type == "cpu": + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + raise ValueError("CUDA device not available for TensorRT inference.") + + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + + # Deserialize the engine + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + self.model_ = runtime.deserialize_cuda_engine(f.read()) + + # Execution context + self.context = self.model_.create_execution_context() + self.bindings = OrderedDict() + + self.is_trt10 = not hasattr(self.model_, "num_bindings") + num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings) + + # Parse bindings + for index in num: + if self.is_trt10: + name = self.model_.get_tensor_name(index) + dtype = trt.nptype(self.model_.get_tensor_dtype(name)) + is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input and -1 in tuple(self.model_.get_tensor_shape(name)): + self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1])) + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_tensor_shape(name)) + + else: + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + is_input = self.model_.binding_is_input(index) + + # Handle dynamic shapes + if is_input and -1 in self.model_.get_binding_shape(index): + profile_index = 0 + min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index) + self.context.set_binding_shape(index, opt_shape) + + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_binding_shape(index)) + data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device) + self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) + + self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) + + def forward(self, im_batch): + temp_im_batch = im_batch.clone() + batch_array = [] + inp_batch = im_batch.shape[0] + out_batch = self.bindings["output"].shape[0] + resultant_features = [] + + # Divide batch to sub batches + while inp_batch > out_batch: + batch_array.append(temp_im_batch[:out_batch]) + temp_im_batch = temp_im_batch[out_batch:] + inp_batch = temp_im_batch.shape[0] + if temp_im_batch.shape[0] > 0: + batch_array.append(temp_im_batch) + + for temp_batch in batch_array: + # Adjust for dynamic shapes + if temp_batch.shape != self.bindings["images"].shape: + if self.is_trt10: + + self.context.set_input_shape("images", temp_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape) + self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output"))) + else: + i_in = self.model_.get_binding_index("images") + i_out = self.model_.get_binding_index("output") + self.context.set_binding_shape(i_in, temp_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=temp_batch.shape) + output_shape = tuple(self.context.get_binding_shape(i_out)) + self.bindings["output"].data.resize_(output_shape) + + s = self.bindings["images"].shape + assert temp_batch.shape == s, f"Input size {temp_batch.shape} does not match model size {s}" + + self.binding_addrs["images"] = int(temp_batch.data_ptr()) + + # Execute inference + self.context.execute_v2(list(self.binding_addrs.values())) + features = self.bindings["output"].data + resultant_features.append(features.clone()) + + if len(resultant_features)== 1: + return resultant_features[0] + else: + rslt_features = torch.cat(resultant_features,dim=0) + rslt_features= rslt_features[:im_batch.shape[0]] + return rslt_features diff --git a/boxmot/appearance/backends/tflite_backend.py b/boxmot/appearance/backends/tflite_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2e77068359299ec3d3bfe688c7ed206cda72de --- /dev/null +++ b/boxmot/appearance/backends/tflite_backend.py @@ -0,0 +1,86 @@ +import torch +import numpy as np +from pathlib import Path +from boxmot.utils import logger as LOGGER + +from boxmot.appearance.backends.base_backend import BaseModelBackend + + +class TFLiteBackend(BaseModelBackend): + """ + A class to handle TensorFlow Lite model inference with dynamic batch size support. + + Attributes: + nhwc (bool): A flag indicating the order of dimensions. + half (bool): A flag to indicate if half precision is used. + interpreter (tf.lite.Interpreter): The TensorFlow Lite interpreter. + current_allocated_batch_size (int): The current batch size allocated in the interpreter. + """ + + def __init__(self, weights: Path, device: str, half: bool): + """ + Initializes the TFLiteBackend with given weights, device, and precision flag. + + Args: + weights (Path): Path to the TFLite model file. + device (str): Device type (e.g., 'cpu', 'gpu'). + half (bool): Flag to indicate if half precision is used. + """ + super().__init__(weights, device, half) + self.nhwc = True + self.half = False + # self.interpreter: tf.lite.Interpreter = None + # self.current_allocated_batch_size: int = None + + def load_model(self, w): + """ + Loads the TensorFlow Lite model and initializes the interpreter. + + Args: + w (str): Path to the TFLite model file. + """ + self.checker.check_packages(("tensorflow",)) + + LOGGER.info(f"Loading {str(w)} for TensorFlow Lite inference...") + + import tensorflow as tf + self.interpreter = tf.lite.Interpreter(model_path=str(w)) + + + self.interpreter.allocate_tensors() # allocate + self.input_details = self.interpreter.get_input_details() # inputs + self.output_details = self.interpreter.get_output_details() # outputs + self.current_allocated_batch_size = self.input_details[0]['shape'][0] + + def forward(self, im_batch: torch.Tensor) -> np.ndarray: + """ + Runs forward pass for the given image batch through the TFLite model. + + Args: + im_batch (torch.Tensor): Input image batch tensor. + + Returns: + np.ndarray: Output features from the TFLite model. + """ + im_batch = im_batch.cpu().numpy() + + # Extract batch size from im_batch + batch_size = im_batch.shape[0] + + # Resize tensors if the new batch size is different from the current allocated batch size + if batch_size != self.current_allocated_batch_size: + # print(f"Resizing tensor input to batch size {batch_size}") + self.interpreter.resize_tensor_input(self.input_details[0]['index'], [batch_size, 256, 128, 3]) + self.interpreter.allocate_tensors() + self.current_allocated_batch_size = batch_size + + # Set the tensor to point to the input data + self.interpreter.set_tensor(self.input_details[0]['index'], im_batch) + + # Run inference + self.interpreter.invoke() + + # Get the output data + features = self.interpreter.get_tensor(self.output_details[0]['index']) + + return features \ No newline at end of file diff --git a/boxmot/appearance/backends/torchscript_backend.py b/boxmot/appearance/backends/torchscript_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6a40f46a15e9a1edb35f8a9db23b06260a848b --- /dev/null +++ b/boxmot/appearance/backends/torchscript_backend.py @@ -0,0 +1,24 @@ +import torch +import numpy as np +from pathlib import Path +from boxmot.utils import logger as LOGGER + +from boxmot.appearance.backends.base_backend import BaseModelBackend + + +class TorchscriptBackend(BaseModelBackend): + + def __init__(self, weights, device, half): + super().__init__(weights, device, half) + self.nhwc = False + self.half = half + + def load_model(self, w): + + LOGGER.info(f"Loading {w} for TorchScript inference...") + self.model = torch.jit.load(w) + self.model.half() if self.half else self.model.float() + + def forward(self, im_batch): + features = self.model(im_batch) + return features diff --git a/boxmot/appearance/exporters/base_exporter.py b/boxmot/appearance/exporters/base_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ce671a3dd30993e64b5a85937aa94e25c9c842 --- /dev/null +++ b/boxmot/appearance/exporters/base_exporter.py @@ -0,0 +1,56 @@ +import logging +import torch +from pathlib import Path +from boxmot.utils.checks import RequirementsChecker +from boxmot.utils import logger as LOGGER + + +def export_decorator(export_func): + def wrapper(self, *args, **kwargs): + try: + if hasattr(self, 'required_packages'): + if hasattr(self, 'cmd'): + self.checker.check_packages(self.required_packages, cmd=self.cmd) + else: + self.checker.check_packages(self.required_packages) + + LOGGER.info(f"\nStarting {self.file} export with {self.__class__.__name__}...") + result = export_func(self, *args, **kwargs) + if result: + LOGGER.info(f"Export success, saved as {result} ({self.file_size(result):.1f} MB)") + return result + except Exception as e: + LOGGER.error(f"Export failure: {e}") + return None + return wrapper + + +class BaseExporter: + def __init__(self, model, im, file, optimize=False, dynamic=False, half=False, simplify=False): + self.model = model + self.im = im + self.file = Path(file) + self.optimize = optimize + self.dynamic = dynamic + self.half = half + self.simplify = simplify + self.checker = RequirementsChecker() + self.workspace = 4 + + @staticmethod + def file_size(path): + path = Path(path) + if path.is_file(): + return path.stat().st_size / 1e6 + elif path.is_dir(): + return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / 1e6 + else: + return 0.0 + + def export(self): + raise NotImplementedError("Export method must be implemented in subclasses.") + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if 'export' in cls.__dict__: + cls.export = export_decorator(cls.export) \ No newline at end of file diff --git a/boxmot/appearance/exporters/onnx_exporter.py b/boxmot/appearance/exporters/onnx_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..5548bfedce0ddcd8b28e6bfcb78c6cd1f9886f2f --- /dev/null +++ b/boxmot/appearance/exporters/onnx_exporter.py @@ -0,0 +1,56 @@ +import torch +import onnx +from boxmot.appearance.exporters.base_exporter import BaseExporter +from boxmot.utils import logger as LOGGER + + +class ONNXExporter(BaseExporter): + required_packages = ("onnx>=1.16.1",) + + def export(self): + + f = self.file.with_suffix(".onnx") + + dynamic = {"images": {0: "batch"}, "output": {0: "batch"}} if self.dynamic else None + + torch.onnx.export( + self.model.cpu() if self.dynamic else self.model, + self.im.cpu() if self.dynamic else self.im, + f, + verbose=False, + opset_version=12, + do_constant_folding=True, + input_names=["images"], + output_names=["output"], + dynamic_axes=dynamic, + ) + + model_onnx = onnx.load(f) + onnx.checker.check_model(model_onnx) + onnx.save(model_onnx, f) + + if self.simplify: + self.simplify_model(model_onnx, f) + + return f + + + def simplify_model(self, model_onnx, f): + try: + cuda = torch.cuda.is_available() + self.checker.check_packages( + ( + "onnxruntime-gpu" if cuda else "onnxruntime", + "onnx-simplifier>=0.4.1", + ) + ) + import onnxsim + + LOGGER.info( + f"Simplifying with onnx-simplifier {onnxsim.__version__}..." + ) + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "assert check failed" + onnx.save(model_onnx, f) + except Exception as e: + LOGGER.error(f"Simplifier failure: {e}") \ No newline at end of file diff --git a/boxmot/appearance/exporters/openvino_exporter.py b/boxmot/appearance/exporters/openvino_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..ed39805509319eb9282fd1de10c56b24f5de6738 --- /dev/null +++ b/boxmot/appearance/exporters/openvino_exporter.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path +import openvino.runtime as ov +from openvino.tools import mo +from boxmot.appearance.exporters.base_exporter import BaseExporter +from boxmot.utils import logger as LOGGER + + +class OpenVINOExporter(BaseExporter): + required_packages = ("openvino-dev>=2023.0",) + + def export(self): + + f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}") + f_onnx = self.file.with_suffix(".onnx") + f_ov = str(Path(f) / self.file.with_suffix(".xml").name) + + ov_model = mo.convert_model( + f_onnx, + model_name=self.file.with_suffix(".xml"), + framework="onnx", + compress_to_fp16=self.half, + ) + ov.serialize(ov_model, f_ov) + + return f \ No newline at end of file diff --git a/boxmot/appearance/exporters/tensorrt_exporter.py b/boxmot/appearance/exporters/tensorrt_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..7f138da79bf40e7a30a677a62fdf883d0c35c749 --- /dev/null +++ b/boxmot/appearance/exporters/tensorrt_exporter.py @@ -0,0 +1,80 @@ +import platform +import torch +from boxmot.appearance.exporters.base_exporter import BaseExporter +from boxmot.appearance.exporters.onnx_exporter import ONNXExporter +from boxmot.utils import logger as LOGGER + + +class EngineExporter(BaseExporter): + required_packages = ("nvidia-tensorrt",) + cmds = '--extra-index-url https://pypi.ngc.nvidia.com' + + def export(self): + + assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. `python export.py --device 0`" + try: + import tensorrt as trt + except ImportError: + import tensorrt as trt + + onnx_file = self.export_onnx() + LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...") + is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10 + assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}" + f = self.file.with_suffix(".engine") + logger = trt.Logger(trt.Logger.INFO) + if True: + logger.min_severity = trt.Logger.Severity.VERBOSE + + builder = trt.Builder(logger) + config = builder.create_builder_config() + workspace = int(self.workspace * (1 << 30)) + if is_trt10: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) + else: # TensorRT versions 7, 8 + config.max_workspace_size = workspace + + flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(flag) + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(str(onnx_file)): + raise RuntimeError(f"Failed to load ONNX file: {onnx_file}") + + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + LOGGER.info("Network Description:") + for inp in inputs: + LOGGER.info(f'\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') + for out in outputs: + LOGGER.info(f'\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') + + if self.dynamic: + if self.im.shape[0] <= 1: + LOGGER.warning("WARNING: --dynamic model requires maximum --batch-size argument") + profile = builder.create_optimization_profile() + for inp in inputs: + if self.half: + inp.dtype = trt.float16 + profile.set_shape( + inp.name, + (1, *self.im.shape[1:]), + (max(1, self.im.shape[0] // 2), *self.im.shape[1:]), + self.im.shape, + ) + config.add_optimization_profile(profile) + + LOGGER.info(f"Building FP{16 if builder.platform_has_fast_fp16 and self.half else 32} engine in {f}") + if builder.platform_has_fast_fp16 and self.half: + config.set_flag(trt.BuilderFlag.FP16) + config.default_device_type = trt.DeviceType.GPU + + build = builder.build_serialized_network if is_trt10 else builder.build_engine + with build(network, config) as engine, open(f, "wb") as t: + t.write(engine if is_trt10 else engine.serialize()) + + return f + + + def export_onnx(self): + onnx_exporter = ONNXExporter(self.model, self.im, self.file, self.optimize, self.dynamic, self.half, self.simplify) + return onnx_exporter.export() diff --git a/boxmot/appearance/exporters/tflite_exporter.py b/boxmot/appearance/exporters/tflite_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..270dd05d6f8e604782d62b2de47ef154169f5bb4 --- /dev/null +++ b/boxmot/appearance/exporters/tflite_exporter.py @@ -0,0 +1,37 @@ +import os +from boxmot.appearance.exporters.base_exporter import BaseExporter +from boxmot.utils import logger as LOGGER + + +class TFLiteExporter(BaseExporter): + required_packages = ( + "onnx2tf>=1.18.0", + "onnx>=1.16.1", + "tensorflow==2.17.0", + "tf_keras", # required by 'onnx2tf' package + "sng4onnx>=1.0.1", # required by 'onnx2tf' package + "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package + "onnxslim>=0.1.31", + "onnxruntime", + "flatbuffers>=23.5.26", + "psutil==5.9.5", + "ml_dtypes==0.3.2", + "ai_edge_litert>=1.2.0" + ) + cmds = '--extra-index-url https://pypi.ngc.nvidia.com' + + def export(self): + + import onnx2tf + input_onnx_file_path = str(self.file.with_suffix('.onnx')) + output_folder_path = input_onnx_file_path.replace(".onnx", f"_saved_model{os.sep}") + onnx2tf.convert( + input_onnx_file_path=input_onnx_file_path, + output_folder_path=output_folder_path, + not_use_onnxsim=True, + verbosity=True, + # output_integer_quantized_tflite=self.args.int8, + # quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate) + # custom_input_op_name_np_data_path=np_data, + ) + return output_folder_path diff --git a/boxmot/appearance/exporters/torchscript_exporter.py b/boxmot/appearance/exporters/torchscript_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..df82ef54f6edf6ee1af170da9df1bd33cbcbd198 --- /dev/null +++ b/boxmot/appearance/exporters/torchscript_exporter.py @@ -0,0 +1,15 @@ +import torch +from boxmot.appearance.exporters.base_exporter import BaseExporter +from boxmot.utils import logger as LOGGER + + +class TorchScriptExporter(BaseExporter): + def export(self): + f = self.file.with_suffix(".torchscript") + ts = torch.jit.trace(self.model, self.im, strict=False) + if self.optimize: + torch.utils.mobile_optimizer.optimize_for_mobile(ts)._save_for_lite_interpreter(str(f)) + else: + ts.save(str(f)) + + return f diff --git a/boxmot/appearance/reid/__init__.py b/boxmot/appearance/reid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b060aea5c01d4d19635a26167d2ae4653c8eb7fb --- /dev/null +++ b/boxmot/appearance/reid/__init__.py @@ -0,0 +1,16 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import pandas as pd + + +def export_formats(): + # yolo tracking export formats + x = [ + ["PyTorch", "-", ".pt", True, True], + ["TorchScript", "torchscript", ".torchscript", True, True], + ["ONNX", "onnx", ".onnx", True, True], + ["OpenVINO", "openvino", "_openvino_model", True, False], + ["TensorRT", "engine", ".engine", False, True], + ["TensorFlow Lite", "tflite", ".tflite", True, False], + ] + return pd.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"]) diff --git a/boxmot/appearance/reid/auto_backend.py b/boxmot/appearance/reid/auto_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..5879c68567b384415ad1a54ab745ead078b19707 --- /dev/null +++ b/boxmot/appearance/reid/auto_backend.py @@ -0,0 +1,128 @@ +import torch +from pathlib import Path +from typing import Union, Tuple + +from boxmot.utils import WEIGHTS +from boxmot.utils import logger as LOGGER +from boxmot.utils.torch_utils import select_device +from boxmot.appearance.reid import export_formats +from boxmot.appearance.backends.onnx_backend import ONNXBackend +from boxmot.appearance.backends.openvino_backend import OpenVinoBackend +from boxmot.appearance.backends.pytorch_backend import PyTorchBackend +from boxmot.appearance.backends.tensorrt_backend import TensorRTBackend +from boxmot.appearance.backends.tflite_backend import TFLiteBackend +from boxmot.appearance.backends.torchscript_backend import TorchscriptBackend +from boxmot.appearance.backends.base_backend import BaseModelBackend + + + +class ReidAutoBackend(): + def __init__( + self, + weights: Path = WEIGHTS / "osnet_x0_25_msmt17.pt", + device: torch.device = torch.device("cpu"), + half: bool = False) -> None: + """ + Initializes the ReidAutoBackend instance with specified weights, device, and precision mode. + + Args: + weights (Union[str, List[str]]): Path to the model weights. Can be a string or a list of strings; if a list, the first element is used. + device (torch.device): The device to run the model on, e.g., CPU or GPU. + half (bool): Whether to use half precision for model inference. + """ + super().__init__() + w = weights[0] if isinstance(weights, list) else weights + ( + self.pt, + self.jit, + self.onnx, + self.xml, + self.engine, + self.tflite, + ) = self.model_type(w) # get backend + + self.weights = weights + self.device = select_device(device) + self.half = half + self.model = self.get_backend() + + + def get_backend(self) -> Union['PyTorchBackend', 'TorchscriptBackend', 'ONNXBackend', 'TensorRTBackend', 'OpenVinoBackend', 'TFLiteBackend']: + """ + Returns an instance of the appropriate backend based on the model type. + + Returns: + An instance of a backend class corresponding to the detected model type. + + Raises: + SystemExit: If no supported model framework is detected. + """ + + # Mapping of conditions to backend constructors + backend_map = { + self.pt: PyTorchBackend, + self.jit: TorchscriptBackend, + self.onnx: ONNXBackend, + self.engine: TensorRTBackend, + self.xml: OpenVinoBackend, + self.tflite: TFLiteBackend + } + + # Iterate through the mapping and return the first matching backend + for condition, backend_class in backend_map.items(): + if condition: + return backend_class(self.weights, self.device, self.half) + + # If no condition is met, log an error and exit + LOGGER.error("This model framework is not supported yet!") + exit() + + + def forward(self, im_batch: torch.Tensor) -> torch.Tensor: + """ + Processes an image batch through the selected backend and returns the processed batch. + + Args: + im_batch (torch.Tensor): The batch of images to process. + + Returns: + torch.Tensor: The processed image batch. + """ + im_batch = self.backend.preprocess_input(im_batch) + return self.backend.get_features(im_batch) + + + def check_suffix(self, file: Path = "osnet_x0_25_msmt17.pt", suffix: Union[str, Tuple[str, ...]] = (".pt",), msg: str = "") -> None: + """ + Validates that the file or files have an acceptable suffix. + + Args: + file (Union[str, List[str], Path]): The file or files to check. + suffix (Union[str, Tuple[str, ...]]): Acceptable suffix or suffixes. + msg (str): Additional message to log in case of an error. + """ + + suffix = [suffix] if isinstance(suffix, str) else list(suffix) + files = [file] if isinstance(file, (str, Path)) else list(file) + + for f in files: + file_suffix = Path(f).suffix.lower() + if file_suffix and file_suffix not in suffix: + LOGGER.error(f"File {f} does not have an acceptable suffix. Expected: {suffix}") + + + def model_type(self, p: Path) -> Tuple[bool, ...]: + """ + Determines the model type based on the file's suffix. + + Args: + path (str): The file path to the model. + + Returns: + Tuple[bool, ...]: A tuple of booleans indicating the model type, corresponding to pt, jit, onnx, xml, engine, and tflite. + """ + + sf = list(export_formats().Suffix) # export suffixes + self.check_suffix(p, sf) # checks + types = [s in Path(p).name for s in sf] + return types \ No newline at end of file diff --git a/boxmot/appearance/reid/config.py b/boxmot/appearance/reid/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4293273904b2e7107dc97ce3a29ab9708cab71e7 --- /dev/null +++ b/boxmot/appearance/reid/config.py @@ -0,0 +1,73 @@ +MODEL_TYPES = [ + "resnet50", + "resnet101", + "mlfn", + "hacnn", + "mobilenetv2_x1_0", + "mobilenetv2_x1_4", + "osnet_x1_0", + "osnet_x0_75", + "osnet_x0_5", + "osnet_x0_25", + "osnet_ibn_x1_0", + "osnet_ain_x1_0", + "lmbn_n", + "clip", +] + +TRAINED_URLS = { + # resnet50 + "resnet50_market1501.pt": "https://drive.google.com/uc?id=1dUUZ4rHDWohmsQXCRe2C_HbYkzz94iBV", + "resnet50_dukemtmcreid.pt": "https://drive.google.com/uc?id=17ymnLglnc64NRvGOitY3BqMRS9UWd1wg", + "resnet50_msmt17.pt": "https://drive.google.com/uc?id=1ep7RypVDOthCRIAqDnn4_N-UhkkFHJsj", + "resnet50_fc512_market1501.pt": "https://drive.google.com/uc?id=1kv8l5laX_YCdIGVCetjlNdzKIA3NvsSt", + "resnet50_fc512_dukemtmcreid.pt": "https://drive.google.com/uc?id=13QN8Mp3XH81GK4BPGXobKHKyTGH50Rtx", + "resnet50_fc512_msmt17.pt": "https://drive.google.com/uc?id=1fDJLcz4O5wxNSUvImIIjoaIF9u1Rwaud", + # mlfn + "mlfn_market1501.pt": "https://drive.google.com/uc?id=1wXcvhA_b1kpDfrt9s2Pma-MHxtj9pmvS", + "mlfn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1rExgrTNb0VCIcOnXfMsbwSUW1h2L1Bum", + "mlfn_msmt17.pt": "https://drive.google.com/uc?id=18JzsZlJb3Wm7irCbZbZ07TN4IFKvR6p-", + # hacnn + "hacnn_market1501.pt": "https://drive.google.com/uc?id=1LRKIQduThwGxMDQMiVkTScBwR7WidmYF", + "hacnn_dukemtmcreid.pt": "https://drive.google.com/uc?id=1zNm6tP4ozFUCUQ7Sv1Z98EAJWXJEhtYH", + "hacnn_msmt17.pt": "https://drive.google.com/uc?id=1MsKRtPM5WJ3_Tk2xC0aGOO7pM3VaFDNZ", + # mobilenetv2 + "mobilenetv2_x1_0_market1501.pt": "https://drive.google.com/uc?id=18DgHC2ZJkjekVoqBWszD8_Xiikz-fewp", + "mobilenetv2_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1q1WU2FETRJ3BXcpVtfJUuqq4z3psetds", + "mobilenetv2_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1j50Hv14NOUAg7ZeB3frzfX-WYLi7SrhZ", + "mobilenetv2_x1_4_market1501.pt": "https://drive.google.com/uc?id=1t6JCqphJG-fwwPVkRLmGGyEBhGOf2GO5", + "mobilenetv2_x1_4_dukemtmcreid.pt": "https://drive.google.com/uc?id=12uD5FeVqLg9-AFDju2L7SQxjmPb4zpBN", + "mobilenetv2_x1_4_msmt17.pt": "https://drive.google.com/uc?id=1ZY5P2Zgm-3RbDpbXM0kIBMPvspeNIbXz", + # osnet + "osnet_x1_0_market1501.pt": "https://drive.google.com/uc?id=1vduhq5DpN2q1g4fYEZfPI17MJeh9qyrA", + "osnet_x1_0_dukemtmcreid.pt": "https://drive.google.com/uc?id=1QZO_4sNf4hdOKKKzKc-TZU9WW1v6zQbq", + "osnet_x1_0_msmt17.pt": "https://drive.google.com/uc?id=112EMUfBPYeYg70w-syK6V6Mx8-Qb9Q1M", + "osnet_x0_75_market1501.pt": "https://drive.google.com/uc?id=1ozRaDSQw_EQ8_93OUmjDbvLXw9TnfPer", + "osnet_x0_75_dukemtmcreid.pt": "https://drive.google.com/uc?id=1IE3KRaTPp4OUa6PGTFL_d5_KQSJbP0Or", + "osnet_x0_75_msmt17.pt": "https://drive.google.com/uc?id=1QEGO6WnJ-BmUzVPd3q9NoaO_GsPNlmWc", + "osnet_x0_5_market1501.pt": "https://drive.google.com/uc?id=1PLB9rgqrUM7blWrg4QlprCuPT7ILYGKT", + "osnet_x0_5_dukemtmcreid.pt": "https://drive.google.com/uc?id=1KoUVqmiST175hnkALg9XuTi1oYpqcyTu", + "osnet_x0_5_msmt17.pt": "https://drive.google.com/uc?id=1UT3AxIaDvS2PdxzZmbkLmjtiqq7AIKCv", + "osnet_x0_25_market1501.pt": "https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj", + "osnet_x0_25_dukemtmcreid.pt": "https://drive.google.com/uc?id=1eumrtiXT4NOspjyEV4j8cHmlOaaCGk5l", + "osnet_x0_25_msmt17.pt": "https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF", + # osnet_ain | osnet_ibn + "osnet_ibn_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1q3Sj2ii34NlfxA4LvmHdWO_75NDRmECJ", + "osnet_ain_x1_0_msmt17.pt": "https://drive.google.com/uc?id=1SigwBE6mPdqiJMqhuIY4aqC7--5CsMal", + # lmbn + "lmbn_n_duke.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_duke.pth", + "lmbn_n_market.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_market.pth", + "lmbn_n_cuhk03_d.pt": "https://github.com/mikel-brostrom/yolov8_tracking/releases/download/v9.0/lmbn_n_cuhk03_d.pth", + # clip + "clip_market1501.pt": "https://drive.google.com/uc?id=1GnyAVeNOg3Yug1KBBWMKKbT2x43O5Ch7", + "clip_duke.pt": "https://drive.google.com/uc?id=1ldjSkj-7pXAWmx8on5x0EftlCaolU4dY", + "clip_veri.pt": "https://drive.google.com/uc?id=1RyfHdOBI2pan_wIGSim5-l6cM4S2WN8e", + "clip_vehicleid.pt": "https://drive.google.com/uc?id=168BLegHHxNqatW5wx1YyL2REaThWoof5" +} + +NR_CLASSES_DICT = { + "market1501": 751, + "duke": 702, + "veri": 576, + "vehicleid": 576, +} \ No newline at end of file diff --git a/boxmot/appearance/reid/export.py b/boxmot/appearance/reid/export.py new file mode 100644 index 0000000000000000000000000000000000000000..d6282d43d227f9cec679d485ded11849d01ad32f --- /dev/null +++ b/boxmot/appearance/reid/export.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +import argparse +import time +from pathlib import Path + +import torch + +from boxmot.appearance.exporters.base_exporter import BaseExporter +from boxmot.appearance.exporters.onnx_exporter import ONNXExporter +from boxmot.appearance.exporters.openvino_exporter import OpenVINOExporter +from boxmot.appearance.exporters.tflite_exporter import TFLiteExporter +from boxmot.appearance.exporters.torchscript_exporter import TorchScriptExporter +from boxmot.appearance.exporters.tensorrt_exporter import EngineExporter +from boxmot.appearance.reid import export_formats +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.appearance.reid.registry import ReIDModelRegistry +from boxmot.utils import WEIGHTS, logger as LOGGER +from boxmot.utils.torch_utils import select_device + + +def parse_args(): + """ + Parse command-line arguments for the ReID export script. + """ + parser = argparse.ArgumentParser(description="ReID Export Script") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for export") + parser.add_argument("--imgsz", "--img", "--img-size", + nargs="+", type=int, default=[256, 128], + help="Image size in the format: height width") + parser.add_argument("--device", default="cpu", + help="CUDA device (e.g., '0', '0,1,2,3', or 'cpu')") + parser.add_argument("--optimize", action="store_true", + help="Optimize TorchScript for mobile (CPU export only)") + parser.add_argument("--dynamic", action="store_true", + help="Enable dynamic axes for ONNX/TF/TensorRT export") + parser.add_argument("--simplify", action="store_true", + help="Simplify ONNX model") + parser.add_argument("--opset", type=int, default=12, + help="ONNX opset version") + parser.add_argument("--workspace", type=int, default=4, + help="TensorRT workspace size (GB)") + parser.add_argument("--verbose", action="store_true", + help="Enable verbose logging for TensorRT") + parser.add_argument("--weights", type=Path, + default=WEIGHTS / "osnet_x0_25_msmt17.pt", + help="Path to the model weights (.pt file)") + parser.add_argument("--half", action="store_true", + help="Enable FP16 half-precision export (GPU only)") + parser.add_argument("--include", nargs="+", + default=["torchscript"], + help=("Export formats to include. Options: torchscript, onnx, " + "openvino, engine, tflite")) + return parser.parse_args() + + +def validate_export_formats(include): + """ + Validate the provided export formats and return corresponding flags. + + Args: + include (list): List of export formats provided via the command line. + + Returns: + tuple: Boolean flags for each export format in the order: + (torchscript, onnx, openvino, engine, tflite) + """ + available_formats = tuple(export_formats()["Argument"][1:]) + include_lower = [fmt.lower() for fmt in include] + flags = [fmt in include_lower for fmt in available_formats] + if sum(flags) != len(include_lower): + raise AssertionError( + f"ERROR: Invalid --include {include}, valid arguments are {available_formats}" + ) + return tuple(flags) + + +def setup_model(args): + """ + Initialize and prepare the ReID model for export. + + Args: + args: Parsed command-line arguments. + + Returns: + tuple: (model (torch.nn.Module), dummy_input (torch.Tensor)) + """ + # Select the correct device + args.device = select_device(args.device) + if args.half and args.device.type == "cpu": + raise AssertionError("--half only compatible with GPU export, use --device 0 for GPU") + + # Initialize backend model using the auto backend + auto_backend = ReidAutoBackend(weights=args.weights, device=args.device, half=args.half) + _ = auto_backend.get_backend() # Backend model is managed internally + + # Build and load the ReID model from the registry + model_name = ReIDModelRegistry.get_model_name(args.weights) + nr_classes = ReIDModelRegistry.get_nr_classes(args.weights) + pretrained = not (args.weights and args.weights.is_file() and args.weights.suffix == ".pt") + model = ReIDModelRegistry.build_model( + model_name, + num_classes=nr_classes, + pretrained=pretrained, + use_gpu=args.device, + ).to(args.device) + ReIDModelRegistry.load_pretrained_weights(model, args.weights) + model.eval() + + # Ensure --optimize is only used with CPU exports + if args.optimize and args.device.type != "cpu": + raise AssertionError("--optimize not compatible with CUDA devices, use --device cpu") + + # Adjust image size if a specific weight type is detected + if "lmbn" in str(args.weights): + args.imgsz = [384, 128] + + # Create dummy input tensor for warming up the model + dummy_input = torch.empty(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device) + for _ in range(2): + _ = model(dummy_input) + + # Convert to half precision if required + if args.half: + dummy_input = dummy_input.half() + model = model.half() + + return model, dummy_input + + +def create_export_tasks(args, model, dummy_input): + """ + Create a mapping of export tasks with associated flags, exporter classes, and parameters. + + Args: + args: Parsed command-line arguments. + model: Prepared ReID model. + dummy_input: Dummy input tensor. + + Returns: + dict: Mapping of export format to a tuple (flag, exporter_class, export_args) + """ + torchscript_flag, onnx_flag, openvino_flag, engine_flag, tflite_flag = validate_export_formats(args.include) + return { + "torchscript": ( + torchscript_flag, + TorchScriptExporter, + (model, dummy_input, args.weights, args.optimize) + ), + "engine": ( + engine_flag, + EngineExporter, + (model, dummy_input, args.weights, args.half, args.dynamic, args.simplify, args.verbose) + ), + "onnx": ( + onnx_flag, + ONNXExporter, + (model, dummy_input, args.weights, args.opset, args.dynamic, args.half, args.simplify) + ), + "tflite": ( + tflite_flag, + TFLiteExporter, + (model, dummy_input, args.weights) + ), + "openvino": ( + openvino_flag, + OpenVINOExporter, + (model, dummy_input, args.weights, args.half) + ) + } + + +def perform_exports(export_tasks): + """ + Iterate over export tasks and perform export for enabled formats. + + Args: + export_tasks (dict): Mapping of export tasks. + + Returns: + dict: Mapping of export format to export results. + """ + exported_files = {} + for fmt, (flag, exporter_class, exp_args) in export_tasks.items(): + if flag: + exporter = exporter_class(*exp_args) + export_result = exporter.export() + exported_files[fmt] = export_result + return exported_files + + +def main(): + """Main function to execute the ReID export process.""" + args = parse_args() + start_time = time.time() + + # Ensure the weights directory exists + WEIGHTS.mkdir(parents=False, exist_ok=True) + + # Setup model and create a dummy input tensor + model, dummy_input = setup_model(args) + + # Log model output shape and file size + output = model(dummy_input) + output_tensor = output[0] if isinstance(output, tuple) else output + output_shape = tuple(output_tensor.shape) + LOGGER.info( + f"\nStarting from {args.weights} with output shape {output_shape} " + f"({BaseExporter.file_size(args.weights):.1f} MB)" + ) + + # Create export tasks + export_tasks = create_export_tasks(args, model, dummy_input) + + # Perform exports for enabled formats + exported_files = perform_exports(export_tasks) + + if exported_files: + elapsed_time = time.time() - start_time + LOGGER.info( + f"\nExport complete ({elapsed_time:.1f}s)" + f"\nResults saved to {args.weights.parent.resolve()}" + f"\nVisualize: https://netron.app" + ) + + +if __name__ == "__main__": + main() diff --git a/boxmot/appearance/reid/factory.py b/boxmot/appearance/reid/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..5bca3c39629adb1e17295a801526afdfeb2cb464 --- /dev/null +++ b/boxmot/appearance/reid/factory.py @@ -0,0 +1,40 @@ +from boxmot.appearance.backbones.clip.make_model import make_model +from boxmot.appearance.backbones.hacnn import HACNN +from boxmot.appearance.backbones.lmbn.lmbn_n import LMBN_n +from boxmot.appearance.backbones.mlfn import mlfn +from boxmot.appearance.backbones.mobilenetv2 import mobilenetv2_x1_0, mobilenetv2_x1_4 +from boxmot.appearance.backbones.osnet import ( + osnet_ibn_x1_0, + osnet_x0_5, + osnet_x0_25, + osnet_x0_75, + osnet_x1_0, +) +from boxmot.appearance.backbones.osnet_ain import ( + osnet_ain_x0_5, + osnet_ain_x0_25, + osnet_ain_x0_75, + osnet_ain_x1_0, +) +from boxmot.appearance.backbones.resnet import resnet50, resnet101 + +# Map model names to their respective constructors +MODEL_FACTORY = { + "resnet50": resnet50, + "resnet101": resnet101, + "mobilenetv2_x1_0": mobilenetv2_x1_0, + "mobilenetv2_x1_4": mobilenetv2_x1_4, + "hacnn": HACNN, + "mlfn": mlfn, + "osnet_x1_0": osnet_x1_0, + "osnet_x0_75": osnet_x0_75, + "osnet_x0_5": osnet_x0_5, + "osnet_x0_25": osnet_x0_25, + "osnet_ibn_x1_0": osnet_ibn_x1_0, + "osnet_ain_x1_0": osnet_ain_x1_0, + "osnet_ain_x0_75": osnet_ain_x0_75, + "osnet_ain_x0_5": osnet_ain_x0_5, + "osnet_ain_x0_25": osnet_ain_x0_25, + "lmbn_n": LMBN_n, + "clip": make_model, +} \ No newline at end of file diff --git a/boxmot/appearance/reid/registry.py b/boxmot/appearance/reid/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc90bd3c60a2df1815bf84d1f336b7325481ee7 --- /dev/null +++ b/boxmot/appearance/reid/registry.py @@ -0,0 +1,87 @@ +# model_registry.py +import torch +from collections import OrderedDict +from boxmot.utils import logger as LOGGER + +from boxmot.appearance.reid.config import MODEL_TYPES, TRAINED_URLS, NR_CLASSES_DICT +from boxmot.appearance.reid.factory import MODEL_FACTORY + +class ReIDModelRegistry: + """Encapsulates model registration and related utilities.""" + + @staticmethod + def show_downloadable_models(): + LOGGER.info("Available .pt ReID models for automatic download") + LOGGER.info(list(TRAINED_URLS.keys())) + + @staticmethod + def get_model_name(model): + for name in MODEL_TYPES: + if name in model.name: + return name + return None + + @staticmethod + def get_model_url(model): + return TRAINED_URLS.get(model.name, None) + + @staticmethod + def load_pretrained_weights(model, weight_path): + """ + Loads pretrained weights into a model. + Chooses the proper map_location based on CUDA availability. + """ + device = "cpu" if not torch.cuda.is_available() else None + checkpoint = torch.load(weight_path, map_location=torch.device("cpu") if device == "cpu" else None) + state_dict = checkpoint.get("state_dict", checkpoint) + model_dict = model.state_dict() + + if "lmbn" in weight_path.parts: + model.load_state_dict(model_dict, strict=True) + else: + new_state_dict = OrderedDict() + matched_layers, discarded_layers = [], [] + for k, v in state_dict.items(): + # Remove 'module.' prefix if present + key = k[7:] if k.startswith("module.") else k + if key in model_dict and model_dict[key].size() == v.size(): + new_state_dict[key] = v + matched_layers.append(key) + else: + discarded_layers.append(key) + model_dict.update(new_state_dict) + model.load_state_dict(model_dict) + + if not matched_layers: + LOGGER.debug(f"Pretrained weights from {weight_path} cannot be loaded. Check key names manually.") + else: + LOGGER.success(f"Loaded pretrained weights from {weight_path}") + + if discarded_layers: + LOGGER.debug(f"Discarded layers due to unmatched keys or size: {discarded_layers}") + + @staticmethod + def show_available_models(): + LOGGER.info("Available models:") + LOGGER.info(list(MODEL_FACTORY.keys())) + + @staticmethod + def get_nr_classes(weights): + # Extract dataset name from weights name, then look up in the class dictionary + dataset_key = weights.name.split('_')[1] + return NR_CLASSES_DICT.get(dataset_key, 1) + + @staticmethod + def build_model(name, num_classes, loss="softmax", pretrained=True, use_gpu=True): + if name not in MODEL_FACTORY: + available = list(MODEL_FACTORY.keys()) + raise KeyError(f"Unknown model '{name}'. Must be one of {available}") + + # Special case handling for clip model + if 'clip' in name: + from boxmot.appearance.backbones.clip.config.defaults import _C as cfg + return MODEL_FACTORY[name](cfg, num_class=num_classes, camera_num=2, view_num=1) + + return MODEL_FACTORY[name]( + num_classes=num_classes, loss=loss, pretrained=pretrained, use_gpu=use_gpu + ) diff --git a/boxmot/configs/__init__.py b/boxmot/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/configs/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/configs/boosttrack.yaml b/boxmot/configs/boosttrack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..927616997297c226cc70f08affb0a10b05a6401a --- /dev/null +++ b/boxmot/configs/boosttrack.yaml @@ -0,0 +1,90 @@ +max_age: + type: uniform + default: 60 + range: [15, 90] + +min_hits: + type: uniform + default: 3 + range: [1, 5] + +det_thresh: + type: uniform + default: 0.6 + range: [0.1, 0.9] + +iou_threshold: + type: uniform + default: 0.3 + range: [0.1, 0.9] + +use_ecc: + type: choice + default: True + options: [False, True] + +min_box_area: + type: uniform + default: 10 + range: [5, 100] + +aspect_ratio_thresh: + type: uniform + default: 1.6 + range: [0.1, 2.0] + +lambda_iou: + type: uniform + default: 0.5 + range: [0.3, 2.0] + +lambda_mhd: + type: uniform + default: 0.25 + range: [0.5, 2.0] + +lambda_shape: + type: uniform + default: 0.25 + range: [0.5, 2.0] + +use_dlo_boost: + type: choice + default: True + options: [False, True] + +use_duo_boost: + type: choice + default: True + options: [False, True] + +dlo_boost_coef: + type: uniform + default: 0.65 + range: [0.3, 2.0] + +s_sim_corr: + type: choice + default: False + options: [False, True] + +use_rich_s: + type: choice + default: True # True for BoostTrack++ + options: [False, True] + +use_sb: + type: choice + default: True # True for BoostTrack++ + options: [False, True] + +use_vt: + type: choice + default: True # True for BoostTrack++ + options: [False, True] + +with_reid: + type: choice + default: True # True for BoostTrack+ and BoostTrack++ + options: [False, True] + diff --git a/boxmot/configs/botsort.yaml b/boxmot/configs/botsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ff1115e492900e4ee9db48552eb40cba5ae62ef --- /dev/null +++ b/boxmot/configs/botsort.yaml @@ -0,0 +1,39 @@ +track_high_thresh: + type: uniform + default: 0.6 # from the default parameters + range: [0.3, 0.7] + +track_low_thresh: + type: uniform + default: 0.1 # from the default parameters + range: [0.1, 0.3] + +new_track_thresh: + type: uniform + default: 0.7 # from the default parameters + range: [0.1, 0.8] + +track_buffer: + type: randint + default: 30 # from the default parameters + range: [20, 81] + +match_thresh: + type: uniform + default: 0.8 # from the default parameters + range: [0.1, 0.9] + +proximity_thresh: + type: uniform + default: 0.5 # from the default parameters + range: [0.25, 0.75] + +appearance_thresh: + type: uniform + default: 0.25 # from the default parameters + range: [0.1, 0.8] + +cmc_method: + type: choice + default: ecc # from the default parameters + options: [sof, ecc] \ No newline at end of file diff --git a/boxmot/configs/bytetrack.yaml b/boxmot/configs/bytetrack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d64d79edb70389786357f82e956f8784015df0a --- /dev/null +++ b/boxmot/configs/bytetrack.yaml @@ -0,0 +1,24 @@ +min_conf: + type: uniform + default: 0.1 # from the default parameters + range: [0.1, 0.3] + +track_thresh: + type: uniform + default: 0.6 # from the default parameters + range: [0.4, 0.6] + +track_buffer: + type: randint + default: 30 # from the default parameters + range: [10, 61, 10] # step size of 10, upper bound exclusive + +match_thresh: + type: uniform + default: 0.9 # from the default parameters + range: [0.7, 0.9] + +frame_rate: + type: choice + default: 30 # from the default parameters + choices: [30] # static choice for Ray Search \ No newline at end of file diff --git a/boxmot/configs/deepocsort.yaml b/boxmot/configs/deepocsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..567386b5b7766e19395caf0c25ccc452025070c7 --- /dev/null +++ b/boxmot/configs/deepocsort.yaml @@ -0,0 +1,74 @@ +det_thresh: + type: uniform + default: 0.5 # from the default parameters + range: [0.3, 0.6] + +max_age: + type: randint + default: 30 # from the default parameters + range: [10, 61, 10] # step size of 10, upper bound exclusive + +min_hits: + type: randint + default: 3 # from the default parameters + range: [1, 6] # upper bound exclusive + +iou_thresh: + type: uniform + default: 0.3 # from the default parameters + range: [0.1, 0.4] + +delta_t: + type: randint + default: 3 # from the default parameters + range: [1, 6] # upper bound exclusive + +asso_func: + type: choice + default: iou # from the default parameters + options: ['iou', 'giou', 'diou', 'ciou', 'hmiou'] + +inertia: + type: uniform + default: 0.2 # from the default parameters + range: [0.1, 0.4] + +w_association_emb: + type: uniform + default: 0.75 # from the default parameters + range: [0.5, 0.9] + +alpha_fixed_emb: + type: uniform + default: 0.95 # from the default parameters + range: [0.9, 0.999] + +aw_param: + type: uniform + default: 0.5 # from the default parameters + range: [0.3, 0.7] + +embedding_off: + type: choice + default: false # from the default parameters + options: [True, False] + +cmc_off: + type: choice + default: false # from the default parameters + options: [True, False] + +aw_off: + type: choice + default: false # from the default parameters + options: [True, False] + +Q_xy_scaling: + type: uniform + default: 0.01 # from the default parameters + range: [0.01, 1] + +Q_s_scaling: + type: uniform + default: 0.0001 # from the default parameters + range: [0.0001, 1] diff --git a/boxmot/configs/hybridsort.yaml b/boxmot/configs/hybridsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..db72d502c555804af1c0dcc6a95ba2fa45035f5f --- /dev/null +++ b/boxmot/configs/hybridsort.yaml @@ -0,0 +1,49 @@ +det_thresh: + type: uniform + default: 0.12442660055370669 # from the default parameters + range: [0, 0.6] + +max_age: + type: randint + default: 30 # from the default parameters + range: [10, 151, 10] # step size of 10, upper bound exclusive + +min_hits: + type: randint + default: 1 # from the default parameters + range: [1, 6] # upper bound exclusive + +delta_t: + type: randint + default: 5 # from the default parameters + range: [1, 6] # upper bound exclusive + +asso_func: + type: choice + default: hmiou # from the default parameters + options: ['iou', 'giou', 'diou'] + +iou_threshold: + type: uniform + default: 0.3 # from the default parameters + range: [0.1, 0.4] + +inertia: + type: uniform + default: 0.369525477649008 # from the default parameters + range: [0.1, 0.4] + +TCM_first_step_weight: + type: uniform + default: 0.2866529225304586 # from the default parameters + range: [0, 0.5] + +longterm_reid_weight: + type: uniform + default: 0.0509704360503877 # from the default parameters + range: [0, 0.5] + +use_byte: + type: choice + default: False # from the default parameters + options: [True, False] \ No newline at end of file diff --git a/boxmot/configs/imprassoc.yaml b/boxmot/configs/imprassoc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bc3f2f813f7fd6b0c019afdabc1dbbb8f2b6e7b --- /dev/null +++ b/boxmot/configs/imprassoc.yaml @@ -0,0 +1,59 @@ +track_high_thresh: + type: uniform + default: 0.5 # from the default parameters + range: [0.3, 0.7] + +track_low_thresh: + type: uniform + default: 0.1 # from the default parameters + range: [0.05, 0.3] + +new_track_thresh: + type: uniform + default: 0.5 # from the default parameters + range: [0.5, 0.9] + +track_buffer: + type: qrandint + default: 35 # from the default parameters + range: [20, 80, 10] # step size of 10, upper bound exclusive + +match_thresh: + type: uniform + default: 0.65 # from the default parameters + range: [0.1, 0.9] + +second_match_thresh: + type: uniform + default: 0.19 # from the default parameters + range: [0.1, 0.4] + +overlap_thresh: + type: uniform + default: 0.55 # from the default parameters + range: [0.3, 0.6] + +proximity_thresh: + type: uniform + default: 0.1 # from the default parameters + range: [0.1, 0.8] + +appearance_thresh: + type: uniform + default: 0.25 # from the default parameters + range: [0.1, 0.8] + +cmc_method: + type: choice + default: sparseOptFlow # from the default parameters + options: ['sparseOptFlow'] + +frame_rate: + type: choice + default: 30 # from the default parameters + options: [30] + +lambda_: + type: uniform + default: 0.05 # from the default parameters + range: [0.05, 0.3] \ No newline at end of file diff --git a/boxmot/configs/ocsort.yaml b/boxmot/configs/ocsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d11c952aa3b2df8b24acc02c020502a160bfa1c7 --- /dev/null +++ b/boxmot/configs/ocsort.yaml @@ -0,0 +1,49 @@ +min_conf: + type: uniform + default: 0.1 # from the default parameters + range: [0.1, 0.3] + +det_thresh: + type: uniform + default: 0.6 # from the default parameters + range: [0, 0.6] + +max_age: + type: grid_search + default: 30 # from the default parameters + values: [10, 20, 30, 40, 50, 60] + +min_hits: + type: grid_search + default: 3 # from the default parameters + values: [1, 2, 3, 4, 5] + +delta_t: + type: grid_search + default: 3 # from the default parameters + values: [1, 2, 3, 4, 5] + +asso_func: + type: choice + default: iou # from the default parameters + options: ['iou', 'giou', 'diou', 'ciou', 'hmiou'] + +use_byte: + type: choice + default: false # from the default parameters + options: [True, False] + +inertia: + type: uniform + default: 0.1 # from the default parameters + range: [0.1, 0.4] + +Q_xy_scaling: + type: loguniform + default: 0.01 # from the default parameters + range: [0.01, 1] + +Q_s_scaling: + type: loguniform + default: 0.0001 # from the default parameters + range: [0.0001, 1] diff --git a/boxmot/configs/strongsort.yaml b/boxmot/configs/strongsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a7982f10982635324ea3ac89781c2292d48ec326 --- /dev/null +++ b/boxmot/configs/strongsort.yaml @@ -0,0 +1,39 @@ +min_conf: + type: uniform + default: 0.6 # from the default parameters + range: [0.2, 0.8] + +ema_alpha: + type: uniform + default: 0.9 # from the default parameters + range: [0.7, 0.95] + +max_cos_dist: + type: uniform + default: 0.4 # from the default parameters + range: [0.1, 0.4] + +max_iou_dist: + type: uniform + default: 0.7 # from the default parameters + range: [0.5, 0.95] + +max_age: + type: randint + default: 30 # from the default parameters + range: [10, 151] # upper bound exclusive + +n_init: + type: randint + default: 3 # from the default parameters + range: [1, 4] # upper bound exclusive + +mc_lambda: + type: uniform + default: 0.98 # from the default parameters + range: [0.90, 0.999] + +nn_budget: + type: choice + default: 100 # from the default parameters + options: [100] diff --git a/boxmot/data/__init__.py b/boxmot/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/data/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/data/loader.py b/boxmot/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ef85db1d08aa91475c9316e04f3d860f30674691 --- /dev/null +++ b/boxmot/data/loader.py @@ -0,0 +1,128 @@ +import os +import cv2 +import glob +import math +import numpy as np +from pathlib import Path +from PIL import Image + + +VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes + + +class LoadImagesAndVideos: + """ + A data loader for handling both images and videos, providing batches of frames or images for processing. + Supports various image formats, including HEIC, and handles text files with paths to images/videos. + """ + + def __init__(self, path, batch_size=1, vid_stride=1): + self.batch_size = batch_size + self.vid_stride = vid_stride + self.files = self._load_files(path) + self.video_flag = [self._is_video(f) for f in self.files] + self.nf = len(self.files) + self.ni = sum(not is_video for is_video in self.video_flag) + self.mode = "image" + + self.cap = None + if any(self.video_flag): + self._start_video(self.files[self.video_flag.index(True)]) + + if not self.files: + raise FileNotFoundError(f"No images or videos found in {path}.") + + def _load_files(self, path): + """Load files from a given path, which may be a directory, list, or text file.""" + if isinstance(path, str) and Path(path).suffix == ".txt": + path = Path(path).read_text().splitlines() + + files = [] + for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: + p = str(Path(p).absolute()) + if "*" in p: + files.extend(glob.glob(p, recursive=True)) + elif os.path.isdir(p): + files.extend(glob.glob(os.path.join(p, "*.*"))) + elif os.path.isfile(p): + files.append(p) + else: + raise FileNotFoundError(f"{p} does not exist") + return files + + def _is_video(self, file_path): + """Check if a file is a video based on its extension.""" + return file_path.split('.')[-1].lower() in VID_FORMATS + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + paths, imgs, infos = [], [], [] + while len(imgs) < self.batch_size: + if self.count >= self.nf: + if imgs: + return paths, imgs, infos + else: + raise StopIteration + + path = self.files[self.count] + if self.video_flag[self.count]: + self._process_video(paths, imgs, infos, path) + else: + self._process_image(paths, imgs, infos, path) + self.count += 1 + + return paths, imgs, infos + + def _process_image(self, paths, imgs, infos, path): + """Process an image file and append it to the batch.""" + img = self._read_image(path) + if img is not None: + paths.append(path) + imgs.append(img) + infos.append(f"image {self.count + 1}/{self.nf} {path}") + + def _process_video(self, paths, imgs, infos, path): + """Process a video file, reading frames as per the stride.""" + self.mode = "video" + if not self.cap or not self.cap.isOpened(): + self._start_video(path) + + success = False + for _ in range(self.vid_stride): + success = self.cap.grab() + if not success: + break + + if success: + _, frame = self.cap.retrieve() + paths.append(path) + imgs.append(frame) + infos.append(f"video {self.count + 1}/{self.nf} frame {self.frame}/{self.frames} {path}") + self.frame += 1 + if self.frame >= self.frames: + self.cap.release() + + def _read_image(self, path): + """Read an image from a file, handling HEIC format if necessary.""" + if path.lower().endswith("heic"): + from pillow_heif import register_heif_opener + register_heif_opener() + with Image.open(path) as img: + return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + else: + return cv2.imread(path) + + def _start_video(self, path): + """Initialize video capture for a new video file.""" + self.cap = cv2.VideoCapture(path) + if not self.cap.isOpened(): + raise FileNotFoundError(f"Failed to open video {path}") + self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) + self.frame = 0 + + def __len__(self): + return math.ceil(self.nf / self.batch_size) diff --git a/boxmot/motion/__init__.py b/boxmot/motion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/motion/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/motion/cmc/__init__.py b/boxmot/motion/cmc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..594cef7df58177bceb5e815966bcbe9d60467096 --- /dev/null +++ b/boxmot/motion/cmc/__init__.py @@ -0,0 +1,19 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from boxmot.motion.cmc.ecc import ECC +from boxmot.motion.cmc.orb import ORB +from boxmot.motion.cmc.sift import SIFT +from boxmot.motion.cmc.sof import SOF + + +def get_cmc_method(cmc_method): + if cmc_method == 'ecc': + return ECC + elif cmc_method == 'orb': + return ORB + elif cmc_method == 'sof': + return SOF + elif cmc_method == 'sift': + return SIFT + else: + return None diff --git a/boxmot/motion/cmc/base_cmc.py b/boxmot/motion/cmc/base_cmc.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb48291bc14a4a517608eb9f7b8a3677d0ef722 --- /dev/null +++ b/boxmot/motion/cmc/base_cmc.py @@ -0,0 +1,42 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import cv2 +import numpy as np +from abc import ABC, abstractmethod + + +class BaseCMC(ABC): + + @abstractmethod + def apply(self, im): + pass + + def generate_mask(self, img, dets, scale): + h, w = img.shape + mask = np.zeros_like(img) + + mask[int(0.02 * h): int(0.98 * h), int(0.02 * w): int(0.98 * w)] = 255 + if dets is not None: + for det in dets: + tlbr = np.multiply(det, scale).astype(int) + mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0 + + return mask + + def preprocess(self, img): + + # bgr2gray + if self.grayscale: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # resize + if self.scale is not None: + img = cv2.resize( + img, + (0, 0), + fx=self.scale, + fy=self.scale, + interpolation=cv2.INTER_LINEAR + ) + + return img diff --git a/boxmot/motion/cmc/ecc.py b/boxmot/motion/cmc/ecc.py new file mode 100644 index 0000000000000000000000000000000000000000..7a02fce5ee884dcdbb0d47156ee5f26e1630c04b --- /dev/null +++ b/boxmot/motion/cmc/ecc.py @@ -0,0 +1,143 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import time + +import cv2 +import numpy as np + +from boxmot.motion.cmc.base_cmc import BaseCMC +from boxmot.utils import BOXMOT +from boxmot.utils import logger as LOGGER + + +class ECC(BaseCMC): + def __init__( + self, + warp_mode: int = cv2.MOTION_EUCLIDEAN, + eps: float = 1e-5, + max_iter: int = 100, + scale: float = 0.1, + align: bool = False, + grayscale: bool = True + ) -> None: + """Compute the warp matrix from src to dst. + + Parameters + ---------- + warp_mode: opencv flag + translation: cv2.MOTION_TRANSLATION + rotated and shifted: cv2.MOTION_EUCLIDEAN + affine(shift,rotated,shear): cv2.MOTION_AFFINE + homography(3d): cv2.MOTION_HOMOGRAPHY + eps: float + the threshold of the increment in the correlation coefficient between two iterations + max_iter: int + the number of iterations. + scale: float or [int, int] + scale_ratio: float + scale_size: [W, H] + align: bool + whether to warp affine or perspective transforms to the source image + grayscale: bool + whether to transform 3 channel RGB to single channel grayscale for faster computations + + Returns + ------- + warp matrix : ndarray + Returns the warp matrix from src to dst. + if motion models is homography, the warp matrix will be 3x3, otherwise 2x3 + src_aligned: ndarray + aligned source image of gray + """ + self.align = align + self.grayscale = grayscale + self.scale = scale + self.warp_mode = warp_mode + self.termination_criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, max_iter, eps) + self.prev_img = None + + def apply(self, img: np.ndarray, dets: np.ndarray = None) -> np.ndarray: + """Apply sparse optical flow to compute the warp matrix. + + Parameters: + img (ndarray): The input image. + dets: Description of dets parameter. + + Returns: + ndarray: The warp matrix from the source to the destination. + If the motion model is homography, the warp matrix will be 3x3; otherwise, it will be 2x3. + """ + + if self.warp_mode == cv2.MOTION_HOMOGRAPHY: + warp_matrix = np.eye(3, 3, dtype=np.float32) + else: + warp_matrix = np.eye(2, 3, dtype=np.float32) + + if self.prev_img is None: + self.prev_img = self.preprocess(img) + return warp_matrix + + img = self.preprocess(img) + + try: + (ret_val, warp_matrix) = cv2.findTransformECC( + self.prev_img, # already processed + img, + warp_matrix, + self.warp_mode, + self.termination_criteria, + None, + 1 + ) + except Exception as e: + LOGGER.warning(f'Affine matrix could not be generated: {e}. Returning identity') + return warp_matrix + + # upscale warp matrix to original images size + if self.scale < 1: + warp_matrix[0, 2] /= self.scale + warp_matrix[1, 2] /= self.scale + + if self.align: + h, w = self.prev_img.shape + if self.warp_mode == cv2.MOTION_HOMOGRAPHY: + # Use warpPerspective for Homography + self.prev_img_aligned = cv2.warpPerspective(self.prev_img, warp_matrix, (w, h), flags=cv2.INTER_LINEAR) + else: + # Use warpAffine for Translation, Euclidean and Affine + self.prev_img_aligned = cv2.warpAffine(self.prev_img, warp_matrix, (w, h), flags=cv2.INTER_LINEAR) + else: + self.prev_img_aligned = None + + self.prev_img = img + + return warp_matrix # , prev_img_aligned + + +def main(): + ecc = ECC(scale=0.5, align=True, grayscale=True) + curr_img = cv2.imread('assets/MOT17-mini/train/MOT17-2-FRCNN/img1/000005.jpg') + prev_img = cv2.imread('assets/MOT17-mini/train/MOT17-2-FRCNN/img1/000001.jpg') + + warp_matrix = ecc.apply(prev_img, None) + warp_matrix = ecc.apply(curr_img, None) + + start = time.process_time() + for i in range(0, 100): + warp_matrix = ecc.apply(prev_img, None) + warp_matrix = ecc.apply(curr_img, None) + end = time.process_time() + print('Total time', end - start) + print(warp_matrix) + + if ecc.prev_img_aligned is not None: + curr_img = ecc.preprocess(curr_img) + prev_img = ecc.preprocess(prev_img) + weighted_img = cv2.addWeighted(curr_img, 0.5, ecc.prev_img_aligned, 0.5, 0) + cv2.imshow('prev_img_aligned', weighted_img) + cv2.waitKey(0) + cv2.imwrite(str(BOXMOT / 'motion/cmc/ecc_aligned.jpg'), weighted_img) + + +if __name__ == "__main__": + main() diff --git a/boxmot/motion/cmc/orb.py b/boxmot/motion/cmc/orb.py new file mode 100644 index 0000000000000000000000000000000000000000..841b9a7e032f05d637ecdc41a393836962d8e41f --- /dev/null +++ b/boxmot/motion/cmc/orb.py @@ -0,0 +1,251 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import copy +import time + +import cv2 +import numpy as np + +from boxmot.motion.cmc.base_cmc import BaseCMC +from boxmot.utils import BOXMOT + + +class ORB(BaseCMC): + + def __init__( + self, + feature_detector_threshold: int = 20, + matcher_norm_type: int = cv2.NORM_HAMMING, + scale: float = 0.1, + grayscale: bool = True, + draw_keypoint_matches: bool = False, + align: bool = False + ) -> None: + """Compute the warp matrix from src to dst. + + Parameters + ---------- + feature_detector_threshold: int, optional + The threshold for feature extraction. Defaults to 20. + matcher_norm_type: int, optional + The norm type of the matcher. Defaults to cv2.NORM_HAMMING. + scale: float, optional + Scale ratio. Defaults to 0.1. + grayscale: bool, optional + Whether to transform 3-channel RGB to single-channel grayscale for faster computations. + Defaults to True. + draw_keypoint_matches: bool, optional + Whether to draw keypoint matches on the output image. Defaults to False. + align: bool, optional + Whether to align the images based on keypoint matches. Defaults to False. + """ + self.grayscale = grayscale + self.scale = scale + + self.detector = cv2.FastFeatureDetector_create(threshold=feature_detector_threshold) + self.extractor = cv2.ORB_create() + self.matcher = cv2.BFMatcher(matcher_norm_type) + + self.prev_img = None + self.draw_keypoint_matches = draw_keypoint_matches + self.align = align + + def apply(self, img: np.ndarray, dets: np.ndarray) -> np.ndarray: + """Apply ORB-based sparse optical flow to compute the warp matrix. + + Parameters + ---------- + img : ndarray + The input image. + dets : ndarray + Detected bounding boxes in the image. + + Returns + ------- + ndarray + The warp matrix from the matching keypoint in the previous image to the current. + The warp matrix is always 2x3. + """ + + H = np.eye(2, 3) + + img = self.preprocess(img) + h, w = img.shape + + # generate dynamic object maks + mask = self.generate_mask(img, dets, self.scale) + + # find static keypoints + keypoints = self.detector.detect(img, mask) + + # compute the descriptors + keypoints, descriptors = self.extractor.compute(img, keypoints) + + # handle first frame + if self.prev_img is None: + # Initialize data + self.prev_dets = dets.copy() + self.prev_img = img.copy() + self.prev_keypoints = copy.copy(keypoints) + self.prev_descriptors = copy.copy(descriptors) + + return H + + # Match descriptors. + knnMatches = self.matcher.knnMatch(self.prev_descriptors, descriptors, k=2) + + # Handle empty matches case + if len(knnMatches) == 0: + # Store to next iteration + self.prev_img = img.copy() + self.prev_keypoints = copy.copy(keypoints) + self.prev_descriptors = copy.copy(descriptors) + + return H + + # filtered matches based on smallest spatial distance + matches = [] + spatial_distances = [] + max_spatial_distance = 0.25 * np.array([w, h]) + + for m, n in knnMatches: + if m.distance < 0.9 * n.distance: + prevKeyPointLocation = self.prev_keypoints[m.queryIdx].pt + currKeyPointLocation = keypoints[m.trainIdx].pt + + spatial_distance = (prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1]) + + if (np.abs(spatial_distance[0]) < max_spatial_distance[0]) and \ + (np.abs(spatial_distance[1]) < max_spatial_distance[1]): + spatial_distances.append(spatial_distance) + matches.append(m) + + mean_spatial_distances = np.mean(spatial_distances, 0) + std_spatial_distances = np.std(spatial_distances, 0) + + inliesrs = (spatial_distances - mean_spatial_distances) < 2.5 * std_spatial_distances + + goodMatches = [] + prevPoints = [] + currPoints = [] + for i in range(len(matches)): + if inliesrs[i, 0] and inliesrs[i, 1]: + goodMatches.append(matches[i]) + prevPoints.append(self.prev_keypoints[matches[i].queryIdx].pt) + currPoints.append(keypoints[matches[i].trainIdx].pt) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # draw keypoint matches on the output image + if self.draw_keypoint_matches: + self.prev_img[:, :][mask == True] = 0 # noqa:E712 + self.matches_img = np.hstack((self.prev_img, img)) + self.matches_img = cv2.cvtColor(self.matches_img, cv2.COLOR_GRAY2BGR) + + W = np.size(self.prev_img, 1) + for m in goodMatches: + prev_pt = np.array(self.prev_keypoints[m.queryIdx].pt, dtype=np.int_) + curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_) + curr_pt[0] += W + color = np.random.randint(0, 255, (3,)) + color = (int(color[0]), int(color[1]), int(color[2])) + self.matches_img = cv2.line(self.matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA) + self.matches_img = cv2.circle(self.matches_img, prev_pt, 2, tuple(color), -1) + self.matches_img = cv2.circle(self.matches_img, curr_pt, 2, tuple(color), -1) + for det in dets: + det = np.multiply(det, self.scale).astype(int) + start = (det[0] + w, det[1]) + end = (det[2] + w, det[3]) + self.matches_img = cv2.rectangle(self.matches_img, start, end, (0, 0, 255), 2) + for det in self.prev_dets: + det = np.multiply(det, self.scale).astype(int) + start = (det[0], det[1]) + end = (det[2], det[3]) + self.matches_img = cv2.rectangle(self.matches_img, start, end, (0, 0, 255), 2) + else: + self.matches_img = None + + # find rigid matrix + if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)): + H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + # upscale warp matrix to original images size + if self.scale < 1.0: + H[0, 2] /= self.scale + H[1, 2] /= self.scale + + if self.align: + self.prev_img_aligned = cv2.warpAffine(self.prev_img, H, (w, h), flags=cv2.INTER_LINEAR) + else: + print('Warning: not enough matching points') + + # Store to next iteration + self.prev_img = img.copy() + self.prev_keypoints = copy.copy(keypoints) + self.prev_descriptors = copy.copy(descriptors) + + return H + + +def main(): + orb = ORB(scale=0.5, align=True, grayscale=True, draw_keypoint_matches=False) + curr_img = cv2.imread('assets/MOT17-mini/train/MOT17-13-FRCNN/img1/000005.jpg') + prev_img = cv2.imread('assets/MOT17-mini/train/MOT17-13-FRCNN/img1/000001.jpg') + curr_dets = np.array( + [[1083.8207, 541.5978, 1195.7952, 655.8790], # noqa:E241 + [1635.6456, 563.8348, 1695.4153, 686.6704], # noqa:E241 + [ 957.0879, 545.6558, 1042.6743, 611.8740], # noqa:E241,E261,E201 + [1550.0317, 562.5705, 1600.3931, 684.7425], # noqa:E241 + [ 78.8801, 714.3307, 121.0272, 817.6857], # noqa:E241,E261,E201 + [1382.9938, 512.2731, 1418.6012, 620.1938], # noqa:E241 + [1459.7921, 496.2123, 1488.5767, 584.3533], # noqa:E241 + [ 982.9818, 492.8579, 1013.6625, 517.9271], # noqa:E241,E261,E201 + [ 496.1809, 541.3972, 531.4617, 638.0989], # noqa:E241,E261,E201 + [1498.8512, 522.6646, 1526.1145, 587.7672], # noqa:E241 + [ 536.4527, 548.4061, 569.2723, 635.5656], # noqa:E241,E261,E201 + [ 247.8834, 580.8851, 287.2241, 735.3685], # noqa:E241,E261,E201 + [ 151.4096, 572.3918, 203.5401, 731.1011], # noqa:E241,E261,E201 + [1227.4098, 440.5505, 1252.7986, 489.5295]] # noqa:E241 + ) + prev_dets = np.array( + [[2.1069e-02, 6.7026e+02, 4.9816e+01, 8.8407e+02], + [1.0765e+03, 5.4009e+02, 1.1883e+03, 6.5219e+02], + [1.5208e+03, 5.6322e+02, 1.5711e+03, 6.7676e+02], + [1.6111e+03, 5.5926e+02, 1.6640e+03, 6.7443e+02], + [9.5244e+02, 5.4681e+02, 1.0384e+03, 6.1180e+02], + [1.3691e+03, 5.1258e+02, 1.4058e+03, 6.1695e+02], + [1.2043e+02, 7.0780e+02, 1.7309e+02, 8.0518e+02], + [1.4454e+03, 5.0919e+02, 1.4724e+03, 5.8270e+02], + [9.7848e+02, 4.9563e+02, 1.0083e+03, 5.1980e+02], + [5.0166e+02, 5.4778e+02, 5.3796e+02, 6.3940e+02], + [1.4777e+03, 5.1856e+02, 1.5105e+03, 5.9523e+02], + [1.9540e+02, 5.7292e+02, 2.3711e+02, 7.2717e+02], + [2.7373e+02, 5.8564e+02, 3.1335e+02, 7.3281e+02], + [5.4038e+02, 5.4735e+02, 5.7359e+02, 6.3797e+02], + [1.2190e+03, 4.4176e+02, 1.2414e+03, 4.9038e+02]] + ) + + warp_matrix = orb.apply(prev_img, prev_dets) + warp_matrix = orb.apply(curr_img, curr_dets) + + start = time.process_time() + for i in range(0, 100): + warp_matrix = orb.apply(prev_img, prev_dets) + warp_matrix = orb.apply(curr_img, curr_dets) + end = time.process_time() + print('Total time', end - start) + print(warp_matrix) + + if orb.prev_img_aligned is not None: + curr_img = orb.preprocess(curr_img) + prev_img = orb.preprocess(prev_img) + weighted_img = cv2.addWeighted(curr_img, 0.5, orb.prev_img_aligned, 0.5, 0) + cv2.imshow('prev_img_aligned', weighted_img) + cv2.waitKey(0) + cv2.imwrite(str(BOXMOT / 'motion/cmc/orb_aligned.jpg'), weighted_img) + + +if __name__ == "__main__": + main() diff --git a/boxmot/motion/cmc/sift.py b/boxmot/motion/cmc/sift.py new file mode 100644 index 0000000000000000000000000000000000000000..7c80e0a90fbb5dc60d4dc55ddfa4900e158fb1b5 --- /dev/null +++ b/boxmot/motion/cmc/sift.py @@ -0,0 +1,264 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import copy +import time + +import cv2 +import numpy as np + +from boxmot.motion.cmc.base_cmc import BaseCMC +from boxmot.utils import BOXMOT + + +class SIFT(BaseCMC): + + def __init__( + self, + warp_mode=cv2.MOTION_EUCLIDEAN, + eps=1e-5, + max_iter=100, + scale=0.1, + grayscale=True, + draw_keypoint_matches=False, + align=False + ): + """Compute the warp matrix from src to dst. + + Parameters + ---------- + warp_mode: opencv flag + translation: cv2.MOTION_TRANSLATION + rotated and shifted: cv2.MOTION_EUCLIDEAN + affine(shift,rotated,shear): cv2.MOTION_AFFINE + homography(3d): cv2.MOTION_HOMOGRAPHY + eps: float + the threshold of the increment in the correlation coefficient between two iterations + max_iter: int + the number of iterations. + scale: float or [int, int] + scale_ratio: float + scale_size: [W, H] + align: bool + whether to warp affine or perspective transforms to the source image + grayscale: bool + whether to transform 3 channel RGB to single channel grayscale for faster computations + + Returns + ------- + warp matrix : ndarray + Returns the warp matrix from src to dst. + if motion models is homography, the warp matrix will be 3x3, otherwise 2x3 + src_aligned: ndarray + aligned source image of gray + """ + self.grayscale = grayscale + self.scale = scale + self.warp_mode = warp_mode + self.termination_criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, max_iter, eps) + if self.warp_mode == cv2.MOTION_HOMOGRAPHY: + self.warp_matrix = np.eye(3, 3, dtype=np.float32) + else: + self.warp_matrix = np.eye(2, 3, dtype=np.float32) + + self.detector = cv2.SIFT_create(nOctaveLayers=2, contrastThreshold=0.5, edgeThreshold=10) + self.extractor = cv2.SIFT_create(nOctaveLayers=2, contrastThreshold=0.5, edgeThreshold=10) + self.matcher = cv2.BFMatcher(cv2.NORM_L2) + + self.prev_img = None + self.minimum_features = 10 + self.prev_desc = None + self.draw_keypoint_matches = draw_keypoint_matches + self.align = align + + def apply(self, img: np.ndarray, dets: np.ndarray) -> np.ndarray: + """Apply ORB-based sparse optical flow to compute the warp matrix. + + Parameters + ---------- + img : ndarray + The input image. + dets : ndarray + Detected bounding boxes in the image. + + Returns + ------- + ndarray + The warp matrix from the matching keypoint in the previous image to the current. + The warp matrix is always 2x3. + """ + + H = np.eye(2, 3) + + img = self.preprocess(img) + h, w = img.shape + + # generate dynamic object maks + mask = self.generate_mask(img, dets, self.scale) + + # find static keypoints + keypoints = self.detector.detect(img, mask) + + # compute the descriptors + keypoints, descriptors = self.extractor.compute(img, keypoints) + + # handle first frame + if self.prev_img is None: + # Initialize data + self.prev_dets = dets.copy() + self.prev_img = img.copy() + self.prev_keypoints = copy.copy(keypoints) + self.prev_descriptors = copy.copy(descriptors) + + return H + + # Match descriptors. + knnMatches = self.matcher.knnMatch(self.prev_descriptors, descriptors, k=2) + + # Handle empty matches case + if len(knnMatches) == 0: + # Store to next iteration + self.prev_img = img.copy() + self.prev_keypoints = copy.copy(keypoints) + self.prev_descriptors = copy.copy(descriptors) + + return H + + # filtered matches based on smallest spatial distance + matches = [] + spatial_distances = [] + max_spatial_distance = 0.25 * np.array([w, h]) + + for m, n in knnMatches: + if m.distance < 0.9 * n.distance: + prevKeyPointLocation = self.prev_keypoints[m.queryIdx].pt + currKeyPointLocation = keypoints[m.trainIdx].pt + + spatial_distance = (prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1]) + + if (np.abs(spatial_distance[0]) < max_spatial_distance[0]) and \ + (np.abs(spatial_distance[1]) < max_spatial_distance[1]): + spatial_distances.append(spatial_distance) + matches.append(m) + + mean_spatial_distances = np.mean(spatial_distances, 0) + std_spatial_distances = np.std(spatial_distances, 0) + + inliers = (spatial_distances - mean_spatial_distances) < 2.5 * std_spatial_distances + + good_matches = [matches[i] for i in range(len(matches)) if np.all(inliers[i])] + + prevPoints = np.array([self.prev_keypoints[m.queryIdx].pt for m in good_matches]) + currPoints = np.array([keypoints[m.trainIdx].pt for m in good_matches]) + + # Draw the keypoint matches on the output image + if self.draw_keypoint_matches: + self.prev_img[:, :][mask == True] = 0 # noqa:E712 + self.matches_img = np.hstack((self.prev_img, img)) + self.matches_img = cv2.cvtColor(self.matches_img, cv2.COLOR_GRAY2BGR) + + W = np.size(self.prev_img, 1) + for m in goodMatches: + prev_pt = np.array(self.prev_keypoints[m.queryIdx].pt, dtype=np.int_) + curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_) + curr_pt[0] += W + color = np.random.randint(0, 255, (3,)) + color = (int(color[0]), int(color[1]), int(color[2])) + self.matches_img = cv2.line(self.matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA) + self.matches_img = cv2.circle(self.matches_img, prev_pt, 2, tuple(color), -1) + self.matches_img = cv2.circle(self.matches_img, curr_pt, 2, tuple(color), -1) + for det in dets: + det = np.multiply(det, self.scale).astype(int) + start = (det[0] + w, det[1]) + end = (det[2] + w, det[3]) + self.matches_img = cv2.rectangle(self.matches_img, start, end, (0, 0, 255), 2) + for det in self.prev_dets: + det = np.multiply(det, self.scale).astype(int) + start = (det[0], det[1]) + end = (det[2], det[3]) + self.matches_img = cv2.rectangle(self.matches_img, start, end, (0, 0, 255), 2) + else: + self.matches_img = None + + # find rigid matrix + if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)): + H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + # upscale warp matrix to original images size + if self.scale < 1.0: + H[0, 2] /= self.scale + H[1, 2] /= self.scale + + if self.align: + self.prev_img_aligned = cv2.warpAffine(self.prev_img, H, (w, h), flags=cv2.INTER_LINEAR) + else: + print('Warning: not enough matching points') + + # Store to next iteration + self.prev_img = img.copy() + self.prev_keypoints = copy.copy(keypoints) + self.prev_descriptors = copy.copy(descriptors) + + return H + + +def main(): + sift = SIFT(scale=0.5, align=True, grayscale=True, draw_keypoint_matches=False) + curr_img = cv2.imread('assets/MOT17-mini/train/MOT17-13-FRCNN/img1/000005.jpg') + prev_img = cv2.imread('assets/MOT17-mini/train/MOT17-13-FRCNN/img1/000001.jpg') + curr_dets = np.array( + [[1083.8207, 541.5978, 1195.7952, 655.8790], # noqa:E241 + [1635.6456, 563.8348, 1695.4153, 686.6704], # noqa:E241 + [ 957.0879, 545.6558, 1042.6743, 611.8740], # noqa:E241,E261,E201 + [1550.0317, 562.5705, 1600.3931, 684.7425], # noqa:E241 + [ 78.8801, 714.3307, 121.0272, 817.6857], # noqa:E241,E261,E201 + [1382.9938, 512.2731, 1418.6012, 620.1938], # noqa:E241 + [1459.7921, 496.2123, 1488.5767, 584.3533], # noqa:E241 + [ 982.9818, 492.8579, 1013.6625, 517.9271], # noqa:E241,E261,E201 + [ 496.1809, 541.3972, 531.4617, 638.0989], # noqa:E241,E261,E201 + [1498.8512, 522.6646, 1526.1145, 587.7672], # noqa:E241 + [ 536.4527, 548.4061, 569.2723, 635.5656], # noqa:E241,E261,E201 + [ 247.8834, 580.8851, 287.2241, 735.3685], # noqa:E241,E261,E201 + [ 151.4096, 572.3918, 203.5401, 731.1011], # noqa:E241,E261,E201 + [1227.4098, 440.5505, 1252.7986, 489.5295]] # noqa:E241 + ) + prev_dets = np.array( + [[2.1069e-02, 6.7026e+02, 4.9816e+01, 8.8407e+02], + [1.0765e+03, 5.4009e+02, 1.1883e+03, 6.5219e+02], + [1.5208e+03, 5.6322e+02, 1.5711e+03, 6.7676e+02], + [1.6111e+03, 5.5926e+02, 1.6640e+03, 6.7443e+02], + [9.5244e+02, 5.4681e+02, 1.0384e+03, 6.1180e+02], + [1.3691e+03, 5.1258e+02, 1.4058e+03, 6.1695e+02], + [1.2043e+02, 7.0780e+02, 1.7309e+02, 8.0518e+02], + [1.4454e+03, 5.0919e+02, 1.4724e+03, 5.8270e+02], + [9.7848e+02, 4.9563e+02, 1.0083e+03, 5.1980e+02], + [5.0166e+02, 5.4778e+02, 5.3796e+02, 6.3940e+02], + [1.4777e+03, 5.1856e+02, 1.5105e+03, 5.9523e+02], + [1.9540e+02, 5.7292e+02, 2.3711e+02, 7.2717e+02], + [2.7373e+02, 5.8564e+02, 3.1335e+02, 7.3281e+02], + [5.4038e+02, 5.4735e+02, 5.7359e+02, 6.3797e+02], + [1.2190e+03, 4.4176e+02, 1.2414e+03, 4.9038e+02]] + ) + + warp_matrix = sift.apply(prev_img, prev_dets) + warp_matrix = sift.apply(curr_img, curr_dets) + + start = time.process_time() + for i in range(0, 100): + warp_matrix = sift.apply(prev_img, prev_dets) + warp_matrix = sift.apply(curr_img, curr_dets) + end = time.process_time() + print('Total time', end - start) + print(warp_matrix) + + if sift.prev_img_aligned is not None: + curr_img = sift.preprocess(curr_img) + prev_img = sift.preprocess(prev_img) + weighted_img = cv2.addWeighted(curr_img, 0.5, sift.prev_img_aligned, 0.5, 0) + cv2.imshow('prev_img_aligned', weighted_img) + cv2.waitKey(0) + cv2.imwrite(str(BOXMOT / 'motion/cmc/sift_aligned.jpg'), weighted_img) + + +if __name__ == "__main__": + main() diff --git a/boxmot/motion/cmc/sof.py b/boxmot/motion/cmc/sof.py new file mode 100644 index 0000000000000000000000000000000000000000..82885cfca5e13ea111e85abce1cd6a888f3802b5 --- /dev/null +++ b/boxmot/motion/cmc/sof.py @@ -0,0 +1,169 @@ +import cv2 +import numpy as np +import copy +import time +from boxmot.motion.cmc.base_cmc import BaseCMC + + +class SOF(BaseCMC): + """ + Sparse Optical Flow (SOF) tracker for estimating a 2x3 warp (affine transformation) + between consecutive frames. This class is modeled after a GMC implementation using + the 'sparseOptFlow' method. + """ + def __init__(self, scale=0.1): + """ + Initialize the SOF object. + + Parameters + ---------- + downscale : int, optional + Factor by which to downscale the input frames. Defaults to 1 (no downscale). + feature_params : dict, optional + Parameters for cv2.goodFeaturesToTrack. Defaults to: + { + maxCorners: 1000, + qualityLevel: 0.01, + minDistance: 1, + blockSize: 3, + useHarrisDetector: False, + k: 0.04 + } + lk_params : dict, optional + Lucas-Kanade optical flow parameters. Defaults to: + { + winSize: (21, 21), + maxLevel: 3, + criteria: (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 30, 0.01) + } + """ + self.scale = scale + self.grayscale = True + + # Set default feature detection parameters if not provided + self.feature_params = dict( + maxCorners=1000, + qualityLevel=0.01, + minDistance=1, + blockSize=3, + useHarrisDetector=False, + k=0.04 + ) + + # Set default Lucas-Kanade optical flow parameters if not provided. + self.lk_params = dict( + winSize=(21, 21), + maxLevel=3, + criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 30, 0.01) + ) + + self.prevFrame = None + self.prevKeyPoints = None + self.initializedFirstFrame = False + + def apply(self, img, detections=None): + """ + Apply sparse optical flow tracking to estimate a warp (affine transformation) + between the previous frame and the current raw frame. + + Parameters + ---------- + raw_frame : np.ndarray + The current input color image. + detections : Any, optional + (Not used here but provided for API compatibility.) + + Returns + ------- + np.ndarray + The estimated 2x3 warp matrix. If estimation fails, returns an identity matrix. + """ + # Convert the raw frame to grayscale. + frame_gray = self.preprocess(img) + height, width = frame_gray.shape + + # Default transformation: identity. + H = np.eye(2, 3, dtype=np.float32) + + # On the first frame, detect keypoints and initialize internal state. + if not self.initializedFirstFrame: + keypoints = cv2.goodFeaturesToTrack(frame_gray, mask=None, **self.feature_params) + if keypoints is None: + return H + # Optional subpixel refinement. + term_crit = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 30, 0.01) + cv2.cornerSubPix(frame_gray, keypoints, winSize=(5, 5), zeroZone=(-1, -1), criteria=term_crit) + self.prevFrame = frame_gray.copy() + self.prevKeyPoints = keypoints.copy() + self.initializedFirstFrame = True + return H + + # Compute optical flow to track the previous keypoints into the current frame. + nextKeypoints, status, err = cv2.calcOpticalFlowPyrLK( + self.prevFrame, frame_gray, self.prevKeyPoints, None, **self.lk_params + ) + + # Filter out points that were not successfully tracked. + valid_prev = [] + valid_next = [] + for i, s in enumerate(status): + if s: + valid_prev.append(self.prevKeyPoints[i]) + valid_next.append(nextKeypoints[i]) + + if len(valid_prev) < 4: + print("Warning: not enough matching points detected; redetecting keypoints.") + # If too few matches, re-detect keypoints for the current frame. + keypoints = cv2.goodFeaturesToTrack(frame_gray, mask=None, **self.feature_params) + self.prevFrame = frame_gray.copy() + self.prevKeyPoints = keypoints if keypoints is not None else np.array([]) + return H + + valid_prev = np.array(valid_prev) + valid_next = np.array(valid_next) + + # Estimate the affine warp matrix using robust RANSAC. + H_est, inliers = cv2.estimateAffinePartial2D(valid_prev, valid_next, method=cv2.RANSAC) + if H_est is None: + H_est = H + else: + # If the frame was downscaled, adjust the translation parameters back to original scale. + if self.scale < 1: + H_est[0, 2] /= self.scale + H_est[1, 2] /= self.scale + + # Update the previous frame and keypoints for the next iteration. + # Optionally, you might want to re-detect keypoints rather than simply tracking them. + new_keypoints = cv2.goodFeaturesToTrack(frame_gray, mask=None, **self.feature_params) + if new_keypoints is None: + # Use the tracked keypoints if new detection fails. + new_keypoints = valid_next + self.prevFrame = frame_gray.copy() + self.prevKeyPoints = new_keypoints.copy() + + return H_est + + +# ============================================================================== +# Example Usage +# ============================================================================== + +def main(): + # Create an instance of the SOF class with a downscaling factor, if desired. + sof_tracker = SOF(scale=0.3) + + # For example purposes, load two consecutive frames. + prev_img = cv2.imread("assets/MOT17-mini/train/MOT17-13-FRCNN/img1/000001.jpg") + curr_img = cv2.imread("assets/MOT17-mini/train/MOT17-13-FRCNN/img1/000005.jpg") + + # Process the first frame to initialize the tracker. + _ = sof_tracker.apply(prev_img) + + # Now process the next frame to compute the warp matrix. + H = sof_tracker.apply(curr_img) + print("Estimated warp matrix:\n", H) + + # Optionally, you can visualize the transformation (overlay, etc.) + +if __name__ == "__main__": + main() diff --git a/boxmot/motion/kalman_filters/__init__.py b/boxmot/motion/kalman_filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/motion/kalman_filters/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/motion/kalman_filters/aabb/base_kalman_filter.py b/boxmot/motion/kalman_filters/aabb/base_kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..a20bd498aa24b8ad9ab4082946b06e817539ff30 --- /dev/null +++ b/boxmot/motion/kalman_filters/aabb/base_kalman_filter.py @@ -0,0 +1,158 @@ +import numpy as np +import scipy.linalg +from typing import Tuple + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919 +} + +class BaseKalmanFilter: + """ + Base class for Kalman filters tracking bounding boxes in image space. + """ + + def __init__(self, ndim: int): + self.ndim = ndim + self.dt = 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) # State transition matrix + for i in range(ndim): + self._motion_mat[i, ndim + i] = self.dt + self._update_mat = np.eye(ndim, 2 * ndim) # Observation matrix + + # Motion and observation uncertainty weights. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Create track from unassociated measurement. + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = self._get_initial_covariance_std(measurement) + covariance = np.diag(np.square(std)) + return mean, covariance + + def _get_initial_covariance_std(self, measurement: np.ndarray) -> np.ndarray: + """ + Return initial standard deviations for the covariance matrix. + Should be implemented by subclasses. + """ + raise NotImplementedError + + def predict(self, mean: np.ndarray, covariance: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Run Kalman filter prediction step. + """ + std_pos, std_vel = self._get_process_noise_std(mean) + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def _get_process_noise_std(self, mean: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Return standard deviations for process noise. + Should be implemented by subclasses. + """ + raise NotImplementedError + + def project(self, mean: np.ndarray, covariance: np.ndarray, confidence: float = 0.0) -> Tuple[np.ndarray, np.ndarray]: + """ + Project state distribution to measurement space. + """ + std = self._get_measurement_noise_std(mean, confidence) + + # NSA Kalman algorithm from GIAOTracker, which proposes a formula to + # adaptively calculate the noise covariance Rek: + # Rk = (1 − ck) Rk + # where Rk is the preset constant measurement noise covariance + # and ck is the detection confidence score at state k. Intuitively, + # the detection has a higher score ck when it has less noise, + # which results in a low Re. + std = [(1 - confidence) * x for x in std] + + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Run Kalman filter prediction step (Vectorized version). + """ + std_pos, std_vel = self._get_multi_process_noise_std(mean) + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray, confidence: float = 0.0) -> Tuple[np.ndarray, np.ndarray]: + """ + Run Kalman filter correction step. + """ + projected_mean, projected_cov = self.project(mean, covariance, confidence) + + chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def _get_multi_process_noise_std(self, mean: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Return standard deviations for process noise in vectorized form. + Should be implemented by subclasses. + """ + raise NotImplementedError + + def gating_distance(self, mean: np.ndarray, covariance: np.ndarray, measurements: np.ndarray, only_position: bool = False, metric: str = 'maha') -> np.ndarray: + """ + Compute gating distance between state distribution and measurements. + """ + mean, covariance = self.project(mean, covariance) + + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == 'gaussian': + return np.sum(d * d, axis=1) + elif metric == 'maha': + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + else: + raise ValueError('invalid distance metric') \ No newline at end of file diff --git a/boxmot/motion/kalman_filters/aabb/xyah_kf.py b/boxmot/motion/kalman_filters/aabb/xyah_kf.py new file mode 100644 index 0000000000000000000000000000000000000000..63100669fda3762137266bd550e1b73f59074654 --- /dev/null +++ b/boxmot/motion/kalman_filters/aabb/xyah_kf.py @@ -0,0 +1,73 @@ +import numpy as np +from typing import Tuple +from boxmot.motion.kalman_filters.aabb.base_kalman_filter import BaseKalmanFilter + + +class KalmanFilterXYAH(BaseKalmanFilter): + """ + A Kalman filter for tracking bounding boxes in image space with state space: + x, y, a, h, vx, vy, va, vh + """ + + def __init__(self): + super().__init__(ndim=4) + + def _get_initial_covariance_std(self, measurement: np.ndarray) -> np.ndarray: + # initial uncertainty in the aspect ratio is very low, + # suggesting that it is not expected to vary significantly. + return [ + 2 * self._std_weight_position * measurement[3], # x + 2 * self._std_weight_position * measurement[3], # y + 1e-2, # a (aspect ratio) + 2 * self._std_weight_position * measurement[3], # H + 10 * self._std_weight_velocity * measurement[3], # vx + 10 * self._std_weight_velocity * measurement[3], # vy + 1e-5, # va (aspect ration vel) + 10 * self._std_weight_velocity * measurement[3] # vh + ] + + def _get_process_noise_std(self, mean: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + # very small process noise standard deviation assigned to the + # aspect ratio state and its velocity. Suggests + # that the aspect ratio is expected to remain relatively constant + # over time. + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3] + ] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3] + ] + return std_pos, std_vel + + def _get_measurement_noise_std(self, mean: np.ndarray, confidence: float) -> np.ndarray: + # small measurement noise standard deviation for + # aspect ratio state, indicating low expected measurement noise in + # the aspect ratio. + std_noise = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3] + ] + return std_noise + + def _get_multi_process_noise_std(self, mean: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + std_pos = [ + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3] + ] + std_vel = [ + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3] + ] + return std_pos, std_vel diff --git a/boxmot/motion/kalman_filters/aabb/xysr_kf.py b/boxmot/motion/kalman_filters/aabb/xysr_kf.py new file mode 100644 index 0000000000000000000000000000000000000000..0043c073cb6c9055648a1ec9cc2aa8c5ab9f6a83 --- /dev/null +++ b/boxmot/motion/kalman_filters/aabb/xysr_kf.py @@ -0,0 +1,469 @@ +""" +This module implements the linear Kalman filter in both an object +oriented and procedural form. The KalmanFilter class implements +the filter by storing the various matrices in instance variables, +minimizing the amount of bookkeeping you have to do. +All Kalman filters operate with a predict->update cycle. The +predict step, implemented with the method or function predict(), +uses the state transition matrix F to predict the state in the next +time period (epoch). The state is stored as a gaussian (x, P), where +x is the state (column) vector, and P is its covariance. Covariance +matrix Q specifies the process covariance. In Bayesian terms, this +prediction is called the *prior*, which you can think of colloquially +as the estimate prior to incorporating the measurement. +The update step, implemented with the method or function `update()`, +incorporates the measurement z with covariance R, into the state +estimate (x, P). The class stores the system uncertainty in S, +the innovation (residual between prediction and measurement in +measurement space) in y, and the Kalman gain in k. The procedural +form returns these variables to you. In Bayesian terms this computes +the *posterior* - the estimate after the information from the +measurement is incorporated. +Whether you use the OO form or procedural form is up to you. If +matrices such as H, R, and F are changing each epoch, you'll probably +opt to use the procedural form. If they are unchanging, the OO +form is perhaps easier to use since you won't need to keep track +of these matrices. This is especially useful if you are implementing +banks of filters or comparing various KF designs for performance; +a trivial coding bug could lead to using the wrong sets of matrices. +This module also offers an implementation of the RTS smoother, and +other helper functions, such as log likelihood computations. +The Saver class allows you to easily save the state of the +KalmanFilter class after every update. +""" + +from __future__ import absolute_import, division + +from copy import deepcopy +from math import log, exp, sqrt +import sys +import numpy as np +from numpy import dot, zeros, eye, isscalar, shape +import numpy.linalg as linalg +from filterpy.stats import logpdf +from filterpy.common import pretty_str, reshape_z +from collections import deque + + +class KalmanFilterXYSR(object): + """ Implements a Kalman filter. You are responsible for setting the + various state variables to reasonable values; the defaults will + not give you a functional filter. + """ + + def __init__(self, dim_x, dim_z, dim_u=0, max_obs=50): + if dim_x < 1: + raise ValueError('dim_x must be 1 or greater') + if dim_z < 1: + raise ValueError('dim_z must be 1 or greater') + if dim_u < 0: + raise ValueError('dim_u must be 0 or greater') + + self.dim_x = dim_x + self.dim_z = dim_z + self.dim_u = dim_u + + self.x = zeros((dim_x, 1)) # state + self.P = eye(dim_x) # uncertainty covariance + self.Q = eye(dim_x) # process uncertainty + self.B = None # control transition matrix + self.F = eye(dim_x) # state transition matrix + self.H = zeros((dim_z, dim_x)) # measurement function + self.R = eye(dim_z) # measurement uncertainty + self._alpha_sq = 1. # fading memory control + self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation + self.z = np.array([[None]*self.dim_z]).T + + # gain and residual are computed during the innovation step. We + # save them so that in case you want to inspect them for various + # purposes + self.K = np.zeros((dim_x, dim_z)) # kalman gain + self.y = zeros((dim_z, 1)) + self.S = np.zeros((dim_z, dim_z)) # system uncertainty + self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty + + # identity matrix. Do not alter this. + self._I = np.eye(dim_x) + + # these will always be a copy of x,P after predict() is called + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + # these will always be a copy of x,P after update() is called + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # Only computed only if requested via property + self._log_likelihood = log(sys.float_info.min) + self._likelihood = sys.float_info.min + self._mahalanobis = None + + # keep all observations + self.max_obs = max_obs + self.history_obs = deque([], maxlen=self.max_obs) + + self.inv = np.linalg.inv + + self.attr_saved = None + self.observed = False + self.last_measurement = None + + + def apply_affine_correction(self, m, t): + """ + Apply to both last state and last observation for OOS smoothing. + + Messy due to internal logic for kalman filter being messy. + """ + + scale = np.linalg.norm(m[:, 0]) + self.x[:2] = m @ self.x[:2] + t + self.x[4:6] = m @ self.x[4:6] + + self.P[:2, :2] = m @ self.P[:2, :2] @ m.T + self.P[4:6, 4:6] = m @ self.P[4:6, 4:6] @ m.T + + # If frozen, also need to update the frozen state for OOS + if not self.observed and self.attr_saved is not None: + self.attr_saved["x"][:2] = m @ self.attr_saved["x"][:2] + t + self.attr_saved["x"][4:6] = m @ self.attr_saved["x"][4:6] + + self.attr_saved["P"][:2, :2] = m @ self.attr_saved["P"][:2, :2] @ m.T + self.attr_saved["P"][4:6, 4:6] = m @ self.attr_saved["P"][4:6, 4:6] @ m.T + + self.attr_saved["last_measurement"][:2] = m @ self.attr_saved["last_measurement"][:2] + t + + + def predict(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the Kalman filter state propagation + equations. + Parameters + ---------- + u : np.array, default 0 + Optional control vector. + B : np.array(dim_x, dim_u), or None + Optional control transition matrix; a value of None + will cause the filter to use `self.B`. + F : np.array(dim_x, dim_x), or None + Optional state transition matrix; a value of None + will cause the filter to use `self.F`. + Q : np.array(dim_x, dim_x), scalar, or None + Optional process noise matrix; a value of None will cause the + filter to use `self.Q`. + """ + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + # x = Fx + Bu + if B is not None and u is not None: + self.x = dot(F, self.x) + dot(B, u) + else: + self.x = dot(F, self.x) + + # P = FPF' + Q + self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + # save prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + def freeze(self): + """ + Save the parameters before non-observation forward + """ + self.attr_saved = deepcopy(self.__dict__) + + def unfreeze(self): + if self.attr_saved is not None: + new_history = deepcopy(list(self.history_obs)) + self.__dict__ = self.attr_saved + self.history_obs = deque(list(self.history_obs)[:-1], maxlen=self.max_obs) + occur = [int(d is None) for d in new_history] + indices = np.where(np.array(occur) == 0)[0] + index1, index2 = indices[-2], indices[-1] + box1, box2 = new_history[index1], new_history[index2] + x1, y1, s1, r1 = box1 + w1, h1 = np.sqrt(s1 * r1), np.sqrt(s1 / r1) + x2, y2, s2, r2 = box2 + w2, h2 = np.sqrt(s2 * r2), np.sqrt(s2 / r2) + time_gap = index2 - index1 + dx, dy = (x2 - x1) / time_gap, (y2 - y1) / time_gap + dw, dh = (w2 - w1) / time_gap, (h2 - h1) / time_gap + + for i in range(index2 - index1): + x, y = x1 + (i + 1) * dx, y1 + (i + 1) * dy + w, h = w1 + (i + 1) * dw, h1 + (i + 1) * dh + s, r = w * h, w / float(h) + new_box = np.array([x, y, s, r]).reshape((4, 1)) + self.update(new_box) + if not i == (index2 - index1 - 1): + self.predict() + self.history_obs.pop() + self.history_obs.pop() + + def update(self, z, R=None, H=None): + """ + Add a new measurement (z) to the Kalman filter. If z is None, nothing is changed. + Parameters + ---------- + z : np.array + Measurement for this update. z can be a scalar if dim_z is 1, + otherwise it must be a column vector. + R : np.array, scalar, or None + Measurement noise. If None, the filter's self.R value is used. + H : np.array, or None + Measurement function. If None, the filter's self.H value is used. + """ + + # set to None to force recompute + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + # append the observation + self.history_obs.append(z) + + if z is None: + if self.observed: + """ + Got no observation so freeze the current parameters for future + potential online smoothing. + """ + self.last_measurement = self.history_obs[-2] + self.freeze() + self.observed = False + self.z = np.array([[None] * self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + # self.observed = True + if not self.observed: + """ + Get observation, use online smoothing to re-update parameters + """ + self.unfreeze() + self.observed = True + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + if H is None: + z = reshape_z(z, self.dim_z, self.x.ndim) + H = self.H + + # y = z - Hx + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # common subexpression for speed + PHT = dot(self.P, H.T) + + # S = HPH' + R + self.S = dot(H, PHT) + R + self.SI = self.inv(self.S) + + # K = PH'inv(S) + self.K = PHT.dot(self.SI) + + # x = x + Ky + self.x = self.x + dot(self.K, self.y) + + # P = (I-KH)P(I-KH)' + KRK' + I_KH = self._I - dot(self.K, H) + self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T) + + # save measurement and posterior state + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # save history of observations + self.history_obs.append(z) + + def update_steadystate(self, z, H=None): + """ Update Kalman filter using the Kalman gain and state covariance + matrix as computed for the steady state. Only x is updated, and the + new value is stored in self.x. P is left unchanged. Must be called + after a prior call to compute_steady_state(). + """ + if z is None: + self.history_obs.append(z) + return + + if H is None: + H = self.H + + H = np.asarray(H) + # error (residual) between measurement and prediction + self.y = z - dot(H, self.x) + + # x = x + Ky + self.x = self.x + dot(self.K_steady_state, self.y) + + # save measurement and posterior state + self.z = deepcopy(z) + self.x_post = self.x.copy() + + # save history of observations + self.history_obs.append(z) + + def log_likelihood(self, z=None): + """ log-likelihood of the measurement z. Computed from the + system uncertainty S. + """ + + if z is None: + z = self.z + return logpdf(z, dot(self.H, self.x), self.S) + + def likelihood(self, z=None): + """ likelihood of the measurement z. Computed from the + system uncertainty S. + """ + + if z is None: + z = self.z + return exp(self.log_likelihood(z)) + + @property + def log_likelihood(self): + """ log-likelihood of the last measurement. + """ + + return self._log_likelihood + + @property + def likelihood(self): + """ likelihood of the last measurement. + """ + + return self._likelihood + + +def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None, update_first=False, saver=None): + """ + Batch processes a sequences of measurements. + Parameters + ---------- + zs : list-like + list of measurements at each time step. Missing measurements must be + represented by None. + Fs : list-like + list of values to use for the state transition matrix matrix. + Qs : list-like + list of values to use for the process error + covariance. + Hs : list-like + list of values to use for the measurement matrix. + Rs : list-like + list of values to use for the measurement error + covariance. + Bs : list-like, optional + list of values to use for the control transition matrix; + a value of None in any position will cause the filter + to use `self.B` for that time step. + us : list-like, optional + list of values to use for the control input vector; + a value of None in any position will cause the filter to use + 0 for that time step. + update_first : bool, optional + controls whether the order of operations is update followed by + predict, or predict followed by update. Default is predict->update. + saver : filterpy.common.Saver, optional + filterpy.common.Saver object. If provided, saver.save() will be + called after every epoch + Returns + ------- + means : np.array((n,dim_x,1)) + array of the state for each time step after the update. Each entry + is an np.array. In other words `means[k,:]` is the state at step + `k`. + covariance : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the update. + In other words `covariance[k,:,:]` is the covariance at step `k`. + means_predictions : np.array((n,dim_x,1)) + array of the state for each time step after the predictions. Each + entry is an np.array. In other words `means[k,:]` is the state at + step `k`. + covariance_predictions : np.array((n,dim_x,dim_x)) + array of the covariances for each time step after the prediction. + In other words `covariance[k,:,:]` is the covariance at step `k`. + Examples + -------- + .. code-block:: Python + zs = [t + random.randn()*4 for t in range (40)] + Fs = [kf.F for t in range (40)] + Hs = [kf.H for t in range (40)] + (mu, cov, _, _) = kf.batch_filter(zs, Rs=R_list, Fs=Fs, Hs=Hs, Qs=None, + Bs=None, us=None, update_first=False) + (xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs, Qs=None) + """ + + n = np.size(zs, 0) + dim_x = x.shape[0] + + # mean estimates from Kalman Filter + if x.ndim == 1: + means = zeros((n, dim_x)) + means_p = zeros((n, dim_x)) + else: + means = zeros((n, dim_x, 1)) + means_p = zeros((n, dim_x, 1)) + + # state covariances from Kalman Filter + covariances = zeros((n, dim_x, dim_x)) + covariances_p = zeros((n, dim_x, dim_x)) + + if us is None: + us = [0.0] * n + Bs = [0.0] * n + + if update_first: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + if saver is not None: + saver.save() + else: + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + def batch_filter(self, zs, Rs=None): + """ + Batch process a sequence of measurements. This method is suitable + for cases where the measurement noise varies with each measurement. + """ + means, covariances = [], [] + for z, R in zip(zs, Rs): + self.predict() + self.update(z, R=R) + means.append(self.x.copy()) + covariances.append(self.P.copy()) + return np.array(means), np.array(covariances) \ No newline at end of file diff --git a/boxmot/motion/kalman_filters/aabb/xywh_kf.py b/boxmot/motion/kalman_filters/aabb/xywh_kf.py new file mode 100644 index 0000000000000000000000000000000000000000..8c933ca45df75a98b6e2bf19ef3ee4e479549c4b --- /dev/null +++ b/boxmot/motion/kalman_filters/aabb/xywh_kf.py @@ -0,0 +1,64 @@ +import numpy as np +from typing import Tuple +from boxmot.motion.kalman_filters.aabb.base_kalman_filter import BaseKalmanFilter + + +class KalmanFilterXYWH(BaseKalmanFilter): + """ + A Kalman filter for tracking bounding boxes in image space with state space: + x, y, w, h, vx, vy, vw, vh + """ + + def __init__(self): + super().__init__(ndim=4) + + def _get_initial_covariance_std(self, measurement: np.ndarray) -> np.ndarray: + return [ + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3] + ] + + def _get_process_noise_std(self, mean: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + std_pos = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3] + ] + std_vel = [ + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3] + ] + return std_pos, std_vel + + def _get_measurement_noise_std(self, mean: np.ndarray, confidence: float) -> np.ndarray: + std_noise = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3] + ] + return std_noise + + def _get_multi_process_noise_std(self, mean: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + std_pos = [ + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3] + ] + std_vel = [ + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3] + ] + return std_pos, std_vel diff --git a/boxmot/motion/kalman_filters/obb/xywha_kf.py b/boxmot/motion/kalman_filters/obb/xywha_kf.py new file mode 100644 index 0000000000000000000000000000000000000000..c42f0adc6edf7f2ed50a2aed1915e98809c397e6 --- /dev/null +++ b/boxmot/motion/kalman_filters/obb/xywha_kf.py @@ -0,0 +1,637 @@ +from __future__ import absolute_import, division + +from copy import deepcopy +from math import log, exp, sqrt, pi +import sys +import numpy as np +from numpy import dot, zeros, eye, isscalar +import numpy.linalg as linalg +from filterpy.stats import logpdf +from filterpy.common import pretty_str, reshape_z +from collections import deque + + +def speed_direction_obb(bbox1, bbox2): + cx1, cy1 = bbox1[0], bbox1[1] + cx2, cy2 = bbox2[0], bbox2[1] + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +class KalmanBoxTrackerOBB(object): + """ + This class represents the internal state of individual tracked objects observed as oriented bbox. + """ + + count = 0 + + def __init__(self, bbox, cls, det_ind, delta_t=3, max_obs=50, Q_xy_scaling = 0.01, Q_a_scaling = 0.01): + """ + Initialises a tracker using initial bounding box. + """ + # define constant velocity model + self.det_ind = det_ind + + self.Q_xy_scaling = Q_xy_scaling + self.Q_a_scaling = Q_a_scaling + + self.kf = KalmanFilterXYWHA(dim_x=10, dim_z=5, max_obs=max_obs) + self.kf.F = np.array( + [ + [1, 0, 0, 0, 0, 1, 0, 0, 0, 0], # cx = cx + vx + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], # cy = cy + vy + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], # w = w + vw + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], # h = h + vh + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1], # a = a + va + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1] + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0, 0, 0 ,0], # cx + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # cy + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # w + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # h + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # angle + ] + ) + + self.kf.R[2:, 2:] *= 10.0 + self.kf.P[ + 5:, 5: + ] *= 1000.0 # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10.0 + + self.kf.Q[5:7, 5:7] *= self.Q_xy_scaling + self.kf.Q[-1, -1] *= self.Q_a_scaling + + self.kf.x[:5] = bbox[:5].reshape((5, 1)) # x, y, w, h, angle (dont take confidence score) + self.time_since_update = 0 + self.id = KalmanBoxTrackerOBB.count + KalmanBoxTrackerOBB.count += 1 + self.max_obs = max_obs + self.history = deque([], maxlen=self.max_obs) + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + self.conf = bbox[-1] + self.cls = cls + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), + let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1, -1]) #WARNING : -1 is a valid angle value + self.observations = dict() + self.history_observations = deque([], maxlen=self.max_obs) + self.velocity = None + self.delta_t = delta_t + + def update(self, bbox, cls, det_ind): + """ + Updates the state vector with observed bbox. + """ + self.det_ind = det_ind + if bbox is not None: + self.conf = bbox[-1] + self.cls = cls + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + dt = self.delta_t - i + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction_obb(previous_box, bbox) + + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + self.kf.update(bbox[:5].reshape((5, 1))) # x, y, w, h, angle as column vector (dont take confidence score) + else: + self.kf.update(bbox) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if (self.kf.x[7] + self.kf.x[2]) <= 0: # Negative width + self.kf.x[7] *= 0.0 + if (self.kf.x[8] + self.kf.x[3]) <= 0: # Negative Height + self.kf.x[8] *= 0.0 + self.kf.predict() + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(self.kf.x[0:5].reshape((1, 5))) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return self.kf.x[0:5].reshape((1, 5)) + + +class KalmanFilterXYWHA(object): + """ + Implements a Kalman Filter specialized for tracking Oriented Bounding Boxes. + The default state vector is [x, y, w, h, a]^T: + + - (x, y): center of the bounding box + - w, h : width and height of the bounding box + - a : orientation angle (radians) + + This filter supports "freeze" and "unfreeze" methods to handle missing + observations (no measurements) or out-of-sequence (OOS) smoothing logic. + """ + + def __init__(self, dim_x, dim_z, dim_u=0, max_obs=50): + """ + Parameters + ---------- + dim_x : int + Dimensionality of the state vector. Typically 5 if [x, y, w, h, a]. + dim_z : int + Dimensionality of the measurement vector. Typically also 5. + dim_u : int + Dimensionality of the control vector. Default is 0 (no control). + max_obs : int + Maximum number of stored observations for freeze/unfreeze logic. + """ + if dim_x < 1: + raise ValueError('dim_x must be 1 or greater') + if dim_z < 1: + raise ValueError('dim_z must be 1 or greater') + if dim_u < 0: + raise ValueError('dim_u must be 0 or greater') + + self.dim_x = dim_x + self.dim_z = dim_z + self.dim_u = dim_u + + # State: x is a (dim_x, 1) column vector + self.x = zeros((dim_x, 1)) # state + self.P = eye(dim_x) # covariance of the state + self.Q = eye(dim_x) # process noise covariance + self.B = None # control transition matrix + self.F = eye(dim_x) # state transition matrix + self.H = zeros((dim_z, dim_x)) # measurement function + self.R = eye(dim_z) # measurement noise covariance + self._alpha_sq = 1. # fading memory control + self.M = np.zeros((dim_x, dim_z)) # cross correlation (rarely used) + self.z = np.array([[None]*self.dim_z]).T + + # Gains and residuals computed during update + self.K = np.zeros((dim_x, dim_z)) # Kalman gain + self.y = zeros((dim_z, 1)) # residual + self.S = np.zeros((dim_z, dim_z)) # system uncertainty (innovation covariance) + self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty + + # Identity matrix (used in update) + self._I = np.eye(dim_x) + + # Save prior (after predict) and posterior (after update) + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # Internal log-likelihood computations + self._log_likelihood = log(sys.float_info.min) + self._likelihood = sys.float_info.min + self._mahalanobis = None + + # Store recent observations for freeze/unfreeze logic + self.max_obs = max_obs + self.history_obs = deque([], maxlen=self.max_obs) + + # For potential smoothing usage + self.inv = np.linalg.inv + self.attr_saved = None + self.observed = False + self.last_measurement = None + + def apply_affine_correction(self, m, t): + """ + Apply an affine transform to the current state and covariance. + This is useful if the image or reference frame is warped. + + Parameters + ---------- + m : np.array(2x2) + Affine transform (rotation/scale) to be applied to x,y and maybe w,h + t : np.array(2x1) + Translation vector to be added after applying the transform. + + TODO: adapt for oriented bounding box (especially if the orientation + is also changed by the transform). + """ + # For demonstration, we apply the transform to [x, y] and [x_dot, y_dot], etc. + # But for your OBB case, consider carefully how w, h, and angle should transform. + + # Example basic approach: transform x, y + self.x[:2] = m @ self.x[:2] + t + + # Possibly transform w, h. But if w,h are not purely length in x,y directions, + # you might have to do something more elaborate. For demonstration: + self.x[2:4] = np.abs(m @ self.x[2:4]) # naive approach: scale w,h + + # P block for positions: + self.P[:2, :2] = m @ self.P[:2, :2] @ m.T + # P block for widths/heights (again naive if we treat w,h as x,y scale): + self.P[2:4, 2:4] = m @ self.P[2:4, 2:4] @ m.T + + # If angle is included, consider adjusting it or leaving it if the transform + # is purely in the plane with no orientation offset. Could do angle wrap here. + + # If we froze the filter, also update the frozen state + if not self.observed and self.attr_saved is not None: + self.attr_saved["x"][:2] = m @ self.attr_saved["x"][:2] + t + self.attr_saved["x"][2:4] = np.abs(m @ self.attr_saved["x"][2:4]) + self.attr_saved["P"][:2, :2] = m @ self.attr_saved["P"][:2, :2] @ m.T + self.attr_saved["P"][2:4, 2:4] = m @ self.attr_saved["P"][2:4, 2:4] @ m.T + + # last_measurement might need updating similarly + self.attr_saved["last_measurement"][:2] = ( + m @ self.attr_saved["last_measurement"][:2] + t + ) + + def predict(self, u=None, B=None, F=None, Q=None): + """ + Predict next state (prior) using the state transition matrix F + and process noise Q. + + Parameters + ---------- + u : np.array(dim_u, 1), optional + Control vector. If not provided, assumed 0. + B : np.array(dim_x, dim_u), optional + Control transition matrix. If None, self.B is used. + F : np.array(dim_x, dim_x), optional + State transition matrix. If None, self.F is used. + Q : np.array(dim_x, dim_x) or scalar, optional + Process noise matrix. If None, self.Q is used. If scalar, + Q = scalar * I. + """ + if B is None: + B = self.B + if F is None: + F = self.F + if Q is None: + Q = self.Q + elif isscalar(Q): + Q = eye(self.dim_x) * Q + + # x = F x + B u + if B is not None and u is not None: + self.x = dot(F, self.x) + dot(B, u) + else: + self.x = dot(F, self.x) + + # P = F P F^T + Q + self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q + + # Save the prior + self.x_prior = self.x.copy() + self.P_prior = self.P.copy() + + # ---- New: Enforce bounding box and angle constraints (if dim_x >= 5) ---- + if self.dim_x >= 5: + # clamp w, h > 0 + self.x[2, 0] = max(self.x[2, 0], 1e-4) + self.x[3, 0] = max(self.x[3, 0], 1e-4) + + # wrap angle to [-pi, pi] + self.x[4, 0] = (self.x[4, 0] + pi) % (2 * pi) - pi + + def freeze(self): + """ + Save the current filter parameters in attr_saved so that if the next + observation is missing, we can revert to these parameters for + out-of-sequence or offline smoothing. + """ + self.attr_saved = deepcopy(self.__dict__) + + def unfreeze(self): + """ + Revert the filter parameters to the saved (frozen) state, then "replay" + the missing measurements from history to smooth the intermediate states. + """ + if self.attr_saved is not None: + new_history = deepcopy(list(self.history_obs)) + # revert to the frozen attributes + self.__dict__ = self.attr_saved + # remove last measurement from history (since we'll re-apply them) + self.history_obs = deque(list(self.history_obs)[:-1], maxlen=self.max_obs) + + # naive approach: re-update states between the two known measurements + occur = [int(d is None) for d in new_history] + indices = np.where(np.array(occur) == 0)[0] + if len(indices) < 2: + return # not enough measurements to replay + + index1, index2 = indices[-2], indices[-1] + box1, box2 = new_history[index1], new_history[index2] + x1, y1, w1, h1, a1 = box1 + x2, y2, w2, h2, a2 = box2 + time_gap = index2 - index1 + dx, dy = (x2 - x1) / time_gap, (y2 - y1) / time_gap + dw, dh = (w2 - w1) / time_gap, (h2 - h1) / time_gap + da = (a2 - a1) / time_gap + + for i in range(index2 - index1): + x_ = x1 + (i + 1) * dx + y_ = y1 + (i + 1) * dy + w_ = w1 + (i + 1) * dw + h_ = h1 + (i + 1) * dh + a_ = a1 + (i + 1) * da + + new_box = np.array([x_, y_, w_, h_, a_]).reshape((5, 1)) + self.update(new_box) + if i != (index2 - index1 - 1): + self.predict() + self.history_obs.pop() + self.history_obs.pop() + + def update(self, z, R=None, H=None): + """ + Incorporate a new measurement z into the state estimate. + + Parameters + ---------- + z : np.array(dim_z, 1) + Measurement vector. If None, skip update step (missing measurement). + R : np.array(dim_z, dim_z), scalar, or None + Measurement noise matrix. If None, self.R is used. + H : np.array(dim_z, dim_x) or None + Measurement function. If None, self.H is used. + """ + # reset log-likelihood computations + self._log_likelihood = None + self._likelihood = None + self._mahalanobis = None + + # Save the observation (even if None) + self.history_obs.append(z) + + # If measurement is missing + if z is None: + if self.observed: + # freeze the current parameters for future potential smoothing + self.last_measurement = self.history_obs[-2] + self.freeze() + self.observed = False + self.z = np.array([[None] * self.dim_z]).T + self.x_post = self.x.copy() + self.P_post = self.P.copy() + self.y = zeros((self.dim_z, 1)) + return + + # If we haven't observed for a while, revert to the frozen state + if not self.observed: + self.unfreeze() + self.observed = True + + if R is None: + R = self.R + elif isscalar(R): + R = eye(self.dim_z) * R + + if H is None: + H = self.H + z = reshape_z(z, self.dim_z, self.x.ndim) + + # y = z - Hx (residual) + self.y = z - dot(H, self.x) + + PHT = dot(self.P, H.T) + self.S = dot(H, PHT) + R + self.SI = self.inv(self.S) + + # K = PHT * SI + self.K = PHT.dot(self.SI) + + # Optional gating (commented out): + # mahal_dist = float(self.y.T @ self.SI @ self.y) + # gating_threshold = 9.21 # e.g., chi-square with 2-5 dof + # if mahal_dist > gating_threshold: + # # Outlier measurement, skip or handle differently + # return + + # x = x + K y + self.x = self.x + dot(self.K, self.y) + + # P = (I - K H) P (I - K H)^T + K R K^T + I_KH = self._I - dot(self.K, H) + self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T) + + # Save measurement and posterior + self.z = deepcopy(z) + self.x_post = self.x.copy() + self.P_post = self.P.copy() + + # ---- New: Enforce bounding box and angle constraints (if dim_x >= 5) ---- + if self.dim_x >= 5: + # clamp w, h > 0 + self.x[2, 0] = max(self.x[2, 0], 1e-4) + self.x[3, 0] = max(self.x[3, 0], 1e-4) + + # wrap angle to [-pi, pi] + self.x[4, 0] = (self.x[4, 0] + pi) % (2 * pi) - pi + + def update_steadystate(self, z, H=None): + """ + Update using precomputed steady-state gain (K_steady_state) and + steady-state covariance P. Only x is updated here. + P remains unchanged. + """ + if z is None: + self.history_obs.append(z) + return + + if H is None: + H = self.H + + # residual + self.y = z - dot(H, self.x) + + # x = x + K_steady_state * y + self.x = self.x + dot(self.K_steady_state, self.y) + + # Save measurement and posterior + self.z = deepcopy(z) + self.x_post = self.x.copy() + + self.history_obs.append(z) + + def log_likelihood_of(self, z=None): + """ + Compute the log-likelihood of measurement z given the current + measurement prediction. This uses logpdf from filterpy.stats. + """ + if z is None: + z = self.z + return logpdf(z, dot(self.H, self.x), self.S) + + def likelihood_of(self, z=None): + """ + Compute the likelihood (probability) of measurement z given + the current measurement prediction. + """ + return exp(self.log_likelihood_of(z)) + + @property + def log_likelihood(self): + """ log-likelihood of the last measurement. """ + return self._log_likelihood + + @property + def likelihood(self): + """ likelihood of the last measurement. """ + return self._likelihood + + +def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None, + update_first=False, saver=None): + """ + Batch processes a sequence of measurements. + + Parameters + ---------- + x : np.array(dim_x, 1) + Initial state. + P : np.array(dim_x, dim_x) + Initial covariance. + zs : list-like + List of measurements at each time step (None for missing). + Fs : list-like + State transition matrices for each step. + Qs : list-like + Process noise covariances for each step. + Hs : list-like + Measurement matrices for each step. + Rs : list-like + Measurement noise covariances for each step. + Bs : list-like, optional + Control transition matrices for each step. + us : list-like, optional + Control vectors for each step. + update_first : bool + If True, apply update->predict. Otherwise predict->update. + saver : filterpy.common.Saver, optional + If provided, saver.save() is called at each step. + + Returns + ------- + means : np.array((n,dim_x,1)) + covariances : np.array((n,dim_x,dim_x)) + means_p : np.array((n,dim_x,1)) + Predictions after each step + covariances_p : np.array((n,dim_x,dim_x)) + Covariances after prediction each step + """ + n = np.size(zs, 0) + dim_x = x.shape[0] + + # Arrays to store results + if x.ndim == 1: + means = np.zeros((n, dim_x)) + means_p = np.zeros((n, dim_x)) + else: + means = np.zeros((n, dim_x, 1)) + means_p = np.zeros((n, dim_x, 1)) + + covariances = np.zeros((n, dim_x, dim_x)) + covariances_p = np.zeros((n, dim_x, dim_x)) + + if us is None: + us = [0.0] * n + Bs = [0.0] * n + + # Procedural version of predict->update or update->predict + for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)): + + if update_first: + # Update step + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + + # Predict step + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + + else: + # Predict step + x, P = predict(x, P, u=u, B=B, F=F, Q=Q) + means_p[i, :] = x + covariances_p[i, :, :] = P + + # Update step + x, P = update(x, P, z, R=R, H=H) + means[i, :] = x + covariances[i, :, :] = P + + if saver is not None: + saver.save() + + return (means, covariances, means_p, covariances_p) + + +def update(x, P, z, R, H): + """ + Procedural form of the update step of the Kalman Filter. + """ + if z is None: + return x, P + + # y = z - Hx + y = z - dot(H, x) + PHT = dot(P, H.T) + S = dot(H, PHT) + R + SI = linalg.inv(S) + K = dot(PHT, SI) + + # x = x + Ky + x = x + dot(K, y) + + # P = (I - KH)P(I - KH)' + KRK' + I_KH = np.eye(x.shape[0]) - dot(K, H) + P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T) + + return x, P + + +def predict(x, P, F, Q, B=None, u=None): + """ + Procedural form of the predict step of the Kalman Filter. + """ + if B is not None and u is not None: + x = dot(F, x) + dot(B, u) + else: + x = dot(F, x) + + P = dot(dot(F, P), F.T) + Q + return x, P diff --git a/boxmot/postprocessing/__init__.py b/boxmot/postprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/postprocessing/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/postprocessing/gsi.py b/boxmot/postprocessing/gsi.py new file mode 100644 index 0000000000000000000000000000000000000000..86a1b5b751a1adff3ee31dde3408ef1cf75d2209 --- /dev/null +++ b/boxmot/postprocessing/gsi.py @@ -0,0 +1,143 @@ +from pathlib import Path +import numpy as np +import argparse +from sklearn.gaussian_process import GaussianProcessRegressor as GPR +from sklearn.gaussian_process.kernels import RBF +from boxmot.utils import logger as LOGGER +import concurrent.futures +from tqdm import tqdm + + +def linear_interpolation(data: np.ndarray, interval: int) -> np.ndarray: + """ + Apply linear interpolation between rows in the tracking results. + + The function assumes the first two columns of `data` represent frame number and object ID. + Interpolated rows are added when consecutive rows for the same ID have a gap of more than 1 + frame but less than the specified interval. + + Parameters: + data (np.ndarray): Input tracking results. + interval (int): Maximum gap to perform interpolation. + + Returns: + np.ndarray: Tracking results with interpolated rows included. + """ + # Sort data by frame and then by ID + sorted_data = data[np.lexsort((data[:, 0], data[:, 1]))] + result_rows = [] + previous_id = None + previous_frame = None + previous_row = None + + for row in sorted_data: + current_frame, current_id = int(row[0]), int(row[1]) + if previous_id is not None and current_id == previous_id and previous_frame + 1 < current_frame < previous_frame + interval: + gap = current_frame - previous_frame - 1 + for i in range(1, gap + 1): + # Linear interpolation for each missing frame + new_row = previous_row + (row - previous_row) * (i / (current_frame - previous_frame)) + result_rows.append(new_row) + result_rows.append(row) + previous_id, previous_frame, previous_row = current_id, current_frame, row + + result_array = np.array(result_rows) + # Resort the array + return result_array[np.lexsort((result_array[:, 0], result_array[:, 1]))] + + +def gaussian_smooth(data: np.ndarray, tau: float) -> np.ndarray: + """ + Apply Gaussian process smoothing to specified columns in the tracking results. + + For each unique object ID in the data, this function smooths columns 2 through 5 using + a Gaussian Process with an RBF kernel. Additional columns (columns 6 and 7) and a constant + value (-1) are appended to each row. + + Parameters: + data (np.ndarray): Tracking results. + tau (float): Smoothing parameter. + + Returns: + np.ndarray: Tracking results with smoothed columns. + """ + smoothed_output = [] + unique_ids = np.unique(data[:, 1]) + for obj_id in unique_ids: + tracks = data[data[:, 1] == obj_id] + num_tracks = len(tracks) + # Determine length scale using logarithmic scaling with clipping + length_scale = np.clip(tau * np.log(tau ** 3 / num_tracks), tau ** -1, tau ** 2) + t = tracks[:, 0].reshape(-1, 1) + kernel = RBF(length_scale, length_scale_bounds="fixed") + gpr = GPR(kernel) + + # Smooth columns 2 to 5 simultaneously (if supported by your version of scikit-learn) + smoothed_columns = gpr.fit(t, tracks[:, 2:6]).predict(t) + + # Build new rows with the smoothed data, retaining other columns and appending -1 + for i in range(len(tracks)): + new_row = np.concatenate(([tracks[i, 0], obj_id], smoothed_columns[i], tracks[i, 6:8], [-1])) + smoothed_output.append(new_row) + + return np.array(smoothed_output) + + +def process_file(file_path: Path, interval: int, tau: float): + """ + Process a single MOT results file by applying linear interpolation and Gaussian smoothing. + + Parameters: + file_path (Path): Path to the tracking results file. + interval (int): Interval for linear interpolation. + tau (float): Smoothing parameter for Gaussian process. + """ + LOGGER.info(f"Applying GSI to: {file_path}") + tracking_results = np.loadtxt(file_path, delimiter=',') + if tracking_results.size != 0: + interpolated_results = linear_interpolation(tracking_results, interval) + smoothed_results = gaussian_smooth(interpolated_results, tau) + np.savetxt(file_path, smoothed_results, fmt='%d %d %d %d %d %d %d %d %d') + else: + LOGGER.warning(f'No tracking results in {file_path}. Skipping...') + + +def gsi(mot_results_folder: Path, interval: int = 20, tau: float = 10): + """ + Apply Gaussian Smoothed Interpolation (GSI) to all tracking result files in a folder. + + Parameters: + mot_results_folder (Path): Path to the folder containing MOT result files. + interval (int, optional): Maximum gap to perform interpolation. Defaults to 20. + tau (float, optional): Smoothing parameter for Gaussian process. Defaults to 10. + """ + tracking_files = list(mot_results_folder.glob('MOT*.txt')) + total_files = len(tracking_files) + LOGGER.info(f"Found {total_files} file(s) to process.") + + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = {executor.submit(process_file, file_path, interval, tau): file_path for file_path in tracking_files} + for future in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Processing files"): + file_path = futures[future] + try: + future.result() + except Exception as e: + LOGGER.error(f"Error processing file {file_path}: {e}") + + +def main(): + """ + Parse command line arguments and run the Gaussian Smoothed Interpolation process. + """ + parser = argparse.ArgumentParser( + description='Apply Gaussian Smoothed Interpolation (GSI) to tracking results.' + ) + parser.add_argument('--path', type=str, required=True, help='Path to MOT results folder') + args = parser.parse_args() + + mot_results_folder = Path(args.path) + gsi(mot_results_folder) + + +if __name__ == "__main__": + main() diff --git a/boxmot/tracker_zoo.py b/boxmot/tracker_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3ffdf6c815ddc44dafb10ab1ba141e9ea17b53 --- /dev/null +++ b/boxmot/tracker_zoo.py @@ -0,0 +1,73 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import yaml +from boxmot.utils import BOXMOT, TRACKER_CONFIGS + +def get_tracker_config(tracker_type): + """Returns the path to the tracker configuration file.""" + return TRACKER_CONFIGS / f'{tracker_type}.yaml' + +def create_tracker(tracker_type, tracker_config=None, reid_weights=None, device=None, half=None, per_class=None, evolve_param_dict=None): + """ + Creates and returns an instance of the specified tracker type. + + Parameters: + - tracker_type: The type of the tracker (e.g., 'strongsort', 'ocsort'). + - tracker_config: Path to the tracker configuration file. + - reid_weights: Weights for ReID (re-identification). + - device: Device to run the tracker on (e.g., 'cpu', 'cuda'). + - half: Boolean indicating whether to use half-precision. + - per_class: Boolean for class-specific tracking (optional). + - evolve_param_dict: A dictionary of parameters for evolving the tracker. + + Returns: + - An instance of the selected tracker. + """ + + # Load configuration from file or use provided dictionary + if evolve_param_dict is None: + with open(tracker_config, "r") as f: + yaml_config = yaml.load(f, Loader=yaml.FullLoader) + tracker_args = {param: details['default'] for param, details in yaml_config.items()} + else: + tracker_args = evolve_param_dict + + # Arguments specific to ReID models + reid_args = { + 'reid_weights': reid_weights, + 'device': device, + 'half': half, + } + + # Map tracker types to their corresponding classes + tracker_mapping = { + 'strongsort': 'boxmot.trackers.strongsort.strongsort.StrongSort', + 'ocsort': 'boxmot.trackers.ocsort.ocsort.OcSort', + 'bytetrack': 'boxmot.trackers.bytetrack.bytetrack.ByteTrack', + 'botsort': 'boxmot.trackers.botsort.botsort.BotSort', + 'deepocsort': 'boxmot.trackers.deepocsort.deepocsort.DeepOcSort', + 'hybridsort': 'boxmot.trackers.hybridsort.hybridsort.HybridSort', + 'imprassoc': 'boxmot.trackers.imprassoc.imprassoctrack.ImprAssocTrack', + 'boosttrack': 'boxmot.trackers.boosttrack.boosttrack.BoostTrack', + } + + # Check if the tracker type exists in the mapping + if tracker_type not in tracker_mapping: + print('Error: No such tracker found.') + exit() + + # Dynamically import and instantiate the correct tracker class + module_path, class_name = tracker_mapping[tracker_type].rsplit('.', 1) + tracker_class = getattr(__import__(module_path, fromlist=[class_name]), class_name) + + # For specific trackers, update tracker arguments with ReID parameters + if tracker_type in ['strongsort', 'botsort', 'deepocsort', 'hybridsort', 'imprassoc', 'boosttrack']: + tracker_args['per_class'] = per_class + tracker_args.update(reid_args) + if tracker_type in ['strongsort', 'boosttrack']: + tracker_args.pop('per_class') # per class not supported by + else: + tracker_args['per_class'] = per_class + + # Return the instantiated tracker class with arguments + return tracker_class(**tracker_args) \ No newline at end of file diff --git a/boxmot/trackers/__init__.py b/boxmot/trackers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/basetracker.py b/boxmot/trackers/basetracker.py new file mode 100644 index 0000000000000000000000000000000000000000..88772e33506055fe35e48271444f7fcb07f343ed --- /dev/null +++ b/boxmot/trackers/basetracker.py @@ -0,0 +1,373 @@ +import numpy as np +import cv2 as cv +import hashlib +import colorsys +from abc import ABC, abstractmethod +from boxmot.utils import logger as LOGGER +from boxmot.utils.iou import AssociationFunction + + +class BaseTracker(ABC): + def __init__( + self, + det_thresh: float = 0.3, + max_age: int = 30, + min_hits: int = 3, + iou_threshold: float = 0.3, + max_obs: int = 50, + nr_classes: int = 80, + per_class: bool = False, + asso_func: str = 'iou', + is_obb: bool = False + ): + """ + Initialize the BaseTracker object with detection threshold, maximum age, minimum hits, + and Intersection Over Union (IOU) threshold for tracking objects in video frames. + + Parameters: + - det_thresh (float): Detection threshold for considering detections. + - max_age (int): Maximum age of a track before it is considered lost. + - min_hits (int): Minimum number of detection hits before a track is considered confirmed. + - iou_threshold (float): IOU threshold for determining match between detection and tracks. + + Attributes: + - frame_count (int): Counter for the frames processed. + - active_tracks (list): List to hold active tracks, may be used differently in subclasses. + """ + self.det_thresh = det_thresh + self.max_age = max_age + self.max_obs = max_obs + self.min_hits = min_hits + self.per_class = per_class # Track per class or not + self.nr_classes = nr_classes + self.iou_threshold = iou_threshold + self.last_emb_size = None + self.asso_func_name = asso_func+"_obb" if is_obb else asso_func + self.is_obb = is_obb + + self.frame_count = 0 + self.active_tracks = [] # This might be handled differently in derived classes + self.per_class_active_tracks = None + self._first_frame_processed = False # Flag to track if the first frame has been processed + self._first_dets_processed = False + + # Initialize per-class active tracks + if self.per_class: + self.per_class_active_tracks = {} + for i in range(self.nr_classes): + self.per_class_active_tracks[i] = [] + + if self.max_age >= self.max_obs: + LOGGER.warning("Max age > max observations, increasing size of max observations...") + self.max_obs = self.max_age + 5 + print("self.max_obs", self.max_obs) + + @abstractmethod + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + """ + Abstract method to update the tracker with new detections for a new frame. This method + should be implemented by subclasses. + + Parameters: + - dets (np.ndarray): Array of detections for the current frame. + - img (np.ndarray): The current frame as an image array. + - embs (np.ndarray, optional): Embeddings associated with the detections, if any. + + Raises: + - NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError("The update method needs to be implemented by the subclass.") + + def get_class_dets_n_embs(self, dets, embs, cls_id): + # Initialize empty arrays for detections and embeddings + class_dets = np.empty((0, 6)) + class_embs = np.empty((0, self.last_emb_size)) if self.last_emb_size is not None else None + + # Check if there are detections + if dets.size > 0: + class_indices = np.where(dets[:, 5] == cls_id)[0] + class_dets = dets[class_indices] + + if embs is not None: + # Assert that if embeddings are provided, they have the same number of elements as detections + assert dets.shape[0] == embs.shape[0], "Detections and embeddings must have the same number of elements when both are provided" + + if embs.size > 0: + class_embs = embs[class_indices] + self.last_emb_size = class_embs.shape[1] # Update the last known embedding size + else: + class_embs = None + return class_dets, class_embs + + @staticmethod + def setup_decorator(method): + """ + Decorator to perform setup on the first frame only. + This ensures that initialization tasks (like setting the association function) only + happen once, on the first frame, and are skipped on subsequent frames. + """ + def wrapper(self, *args, **kwargs): + # If setup hasn't been done yet, perform it + # Even if dets is empty (e.g., shape (0, 7)), this check will still pass if it's Nx7 + if not self._first_dets_processed: + dets = args[0] + if dets is not None: + if dets.ndim == 2 and dets.shape[1] == 6: + self.is_obb = False + self._first_dets_processed = True + elif dets.ndim == 2 and dets.shape[1] == 7: + self.is_obb = True + self._first_dets_processed = True + + if not self._first_frame_processed: + img = args[1] + self.h, self.w = img.shape[0:2] + self.asso_func = AssociationFunction(w=self.w, h=self.h, asso_mode=self.asso_func_name).asso_func + + # Mark that the first frame setup has been done + self._first_frame_processed = True + + # Call the original method (e.g., update) + return method(self, *args, **kwargs) + + return wrapper + + + @staticmethod + def per_class_decorator(update_method): + """ + Decorator for the update method to handle per-class processing. + """ + def wrapper(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None): + + #handle different types of inputs + if dets is None or len(dets) == 0: + dets = np.empty((0, 6)) + + if self.per_class: + # Initialize an array to store the tracks for each class + per_class_tracks = [] + + # same frame count for all classes + frame_count = self.frame_count + + for cls_id in range(self.nr_classes): + # Get detections and embeddings for the current class + class_dets, class_embs = self.get_class_dets_n_embs(dets, embs, cls_id) + + LOGGER.debug(f"Processing class {int(cls_id)}: {class_dets.shape} with embeddings {class_embs.shape if class_embs is not None else None}") + + # Activate the specific active tracks for this class id + self.active_tracks = self.per_class_active_tracks[cls_id] + + # Reset frame count for every class + self.frame_count = frame_count + + # Update detections using the decorated method + tracks = update_method(self, dets=class_dets, img=img, embs=class_embs) + + # Save the updated active tracks + self.per_class_active_tracks[cls_id] = self.active_tracks + + if tracks.size > 0: + per_class_tracks.append(tracks) + + # Increase frame count by 1 + self.frame_count = frame_count + 1 + + return np.vstack(per_class_tracks) if per_class_tracks else np.empty((0, 8)) + else: + # Process all detections at once if per_class is False + return update_method(self, dets=dets, img=img, embs=embs) + return wrapper + + + def check_inputs(self, dets, img): + assert isinstance( + dets, np.ndarray + ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" + assert isinstance( + img, np.ndarray + ), f"Unsupported 'img_numpy' input format '{type(img)}', valid format is np.ndarray" + assert ( + len(dets.shape) == 2 + ), "Unsupported 'dets' dimensions, valid number of dimensions is two" + if self.is_obb: + assert ( + dets.shape[1] == 7 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6 (cx,cy,w,h,angle,conf,cls)" + else : + assert ( + dets.shape[1] == 6 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6 (x1,y1,x2,y2,conf,cls)" + + + def id_to_color(self, id: int, saturation: float = 0.75, value: float = 0.95) -> tuple: + """ + Generates a consistent unique BGR color for a given ID using hashing. + + Parameters: + - id (int): Unique identifier for which to generate a color. + - saturation (float): Saturation value for the color in HSV space. + - value (float): Value (brightness) for the color in HSV space. + + Returns: + - tuple: A tuple representing the BGR color. + """ + + # Hash the ID to get a consistent unique value + hash_object = hashlib.sha256(str(id).encode()) + hash_digest = hash_object.hexdigest() + + # Convert the first few characters of the hash to an integer + # and map it to a value between 0 and 1 for the hue + hue = int(hash_digest[:8], 16) / 0xffffffff + + # Convert HSV to RGB + rgb = colorsys.hsv_to_rgb(hue, saturation, value) + + # Convert RGB from 0-1 range to 0-255 range and format as hexadecimal + rgb_255 = tuple(int(component * 255) for component in rgb) + hex_color = '#%02x%02x%02x' % rgb_255 + # Strip the '#' character and convert the string to RGB integers + rgb = tuple(int(hex_color.strip('#')[i:i+2], 16) for i in (0, 2, 4)) + + # Convert RGB to BGR for OpenCV + bgr = rgb[::-1] + + return bgr + + def plot_box_on_img(self, img: np.ndarray, box: tuple, conf: float, cls: int, id: int, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray: + """ + Draws a bounding box with ID, confidence, and class information on an image. + + Parameters: + - img (np.ndarray): The image array to draw on. + - box (tuple): The bounding box coordinates as (x1, y1, x2, y2). + - conf (float): Confidence score of the detection. + - cls (int): Class ID of the detection. + - id (int): Unique identifier for the detection. + - thickness (int): The thickness of the bounding box. + - fontscale (float): The font scale for the text. + + Returns: + - np.ndarray: The image array with the bounding box drawn on it. + """ + if self.is_obb: + + angle = box[4] * 180.0 / np.pi # Convert radians to degrees + box_poly = ((box[0], box[1]), (box[2], box[3]), angle) + # print((width, height)) + rotrec = cv.boxPoints(box_poly) + box_poly = np.int_(rotrec) # Convert to integer + + # Draw the rectangle on the image + img = cv.polylines(img, [box_poly], isClosed=True, color=self.id_to_color(id), thickness=thickness) + + img = cv.putText( + img, + f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}, a: {box[4]:.2f}', + (int(box[0]), int(box[1]) - 10), + cv.FONT_HERSHEY_SIMPLEX, + fontscale, + self.id_to_color(id), + thickness + ) + else : + + img = cv.rectangle( + img, + (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), + self.id_to_color(id), + thickness + ) + img = cv.putText( + img, + f'id: {int(id)}, conf: {conf:.2f}, c: {int(cls)}', + (int(box[0]), int(box[1]) - 10), + cv.FONT_HERSHEY_SIMPLEX, + fontscale, + self.id_to_color(id), + thickness + ) + return img + + + def plot_trackers_trajectories(self, img: np.ndarray, observations: list, id: int) -> np.ndarray: + """ + Draws the trajectories of tracked objects based on historical observations. Each point + in the trajectory is represented by a circle, with the thickness increasing for more + recent observations to visualize the path of movement. + + Parameters: + - img (np.ndarray): The image array on which to draw the trajectories. + - observations (list): A list of bounding box coordinates representing the historical + observations of a tracked object. Each observation is in the format (x1, y1, x2, y2). + - id (int): The unique identifier of the tracked object for color consistency in visualization. + + Returns: + - np.ndarray: The image array with the trajectories drawn on it. + """ + for i, box in enumerate(observations): + trajectory_thickness = int(np.sqrt(float (i + 1)) * 1.2) + if self.is_obb: + img = cv.circle( + img, + (int(box[0]), int(box[1])), + 2, + color=self.id_to_color(int(id)), + thickness=trajectory_thickness + ) + else: + + img = cv.circle( + img, + (int((box[0] + box[2]) / 2), + int((box[1] + box[3]) / 2)), + 2, + color=self.id_to_color(int(id)), + thickness=trajectory_thickness + ) + return img + + + def plot_results(self, img: np.ndarray, show_trajectories: bool, thickness: int = 2, fontscale: float = 0.5) -> np.ndarray: + """ + Visualizes the trajectories of all active tracks on the image. For each track, + it draws the latest bounding box and the path of movement if the history of + observations is longer than two. This helps in understanding the movement patterns + of each tracked object. + + Parameters: + - img (np.ndarray): The image array on which to draw the trajectories and bounding boxes. + - show_trajectories (bool): Whether to show the trajectories. + - thickness (int): The thickness of the bounding box. + - fontscale (float): The font scale for the text. + + Returns: + - np.ndarray: The image array with trajectories and bounding boxes of all active tracks. + """ + + # if values in dict + if self.per_class_active_tracks is not None: + for k in self.per_class_active_tracks.keys(): + active_tracks = self.per_class_active_tracks[k] + for a in active_tracks: + if a.history_observations: + if len(a.history_observations) > 2: + box = a.history_observations[-1] + img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale) + if show_trajectories: + img = self.plot_trackers_trajectories(img, a.history_observations, a.id) + else: + for a in self.active_tracks: + if a.history_observations: + if len(a.history_observations) > 2: + box = a.history_observations[-1] + img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id, thickness, fontscale) + if show_trajectories: + img = self.plot_trackers_trajectories(img, a.history_observations, a.id) + + return img + diff --git a/boxmot/trackers/boosttrack/assoc.py b/boxmot/trackers/boosttrack/assoc.py new file mode 100644 index 0000000000000000000000000000000000000000..a37b79b0e660efee86108242977112bcf25e614c --- /dev/null +++ b/boxmot/trackers/boosttrack/assoc.py @@ -0,0 +1,209 @@ +import warnings +from copy import deepcopy +from typing import Optional + +import lap +import numpy as np + + +def shape_similarity(detects: np.ndarray, tracks: np.ndarray, s_sim_corr: bool) -> np.ndarray: + if not s_sim_corr: + return shape_similarity_v1(detects, tracks) + else: + return shape_similarity_v2(detects, tracks) + +def shape_similarity_v1(detects: np.ndarray, tracks: np.ndarray) -> np.ndarray: + if detects.size == 0 or tracks.size == 0: + return np.zeros((0, 0)) + + dw = (detects[:, 2] - detects[:, 0]).reshape((-1, 1)) + dh = (detects[:, 3] - detects[:, 1]).reshape((-1, 1)) + tw = (tracks[:, 2] - tracks[:, 0]).reshape((1, -1)) + th = (tracks[:, 3] - tracks[:, 1]).reshape((1, -1)) + return np.exp(-(np.abs(dw - tw)/np.maximum(dw, tw) + np.abs(dh - th)/np.maximum(dw, tw))) + + +def shape_similarity_v2(detects: np.ndarray, tracks: np.ndarray) -> np.ndarray: + if detects.size == 0 or tracks.size == 0: + return np.zeros((0, 0)) + + dw = (detects[:, 2] - detects[:, 0]).reshape((-1, 1)) + dh = (detects[:, 3] - detects[:, 1]).reshape((-1, 1)) + tw = (tracks[:, 2] - tracks[:, 0]).reshape((1, -1)) + th = (tracks[:, 3] - tracks[:, 1]).reshape((1, -1)) + return np.exp(-(np.abs(dw - tw)/np.maximum(dw, tw) + np.abs(dh - th)/np.maximum(dh, th))) + + +def MhDist_similarity(mahalanobis_distance: np.ndarray, softmax_temp: float = 1.0) -> np.ndarray: + limit = 13.2767 # 99% conf interval https://www.mathworks.com/help/stats/chi2inv.html + mahalanobis_distance = deepcopy(mahalanobis_distance) + mask = mahalanobis_distance > limit + mahalanobis_distance[mask] = limit + mahalanobis_distance = limit - mahalanobis_distance + + mahalanobis_distance = np.exp(mahalanobis_distance/softmax_temp) / np.exp(mahalanobis_distance/softmax_temp).sum(0).reshape((1, -1)) + mahalanobis_distance = np.where(mask, 0, mahalanobis_distance) + return mahalanobis_distance + + +def iou_batch(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + o = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + - wh + ) + + return o + + +def soft_biou_batch(bboxes1, bboxes2): + """ + Computes soft BIoU between two bboxes in the form [x1,y1,x2,y2] + BIoU is introduced in https://arxiv.org/pdf/2211.14317 + Soft BIoU is introduced as part of BoostTrack++ + # Author : Vukasin Stanojevic + # Email : vukasin.stanojevic@pmf.edu.rs + """ + + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + k1 = 0.25 + k2 = 0.5 + b2conf = bboxes2[..., 4] + b1x1 = bboxes1[..., 0] - (bboxes1[..., 2]-bboxes1[..., 0]) * (1-b2conf)*k1 + b2x1 = bboxes2[..., 0] - (bboxes2[..., 2]-bboxes2[..., 0]) * (1-b2conf)*k2 + xx1 = np.maximum(b1x1, b2x1) + + b1y1 = bboxes1[..., 1] - (bboxes1[..., 3]-bboxes1[..., 1]) * (1-b2conf)*k1 + b2y1 = bboxes2[..., 1] - (bboxes2[..., 3]-bboxes2[..., 1]) * (1-b2conf)*k2 + yy1 = np.maximum(b1y1, b2y1) + + b1x2 = bboxes1[..., 2] + (bboxes1[..., 2]-bboxes1[..., 0]) * (1-b2conf)*k1 + b2x2 = bboxes2[..., 2] + (bboxes2[..., 2]-bboxes2[..., 0]) * (1-b2conf)*k2 + xx2 = np.minimum(b1x2, b2x2) + + b1y2 = bboxes1[..., 3] + (bboxes1[..., 3]-bboxes1[..., 1]) * (1-b2conf)*k1 + b2y2 = bboxes2[..., 3] + (bboxes2[..., 3]-bboxes2[..., 1]) * (1-b2conf)*k2 + yy2 = np.minimum(b1y2, b2y2) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + + o = wh / ( + (b1x2 - b1x1) * (b1y2 - b1y1) + + (b2x2 - b2x1) * (b2y2 - b2y1) + - wh + ) + + return o + + +def match(cost_matrix: np.ndarray, threshold: float) -> np.ndarray: + if cost_matrix.size > 0: + a = (cost_matrix > threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + _, x, y = lap.lapjv(-cost_matrix, extend_cost=True) + matched_indices = np.array([[y[i], i] for i in x if i >= 0]) + else: + matched_indices = np.empty(shape=(0, 2)) + return matched_indices + + +def linear_assignment(detections: np.ndarray, trackers: np.ndarray, + iou_matrix: np.ndarray, cost_matrix: np.ndarray, + threshold: float, emb_cost: Optional[np.ndarray] = None): + if iou_matrix is None and cost_matrix is None: + raise Exception("Both iou_matrix and cost_matrix are None!") + if iou_matrix is None: + iou_matrix = deepcopy(cost_matrix) + if cost_matrix is None: + cost_matrix = deepcopy(iou_matrix) + matched_indices = match(cost_matrix, threshold) + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + valid_match = iou_matrix[m[0], m[1]] >= threshold or (False if emb_cost is None else (iou_matrix[m[0], m[1]] >= threshold / 2 and emb_cost[m[0], m[1]] >= 0.75)) + if valid_match: + matches.append(m.reshape(1, 2)) + else: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers), cost_matrix + + +def associate( + detections, + trackers, + iou_threshold, + mahalanobis_distance: Optional[np.ndarray] = None, + track_confidence: Optional[np.ndarray] = None, + detection_confidence: Optional[np.ndarray] = None, + emb_cost: Optional[np.ndarray] = None, + lambda_iou: float = 0.5, + lambda_mhd: float = 0.25, + lambda_shape: float = 0.25, + s_sim_corr: bool = False, +): + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + np.empty((0, 0)) + ) + iou_matrix = iou_batch(detections, trackers) + + cost_matrix = deepcopy(iou_matrix) + + if detection_confidence is not None and track_confidence is not None: + conf = np.multiply(detection_confidence.reshape((-1, 1)), track_confidence.reshape((1, -1))) + conf[iou_matrix < iou_threshold] = 0 + + cost_matrix += lambda_iou * conf * iou_batch(detections, trackers) + else: + warnings.warn("Detections or tracklet confidence is None and detection-tracklet confidence cannot be computed!") + conf = None + + if mahalanobis_distance is not None and mahalanobis_distance.size > 0: + mahalanobis_distance = MhDist_similarity(mahalanobis_distance) + + cost_matrix += lambda_mhd * mahalanobis_distance + if conf is not None: + cost_matrix += lambda_shape * conf * shape_similarity(detections, trackers, s_sim_corr) + + if emb_cost is not None: + lambda_emb = (1+lambda_iou+lambda_shape+lambda_mhd) * 1.5 + cost_matrix += lambda_emb * emb_cost + + return linear_assignment(detections, trackers, iou_matrix, cost_matrix, iou_threshold, emb_cost) \ No newline at end of file diff --git a/boxmot/trackers/boosttrack/boosttrack.py b/boxmot/trackers/boosttrack/boosttrack.py new file mode 100644 index 0000000000000000000000000000000000000000..2e79f10daf96b10aa3039f6c9a438d44ef31718d --- /dev/null +++ b/boxmot/trackers/boosttrack/boosttrack.py @@ -0,0 +1,398 @@ +import numpy as np +from typing import Optional, List +from collections import deque + +from boxmot.trackers.boosttrack.assoc import ( + associate, + iou_batch, + MhDist_similarity, + shape_similarity, + soft_biou_batch, +) +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.trackers.boosttrack.kalmanfilter import KalmanFilter +from boxmot.trackers.boosttrack.ecc import ECC +from boxmot.trackers.basetracker import BaseTracker + + +def convert_bbox_to_z(bbox): + """ + Converts a bounding box [x1,y1,x2,y2] to state vector [x, y, h, r]. + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2.0 + y = bbox[1] + h / 2.0 + r = w / float(h + 1e-6) + return np.array([x, y, h, r]).reshape((4, 1)) + + +def convert_x_to_bbox(x, score=None): + """ + Converts a state vector [x, y, h, r] back to bounding box [x1,y1,x2,y2]. + """ + h = x[2] + r = x[3] + w = 0 if r <= 0 else r * h + if score is None: + return np.array([x[0] - w / 2.0, x[1] - h / 2.0, + x[0] + w / 2.0, x[1] + h / 2.0]).reshape((1, 4)) + else: + return np.array([x[0] - w / 2.0, x[1] - h / 2.0, + x[0] + w / 2.0, x[1] + h / 2.0, score]).reshape((1, 5)) + + +class KalmanBoxTracker: + """ + Single object tracker using a Kalman filter. + """ + count = 0 + + def __init__(self, det, max_obs, emb: Optional[np.ndarray] = None): + self.bbox_to_z_func = convert_bbox_to_z + self.x_to_bbox_func = convert_x_to_bbox + KalmanBoxTracker.count += 1 + + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + self.kf = KalmanFilter(self.bbox_to_z_func(det[:4])) + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + self.emb = emb + self.hit_streak = 0 + self.age = 0 + self.history_observations = deque([], maxlen=max_obs) + + def get_confidence(self, coef: float = 0.9) -> float: + n = 7 + if self.age < n: + return coef ** (n - self.age) + return coef ** (self.time_since_update - 1) + + def update(self, det: np.ndarray, score: float = 0): + self.time_since_update = 0 + self.hit_streak += 1 + self.history_observations.append(self.get_state()[0]) + self.kf.update(self.bbox_to_z_func(det), score) + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + + def camera_update(self, transform: np.ndarray): + x1, y1, x2, y2 = self.get_state()[0] + x1_, y1_, _ = transform @ np.array([x1, y1, 1]).T + x2_, y2_, _ = transform @ np.array([x2, y2, 1]).T + w, h = x2_ - x1_, y2_ - y1_ + cx, cy = x1_ + w / 2, y1_ + h / 2 + self.kf.x[:4] = [cx, cy, h, w / h] + + def predict(self): + self.kf.predict() + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + return self.get_state() + + def get_state(self): + return self.x_to_bbox_func(self.kf.x) + + def update_emb(self, emb, alpha=0.9): + self.emb = alpha * self.emb + (1 - alpha) * emb + self.emb /= np.linalg.norm(self.emb) + + def get_emb(self): + return self.emb + + +class BoostTrack(BaseTracker): + + def __init__( + self, + reid_weights, + device, + half: bool, + + max_age: int = 60, + min_hits: int = 3, + det_thresh: float = 0.6, + iou_threshold: float = 0.3, + use_ecc: bool = True, + min_box_area: int = 10, + aspect_ratio_thresh: bool = 1.6, + + # BoostTrack parameters + lambda_iou: float = 0.5, + lambda_mhd: float = 0.25, + lambda_shape: float = 0.25, + use_dlo_boost: bool = True, + use_duo_boost: bool = True, + dlo_boost_coef: float = 0.65, + s_sim_corr: bool = False, + + # BoostTrack++ parameters + use_rich_s: bool = False, + use_sb: bool = False, + use_vt: bool = False, + + with_reid: bool = False, + ): + super().__init__() + self.frame_count = 0 + self.trackers: List[KalmanBoxTracker] = [] + + # Parameters for BoostTrack (these can be tuned as needed) + self.max_age = max_age # maximum allowed frames without update + self.min_hits = min_hits # minimum hits to output a track + self.det_thresh = det_thresh # detection confidence threshold + self.iou_threshold = iou_threshold # association IoU threshold + self.use_ecc = use_ecc # use ECC for camera motion compensation + self.min_box_area = min_box_area # minimum box area for detections + self.aspect_ratio_thresh = aspect_ratio_thresh # aspect ratio threshold for detections + + self.lambda_iou = lambda_iou + self.lambda_mhd = lambda_mhd + self.lambda_shape = lambda_shape + self.use_dlo_boost = use_dlo_boost + self.use_duo_boost = use_duo_boost + self.dlo_boost_coef = dlo_boost_coef + self.s_sim_corr = s_sim_corr + + self.use_rich_s = use_rich_s + self.use_sb = use_sb + self.use_vt = use_vt + + self.with_reid = with_reid + + if self.with_reid: + self.reid_model = ReidAutoBackend(weights=reid_weights, device=device, half=half).model + else: + self.reid_model = None + + if self.use_ecc: + self.ecc = ECC(scale=350, video_name=None, use_cache=True) + else: + self.ecc = None + + def update(self, dets: np.ndarray, img: np.ndarray, embs: Optional[np.ndarray] = None) -> np.ndarray: + """ + Update the tracker with detections and an image. + + Args: + dets (np.ndarray): Detection boxes in the format [[x1,y1,x2,y2,score], ...] + img (np.ndarray): The current image frame. + embs (Optional[np.ndarray]): Optional precomputed embeddings. + + Returns: + np.ndarray: Tracked objects in the format + [x1, y1, x2, y2, id, confidence, cls, det_ind] + (with cls and det_ind set to -1 if unused) + """ + if dets is None or dets.size == 0: + dets = np.empty((0, 6)) + + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + + self.frame_count += 1 + + if self.ecc is not None: + transform = self.ecc(img, self.frame_count) + for trk in self.trackers: + trk.camera_update(transform) + + trks = [] + confs = [] + + for trk in self.trackers: + pos = trk.predict()[0] + conf = trk.get_confidence() + confs.append(conf) + trks.append(np.concatenate([pos, [conf]])) + trks_np = np.vstack(trks) if len(trks) > 0 else np.empty((0, 5)) + + if self.use_dlo_boost: + dets = self.dlo_confidence_boost(dets) + if self.use_duo_boost: + dets = self.duo_confidence_boost(dets) + + dets_embs = np.ones((dets.shape[0], 1)) + if dets.size > 0: + remain_inds = dets[:, 4] >= self.det_thresh + dets = dets[remain_inds] + scores = dets[:, 4] + + if self.with_reid: + if embs is not None: + dets_embs = embs[remain_inds] + else: + dets_embs = self.reid_model.get_features(dets[:, :4], img) + else: + scores = np.empty(0) + dets_embs = np.ones((dets.shape[0], 1)) + + + if self.with_reid and len(self.trackers) > 0: + tracker_embs = np.array([trk.get_emb() for trk in self.trackers]) + if dets_embs.shape[0] == 0: + emb_cost = np.empty((0, tracker_embs.shape[0])) + else: + emb_cost = dets_embs.reshape(dets_embs.shape[0], -1) @ tracker_embs.reshape((tracker_embs.shape[0], -1)).T + else: + emb_cost = None + + mh_dist_matrix = self.get_mh_dist_matrix(dets) + + matched, unmatched_dets, unmatched_trks, _ = associate( + dets, + trks_np, + self.iou_threshold, + mahalanobis_distance=mh_dist_matrix, + track_confidence=np.array(confs).reshape(-1, 1), + detection_confidence=scores, + emb_cost=emb_cost, + lambda_iou=self.lambda_iou, + lambda_mhd=self.lambda_mhd, + lambda_shape=self.lambda_shape, + s_sim_corr=self.s_sim_corr + ) + + if dets.size > 0: + trust = (dets[:, 4] - self.det_thresh) / (1 - self.det_thresh) + af = 0.95 + dets_alpha = af + (1 - af) * (1 - trust) + else: + dets_alpha = np.empty(0) + + for m in matched: + self.trackers[m[1]].update(dets[m[0], :], scores[m[0]]) + self.trackers[m[1]].update_emb(dets_embs[m[0]], alpha=dets_alpha[m[0]]) + + for i in unmatched_dets: + if dets[i, 4] >= self.det_thresh: + self.trackers.append(KalmanBoxTracker(dets[i, :], max_obs=self.max_obs, emb=dets_embs[i])) + + outputs = [] + self.active_tracks = [] + for trk in self.trackers: + d = trk.get_state()[0] + if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + # Format: [x1, y1, x2, y2, id, confidence, cls, det_ind] + outputs.append(np.array([d[0], d[1], d[2], d[3], trk.id + 1, trk.conf, trk.cls, trk.det_ind])) + self.active_tracks.append(trk) + + self.trackers = [trk for trk in self.trackers if trk.time_since_update <= self.max_age] + + if len(outputs) > 0: + outputs = np.vstack(outputs) + return self.filter_outputs(outputs) + return np.empty((0, 8)) + + + def filter_outputs(self, outputs: np.ndarray) -> np.ndarray: + + w_arr = outputs[:, 2] - outputs[:, 0] + h_arr = outputs[:, 3] - outputs[:, 1] + + vertical_filter = w_arr / h_arr <= self.aspect_ratio_thresh + area_filter = w_arr * h_arr > self.min_box_area + + return outputs[vertical_filter & area_filter] + + def dump_cache(self): + if self.ecc is not None: + self.ecc.save_cache() + + def get_iou_matrix(self, detections: np.ndarray, buffered: bool = False) -> np.ndarray: + trackers = np.zeros((len(self.trackers), 5)) + for t, trk in enumerate(trackers): + pos = self.trackers[t].get_state()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], self.trackers[t].get_confidence()] + + return iou_batch(detections, trackers) if not buffered else soft_biou_batch(detections, trackers) + + def get_mh_dist_matrix(self, detections: np.ndarray, n_dims: int = 4) -> np.ndarray: + if len(self.trackers) == 0: + return np.zeros((0, 0)) + z = np.zeros((len(detections), n_dims), dtype=float) + x = np.zeros((len(self.trackers), n_dims), dtype=float) + sigma_inv = np.zeros((len(self.trackers), n_dims), dtype=float) + + for i in range(len(detections)): + z[i, :n_dims] = convert_bbox_to_z(detections[i, :]).reshape(-1)[:n_dims] + for i, trk in enumerate(self.trackers): + x[i] = trk.kf.x[:n_dims] + sigma_inv[i] = np.reciprocal(np.diag(trk.kf.covariance[:n_dims, :n_dims])) + return ((z.reshape((-1, 1, n_dims)) - x.reshape((1, -1, n_dims))) ** 2 * + sigma_inv.reshape((1, -1, n_dims))).sum(axis=2) + + def duo_confidence_boost(self, detections: np.ndarray) -> np.ndarray: + if len(detections) == 0: + return detections + + n_dims = 4 + limit = 13.2767 + mh_dist = self.get_mh_dist_matrix(detections, n_dims) + if mh_dist.size > 0 and self.frame_count > 1: + min_dists = mh_dist.min(1) + mask = (min_dists > limit) & (detections[:, 4] < self.det_thresh) + boost_inds = np.where(mask)[0] + iou_limit = 0.3 + if len(boost_inds) > 0: + bdiou = iou_batch(detections[boost_inds], detections[boost_inds]) - np.eye(len(boost_inds)) + bdiou_max = bdiou.max(axis=1) + remaining = boost_inds[bdiou_max <= iou_limit] + args = np.where(bdiou_max > iou_limit)[0] + for i in range(len(args)): + bi = args[i] + tmp = np.where(bdiou[bi] > iou_limit)[0] + args_tmp = np.append(np.intersect1d(boost_inds[args], boost_inds[tmp]), boost_inds[bi]) + conf_max = np.max(detections[args_tmp, 4]) + if detections[boost_inds[bi], 4] == conf_max: + remaining = np.concatenate([remaining, [boost_inds[bi]]]) + mask_boost = np.zeros_like(detections[:, 4], dtype=bool) + mask_boost[remaining] = True + detections[:, 4] = np.where(mask_boost, self.det_thresh + 1e-4, detections[:, 4]) + return detections + + def dlo_confidence_boost(self, detections: np.ndarray) -> np.ndarray: + if len(detections) == 0: + return detections + + sbiou_matrix = self.get_iou_matrix(detections, True) + if sbiou_matrix.size == 0: + return detections + + trackers = np.zeros((len(self.trackers), 6)) + for t, trk in enumerate(self.trackers): + pos = trk.get_state()[0] + trackers[t] = [pos[0], pos[1], pos[2], pos[3], 0, trk.time_since_update - 1] + + if self.use_rich_s: + mhd_sim = MhDist_similarity(self.get_mh_dist_matrix(detections), 1) + shape_sim = shape_similarity(detections, trackers, self.s_sim_corr) + S = (mhd_sim + shape_sim + sbiou_matrix) / 3 + else: + S = self.get_iou_matrix(detections, False) + + if not self.use_sb and not self.use_vt: + max_s = S.max(1) + detections[:, 4] = np.maximum(detections[:, 4], max_s * self.dlo_boost_coef) + else: + if self.use_sb: + max_s = S.max(1) + alpha = 0.65 + detections[:, 4] = np.maximum( + detections[:, 4], + alpha * detections[:, 4] + (1 - alpha) * max_s ** 1.5) + if self.use_vt: + threshold_s = 0.95 + threshold_e = 0.8 + n_steps = 20 + alpha = (threshold_s - threshold_e) / n_steps + tmp = (S > np.maximum(threshold_s - np.array([trk.time_since_update - 1 for trk in self.trackers]), + threshold_e)).max(1) + scores = detections[:, 4].copy() + scores[tmp] = np.maximum(scores[tmp], self.det_thresh + 1e-5) + detections[:, 4] = scores + return detections diff --git a/boxmot/trackers/boosttrack/ecc.py b/boxmot/trackers/boosttrack/ecc.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8eafa45579e3a90bfd8ce188a602bf1c446b84 --- /dev/null +++ b/boxmot/trackers/boosttrack/ecc.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# Author : HuangPiao +# Email : huangpiao2985@163.com +# Date : 3/11/2019 + +from __future__ import division + +from copy import deepcopy +from typing import Optional, Dict + +import numpy as np +import cv2 +import os +import json + + +def ecc(src, dst, warp_mode = cv2.MOTION_EUCLIDEAN, eps = 1e-5, + max_iter = 100, scale = 0.1, align = False): + """Compute the warp matrix from src to dst. + + Parameters + ---------- + src : ndarray + An NxM matrix of source img(BGR or Gray), it must be the same format as dst. + dst : ndarray + An NxM matrix of target img(BGR or Gray). + warp_mode: flags of opencv + translation: cv2.MOTION_TRANSLATION + rotated and shifted: cv2.MOTION_EUCLIDEAN + affine(shift,rotated,shear): cv2.MOTION_AFFINE + homography(3d): cv2.MOTION_HOMOGRAPHY + eps: float + the threshold of the increment in the correlation coefficient between two iterations + max_iter: int + the number of iterations. + scale: float or [int, int] + scale_ratio: float + scale_size: [W, H] + align: bool + whether to warp affine or perspective transforms to the source image + + Returns + ------- + warp matrix : ndarray + Returns the warp matrix from src to dst. + if motion models is homography, the warp matrix will be 3x3, otherwise 2x3 + src_aligned: ndarray + aligned source image of gray + """ + assert src.shape == dst.shape, "the source image must be the same format to the target image!" + + # BGR2GRAY + if src.ndim == 3: + # Convert images to grayscale + src = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY) + dst = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY) + + # make the imgs smaller to speed up + if scale is not None: + if isinstance(scale, float): + if scale != 1: + src_r = cv2.resize(src, (0, 0), fx = scale, fy = scale,interpolation = cv2.INTER_LINEAR) + dst_r = cv2.resize(dst, (0, 0), fx = scale, fy = scale,interpolation = cv2.INTER_LINEAR) + scale = [scale, scale] + else: + src_r, dst_r = src, dst + scale = None + elif isinstance(scale, int): + scale = scale / src.shape[1] + src_r = cv2.resize(src, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) + dst_r = cv2.resize(dst, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) + scale = [scale, scale] + else: + if scale[0] != src.shape[1] and scale[1] != src.shape[0]: + src_r = cv2.resize(src, (scale[0], scale[1]), interpolation = cv2.INTER_LINEAR) + dst_r = cv2.resize(dst, (scale[0], scale[1]), interpolation=cv2.INTER_LINEAR) + scale = [scale[0] / src.shape[1], scale[1] / src.shape[0]] + else: + src_r, dst_r = src, dst + scale = None + else: + src_r, dst_r = src, dst + + # Define 2x3 or 3x3 matrices and initialize the matrix to identity + if warp_mode == cv2.MOTION_HOMOGRAPHY : + warp_matrix = np.eye(3, 3, dtype=np.float32) + else : + warp_matrix = np.eye(2, 3, dtype=np.float32) + + # Define termination criteria + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, max_iter, eps) + + # Run the ECC algorithm. The results are stored in warp_matrix. + (cc, warp_matrix) = cv2.findTransformECC (src_r, dst_r, warp_matrix, warp_mode, criteria, None, 1) + + if scale is not None: + warp_matrix[0, 2] = warp_matrix[0, 2] / scale[0] + warp_matrix[1, 2] = warp_matrix[1, 2] / scale[1] + + if align: + sz = src.shape + if warp_mode == cv2.MOTION_HOMOGRAPHY: + # Use warpPerspective for Homography + src_aligned = cv2.warpPerspective(src, warp_matrix, (sz[1],sz[0]), flags=cv2.INTER_LINEAR) + else : + # Use warpAffine for Translation, Euclidean and Affine + src_aligned = cv2.warpAffine(src, warp_matrix, (sz[1],sz[0]), flags=cv2.INTER_LINEAR) + return warp_matrix, src_aligned + else: + return warp_matrix, None + + +class ECC: + + def __init__(self, warp_mode = cv2.MOTION_EUCLIDEAN, eps = 1e-4, + max_iter = 100, scale = 0.15, align = False, + video_name: Optional[str] = None, use_cache: bool = True): + self.wrap_mode = warp_mode + self.eps = eps + self.max_iter = max_iter + self.scale = scale + self.align = align + self.prev_image: Optional[np.ndarray] = None + self.video_name = video_name + self.use_cache = use_cache + self.cache: Dict[str, np.ndarray] = dict() + if self.use_cache and self.video_name is not None and len(self.video_name) > 0: + try: + self.cache = json.load(open(os.path.join("cache", self.video_name + ".json"), 'r')) + for k in self.cache: + self.cache[k] = np.array(self.cache[k]) + if len(self.cache) > 1: + print("USING CMC CACHE!") + except: + pass + + def __call__(self, np_image: np.ndarray, frame_id: int, video: Optional[str] = "") -> np.ndarray: + if frame_id == 1: + self.prev_image = deepcopy(np_image) + return np.eye(3, dtype=float) + key = "{}-{}".format(video, frame_id) + if key in self.cache: + return self.cache[key] + + result, _ = ecc(self.prev_image, np_image, self.wrap_mode, self.eps, self.max_iter, self.scale, self.align) + self.prev_image = deepcopy(np_image) + if result.shape == (2, 3): + result = np.vstack((result, np.array([[0, 0, 1]], dtype=float))) + + if self.use_cache: + self.cache[key] = deepcopy(result) + + return result + + def save_cache(self): + if not self.use_cache: + return + if self.video_name is not None and len(self.video_name) > 0: + f = open(os.path.join("cache", self.video_name + ".json"), "w") + for k in self.cache: + self.cache[k] = self.cache[k].tolist() + json.dump(self.cache, f) + f.close() diff --git a/boxmot/trackers/boosttrack/kalmanfilter.py b/boxmot/trackers/boosttrack/kalmanfilter.py new file mode 100644 index 0000000000000000000000000000000000000000..fedc72dc04fadb92d9e57268b795f4091fda0945 --- /dev/null +++ b/boxmot/trackers/boosttrack/kalmanfilter.py @@ -0,0 +1,186 @@ +# vim: expandtab:ts=4:sw=4 +import math +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Optional, Tuple, Union + +import numpy as np +import scipy.linalg + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class CovariancePolicy(ABC): + + def __init__(self, x_dim: int, z_dim: int): + self.x_dim = x_dim + self.z_dim = z_dim + + @abstractmethod + def get_init_state_cov(self, z: np.ndarray) -> np.ndarray: + pass + + @abstractmethod + def get_R(self, x: np.ndarray, confidence: float = 0.0) -> np.ndarray: + ... + + @abstractmethod + def get_Q(self, x: np.ndarray) -> np.ndarray: + ... + + +class ConstantNoise(CovariancePolicy): + + def get_init_state_cov(self, z: np.ndarray) -> np.ndarray: + + P = np.eye(self.x_dim) + P[4:, 4:] *= 1000.0 # give high uncertainty to the unobservable initial velocities + P *= 10.0 + + return P + + def get_R(self, x: np.ndarray, confidence: float = 0.0) -> np.ndarray: + return np.diag([1, 1, 10, 0.01]) + + def get_Q(self, x: np.ndarray) -> np.ndarray: + Q = np.eye(self.x_dim) + Q[4:, 4:] *= 0.01 + + return Q + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, h, a, vx, vy, vh, va + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, h, a) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self, z: np.ndarray, ndim: int = 8, dt: int = 1, + cov_update_policy: CovariancePolicy = ConstantNoise, + id: int = -1): + if z.ndim == 2: + z = deepcopy(z.reshape((-1, ))) + + self.dt = dt + self.ndim = ndim + self.cov_update_policy: CovariancePolicy = cov_update_policy(ndim, z.size) + # Create Kalman filter model matrices. + self._motion_mat = np.eye(ndim, ndim) + for i in range(4 - (ndim % 2)): + self._motion_mat[i, i + 4] = dt + + self._update_mat = np.eye(4, ndim) + + self.x = np.zeros((ndim,)) + self.x[:4] = z[:] + + self.covariance = self.cov_update_policy.get_init_state_cov(z) + self.id = id + + def predict(self, mean: Optional[np.ndarray] = None, + covariance: Optional[np.ndarray] = None): + """Run Kalman filter prediction step. + + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + + """ + update = False + if mean is None: + mean = self.x + covariance = self.covariance + update = True + motion_cov = self.cov_update_policy.get_Q(mean) + + mean = np.dot(self._motion_mat, mean) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + if update: + self.x = mean + self.covariance = covariance + + return mean, covariance + + def project(self, confidence=.0): + """Project state distribution to measurement space. + + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + + """ + + innovation_cov = self.cov_update_policy.get_R(self.x, 0) + + mean = np.dot(self._update_mat, self.x) + covariance = np.linalg.multi_dot(( + self._update_mat, self.covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def update(self, z: np.ndarray, confidence=.0): + """Run Kalman filter correction step. + + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + + """ + + if z.ndim == 2: + z = deepcopy(z.reshape((-1, ))) + projected_mean, projected_cov = self.project(confidence) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(self.covariance, self._update_mat.T).T, + check_finite=False).T + + innovation = z - projected_mean + + self.x = self.x + np.dot(innovation, kalman_gain.T) + self.covariance = self.covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + + return self.x, self.covariance diff --git a/boxmot/trackers/botsort/__init__.py b/boxmot/trackers/botsort/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/botsort/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/botsort/basetrack.py b/boxmot/trackers/botsort/basetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..664ee63cef75973d2f56ea0fd064c59d456efb8f --- /dev/null +++ b/boxmot/trackers/botsort/basetrack.py @@ -0,0 +1,135 @@ +from collections import OrderedDict +import numpy as np + +class TrackState: + """ + Enum-like class for tracking states. + + Attributes: + New (int): Represents a newly created track. + Tracked (int): Represents a currently tracked object. + Lost (int): Represents a temporarily lost track. + LongLost (int): Represents a track that has been lost for a long time. + Removed (int): Represents a track that has been removed. + """ + New = 0 + Tracked = 1 + Lost = 2 + LongLost = 3 + Removed = 4 + + +class BaseTrack: + """ + Base class for managing the state of a track in multi-object tracking. + + Attributes: + _count (int): Class variable to keep track of the number of tracks created. + track_id (int): The unique ID assigned to the track. + is_activated (bool): Whether the track has been activated. + state (TrackState): The current state of the track. + history (OrderedDict): A history of the track's past states or observations. + features (list): A list of feature vectors associated with the track. + curr_feature (np.ndarray): The most recent feature vector. + score (float): The confidence score of the track. + start_frame (int): The frame where the track started. + frame_id (int): The most recent frame ID associated with the track. + time_since_update (int): The number of frames since the track was last updated. + location (tuple): The location of the object in multi-camera tracking (set to infinity by default). + """ + _count = 0 + + track_id: int = 0 + is_activated: bool = False + state: int = TrackState.New + + history: OrderedDict = OrderedDict() + features: list = [] + curr_feature: np.ndarray = None + score: float = 0 + start_frame: int = 0 + frame_id: int = 0 + time_since_update: int = 0 + + # multi-camera + location: tuple = (np.inf, np.inf) + + @property + def end_frame(self) -> int: + """ + Returns the last frame the track was updated. + + Returns: + int: The frame ID of the last update. + """ + return self.frame_id + + @staticmethod + def next_id() -> int: + """ + Generates the next unique track ID. + + Returns: + int: A unique track ID. + """ + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + """ + Activates the track. This method should be implemented in subclasses. + + Args: + *args: Variable length argument list. + + Raises: + NotImplementedError: If this method is not implemented in the subclass. + """ + raise NotImplementedError + + def predict(self): + """ + Predicts the next state of the track using a motion model. This method should be implemented in subclasses. + + Raises: + NotImplementedError: If this method is not implemented in the subclass. + """ + raise NotImplementedError + + def update(self, *args, **kwargs): + """ + Updates the state of the track based on a new observation. This method should be implemented in subclasses. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Raises: + NotImplementedError: If this method is not implemented in the subclass. + """ + raise NotImplementedError + + def mark_lost(self): + """ + Marks the track as lost. + """ + self.state = TrackState.Lost + + def mark_long_lost(self): + """ + Marks the track as long lost. + """ + self.state = TrackState.LongLost + + def mark_removed(self): + """ + Marks the track as removed. + """ + self.state = TrackState.Removed + + @staticmethod + def clear_count(): + """ + Resets the track ID counter to 0. + """ + BaseTrack._count = 0 diff --git a/boxmot/trackers/botsort/botsort.py b/boxmot/trackers/botsort/botsort.py new file mode 100644 index 0000000000000000000000000000000000000000..9b81f6c6b6c5d97fc252e11d2b6d31780921e34f --- /dev/null +++ b/boxmot/trackers/botsort/botsort.py @@ -0,0 +1,327 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import torch +import numpy as np +from pathlib import Path + +from boxmot.motion.kalman_filters.aabb.xywh_kf import KalmanFilterXYWH +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.motion.cmc.sof import SOF +from boxmot.trackers.botsort.basetrack import BaseTrack, TrackState +from boxmot.utils.matching import (embedding_distance, fuse_score, + iou_distance, linear_assignment) +from boxmot.trackers.basetracker import BaseTracker +from boxmot.trackers.botsort.botsort_utils import joint_stracks, sub_stracks, remove_duplicate_stracks +from boxmot.trackers.botsort.botsort_track import STrack +from boxmot.motion.cmc import get_cmc_method + + + +class BotSort(BaseTracker): + """ + BoTSORT Tracker: A tracking algorithm that combines appearance and motion-based tracking. + + Args: + reid_weights (str): Path to the model weights for ReID. + device (torch.device): Device to run the model on (e.g., 'cpu' or 'cuda'). + half (bool): Use half-precision (fp16) for faster inference. + per_class (bool, optional): Whether to perform per-class tracking. + track_high_thresh (float, optional): Detection confidence threshold for first association. + track_low_thresh (float, optional): Detection confidence threshold for ignoring detections. + new_track_thresh (float, optional): Threshold for creating a new track. + track_buffer (int, optional): Frames to keep a track alive after last detection. + match_thresh (float, optional): Matching threshold for data association. + proximity_thresh (float, optional): IoU threshold for first-round association. + appearance_thresh (float, optional): Appearance embedding distance threshold for ReID. + cmc_method (str, optional): Method for correcting camera motion, e.g., "sof" (simple optical flow). + frame_rate (int, optional): Video frame rate, used to scale the track buffer. + fuse_first_associate (bool, optional): Fuse appearance and motion in the first association step. + with_reid (bool, optional): Use ReID features for association. + """ + + def __init__( + self, + reid_weights: Path, + device: torch.device, + half: bool, + per_class: bool = False, + track_high_thresh: float = 0.5, + track_low_thresh: float = 0.1, + new_track_thresh: float = 0.6, + track_buffer: int = 30, + match_thresh: float = 0.8, + proximity_thresh: float = 0.5, + appearance_thresh: float = 0.25, + cmc_method: str = "ecc", + frame_rate=30, + fuse_first_associate: bool = False, + with_reid: bool = True, + ): + super().__init__(per_class=per_class) + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + BaseTrack.clear_count() + + self.per_class = per_class + self.track_high_thresh = track_high_thresh + self.track_low_thresh = track_low_thresh + self.new_track_thresh = new_track_thresh + self.match_thresh = match_thresh + + self.buffer_size = int(frame_rate / 30.0 * track_buffer) + self.max_time_lost = self.buffer_size + self.kalman_filter = KalmanFilterXYWH() + + # ReID module + self.proximity_thresh = proximity_thresh + self.appearance_thresh = appearance_thresh + self.with_reid = with_reid + if self.with_reid: + self.model = ReidAutoBackend( + weights=reid_weights, device=device, half=half + ).model + + self.cmc = get_cmc_method(cmc_method)() + self.fuse_first_associate = fuse_first_associate + + @BaseTracker.setup_decorator + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + self.check_inputs(dets, img) + self.frame_count += 1 + + activated_stracks, refind_stracks, lost_stracks, removed_stracks = [], [], [], [] + + # Preprocess detections + dets, dets_first, embs_first, dets_second = self._split_detections(dets, embs) + + # Extract appearance features + if self.with_reid and embs is None: + features_high = self.model.get_features(dets_first[:, 0:4], img) + else: + features_high = embs_first if embs_first is not None else [] + + # Create detections + detections = self._create_detections(dets_first, features_high) + + # Separate unconfirmed and active tracks + unconfirmed, active_tracks = self._separate_tracks() + + strack_pool = joint_stracks(active_tracks, self.lost_stracks) + + # First association + matches_first, u_track_first, u_detection_first = self._first_association(dets, dets_first, active_tracks, unconfirmed, img, detections, activated_stracks, refind_stracks, strack_pool) + + # Second association + matches_second, u_track_second, u_detection_second = self._second_association(dets_second, activated_stracks, lost_stracks, refind_stracks, u_track_first, strack_pool) + + # Handle unconfirmed tracks + matches_unc, u_track_unc, u_detection_unc = self._handle_unconfirmed_tracks(u_detection_first, detections, activated_stracks, removed_stracks, unconfirmed) + + # Initialize new tracks + self._initialize_new_tracks(u_detection_unc, activated_stracks, [detections[i] for i in u_detection_first]) + + # Update lost and removed tracks + self._update_track_states(lost_stracks, removed_stracks) + + # Merge and prepare output + return self._prepare_output(activated_stracks, refind_stracks, lost_stracks, removed_stracks) + + def _split_detections(self, dets, embs): + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + confs = dets[:, 4] + second_mask = np.logical_and(confs > self.track_low_thresh, confs < self.track_high_thresh) + dets_second = dets[second_mask] + first_mask = confs > self.track_high_thresh + dets_first = dets[first_mask] + embs_first = embs[first_mask] if embs is not None else None + return dets, dets_first, embs_first, dets_second + + def _create_detections(self, dets_first, features_high): + if len(dets_first) > 0: + if self.with_reid: + detections = [STrack(det, f, max_obs=self.max_obs) for (det, f) in zip(dets_first, features_high)] + else: + detections = [STrack(det, max_obs=self.max_obs) for det in dets_first] + else: + detections = [] + return detections + + def _separate_tracks(self): + unconfirmed, active_tracks = [], [] + for track in self.active_tracks: + if not track.is_activated: + unconfirmed.append(track) + else: + active_tracks.append(track) + return unconfirmed, active_tracks + + def _first_association(self, dets, dets_first, active_tracks, unconfirmed, img, detections, activated_stracks, refind_stracks, strack_pool): + + STrack.multi_predict(strack_pool) + + # Fix camera motion + warp = self.cmc.apply(img, dets) + STrack.multi_gmc(strack_pool, warp) + STrack.multi_gmc(unconfirmed, warp) + + # Associate with high confidence detection boxes + ious_dists = iou_distance(strack_pool, detections) + ious_dists_mask = ious_dists > self.proximity_thresh + if self.fuse_first_associate: + ious_dists = fuse_score(ious_dists, detections) + + if self.with_reid: + emb_dists = embedding_distance(strack_pool, detections) / 2.0 + emb_dists[emb_dists > self.appearance_thresh] = 1.0 + emb_dists[ious_dists_mask] = 1.0 + dists = np.minimum(ious_dists, emb_dists) + else: + dists = ious_dists + + matches, u_track, u_detection = linear_assignment(dists, thresh=self.match_thresh) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_count) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_count, new_id=False) + refind_stracks.append(track) + + return matches, u_track, u_detection + + def _second_association(self, dets_second, activated_stracks, lost_stracks, refind_stracks, u_track_first, strack_pool): + if len(dets_second) > 0: + detections_second = [STrack(det, max_obs=self.max_obs) for det in dets_second] + else: + detections_second = [] + + r_tracked_stracks = [ + strack_pool[i] + for i in u_track_first + if strack_pool[i].state == TrackState.Tracked + ] + + dists = iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection = linear_assignment(dists, thresh=0.5) + + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_count) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_count, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + + return matches, u_track, u_detection + + + def _handle_unconfirmed_tracks(self, u_detection, detections, activated_stracks, removed_stracks, unconfirmed): + """ + Handle unconfirmed tracks (tracks with only one detection frame). + + Args: + u_detection: Unconfirmed detection indices. + detections: Current list of detections. + activated_stracks: List of newly activated tracks. + removed_stracks: List of tracks to remove. + """ + # Only use detections that are unconfirmed (filtered by u_detection) + detections = [detections[i] for i in u_detection] + + # Calculate IoU distance between unconfirmed tracks and detections + ious_dists = iou_distance(unconfirmed, detections) + + # Apply IoU mask to filter out distances that exceed proximity threshold + ious_dists_mask = ious_dists > self.proximity_thresh + ious_dists = fuse_score(ious_dists, detections) + + # Fuse scores for IoU-based and embedding-based matching (if applicable) + if self.with_reid: + emb_dists = embedding_distance(unconfirmed, detections) / 2.0 + emb_dists[emb_dists > self.appearance_thresh] = 1.0 + emb_dists[ious_dists_mask] = 1.0 # Apply the IoU mask to embedding distances + dists = np.minimum(ious_dists, emb_dists) + else: + dists = ious_dists + + # Perform data association using linear assignment on the combined distances + matches, u_unconfirmed, u_detection = linear_assignment(dists, thresh=0.7) + + # Update matched unconfirmed tracks + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_count) + activated_stracks.append(unconfirmed[itracked]) + + # Mark unmatched unconfirmed tracks as removed + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + return matches, u_unconfirmed, u_detection + + def _initialize_new_tracks(self, u_detections, activated_stracks, detections): + for inew in u_detections: + track = detections[inew] + if track.conf < self.new_track_thresh: + continue + + track.activate(self.kalman_filter, self.frame_count) + activated_stracks.append(track) + + def _update_tracks(self, matches, strack_pool, detections, activated_stracks, refind_stracks, mark_removed=False): + # Update or reactivate matched tracks + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_count) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_count, new_id=False) + refind_stracks.append(track) + + # Mark only unmatched tracks as removed, if mark_removed flag is True + if mark_removed: + unmatched_tracks = [strack_pool[i] for i in range(len(strack_pool)) if i not in [m[0] for m in matches]] + for track in unmatched_tracks: + track.mark_removed() + + def _update_track_states(self, lost_stracks, removed_stracks): + for track in self.lost_stracks: + if self.frame_count - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + def _prepare_output(self, activated_stracks, refind_stracks, lost_stracks, removed_stracks): + self.active_tracks = [ + t for t in self.active_tracks if t.state == TrackState.Tracked + ] + self.active_tracks = joint_stracks(self.active_tracks, activated_stracks) + self.active_tracks = joint_stracks(self.active_tracks, refind_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.active_tracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.active_tracks, self.lost_stracks = remove_duplicate_stracks( + self.active_tracks, self.lost_stracks + ) + + outputs = [ + [*t.xyxy, t.id, t.conf, t.cls, t.det_ind] + for t in self.active_tracks if t.is_activated + ] + + return np.asarray(outputs) diff --git a/boxmot/trackers/botsort/botsort_track.py b/boxmot/trackers/botsort/botsort_track.py new file mode 100644 index 0000000000000000000000000000000000000000..9f97b9f9bdc89e17dec41c771aa04e4dc4ce245a --- /dev/null +++ b/boxmot/trackers/botsort/botsort_track.py @@ -0,0 +1,150 @@ +import numpy as np +from collections import deque + +from boxmot.trackers.botsort.basetrack import BaseTrack, TrackState +from boxmot.motion.kalman_filters.aabb.xywh_kf import KalmanFilterXYWH +from boxmot.utils.ops import xywh2xyxy, xyxy2xywh + + +class STrack(BaseTrack): + shared_kalman = KalmanFilterXYWH() + + def __init__(self, det, feat=None, feat_history=50, max_obs=50): + # Initialize detection parameters + self.xywh = xyxy2xywh(det[:4]) # Convert to (xc, yc, w, h) + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + self.max_obs = max_obs + + # Kalman filter and tracking state + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + self.tracklet_len = 0 + + # Classification history and feature history + self.cls_hist = [] + self.history_observations = deque(maxlen=self.max_obs) + self.features = deque(maxlen=feat_history) + self.smooth_feat = None + self.curr_feat = None + self.alpha = 0.9 + + # Update initial class and features + self.update_cls(self.cls, self.conf) + if feat is not None: + self.update_features(feat) + + def update_features(self, feat): + """Normalize and update feature vectors.""" + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + self.features.append(feat) + + def update_cls(self, cls, conf): + """Update class history based on detection confidence.""" + max_freq = 0 + found = False + for c in self.cls_hist: + if cls == c[0]: + c[1] += conf + found = True + if c[1] > max_freq: + max_freq = c[1] + self.cls = c[0] + if not found: + self.cls_hist.append([cls, conf]) + self.cls = cls + + def predict(self): + """Predict the next state using Kalman filter.""" + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[6:8] = 0 # Reset velocities + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks): + """Perform batch prediction for multiple tracks.""" + if not stracks: + return + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][6:8] = 0 # Reset velocities + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for st, mean, cov in zip(stracks, multi_mean, multi_covariance): + st.mean, st.covariance = mean, cov + + @staticmethod + def multi_gmc(stracks, H=np.eye(2, 3)): + """Apply geometric motion compensation to multiple tracks.""" + if not stracks: + return + R = H[:2, :2] + R8x8 = np.kron(np.eye(4), R) + t = H[:2, 2] + + for st in stracks: + mean = R8x8.dot(st.mean) + mean[:2] += t + st.mean = mean + st.covariance = R8x8.dot(st.covariance).dot(R8x8.T) + + def activate(self, kalman_filter, frame_id): + """Activate a new track.""" + self.kalman_filter = kalman_filter + self.id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.xywh) + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + """Re-activate a track with a new detection.""" + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, new_track.xywh) + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.id = self.next_id() + self.conf = new_track.conf + self.cls = new_track.cls + self.det_ind = new_track.det_ind + self.update_cls(new_track.cls, new_track.conf) + + def update(self, new_track, frame_id): + """Update the current track with a matched detection.""" + self.frame_id = frame_id + self.tracklet_len += 1 + self.history_observations.append(self.xyxy) + + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, new_track.xywh) + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + + self.state = TrackState.Tracked + self.is_activated = True + self.conf = new_track.conf + self.cls = new_track.cls + self.det_ind = new_track.det_ind + self.update_cls(new_track.cls, new_track.conf) + + @property + def xyxy(self): + """Convert bounding box format to `(min x, min y, max x, max y)`.""" + ret = self.mean[:4].copy() if self.mean is not None else self.xywh.copy() + return xywh2xyxy(ret) diff --git a/boxmot/trackers/botsort/botsort_utils.py b/boxmot/trackers/botsort/botsort_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8526c49bd5f561769c63c6f3590694c38d8bfe32 --- /dev/null +++ b/boxmot/trackers/botsort/botsort_utils.py @@ -0,0 +1,77 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +from typing import List, Tuple +from boxmot.utils.matching import iou_distance + + +def joint_stracks(tlista: List['STrack'], tlistb: List['STrack']) -> List['STrack']: + """ + Joins two lists of tracks, ensuring that there are no duplicates based on track IDs. + + Args: + tlista (List[STrack]): The first list of tracks. + tlistb (List[STrack]): The second list of tracks. + + Returns: + List[STrack]: A combined list of tracks from both input lists, without duplicates. + """ + exists = {} + res = [] + for t in tlista: + exists[t.id] = 1 + res.append(t) + for t in tlistb: + tid = t.id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista: List['STrack'], tlistb: List['STrack']) -> List['STrack']: + """ + Subtracts the tracks in tlistb from tlista based on track IDs. + + Args: + tlista (List[STrack]): The list of tracks from which tracks will be removed. + tlistb (List[STrack]): The list of tracks to be removed from tlista. + + Returns: + List[STTrack]: The remaining tracks after removal. + """ + stracks = {t.id: t for t in tlista} + for t in tlistb: + tid = t.id + if tid in stracks: + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa: List['STrack'], stracksb: List['STrack']) -> Tuple[List['STrack'], List['STrack']]: + """ + Removes duplicate tracks between two lists based on their IoU distance and track duration. + + Args: + stracksa (List[STrack]): The first list of tracks. + stracksb (List[STrack]): The second list of tracks. + + Returns: + Tuple[List[STrack], List[STrack]]: The filtered track lists, with duplicates removed. + """ + pdist = iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = [], [] + + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + + resa = [t for i, t in enumerate(stracksa) if i not in dupa] + resb = [t for i, t in enumerate(stracksb) if i not in dupb] + + return resa, resb diff --git a/boxmot/trackers/bytetrack/__init__.py b/boxmot/trackers/bytetrack/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/bytetrack/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/bytetrack/basetrack.py b/boxmot/trackers/bytetrack/basetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..da1ec8f010b3d31d5b179d2e1636e4ba50f7690c --- /dev/null +++ b/boxmot/trackers/bytetrack/basetrack.py @@ -0,0 +1,59 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from collections import OrderedDict + +import numpy as np + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + conf = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_removed(self): + self.state = TrackState.Removed + + @staticmethod + def clear_count(): + BaseTrack._count = 0 diff --git a/boxmot/trackers/bytetrack/bytetrack.py b/boxmot/trackers/bytetrack/bytetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..40f4b110b793bb733f74f53add3b4c3a9d9457d6 --- /dev/null +++ b/boxmot/trackers/bytetrack/bytetrack.py @@ -0,0 +1,342 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +from collections import deque + +from boxmot.motion.kalman_filters.aabb.xyah_kf import KalmanFilterXYAH +from boxmot.trackers.bytetrack.basetrack import BaseTrack, TrackState +from boxmot.utils.matching import fuse_score, iou_distance, linear_assignment +from boxmot.utils.ops import tlwh2xyah, xywh2tlwh, xywh2xyxy, xyxy2xywh +from boxmot.trackers.basetracker import BaseTracker + + +class STrack(BaseTrack): + shared_kalman = KalmanFilterXYAH() + + def __init__(self, det, max_obs): + # wait activate + self.xywh = xyxy2xywh(det[0:4]) # (x1, y1, x2, y2) --> (xc, yc, w, h) + self.tlwh = xywh2tlwh(self.xywh) # (xc, yc, w, h) --> (t, l, w, h) + self.xyah = tlwh2xyah(self.tlwh) + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + self.max_obs=max_obs + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + self.tracklet_len = 0 + self.history_observations = deque([], maxlen=self.max_obs) + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict( + mean_state, self.covariance + ) + + @staticmethod + def multi_predict(stracks): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict( + multi_mean, multi_covariance + ) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.xyah) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + # self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, new_track.xyah + ) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.id = self.next_id() + self.conf = new_track.conf + self.cls = new_track.cls + self.det_ind = new_track.det_ind + + def update(self, new_track, frame_id): + """ + Update a matched track + :type new_track: STrack + :type frame_id: int + :type update_feature: bool + :return: + """ + self.frame_id = frame_id + self.tracklet_len += 1 + self.history_observations.append(self.xyxy) + + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, new_track.xyah + ) + self.state = TrackState.Tracked + self.is_activated = True + + self.conf = new_track.conf + self.cls = new_track.cls + self.det_ind = new_track.det_ind + + @property + def xyxy(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + if self.mean is None: + ret = self.xywh.copy() # (xc, yc, w, h) + else: + ret = self.mean[:4].copy() # kf (xc, yc, a, h) + ret[2] *= ret[3] # (xc, yc, a, h) --> (xc, yc, w, h) + ret = xywh2xyxy(ret) + return ret + + +class ByteTrack(BaseTracker): + """ + BYTETracker: A tracking algorithm based on ByteTrack, which utilizes motion-based tracking. + + Args: + min_conf (float, optional): Threshold for detection confidence. Detections below this threshold are discarded. + track_thresh (float, optional): Threshold for detection confidence. Detections above this threshold are considered for tracking in the first association round. + match_thresh (float, optional): Threshold for the matching step in data association. Controls the maximum distance allowed between tracklets and detections for a match. + track_buffer (int, optional): Number of frames to keep a track alive after it was last detected. A longer buffer allows for more robust tracking but may increase identity switches. + frame_rate (int, optional): Frame rate of the video being processed. Used to scale the track buffer size. + per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class. + """ + def __init__( + self, + min_conf: float = 0.1, + track_thresh: float = 0.45, + match_thresh: float = 0.8, + track_buffer: int = 25, + frame_rate: int = 30, + per_class: bool = False, + ): + super().__init__(per_class=per_class) + self.active_tracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + + self.frame_id = 0 + self.track_buffer = track_buffer + + self.per_class = per_class + self.min_conf = min_conf + self.track_thresh = track_thresh + self.match_thresh = match_thresh + self.det_thresh = track_thresh + self.buffer_size = int(frame_rate / 30.0 * track_buffer) + self.max_time_lost = self.buffer_size + self.kalman_filter = KalmanFilterXYAH() + + @BaseTracker.setup_decorator + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray = None, embs: np.ndarray = None) -> np.ndarray: + + self.check_inputs(dets, img) + + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + self.frame_count += 1 + activated_starcks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + confs = dets[:, 4] + + remain_inds = confs > self.track_thresh + + inds_low = confs > self.min_conf + inds_high = confs < self.track_thresh + inds_second = np.logical_and(inds_low, inds_high) + + dets_second = dets[inds_second] + dets = dets[remain_inds] + + if len(dets) > 0: + """Detections""" + detections = [ + STrack(det, max_obs=self.max_obs) for det in dets + ] + else: + detections = [] + + """ Add newly detected tracklets to tracked_stracks""" + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.active_tracks: + if not track.is_activated: + unconfirmed.append(track) + else: + tracked_stracks.append(track) + + """ Step 2: First association, with high conf detection boxes""" + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + STrack.multi_predict(strack_pool) + dists = iou_distance(strack_pool, detections) + # if not self.args.mot20: + dists = fuse_score(dists, detections) + matches, u_track, u_detection = linear_assignment( + dists, thresh=self.match_thresh + ) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_count) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_count, new_id=False) + refind_stracks.append(track) + + """ Step 3: Second association, with low conf detection boxes""" + # association the untrack to the low conf detections + if len(dets_second) > 0: + """Detections""" + detections_second = [STrack(det_second, max_obs=self.max_obs) for det_second in dets_second] + else: + detections_second = [] + r_tracked_stracks = [ + strack_pool[i] + for i in u_track + if strack_pool[i].state == TrackState.Tracked + ] + dists = iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection_second = linear_assignment(dists, thresh=0.5) + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_count) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_count, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + + """Deal with unconfirmed tracks, usually tracks with only one beginning frame""" + detections = [detections[i] for i in u_detection] + dists = iou_distance(unconfirmed, detections) + # if not self.args.mot20: + dists = fuse_score(dists, detections) + matches, u_unconfirmed, u_detection = linear_assignment(dists, thresh=0.7) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_count) + activated_starcks.append(unconfirmed[itracked]) + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.conf < self.det_thresh: + continue + track.activate(self.kalman_filter, self.frame_count) + activated_starcks.append(track) + """ Step 5: Update state""" + for track in self.lost_stracks: + if self.frame_count - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + self.active_tracks = [ + t for t in self.active_tracks if t.state == TrackState.Tracked + ] + self.active_tracks = joint_stracks(self.active_tracks, activated_starcks) + self.active_tracks = joint_stracks(self.active_tracks, refind_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.active_tracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.active_tracks, self.lost_stracks = remove_duplicate_stracks( + self.active_tracks, self.lost_stracks + ) + # get confs of lost tracks + output_stracks = [track for track in self.active_tracks if track.is_activated] + outputs = [] + for t in output_stracks: + output = [] + output.extend(t.xyxy) + output.append(t.id) + output.append(t.conf) + output.append(t.cls) + output.append(t.det_ind) + outputs.append(output) + outputs = np.asarray(outputs) + return outputs + + +# id, class_id, conf + + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.id] = 1 + res.append(t) + for t in tlistb: + tid = t.id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.id] = t + for t in tlistb: + tid = t.id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if i not in dupa] + resb = [t for i, t in enumerate(stracksb) if i not in dupb] + return resa, resb diff --git a/boxmot/trackers/deepocsort/__init__.py b/boxmot/trackers/deepocsort/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/deepocsort/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/deepocsort/deepocsort.py b/boxmot/trackers/deepocsort/deepocsort.py new file mode 100644 index 0000000000000000000000000000000000000000..55e63c17a731b7cdda617dddab1bf04ee58a9aac --- /dev/null +++ b/boxmot/trackers/deepocsort/deepocsort.py @@ -0,0 +1,467 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +import torch +from pathlib import Path +from collections import deque + +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.motion.cmc import get_cmc_method +from boxmot.motion.kalman_filters.aabb.xysr_kf import KalmanFilterXYSR +from boxmot.utils.association import associate, linear_assignment +from boxmot.trackers.basetracker import BaseTracker +from boxmot.utils.ops import xyxy2xysr + + +def k_previous_obs(observations, cur_age, k): + if len(observations) == 0: + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age - dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if score is None: + return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]).reshape((1, 4)) + else: + return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + + count = 0 + + def __init__(self, det, delta_t=3, emb=None, alpha=0, max_obs=50, Q_xy_scaling = 0.01, Q_s_scaling = 0.0001): + """ + Initialises a tracker using initial bounding box. + + """ + # define constant velocity model + self.max_obs=max_obs + bbox = det[0:5] + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + + self.Q_xy_scaling = Q_xy_scaling + self.Q_s_scaling = Q_s_scaling + + self.kf = KalmanFilterXYSR(dim_x=7, dim_z=4) + self.kf.F = np.array( + [ + # x y s r x' y' s' + [1, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1], + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ] + ) + self.kf.R[2:, 2:] *= 10.0 + self.kf.P[4:, 4:] *= 1000.0 # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10.0 + self.kf.Q[4:6, 4:6] *= self.Q_xy_scaling + self.kf.Q[-1, -1] *= self.Q_s_scaling + + self.bbox_to_z_func = xyxy2xysr + self.x_to_bbox_func = convert_x_to_bbox + + self.kf.x[:4] = self.bbox_to_z_func(bbox) + + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.history = deque([], maxlen=self.max_obs) + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), + let's bear it for now. + """ + # Used for OCR + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + # Used to output track after min_hits reached + self.features = deque([], maxlen=self.max_obs) + # Used for velocity + self.observations = dict() + self.velocity = None + self.delta_t = delta_t + self.history_observations = deque([], maxlen=self.max_obs) + + self.emb = emb + + self.frozen = False + + def update(self, det): + """ + Updates the state vector with observed bbox. + """ + + if det is not None: + bbox = det[0:5] + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + self.frozen = False + + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for dt in range(self.delta_t, 0, -1): + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction(previous_box, bbox) + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + + self.kf.update(self.bbox_to_z_func(bbox)) + else: + self.kf.update(det) + self.frozen = True + + def update_emb(self, emb, alpha=0.9): + self.emb = alpha * self.emb + (1 - alpha) * emb + self.emb /= np.linalg.norm(self.emb) + + def get_emb(self): + return self.emb + + def apply_affine_correction(self, affine): + m = affine[:, :2] + t = affine[:, 2].reshape(2, 1) + # For OCR + if self.last_observation.sum() > 0: + ps = self.last_observation[:4].reshape(2, 2).T + ps = m @ ps + t + self.last_observation[:4] = ps.T.reshape(-1) + + # Apply to each box in the range of velocity computation + for dt in range(self.delta_t, -1, -1): + if self.age - dt in self.observations: + ps = self.observations[self.age - dt][:4].reshape(2, 2).T + ps = m @ ps + t + self.observations[self.age - dt][:4] = ps.T.reshape(-1) + + # Also need to change kf state, but might be frozen + self.kf.apply_affine_correction(m, t) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + # Don't allow negative bounding boxes + if (self.kf.x[6] + self.kf.x[2]) <= 0: + self.kf.x[6] *= 0.0 + Q = None + + self.kf.predict(Q=Q) + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(self.x_to_bbox_func(self.kf.x)) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return self.x_to_bbox_func(self.kf.x) + + def mahalanobis(self, bbox): + """Should be run after a predict() call for accuracy.""" + return self.kf.md_for_measurement(self.bbox_to_z_func(bbox)) + + +class DeepOcSort(BaseTracker): + """ + DeepOCSort Tracker: A tracking algorithm that utilizes a combination of appearance and motion-based tracking. + + Args: + model_weights (str): Path to the model weights for ReID (Re-Identification). + device (str): Device on which to run the model (e.g., 'cpu' or 'cuda'). + fp16 (bool): Whether to use half-precision (fp16) for faster inference on compatible devices. + per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class. + det_thresh (float, optional): Detection confidence threshold. Detections below this threshold will be ignored. + max_age (int, optional): Maximum number of frames to keep a track alive without any detections. + min_hits (int, optional): Minimum number of hits required to confirm a track. + iou_threshold (float, optional): Intersection over Union (IoU) threshold for data association. + delta_t (int, optional): Time delta for velocity estimation in Kalman Filter. + asso_func (str, optional): Association function to use for data association. Options include "iou" for IoU-based association. + inertia (float, optional): Weight for inertia in motion modeling. Higher values make tracks less responsive to changes. + w_association_emb (float, optional): Weight for the embedding-based association score. + alpha_fixed_emb (float, optional): Fixed alpha for updating embeddings. Controls the contribution of new and old embeddings in the ReID model. + aw_param (float, optional): Parameter for adaptive weighting between association costs. + embedding_off (bool, optional): Whether to turn off the embedding-based association. + cmc_off (bool, optional): Whether to turn off camera motion compensation (CMC). + aw_off (bool, optional): Whether to turn off adaptive weighting. + Q_xy_scaling (float, optional): Scaling factor for the process noise covariance in the Kalman Filter for position coordinates. + Q_s_scaling (float, optional): Scaling factor for the process noise covariance in the Kalman Filter for scale coordinates. + **kwargs: Additional arguments for future extensions or parameters. + """ + def __init__( + self, + reid_weights: Path, + device: torch.device, + half: bool, + per_class: bool = False, + det_thresh: float = 0.3, + max_age: int = 30, + min_hits: int = 3, + iou_threshold: float = 0.3, + delta_t: int = 3, + asso_func: str = "iou", + inertia: float = 0.2, + w_association_emb: float = 0.5, + alpha_fixed_emb: float = 0.95, + aw_param: float = 0.5, + embedding_off: bool = False, + cmc_off: bool = False, + aw_off: bool = False, + Q_xy_scaling: float = 0.01, + Q_s_scaling: float = 0.0001, + **kwargs: dict + ): + super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func) + """ + Sets key parameters for SORT + """ + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.det_thresh = det_thresh + self.delta_t = delta_t + self.asso_func = asso_func + self.inertia = inertia + self.w_association_emb = w_association_emb + self.alpha_fixed_emb = alpha_fixed_emb + self.aw_param = aw_param + self.per_class = per_class + self.Q_xy_scaling = Q_xy_scaling + self.Q_s_scaling = Q_s_scaling + KalmanBoxTracker.count = 1 + + self.model = ReidAutoBackend( + weights=reid_weights, device=device, half=half + ).model + # "similarity transforms using feature point extraction, optical flow, and RANSAC" + self.cmc = get_cmc_method('sof')() + self.embedding_off = embedding_off + self.cmc_off = cmc_off + self.aw_off = aw_off + + @BaseTracker.setup_decorator + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + """ + Params: + dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] + Requires: this method must be called once for each frame even with empty detections + (use np.empty((0, 5)) for frames without detections). + Returns the a similar array, where the last column is the object ID. + NOTE: The number of objects returned may differ from the number of detections provided. + """ + #dets, s, c = dets.data + #print(dets, s, c) + self.check_inputs(dets, img) + + self.frame_count += 1 + self.height, self.width = img.shape[:2] + + scores = dets[:, 4] + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + assert dets.shape[1] == 7 + remain_inds = scores > self.det_thresh + dets = dets[remain_inds] + + # appearance descriptor extraction + if self.embedding_off or dets.shape[0] == 0: + dets_embs = np.ones((dets.shape[0], 1)) + elif embs is not None: + dets_embs = embs[remain_inds] + else: + # (Ndets x X) [512, 1024, 2048] + dets_embs = self.model.get_features(dets[:, 0:4], img) + + # CMC + if not self.cmc_off: + transform = self.cmc.apply(img, dets[:, :4]) + for trk in self.active_tracks: + trk.apply_affine_correction(transform) + + trust = (dets[:, 4] - self.det_thresh) / (1 - self.det_thresh) + af = self.alpha_fixed_emb + # From [self.alpha_fixed_emb, 1], goes to 1 as detector is less confident + dets_alpha = af + (1 - af) * (1 - trust) + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.active_tracks), 5)) + trk_embs = [] + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.active_tracks[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + if np.any(np.isnan(pos)): + to_del.append(t) + else: + trk_embs.append(self.active_tracks[t].get_emb()) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + + if len(trk_embs) > 0: + trk_embs = np.vstack(trk_embs) + else: + trk_embs = np.array(trk_embs) + + for t in reversed(to_del): + self.active_tracks.pop(t) + + velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.active_tracks]) + last_boxes = np.array([trk.last_observation for trk in self.active_tracks]) + k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.active_tracks]) + + """ + First round of association + """ + # (M detections X N tracks, final score) + if self.embedding_off or dets.shape[0] == 0 or trk_embs.shape[0] == 0: + stage1_emb_cost = None + else: + stage1_emb_cost = dets_embs @ trk_embs.T + matched, unmatched_dets, unmatched_trks = associate( + dets[:, 0:5], + trks, + self.asso_func, + self.iou_threshold, + velocities, + k_observations, + self.inertia, + img.shape[1], # w + img.shape[0], # h + stage1_emb_cost, + self.w_association_emb, + self.aw_off, + self.aw_param, + ) + for m in matched: + self.active_tracks[m[1]].update(dets[m[0], :]) + self.active_tracks[m[1]].update_emb(dets_embs[m[0]], alpha=dets_alpha[m[0]]) + + """ + Second round of associaton by OCR + """ + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + left_dets_embs = dets_embs[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + left_trks_embs = trk_embs[unmatched_trks] + + iou_left = self.asso_func(left_dets, left_trks) + # TODO: is better without this + emb_cost_left = left_dets_embs @ left_trks_embs.T + if self.embedding_off: + emb_cost_left = np.zeros_like(emb_cost_left) + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.active_tracks[trk_ind].update(dets[det_ind, :]) + self.active_tracks[trk_ind].update_emb(dets_embs[det_ind], alpha=dets_alpha[det_ind]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + for m in unmatched_trks: + self.active_tracks[m].update(None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker( + dets[i], + delta_t=self.delta_t, + emb=dets_embs[i], + alpha=dets_alpha[i], + Q_xy_scaling=self.Q_xy_scaling, + Q_s_scaling=self.Q_s_scaling, + max_obs=self.max_obs + ) + self.active_tracks.append(trk) + i = len(self.active_tracks) + for trk in reversed(self.active_tracks): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0] + else: + """ + this is optional to use the recent observation or the kalman filter prediction, + we didn't notice significant difference here + """ + d = trk.last_observation[:4] + if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + # +1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id], [trk.conf], [trk.cls], [trk.det_ind])).reshape(1, -1)) + i -= 1 + # remove dead tracklet + if trk.time_since_update > self.max_age: + self.active_tracks.pop(i) + if len(ret) > 0: + return np.concatenate(ret) + return np.array([]) diff --git a/boxmot/trackers/hybridsort/__init__.py b/boxmot/trackers/hybridsort/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/hybridsort/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/hybridsort/association.py b/boxmot/trackers/hybridsort/association.py new file mode 100644 index 0000000000000000000000000000000000000000..80b1612678889313a5e9a526574207e8d7d06ddf --- /dev/null +++ b/boxmot/trackers/hybridsort/association.py @@ -0,0 +1,684 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np + + +def intersection_batch(bboxes1, bboxes2): + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + intersections = w * h + return intersections + + +def box_area(bbox): + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + return area + + +def iou_batch(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + return (o) + + +def cal_score_dif_batch(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + score2 = bboxes2[..., 4] + score1 = bboxes1[..., 4] + + return (abs(score2 - score1)) + + +def cal_score_dif_batch_two_score(bboxes1, bboxes2): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + score2 = bboxes2[..., 5] + score1 = bboxes1[..., 4] + + return (abs(score2 - score1)) + + +def hmiou(bboxes1, bboxes2): + """ + Height_Modulated_IoU + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + yy11 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + yy12 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + + yy21 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + yy22 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + o = (yy12 - yy11) / (yy22 - yy21) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o *= wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + return (o) + + +def giou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + wc = xxc2 - xxc1 + hc = yyc2 - yyc1 + assert ((wc > 0).all() and (hc > 0).all()) + area_enclose = wc * hc + giou = iou - (area_enclose - wh) / area_enclose + giou = (giou + 1.) / 2.0 # resize from (-1,1) to (0,1) + return giou + + +def giou_batch_true(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + union = ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + iou = wh / union + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + wc = xxc2 - xxc1 + hc = yyc2 - yyc1 + assert ((wc > 0).all() and (hc > 0).all()) + area_enclose = wc * hc + giou = iou - (area_enclose - union) / area_enclose + giou = (giou + 1.) / 2.0 # resize from (-1,1) to (0,1) + return giou + + +def diou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + diou = iou - inner_diag / outer_diag + + return (diou + 1) / 2.0 # resize from (-1,1) to (0,1) + + +def ciou_batch(bboxes1, bboxes2): + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + iou = wh / ((bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - wh) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + + w1 = bboxes1[..., 2] - bboxes1[..., 0] + h1 = bboxes1[..., 3] - bboxes1[..., 1] + w2 = bboxes2[..., 2] - bboxes2[..., 0] + h2 = bboxes2[..., 3] - bboxes2[..., 1] + + # prevent dividing over zero. add one pixel shift + h2 = h2 + 1. + h1 = h1 + 1. + arctan = np.arctan(w2 / h2) - np.arctan(w1 / h1) + v = (4 / (np.pi ** 2)) * (arctan ** 2) + S = 1 - iou + alpha = v / (S + v) + ciou = iou - inner_diag / outer_diag - alpha * v + + return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1) + + +def ct_dist(bboxes1, bboxes2): + """ + Measure the center distance between two sets of bounding boxes, + this is a coarse implementation, we don't recommend using it only + for association, which can be unstable and sensitive to frame rate + and object speed. + """ + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + ct_dist = np.sqrt(ct_dist2) + + # The linear rescaling is a naive version and needs more study + ct_dist = ct_dist / ct_dist.max() + return ct_dist.max() - ct_dist # resize to (0,1) + + +def speed_direction_batch(dets, tracks): + """ + batch formulation of function 'speed_direction', compute normalized speed from batch bboxes + @param dets: + @param tracks: + @return: normalized speed in batch + """ + tracks = tracks[..., np.newaxis] + CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0 + CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (tracks[:, 1] + tracks[:, 3]) / 2.0 + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def linear_assignment(cost_matrix, thresh=0.): + try: # [hgx0411] goes here! + import lap + if thresh != 0: + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + else: + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i], i] for i in x if i >= 0]) + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + + +def cost_vel(Y, X, trackers, velocities, detections, previous_obs, vdc_weight): + # Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + + # iou_matrix = iou_batch(detections, trackers) + scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + # iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + return angle_diff_cost + + +def speed_direction_batch_lt(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = dets[:, 0], dets[:, 1] + CX2, CY2 = tracks[:, 0], tracks[:, 1] + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def speed_direction_batch_rt(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = dets[:, 0], dets[:, 3] + CX2, CY2 = tracks[:, 0], tracks[:, 3] + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def speed_direction_batch_lb(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = dets[:, 2], dets[:, 1] + CX2, CY2 = tracks[:, 2], tracks[:, 1] + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def speed_direction_batch_rb(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = dets[:, 2], dets[:, 3] + CX2, CY2 = tracks[:, 2], tracks[:, 3] + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def associate_4_points( + detections, trackers, iou_threshold, lt, rt, lb, rb, previous_obs, vdc_weight, iou_type=None, args=None +): + if (len(trackers) == 0): + return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int) + + Y1, X1 = speed_direction_batch_lt(detections, previous_obs) + Y2, X2 = speed_direction_batch_rt(detections, previous_obs) + Y3, X3 = speed_direction_batch_lb(detections, previous_obs) + Y4, X4 = speed_direction_batch_rb(detections, previous_obs) + YC, XC = speed_direction_batch(detections, previous_obs) + cost_lt = cost_vel(Y1, X1, trackers, lt, detections, previous_obs, vdc_weight) + cost_rt = cost_vel(Y2, X2, trackers, rt, detections, previous_obs, vdc_weight) + cost_lb = cost_vel(Y3, X3, trackers, lb, detections, previous_obs, vdc_weight) + cost_rb = cost_vel(Y4, X4, trackers, rb, detections, previous_obs, vdc_weight) + + iou_matrix = iou_type(detections, trackers) + angle_diff_cost = cost_lt + cost_rt + cost_lb + cost_rb + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost)) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if (iou_matrix[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def associate_4_points_with_score( + detections, trackers, iou_threshold, lt, rt, lb, rb, previous_obs, vdc_weight, iou_type=None, args=None +): + if (len(trackers) == 0): + return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int) + + Y1, X1 = speed_direction_batch_lt(detections, previous_obs) + Y2, X2 = speed_direction_batch_rt(detections, previous_obs) + Y3, X3 = speed_direction_batch_lb(detections, previous_obs) + Y4, X4 = speed_direction_batch_rb(detections, previous_obs) + cost_lt = cost_vel(Y1, X1, trackers, lt, detections, previous_obs, vdc_weight) + cost_rt = cost_vel(Y2, X2, trackers, rt, detections, previous_obs, vdc_weight) + cost_lb = cost_vel(Y3, X3, trackers, lb, detections, previous_obs, vdc_weight) + cost_rb = cost_vel(Y4, X4, trackers, rb, detections, previous_obs, vdc_weight) + iou_matrix = iou_type(detections, trackers) + score_dif = cal_score_dif_batch(detections, trackers) + + angle_diff_cost = cost_lt + cost_rt + cost_lb + cost_rb + + # TCM + angle_diff_cost -= score_dif * args.TCM_first_step_weight + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost)) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if (iou_matrix[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def associate_4_points_with_score_with_reid( + detections, trackers, iou_threshold, lt, rt, lb, rb, previous_obs, vdc_weight, + TCM_first_step_weight, + iou_type=None, emb_cost=None, weights=(1.0, 0), thresh=0.8, + long_emb_dists=None, with_longterm_reid=False, + longterm_reid_weight=0.0, with_longterm_reid_correction=False, + longterm_reid_correction_thresh=0.0, dataset="dancetrack" +): + + if (len(trackers) == 0): + return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int) + + Y1, X1 = speed_direction_batch_lt(detections, previous_obs) + Y2, X2 = speed_direction_batch_rt(detections, previous_obs) + Y3, X3 = speed_direction_batch_lb(detections, previous_obs) + Y4, X4 = speed_direction_batch_rb(detections, previous_obs) + cost_lt = cost_vel(Y1, X1, trackers, lt, detections, previous_obs, vdc_weight) + cost_rt = cost_vel(Y2, X2, trackers, rt, detections, previous_obs, vdc_weight) + cost_lb = cost_vel(Y3, X3, trackers, lb, detections, previous_obs, vdc_weight) + cost_rb = cost_vel(Y4, X4, trackers, rb, detections, previous_obs, vdc_weight) + iou_matrix = iou_type(detections, trackers) + score_dif = cal_score_dif_batch(detections, trackers) + + angle_diff_cost = cost_lt + cost_rt + cost_lb + cost_rb + + # TCM + angle_diff_cost -= score_dif * TCM_first_step_weight + + if min(iou_matrix.shape) > 0: + if emb_cost is None: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost)) + else: + if not with_longterm_reid: + matched_indices = linear_assignment( + weights[0] * (-(iou_matrix + angle_diff_cost)) + + weights[1] * emb_cost) # , thresh=thresh + else: # long-term reid feats + matched_indices = linear_assignment( + weights[0] * (-(iou_matrix + angle_diff_cost)) + + weights[1] * emb_cost + longterm_reid_weight * long_emb_dists + ) # , thresh=thresh + + if matched_indices.size == 0: + matched_indices = np.empty(shape=(0, 2)) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU (and long-term ReID feats) + matches = [] + # iou_matrix_thre = iou_matrix if dataset == "dancetrack" else iou_matrix - score_dif + iou_matrix_thre = iou_matrix - score_dif + if with_longterm_reid_correction: + for m in matched_indices: + if (emb_cost[m[0], m[1]] > longterm_reid_correction_thresh) and\ + (iou_matrix_thre[m[0], m[1]] < iou_threshold): + print("correction:", emb_cost[m[0], m[1]]) + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + else: + for m in matched_indices: + if (iou_matrix_thre[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def associate_kitti( + detections, trackers, det_cates, iou_threshold, velocities, previous_obs, vdc_weight +): + if (len(trackers) == 0): + return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int) + + """ + Cost from the velocity direction consistency + """ + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + """ + Cost from IoU + """ + iou_matrix = iou_batch(detections, trackers) + + """ + With multiple categories, generate the cost for catgory mismatch + """ + num_dets = detections.shape[0] + num_trk = trackers.shape[0] + cate_matrix = np.zeros((num_dets, num_trk)) + for i in range(num_dets): + for j in range(num_trk): + if det_cates[i] != trackers[j, 4]: + cate_matrix[i][j] = -1e6 + + cost_matrix = - iou_matrix - angle_diff_cost - cate_matrix + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(cost_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if (d not in matched_indices[:, 0]): + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if (t not in matched_indices[:, 1]): + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if (iou_matrix[m[0], m[1]] < iou_threshold): + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if (len(matches) == 0): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +# compute embedding distance and gating, borrowed and modified from FairMOT +from scipy.spatial.distance import cdist + + +def embedding_distance(tracks_feat, detections_feat, metric='cosine'): + """ + :param tracks: list[KalmanBoxTracker] + :param detections: list[KalmanBoxTracker] + :param metric: + :return: cost_matrix np.ndarray + """ + + cost_matrix = np.zeros((len(tracks_feat), len(detections_feat)), dtype=np.float64) + if cost_matrix.size == 0: + return cost_matrix + # det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float64) # [detection_num, emd_dim] + # #for i, track in enumerate(tracks): + # #cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + # track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float64) # [track_num, emd_dim] + # Nomalized features, metric: cosine, [track_num, detection_num] + cost_matrix = np.maximum(0.0, cdist(tracks_feat, detections_feat, metric)) + return cost_matrix diff --git a/boxmot/trackers/hybridsort/hybridsort.py b/boxmot/trackers/hybridsort/hybridsort.py new file mode 100644 index 0000000000000000000000000000000000000000..2590a1543ca365541b0bf89c17ff82aad39e58ba --- /dev/null +++ b/boxmot/trackers/hybridsort/hybridsort.py @@ -0,0 +1,597 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +""" + This script is adopted from the SORT script by Alex Bewley alex@bewley.ai +""" + +from collections import deque # [hgx0418] deque for reid feature + +import numpy as np + +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.motion.cmc import get_cmc_method +from boxmot.trackers.hybridsort.association import ( + associate_4_points_with_score, associate_4_points_with_score_with_reid, + cal_score_dif_batch_two_score, embedding_distance, linear_assignment) +from boxmot.trackers.basetracker import BaseTracker + + +np.random.seed(0) + + +def k_previous_obs(observations, cur_age, k): + if len(observations) == 0: + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age - dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2. + y = bbox[1] + h / 2. + s = w * h # scale is just area + r = w / float(h + 1e-6) + score = bbox[4] + if score: + return np.array([x, y, s, score, r]).reshape((5, 1)) + else: + return np.array([x, y, s, r]).reshape((4, 1)) + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[4]) + h = x[2] / w + score = x[3] + if score is None: + return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4)) + else: + return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +def speed_direction_lt(bbox1, bbox2): + cx1, cy1 = bbox1[0], bbox1[1] + cx2, cy2 = bbox2[0], bbox2[1] + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +def speed_direction_rt(bbox1, bbox2): + cx1, cy1 = bbox1[0], bbox1[3] + cx2, cy2 = bbox2[0], bbox2[3] + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +def speed_direction_lb(bbox1, bbox2): + cx1, cy1 = bbox1[2], bbox1[1] + cx2, cy2 = bbox2[2], bbox2[1] + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +def speed_direction_rb(bbox1, bbox2): + cx1, cy1 = bbox1[2], bbox1[3] + cx2, cy2 = bbox2[2], bbox2[3] + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1)**2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + count = 0 + + def __init__( + self, + bbox, + cls, + det_ind, + temp_feat, + delta_t=3, + orig=False, + buffer_size=30, + longterm_bank_length=30, + alpha=0.8, + max_obs=50 + ): # 'temp_feat' and 'buffer_size' for reid feature + """ + Initialises a tracker using initial bounding box. + + """ + # define constant velocity model + # if not orig and not args.kalman_GPR: + from boxmot.motion.kalman_filters.aabb.xysr_kf import KalmanFilterXYSR + self.kf = KalmanFilterXYSR(dim_x=9, dim_z=5, max_obs=max_obs) + + # u, v, s, c, r, ~u, ~v, ~s, ~c + self.kf.F = np.array([[1, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1]]) + self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0]]) + + self.kf.R[2:, 2:] *= 10. + self.kf.P[5:, 5:] *= 1000. # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10. + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[-2, -2] *= 0.01 + self.kf.Q[5:, 5:] *= 0.01 + + self.kf.x[:5] = convert_bbox_to_z(bbox) + + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.max_obs = max_obs + self.history = deque([], maxlen=self.max_obs) + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + self.conf = bbox[4] + self.cls = cls + self.det_ind = det_ind + self.adapfs = False + + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), + let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + self.last_observation_save = np.array([-1, -1, -1, -1, -1]) + self.observations = dict() + self.history_observations = deque([], maxlen=self.max_obs) + self.velocity_lt = None + self.velocity_rt = None + self.velocity_lb = None + self.velocity_rb = None + self.delta_t = delta_t + self.confidence_pre = None + self.confidence = bbox[4] + + # add the following values and functions + self.smooth_feat = None + buffer_size = longterm_bank_length + self.features = deque([], maxlen=buffer_size) + self.update_features(temp_feat) + + # momentum of embedding update + self.alpha = alpha + + # ReID. for update embeddings during tracking + def update_features(self, feat, score=-1): + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + if self.adapfs: + assert score > 0 + pre_w = self.alpha * (self.confidence / (self.confidence + score)) + cur_w = (1 - self.alpha) * (score / (self.confidence + score)) + sum_w = pre_w + cur_w + pre_w = pre_w / sum_w + cur_w = cur_w / sum_w + self.smooth_feat = pre_w * self.smooth_feat + cur_w * feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def camera_update(self, warp_matrix): + """ + update 'self.mean' of current tracklet with ecc results. + Parameters + ---------- + warp_matrix: warp matrix computed by ECC. + """ + x1, y1, x2, y2, s = convert_x_to_bbox(self.kf.x)[0] + x1_, y1_ = warp_matrix @ np.array([x1, y1, 1]).T + x2_, y2_ = warp_matrix @ np.array([x2, y2, 1]).T + # w, h = x2_ - x1_, y2_ - y1_ + # cx, cy = x1_ + w / 2, y1_ + h / 2 + self.kf.x[:5] = convert_bbox_to_z([x1_, y1_, x2_, y2_, s]) + + def update(self, bbox, cls, det_ind, id_feature, update_feature=True): + """ + Updates the state vector with observed bbox. + """ + velocity_lt = None + velocity_rt = None + velocity_lb = None + velocity_rb = None + if bbox is not None: + self.conf = bbox[-1] + self.cls = cls + self.det_ind = det_ind + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + # dt = self.delta_t - i + if self.age - i - 1 in self.observations: + previous_box = self.observations[self.age - i - 1] + if velocity_lt is not None: + velocity_lt += speed_direction_lt(previous_box, bbox) + velocity_rt += speed_direction_rt(previous_box, bbox) + velocity_lb += speed_direction_lb(previous_box, bbox) + velocity_rb += speed_direction_rb(previous_box, bbox) + else: + velocity_lt = speed_direction_lt(previous_box, bbox) + velocity_rt = speed_direction_rt(previous_box, bbox) + velocity_lb = speed_direction_lb(previous_box, bbox) + velocity_rb = speed_direction_rb(previous_box, bbox) + # break + if previous_box is None: + previous_box = self.last_observation + self.velocity_lt = speed_direction_lt(previous_box, bbox) + self.velocity_rt = speed_direction_rt(previous_box, bbox) + self.velocity_lb = speed_direction_lb(previous_box, bbox) + self.velocity_rb = speed_direction_rb(previous_box, bbox) + else: + self.velocity_lt = velocity_lt + self.velocity_rt = velocity_rt + self.velocity_lb = velocity_lb + self.velocity_rb = velocity_rb + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.last_observation_save = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + self.kf.update(convert_bbox_to_z(bbox)) + # add interface for update feature or not + if update_feature: + if self.adapfs: + self.update_features(id_feature, score=bbox[4]) + else: + self.update_features(id_feature) + self.confidence_pre = self.confidence + self.confidence = bbox[4] + else: + self.kf.update(bbox) + self.confidence_pre = None + + def predict(self, track_thresh=0.6): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if ((self.kf.x[7] + self.kf.x[2]) <= 0): + self.kf.x[7] *= 0.0 + + self.kf.predict() + self.age += 1 + if (self.time_since_update > 0): + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x)) + if not self.confidence_pre: + return ( + self.history[-1], + np.clip(self.kf.x[3], track_thresh, 1.0), + np.clip(self.confidence, 0.1, track_thresh) + ) + else: + return ( + self.history[-1], + np.clip(self.kf.x[3], track_thresh, 1.0), + np.clip(self.confidence - (self.confidence_pre - self.confidence), 0.1, track_thresh) + ) + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return convert_x_to_bbox(self.kf.x) + + +class HybridSort(BaseTracker): + """ + HybridSORT Tracker: A tracking algorithm that utilizes a combination of appearance and motion-based tracking + and temporal consistency models (TCM) for improved tracking accuracy and robustness. + + Args: + reid_weights (str): Path to the model weights for ReID (Re-Identification). + device (str): Device on which to run the model (e.g., 'cpu' or 'cuda'). + half (bool): Whether to use half-precision (fp16) for faster inference on compatible devices. + det_thresh (float): Detection confidence threshold. Detections below this threshold will be ignored in the first association step. + per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class. + max_age (int, optional): Maximum number of frames to keep a track alive without any detections. + min_hits (int, optional): Minimum number of hits required to confirm a track. + iou_threshold (float, optional): Intersection over Union (IoU) threshold for data association. + delta_t (int, optional): Time delta for velocity estimation in Kalman Filter. + asso_func (str, optional): Association function to use for data association. Options include "iou" for IoU-based association. + inertia (float, optional): Weight for inertia in motion modeling. Higher values make tracks less responsive to changes. + longterm_reid_weight (float, optional): Weight for the long-term ReID feature in the association process. + TCM_first_step_weight (float, optional): Weight for the Temporal Consistency Model (TCM) in the first association step. + use_byte (bool, optional): Whether to use BYTE association in the second association step. + """ + def __init__(self, reid_weights, device, half, det_thresh, per_class=False, max_age=30, min_hits=3, + iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, longterm_reid_weight=0, TCM_first_step_weight=0, use_byte=False): + super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func) + + """ + Sets key parameters for SORT + """ + self.max_age: int = max_age + self.min_hits: int = min_hits + self.iou_threshold: float = iou_threshold + self.per_class: bool = per_class + self.frame_count: int = 0 + self.det_thresh: float = det_thresh + self.delta_t: int = delta_t + self.inertia: float = inertia + self.use_byte: bool = use_byte + self.low_thresh: float = 0.1 + self.EG_weight_high_score: float = 1.3 + self.EG_weight_low_score: float = 1.2 + self.TCM_first_step: bool = True + self.with_longterm_reid: bool = True + self.with_longterm_reid_correction: bool = True + self.longterm_reid_weight: float = longterm_reid_weight + self.TCM_first_step_weight: float = TCM_first_step_weight + self.high_score_matching_thresh: float = 0.8 + self.longterm_reid_correction_thresh: float = 0.4 + self.longterm_reid_correction_thresh_low: float = 0.4 + self.TCM_byte_step: bool = True + self.TCM_byte_step_weight: float = 1.0 + self.dataset: str = 'dancetrack' + self.ECC: bool = False + KalmanBoxTracker.count = 0 + + self.model = ReidAutoBackend( + weights=reid_weights, device=device, half=half + ).model + self.cmc = get_cmc_method('ecc')() + + def camera_update(self, trackers, warp_matrix): + for tracker in trackers: + tracker.camera_update(warp_matrix) + + @BaseTracker.setup_decorator + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + """ + Params: + dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] + Requires: this method must be called once for each frame even with empty detections + (use np.empty((0, 5)) for frames without detections). + Returns the a similar array, where the last column is the object ID. + NOTE: The number of objects returned may differ from the number of detections provided. + """ + + self.check_inputs(dets, img) + + if dets is None: + return np.empty((0, 7)) + + if self.ECC: + warp_matrix = self.cmc.apply(img, dets) + if warp_matrix is not None: + self.camera_update(self.active_tracks, warp_matrix) + + self.frame_count += 1 + scores = dets[:, 4] + bboxes = dets[:, :4] + + dets_embs = self.model.get_features(bboxes, img) + dets0 = np.concatenate((dets, np.expand_dims(scores, axis=-1)), axis=1) + dets = np.concatenate((bboxes, np.expand_dims(scores, axis=-1)), axis=1) + inds_low = scores > self.low_thresh + inds_high = scores < self.det_thresh + inds_second = np.logical_and(inds_low, inds_high) # self.det_thresh > score > 0.1, for second matching + dets_second = dets[inds_second] # detections for second matching + remain_inds = scores > self.det_thresh + dets = dets[remain_inds] + id_feature_keep = dets_embs[remain_inds] # ID feature of 1st stage matching + id_feature_second = dets_embs[inds_second] # ID feature of 2nd stage matching + + trks = np.zeros((len(self.active_tracks), 8)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos, kalman_score, simple_score = self.active_tracks[t].predict() + trk[:6] = [pos[0][0], pos[0][1], pos[0][2], pos[0][3], kalman_score[0], simple_score] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.active_tracks.pop(t) + + velocities_lt = np.array( + [trk.velocity_lt if trk.velocity_lt is not None else np.array((0, 0)) for trk in self.active_tracks]) + velocities_rt = np.array( + [trk.velocity_rt if trk.velocity_rt is not None else np.array((0, 0)) for trk in self.active_tracks]) + velocities_lb = np.array( + [trk.velocity_lb if trk.velocity_lb is not None else np.array((0, 0)) for trk in self.active_tracks]) + velocities_rb = np.array( + [trk.velocity_rb if trk.velocity_rb is not None else np.array((0, 0)) for trk in self.active_tracks]) + last_boxes = np.array([trk.last_observation for trk in self.active_tracks]) + k_observations = np.array( + [k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.active_tracks]) + + """ + First round of association + """ + if self.EG_weight_high_score > 0 and self.TCM_first_step: + track_features = np.asarray([track.smooth_feat for track in self.active_tracks], + dtype=np.float64) + emb_dists = embedding_distance(track_features, id_feature_keep).T + if self.with_longterm_reid or self.with_longterm_reid_correction: + long_track_features = np.asarray([np.vstack(list(track.features)).mean(0) for track in self.active_tracks], + dtype=np.float64) + assert track_features.shape == long_track_features.shape + long_emb_dists = embedding_distance(long_track_features, id_feature_keep).T + assert emb_dists.shape == long_emb_dists.shape + matched, unmatched_dets, unmatched_trks = associate_4_points_with_score_with_reid( + dets, trks, self.iou_threshold, velocities_lt, velocities_rt, velocities_lb, velocities_rb, + k_observations, self.inertia, self.TCM_first_step_weight, self.asso_func, emb_cost=emb_dists, + weights=(1.0, self.EG_weight_high_score), thresh=self.high_score_matching_thresh, + long_emb_dists=long_emb_dists, with_longterm_reid=self.with_longterm_reid, + longterm_reid_weight=self.longterm_reid_weight, + with_longterm_reid_correction=self.with_longterm_reid_correction, + longterm_reid_correction_thresh=self.longterm_reid_correction_thresh, + dataset=self.dataset) + else: + matched, unmatched_dets, unmatched_trks = associate_4_points_with_score_with_reid( + dets, trks, self.iou_threshold, velocities_lt, velocities_rt, velocities_lb, velocities_rb, + k_observations, self.inertia, self.TCM_first_step_weight, self.asso_func, emb_cost=emb_dists, + weights=(1.0, self.EG_weight_high_score), thresh=self.high_score_matching_thresh) + elif self.TCM_first_step: + matched, unmatched_dets, unmatched_trks = associate_4_points_with_score( + dets, trks, self.iou_threshold, velocities_lt, velocities_rt, velocities_lb, velocities_rb, + k_observations, self.inertia, self.TCM_first_step_weight, self.asso_func) + + # update with id feature + for m in matched: + self.active_tracks[m[1]].update(dets[m[0], :], dets0[m[0], 5], dets0[m[0], 6], id_feature_keep[m[0], :]) + + """ + Second round of associaton by OCR + """ + # BYTE association + if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0: + u_trks = trks[unmatched_trks] + u_tracklets = [self.active_tracks[index] for index in unmatched_trks] + iou_left = self.asso_func(dets_second, u_trks) + iou_left = np.array(iou_left) + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + if self.TCM_byte_step: + iou_left -= np.array( + cal_score_dif_batch_two_score(dets_second, u_trks) * self.TCM_byte_step_weight + ) + iou_left_thre = iou_left + if self.EG_weight_low_score > 0: + u_track_features = np.asarray([track.smooth_feat for track in u_tracklets], dtype=np.float64) + emb_dists_low_score = embedding_distance(u_track_features, id_feature_second).T + matched_indices = linear_assignment(-iou_left + self.EG_weight_low_score * emb_dists_low_score, + ) + else: + matched_indices = linear_assignment(-iou_left) + to_remove_trk_indices = [] + for m in matched_indices: + det_ind, trk_ind = m[0], unmatched_trks[m[1]] + if self.with_longterm_reid_correction and self.EG_weight_low_score > 0: + if (iou_left_thre[m[0], m[1]] < self.iou_threshold) or \ + (emb_dists_low_score[m[0], m[1]] > self.longterm_reid_correction_thresh_low): + print("correction 2nd:", emb_dists_low_score[m[0], m[1]]) + continue + else: + if iou_left_thre[m[0], m[1]] < self.iou_threshold: + continue + self.active_tracks[trk_ind].update( + dets_second[det_ind, :], + id_feature_second[det_ind, :], + update_feature=False + ) # [hgx0523] do not update with id feature + to_remove_trk_indices.append(trk_ind) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + # left_id_feature = id_feature_keep[unmatched_dets] # update id feature, if needed + left_trks = last_boxes[unmatched_trks] + iou_left = self.asso_func(left_dets, left_trks) + iou_left = np.array(iou_left) + + if iou_left.max() > self.iou_threshold: + """ + NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.iou_threshold: + continue + self.active_tracks[trk_ind].update( + dets[det_ind, :], + dets0[det_ind, 5], + dets0[det_ind, 6], + id_feature_keep[det_ind, :], + update_feature=False + ) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices)) + unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices)) + + for m in unmatched_trks: + self.active_tracks[m].update(None, None, None, None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker(dets[i, :], dets0[i, 5], dets0[i, 6], id_feature_keep[i, :], delta_t=self.delta_t, max_obs=self.max_obs) + self.active_tracks.append(trk) + i = len(self.active_tracks) + for trk in reversed(self.active_tracks): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0][:4] + else: + """ + this is optional to use the recent observation or the kalman filter prediction, + we didn't notice significant difference here + """ + d = trk.last_observation[:4] + if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + # +1 as MOT benchmark requires positive + ret.append(np.concatenate((d, [trk.id + 1], [trk.conf], [trk.cls], [trk.det_ind])).reshape(1, -1)) + i -= 1 + # remove dead tracklet + if (trk.time_since_update > self.max_age): + self.active_tracks.pop(i) + if (len(ret) > 0): + return np.concatenate(ret) + return np.empty((0, 7)) diff --git a/boxmot/trackers/imprassoc/basetrack.py b/boxmot/trackers/imprassoc/basetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d4c15c5bf192a426335dee0bc4385ab41b866b --- /dev/null +++ b/boxmot/trackers/imprassoc/basetrack.py @@ -0,0 +1,60 @@ +import numpy as np +from collections import OrderedDict + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + LongLost = 3 + Removed = 4 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_long_lost(self): + self.state = TrackState.LongLost + + def mark_removed(self): + self.state = TrackState.Removed + + @staticmethod + def clear_count(): + BaseTrack._count = 0 diff --git a/boxmot/trackers/imprassoc/imprassoctrack.py b/boxmot/trackers/imprassoc/imprassoctrack.py new file mode 100644 index 0000000000000000000000000000000000000000..897eb70da5475aaedcea0fcbe7f4784f077ecb71 --- /dev/null +++ b/boxmot/trackers/imprassoc/imprassoctrack.py @@ -0,0 +1,522 @@ +# Raif Olson + +import numpy as np +from collections import deque +from pathlib import Path +from torch import device + +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.motion.cmc.sof import SOF +from boxmot.motion.kalman_filters.aabb.xywh_kf import KalmanFilterXYWH +from boxmot.trackers.imprassoc.basetrack import BaseTrack, TrackState +from boxmot.utils.matching import (embedding_distance, fuse_score, + iou_distance, linear_assignment, + d_iou_distance) +from boxmot.utils.ops import xywh2xyxy, xyxy2xywh +from boxmot.trackers.basetracker import BaseTracker + + +class STrack(BaseTrack): + shared_kalman = KalmanFilterXYWH() + + def __init__(self, det, feat=None, feat_history=15, max_obs=15): + self.xywh = xyxy2xywh(det[0:4]) # (x1, y1, x2, y2) --> (xc, yc, w, h) + self.conf = det[4] + self.cls = det[5] + self.det_ind = det[6] + self.max_obs=max_obs + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + self.cls_hist = [] # (cls id, freq) + self.update_cls(self.cls, self.conf) + self.history_observations = deque([], maxlen=self.max_obs) + + self.tracklet_len = 0 + + self.smooth_feat = None + self.curr_feat = None + if feat is not None: + self.update_features(feat) + self.features = deque([], maxlen=feat_history) + self.alpha = 0.9 + + def update_features(self, feat): + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def update_cls(self, cls, conf): + if len(self.cls_hist) > 0: + max_freq = 0 + found = False + for c in self.cls_hist: + if cls == c[0]: + c[1] += conf + found = True + + if c[1] > max_freq: + max_freq = c[1] + self.cls = c[0] + if not found: + self.cls_hist.append([cls, conf]) + self.cls = cls + else: + self.cls_hist.append([cls, conf]) + self.cls = cls + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[6] = 0 + mean_state[7] = 0 + + self.mean, self.covariance = self.kalman_filter.predict( + mean_state, self.covariance + ) + + @staticmethod + def multi_predict(stracks): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][6] = 0 + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict( + multi_mean, multi_covariance + ) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + @staticmethod + def multi_gmc(stracks, H=np.eye(2, 3)): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + + R = H[:2, :2] + R8x8 = np.kron(np.eye(4, dtype=float), R) + t = H[:2, 2] + + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + mean = R8x8.dot(mean) + mean[:2] += t + cov = R8x8.dot(cov).dot(R8x8.transpose()) + + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_count): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.id = self.next_id() + + self.mean, self.covariance = self.kalman_filter.initiate(self.xywh) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + # from OAI track, no unconfirmed tracks. + self.is_activated = True + self.frame_count = frame_count + self.start_frame = frame_count + + def re_activate(self, new_track, frame_count, new_id=False): + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, new_track.xywh, self.conf + ) + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_count = frame_count + if new_id: + self.id = self.next_id() + self.conf = new_track.conf + self.cls = new_track.cls + self.det_ind = new_track.det_ind + + self.update_cls(new_track.cls, new_track.conf) + + def update(self, new_track, frame_count): + """ + Update a matched track + :type new_track: STrack + :type frame_count: int + :type update_feature: bool + :return: + """ + self.frame_count = frame_count + self.tracklet_len += 1 + + self.history_observations.append(self.xyxy) + + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, new_track.xywh, self.conf + ) + + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + + self.state = TrackState.Tracked + self.is_activated = True + + self.conf = new_track.conf + self.cls = new_track.cls + self.det_ind = new_track.det_ind + self.update_cls(new_track.cls, new_track.conf) + + @property + def xyxy(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + if self.mean is None: + ret = self.xywh.copy() # (xc, yc, w, h) + else: + ret = self.mean[:4].copy() # kf (xc, yc, w, h) + ret = xywh2xyxy(ret) + return ret + + +class ImprAssocTrack(BaseTracker): + """ + ImprAssocTrack Tracker: A tracking algorithm that utilizes a combination of appearance and motion-based tracking. + + Args: + model_weights (str): Path to the model weights for ReID (Re-Identification). + device (str): Device on which to run the model (e.g., 'cpu' or 'cuda'). + fp16 (bool): Whether to use half-precision (fp16) for faster inference on compatible devices. + per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class. + track_high_thresh (float, optional): High threshold for detection confidence. Detections above this threshold are used in the first association round. + track_low_thresh (float, optional): Low threshold for detection confidence. Detections below this threshold are ignored. + new_track_thresh (float, optional): Threshold for creating a new track. Detections above this threshold will be considered as potential new tracks. + match_thresh (float, optional): Threshold for the matching step in data association. Controls the maximum distance allowed between tracklets and detections for a match. + second_match_thresh (float, optional): Threshold for the second round of matching, used to associate low confidence detections. + overlap_thresh (float, optional): Threshold for discarding overlapping detections after association. + lambda_ (float, optional): Weighting factor for combining different association costs (e.g., IoU and ReID distance). + track_buffer (int, optional): Number of frames to keep a track alive after it was last detected. A longer buffer allows for more robust tracking but may increase identity switches. + proximity_thresh (float, optional): Threshold for IoU (Intersection over Union) distance in first-round association. + appearance_thresh (float, optional): Threshold for appearance embedding distance in the ReID module. + cmc_method (str, optional): Method for correcting camera motion. Options include "sparseOptFlow" (Sparse Optical Flow). + frame_rate (int, optional): Frame rate of the video being processed. Used to scale the track buffer size. + with_reid (bool, optional): Whether to use ReID (Re-Identification) features for association. + """ + def __init__( + self, + reid_weights: Path, + device: device, + half: bool, + per_class: bool = False, + track_high_thresh: float = 0.6, + track_low_thresh: float = 0.1, + new_track_thresh: float = 0.7, + match_thresh: float = 0.65, # bigger? + second_match_thresh: float = 0.19, + overlap_thresh: float = 0.55, + lambda_: float = 0.2, + track_buffer: int = 35, + proximity_thresh: float = 0.1, + appearance_thresh: float = 0.25, + cmc_method: str = "sparseOptFlow", + frame_rate=30, + with_reid: bool = True + ): + super().__init__(per_class=per_class) + self.active_tracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + BaseTrack.clear_count() + + self.per_class = per_class + self.track_high_thresh = track_high_thresh + self.track_low_thresh = track_low_thresh + self.new_track_thresh = new_track_thresh + self.match_thresh = match_thresh + + self.second_match_thresh = second_match_thresh + self.overlap_thresh = overlap_thresh + self.lambda_ = lambda_ + + self.buffer_size = int(frame_rate / 30.0 * track_buffer) + self.max_time_lost = self.buffer_size + self.kalman_filter = KalmanFilterXYWH() + + # ReID module + self.proximity_thresh = proximity_thresh + self.appearance_thresh = appearance_thresh + + self.with_reid = with_reid + if self.with_reid: + rab = ReidAutoBackend( + weights=reid_weights, device=device, half=half + ) + self.model = rab.get_backend() + + self.cmc = SOF() + + + @BaseTracker.setup_decorator + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + self.check_inputs(dets, img) + + self.frame_count += 1 + activated_starcks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + + # Remove bad detections + confs = dets[:, 4] + + # find second round association detections + second_mask = np.logical_and(confs > self.track_low_thresh, confs < self.track_high_thresh) + dets_second = dets[second_mask] + + # find first round association detections + first_mask = confs > self.track_high_thresh + dets_first = dets[first_mask] + + """Extract embeddings """ + # appearance descriptor extraction + if self.with_reid: + if embs is not None: + features_high = embs + else: + # (Ndets x X) [512, 1024, 2048] + features_high = self.model.get_features(dets_first[:, 0:4], img) + + if len(dets) > 0: + """Detections""" + if self.with_reid: + detections = [STrack(det, f, max_obs=self.max_obs) for (det, f) in zip(dets_first, features_high)] + else: + detections = [STrack(det, max_obs=self.max_obs) for (det) in np.array(dets_first)] + else: + detections = [] + + """ Add newly detected tracklets to active_tracks""" + unconfirmed = [] + low_tent = [] + high_tent = [] + active_tracks = [] # type: list[STrack] + for track in self.active_tracks: + if not track.is_activated: + unconfirmed.append(track) + else: + active_tracks.append(track) + + '''Improved Association: First they calc the cost matrix of the high + detections(func_1 -> cost_h), then the calc the cost matrix of the low + detections (func_2 -> cost_l) and get the max values of both. Then + B = det_h_max / det_l_max. + Finally they calc cost = concat(cost_h, B*cost_l) for the matching + ''' + + ''' Step 2: First association, with high score detection boxes''' + strack_pool = joint_stracks(active_tracks, self.lost_stracks) + + # Fix camera motion + warp = self.cmc.apply(img, dets_first) + STrack.multi_gmc(strack_pool, warp) + STrack.multi_gmc(unconfirmed, warp) + + # Predict the current location with KF + STrack.multi_predict(strack_pool) + + # Associate with high score detection boxes + d_ious_dists = d_iou_distance(strack_pool, detections) + ious = 1 - iou_distance(strack_pool, detections) + ious_dists_mask = (ious < self.proximity_thresh) # o_min in ImprAssoc paper + + num_high_detections = len(detections) + + if self.with_reid: + # ConfTrack version + # emb_dists = embedding_distance(strack_pool, detections) / 2.0 + # raw_emb_dists = emb_dists.copy() + # emb_dists[emb_dists > self.appearance_thresh] = 1.0 + # emb_dists[ious_dists_mask] = 1.0 + # dists = np.minimum(ious_dists, emb_dists) + + # Popular ReID method (JDE / FairMOT) + # raw_emb_dists = matching.embedding_distance(strack_pool, detections) + # dists = matching.fuse_motion(self.kalman_filter, raw_emb_dists, strack_pool, detections) + # emb_dists = dists + + # IoU making ReID + # dists = matching.embedding_distance(strack_pool, detections) + # dists[ious_dists_mask] = 1.0 + + # Improved Association Version (CD) + emb_dists = embedding_distance(strack_pool, detections) # high dets + dists = self.lambda_*d_ious_dists + (1-self.lambda_)*emb_dists + dists[ious_dists_mask] = self.match_thresh + 0.00001 + else: + dists = d_ious_dists + dists[ious_dists_mask] = self.match_thresh + 0.00001 + + # Add in the low score detection boxes + + # association the untrack to the low score detections + if len(dets_second) > 0: + '''Detections''' + detections_second = [STrack(det, max_obs=self.max_obs) for + (det) in np.array(dets_second)] + else: + detections_second = [] + dists_second = iou_distance(strack_pool, detections_second) + dists_second_mask = (dists_second > self.second_match_thresh) # this is what the paper used + dists_second[dists_second_mask] = self.second_match_thresh + 0.00001 + + + B = self.match_thresh/self.second_match_thresh + + combined_dists = np.concatenate((dists, B*dists_second), axis=1) + + matches, track_conf_remain, det_remain = linear_assignment(combined_dists, thresh=self.match_thresh) + + # concat detections so that it all works + detections = np.concatenate((detections, detections_second), axis=0) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(detections[idet], self.frame_count) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_count, new_id=False) + refind_stracks.append(track) + + '''Deal with lost tracks''' + + # left over confirmed tracks get lost + for it in track_conf_remain: + track = strack_pool[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + + '''now do OAI from Improved Association paper''' + # calc the iou between every unmatched det and all tracks if the max iou + # for a det D is above overlap_thresh, discard it. + sdet_remain = [detections[i] for i in det_remain] + + if self.with_reid: + # if we don't need to recompute features + if (self.new_track_thresh >= self.track_high_thresh) and features_high is not None: + features = [features_high[i] for i in det_remain if i < num_high_detections] + else: + bboxes = [track.xyxy for track in sdet_remain] + bboxes = np.array(bboxes) + # (Ndets x X) [512, 1024, 2048] + features = self.model.get_features(bboxes, img) + + unmatched_overlap = 1 - iou_distance(strack_pool, sdet_remain) + + for det_ind in range(unmatched_overlap.shape[1]): # loop over the rows + if len(unmatched_overlap[:, det_ind]) != 0: + if np.max(unmatched_overlap[:, det_ind]) < self.overlap_thresh: + # now initialize it + track = sdet_remain[det_ind] + if track.conf > self.new_track_thresh: + track.activate(self.kalman_filter, self.frame_count) + if self.with_reid: + track.update_features(features[det_ind]) + activated_starcks.append(track) + else: + # if no curr tracks, then init one + track = sdet_remain[det_ind] + if track.conf > self.new_track_thresh: + track.activate(self.kalman_filter, self.frame_count) + if self.with_reid: + track.update_features(features[det_ind]) + activated_starcks.append(track) + + + """ Step 6: Update state""" + for track in self.lost_stracks: + if self.frame_count - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + """ Merge """ + self.active_tracks = [ + t for t in self.active_tracks if t.state == TrackState.Tracked + ] + self.active_tracks = joint_stracks(self.active_tracks, activated_starcks) + self.active_tracks = joint_stracks(self.active_tracks, refind_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.active_tracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.active_tracks, self.lost_stracks = remove_duplicate_stracks( + self.active_tracks, self.lost_stracks + ) + + output_stracks = [track for track in self.active_tracks] + outputs = [] + for t in output_stracks: + output = [] + output.extend(t.xyxy) + output.append(t.id) + output.append(t.conf) + output.append(t.cls) + output.append(t.det_ind) + outputs.append(output) + + outputs = np.asarray(outputs) + return outputs + + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.id] = 1 + res.append(t) + for t in tlistb: + tid = t.id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.id] = t + for t in tlistb: + tid = t.id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_count - stracksa[p].start_frame + timeq = stracksb[q].frame_count - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if i not in dupa] + resb = [t for i, t in enumerate(stracksb) if i not in dupb] + return resa, resb diff --git a/boxmot/trackers/ocsort/__init__.py b/boxmot/trackers/ocsort/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/ocsort/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/ocsort/ocsort.py b/boxmot/trackers/ocsort/ocsort.py new file mode 100644 index 0000000000000000000000000000000000000000..d04e3547eba118566bf5a4cd063a7701e721f06a --- /dev/null +++ b/boxmot/trackers/ocsort/ocsort.py @@ -0,0 +1,395 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +""" + This script is adopted from the SORT script by Alex Bewley alex@bewley.ai +""" +import numpy as np +from collections import deque + + +from boxmot.motion.kalman_filters.aabb.xysr_kf import KalmanFilterXYSR +from boxmot.utils.association import associate, linear_assignment +from boxmot.trackers.basetracker import BaseTracker +from boxmot.utils.ops import xyxy2xysr +from boxmot.motion.kalman_filters.obb.xywha_kf import KalmanBoxTrackerOBB + + +def k_previous_obs(observations, cur_age, k, is_obb=False): + if len(observations) == 0: + if is_obb: + return [-1, -1, -1, -1, -1, -1] + else : + return [-1, -1, -1, -1, -1] + for i in range(k): + dt = k - i + if cur_age - dt in observations: + return observations[cur_age - dt] + max_age = max(observations.keys()) + return observations[max_age] + + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if score is None: + return np.array( + [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0] + ).reshape((1, 4)) + else: + return np.array( + [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score] + ).reshape((1, 5)) + + +def speed_direction(bbox1, bbox2): + cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0 + cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0 + speed = np.array([cy2 - cy1, cx2 - cx1]) + norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6 + return speed / norm + + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + + count = 0 + + def __init__(self, bbox, cls, det_ind, delta_t=3, max_obs=50, Q_xy_scaling = 0.01, Q_s_scaling = 0.0001): + """ + Initialises a tracker using initial bounding box. + + """ + # define constant velocity model + self.det_ind = det_ind + + self.Q_xy_scaling = Q_xy_scaling + self.Q_s_scaling = Q_s_scaling + + self.kf = KalmanFilterXYSR(dim_x=7, dim_z=4, max_obs=max_obs) + self.kf.F = np.array( + [ + [1, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1], + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ] + ) + + self.kf.R[2:, 2:] *= 10.0 + self.kf.P[ + 4:, 4: + ] *= 1000.0 # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10.0 + + self.kf.Q[4:6, 4:6] *= self.Q_xy_scaling + self.kf.Q[-1, -1] *= self.Q_s_scaling + + self.kf.x[:4] = xyxy2xysr(bbox) + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.max_obs = max_obs + self.history = deque([], maxlen=self.max_obs) + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + self.conf = bbox[-1] + self.cls = cls + """ + NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of + function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a + fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), + let's bear it for now. + """ + self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder + self.observations = dict() + self.history_observations = deque([], maxlen=self.max_obs) + self.velocity = None + self.delta_t = delta_t + + def update(self, bbox, cls, det_ind): + """ + Updates the state vector with observed bbox. + """ + self.det_ind = det_ind + if bbox is not None: + self.conf = bbox[-1] + self.cls = cls + if self.last_observation.sum() >= 0: # no previous observation + previous_box = None + for i in range(self.delta_t): + dt = self.delta_t - i + if self.age - dt in self.observations: + previous_box = self.observations[self.age - dt] + break + if previous_box is None: + previous_box = self.last_observation + """ + Estimate the track speed direction with observations \Delta t steps away + """ + self.velocity = speed_direction(previous_box, bbox) + + """ + Insert new observations. This is a ugly way to maintain both self.observations + and self.history_observations. Bear it for the moment. + """ + self.last_observation = bbox + self.observations[self.age] = bbox + self.history_observations.append(bbox) + + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + self.kf.update(xyxy2xysr(bbox)) + else: + self.kf.update(bbox) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if (self.kf.x[6] + self.kf.x[2]) <= 0: + self.kf.x[6] *= 0.0 + + self.kf.predict() + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x)) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return convert_x_to_bbox(self.kf.x) + + +class OcSort(BaseTracker): + """ + OCSort Tracker: A tracking algorithm that utilizes motion-based tracking. + + Args: + per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class. + det_thresh (float, optional): Detection confidence threshold. Detections below this threshold are ignored in the first association step. + max_age (int, optional): Maximum number of frames to keep a track alive without any detections. + min_hits (int, optional): Minimum number of hits required to confirm a track. + asso_threshold (float, optional): Threshold for the association step in data association. Controls the maximum distance allowed between tracklets and detections for a match. + delta_t (int, optional): Time delta for velocity estimation in Kalman Filter. + asso_func (str, optional): Association function to use for data association. Options include "iou" for IoU-based association. + inertia (float, optional): Weight for inertia in motion modeling. Higher values make tracks less responsive to changes. + use_byte (bool, optional): Whether to use BYTE association in the second association step. + Q_xy_scaling (float, optional): Scaling factor for the process noise covariance in the Kalman Filter for position coordinates. + Q_s_scaling (float, optional): Scaling factor for the process noise covariance in the Kalman Filter for scale coordinates. + """ + def __init__( + self, + per_class: bool = False, + min_conf: float = 0.1, + det_thresh: float = 0.2, + max_age: int = 30, + min_hits: int = 3, + asso_threshold: float = 0.3, + delta_t: int = 3, + asso_func: str = "iou", + inertia: float = 0.2, + use_byte: bool = False, + Q_xy_scaling: float = 0.01, + Q_s_scaling: float = 0.0001, + ): + super().__init__(max_age=max_age, per_class=per_class, asso_func=asso_func) + """ + Sets key parameters for SORT + """ + self.per_class = per_class + self.min_conf = min_conf + self.max_age = max_age + self.min_hits = min_hits + self.asso_threshold = asso_threshold + self.frame_count = 0 + self.det_thresh = det_thresh + self.delta_t = delta_t + self.inertia = inertia + self.use_byte = use_byte + self.Q_xy_scaling = Q_xy_scaling + self.Q_s_scaling = Q_s_scaling + KalmanBoxTracker.count = 0 + + @BaseTracker.setup_decorator + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + """ + Params: + dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] + Requires: this method must be called once for each frame even with empty detections + (use np.empty((0, 5)) for frames without detections). + Returns the a similar array, where the last column is the object ID. + NOTE: The number of objects returned may differ from the number of detections provided. + """ + + self.check_inputs(dets, img) + + self.frame_count += 1 + h, w = img.shape[0:2] + + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + confs = dets[:, 4+self.is_obb] + + inds_low = confs > self.min_conf + inds_high = confs < self.det_thresh + inds_second = np.logical_and( + inds_low, inds_high + ) # self.det_thresh > score > 0.1, for second matching + dets_second = dets[inds_second] # detections for second matching + remain_inds = confs > self.det_thresh + dets = dets[remain_inds] + + # get predicted locations from existing trackers. + trks = np.zeros((len(self.active_tracks), 5+self.is_obb)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.active_tracks[t].predict()[0] + trk[:] = [pos[i] for i in range(4+self.is_obb)] + [0] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.active_tracks.pop(t) + + velocities = np.array( + [ + trk.velocity if trk.velocity is not None else np.array((0, 0)) + for trk in self.active_tracks + ] + ) + last_boxes = np.array([trk.last_observation for trk in self.active_tracks]) + + k_observations = np.array( + [ + k_previous_obs(trk.observations, trk.age, self.delta_t, is_obb=self.is_obb) + for trk in self.active_tracks + ] + ) + + """ + First round of association + """ + matched, unmatched_dets, unmatched_trks = associate( + dets[:, 0:5+self.is_obb], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h + ) + for m in matched: + self.active_tracks[m[1]].update(dets[m[0], :-2], dets[m[0], -2], dets[m[0], -1]) + + """ + Second round of associaton by OCR + """ + # BYTE association + if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0: + u_trks = trks[unmatched_trks] + iou_left = self.asso_func( + dets_second, u_trks + ) # iou between low score detections and unmatched tracks + iou_left = np.array(iou_left) + if iou_left.max() > self.asso_threshold: + """ + NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + matched_indices = linear_assignment(-iou_left) + to_remove_trk_indices = [] + for m in matched_indices: + det_ind, trk_ind = m[0], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.asso_threshold: + continue + self.active_tracks[trk_ind].update( + dets_second[det_ind, :-2], dets_second[det_ind, -2], dets_second[det_ind, -1] + ) + to_remove_trk_indices.append(trk_ind) + unmatched_trks = np.setdiff1d( + unmatched_trks, np.array(to_remove_trk_indices) + ) + + if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0: + left_dets = dets[unmatched_dets] + left_trks = last_boxes[unmatched_trks] + iou_left = self.asso_func(left_dets, left_trks) + iou_left = np.array(iou_left) + if iou_left.max() > self.asso_threshold: + """ + NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may + get a higher performance especially on MOT17/MOT20 datasets. But we keep it + uniform here for simplicity + """ + rematched_indices = linear_assignment(-iou_left) + to_remove_det_indices = [] + to_remove_trk_indices = [] + for m in rematched_indices: + det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]] + if iou_left[m[0], m[1]] < self.asso_threshold: + continue + self.active_tracks[trk_ind].update(dets[det_ind, :-2], dets[det_ind, -2], dets[det_ind, -1]) + to_remove_det_indices.append(det_ind) + to_remove_trk_indices.append(trk_ind) + unmatched_dets = np.setdiff1d( + unmatched_dets, np.array(to_remove_det_indices) + ) + unmatched_trks = np.setdiff1d( + unmatched_trks, np.array(to_remove_trk_indices) + ) + + for m in unmatched_trks: + self.active_tracks[m].update(None, None, None) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + if self.is_obb: + trk = KalmanBoxTrackerOBB(dets[i, :-2], dets[i, -2], dets[i, -1], delta_t=self.delta_t, Q_xy_scaling=self.Q_xy_scaling, Q_a_scaling=self.Q_s_scaling, max_obs=self.max_obs) + else: + trk = KalmanBoxTracker(dets[i, :5], dets[i, 5], dets[i, 6], delta_t=self.delta_t, Q_xy_scaling=self.Q_xy_scaling, Q_s_scaling=self.Q_s_scaling, max_obs=self.max_obs) + self.active_tracks.append(trk) + i = len(self.active_tracks) + for trk in reversed(self.active_tracks): + if trk.last_observation.sum() < 0: + d = trk.get_state()[0] + else: + """ + this is optional to use the recent observation or the kalman filter prediction, + we didn't notice significant difference here + """ + d = trk.last_observation[:4+self.is_obb] + if (trk.time_since_update < 1) and ( + trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits + ): + # +1 as MOT benchmark requires positive + ret.append( + np.concatenate((d, [trk.id + 1], [trk.conf], [trk.cls], [trk.det_ind])).reshape( + 1, -1 + ) + ) + i -= 1 + # remove dead tracklet + if trk.time_since_update > self.max_age: + self.active_tracks.pop(i) + if len(ret) > 0: + return np.concatenate(ret) + return np.array([]) \ No newline at end of file diff --git a/boxmot/trackers/strongsort/__init__.py b/boxmot/trackers/strongsort/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/strongsort/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/strongsort/sort/__init__.py b/boxmot/trackers/strongsort/sort/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a4d0f1f9a635d5ccf0932bbf5c2069529a7ac --- /dev/null +++ b/boxmot/trackers/strongsort/sort/__init__.py @@ -0,0 +1 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license diff --git a/boxmot/trackers/strongsort/sort/detection.py b/boxmot/trackers/strongsort/sort/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..5f94b5330a5d1b32c82896cc71b09360dd0c0134 --- /dev/null +++ b/boxmot/trackers/strongsort/sort/detection.py @@ -0,0 +1,41 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +class Detection(object): + """ + This class represents a bounding box detection in a single image. + + Parameters + ---------- + tlwh : array_like + Bounding box in format `(x, y, w, h)`. + confidence : float + Detector confidence score. + feature : array_like + A feature vector that describes the object contained in this image. + + Attributes + ---------- + tlwh : ndarray + Bounding box in format `(top left x, top left y, width, height)`. + confidence : ndarray + Detector confidence score. + feature : ndarray | NoneType + A feature vector that describes the object contained in this image. + + """ + + def __init__(self, tlwh, conf, cls, det_ind, feat): + self.tlwh = tlwh + self.conf = conf + self.cls = cls + self.det_ind = det_ind + self.feat = feat + + def to_xyah(self): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret diff --git a/boxmot/trackers/strongsort/sort/iou_matching.py b/boxmot/trackers/strongsort/sort/iou_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..3483ce47248cf334e02760143a15a8c812b597b0 --- /dev/null +++ b/boxmot/trackers/strongsort/sort/iou_matching.py @@ -0,0 +1,87 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import + +import numpy as np + +from . import linear_assignment + + +def iou(bbox, candidates): + """Computer intersection over union. + + Parameters + ---------- + bbox : ndarray + A bounding box in format `(top left x, top left y, width, height)`. + candidates : ndarray + A matrix of candidate bounding boxes (one per row) in the same format + as `bbox`. + + Returns + ------- + ndarray + The intersection over union in [0, 1] between the `bbox` and each + candidate. A higher score means a larger fraction of the `bbox` is + occluded by the candidate. + + """ + bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] + candidates_tl = candidates[:, :2] + candidates_br = candidates[:, :2] + candidates[:, 2:] + + tl = np.c_[ + np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], + np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis], + ] + br = np.c_[ + np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], + np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis], + ] + wh = np.maximum(0.0, br - tl) + + area_intersection = wh.prod(axis=1) + area_bbox = bbox[2:].prod() + area_candidates = candidates[:, 2:].prod(axis=1) + return area_intersection / (area_bbox + area_candidates - area_intersection) + + +def iou_cost(tracks, detections, track_indices=None, detection_indices=None): + """An intersection over union distance metric. + + Parameters + ---------- + tracks : List[deep_sort.track.Track] + A list of tracks. + detections : List[deep_sort.detection.Detection] + A list of detections. + track_indices : Optional[List[int]] + A list of indices to tracks that should be matched. Defaults to + all `tracks`. + detection_indices : Optional[List[int]] + A list of indices to detections that should be matched. Defaults + to all `detections`. + + Returns + ------- + ndarray + Returns a cost matrix of shape + len(track_indices), len(detection_indices) where entry (i, j) is + `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. + + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + cost_matrix = np.zeros((len(track_indices), len(detection_indices))) + for row, track_idx in enumerate(track_indices): + if tracks[track_idx].time_since_update > 1: + cost_matrix[row, :] = linear_assignment.INFTY_COST + continue + + bbox = tracks[track_idx].to_tlwh() + candidates = np.asarray([detections[i].tlwh for i in detection_indices]) + cost_matrix[row, :] = 1.0 - iou(bbox, candidates) + return cost_matrix diff --git a/boxmot/trackers/strongsort/sort/linear_assignment.py b/boxmot/trackers/strongsort/sort/linear_assignment.py new file mode 100644 index 0000000000000000000000000000000000000000..c27c4c303682f8209bc09c5703000d64b4cc5322 --- /dev/null +++ b/boxmot/trackers/strongsort/sort/linear_assignment.py @@ -0,0 +1,200 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from boxmot.utils.matching import chi2inv95 + +INFTY_COST = 1e5 + + +def min_cost_matching( + distance_metric, + max_distance, + tracks, + detections, + track_indices=None, + detection_indices=None, +): + """Solve linear assignment problem. + Parameters + ---------- + distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as well as + a list of N track indices and M detection indices. The metric should + return the NxM dimensional cost matrix, where element (i, j) is the + association cost between the i-th track in the given track indices and + the j-th detection in the given detection_indices. + max_distance : float + Gating threshold. Associations with cost larger than this value are + disregarded. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : List[int] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). + detection_indices : List[int] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). + Returns + ------- + (List[(int, int)], List[int], List[int]) + Returns a tuple with the following three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + if len(detection_indices) == 0 or len(track_indices) == 0: + return [], track_indices, detection_indices # Nothing to match. + + cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices) + cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 + row_indices, col_indices = linear_sum_assignment(cost_matrix) + + matches, unmatched_tracks, unmatched_detections = [], [], [] + for col, detection_idx in enumerate(detection_indices): + if col not in col_indices: + unmatched_detections.append(detection_idx) + for row, track_idx in enumerate(track_indices): + if row not in row_indices: + unmatched_tracks.append(track_idx) + for row, col in zip(row_indices, col_indices): + track_idx = track_indices[row] + detection_idx = detection_indices[col] + if cost_matrix[row, col] > max_distance: + unmatched_tracks.append(track_idx) + unmatched_detections.append(detection_idx) + else: + matches.append((track_idx, detection_idx)) + return matches, unmatched_tracks, unmatched_detections + + +def matching_cascade( + distance_metric, + max_distance, + cascade_depth, + tracks, + detections, + track_indices=None, + detection_indices=None, +): + """Run matching cascade. + Parameters + ---------- + distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as well as + a list of N track indices and M detection indices. The metric should + return the NxM dimensional cost matrix, where element (i, j) is the + association cost between the i-th track in the given track indices and + the j-th detection in the given detection indices. + max_distance : float + Gating threshold. Associations with cost larger than this value are + disregarded. + cascade_depth: int + The cascade depth, should be se to the maximum track age. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : Optional[List[int]] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). Defaults to all tracks. + detection_indices : Optional[List[int]] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). Defaults to all + detections. + Returns + ------- + (List[(int, int)], List[int], List[int]) + Returns a tuple with the following three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = list(range(len(tracks))) + if detection_indices is None: + detection_indices = list(range(len(detections))) + + unmatched_detections = detection_indices + matches = [] + track_indices_l = [k for k in track_indices] + matches_l, _, unmatched_detections = min_cost_matching( + distance_metric, + max_distance, + tracks, + detections, + track_indices_l, + unmatched_detections, + ) + matches += matches_l + unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) + return matches, unmatched_tracks, unmatched_detections + + +def gate_cost_matrix( + cost_matrix, + tracks, + detections, + track_indices, + detection_indices, + mc_lambda, + gated_cost=INFTY_COST, + only_position=False, +): + """Invalidate infeasible entries in cost matrix based on the state + distributions obtained by Kalman filtering. + Parameters + ---------- + kf : The Kalman filter. + cost_matrix : ndarray + The NxM dimensional cost matrix, where N is the number of track indices + and M is the number of detection indices, such that entry (i, j) is the + association cost between `tracks[track_indices[i]]` and + `detections[detection_indices[j]]`. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : List[int] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). + detection_indices : List[int] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). + gated_cost : Optional[float] + Entries in the cost matrix corresponding to infeasible associations are + set this value. Defaults to a very large value. + only_position : Optional[bool] + If True, only the x, y position of the state distribution is considered + during gating. Defaults to False. + Returns + ------- + ndarray + Returns the modified cost matrix. + """ + + gating_threshold = chi2inv95[4] + measurements = np.asarray([detections[i].to_xyah() for i in detection_indices]) + for row, track_idx in enumerate(track_indices): + track = tracks[track_idx] + gating_distance = track.kf.gating_distance( + track.mean, + track.covariance, + measurements, + only_position + ) + cost_matrix[row, gating_distance > gating_threshold] = gated_cost + cost_matrix[row] = ( + mc_lambda * cost_matrix[row] + (1 - mc_lambda) * gating_distance + ) + return cost_matrix diff --git a/boxmot/trackers/strongsort/sort/track.py b/boxmot/trackers/strongsort/sort/track.py new file mode 100644 index 0000000000000000000000000000000000000000..c33367819c0b1b8c8c6f0c43e7fe22b494a2034a --- /dev/null +++ b/boxmot/trackers/strongsort/sort/track.py @@ -0,0 +1,200 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import os +import numpy as np + +from boxmot.motion.kalman_filters.aabb.xyah_kf import KalmanFilterXYAH + + +class TrackState: + """ + Enumeration type for the single target track state. Newly created tracks are + classified as `tentative` until enough evidence has been collected. Then, + the track state is changed to `confirmed`. Tracks that are no longer alive + are classified as `deleted` to mark them for removal from the set of active + tracks. + + """ + + Tentative = 1 + Confirmed = 2 + Deleted = 3 + + +class Track: + """ + A single target track with state space `(x, y, a, h)` and associated + velocities, where `(x, y)` is the center of the bounding box, `a` is the + aspect ratio and `h` is the height. + + Parameters + ---------- + mean : ndarray + Mean vector of the initial state distribution. + covariance : ndarray + Covariance matrix of the initial state distribution. + track_id : int + A unique track identifier. + n_init : int + Number of consecutive detections before the track is confirmed. The + track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + max_age : int + The maximum number of consecutive misses before the track state is + set to `Deleted`. + feature : Optional[ndarray] + Feature vector of the detection this track originates from. If not None, + this feature is added to the `features` cache. + + Attributes + ---------- + mean : ndarray + Mean vector of the initial state distribution. + covariance : ndarray + Covariance matrix of the initial state distribution. + track_id : int + A unique track identifier. + hits : int + Total number of measurement updates. + age : int + Total number of frames since first occurance. + time_since_update : int + Total number of frames since last measurement update. + state : TrackState + The current track state. + features : List[ndarray] + A cache of features. On each measurement update, the associated feature + vector is added to this list. + + """ + + def __init__( + self, + detection, + id, + n_init, + max_age, + ema_alpha, + ): + self.id = id + self.bbox = detection.to_xyah() + self.conf = detection.conf + self.cls = detection.cls + self.det_ind = detection.det_ind + self.hits = 1 + self.age = 1 + self.time_since_update = 0 + self.ema_alpha = ema_alpha + + # start with confirmed in Ci as test expect equal amount of outputs as inputs + self.state = TrackState.Confirmed if (os.getenv('GITHUB_ACTIONS') == 'true' and os.getenv('GITHUB_JOB') != 'mot-metrics-benchmark') else TrackState.Tentative + self.features = [] + if detection.feat is not None: + detection.feat /= np.linalg.norm(detection.feat) + self.features.append(detection.feat) + + self._n_init = n_init + self._max_age = max_age + + self.kf = KalmanFilterXYAH() + self.mean, self.covariance = self.kf.initiate(self.bbox) + + def to_tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + + Returns + ------- + ndarray + The bounding box. + + """ + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + def to_tlbr(self): + """Get kf estimated current position in bounding box format `(min x, miny, max x, + max y)`. + + Returns + ------- + ndarray + The predicted kf bounding box. + + """ + ret = self.to_tlwh() + ret[2:] = ret[:2] + ret[2:] + return ret + + def camera_update(self, warp_matrix): + [a, b] = warp_matrix + warp_matrix = np.array([a, b, [0, 0, 1]]) + warp_matrix = warp_matrix.tolist() + x1, y1, x2, y2 = self.to_tlbr() + x1_, y1_, _ = warp_matrix @ np.array([x1, y1, 1]).T + x2_, y2_, _ = warp_matrix @ np.array([x2, y2, 1]).T + w, h = x2_ - x1_, y2_ - y1_ + cx, cy = x1_ + w / 2, y1_ + h / 2 + self.mean[:4] = [cx, cy, w / h, h] + + def increment_age(self): + self.age += 1 + self.time_since_update += 1 + + def predict(self): + """Propagate the state distribution to the current time step using a + Kalman filter prediction step. + """ + self.mean, self.covariance = self.kf.predict(self.mean, self.covariance) + self.age += 1 + self.time_since_update += 1 + + def update(self, detection): + """Perform Kalman filter measurement update step and update the feature + cache. + Parameters + ---------- + detection : Detection + The associated detection. + """ + self.bbox = detection.to_xyah() + self.conf = detection.conf + self.cls = detection.cls + self.det_ind = detection.det_ind + self.mean, self.covariance = self.kf.update( + self.mean, self.covariance, self.bbox, self.conf + ) + + feature = detection.feat / np.linalg.norm(detection.feat) + + smooth_feat = ( + self.ema_alpha * self.features[-1] + (1 - self.ema_alpha) * feature + ) + smooth_feat /= np.linalg.norm(smooth_feat) + self.features = [smooth_feat] + + self.hits += 1 + self.time_since_update = 0 + if self.state == TrackState.Tentative and self.hits >= self._n_init: + self.state = TrackState.Confirmed + + def mark_missed(self): + """Mark this track as missed (no association at the current time step).""" + if self.state == TrackState.Tentative: + self.state = TrackState.Deleted + elif self.time_since_update > self._max_age: + self.state = TrackState.Deleted + + def is_tentative(self): + """Returns True if this track is tentative (unconfirmed).""" + return self.state == TrackState.Tentative + + def is_confirmed(self): + """Returns True if this track is confirmed.""" + return self.state == TrackState.Confirmed + + def is_deleted(self): + """Returns True if this track is dead and should be deleted.""" + return self.state == TrackState.Deleted diff --git a/boxmot/trackers/strongsort/sort/tracker.py b/boxmot/trackers/strongsort/sort/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..c5721cda97fd9d2cec81ed5eb81438e6a0b7a2d4 --- /dev/null +++ b/boxmot/trackers/strongsort/sort/tracker.py @@ -0,0 +1,169 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from __future__ import absolute_import + +import numpy as np + +from boxmot.motion.cmc import get_cmc_method +from boxmot.trackers.strongsort.sort import iou_matching, linear_assignment +from boxmot.trackers.strongsort.sort.track import Track +from boxmot.utils.matching import chi2inv95 + + +class Tracker: + """ + This is the multi-target tracker. + Parameters + ---------- + metric : nn_matching.NearestNeighborDistanceMetric + A distance metric for measurement-to-track association. + max_age : int + Maximum number of missed misses before a track is deleted. + n_init : int + Number of consecutive detections before the track is confirmed. The + track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + Attributes + ---------- + metric : nn_matching.NearestNeighborDistanceMetric + The distance metric used for measurement to track association. + max_age : int + Maximum number of missed misses before a track is deleted. + n_init : int + Number of frames that a track remains in initialization phase. + tracks : List[Track] + The list of active tracks at the current time step. + """ + + GATING_THRESHOLD = np.sqrt(chi2inv95[4]) + + def __init__( + self, + metric, + max_iou_dist=0.9, + max_age=30, + n_init=3, + _lambda=0, + ema_alpha=0.9, + mc_lambda=0.995, + ): + self.metric = metric + self.max_iou_dist = max_iou_dist + self.max_age = max_age + self.n_init = n_init + self._lambda = _lambda + self.ema_alpha = ema_alpha + self.mc_lambda = mc_lambda + + self.tracks = [] + self._next_id = 1 + self.cmc = get_cmc_method('ecc')() + + def predict(self): + """Propagate track state distributions one time step forward. + + This function should be called once every time step, before `update`. + """ + for track in self.tracks: + track.predict() + + def increment_ages(self): + for track in self.tracks: + track.increment_age() + track.mark_missed() + + def update(self, detections): + """Perform measurement update and track management. + + Parameters + ---------- + detections : List[deep_sort.detection.Detection] + A list of detections at the current time step. + + """ + # Run matching cascade. + matches, unmatched_tracks, unmatched_detections = self._match(detections) + + # Update track set. + for track_idx, detection_idx in matches: + self.tracks[track_idx].update(detections[detection_idx]) + for track_idx in unmatched_tracks: + self.tracks[track_idx].mark_missed() + for detection_idx in unmatched_detections: + self._initiate_track(detections[detection_idx]) + self.tracks = [t for t in self.tracks if not t.is_deleted()] + + # Update distance metric. + active_targets = [t.id for t in self.tracks if t.is_confirmed()] + features, targets = [], [] + for track in self.tracks: + if not track.is_confirmed(): + continue + features += track.features + targets += [track.id for _ in track.features] + self.metric.partial_fit( + np.asarray(features), np.asarray(targets), active_targets + ) + + def _match(self, detections): + def gated_metric(tracks, dets, track_indices, detection_indices): + features = np.array([dets[i].feat for i in detection_indices]) + targets = np.array([tracks[i].id for i in track_indices]) + cost_matrix = self.metric.distance(features, targets) + cost_matrix = linear_assignment.gate_cost_matrix( + cost_matrix, + tracks, + dets, + track_indices, + detection_indices, + self.mc_lambda, + ) + + return cost_matrix + + # Split track set into confirmed and unconfirmed tracks. + confirmed_tracks = [i for i, t in enumerate(self.tracks) if t.is_confirmed()] + unconfirmed_tracks = [i for i, t in enumerate(self.tracks) if not t.is_confirmed()] + + # Associate confirmed tracks using appearance features. + matches_a, unmatched_tracks_a, unmatched_detections = linear_assignment.matching_cascade( + gated_metric, + self.metric.matching_threshold, + self.max_age, + self.tracks, + detections, + confirmed_tracks, + ) + + # Associate remaining tracks together with unconfirmed tracks using IOU. + iou_track_candidates = unconfirmed_tracks + [ + k for k in unmatched_tracks_a if self.tracks[k].time_since_update == 1 + ] + unmatched_tracks_a = [ + k for k in unmatched_tracks_a if self.tracks[k].time_since_update != 1 + ] + + matches_b, unmatched_tracks_b, unmatched_detections = linear_assignment.min_cost_matching( + iou_matching.iou_cost, + self.max_iou_dist, + self.tracks, + detections, + iou_track_candidates, + unmatched_detections, + ) + + matches = matches_a + matches_b + unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) + return matches, unmatched_tracks, unmatched_detections + + def _initiate_track(self, detection): + self.tracks.append( + Track( + detection, + self._next_id, + self.n_init, + self.max_age, + self.ema_alpha, + ) + ) + self._next_id += 1 diff --git a/boxmot/trackers/strongsort/strongsort.py b/boxmot/trackers/strongsort/strongsort.py new file mode 100644 index 0000000000000000000000000000000000000000..557380476492ea22b4048358f4bfcc9cbb63ffb7 --- /dev/null +++ b/boxmot/trackers/strongsort/strongsort.py @@ -0,0 +1,129 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +from torch import device +from pathlib import Path + +from boxmot.appearance.reid.auto_backend import ReidAutoBackend +from boxmot.motion.cmc import get_cmc_method +from boxmot.trackers.strongsort.sort.detection import Detection +from boxmot.trackers.strongsort.sort.tracker import Tracker +from boxmot.utils.matching import NearestNeighborDistanceMetric +from boxmot.utils.ops import xyxy2tlwh +from boxmot.trackers.basetracker import BaseTracker + + +class StrongSort(object): + """ + StrongSORT Tracker: A tracking algorithm that utilizes a combination of appearance and motion-based tracking. + + Args: + model_weights (str): Path to the model weights for ReID (Re-Identification). + device (str): Device on which to run the model (e.g., 'cpu' or 'cuda'). + fp16 (bool): Whether to use half-precision (fp16) for faster inference on compatible devices. + per_class (bool, optional): Whether to perform per-class tracking. If True, tracks are maintained separately for each object class. + max_dist (float, optional): Maximum cosine distance for ReID feature matching in Nearest Neighbor Distance Metric. + max_iou_dist (float, optional): Maximum Intersection over Union (IoU) distance for data association. Controls the maximum allowed distance between tracklets and detections for a match. + max_age (int, optional): Maximum number of frames to keep a track alive without any detections. + n_init (int, optional): Number of consecutive frames required to confirm a track. + nn_budget (int, optional): Maximum size of the feature library for Nearest Neighbor Distance Metric. If the library size exceeds this value, the oldest features are removed. + mc_lambda (float, optional): Weight for motion consistency in the track state estimation. Higher values give more weight to motion information. + ema_alpha (float, optional): Alpha value for exponential moving average (EMA) update of appearance features. Controls the contribution of new and old embeddings in the ReID model. + """ + def __init__( + self, + reid_weights: Path, + device: device, + half: bool, + per_class: bool = False, + min_conf: float = 0.1, + max_cos_dist=0.2, + max_iou_dist=0.7, + max_age=30, + n_init=3, + nn_budget=100, + mc_lambda=0.98, + ema_alpha=0.9, + ): + + self.per_class = per_class + self.min_conf = min_conf + self.model = ReidAutoBackend( + weights=reid_weights, device=device, half=half + ).model + + self.tracker = Tracker( + metric=NearestNeighborDistanceMetric("cosine", max_cos_dist, nn_budget), + max_iou_dist=max_iou_dist, + max_age=max_age, + n_init=n_init, + mc_lambda=mc_lambda, + ema_alpha=ema_alpha, + ) + self.cmc = get_cmc_method('ecc')() + + @BaseTracker.per_class_decorator + def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: + assert isinstance( + dets, np.ndarray + ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" + assert isinstance( + img, np.ndarray + ), f"Unsupported 'img' input format '{type(img)}', valid format is np.ndarray" + assert ( + len(dets.shape) == 2 + ), "Unsupported 'dets' dimensions, valid number of dimensions is two" + assert ( + dets.shape[1] == 6 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + + dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) + remain_inds = dets[:, 4] >= self.min_conf + dets = dets[remain_inds] + + xyxy = dets[:, 0:4] + confs = dets[:, 4] + clss = dets[:, 5] + det_ind = dets[:, 6] + + if len(self.tracker.tracks) >= 1: + warp_matrix = self.cmc.apply(img, xyxy) + for track in self.tracker.tracks: + track.camera_update(warp_matrix) + + # extract appearance information for each detection + if embs is not None: + features = embs[remain_inds] + else: + features = self.model.get_features(xyxy, img) + + tlwh = xyxy2tlwh(xyxy) + detections = [ + Detection(box, conf, cls, det_ind, feat) for + box, conf, cls, det_ind, feat in + zip(tlwh, confs, clss, det_ind, features) + ] + + # update tracker + self.tracker.predict() + self.tracker.update(detections) + + # output bbox identities + outputs = [] + for track in self.tracker.tracks: + if not track.is_confirmed() or track.time_since_update >= 1: + continue + + x1, y1, x2, y2 = track.to_tlbr() + + id = track.id + conf = track.conf + cls = track.cls + det_ind = track.det_ind + + outputs.append( + np.concatenate(([x1, y1, x2, y2], [id], [conf], [cls], [det_ind])).reshape(1, -1) + ) + if len(outputs) > 0: + return np.concatenate(outputs) + return np.array([]) diff --git a/boxmot/trackers/strongsort/strongsort_kf.py b/boxmot/trackers/strongsort/strongsort_kf.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf8195b28a9836791479baa6551b741b230aa61 --- /dev/null +++ b/boxmot/trackers/strongsort/strongsort_kf.py @@ -0,0 +1,233 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import scipy.linalg +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, a, h, vx, vy, va, vh + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """Create track from unassociated measurement. + + Parameters + ---------- + measurement : ndarray + Bounding box coordinates (x, y, a, h) with center position (x, y), + aspect ratio a, and height h. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are initialized + to 0 mean. + + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3]] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """Run Kalman filter prediction step. + + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3]] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3]] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(self._motion_mat, mean) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance, confidence=.0): + """Project state distribution to measurement space. + + Parameters + ---------- + mean : ndarray + The state's mean vector (8 dimensional array). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + confidence: (dyh) 检测框置信度 + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3]] + + + std = [(1 - confidence) * x for x in std] + + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def update(self, mean, covariance, measurement, confidence=.0): + """Run Kalman filter correction step. + + Parameters + ---------- + mean : ndarray + The predicted state's mean vector (8 dimensional). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + measurement : ndarray + The 4 dimensional measurement vector (x, y, a, h), where (x, y) + is the center position, a the aspect ratio, and h the height of the + bounding box. + confidence: (dyh)检测框置信度 + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + + """ + projected_mean, projected_cov = self.project(mean, covariance, confidence) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, + only_position=False): + """Compute gating distance between state distribution and measurements. + + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + + Parameters + ---------- + mean : ndarray + Mean vector over the state distribution (8 dimensional). + covariance : ndarray + Covariance of the state distribution (8x8 dimensional). + measurements : ndarray + An Nx4 dimensional matrix of N measurements, each in + format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position : Optional[bool] + If True, distance computation is done with respect to the bounding + box center position only. + + Returns + ------- + ndarray + Returns an array of length N, where the i-th element contains the + squared Mahalanobis distance between (mean, covariance) and + `measurements[i]`. + + """ + mean, covariance = self.project(mean, covariance) + + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + cholesky_factor = np.linalg.cholesky(covariance) + d = measurements - mean + z = scipy.linalg.solve_triangular( + cholesky_factor, d.T, lower=True, check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha \ No newline at end of file diff --git a/boxmot/utils/__init__.py b/boxmot/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b348428d9c8b31d7f5a4931e1cb2090d74da1c73 --- /dev/null +++ b/boxmot/utils/__init__.py @@ -0,0 +1,29 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import os +import sys +from pathlib import Path +import numpy as np +# global logger +from loguru import logger +import threading + + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[2] # root directory +DATA = ROOT / 'data' +BOXMOT = ROOT / "boxmot" +EXAMPLES = ROOT / "tracking" +TRACKER_CONFIGS = ROOT / "boxmot" / "configs" +WEIGHTS = ROOT / "tracking" / "weights" +REQUIREMENTS = ROOT / "requirements.txt" + +NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of BoxMOT multiprocessing threads + + +def only_main_thread(record): + # Check if the current thread is the main thread + return threading.current_thread().name == "MainThread" + +logger.remove() +logger.add(sys.stderr, filter=only_main_thread, colorize=True, level="INFO") \ No newline at end of file diff --git a/boxmot/utils/association.py b/boxmot/utils/association.py new file mode 100644 index 0000000000000000000000000000000000000000..729104f0d0e6a76f93e79004d80d13e75aa91980 --- /dev/null +++ b/boxmot/utils/association.py @@ -0,0 +1,284 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np + +from boxmot.utils.iou import AssociationFunction + + +def speed_direction_batch(dets, tracks): + tracks = tracks[..., np.newaxis] + CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0 + CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (tracks[:, 1] + tracks[:, 3]) / 2.0 + dx = CX1 - CX2 + dy = CY1 - CY2 + norm = np.sqrt(dx**2 + dy**2) + 1e-6 + dx = dx / norm + dy = dy / norm + return dy, dx # size: num_track x num_det + + +def linear_assignment(cost_matrix): + try: + import lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True) + return np.array([[y[i], i] for i in x if i >= 0]) # + except ImportError: + from scipy.optimize import linear_sum_assignment + x, y = linear_sum_assignment(cost_matrix) + return np.array([list(zip(x, y))]) + + +def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3): + """ + Assigns detections to tracked object (both represented as bounding boxes) + Returns 3 lists of matches, unmatched_detections and unmatched_trackers + """ + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + + iou_matrix = AssociationFunction.iou_batch(detections, trackers) + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-iou_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def compute_aw_max_metric(emb_cost, w_association_emb, bottom=0.5): + w_emb = np.full_like(emb_cost, w_association_emb) + + for idx in range(emb_cost.shape[0]): + inds = np.argsort(-emb_cost[idx]) + # If there's less than two matches, just keep original weight + if len(inds) < 2: + continue + if emb_cost[idx, inds[0]] == 0: + row_weight = 0 + else: + row_weight = 1 - max( + (emb_cost[idx, inds[1]] / emb_cost[idx, inds[0]]) - bottom, 0 + ) / (1 - bottom) + w_emb[idx] *= row_weight + + for idj in range(emb_cost.shape[1]): + inds = np.argsort(-emb_cost[:, idj]) + # If there's less than two matches, just keep original weight + if len(inds) < 2: + continue + if emb_cost[inds[0], idj] == 0: + col_weight = 0 + else: + col_weight = 1 - max( + (emb_cost[inds[1], idj] / emb_cost[inds[0], idj]) - bottom, 0 + ) / (1 - bottom) + w_emb[:, idj] *= col_weight + + return w_emb * emb_cost + + +def associate( + detections, + trackers, + asso_func, + iou_threshold, + velocities, + previous_obs, + vdc_weight, + w, + h, + emb_cost=None, + w_assoc_emb=None, + aw_off=None, + aw_param=None, + +): + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + + iou_matrix = asso_func(detections, trackers) + #iou_matrix = iou_batch(detections, trackers) + scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + # iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + if min(iou_matrix.shape): + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + if emb_cost is None: + emb_cost = 0 + else: + emb_cost = emb_cost + emb_cost[iou_matrix <= 0] = 0 + if not aw_off: + emb_cost = compute_aw_max_metric(emb_cost, w_assoc_emb, bottom=aw_param) + else: + emb_cost *= w_assoc_emb + + final_cost = -(iou_matrix + angle_diff_cost + emb_cost) + matched_indices = linear_assignment(final_cost) + if matched_indices.size == 0: + matched_indices = np.empty(shape=(0, 2)) + + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def associate_kitti( + detections, trackers, det_cates, iou_threshold, velocities, previous_obs, vdc_weight +): + if len(trackers) == 0: + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + + """ + Cost from the velocity direction consistency + """ + Y, X = speed_direction_batch(detections, previous_obs) + inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1] + inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1) + inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1) + diff_angle_cos = inertia_X * X + inertia_Y * Y + diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1) + diff_angle = np.arccos(diff_angle_cos) + diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi + + valid_mask = np.ones(previous_obs.shape[0]) + valid_mask[np.where(previous_obs[:, 4] < 0)] = 0 + valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1) + + scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1) + angle_diff_cost = (valid_mask * diff_angle) * vdc_weight + angle_diff_cost = angle_diff_cost.T + angle_diff_cost = angle_diff_cost * scores + + """ + Cost from IoU + """ + iou_matrix = AssociationFunction.iou_batch(detections, trackers) + + """ + With multiple categories, generate the cost for catgory mismatch + """ + num_dets = detections.shape[0] + num_trk = trackers.shape[0] + cate_matrix = np.zeros((num_dets, num_trk)) + for i in range(num_dets): + for j in range(num_trk): + if det_cates[i] != trackers[j, 4]: + cate_matrix[i][j] = -1e6 + + cost_matrix = -iou_matrix - angle_diff_cost - cate_matrix + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(cost_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) diff --git a/boxmot/utils/checks.py b/boxmot/utils/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..05ff1774fbeb2fed08c025205b49967625a515fc --- /dev/null +++ b/boxmot/utils/checks.py @@ -0,0 +1,44 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import pkg_resources +from boxmot.utils import logger as LOGGER +from pathlib import Path +import subprocess + +from boxmot.utils import logger as LOGGER +REQUIREMENTS = Path('requirements.txt') + +class RequirementsChecker: + + def check_requirements(self): + # Use a context manager to open the requirements file safely. + with REQUIREMENTS.open() as f: + requirements = pkg_resources.parse_requirements(f) + self.check_packages(requirements) + + def check_packages(self, requirements, cmds=''): + """Test that each required package is available.""" + missing_packages = [] + for r in requirements: + try: + pkg_resources.require(str(r)) + except Exception as e: + LOGGER.error(f'{e}') + missing_packages.append(str(r)) + + if missing_packages: + self.install_packages(missing_packages, cmds) + + def install_packages(self, packages, cmds=''): + try: + LOGGER.warning( + f'\nMissing packages: {", ".join(packages)}\nAttempting installation...' + ) + # Construct pip command arguments. + pip_args = ['install', '--no-cache-dir'] + packages + cmds.split() + # Use subprocess to call pip. + subprocess.check_call(['uv', 'pip'] + pip_args) + LOGGER.info('All the missing packages were installed successfully') + except Exception as e: + LOGGER.error(f'Failed to install packages: {e}') + raise RuntimeError(f'Failed to install packages: {e}') \ No newline at end of file diff --git a/boxmot/utils/iou.py b/boxmot/utils/iou.py new file mode 100644 index 0000000000000000000000000000000000000000..19fdcc9f7c0a707a5f3475c2c1f030e98d390a18 --- /dev/null +++ b/boxmot/utils/iou.py @@ -0,0 +1,348 @@ +import numpy as np +import cv2 as cv + +def iou_obb_pair(i, j, bboxes1, bboxes2): + """ + Compute IoU for the rotated rectangles at index i and j in the batches `bboxes1`, `bboxes2` . + """ + rect1 = bboxes1[int(i)] + rect2 = bboxes2[int(j)] + + (cx1, cy1, w1, h1, angle1) = rect1[0:5] + (cx2, cy2, w2, h2, angle2) = rect2[0:5] + + + r1 = ((cx1, cy1), (w1, h1), angle1) + r2 = ((cx2, cy2), (w2, h2), angle2) + + # Compute intersection + ret, intersect = cv.rotatedRectangleIntersection(r1, r2) + if ret == 0 or intersect is None: + return 0.0 # No intersection + + # Calculate intersection area + intersection_area = cv.contourArea(intersect) + + # Calculate union area + area1 = w1 * h1 + area2 = w2 * h2 + union_area = area1 + area2 - intersection_area + + # Compute IoU + return intersection_area / union_area if union_area > 0 else 0.0 + +class AssociationFunction: + def __init__(self, w, h, asso_mode="iou"): + """ + Initializes the AssociationFunction class with the necessary parameters for bounding box operations. + The association function is selected based on the `asso_mode` string provided during class creation. + + Parameters: + w (int): The width of the frame, used for normalizing centroid distance. + h (int): The height of the frame, used for normalizing centroid distance. + asso_mode (str): The association function to use (e.g., "iou", "giou", "centroid", etc.). + """ + self.w = w + self.h = h + self.asso_mode = asso_mode + self.asso_func = self._get_asso_func(asso_mode) + + @staticmethod + def iou_batch(bboxes1, bboxes2) -> np.ndarray: + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + o = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - + wh + ) + return o + + @staticmethod + def iou_batch_obb(bboxes1, bboxes2) -> np.ndarray: + + N, M = len(bboxes1), len(bboxes2) + + def wrapper(i, j): + return iou_obb_pair(i, j, bboxes1, bboxes2) + + iou_matrix = np.fromfunction(np.vectorize(wrapper), shape=(N, M), dtype=int) + return iou_matrix + + @staticmethod + def hmiou_batch(bboxes1, bboxes2): + """ + Compute a modified Intersection over Union (hIoU) between two batches of bounding boxes, + incorporating a vertical overlap ratio. + + Parameters: + - bboxes1: (N, 4) array of bounding boxes [x1, y1, x2, y2] + - bboxes2: (M, 4) array of bounding boxes [x1, y1, x2, y2] + + Returns: + - hmiou: (N, M) array where hmiou[i, j] is the modified IoU between bboxes1[i] and bboxes2[j] + """ + # Expand dimensions for broadcasting + bboxes1 = np.expand_dims(bboxes1, axis=1) # Shape: (N, 1, 4) + bboxes2 = np.expand_dims(bboxes2, axis=0) # Shape: (1, M, 4) + + # Compute vertical overlap ratio 'o' + intersect_y1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + intersect_y2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + intersection_height = np.maximum(0.0, intersect_y2 - intersect_y1) + + union_y1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + union_y2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + union_height = np.maximum(1e-10, union_y2 - union_y1) + + o = intersection_height / union_height + + # Compute standard IoU + inter_x1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + inter_y1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + inter_x2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + inter_y2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + + inter_w = np.maximum(0.0, inter_x2 - inter_x1) + inter_h = np.maximum(0.0, inter_y2 - inter_y1) + inter_area = inter_w * inter_h + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) # Shape: (N, 1) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) # Shape: (1, M) + + union_area = area1 + area2 - inter_area + + iou = inter_area / (union_area + 1e-10) + + # Modify IoU with vertical overlap ratio + hmiou = iou * o + + return hmiou + + @staticmethod + def giou_batch(bboxes1, bboxes2) -> np.ndarray: + """ + :param bboxes1: predict of bbox(N,4)(x1,y1,x2,y2) + :param bboxes2: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # Ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h # Intersection area + + # Compute areas of individual boxes + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + + # Union area + union_area = area1 + area2 - wh + + iou = wh / union_area + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + wc = xxc2 - xxc1 + hc = yyc2 - yyc1 + assert (wc > 0).all() and (hc > 0).all() + area_enclose = wc * hc # Area of the smallest enclosing box + + # Corrected GIoU computation + giou = iou - (area_enclose - union_area) / area_enclose + giou = (giou + 1.0) / 2.0 # Resize from (-1,1) to (0,1) + return giou + + + def centroid_batch(self, bboxes1, bboxes2) -> np.ndarray: + centroids1 = np.stack(((bboxes1[..., 0] + bboxes1[..., 2]) / 2, + (bboxes1[..., 1] + bboxes1[..., 3]) / 2), axis=-1) + centroids2 = np.stack(((bboxes2[..., 0] + bboxes2[..., 2]) / 2, + (bboxes2[..., 1] + bboxes2[..., 3]) / 2), axis=-1) + + centroids1 = np.expand_dims(centroids1, 1) + centroids2 = np.expand_dims(centroids2, 0) + + distances = np.sqrt(np.sum((centroids1 - centroids2) ** 2, axis=-1)) + norm_factor = np.sqrt(self.w ** 2 + self.h ** 2) + normalized_distances = distances / norm_factor + + return 1 - normalized_distances + + def centroid_batch_obb(self, bboxes1, bboxes2) -> np.ndarray: + centroids1 = np.stack((bboxes1[..., 0], bboxes1[..., 1]),axis=-1) + centroids2 = np.stack((bboxes2[..., 0], bboxes2[..., 1]),axis=-1) + + centroids1 = np.expand_dims(centroids1, 1) + centroids2 = np.expand_dims(centroids2, 0) + + distances = np.sqrt(np.sum((centroids1 - centroids2) ** 2, axis=-1)) + norm_factor = np.sqrt(self.w ** 2 + self.h ** 2) + normalized_distances = distances / norm_factor + + return 1 - normalized_distances + + + @staticmethod + def ciou_batch(bboxes1, bboxes2) -> np.ndarray: + """ + Calculate Complete Intersection over Union (CIoU) for batches of bounding boxes. + + :param bboxes1: Predicted bounding boxes of shape (N, 4) as (x1, y1, x2, y2) + :param bboxes2: Ground truth bounding boxes of shape (N, 4) as (x1, y1, x2, y2) + :return: CIoU scores scaled between 0 and 1 + """ + epsilon = 1e-7 # Small value to prevent division by zero + + # Expand dimensions for broadcasting + bboxes2 = np.expand_dims(bboxes2, 0) # Shape: (1, M, 4) + bboxes1 = np.expand_dims(bboxes1, 1) # Shape: (N, 1, 4) + + # Calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + + # Calculate IoU + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + iou = wh / (area1 + area2 - wh + epsilon) + + # Calculate center points + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + # Calculate squared center distance + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + # Calculate smallest enclosing box diagonal + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + epsilon + + # Calculate aspect ratio consistency + w1 = bboxes1[..., 2] - bboxes1[..., 0] + h1 = bboxes1[..., 3] - bboxes1[..., 1] + w2 = bboxes2[..., 2] - bboxes2[..., 0] + h2 = bboxes2[..., 3] - bboxes2[..., 1] + + # Prevent division by zero + h2 = h2 + epsilon + h1 = h1 + epsilon + arctan_diff = np.arctan(w2 / h2) - np.arctan(w1 / h1) + v = (4 / (np.pi ** 2)) * (arctan_diff ** 2) + + # Calculate alpha + S = 1 - iou + alpha = v / (S + v + epsilon) + + # Compute CIoU + ciou = iou - (inner_diag / outer_diag) + (alpha * v) + + # Scale CIoU to [0, 1] + return (ciou + 1) / 2.0 + + + def diou_batch(bboxes1, bboxes2) -> np.ndarray: + """ + :param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2) + :param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2) + :return: + """ + # for details should go to https://arxiv.org/pdf/1902.09630.pdf + # ensure predict's bbox form + bboxes2 = np.expand_dims(bboxes2, 0) + bboxes1 = np.expand_dims(bboxes1, 1) + + # calculate the intersection box + xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0]) + yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1]) + xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2]) + yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3]) + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + wh = w * h + iou = wh / ( + (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + + (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) - + wh + ) + + centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0 + centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0 + centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0 + centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0 + + inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2 + + xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0]) + yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1]) + xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2]) + yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3]) + + outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2 + diou = iou - inner_diag / outer_diag + + return (diou + 1) / 2.0 + + + @staticmethod + def run_asso_func(self, bboxes1, bboxes2): + """ + Runs the selected association function (based on the initialization string) on the input bounding boxes. + + Parameters: + bboxes1: First set of bounding boxes. + bboxes2: Second set of bounding boxes. + """ + return self.asso_func(bboxes1, bboxes2) + + def _get_asso_func(self, asso_mode): + """ + Returns the corresponding association function based on the provided mode string. + + Parameters: + asso_mode (str): The association function to use (e.g., "iou", "giou", "centroid", etc.). + + Returns: + function: The appropriate function for the association calculation. + """ + ASSO_FUNCS = { + "iou": AssociationFunction.iou_batch, + "iou_obb": AssociationFunction.iou_batch_obb, + "hmiou": AssociationFunction.hmiou_batch, + "giou": AssociationFunction.giou_batch, + "ciou": AssociationFunction.ciou_batch, + "diou": AssociationFunction.diou_batch, + "centroid": self.centroid_batch, # only not being staticmethod + "centroid_obb": self.centroid_batch_obb + } + + if self.asso_mode not in ASSO_FUNCS: + raise ValueError(f"Invalid association mode: {self.asso_mode}. Choose from {list(ASSO_FUNCS.keys())}") + + return ASSO_FUNCS[self.asso_mode] diff --git a/boxmot/utils/matching.py b/boxmot/utils/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..81e3a11bf11ecbd57d3cb7cd704b52e473243dfa --- /dev/null +++ b/boxmot/utils/matching.py @@ -0,0 +1,405 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import lap +import numpy as np +import scipy +import torch +from scipy.spatial.distance import cdist +from boxmot.utils.iou import AssociationFunction + + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919, +} + + +def merge_matches(m1, m2, shape): + O, P, Q = shape + m1 = np.asarray(m1) + m2 = np.asarray(m2) + + M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) + M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) + + mask = M1 * M2 + match = mask.nonzero() + match = list(zip(match[0], match[1])) + unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) + unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) + + return match, unmatched_O, unmatched_Q + + +def _indices_to_matches(cost_matrix, indices, thresh): + matched_cost = cost_matrix[tuple(zip(*indices))] + matched_mask = matched_cost <= thresh + + matches = indices[matched_mask] + unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def linear_assignment(cost_matrix, thresh): + if cost_matrix.size == 0: + return ( + np.empty((0, 2), dtype=int), + tuple(range(cost_matrix.shape[0])), + tuple(range(cost_matrix.shape[1])), + ) + matches, unmatched_a, unmatched_b = [], [], [] + cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + for ix, mx in enumerate(x): + if mx >= 0: + matches.append([ix, mx]) + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + matches = np.asarray(matches) + return matches, unmatched_a, unmatched_b + + +def ious(atlbrs, btlbrs): + """ + Compute cost based on IoU + :type atlbrs: list[tlbr] | np.ndarray + :type atlbrs: list[tlbr] | np.ndarray + + :rtype ious np.ndarray + """ + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if ious.size == 0: + return ious + + ious = bbox_ious( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ) + + return ious + +def d_iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or ( + len(btracks) > 0 and isinstance(btracks[0], np.ndarray) + ): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xyxy for track in atracks] + btlbrs = [track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if ious.size == 0: + return ious + _ious = AssociationFunction.diou_batch(atlbrs, btlbrs) + + cost_matrix = 1 - _ious + + return cost_matrix + +def iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or ( + len(btracks) > 0 and isinstance(btracks[0], np.ndarray) + ): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xyxy for track in atracks] + btlbrs = [track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if ious.size == 0: + return ious + _ious = AssociationFunction.iou_batch(atlbrs, btlbrs) + + cost_matrix = 1 - _ious + + return cost_matrix + + +def v_iou_distance(atracks, btracks): + """ + Compute cost based on IoU + :type atracks: list[STrack] + :type btracks: list[STrack] + + :rtype cost_matrix np.ndarray + """ + + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or ( + len(btracks) > 0 and isinstance(btracks[0], np.ndarray) + ): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks] + btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks] + _ious = ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + + +def embedding_distance(tracks, detections, metric="cosine"): + """ + :param tracks: list[STrack] + :param detections: list[BaseTrack] + :param metric: + :return: cost_matrix np.ndarray + """ + + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray( + [track.curr_feat for track in detections], dtype=np.float32 + ) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray( + [track.smooth_feat for track in tracks], dtype=np.float32 + ) + cost_matrix = np.maximum( + 0.0, cdist(track_features, det_features, metric) + ) # Nomalized features + return cost_matrix + + +def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position + ) + cost_matrix[row, gating_distance > gating_threshold] = np.inf + return cost_matrix + + +def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position, metric="maha" + ) + cost_matrix[row, gating_distance > gating_threshold] = np.inf + cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance + return cost_matrix + + +def fuse_iou(cost_matrix, tracks, detections): + if cost_matrix.size == 0: + return cost_matrix + reid_sim = 1 - cost_matrix + iou_dist = iou_distance(tracks, detections) + iou_sim = 1 - iou_dist + fuse_sim = reid_sim * (1 + iou_sim) / 2 + det_confs = np.array([det.conf for det in detections]) + det_confs = np.expand_dims(det_confs, axis=0).repeat(cost_matrix.shape[0], axis=0) + # fuse_sim = fuse_sim * (1 + det_confs) / 2 + fuse_cost = 1 - fuse_sim + return fuse_cost + + +def fuse_score(cost_matrix, detections): + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_confs = np.array([det.conf for det in detections]) + det_confs = np.expand_dims(det_confs, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_confs + fuse_cost = 1 - fuse_sim + return fuse_cost + + +def _pdist(a, b): + """Compute pair-wise squared distance between points in `a` and `b`. + Parameters + ---------- + a : array_like + An NxM matrix of N samples of dimensionality M. + b : array_like + An LxM matrix of L samples of dimensionality M. + Returns + ------- + ndarray + Returns a matrix of size len(a), len(b) such that eleement (i, j) + contains the squared distance between `a[i]` and `b[j]`. + """ + a, b = np.asarray(a), np.asarray(b) + if len(a) == 0 or len(b) == 0: + return np.zeros((len(a), len(b))) + a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) + r2 = -2.0 * np.dot(a, b.T) + a2[:, None] + b2[None, :] + r2 = np.clip(r2, 0.0, float(np.inf)) + return r2 + + +def _cosine_distance(a, b, data_is_normalized=False): + """Compute pair-wise cosine distance between points in `a` and `b`. + Parameters + ---------- + a : array_like + An NxM matrix of N samples of dimensionality M. + b : array_like + An LxM matrix of L samples of dimensionality M. + data_is_normalized : Optional[bool] + If True, assumes rows in a and b are unit length vectors. + Otherwise, a and b are explicitly normalized to lenght 1. + Returns + ------- + ndarray + Returns a matrix of size len(a), len(b) such that eleement (i, j) + contains the squared distance between `a[i]` and `b[j]`. + """ + if not data_is_normalized: + a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) + b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) + return 1.0 - np.dot(a, b.T) + + +def _nn_euclidean_distance(x, y): + """Helper function for nearest neighbor distance metric (Euclidean). + Parameters + ---------- + x : ndarray + A matrix of N row-vectors (sample points). + y : ndarray + A matrix of M row-vectors (query points). + Returns + ------- + ndarray + A vector of length M that contains for each entry in `y` the + smallest Euclidean distance to a sample in `x`. + """ + # x_ = torch.from_numpy(np.asarray(x) / np.linalg.norm(x, axis=1, keepdims=True)) + # y_ = torch.from_numpy(np.asarray(y) / np.linalg.norm(y, axis=1, keepdims=True)) + distances = distances = _pdist(x, y) + return np.maximum(0.0, torch.min(distances, axis=0)[0].numpy()) + + +def _nn_cosine_distance(x, y): + """Helper function for nearest neighbor distance metric (cosine). + Parameters + ---------- + x : ndarray + A matrix of N row-vectors (sample points). + y : ndarray + A matrix of M row-vectors (query points). + Returns + ------- + ndarray + A vector of length M that contains for each entry in `y` the + smallest cosine distance to a sample in `x`. + """ + x_ = torch.from_numpy(np.asarray(x)) + y_ = torch.from_numpy(np.asarray(y)) + distances = _cosine_distance(x_, y_) + distances = distances + return distances.min(axis=0) + + +class NearestNeighborDistanceMetric(object): + """ + A nearest neighbor distance metric that, for each target, returns + the closest distance to any sample that has been observed so far. + Parameters + ---------- + metric : str + Either "euclidean" or "cosine". + matching_threshold: float + The matching threshold. Samples with larger distance are considered an + invalid match. + budget : Optional[int] + If not None, fix samples per class to at most this number. Removes + the oldest samples when the budget is reached. + Attributes + ---------- + samples : Dict[int -> List[ndarray]] + A dictionary that maps from target identities to the list of samples + that have been observed so far. + """ + + def __init__(self, metric, matching_threshold, budget=None): + if metric == "euclidean": + self._metric = _nn_euclidean_distance + elif metric == "cosine": + self._metric = _nn_cosine_distance + else: + raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'") + self.matching_threshold = matching_threshold + self.budget = budget + self.samples = {} + + def partial_fit(self, features, targets, active_targets): + """Update the distance metric with new data. + Parameters + ---------- + features : ndarray + An NxM matrix of N features of dimensionality M. + targets : ndarray + An integer array of associated target identities. + active_targets : List[int] + A list of targets that are currently present in the scene. + """ + for feature, target in zip(features, targets): + self.samples.setdefault(target, []).append(feature) + if self.budget is not None: + self.samples[target] = self.samples[target][-self.budget:] + self.samples = {k: self.samples[k] for k in active_targets} + + def distance(self, features, targets): + """Compute distance between features and targets. + Parameters + ---------- + features : ndarray + An NxM matrix of N features of dimensionality M. + targets : List[int] + A list of targets to match the given `features` against. + Returns + ------- + ndarray + Returns a cost matrix of shape len(targets), len(features), where + element (i, j) contains the closest squared distance between + `targets[i]` and `features[j]`. + """ + cost_matrix = np.zeros((len(targets), len(features))) + for i, target in enumerate(targets): + cost_matrix[i, :] = self._metric(self.samples[target], features) + return cost_matrix diff --git a/boxmot/utils/misc.py b/boxmot/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9221fca98b070fe86bcb8154b442d05917a7aa --- /dev/null +++ b/boxmot/utils/misc.py @@ -0,0 +1,34 @@ +from pathlib import Path +import os + +def increment_path(path, exist_ok=False, sep="", mkdir=False): + """ + Generates an incremented file or directory path if it already exists, with an option to create the directory. + + Args: + path (str or Path): Initial file or directory path. + exist_ok (bool): If True, returns the original path even if it exists. + sep (str): Separator to use between path stem and increment. + mkdir (bool): If True, creates the directory if it doesn’t exist. + + Returns: + Path: Incremented path, or original if exist_ok is True. + + Example: + runs/exp --> runs/exp2, runs/exp3, etc. + """ + path = Path(path) # ensures OS compatibility + if path.exists() and not exist_ok: + base, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") + + # Increment path until a non-existing one is found + for n in range(2, 9999): + new_path = f"{base}{sep}{n}{suffix}" + if not Path(new_path).exists(): + path = Path(new_path) + break + + if mkdir: + path.mkdir(parents=True, exist_ok=True) # creates the directory if it doesn’t exist + + return path \ No newline at end of file diff --git a/boxmot/utils/ops.py b/boxmot/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..809396ec62c76f145f7ba4a7e9f408560443de2a --- /dev/null +++ b/boxmot/utils/ops.py @@ -0,0 +1,188 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +import torch +import cv2 +from typing import Tuple, Union + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format. + + Args: + x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + Returns: + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x_c, y_c, width, height) format to + (x1, y1, x2, y2) format where (x1, y1) is the top-left corner and (x2, y2) + is the bottom-right corner. + + Args: + x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + Returns: + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x + y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y + return y + + +def xywh2tlwh(x): + """ + Convert bounding box coordinates from (x c, y c, w, h) format to (t, l, w, h) format where (t, l) is the + top-left corner and (w, h) is width and height. + + Args: + x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + Returns: + y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2.0 # xc --> t + y[..., 1] = x[..., 1] - x[..., 3] / 2.0 # yc --> l + y[..., 2] = x[..., 2] # width + y[..., 3] = x[..., 3] # height + return y + + +def tlwh2xyxy(x): + """ + Convert bounding box coordinates from (t, l ,w ,h) format to (t, l, w, h) format where (t, l) is the + top-left corner and (w, h) is width and height. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + y[..., 1] = x[..., 1] + y[..., 2] = x[..., 0] + x[..., 2] + y[..., 3] = x[..., 1] + x[..., 3] + return y + + +def xyxy2tlwh(x): + """ + Convert bounding box coordinates from (t, l ,w ,h) format to (t, l, w, h) format where (t, l) is the + top-left corner and (w, h) is width and height. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + y[..., 1] = x[..., 1] + y[..., 2] = x[..., 2] - x[..., 0] + y[..., 3] = x[..., 3] - x[..., 1] + return y + + +def tlwh2xyah(x): + """ + Convert bounding box coordinates from (t, l ,w ,h) + to (center x, center y, aspect ratio, height)`, where the aspect ratio is `width / height`. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + (x[..., 2] / 2) + y[..., 1] = x[..., 1] + (x[..., 3] / 2) + y[..., 2] = x[..., 2] / x[..., 3] + y[..., 3] = x[..., 3] + return y + + +def xyxy2xysr(x): + """ + Converts bounding box coordinates from (x1, y1, x2, y2) format to (x, y, s, r) format. + + Args: + bbox (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + Returns: + z (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, s, r) format, where + x, y is the center of the box, + s is the scale (area), and + r is the aspect ratio. + """ + x = x[0:4] + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + w = y[..., 2] - y[..., 0] # width + h = y[..., 3] - y[..., 1] # height + y[..., 0] = y[..., 0] + w / 2.0 # x center + y[..., 1] = y[..., 1] + h / 2.0 # y center + y[..., 2] = w * h # scale (area) + y[..., 3] = w / (h + 1e-6) # aspect ratio + y = y.reshape((4, 1)) + return y + + +def letterbox( + img: np.ndarray, + new_shape: Union[int, Tuple[int, int]] = (640, 640), + color: Tuple[int, int, int] = (114, 114, 114), + auto: bool = True, + scaleFill: bool = False, + scaleup: bool = True +) -> Tuple[np.ndarray, Tuple[float, float], Tuple[float, float]]: + """ + Resizes an image to a new shape while maintaining aspect ratio, padding with color if needed. + + Args: + img (np.ndarray): The original image in BGR format. + new_shape (Union[int, Tuple[int, int]], optional): Desired size as an integer (e.g., 640) + or tuple (width, height). Default is (640, 640). + color (Tuple[int, int, int], optional): Padding color in BGR format. Default is (114, 114, 114). + auto (bool, optional): If True, adjusts padding to be a multiple of 32. Default is True. + scaleFill (bool, optional): If True, stretches the image to fill the new shape. Default is False. + scaleup (bool, optional): If True, allows scaling up; otherwise, only scales down. Default is True. + + Returns: + Tuple[np.ndarray, Tuple[float, float], Tuple[float, float]]: + - Resized and padded image as np.ndarray. + - Scaling ratio used for width and height as (width_ratio, height_ratio). + - Padding applied to width and height as (width_padding, height_padding). + """ + shape = img.shape[:2] # current shape [height, width] + + # Ensure new_shape is a tuple (width, height) + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Calculate scale ratio + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: + r = min(r, 1.0) # only scale down + + # Calculate new dimensions and padding + ratio = (r, r) + new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r))) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] + + if auto: # minimum rectangle + dw, dh = np.mod(dw, 32), np.mod(dh, 32) + elif scaleFill: # stretch to fill + dw, dh = 0.0, 0.0 + new_unpad = new_shape + ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0]) + + # Divide padding by 2 for even distribution + dw /= 2 + dh /= 2 + + # Resize image if necessary + if shape[::-1] != new_unpad: + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + + # Add border to the image + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) + + return img, ratio, (dw, dh) diff --git a/boxmot/utils/torch_utils.py b/boxmot/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70f7388e5e7f25e07888b028bf4e2955fe51a5aa --- /dev/null +++ b/boxmot/utils/torch_utils.py @@ -0,0 +1,53 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import os +import platform +import torch + +from .. import __version__ +from . import logger as LOGGER +from boxmot.utils import ROOT + + +def get_system_info(): + return f"Yolo Tracking v{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__}" + +def parse_device(device): + device = str(device).lower().replace("cuda:", "").replace("none", "").replace("(", "").replace(")", "").replace("[", "").replace("]", "").replace("'", "").replace(" ", "") + return device + +def assert_cuda_available(device): + if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))): + install = ("See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else "") + raise ValueError(f"Invalid CUDA 'device={device}' requested. Use 'device=cpu' or pass valid CUDA device(s) if available, i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" + + f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" + + f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" + + f"\nos.environ['CUDA_VISIBLE_DEVICES']: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}\n{install}") + +def select_device(device="", batch=0): + s = get_system_info() + device = parse_device(device) + mps = device == "mps" + cpu = device == "cpu" or device == "" and not torch.cuda.is_available() + + if cpu or mps: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + elif device: + os.environ["CUDA_VISIBLE_DEVICES"] = device + assert_cuda_available(device) + + if not cpu and not mps and torch.cuda.is_available(): + devices = device.split(",") if device else ["0"] + n = len(devices) + if n > 1 and batch > 0 and batch % n != 0: + raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}.") + s += "\n" + "\n".join(f"CUDA:{d} ({torch.cuda.get_device_properties(i).name}, {torch.cuda.get_device_properties(i).total_memory / (1 << 20):.0f}MiB)" for i, d in enumerate(devices)) + arg = "cuda:" + devices[0] + elif mps: + s += "MPS" + arg = "mps" + else: + s += "CPU" + arg = "cpu" + LOGGER.info(s) + return torch.device(arg) diff --git a/download_models.py b/download_models.py new file mode 100644 index 0000000000000000000000000000000000000000..7d40abeb2c79e9163c1a42142108a2e88524b3d4 --- /dev/null +++ b/download_models.py @@ -0,0 +1,31 @@ +import os +import sys +import torch +import gdown +from pathlib import Path + +MODELS_DIR = Path("models") + +def download_models(): + """Download required models if they don't exist.""" + os.makedirs(MODELS_DIR, exist_ok=True) + + # Model URLs (you should replace these with actual URLs to your models) + models = { + "yolov8n.pt": "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt", + "osnet_x0_25_msmt17.pt": "https://github.com/mikel-brostrom/assets/releases/download/v0.0.0/osnet_x0_25_msmt17.pt" + } + + # Download missing models + for model_name, url in models.items(): + model_path = MODELS_DIR / model_name + if not model_path.exists(): + print(f"Downloading {model_name}...") + try: + torch.hub.download_url_to_file(url, model_path) + print(f"✅ Downloaded {model_name}") + except Exception as e: + print(f"⚠️ Failed to download {model_name}: {e}") + +if __name__ == "__main__": + download_models() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b5d6d1aa9a8aff3f5efa51ee65d63d3f9a7ece0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +setuptools>=78.0.0 +filterpy>=1.4.5 +gdown>=5.1.0 +lapx>=0.5.5 +loguru>=0.7.2 +numpy==1.26.4 +pyyaml>=6.0.1 +regex>=2024.0.0 +yacs>=0.1.8 +scikit-learn>=1.3.0 +pandas>=2.0.0 +opencv-python>=4.7.0 +ftfy>=6.1.3 +gitpython>=3.1.42 +torch>=2.2.1 +torchvision>=0.17.1 +gradio>=3.50.2 +ultralytics +boxmot diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/performance/__init__.py b/tests/performance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/performance/test_cmcs_p.py b/tests/performance/test_cmcs_p.py new file mode 100644 index 0000000000000000000000000000000000000000..69414af59b1c92ce174213114b5d780d788eca84 --- /dev/null +++ b/tests/performance/test_cmcs_p.py @@ -0,0 +1,45 @@ +import cv2 +import time +import numpy as np +import pytest +from boxmot.motion.cmc.ecc import ECC +from boxmot.motion.cmc.orb import ORB +from boxmot.motion.cmc.sift import SIFT +from boxmot.motion.cmc.sof import SOF +from boxmot.utils import ROOT + + + +# Fixture for creating CMC objects +@pytest.fixture +def cmc_object(request): + cmc_class = request.param + return cmc_class() + + +# Define the test function +@pytest.mark.parametrize("cmc_object", [ECC, ORB, SIFT, SOF], indirect=True) +def test_cmc_apply(cmc_object): + + # Create dummy images and detections + curr_img = cv2.imread(str(ROOT / 'assets/MOT17-mini/train/MOT17-04-FRCNN/img1/000005.jpg')) + prev_img = cv2.imread(str(ROOT / 'assets/MOT17-mini/train/MOT17-04-FRCNN/img1/000001.jpg')) + + print(curr_img.shape) + print(prev_img.shape) + + dets = np.array([[0, 0, 10, 10]]) + + n_runs = 100 + start = time.process_time() + for i in range(0, n_runs): + warp_matrix = cmc_object.apply(prev_img, dets) + warp_matrix = cmc_object.apply(curr_img, dets) + end = time.process_time() + elapsed_time_per_interation = (end - start) / n_runs + + # Define a threshold for the maximum allowed time + max_allowed_time = 0.1 + + # Assert that the elapsed time is within the allowed limit + assert elapsed_time_per_interation < max_allowed_time, "CMC algorithm processing time exceeds the allowed limit" \ No newline at end of file diff --git a/tests/performance/test_tracking_p.py b/tests/performance/test_tracking_p.py new file mode 100644 index 0000000000000000000000000000000000000000..57855a76e9beaee74c86021f669534d610f37f67 --- /dev/null +++ b/tests/performance/test_tracking_p.py @@ -0,0 +1,97 @@ +import pytest +import numpy as np +from pathlib import Path +from boxmot.utils import WEIGHTS +import time +import subprocess + +from numpy.testing import assert_allclose +from boxmot import ( + StrongSort, BotSort, DeepOcSort, OcSort, ByteTrack, ImprAssocTrack, get_tracker_config, create_tracker, +) +from tests.test_config import MOTION_ONLY_TRACKING_NAMES, MOTION_N_APPEARANCE_TRACKING_NAMES + + +@pytest.mark.parametrize("tracker_type", MOTION_ONLY_TRACKING_NAMES) +def test_motion_tracker_update_time(tracker_type): + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=False + ) + + rgb = np.random.randint(0, 255, size=(640, 640, 3), dtype=np.uint8) + det = np.array([[144, 212, 578, 480, 0.82, 0], + [425, 281, 576, 472, 0.56, 65]]) + + n_runs = 100 + + # Warm-up iteration to ensure initialization overhead is not measured + tracker.update(det, rgb) + + start = time.perf_counter() + for _ in range(n_runs): + tracker.update(det, rgb) + end = time.perf_counter() + + elapsed_time_per_iteration = (end - start) / n_runs + fps = 1.0 / elapsed_time_per_iteration + + # Print FPS for each tracker type + print(f"Tracker type: {tracker_type} - FPS: {fps:.2f}") + result = subprocess.run( + "cat /proc/cpuinfo | grep 'model name' | head -1", + shell=True, + capture_output=True, + text=True + ) + print(result.stdout.strip()) + max_allowed_time = 0.005 # maximum allowed time per iteration in seconds + + assert elapsed_time_per_iteration < max_allowed_time, ( + f"Tracking algorithm's processing time per iteration ({elapsed_time_per_iteration:.6f}s) " + f"exceeds the allowed limit of {max_allowed_time}s." + ) + + +@pytest.mark.parametrize("tracker_type", MOTION_N_APPEARANCE_TRACKING_NAMES) +def test_motion_n_appearance_tracker_update_time(tracker_type): + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=False + ) + + rgb = np.random.randint(0, 255, size=(640, 640, 3), dtype=np.uint8) + det = np.array([[144, 212, 578, 480, 0.82, 0], + [425, 281, 576, 472, 0.56, 65]]) + + n_runs = 100 + + # Warm-up iteration to avoid initialization overhead in timing + tracker.update(det, rgb) + + start = time.perf_counter() + for _ in range(n_runs): + tracker.update(det, rgb) + end = time.perf_counter() + + elapsed_time_per_iteration = (end - start) / n_runs + fps = 1.0 / elapsed_time_per_iteration + + # Print FPS for each tracker type + print(f"Tracker type: {tracker_type} - FPS: {fps:.2f}") + max_allowed_time = 6 # maximum allowed time per iteration in seconds + + assert elapsed_time_per_iteration < max_allowed_time, ( + f"Tracking algorithm's processing time per iteration ({elapsed_time_per_iteration:.4f}s) " + f"exceeds the allowed limit of {max_allowed_time}s." + ) \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..155ee92a11d19aee256f99d4f037c8641ddbe5f2 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,12 @@ +from boxmot import ( + StrongSort, BotSort, DeepOcSort, OcSort, ByteTrack, ImprAssocTrack, BoostTrack, get_tracker_config, create_tracker, +) + +MOTION_N_APPEARANCE_TRACKING_NAMES = ['botsort', 'deepocsort', 'strongsort', 'imprassoc', 'boosttrack'] +MOTION_ONLY_TRACKING_NAMES = ['ocsort', 'bytetrack'] + +MOTION_N_APPEARANCE_TRACKING_METHODS=[StrongSort, BotSort, DeepOcSort, ImprAssocTrack, BoostTrack] +MOTION_ONLY_TRACKING_METHODS=[OcSort, ByteTrack] + +ALL_TRACKERS = ['botsort', 'deepocsort', 'ocsort', 'bytetrack', 'strongsort', 'imprassoc', 'boosttrack'] +PER_CLASS_TRACKERS = ['botsort', 'deepocsort', 'ocsort', 'bytetrack', 'imprassoc'] \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/unit/test_cmcs_u.py b/tests/unit/test_cmcs_u.py new file mode 100644 index 0000000000000000000000000000000000000000..3e12431d44816c71292c7a230aaa51dc047e459c --- /dev/null +++ b/tests/unit/test_cmcs_u.py @@ -0,0 +1,49 @@ +import cv2 +import numpy as np +import pytest +from boxmot.motion.cmc.ecc import ECC +from boxmot.motion.cmc.orb import ORB +from boxmot.motion.cmc.sift import SIFT +from boxmot.motion.cmc.sof import SOF + + + +# Fixture for creating CMC objects +@pytest.fixture +def cmc_object(request): + cmc_class = request.param + return cmc_class() + + +# Define the test function +@pytest.mark.parametrize("cmc_object", [ECC, ORB, SIFT, SOF], indirect=True) +def test_cmc_apply(cmc_object): + # Create dummy images and detections + prev_img = np.zeros((100, 100, 3), dtype=np.uint8) + dets = np.array([[0, 0, 10, 10]]) + # Apply the CMC algorithm + result = cmc_object.apply(prev_img, dets) + # Assert the type of result + assert isinstance(result, np.ndarray) + + +# Test preprocessing function +@pytest.mark.parametrize("cmc_object", [ECC, ORB, SIFT, SOF], indirect=True) +def test_cmc_preprocess(cmc_object): + # Create a dummy image + img = np.zeros((100, 100, 3), dtype=np.uint8) + processed_img = cmc_object.preprocess(img) + # Assert the shape of the processed image, scale is 0.1 by default + assert processed_img.shape == (10, 10) + + +# Test apply function with empty detections +@pytest.mark.parametrize("cmc_object", [ECC, ORB, SIFT, SOF], indirect=True) +def test_cmc_apply_empty_detections(cmc_object): + # Create dummy images and empty detections + prev_img = np.zeros((100, 100, 3), dtype=np.uint8) + dets = np.array([]) + # Apply the CMC algorithm + result = cmc_object.apply(prev_img, dets) + # Assert that result is an identity matrix + assert np.array_equal(result, np.eye(2, 3, dtype=np.float32)) diff --git a/tests/unit/test_cuda.py b/tests/unit/test_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..cd19eed04c1981348bc250dae6aaab6bbc35293f --- /dev/null +++ b/tests/unit/test_cuda.py @@ -0,0 +1,46 @@ +import cv2 +import torch +import pytest +import numpy as np +from pathlib import Path +from boxmot.utils import ROOT + +from boxmot.appearance.reid.auto_backend import ReidAutoBackend + +REID_MODELS = [ + Path('mobilenetv2_x1_0_market1501.pt'), +] + + +@pytest.mark.parametrize("reid_model", REID_MODELS) +def test_reidbackend_device(reid_model): + + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + + rab = ReidAutoBackend( + weights=reid_model, device=device, half=False + ) + r = rab.get_backend() + + if torch.cuda.is_available(): + assert next(r.model.parameters()).is_cuda + else: + assert next(r.model.parameters()).device.type == 'cpu' + + +@pytest.mark.parametrize("reid_model", REID_MODELS) +def test_reidbackend_half(reid_model): + + half = True if torch.cuda.is_available() else False + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + rab = ReidAutoBackend( + weights=reid_model, device=device, half=False + ) + r = rab.get_backend() + + if device == 'cpu': + expected_dtype = torch.float32 + else: + expected_dtype = torch.float16 + actual_dtype = next(r.model.parameters()).dtype + assert actual_dtype == expected_dtype diff --git a/tests/unit/test_postprocessing.py b/tests/unit/test_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..db24a8bded53e3f8c296090aac40dedd86394019 --- /dev/null +++ b/tests/unit/test_postprocessing.py @@ -0,0 +1,14 @@ +import numpy as np +from boxmot.postprocessing.gsi import gaussian_smooth, linear_interpolation + + +def test_gsi(): + tracking_results = np.array([ + [1, 1, 1475, 419, 75, 169, 0, 0, -1], + [2, 1, 1475, 419, 75, 169, 0, 0, -1], + [4, 1, 1475, 419, 75, 169, 0, 0, -1], + [6, 1, 1475, 419, 75, 169, 0, 0, -1] + ]) + li = linear_interpolation(tracking_results, interval=20) + gsi = gaussian_smooth(li, tau=10) + assert len(gsi) == 6 \ No newline at end of file diff --git a/tests/unit/test_reidbackend.py b/tests/unit/test_reidbackend.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d134acd971ab8ffca8229fc00eaa64187d77ff --- /dev/null +++ b/tests/unit/test_reidbackend.py @@ -0,0 +1,56 @@ +import cv2 +import pytest +import numpy as np +from pathlib import Path +from boxmot.utils import ROOT, WEIGHTS +from boxmot.appearance.backends.onnx_backend import ONNXBackend +from boxmot.appearance.backends.openvino_backend import OpenVinoBackend +from boxmot.appearance.backends.pytorch_backend import PyTorchBackend +from boxmot.appearance.backends.tensorrt_backend import TensorRTBackend +from boxmot.appearance.backends.tflite_backend import TFLiteBackend +from boxmot.appearance.backends.torchscript_backend import TorchscriptBackend + +from boxmot.appearance.reid.auto_backend import ReidAutoBackend + +# generated in previous job step +EXPORTED_REID_MODELS = [ + WEIGHTS / 'osnet_x0_25_msmt17.pt', + WEIGHTS / 'osnet_x0_25_msmt17.torchscript', + WEIGHTS / 'osnet_x0_25_msmt17.onnx', + WEIGHTS / 'osnet_x0_25_msmt17_openvino_model' +] + +ASSOCIATED_BACKEND = [ + PyTorchBackend, + TorchscriptBackend, + ONNXBackend, + OpenVinoBackend +] + + +@pytest.mark.parametrize("reid_model", EXPORTED_REID_MODELS) +def test_reidbackend_output(reid_model): + + rab = ReidAutoBackend( + weights=reid_model, device='cpu', half=False + ) + b = rab.get_backend() + + img = cv2.imread(str(ROOT / 'assets/MOT17-mini/train/MOT17-04-FRCNN/img1/000001.jpg')) + dets = np.array([[144, 212, 578, 480, 0.82, 0], + [425, 281, 576, 472, 0.56, 65]]) + + embs = b.get_features(dets[:, 0:4], img) + assert embs.shape[0] == 2 # two crops should give two embeddings + assert embs.shape[1] == 512 # osnet embeddings are of size 512 + + +@pytest.mark.parametrize("exported_reid_model, backend", zip(EXPORTED_REID_MODELS, ASSOCIATED_BACKEND)) +def test_reidbackend_type(exported_reid_model, backend): + + rab = ReidAutoBackend( + weights=exported_reid_model, device='cpu', half=False + ) + b = rab.get_backend() + + assert isinstance(b, backend) \ No newline at end of file diff --git a/tests/unit/test_trackers.py b/tests/unit/test_trackers.py new file mode 100644 index 0000000000000000000000000000000000000000..23b8629b7b251f46dc4f415c3f9bda52b55d6e13 --- /dev/null +++ b/tests/unit/test_trackers.py @@ -0,0 +1,180 @@ +import pytest +import numpy as np +from pathlib import Path +from boxmot.utils import WEIGHTS + + +from numpy.testing import assert_allclose +from boxmot import ( + StrongSort, BotSort, DeepOcSort, OcSort, ByteTrack, ImprAssocTrack, get_tracker_config, create_tracker, +) + +from boxmot.trackers.ocsort.ocsort import KalmanBoxTracker as OCSortKalmanBoxTracker +from boxmot.trackers.deepocsort.deepocsort import KalmanBoxTracker as DeepOCSortKalmanBoxTracker +from tests.test_config import MOTION_ONLY_TRACKING_METHODS, MOTION_N_APPEARANCE_TRACKING_METHODS, ALL_TRACKERS, PER_CLASS_TRACKERS + + +@pytest.mark.parametrize("Tracker", MOTION_N_APPEARANCE_TRACKING_METHODS) +def test_motion_n_appearance_trackers_instantiation(Tracker): + Tracker( + reid_weights=Path(WEIGHTS / 'osnet_x0_25_msmt17.pt'), + device='cpu', + half=True, + ) + + +@pytest.mark.parametrize("Tracker", MOTION_ONLY_TRACKING_METHODS) +def test_motion_only_trackers_instantiation(Tracker): + Tracker() + + +@pytest.mark.parametrize("tracker_type", ALL_TRACKERS) +def test_tracker_output_size(tracker_type): + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=False + ) + + rgb = np.random.randint(255, size=(640, 640, 3), dtype=np.uint8) + det = np.array([[144, 212, 400, 480, 0.82, 0], + [425, 281, 576, 472, 0.72, 65]]) + + output = tracker.update(det, rgb) + assert output.shape == (2, 8) # two inputs should give two outputs + + +def test_dynamic_max_obs_based_on_max_age(): + max_age = 400 + ocsort = OcSort( + max_age=max_age + ) + + assert ocsort.max_obs == (max_age + 5) + + +def create_kalman_box_tracker_ocsort(bbox, cls, det_ind, tracker): + return OCSortKalmanBoxTracker( + bbox, + cls, + det_ind, + Q_xy_scaling=tracker.Q_xy_scaling, + Q_s_scaling=tracker.Q_s_scaling + ) + + +def create_kalman_box_tracker_deepocsort(bbox, cls, det_ind, tracker): + # DeepOCSort KalmanBoxTracker expects input in different format than OCSort + det = np.concatenate([bbox, [cls, det_ind]]) + return DeepOCSortKalmanBoxTracker( + det, + Q_xy_scaling=tracker.Q_xy_scaling, + Q_s_scaling=tracker.Q_s_scaling + ) + + +TRACKER_CREATORS = { + OcSort: create_kalman_box_tracker_ocsort, + DeepOcSort: create_kalman_box_tracker_deepocsort, +} + + +@pytest.mark.parametrize("Tracker, init_args", [ + (OcSort, {}), + (DeepOcSort, { + 'reid_weights': Path(WEIGHTS / 'osnet_x0_25_msmt17.pt'), + 'device': 'cpu', + 'half': True + }), +]) +def test_Q_matrix_scaling(Tracker, init_args): + bbox = np.array([0, 0, 100, 100, 0.9]) + cls = 1 + det_ind = 0 + Q_xy_scaling = 0.05 + Q_s_scaling = 0.0005 + + tracker = Tracker( + Q_xy_scaling=Q_xy_scaling, + Q_s_scaling=Q_s_scaling, + **init_args + ) + + create_kalman_box_tracker = TRACKER_CREATORS[Tracker] + kalman_box_tracker = create_kalman_box_tracker(bbox, cls, det_ind, tracker) + + assert kalman_box_tracker.kf.Q[4, 4] == Q_xy_scaling, "Q_xy scaling incorrect for x' velocity" + assert kalman_box_tracker.kf.Q[5, 5] == Q_xy_scaling, "Q_xy scaling incorrect for y' velocity" + assert kalman_box_tracker.kf.Q[6, 6] == Q_s_scaling, "Q_s scaling incorrect for s' (scale) velocity" + + +@pytest.mark.parametrize("tracker_type", PER_CLASS_TRACKERS) +def test_per_class_tracker_output_size(tracker_type): + + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=True + ) + + rgb = np.random.randint(255, size=(640, 640, 3), dtype=np.uint8) + det = np.array([[144, 212, 578, 480, 0.82, 0], + [425, 281, 576, 472, 0.72, 65]]) + embs = np.random.random(size=(2, 512)) + + output = tracker.update(det, rgb, embs) + output = tracker.update(det, rgb, embs) + assert output.shape == (2, 8) # two inputs should give two outputs + + +@pytest.mark.parametrize("tracker_type", PER_CLASS_TRACKERS) +def test_per_class_tracker_active_tracks(tracker_type): + + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=True + ) + + rgb = np.random.randint(255, size=(640, 640, 3), dtype=np.uint8) + det = np.array([[144, 212, 578, 480, 0.82, 0], + [425, 281, 576, 472, 0.72, 65]]) + embs = np.random.random(size=(2, 512)) + + tracker.update(det, rgb, embs) + + # Check that tracks are created under the class tracks + assert tracker.per_class_active_tracks[0], "No active tracks for class 0" + assert tracker.per_class_active_tracks[65], "No active tracks for class 65" + + +@pytest.mark.parametrize("tracker_type", ALL_TRACKERS) +@pytest.mark.parametrize("dets", [None, np.array([])]) +def test_tracker_with_no_detections(tracker_type, dets): + tracker_conf = get_tracker_config(tracker_type) + tracker = create_tracker( + tracker_type=tracker_type, + tracker_config=tracker_conf, + reid_weights=WEIGHTS / 'mobilenetv2_x1_4_dukemtmcreid.pt', + device='cpu', + half=False, + per_class=False + ) + + rgb = np.random.randint(255, size=(640, 640, 3), dtype=np.uint8) + embs = np.random.random(size=(2, 512)) + + output = tracker.update(dets, rgb, embs) + assert output.size == 0, "Output should be empty when no detections are provided" \ No newline at end of file diff --git a/tracking/detectors/__init__.py b/tracking/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c02b77e2aa696f0467a4794625899175c95d341e --- /dev/null +++ b/tracking/detectors/__init__.py @@ -0,0 +1,63 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from boxmot.utils import logger as LOGGER +from boxmot.utils.checks import RequirementsChecker + +checker = RequirementsChecker() + +UL_MODELS = ['yolov8', 'yolov9', 'yolov10', 'yolo11', 'yolo12', 'rtdetr', 'sam'] + + +def is_ultralytics_model(yolo_name): + return any(yolo in str(yolo_name) for yolo in UL_MODELS) + + +def is_yolox_model(yolo_name): + return 'yolox' in str(yolo_name) + + +def default_imgsz(yolo_name): + if is_ultralytics_model(yolo_name): + return [640, 640] + elif is_yolox_model(yolo_name): + return [800, 1440] + else: + return [640, 640] + + +def get_yolo_inferer(yolo_model): + + if is_yolox_model(yolo_model): + try: + import yolox # for linear_assignment + assert yolox.__version__ + except (ImportError, AssertionError, AttributeError): + checker.check_packages(('yolox',)) + checker.check_packages(('tabulate',)) # needed dependency + checker.check_packages(('thop',)) # needed dependency + from .yolox import YoloXStrategy + return YoloXStrategy + elif 'yolov8' in str(yolo_model): + # ultralytics already installed when running track.py + from .yolov8 import Yolov8Strategy + return Yolov8Strategy + elif 'rf-detr' in str(yolo_model): + try: + import rfdetr + except (ImportError, AssertionError, AttributeError): + checker.check_packages(('onnxruntime',)) # needed dependency + checker.check_packages(('rfdetr',)) # needed dependency + from .rfdetr import RFDETRStrategy + return RFDETRStrategy + elif 'yolo_nas' in str(yolo_model): + try: + import super_gradients # for linear_assignment + assert super_gradients.__version__ + except (ImportError, AssertionError, AttributeError): + checker.check_packages(('super-gradients==3.1.3',)) # install + from .yolonas import YoloNASStrategy + return YoloNASStrategy + else: + LOGGER.error('Failed to infer inference mode from yolo model name') + LOGGER.error('Your model name has to contain either yolox, yolo_nas or yolov8') + exit() diff --git a/tracking/detectors/rfdetr.py b/tracking/detectors/rfdetr.py new file mode 100644 index 0000000000000000000000000000000000000000..da22ea46e86339a96657c9fb3b824b065f114188 --- /dev/null +++ b/tracking/detectors/rfdetr.py @@ -0,0 +1,83 @@ +# Mikel Broström 🔥 RFDETR Tracking 🧾 AGPL-3.0 license + +import numpy as np +import torch +import cv2 +from PIL import Image +from rfdetr import RFDETRBase +from rfdetr.util.coco_classes import COCO_CLASSES +from ultralytics.engine.results import Results +from ultralytics.utils import ops +from ultralytics.models.yolo.detect import DetectionPredictor + + + +from boxmot.utils import logger as LOGGER +from tracking.detectors.yolo_interface import YoloInterface + + +class RFDETRStrategy(YoloInterface): + pt = False + stride = 32 + fp16 = False + triton = False + names = COCO_CLASSES + + def __init__(self, model, device, args): + self.args = args + LOGGER.info("Loading RFDETR model") + self.model = RFDETRBase(device='cpu') + + @torch.no_grad() + def __call__(self, im, augment, visualize, embed): + + # Convert frame to PIL Image format for RFDETR + frame_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + image = Image.fromarray(frame_rgb) + with torch.no_grad(): + detections = self.model.predict(im, threshold=self.args.conf) + + preds = np.column_stack( + [ + detections.xyxy, + detections.confidence[:, np.newaxis], + detections.class_id[:, np.newaxis] + ] + ) + + preds = torch.from_numpy(preds).unsqueeze(0) + return preds + + def warmup(self, imgsz): + pass + + def update_im_paths(self, predictor: DetectionPredictor): + """ + This function saves image paths for the current batch, + being passed as callback on_predict_batch_start + """ + assert (isinstance(predictor, DetectionPredictor), + "Only ultralytics predictors are supported") + self.im_paths = predictor.batch[0] + + def preprocess(self, im) -> torch.Tensor: + assert isinstance(im, list) + return im[0] + + def postprocess(self, preds, im, im0s): + results = [] + for i, pred in enumerate(preds): + im_path = self.im_paths[i] if len(self.im_paths) else "" + if pred is None or len(pred) == 0: + pred = torch.empty((0, 6)) + else: + if self.args.classes: + pred = pred[torch.isin(pred[:, 5].cpu(), torch.as_tensor(self.args.classes))] + r = Results( + path=im_path, + boxes=pred, + orig_img=im0s[i], + names=COCO_CLASSES + ) + results.append(r) + return results diff --git a/tracking/detectors/yolo_interface.py b/tracking/detectors/yolo_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..815bacdb5f4df986f9aff09539f46a87a671699d --- /dev/null +++ b/tracking/detectors/yolo_interface.py @@ -0,0 +1,56 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from pathlib import Path + +import numpy as np +import torch +from ultralytics.engine.results import Results +from abc import ABC, abstractmethod + + +class YoloInterface(ABC): + + @abstractmethod + def __call__(self, im): + pass + + @abstractmethod + def preprocess(self, ims): + pass + + @abstractmethod + def postprocess(self, preds): + pass + + def get_scaling_factors(self, im, im0): + + # im to im0 factor for predictions + im0_w = im0.shape[1] + im0_h = im0.shape[0] + im_w = im.shape[2] + im_h = im.shape[1] + w_r = im0_w / im_w + h_r = im0_h / im_h + + return im_w, im_h, w_r, h_r + + def scale_and_clip(self, preds, im_w, im_h, w_r, h_r): + # scale bboxes to original image + preds[:, [0, 2]] = preds[:, [0, 2]] * self.w_r + preds[:, [1, 3]] = preds[:, [1, 3]] * self.h_r + + if not isinstance(preds, (torch.Tensor)): + preds = torch.from_numpy(preds) + + preds[:, [0, 2]] = torch.clip(preds[:, [0, 2]], min=0) # max=im_w + preds[:, [1, 3]] = torch.clip(preds[:, [1, 3]], min=0) # max=im_h + + return preds + + def get_model_from_weigths(self, l, model): + model_type = None + for key in l: + if Path(key).stem in str(model.name): + model_type = str(Path(key).with_suffix('')) + break + return model_type diff --git a/tracking/detectors/yolonas.py b/tracking/detectors/yolonas.py new file mode 100644 index 0000000000000000000000000000000000000000..10167f3f2de95286fedd11519a6a8474198501a6 --- /dev/null +++ b/tracking/detectors/yolonas.py @@ -0,0 +1,117 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +import torch +from super_gradients.common.object_names import Models +from super_gradients.training import models +from ultralytics.engine.results import Results +from ultralytics.utils import ops + +from boxmot.utils import logger as LOGGER +from tracking.detectors.yolo_interface import YoloInterface + + +class YoloNASStrategy(YoloInterface): + pt = False + stride = 32 + fp16 = False + triton = False + names = { + 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', + 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', + 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', + 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', + 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', + 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', + 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', + 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', + 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', + 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', + 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', + 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', + 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', + 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', + 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', + 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush' + } + + def __init__(self, model, device, args): + self.args = args + + avail_models = [x.lower() for x in list(Models.__dict__.keys())] + model_type = self.get_model_from_weigths(avail_models, model) + + LOGGER.info(f'Loading {model_type} with {str(model)}') + if not model.exists() and model.stem == model_type: + LOGGER.info('Downloading pretrained weights...') + self.model = models.get( + model_type, + pretrained_weights="coco" + ).to(device) + else: + self.model = models.get( + model_type, + num_classes=-1, # set your num classes + checkpoint_path=str(model) + ).to(device) + + self.device = device + + @torch.no_grad() + def __call__(self, im, augment, visualize, embed): + + im = im[0].permute(1, 2, 0).cpu().numpy() * 255 + + with torch.no_grad(): + preds = self.model.predict( + im, + iou=self.args.iou, + conf=self.args.conf, + fuse_model=False + )[0].prediction + + preds = np.concatenate( + [ + preds.bboxes_xyxy, + preds.confidence[:, np.newaxis], + preds.labels[:, np.newaxis] + ], axis=1 + ) + + preds = torch.from_numpy(preds).unsqueeze(0) + + return preds + + def warmup(self, imgsz): + pass + + def postprocess(self, path, preds, im, im0s): + + results = [] + for i, pred in enumerate(preds): + + if pred is None: + pred = torch.empty((0, 6)) + r = Results( + path=path, + boxes=pred, + orig_img=im0s[i], + names=self.names + ) + results.append(r) + else: + + pred[:, :4] = ops.scale_boxes(im.shape[2:], pred[:, :4], im0s[i].shape) + + # filter boxes by classes + if self.args.classes: + pred = pred[torch.isin(pred[:, 5].cpu(), torch.as_tensor(self.args.classes))] + + r = Results( + path=path, + boxes=pred, + orig_img=im0s[i], + names=self.names + ) + results.append(r) + return results diff --git a/tracking/detectors/yolov8.py b/tracking/detectors/yolov8.py new file mode 100644 index 0000000000000000000000000000000000000000..445946179e4a94bbc6de8574b28ba0edefd6ff24 --- /dev/null +++ b/tracking/detectors/yolov8.py @@ -0,0 +1,17 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +from .yolo_interface import YoloInterface + + +class Yolov8Strategy(YoloInterface): + + def __init__(self, model, device, args): + self.model = model + + def inference(self, im): + preds = self.model(im, augment=False, visualize=False) + return preds + + def postprocess(self, path, preds, im, im0s, predictor): + postprocessed_preds = predictor.postprocess(preds, im, im0s) + return postprocessed_preds diff --git a/tracking/detectors/yolov9.py b/tracking/detectors/yolov9.py new file mode 100644 index 0000000000000000000000000000000000000000..39fd70174ccf28688832d63ed4e10c5f4a7ff951 --- /dev/null +++ b/tracking/detectors/yolov9.py @@ -0,0 +1,118 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +import torch +from super_gradients.common.object_names import Models +from super_gradients.training import models +from ultralytics.engine.results import Results +from ultralytics.utils import ops +from ultralytics.utils.downloads import download +from ultralytics.utils import ops + +from yolov9 import load +from boxmot.utils import logger as LOGGER +from examples.detectors.yolo_interface import YoloInterface + + +YOLOv9_ZOO = { + 'gelan-c.pt': 'https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c.pt', + 'gelan-e.pt': 'https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-e.pt', + 'yolov9-c.pt': 'https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c.pt', + 'yolov9-e.pt': 'https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-e.pt', +} + +class Yolov9Strategy(YoloInterface): + pt = False + stride = 32 + fp16 = False + triton = False + names = { + 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', + 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', + 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', + 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', + 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', + 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', + 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', + 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', + 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', + 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', + 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', + 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', + 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', + 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', + 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', + 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush' + } + + def __init__(self, model, device, args): + + self.args = args + self.pt = False + self.stride = 32 # max stride in YOLOX + + # model_type one of: 'yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x' + model_type = self.get_model_from_weigths(YOLOv9_ZOO.keys(), model) + + LOGGER.info(f'Loading {model_type} with {str(model)}') + + # download crowdhuman bytetrack models + if not model.exists() and model.stem == model_type: + LOGGER.info('Downloading Yolov9 pretrained weights...') + # download( + # url=YOLOv9_ZOO[model_type + '.pt'], + # dir="./weights", + # ) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.model = load( + "/home/mikel.brostrom/yolo_tracking/examples/weights/yolov9-c.pt", + device=device, + ) + + self.model.conf = args.conf + self.model.iou = args.iou + self.model.classes = args.classes + + + @torch.no_grad() + def __call__(self, im, augment, visualize): + + im = im[0].permute(1, 2, 0).cpu().numpy() * 255 + + with torch.no_grad(): + results = self.model(im) + preds = results.pred[0] + + preds = preds.unsqueeze(0) + + return preds + + def warmup(self, imgsz): + pass + + def postprocess(self, path, preds, im, im0s): + + results = [] + for i, pred in enumerate(preds): + + if pred is None: + pred = torch.empty((0, 6)) + r = Results( + path=path, + boxes=pred, + orig_img=im0s[i], + names=self.names + ) + results.append(r) + else: + pred = self.clip(pred, im0s[i]) + r = Results( + path=path, + boxes=pred, + orig_img=im0s[i], + names=self.names + ) + results.append(r) + return results diff --git a/tracking/detectors/yolox.py b/tracking/detectors/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d736253d55fe440bc152172929824e2b8ca338 --- /dev/null +++ b/tracking/detectors/yolox.py @@ -0,0 +1,221 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import gdown +import torch +import numpy as np +import cv2 +from ultralytics.engine.results import Results +from ultralytics.utils import ops +from ultralytics.models.yolo.detect import DetectionPredictor +from yolox.exp import get_exp +from yolox.utils import postprocess +from yolox.utils.model_utils import fuse_model + +from boxmot.utils import logger as LOGGER +from tracking.detectors.yolo_interface import YoloInterface + +# default model weigths for these model names +YOLOX_ZOO = { + 'yolox_n.pt': 'https://drive.google.com/uc?id=1AoN2AxzVwOLM0gJ15bcwqZUpFjlDV1dX', + 'yolox_s.pt': 'https://drive.google.com/uc?id=1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj', + 'yolox_m.pt': 'https://drive.google.com/uc?id=11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun', + 'yolox_l.pt': 'https://drive.google.com/uc?id=1XwfUuCBF4IgWBWK2H7oOhQgEj9Mrb3rz', + 'yolox_x.pt': 'https://drive.google.com/uc?id=1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5', + 'yolox_x_ablation.pt': 'https://drive.google.com/uc?id=1iqhM-6V_r1FpOlOzrdP_Ejshgk0DxOob', +} + + +class YoloXStrategy(YoloInterface): + pt = False + stride = 32 + fp16 = False + triton = False + names = { + 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', + 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', + 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', + 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', + 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', + 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', + 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', + 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', + 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', + 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', + 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', + 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', + 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', + 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', + 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', + 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush' + } + + def __init__(self, model, device, args): + + self.args = args + self.imgsz = args.imgsz + self.pt = False + self.stride = 32 # max stride in YOLOX + + # model_type one of: 'yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x' + model_type = self.get_model_from_weigths(YOLOX_ZOO.keys(), model) + + if model_type == 'yolox_n': + exp = get_exp(None, 'yolox_nano') + else: + exp = get_exp(None, model_type) + + LOGGER.info(f'Loading {model_type} with {str(model)}') + + # download crowdhuman bytetrack models + if not model.exists() and (model.stem == model_type or + model.stem == 'yolox_x_ablation'): + LOGGER.info('Downloading pretrained weights...') + gdown.download( + url=YOLOX_ZOO[model.stem + '.pt'], + output=str(model), + quiet=False + ) + # needed for bytetrack yolox people models + # update with your custom model needs + exp.num_classes = 1 + elif model.stem.startswith(model_type): + exp.num_classes = 1 + + ckpt = torch.load( + str(model), + map_location=torch.device('cpu') + ) + + self.device = device + self.model = exp.get_model() + self.model.eval() + self.model.load_state_dict(ckpt["model"]) + self.model = fuse_model(self.model) + self.model.to(self.device) + self.model.eval() + self.im_paths = [] + self._preproc_data = [] + + @torch.no_grad() + def __call__(self, im, augment, visualize, embed): + if isinstance(im, list): + if len(im[0].shape) == 3: + im = torch.stack(im) + else: + im = torch.vstack(im) + + if len(im.shape) == 3: + im = im.unsqueeze(0) + + assert len(im.shape) == 4, f"Expected 4D tensor as input, got {im.shape}" + + preds = self.model(im) + return preds + + def warmup(self, imgsz): + pass + + def update_im_paths(self, predictor: DetectionPredictor): + """ + This function saves image paths for the current batch, + being passed as callback on_predict_batch_start + """ + assert (isinstance(predictor, DetectionPredictor), + "Only ultralytics predictors are supported") + self.im_paths = predictor.batch[0] + + # This preprocess differs from the current version of YOLOX preprocess, but ByteTrack uses it + # https://github.com/ifzhang/ByteTrack/blob/d1bf0191adff59bc8fcfeaa0b33d3d1642552a99/yolox/data/data_augment.py#L189 + def yolox_preprocess( + self, + image, + input_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + swap=(2, 0, 1) + ): + if len(image.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 + else: + padded_img = np.ones(input_size) * 114.0 + img = np.array(image) + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img[:, :, ::-1] + padded_img /= 255.0 + if mean is not None: + padded_img -= mean + if std is not None: + padded_img /= std + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + + + def preprocess(self, im) -> torch.Tensor: + assert isinstance(im, list) + im_preprocessed = [] + self._preproc_data = [] + for i, img in enumerate(im): + img_pre, ratio = self.yolox_preprocess(img, input_size=self.imgsz) + img_pre = torch.Tensor(img_pre).unsqueeze(0).to(self.device) + + im_preprocessed.append(img_pre) + self._preproc_data.append(ratio) + + im_preprocessed = torch.vstack(im_preprocessed) + + return im_preprocessed + + def postprocess(self, preds, im, im0s): + + results = [] + for i, pred in enumerate(preds): + im_path = self.im_paths[i] if len(self.im_paths) else "" + + pred = postprocess( + pred.unsqueeze(0), # YOLOX postprocessor expects 3D arary + 1, + conf_thre=self.args.conf, + nms_thre=self.args.iou, + class_agnostic=self.args.agnostic_nms + )[0] + + if pred is None: + pred = torch.empty((0, 6)) + r = Results( + path=im_path, + boxes=pred, + orig_img=im0s[i], + names=self.names + ) + results.append(r) + else: + ratio = self._preproc_data[i] + pred[:, 0] = pred[:, 0] / ratio + pred[:, 1] = pred[:, 1] / ratio + pred[:, 2] = pred[:, 2] / ratio + pred[:, 3] = pred[:, 3] / ratio + pred[:, 4] *= pred[:, 5] + pred = pred[:, [0, 1, 2, 3, 4, 6]] + + # filter boxes by classes + if self.args.classes: + pred = pred[torch.isin(pred[:, 5].cpu(), torch.as_tensor(self.args.classes))] + + r = Results( + path=im_path, + boxes=pred, + orig_img=im0s[i], + names=self.names + ) + + results.append(r) + + return results diff --git a/tracking/evolve.py b/tracking/evolve.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7828febfaec9bd06fdedc8a08430b9f4a4122b --- /dev/null +++ b/tracking/evolve.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +""" +This script runs a hyperparameter tuning process for a multi-object tracking (MOT) tracker using Ray Tune. +It loads the tracker configuration from a YAML file, sets up the search space for hyperparameters, and evaluates +the tracker to optimize selected metrics (e.g., MOTA, HOTA, IDF1). +""" + +import os +from pathlib import Path +import yaml + +# Check required packages +from boxmot.utils.checks import RequirementsChecker +checker = RequirementsChecker() +checker.check_packages(('ray[tune]',)) # Install ray[tune] if not already present + +import ray +from ray import tune +from ray.air import RunConfig + +from boxmot.utils.checks import RequirementsChecker +from boxmot.utils import EXAMPLES, TRACKER_CONFIGS, ROOT, NUM_THREADS +from tracking.val import ( + run_generate_dets_embs, + run_generate_mot_results, + run_trackeval, + parse_opt as parse_optt, + download_mot_eval_tools +) + + +class Tracker: + """ + Encapsulates the evaluation of a tracking configuration. + """ + def __init__(self, opt): + self.opt = opt + + def objective_function(self, config: dict) -> dict: + """ + Evaluates a given tracker configuration. + + Args: + config (dict): A dictionary of tracker hyperparameters. + + Returns: + dict: Combined evaluation metrics extracted from run_trackeval. + """ + # Ensure evaluation tools are available + download_mot_eval_tools(self.opt.val_tools_path) + # Generate MOT-compliant results with the specified tracker parameters + run_generate_mot_results(self.opt, config) + # Retrieve evaluation metrics (e.g., MOTA, HOTA, IDF1) + results = run_trackeval(self.opt) + # Extract only the desired objective results + combined_results = {key: results.get(key) for key in self.opt.objectives} + return combined_results + + +def load_yaml_config(tracking_method: str) -> dict: + """ + Loads the YAML configuration file for the given tracking method. + + Args: + tracking_method (str): Name of the tracking method. + + Returns: + dict: Configuration parameters loaded from the YAML file. + """ + config_path = TRACKER_CONFIGS / f"{tracking_method}.yaml" + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + return config + + +def yaml_to_search_space(config: dict) -> dict: + """ + Converts a YAML configuration dictionary to a Ray Tune search space. + + Args: + config (dict): YAML configuration parameters. + + Returns: + dict: A dictionary representing the search space for hyperparameters. + """ + search_space = {} + for param, details in config.items(): + search_type = details.get('type') + if search_type == 'uniform': + search_space[param] = tune.uniform(*details['range']) + elif search_type == 'randint': + search_space[param] = tune.randint(*details['range']) + elif search_type == 'qrandint': + search_space[param] = tune.qrandint(*details['range']) + elif search_type == 'choice': + search_space[param] = tune.choice(details['options']) + elif search_type == 'grid_search': + search_space[param] = tune.grid_search(details['values']) + elif search_type == 'loguniform': + search_space[param] = tune.loguniform(*details['range']) + return search_space + + +def main(): + # Parse options and set necessary paths + opt = parse_optt() + opt.val_tools_path = EXAMPLES / 'val_utils' + opt.source = Path(opt.source).resolve() + opt.yolo_model = [Path(y).resolve() for y in opt.yolo_model] + opt.reid_model = [Path(r).resolve() for r in opt.reid_model] + + # Load YAML configuration and convert it to a Ray Tune search space + yaml_config = load_yaml_config(opt.tracking_method) + search_space = yaml_to_search_space(yaml_config) + + # Create a Tracker instance + tracker = Tracker(opt) + + # Generate detection and embedding files required for evaluation + run_generate_dets_embs(opt) + + # Define a wrapper for the objective function for Ray Tune + def tune_wrapper(config): + return tracker.objective_function(config) + + results_dir = os.path.abspath("ray/") + + # Set up and run the hyperparameter tuning using Ray Tune + tuner = tune.Tuner( + tune.with_resources(tune_wrapper, {"cpu": NUM_THREADS, "gpu": 0}), + param_space=search_space, + tune_config=tune.TuneConfig(num_samples=opt.n_trials), + run_config=RunConfig(storage_path=results_dir) + ) + tuner.fit() + + # Print the tuning results + print(tuner.get_results()) + + +if __name__ == "__main__": + main() diff --git a/tracking/track.py b/tracking/track.py new file mode 100644 index 0000000000000000000000000000000000000000..d28dd8377c44107a318e95ddcb813bf21d79a3e5 --- /dev/null +++ b/tracking/track.py @@ -0,0 +1,191 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import argparse +import cv2 +import numpy as np +from functools import partial +from pathlib import Path + +import torch + +from boxmot import TRACKERS +from boxmot.tracker_zoo import create_tracker +from boxmot.utils import ROOT, WEIGHTS, TRACKER_CONFIGS +from boxmot.utils.checks import RequirementsChecker +from tracking.detectors import (get_yolo_inferer, default_imgsz, + is_ultralytics_model, is_yolox_model) + +# checker = RequirementsChecker() +# checker.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install + +from ultralytics import YOLO +from ultralytics.utils.plotting import Annotator, colors +from ultralytics.data.utils import VID_FORMATS +from ultralytics.utils.plotting import save_one_box + + +def on_predict_start(predictor, persist=False): + """ + Initialize trackers for object tracking during prediction. + + Args: + predictor (object): The predictor object to initialize trackers for. + persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. + """ + + assert predictor.custom_args.tracking_method in TRACKERS, \ + f"'{predictor.custom_args.tracking_method}' is not supported. Supported ones are {TRACKERS}" + + tracking_config = TRACKER_CONFIGS / (predictor.custom_args.tracking_method + '.yaml') + trackers = [] + for i in range(predictor.dataset.bs): + tracker = create_tracker( + predictor.custom_args.tracking_method, + tracking_config, + predictor.custom_args.reid_model, + predictor.device, + predictor.custom_args.half, + predictor.custom_args.per_class + ) + # motion only modeles do not have + if hasattr(tracker, 'model'): + tracker.model.warmup() + trackers.append(tracker) + + predictor.trackers = trackers + + +@torch.no_grad() +def run(args): + + if args.imgsz is None: + args.imgsz = default_imgsz(args.yolo_model) + yolo = YOLO( + args.yolo_model if is_ultralytics_model(args.yolo_model) + else 'yolov8n.pt', + ) + + results = yolo.track( + source=args.source, + conf=args.conf, + iou=args.iou, + agnostic_nms=args.agnostic_nms, + show=False, + stream=True, + device=args.device, + show_conf=args.show_conf, + save_txt=args.save_txt, + show_labels=args.show_labels, + save=args.save, + verbose=args.verbose, + exist_ok=args.exist_ok, + project=args.project, + name=args.name, + classes=args.classes, + imgsz=args.imgsz, + vid_stride=args.vid_stride, + line_width=args.line_width + ) + + yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True)) + + if not is_ultralytics_model(args.yolo_model): + # replace yolov8 model + m = get_yolo_inferer(args.yolo_model) + yolo_model = m(model=args.yolo_model, device=yolo.predictor.device, + args=yolo.predictor.args) + yolo.predictor.model = yolo_model + + # If current model is YOLOX, change the preprocess and postprocess + if not is_ultralytics_model(args.yolo_model): + # add callback to save image paths for further processing + yolo.add_callback( + "on_predict_batch_start", + lambda p: yolo_model.update_im_paths(p) + ) + yolo.predictor.preprocess = ( + lambda imgs: yolo_model.preprocess(im=imgs)) + yolo.predictor.postprocess = ( + lambda preds, im, im0s: + yolo_model.postprocess(preds=preds, im=im, im0s=im0s)) + + # store custom args in predictor + yolo.predictor.custom_args = args + + for r in results: + + + if hasattr(yolo.predictor.trackers[0], "plot_results"): + img = yolo.predictor.trackers[0].plot_results(r.orig_img, args.show_trajectories) + else: + # Ultralytics Results handles its own image internally + img = r.plot() + + if args.show is True: + cv2.imshow('BoxMOT', img) + key = cv2.waitKey(1) & 0xFF + if key == ord(' ') or key == ord('q'): + break + +def parse_opt(): + + parser = argparse.ArgumentParser() + parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8n', + help='yolo model path') + parser.add_argument('--reid-model', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', + help='reid model path') + parser.add_argument('--tracking-method', type=str, default='deepocsort', + help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc, boosttrack') + parser.add_argument('--source', type=str, default='0', + help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=None, + help='inference size h,w') + parser.add_argument('--conf', type=float, default=0.5, + help='confidence threshold') + parser.add_argument('--iou', type=float, default=0.7, + help='intersection over union (IoU) threshold for NMS') + parser.add_argument('--device', default='', + help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--show', action='store_true', + help='display tracking video results') + parser.add_argument('--save', action='store_true', + help='save video tracking results') + # class 0 is person, 1 is bycicle, 2 is car... 79 is oven + parser.add_argument('--classes', nargs='+', type=int, + help='filter by class: --classes 0, or --classes 0 2 3') + parser.add_argument('--project', default=ROOT / 'runs' / 'track', + help='save results to project/name') + parser.add_argument('--name', default='exp', + help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', + help='existing project/name ok, do not increment') + parser.add_argument('--half', action='store_true', + help='use FP16 half-precision inference') + parser.add_argument('--vid-stride', type=int, default=1, + help='video frame-rate stride') + parser.add_argument('--show-labels', action='store_false', + help='either show all or only bboxes') + parser.add_argument('--show-conf', action='store_false', + help='hide confidences when show') + parser.add_argument('--show-trajectories', action='store_true', + help='show confidences') + parser.add_argument('--save-txt', action='store_true', + help='save tracking results in a txt file') + parser.add_argument('--save-id-crops', action='store_true', + help='save each crop to its respective id folder') + parser.add_argument('--line-width', default=None, type=int, + help='The line width of the bounding boxes. If None, it is scaled to the image size.') + parser.add_argument('--per-class', default=False, action='store_true', + help='not mix up classes when tracking') + parser.add_argument('--verbose', default=True, action='store_true', + help='print results per frame') + parser.add_argument('--agnostic-nms', default=False, action='store_true', + help='class-agnostic NMS') + + opt = parser.parse_args() + return opt + + +if __name__ == "__main__": + opt = parse_opt() + run(opt) diff --git a/tracking/ultralytics/__init__.py b/tracking/ultralytics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7418168ee1bb94b941c9a5962dbdd51ca16912be --- /dev/null +++ b/tracking/ultralytics/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +__version__ = "8.3.96" + +import os + +# Set ENV variables (place before imports) +if not os.environ.get("OMP_NUM_THREADS"): + os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training + +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.utils import ASSETS, SETTINGS +from ultralytics.utils.checks import check_yolo as checks +from ultralytics.utils.downloads import download + +settings = SETTINGS +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", +) diff --git a/tracking/ultralytics/cfg/__init__.py b/tracking/ultralytics/cfg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49445ec545ce88456d450f8216a398cdcc5dc964 --- /dev/null +++ b/tracking/ultralytics/cfg/__init__.py @@ -0,0 +1,1029 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import shutil +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List, Union + +import cv2 + +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + DEFAULT_SOL_DICT, + IS_VSCODE, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_FILE, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + vscode_msg, + yaml_load, + yaml_print, +) + +# Define valid solutions +SOLUTION_MAP = { + "count": "ObjectCounter", + "crop": "ObjectCropper", + "blur": "ObjectBlurrer", + "workout": "AIGym", + "heatmap": "Heatmap", + "isegment": "InstanceSegmentation", + "visioneye": "VisionEye", + "speed": "SpeedEstimator", + "queue": "QueueManager", + "analytics": "Analytics", + "inference": "Inference", + "trackzone": "TrackZone", + "help": None, +} + +# Define valid tasks and modes +MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"}) +TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"}) +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} +TASK2MODEL = { + "detect": "yolo11n.pt", + "segment": "yolo11n-seg.pt", + "classify": "yolo11n-cls.pt", + "pose": "yolo11n-pose.pt", + "obb": "yolo11n-obb.pt", +} +TASK2METRIC = { + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} +MODELS = frozenset({TASK2MODEL[task] for task in TASKS}) + +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +SOLUTIONS_HELP_MSG = f""" + Arguments received: {str(["yolo"] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: + + yolo solutions SOLUTION ARGS + + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())[:-1]} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + + 1. Call object counting solution + yolo solutions count source="path/to/video.mp4" region="[(20, 400), (1080, 400), (1080, 360), (20, 360)]" + + 2. Call heatmaps solution + yolo solutions heatmap colormap=cv2.COLORMAP_PARULA model=yolo11n.pt + + 3. Call queue management solution + yolo solutions queue region="[(20, 400), (1080, 400), (1080, 360), (20, 360)]" model=yolo11n.pt + + 4. Call workouts monitoring solution for push-ups + yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10] + + 5. Generate analytical graphs + yolo solutions analytics analytics_type="pie" + + 6. Track objects within specific zones + yolo solutions trackzone source="path/to/video.mp4" region="[(150, 150), (1130, 150), (1130, 570), (150, 570)]" + + 7. Streamlit real-time webcam inference GUI + yolo streamlit-predict + """ +CLI_HELP_MSG = f""" + Arguments received: {str(["yolo"] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of {TASKS} + MODE (required) is one of {MODES} + ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + + 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + 2. Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + 3. Val a pretrained detection model at batch-size 1 and image size 640: + yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + 5. Ultralytics solutions usage + yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video.mp4" + + 6. Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + yolo solutions help + + Docs: https://docs.ultralytics.com + Solutions: https://docs.ultralytics.com/solutions/ + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Define keys for arg type checks +CFG_FLOAT_KEYS = frozenset( + { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", + } +) +CFG_FRACTION_KEYS = frozenset( + { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", + } +) +CFG_INT_KEYS = frozenset( + { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", + } +) +CFG_BOOL_KEYS = frozenset( + { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", + } +) + + +def cfg2dict(cfg: Union[str, Path, Dict, SimpleNamespace]) -> Dict: + """ + Converts a configuration object to a dictionary. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path, + a string, a dictionary, or a SimpleNamespace object. + + Returns: + (dict): Configuration object in dictionary format. + + Examples: + Convert a YAML file path to a dictionary: + >>> config_dict = cfg2dict("config.yaml") + + Convert a SimpleNamespace to a dictionary: + >>> from types import SimpleNamespace + >>> config_sn = SimpleNamespace(param1="value1", param2="value2") + >>> config_dict = cfg2dict(config_sn) + + Pass through an already existing dictionary: + >>> config_dict = cfg2dict({"param1": "value1", "param2": "value2"}) + + Notes: + - If cfg is a path or string, it's loaded as YAML and converted to a dictionary. + - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars(). + - If cfg is already a dictionary, it's returned unchanged. + """ + if isinstance(cfg, (str, Path)): + cfg = yaml_load(cfg) # load dict + elif isinstance(cfg, SimpleNamespace): + cfg = vars(cfg) # convert to dict + return cfg + + +def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None) -> SimpleNamespace: + """ + Load and merge configuration data from a file or dictionary, with optional overrides. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or + SimpleNamespace object. + overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration. + + Returns: + (SimpleNamespace): Namespace containing the merged configuration arguments. + + Examples: + >>> from ultralytics.cfg import get_cfg + >>> config = get_cfg() # Load default configuration + >>> config_with_overrides = get_cfg("path/to/config.yaml", overrides={"epochs": 50, "batch_size": 16}) + + Notes: + - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence. + - Special handling ensures alignment and correctness of the configuration, such as converting numeric + `project` and `name` to strings and validating configuration keys and values. + - The function performs type and value checks on the configuration data. + """ + cfg = cfg2dict(cfg) + + # Merge overrides + if overrides: + overrides = cfg2dict(overrides) + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore + check_dict_alignment(cfg, overrides) + cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) + + # Special handling for numeric project/name + for k in "project", "name": + if k in cfg and isinstance(cfg[k], (int, float)): + cfg[k] = str(cfg[k]) + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = str(cfg.get("model", "")).split(".")[0] + LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") + + # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg: Dict, hard: bool = True) -> None: + """ + Checks configuration argument types and values for the Ultralytics library. + + This function validates the types and values of configuration arguments, ensuring correctness and converting + them if necessary. It checks for specific key types defined in global variables such as `CFG_FLOAT_KEYS`, + `CFG_FRACTION_KEYS`, `CFG_INT_KEYS`, and `CFG_BOOL_KEYS`. + + Args: + cfg (dict): Configuration dictionary to validate. + hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them. + + Examples: + >>> config = { + ... "epochs": 50, # valid integer + ... "lr0": 0.01, # valid float + ... "momentum": 1.2, # invalid float (out of 0.0-1.0 range) + ... "save": "true", # invalid bool + ... } + >>> check_cfg(config, hard=False) + >>> print(config) + {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key + + Notes: + - The function modifies the input dictionary in-place. + - None values are ignored as they may be from optional arguments. + - Fraction keys are checked to be within the range [0.0, 1.0]. + """ + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. '{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + +def get_save_dir(args: SimpleNamespace, name: str = None) -> Path: + """ + Returns the directory path for saving outputs, derived from arguments or default settings. + + Args: + args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', + 'mode', and 'save_dir'. + name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' + or the 'args.mode'. + + Returns: + (Path): Directory path where outputs should be saved. + + Examples: + >>> from types import SimpleNamespace + >>> args = SimpleNamespace(project="my_project", task="detect", mode="train", exist_ok=True) + >>> save_dir = get_save_dir(args) + >>> print(save_dir) + my_project/detect/train + """ + if getattr(args, "save_dir", None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) + + return Path(save_dir) + + +def _handle_deprecation(custom: Dict) -> Dict: + """ + Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings. + + Args: + custom (dict): Configuration dictionary potentially containing deprecated keys. + + Examples: + >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2} + >>> _handle_deprecation(custom_config) + >>> print(custom_config) + {'show_boxes': True, 'show_labels': True, 'line_width': 2} + + Notes: + This function modifies the input dictionary in-place, replacing deprecated keys with their current + equivalents. It also handles value conversions where necessary, such as inverting boolean values for + 'hide_labels' and 'hide_conf'. + """ + for key in custom.copy().keys(): + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") + if key == "label_smoothing": + deprecation_warn(key) + custom.pop("label_smoothing") + + return custom + + +def check_dict_alignment(base: Dict, custom: Dict, e: Exception = None) -> None: + """ + Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error + messages for mismatched keys. + + Args: + base (dict): The base configuration dictionary containing valid keys. + custom (dict): The custom configuration dictionary to be checked for alignment. + e (Exception | None): Optional error instance passed by the calling function. + + Raises: + SystemExit: If mismatched keys are found between the custom and base dictionaries. + + Examples: + >>> base_cfg = {"epochs": 50, "lr0": 0.01, "batch_size": 16} + >>> custom_cfg = {"epoch": 100, "lr": 0.02, "batch_size": 32} + >>> try: + ... check_dict_alignment(base_cfg, custom_cfg) + ... except SystemExit: + ... print("Mismatched keys found") + + Notes: + - Suggests corrections for mismatched keys based on similarity to valid keys. + - Automatically replaces deprecated keys in the custom configuration with updated equivalents. + - Prints detailed error messages for each mismatched key to help users correct their configurations. + """ + custom = _handle_deprecation(custom) + base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom)) + if mismatched := [k for k in custom_keys if k not in base_keys]: + from difflib import get_close_matches + + string = "" + for x in mismatched: + matches = get_close_matches(x, base_keys) # key list + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" + string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" + raise SyntaxError(string + CLI_HELP_MSG) from e + + +def merge_equals_args(args: List[str]) -> List[str]: + """ + Merges arguments around isolated '=' in a list of strings and joins fragments with brackets. + + This function handles the following cases: + 1. ['arg', '=', 'val'] becomes ['arg=val'] + 2. ['arg=', 'val'] becomes ['arg=val'] + 3. ['arg', '=val'] becomes ['arg=val'] + 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]'] + + Args: + args (List[str]): A list of strings where each element represents an argument or fragment. + + Returns: + (List[str]): A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined. + + Examples: + >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"] + >>> merge_equals_args(args) + ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]'] + """ + new_args = [] + current = "" + depth = 0 + + i = 0 + while i < len(args): + arg = args[i] + + # Handle equals sign merging + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" + i += 2 + continue + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") + i += 2 + continue + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] + new_args[-1] += arg + i += 1 + continue + + # Handle bracket joining + depth += arg.count("[") - arg.count("]") + current += arg + if depth == 0: + new_args.append(current) + current = "" + + i += 1 + + # Append any remaining current string + if current: + new_args.append(current) + + return new_args + + +def handle_yolo_hub(args: List[str]) -> None: + """ + Handles Ultralytics HUB command-line interface (CLI) commands for authentication. + + This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a + script with arguments related to HUB authentication. + + Args: + args (List[str]): A list of command line arguments. The first argument should be either 'login' + or 'logout'. For 'login', an optional second argument can be the API key. + + Examples: + ```bash + yolo login YOUR_API_KEY + ``` + + Notes: + - The function imports the 'hub' module from ultralytics to perform login and logout operations. + - For the 'login' command, if no API key is provided, an empty string is passed to the login function. + - The 'logout' command does not require any additional arguments. + """ + from ultralytics import hub + + if args[0] == "login": + key = args[1] if len(args) > 1 else "" + # Log in to Ultralytics HUB using the provided API key + hub.login(key) + elif args[0] == "logout": + # Log out from Ultralytics HUB + hub.logout() + + +def handle_yolo_settings(args: List[str]) -> None: + """ + Handles YOLO settings command-line interface (CLI) commands. + + This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be + called when executing a script with arguments related to YOLO settings management. + + Args: + args (List[str]): A list of command line arguments for YOLO settings management. + + Examples: + >>> handle_yolo_settings(["reset"]) # Reset YOLO settings + >>> handle_yolo_settings(["default_cfg_path=yolo11n.yaml"]) # Update a specific setting + + Notes: + - If no arguments are provided, the function will display the current settings. + - The 'reset' command will delete the existing settings file and create new default settings. + - Other arguments are treated as key-value pairs to update specific settings. + - The function will check for alignment between the provided settings and the existing ones. + - After processing, the updated settings will be displayed. + - For more information on handling YOLO settings, visit: + https://docs.ultralytics.com/quickstart/#ultralytics-settings + """ + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL + try: + if any(args): + if args[0] == "reset": + SETTINGS_FILE.unlink() # delete the settings file + SETTINGS.reset() # create new settings + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset + else: # save a new setting + new = dict(parse_key_value_pair(a) for a in args) + check_dict_alignment(SETTINGS, new) + SETTINGS.update(new) + + print(SETTINGS) # print the current settings + LOGGER.info(f"💡 Learn more about Ultralytics Settings at {url}") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") + + +def handle_yolo_solutions(args: List[str]) -> None: + """ + Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline. + + Args: + args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO + solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source, + and other configuration parameters. + + Examples: + Run people counting solution with default settings: + >>> handle_yolo_solutions(["count"]) + + Run analytics with custom configuration: + >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video.mp4"]) + + Run inference with custom configuration, requires Streamlit version 1.29.0 or higher. + >>> handle_yolo_solutions(["inference", "model=yolo11n.pt"]) + + Notes: + - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT + - Arguments can be provided in the format 'key=value' or as boolean flags + - Available solutions are defined in SOLUTION_MAP with their respective classes and methods + - If an invalid solution is provided, defaults to 'count' solution + - Output videos are saved in 'runs/solution/{solution_name}' directory + - For 'analytics' solution, frame numbers are tracked for generating analytical graphs + - Video processing can be interrupted by pressing 'q' + - Processes video frames sequentially and saves output in .avi format + - If no source is specified, downloads and uses a default sample video + - The inference solution will be launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. + """ + full_args_dict = { + **DEFAULT_SOL_DICT, + **DEFAULT_CFG_DICT, + "blur_ratio": 0.5, + "vision_point": (20, 20), + "crop_dir": "cropped-detections", + } # arguments dictionary + overrides = {} + + # check dictionary alignment + for arg in merge_equals_args(args): + arg = arg.lstrip("-").rstrip(",") + if "=" in arg: + try: + k, v = parse_key_value_pair(arg) + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {arg: ""}, e) + elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool): + overrides[arg] = True + check_dict_alignment(full_args_dict, overrides) # dict alignment + + # Get solution name + if not args: + LOGGER.warning("⚠️ No solution name provided. i.e `yolo solutions count`. Defaulting to 'count'.") + args = ["count"] + if args[0] == "help": + LOGGER.info(SOLUTIONS_HELP_MSG) + return # Early return for 'help' case + elif args[0] in SOLUTION_MAP: + solution_name = args.pop(0) # Extract the solution name directly + else: + LOGGER.warning( + f"❌ '{args[0]}' is not a valid solution. 💡 Defaulting to 'count'.\n" + f"🚀 Available solutions: {', '.join(list(SOLUTION_MAP.keys())[:-1])}\n" + ) + solution_name = "count" # Default for invalid solution + + if solution_name == "inference": + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics live inference app...") + subprocess.run( + [ # Run subprocess with Streamlit custom argument + "streamlit", + "run", + str(ROOT / "solutions/streamlit_inference.py"), + "--server.headless", + "true", + overrides.pop("model", "yolo11n.pt"), + ] + ) + else: + from ultralytics import solutions + + solution = getattr(solutions, SOLUTION_MAP[solution_name])(is_cli=True, **overrides) # class i.e ObjectCounter + + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + if solution_name != "crop": + # extract width, height and fps of the video file, create save directory and initialize video writer + w, h, fps = ( + int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS) + ) + if solution_name == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1280, 720 + save_dir = get_save_dir(SimpleNamespace(project="runs/solutions", name="exp", exist_ok=False)) + save_dir.mkdir(parents=True) # create the output directory i.e. runs/solutions/exp + vw = cv2.VideoWriter(str(save_dir / f"{solution_name}.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + results = solution(frame, f_n := f_n + 1) if solution_name == "analytics" else solution(frame) + if solution_name != "crop": + vw.write(results.plot_im) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() + + +def parse_key_value_pair(pair: str = "key=value") -> tuple: + """ + Parses a key-value pair string into separate key and value components. + + Args: + pair (str): A string containing a key-value pair in the format "key=value". + + Returns: + key (str): The parsed key. + value (str): The parsed value. + + Raises: + AssertionError: If the value is missing or empty. + + Examples: + >>> key, value = parse_key_value_pair("model=yolo11n.pt") + >>> print(f"Key: {key}, Value: {value}") + Key: model, Value: yolo11n.pt + + >>> key, value = parse_key_value_pair("epochs=100") + >>> print(f"Key: {key}, Value: {value}") + Key: epochs, Value: 100 + + Notes: + - The function splits the input string on the first '=' character. + - Leading and trailing whitespace is removed from both key and value. + - An assertion error is raised if the value is empty after stripping. + """ + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces + assert v, f"missing '{k}' value" + return k, smart_value(v) + + +def smart_value(v: str) -> Any: + """ + Converts a string representation of a value to its appropriate Python type. + + This function attempts to convert a given string into a Python object of the most appropriate type. It handles + conversions to None, bool, int, float, and other types that can be evaluated safely. + + Args: + v (str): The string representation of the value to be converted. + + Returns: + (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion + is applicable. + + Examples: + >>> smart_value("42") + 42 + >>> smart_value("3.14") + 3.14 + >>> smart_value("True") + True + >>> smart_value("None") + None + >>> smart_value("some_string") + 'some_string' + + Notes: + - The function uses a case-insensitive comparison for boolean and None values. + - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input. + - If no conversion is possible, the original string is returned. + """ + v_lower = v.lower() + if v_lower == "none": + return None + elif v_lower == "true": + return True + elif v_lower == "false": + return False + else: + try: + return eval(v) + except Exception: + return v + + +def entrypoint(debug: str = "") -> None: + """ + Ultralytics entrypoint function for parsing and executing command-line arguments. + + This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and + executing the corresponding tasks such as training, validation, prediction, exporting models, and more. + + Args: + debug (str): Space-separated string of command-line arguments for debugging purposes. + + Examples: + Train a detection model for 10 epochs with an initial learning_rate of 0.01: + >>> entrypoint("train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01") + + Predict a YouTube video using a pretrained segmentation model at image size 320: + >>> entrypoint("predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320") + + Validate a pretrained detection model at batch-size 1 and image size 640: + >>> entrypoint("val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640") + + Notes: + - If no arguments are passed, the function will display the usage help message. + - For a list of all available commands and their arguments, see the provided help messages and the + Ultralytics documentation at https://docs.ultralytics.com. + """ + args = (debug.split(" ") if debug else ARGV)[1:] + if not args: # no arguments passed + LOGGER.info(CLI_HELP_MSG) + return + + special = { + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "logout": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "solutions": lambda: handle_yolo_solutions(args[1:]), + } + full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} + + # Define common misuses of special commands, i.e. -h, -help, --help + special.update({k[0]: v for k, v in special.items()}) # singular + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} + + overrides = {} # basic overrides, i.e. imgsz=320 + for a in merge_equals_args(args): # merge spaces around '=' sign + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + a = a[2:] + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + a = a[:-1] + if "=" in a: + try: + k, v = parse_key_value_pair(a) + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} + else: + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {a: ""}, e) + + elif a in TASKS: + overrides["task"] = a + elif a in MODES: + overrides["mode"] = a + elif a.lower() in special: + special[a.lower()]() + return + elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): + overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True + elif a in DEFAULT_CFG_DICT: + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) + else: + check_dict_alignment(full_args_dict, {a: ""}) + + # Check keys + check_dict_alignment(full_args_dict, overrides) + + # Mode + mode = overrides.get("mode") + if mode is None: + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + elif mode not in MODES: + raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") + + # Task + task = overrides.pop("task", None) + if task: + if task not in TASKS: + if task == "track": + LOGGER.warning( + "WARNING ⚠️ invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}." + ) + task, mode = "detect", "track" + else: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] + + # Model + model = overrides.pop("model", DEFAULT_CFG.model) + if model is None: + model = "yolo11n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + stem = Path(model).stem.lower() + if "rtdetr" in stem: # guess architecture + from ultralytics import RTDETR + + model = RTDETR(model) # no task argument + elif "fastsam" in stem: + from ultralytics import FastSAM + + model = FastSAM(model) + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: + from ultralytics import SAM + + model = SAM(model) + else: + from ultralytics import YOLO + + model = YOLO(model, task=task) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) + + # Task Update + if task != model.task: + if task: + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) + task = model.task + + # Mode + if mode in {"predict", "track"} and "source" not in overrides: + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in {"train", "val"}: + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") + + # Run command in python + getattr(model, mode)(**overrides) # default args from model + + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + + # Recommend VS Code extension + if IS_VSCODE and SETTINGS.get("vscode_msg", True): + LOGGER.info(vscode_msg()) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_cfg() -> None: + """ + Copies the default configuration file and creates a new one with '_copy' appended to its name. + + This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it + with '_copy' appended to its name in the current working directory. It provides a convenient way + to create a custom configuration file based on the default settings. + + Examples: + >>> copy_default_cfg() + # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml + # Example YOLO command with this new custom cfg: + # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8 + + Notes: + - The new configuration file is created in the current working directory. + - After copying, the function prints a message with the new file's location and an example + YOLO command demonstrating how to use the new configuration file. + - This function is useful for users who want to modify the default configuration without + altering the original file. + """ + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") + shutil.copy2(DEFAULT_CFG_PATH, new_file) + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) + + +if __name__ == "__main__": + # Example: entrypoint(debug='yolo predict model=yolo11n.pt') + entrypoint(debug="") diff --git a/tracking/ultralytics/cfg/datasets/Argoverse.yaml b/tracking/ultralytics/cfg/datasets/Argoverse.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28f56bc7ca65cea8d679d2c7e1338300b4cc9a04 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/Argoverse.yaml @@ -0,0 +1,77 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Argoverse-HD dataset (ring-front-center camera) https://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI +# Documentation: https://docs.ultralytics.com/datasets/detect/argoverse/ +# Example usage: yolo train data=Argoverse.yaml +# parent +# ├── ultralytics +# └── datasets +# └── Argoverse ← downloads here (31.5 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/Argoverse # dataset root dir +train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images +val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images +test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: bus + 5: truck + 6: traffic_light + 7: stop_sign + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import json + from pathlib import Path + + from tqdm import tqdm + from ultralytics.utils.downloads import download + + def argoverse2yolo(set): + """Convert Argoverse dataset annotations to YOLO format for object detection tasks.""" + labels = {} + a = json.load(open(set, "rb")) + for annot in tqdm(a["annotations"], desc=f"Converting {set} to YOLOv5 format..."): + img_id = annot["image_id"] + img_name = a["images"][img_id]["name"] + img_label_name = f"{img_name[:-3]}txt" + + cls = annot["category_id"] # instance class id + x_center, y_center, width, height = annot["bbox"] + x_center = (x_center + width / 2) / 1920.0 # offset and scale + y_center = (y_center + height / 2) / 1200.0 # offset and scale + width /= 1920.0 # scale + height /= 1200.0 # scale + + img_dir = set.parents[2] / "Argoverse-1.1" / "labels" / a["seq_dirs"][a["images"][annot["image_id"]]["sid"]] + if not img_dir.exists(): + img_dir.mkdir(parents=True, exist_ok=True) + + k = str(img_dir / img_label_name) + if k not in labels: + labels[k] = [] + labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n") + + for k in labels: + with open(k, "w", encoding="utf-8") as f: + f.writelines(labels[k]) + + + # Download 'https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip' (deprecated S3 link) + dir = Path(yaml["path"]) # dataset root dir + urls = ["https://drive.google.com/file/d/1st9qW3BeIwQsnR0t8mRpvbsSWIo16ACi/view?usp=drive_link"] + print("\n\nWARNING: Argoverse dataset MUST be downloaded manually, autodownload will NOT work.") + print(f"WARNING: Manually download Argoverse dataset '{urls[0]}' to '{dir}' and re-run your command.\n\n") + # download(urls, dir=dir) + + # Convert + annotations_dir = "Argoverse-HD/annotations/" + (dir / "Argoverse-1.1" / "tracking").rename(dir / "Argoverse-1.1" / "images") # rename 'tracking' to 'images' + for d in "train.json", "val.json": + argoverse2yolo(dir / annotations_dir / d) # convert Argoverse annotations to YOLO labels diff --git a/tracking/ultralytics/cfg/datasets/DOTAv1.5.yaml b/tracking/ultralytics/cfg/datasets/DOTAv1.5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26c73808d7b253dea0b4555394bae246a815076d --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/DOTAv1.5.yaml @@ -0,0 +1,37 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DOTA 1.5 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University +# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ +# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.5.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota1.5 ← downloads here (2GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/DOTAv1.5 # dataset root dir +train: images/train # train images (relative to 'path') 1411 images +val: images/val # val images (relative to 'path') 458 images +test: images/test # test images (optional) 937 images + +# Classes for DOTA 1.5 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + 15: container crane + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/DOTAv1.5.zip diff --git a/tracking/ultralytics/cfg/datasets/DOTAv1.yaml b/tracking/ultralytics/cfg/datasets/DOTAv1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e71d2188d50d3efc80562ebb3fa7b65e07d2b5f --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/DOTAv1.yaml @@ -0,0 +1,36 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DOTA 1.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University +# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ +# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota1 ← downloads here (2GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/DOTAv1 # dataset root dir +train: images/train # train images (relative to 'path') 1411 images +val: images/val # val images (relative to 'path') 458 images +test: images/test # test images (optional) 937 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/DOTAv1.zip diff --git a/tracking/ultralytics/cfg/datasets/GlobalWheat2020.yaml b/tracking/ultralytics/cfg/datasets/GlobalWheat2020.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55323bb825bd407cb76954f6a06fd1b9d6d47d9d --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/GlobalWheat2020.yaml @@ -0,0 +1,68 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Global Wheat 2020 dataset https://www.global-wheat.com/ by University of Saskatchewan +# Documentation: https://docs.ultralytics.com/datasets/detect/globalwheat2020/ +# Example usage: yolo train data=GlobalWheat2020.yaml +# parent +# ├── ultralytics +# └── datasets +# └── GlobalWheat2020 ← downloads here (7.0 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/GlobalWheat2020 # dataset root dir +train: # train images (relative to 'path') 3422 images + - images/arvalis_1 + - images/arvalis_2 + - images/arvalis_3 + - images/ethz_1 + - images/rres_1 + - images/inrae_1 + - images/usask_1 +val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1) + - images/ethz_1 +test: # test images (optional) 1276 images + - images/utokyo_1 + - images/utokyo_2 + - images/nau_1 + - images/uq_1 + +# Classes +names: + 0: wheat_head + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + from pathlib import Path + + from ultralytics.utils.downloads import download + + # Download + dir = Path(yaml["path"]) # dataset root dir + urls = [ + "https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip", + "https://github.com/ultralytics/assets/releases/download/v0.0.0/GlobalWheat2020_labels.zip", + ] + download(urls, dir=dir) + + # Make Directories + for p in "annotations", "images", "labels": + (dir / p).mkdir(parents=True, exist_ok=True) + + # Move + for p in ( + "arvalis_1", + "arvalis_2", + "arvalis_3", + "ethz_1", + "rres_1", + "inrae_1", + "usask_1", + "utokyo_1", + "utokyo_2", + "nau_1", + "uq_1", + ): + (dir / "global-wheat-codalab-official" / p).rename(dir / "images" / p) # move to /images + f = (dir / "global-wheat-codalab-official" / p).with_suffix(".json") # json file + if f.exists(): + f.rename((dir / "annotations" / p).with_suffix(".json")) # move to /annotations diff --git a/tracking/ultralytics/cfg/datasets/ImageNet.yaml b/tracking/ultralytics/cfg/datasets/ImageNet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f15b1638a82412b27483951e90f30b58df7bab05 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/ImageNet.yaml @@ -0,0 +1,2025 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University +# Simplified class names from https://github.com/anishathalye/imagenet-simple-labels +# Documentation: https://docs.ultralytics.com/datasets/classify/imagenet/ +# Example usage: yolo train task=classify data=imagenet +# parent +# ├── ultralytics +# └── datasets +# └── imagenet ← downloads here (144 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/imagenet # dataset root dir +train: train # train images (relative to 'path') 1281167 images +val: val # val images (relative to 'path') 50000 images +test: # test images (optional) + +# Classes +names: + 0: tench + 1: goldfish + 2: great white shark + 3: tiger shark + 4: hammerhead shark + 5: electric ray + 6: stingray + 7: cock + 8: hen + 9: ostrich + 10: brambling + 11: goldfinch + 12: house finch + 13: junco + 14: indigo bunting + 15: American robin + 16: bulbul + 17: jay + 18: magpie + 19: chickadee + 20: American dipper + 21: kite + 22: bald eagle + 23: vulture + 24: great grey owl + 25: fire salamander + 26: smooth newt + 27: newt + 28: spotted salamander + 29: axolotl + 30: American bullfrog + 31: tree frog + 32: tailed frog + 33: loggerhead sea turtle + 34: leatherback sea turtle + 35: mud turtle + 36: terrapin + 37: box turtle + 38: banded gecko + 39: green iguana + 40: Carolina anole + 41: desert grassland whiptail lizard + 42: agama + 43: frilled-necked lizard + 44: alligator lizard + 45: Gila monster + 46: European green lizard + 47: chameleon + 48: Komodo dragon + 49: Nile crocodile + 50: American alligator + 51: triceratops + 52: worm snake + 53: ring-necked snake + 54: eastern hog-nosed snake + 55: smooth green snake + 56: kingsnake + 57: garter snake + 58: water snake + 59: vine snake + 60: night snake + 61: boa constrictor + 62: African rock python + 63: Indian cobra + 64: green mamba + 65: sea snake + 66: Saharan horned viper + 67: eastern diamondback rattlesnake + 68: sidewinder + 69: trilobite + 70: harvestman + 71: scorpion + 72: yellow garden spider + 73: barn spider + 74: European garden spider + 75: southern black widow + 76: tarantula + 77: wolf spider + 78: tick + 79: centipede + 80: black grouse + 81: ptarmigan + 82: ruffed grouse + 83: prairie grouse + 84: peacock + 85: quail + 86: partridge + 87: grey parrot + 88: macaw + 89: sulphur-crested cockatoo + 90: lorikeet + 91: coucal + 92: bee eater + 93: hornbill + 94: hummingbird + 95: jacamar + 96: toucan + 97: duck + 98: red-breasted merganser + 99: goose + 100: black swan + 101: tusker + 102: echidna + 103: platypus + 104: wallaby + 105: koala + 106: wombat + 107: jellyfish + 108: sea anemone + 109: brain coral + 110: flatworm + 111: nematode + 112: conch + 113: snail + 114: slug + 115: sea slug + 116: chiton + 117: chambered nautilus + 118: Dungeness crab + 119: rock crab + 120: fiddler crab + 121: red king crab + 122: American lobster + 123: spiny lobster + 124: crayfish + 125: hermit crab + 126: isopod + 127: white stork + 128: black stork + 129: spoonbill + 130: flamingo + 131: little blue heron + 132: great egret + 133: bittern + 134: crane (bird) + 135: limpkin + 136: common gallinule + 137: American coot + 138: bustard + 139: ruddy turnstone + 140: dunlin + 141: common redshank + 142: dowitcher + 143: oystercatcher + 144: pelican + 145: king penguin + 146: albatross + 147: grey whale + 148: killer whale + 149: dugong + 150: sea lion + 151: Chihuahua + 152: Japanese Chin + 153: Maltese + 154: Pekingese + 155: Shih Tzu + 156: King Charles Spaniel + 157: Papillon + 158: toy terrier + 159: Rhodesian Ridgeback + 160: Afghan Hound + 161: Basset Hound + 162: Beagle + 163: Bloodhound + 164: Bluetick Coonhound + 165: Black and Tan Coonhound + 166: Treeing Walker Coonhound + 167: English foxhound + 168: Redbone Coonhound + 169: borzoi + 170: Irish Wolfhound + 171: Italian Greyhound + 172: Whippet + 173: Ibizan Hound + 174: Norwegian Elkhound + 175: Otterhound + 176: Saluki + 177: Scottish Deerhound + 178: Weimaraner + 179: Staffordshire Bull Terrier + 180: American Staffordshire Terrier + 181: Bedlington Terrier + 182: Border Terrier + 183: Kerry Blue Terrier + 184: Irish Terrier + 185: Norfolk Terrier + 186: Norwich Terrier + 187: Yorkshire Terrier + 188: Wire Fox Terrier + 189: Lakeland Terrier + 190: Sealyham Terrier + 191: Airedale Terrier + 192: Cairn Terrier + 193: Australian Terrier + 194: Dandie Dinmont Terrier + 195: Boston Terrier + 196: Miniature Schnauzer + 197: Giant Schnauzer + 198: Standard Schnauzer + 199: Scottish Terrier + 200: Tibetan Terrier + 201: Australian Silky Terrier + 202: Soft-coated Wheaten Terrier + 203: West Highland White Terrier + 204: Lhasa Apso + 205: Flat-Coated Retriever + 206: Curly-coated Retriever + 207: Golden Retriever + 208: Labrador Retriever + 209: Chesapeake Bay Retriever + 210: German Shorthaired Pointer + 211: Vizsla + 212: English Setter + 213: Irish Setter + 214: Gordon Setter + 215: Brittany + 216: Clumber Spaniel + 217: English Springer Spaniel + 218: Welsh Springer Spaniel + 219: Cocker Spaniels + 220: Sussex Spaniel + 221: Irish Water Spaniel + 222: Kuvasz + 223: Schipperke + 224: Groenendael + 225: Malinois + 226: Briard + 227: Australian Kelpie + 228: Komondor + 229: Old English Sheepdog + 230: Shetland Sheepdog + 231: collie + 232: Border Collie + 233: Bouvier des Flandres + 234: Rottweiler + 235: German Shepherd Dog + 236: Dobermann + 237: Miniature Pinscher + 238: Greater Swiss Mountain Dog + 239: Bernese Mountain Dog + 240: Appenzeller Sennenhund + 241: Entlebucher Sennenhund + 242: Boxer + 243: Bullmastiff + 244: Tibetan Mastiff + 245: French Bulldog + 246: Great Dane + 247: St. Bernard + 248: husky + 249: Alaskan Malamute + 250: Siberian Husky + 251: Dalmatian + 252: Affenpinscher + 253: Basenji + 254: pug + 255: Leonberger + 256: Newfoundland + 257: Pyrenean Mountain Dog + 258: Samoyed + 259: Pomeranian + 260: Chow Chow + 261: Keeshond + 262: Griffon Bruxellois + 263: Pembroke Welsh Corgi + 264: Cardigan Welsh Corgi + 265: Toy Poodle + 266: Miniature Poodle + 267: Standard Poodle + 268: Mexican hairless dog + 269: grey wolf + 270: Alaskan tundra wolf + 271: red wolf + 272: coyote + 273: dingo + 274: dhole + 275: African wild dog + 276: hyena + 277: red fox + 278: kit fox + 279: Arctic fox + 280: grey fox + 281: tabby cat + 282: tiger cat + 283: Persian cat + 284: Siamese cat + 285: Egyptian Mau + 286: cougar + 287: lynx + 288: leopard + 289: snow leopard + 290: jaguar + 291: lion + 292: tiger + 293: cheetah + 294: brown bear + 295: American black bear + 296: polar bear + 297: sloth bear + 298: mongoose + 299: meerkat + 300: tiger beetle + 301: ladybug + 302: ground beetle + 303: longhorn beetle + 304: leaf beetle + 305: dung beetle + 306: rhinoceros beetle + 307: weevil + 308: fly + 309: bee + 310: ant + 311: grasshopper + 312: cricket + 313: stick insect + 314: cockroach + 315: mantis + 316: cicada + 317: leafhopper + 318: lacewing + 319: dragonfly + 320: damselfly + 321: red admiral + 322: ringlet + 323: monarch butterfly + 324: small white + 325: sulphur butterfly + 326: gossamer-winged butterfly + 327: starfish + 328: sea urchin + 329: sea cucumber + 330: cottontail rabbit + 331: hare + 332: Angora rabbit + 333: hamster + 334: porcupine + 335: fox squirrel + 336: marmot + 337: beaver + 338: guinea pig + 339: common sorrel + 340: zebra + 341: pig + 342: wild boar + 343: warthog + 344: hippopotamus + 345: ox + 346: water buffalo + 347: bison + 348: ram + 349: bighorn sheep + 350: Alpine ibex + 351: hartebeest + 352: impala + 353: gazelle + 354: dromedary + 355: llama + 356: weasel + 357: mink + 358: European polecat + 359: black-footed ferret + 360: otter + 361: skunk + 362: badger + 363: armadillo + 364: three-toed sloth + 365: orangutan + 366: gorilla + 367: chimpanzee + 368: gibbon + 369: siamang + 370: guenon + 371: patas monkey + 372: baboon + 373: macaque + 374: langur + 375: black-and-white colobus + 376: proboscis monkey + 377: marmoset + 378: white-headed capuchin + 379: howler monkey + 380: titi + 381: Geoffroy's spider monkey + 382: common squirrel monkey + 383: ring-tailed lemur + 384: indri + 385: Asian elephant + 386: African bush elephant + 387: red panda + 388: giant panda + 389: snoek + 390: eel + 391: coho salmon + 392: rock beauty + 393: clownfish + 394: sturgeon + 395: garfish + 396: lionfish + 397: pufferfish + 398: abacus + 399: abaya + 400: academic gown + 401: accordion + 402: acoustic guitar + 403: aircraft carrier + 404: airliner + 405: airship + 406: altar + 407: ambulance + 408: amphibious vehicle + 409: analog clock + 410: apiary + 411: apron + 412: waste container + 413: assault rifle + 414: backpack + 415: bakery + 416: balance beam + 417: balloon + 418: ballpoint pen + 419: Band-Aid + 420: banjo + 421: baluster + 422: barbell + 423: barber chair + 424: barbershop + 425: barn + 426: barometer + 427: barrel + 428: wheelbarrow + 429: baseball + 430: basketball + 431: bassinet + 432: bassoon + 433: swimming cap + 434: bath towel + 435: bathtub + 436: station wagon + 437: lighthouse + 438: beaker + 439: military cap + 440: beer bottle + 441: beer glass + 442: bell-cot + 443: bib + 444: tandem bicycle + 445: bikini + 446: ring binder + 447: binoculars + 448: birdhouse + 449: boathouse + 450: bobsleigh + 451: bolo tie + 452: poke bonnet + 453: bookcase + 454: bookstore + 455: bottle cap + 456: bow + 457: bow tie + 458: brass + 459: bra + 460: breakwater + 461: breastplate + 462: broom + 463: bucket + 464: buckle + 465: bulletproof vest + 466: high-speed train + 467: butcher shop + 468: taxicab + 469: cauldron + 470: candle + 471: cannon + 472: canoe + 473: can opener + 474: cardigan + 475: car mirror + 476: carousel + 477: tool kit + 478: carton + 479: car wheel + 480: automated teller machine + 481: cassette + 482: cassette player + 483: castle + 484: catamaran + 485: CD player + 486: cello + 487: mobile phone + 488: chain + 489: chain-link fence + 490: chain mail + 491: chainsaw + 492: chest + 493: chiffonier + 494: chime + 495: china cabinet + 496: Christmas stocking + 497: church + 498: movie theater + 499: cleaver + 500: cliff dwelling + 501: cloak + 502: clogs + 503: cocktail shaker + 504: coffee mug + 505: coffeemaker + 506: coil + 507: combination lock + 508: computer keyboard + 509: confectionery store + 510: container ship + 511: convertible + 512: corkscrew + 513: cornet + 514: cowboy boot + 515: cowboy hat + 516: cradle + 517: crane (machine) + 518: crash helmet + 519: crate + 520: infant bed + 521: Crock Pot + 522: croquet ball + 523: crutch + 524: cuirass + 525: dam + 526: desk + 527: desktop computer + 528: rotary dial telephone + 529: diaper + 530: digital clock + 531: digital watch + 532: dining table + 533: dishcloth + 534: dishwasher + 535: disc brake + 536: dock + 537: dog sled + 538: dome + 539: doormat + 540: drilling rig + 541: drum + 542: drumstick + 543: dumbbell + 544: Dutch oven + 545: electric fan + 546: electric guitar + 547: electric locomotive + 548: entertainment center + 549: envelope + 550: espresso machine + 551: face powder + 552: feather boa + 553: filing cabinet + 554: fireboat + 555: fire engine + 556: fire screen sheet + 557: flagpole + 558: flute + 559: folding chair + 560: football helmet + 561: forklift + 562: fountain + 563: fountain pen + 564: four-poster bed + 565: freight car + 566: French horn + 567: frying pan + 568: fur coat + 569: garbage truck + 570: gas mask + 571: gas pump + 572: goblet + 573: go-kart + 574: golf ball + 575: golf cart + 576: gondola + 577: gong + 578: gown + 579: grand piano + 580: greenhouse + 581: grille + 582: grocery store + 583: guillotine + 584: barrette + 585: hair spray + 586: half-track + 587: hammer + 588: hamper + 589: hair dryer + 590: hand-held computer + 591: handkerchief + 592: hard disk drive + 593: harmonica + 594: harp + 595: harvester + 596: hatchet + 597: holster + 598: home theater + 599: honeycomb + 600: hook + 601: hoop skirt + 602: horizontal bar + 603: horse-drawn vehicle + 604: hourglass + 605: iPod + 606: clothes iron + 607: jack-o'-lantern + 608: jeans + 609: jeep + 610: T-shirt + 611: jigsaw puzzle + 612: pulled rickshaw + 613: joystick + 614: kimono + 615: knee pad + 616: knot + 617: lab coat + 618: ladle + 619: lampshade + 620: laptop computer + 621: lawn mower + 622: lens cap + 623: paper knife + 624: library + 625: lifeboat + 626: lighter + 627: limousine + 628: ocean liner + 629: lipstick + 630: slip-on shoe + 631: lotion + 632: speaker + 633: loupe + 634: sawmill + 635: magnetic compass + 636: mail bag + 637: mailbox + 638: tights + 639: tank suit + 640: manhole cover + 641: maraca + 642: marimba + 643: mask + 644: match + 645: maypole + 646: maze + 647: measuring cup + 648: medicine chest + 649: megalith + 650: microphone + 651: microwave oven + 652: military uniform + 653: milk can + 654: minibus + 655: miniskirt + 656: minivan + 657: missile + 658: mitten + 659: mixing bowl + 660: mobile home + 661: Model T + 662: modem + 663: monastery + 664: monitor + 665: moped + 666: mortar + 667: square academic cap + 668: mosque + 669: mosquito net + 670: scooter + 671: mountain bike + 672: tent + 673: computer mouse + 674: mousetrap + 675: moving van + 676: muzzle + 677: nail + 678: neck brace + 679: necklace + 680: nipple + 681: notebook computer + 682: obelisk + 683: oboe + 684: ocarina + 685: odometer + 686: oil filter + 687: organ + 688: oscilloscope + 689: overskirt + 690: bullock cart + 691: oxygen mask + 692: packet + 693: paddle + 694: paddle wheel + 695: padlock + 696: paintbrush + 697: pajamas + 698: palace + 699: pan flute + 700: paper towel + 701: parachute + 702: parallel bars + 703: park bench + 704: parking meter + 705: passenger car + 706: patio + 707: payphone + 708: pedestal + 709: pencil case + 710: pencil sharpener + 711: perfume + 712: Petri dish + 713: photocopier + 714: plectrum + 715: Pickelhaube + 716: picket fence + 717: pickup truck + 718: pier + 719: piggy bank + 720: pill bottle + 721: pillow + 722: ping-pong ball + 723: pinwheel + 724: pirate ship + 725: pitcher + 726: hand plane + 727: planetarium + 728: plastic bag + 729: plate rack + 730: plow + 731: plunger + 732: Polaroid camera + 733: pole + 734: police van + 735: poncho + 736: billiard table + 737: soda bottle + 738: pot + 739: potter's wheel + 740: power drill + 741: prayer rug + 742: printer + 743: prison + 744: projectile + 745: projector + 746: hockey puck + 747: punching bag + 748: purse + 749: quill + 750: quilt + 751: race car + 752: racket + 753: radiator + 754: radio + 755: radio telescope + 756: rain barrel + 757: recreational vehicle + 758: reel + 759: reflex camera + 760: refrigerator + 761: remote control + 762: restaurant + 763: revolver + 764: rifle + 765: rocking chair + 766: rotisserie + 767: eraser + 768: rugby ball + 769: ruler + 770: running shoe + 771: safe + 772: safety pin + 773: salt shaker + 774: sandal + 775: sarong + 776: saxophone + 777: scabbard + 778: weighing scale + 779: school bus + 780: schooner + 781: scoreboard + 782: CRT screen + 783: screw + 784: screwdriver + 785: seat belt + 786: sewing machine + 787: shield + 788: shoe store + 789: shoji + 790: shopping basket + 791: shopping cart + 792: shovel + 793: shower cap + 794: shower curtain + 795: ski + 796: ski mask + 797: sleeping bag + 798: slide rule + 799: sliding door + 800: slot machine + 801: snorkel + 802: snowmobile + 803: snowplow + 804: soap dispenser + 805: soccer ball + 806: sock + 807: solar thermal collector + 808: sombrero + 809: soup bowl + 810: space bar + 811: space heater + 812: space shuttle + 813: spatula + 814: motorboat + 815: spider web + 816: spindle + 817: sports car + 818: spotlight + 819: stage + 820: steam locomotive + 821: through arch bridge + 822: steel drum + 823: stethoscope + 824: scarf + 825: stone wall + 826: stopwatch + 827: stove + 828: strainer + 829: tram + 830: stretcher + 831: couch + 832: stupa + 833: submarine + 834: suit + 835: sundial + 836: sunglass + 837: sunglasses + 838: sunscreen + 839: suspension bridge + 840: mop + 841: sweatshirt + 842: swimsuit + 843: swing + 844: switch + 845: syringe + 846: table lamp + 847: tank + 848: tape player + 849: teapot + 850: teddy bear + 851: television + 852: tennis ball + 853: thatched roof + 854: front curtain + 855: thimble + 856: threshing machine + 857: throne + 858: tile roof + 859: toaster + 860: tobacco shop + 861: toilet seat + 862: torch + 863: totem pole + 864: tow truck + 865: toy store + 866: tractor + 867: semi-trailer truck + 868: tray + 869: trench coat + 870: tricycle + 871: trimaran + 872: tripod + 873: triumphal arch + 874: trolleybus + 875: trombone + 876: tub + 877: turnstile + 878: typewriter keyboard + 879: umbrella + 880: unicycle + 881: upright piano + 882: vacuum cleaner + 883: vase + 884: vault + 885: velvet + 886: vending machine + 887: vestment + 888: viaduct + 889: violin + 890: volleyball + 891: waffle iron + 892: wall clock + 893: wallet + 894: wardrobe + 895: military aircraft + 896: sink + 897: washing machine + 898: water bottle + 899: water jug + 900: water tower + 901: whiskey jug + 902: whistle + 903: wig + 904: window screen + 905: window shade + 906: Windsor tie + 907: wine bottle + 908: wing + 909: wok + 910: wooden spoon + 911: wool + 912: split-rail fence + 913: shipwreck + 914: yawl + 915: yurt + 916: website + 917: comic book + 918: crossword + 919: traffic sign + 920: traffic light + 921: dust jacket + 922: menu + 923: plate + 924: guacamole + 925: consomme + 926: hot pot + 927: trifle + 928: ice cream + 929: ice pop + 930: baguette + 931: bagel + 932: pretzel + 933: cheeseburger + 934: hot dog + 935: mashed potato + 936: cabbage + 937: broccoli + 938: cauliflower + 939: zucchini + 940: spaghetti squash + 941: acorn squash + 942: butternut squash + 943: cucumber + 944: artichoke + 945: bell pepper + 946: cardoon + 947: mushroom + 948: Granny Smith + 949: strawberry + 950: orange + 951: lemon + 952: fig + 953: pineapple + 954: banana + 955: jackfruit + 956: custard apple + 957: pomegranate + 958: hay + 959: carbonara + 960: chocolate syrup + 961: dough + 962: meatloaf + 963: pizza + 964: pot pie + 965: burrito + 966: red wine + 967: espresso + 968: cup + 969: eggnog + 970: alp + 971: bubble + 972: cliff + 973: coral reef + 974: geyser + 975: lakeshore + 976: promontory + 977: shoal + 978: seashore + 979: valley + 980: volcano + 981: baseball player + 982: bridegroom + 983: scuba diver + 984: rapeseed + 985: daisy + 986: yellow lady's slipper + 987: corn + 988: acorn + 989: rose hip + 990: horse chestnut seed + 991: coral fungus + 992: agaric + 993: gyromitra + 994: stinkhorn mushroom + 995: earth star + 996: hen-of-the-woods + 997: bolete + 998: ear + 999: toilet paper + +# Imagenet class codes to human-readable names +map: + n01440764: tench + n01443537: goldfish + n01484850: great_white_shark + n01491361: tiger_shark + n01494475: hammerhead + n01496331: electric_ray + n01498041: stingray + n01514668: cock + n01514859: hen + n01518878: ostrich + n01530575: brambling + n01531178: goldfinch + n01532829: house_finch + n01534433: junco + n01537544: indigo_bunting + n01558993: robin + n01560419: bulbul + n01580077: jay + n01582220: magpie + n01592084: chickadee + n01601694: water_ouzel + n01608432: kite + n01614925: bald_eagle + n01616318: vulture + n01622779: great_grey_owl + n01629819: European_fire_salamander + n01630670: common_newt + n01631663: eft + n01632458: spotted_salamander + n01632777: axolotl + n01641577: bullfrog + n01644373: tree_frog + n01644900: tailed_frog + n01664065: loggerhead + n01665541: leatherback_turtle + n01667114: mud_turtle + n01667778: terrapin + n01669191: box_turtle + n01675722: banded_gecko + n01677366: common_iguana + n01682714: American_chameleon + n01685808: whiptail + n01687978: agama + n01688243: frilled_lizard + n01689811: alligator_lizard + n01692333: Gila_monster + n01693334: green_lizard + n01694178: African_chameleon + n01695060: Komodo_dragon + n01697457: African_crocodile + n01698640: American_alligator + n01704323: triceratops + n01728572: thunder_snake + n01728920: ringneck_snake + n01729322: hognose_snake + n01729977: green_snake + n01734418: king_snake + n01735189: garter_snake + n01737021: water_snake + n01739381: vine_snake + n01740131: night_snake + n01742172: boa_constrictor + n01744401: rock_python + n01748264: Indian_cobra + n01749939: green_mamba + n01751748: sea_snake + n01753488: horned_viper + n01755581: diamondback + n01756291: sidewinder + n01768244: trilobite + n01770081: harvestman + n01770393: scorpion + n01773157: black_and_gold_garden_spider + n01773549: barn_spider + n01773797: garden_spider + n01774384: black_widow + n01774750: tarantula + n01775062: wolf_spider + n01776313: tick + n01784675: centipede + n01795545: black_grouse + n01796340: ptarmigan + n01797886: ruffed_grouse + n01798484: prairie_chicken + n01806143: peacock + n01806567: quail + n01807496: partridge + n01817953: African_grey + n01818515: macaw + n01819313: sulphur-crested_cockatoo + n01820546: lorikeet + n01824575: coucal + n01828970: bee_eater + n01829413: hornbill + n01833805: hummingbird + n01843065: jacamar + n01843383: toucan + n01847000: drake + n01855032: red-breasted_merganser + n01855672: goose + n01860187: black_swan + n01871265: tusker + n01872401: echidna + n01873310: platypus + n01877812: wallaby + n01882714: koala + n01883070: wombat + n01910747: jellyfish + n01914609: sea_anemone + n01917289: brain_coral + n01924916: flatworm + n01930112: nematode + n01943899: conch + n01944390: snail + n01945685: slug + n01950731: sea_slug + n01955084: chiton + n01968897: chambered_nautilus + n01978287: Dungeness_crab + n01978455: rock_crab + n01980166: fiddler_crab + n01981276: king_crab + n01983481: American_lobster + n01984695: spiny_lobster + n01985128: crayfish + n01986214: hermit_crab + n01990800: isopod + n02002556: white_stork + n02002724: black_stork + n02006656: spoonbill + n02007558: flamingo + n02009229: little_blue_heron + n02009912: American_egret + n02011460: bittern + n02012849: crane_(bird) + n02013706: limpkin + n02017213: European_gallinule + n02018207: American_coot + n02018795: bustard + n02025239: ruddy_turnstone + n02027492: red-backed_sandpiper + n02028035: redshank + n02033041: dowitcher + n02037110: oystercatcher + n02051845: pelican + n02056570: king_penguin + n02058221: albatross + n02066245: grey_whale + n02071294: killer_whale + n02074367: dugong + n02077923: sea_lion + n02085620: Chihuahua + n02085782: Japanese_spaniel + n02085936: Maltese_dog + n02086079: Pekinese + n02086240: Shih-Tzu + n02086646: Blenheim_spaniel + n02086910: papillon + n02087046: toy_terrier + n02087394: Rhodesian_ridgeback + n02088094: Afghan_hound + n02088238: basset + n02088364: beagle + n02088466: bloodhound + n02088632: bluetick + n02089078: black-and-tan_coonhound + n02089867: Walker_hound + n02089973: English_foxhound + n02090379: redbone + n02090622: borzoi + n02090721: Irish_wolfhound + n02091032: Italian_greyhound + n02091134: whippet + n02091244: Ibizan_hound + n02091467: Norwegian_elkhound + n02091635: otterhound + n02091831: Saluki + n02092002: Scottish_deerhound + n02092339: Weimaraner + n02093256: Staffordshire_bullterrier + n02093428: American_Staffordshire_terrier + n02093647: Bedlington_terrier + n02093754: Border_terrier + n02093859: Kerry_blue_terrier + n02093991: Irish_terrier + n02094114: Norfolk_terrier + n02094258: Norwich_terrier + n02094433: Yorkshire_terrier + n02095314: wire-haired_fox_terrier + n02095570: Lakeland_terrier + n02095889: Sealyham_terrier + n02096051: Airedale + n02096177: cairn + n02096294: Australian_terrier + n02096437: Dandie_Dinmont + n02096585: Boston_bull + n02097047: miniature_schnauzer + n02097130: giant_schnauzer + n02097209: standard_schnauzer + n02097298: Scotch_terrier + n02097474: Tibetan_terrier + n02097658: silky_terrier + n02098105: soft-coated_wheaten_terrier + n02098286: West_Highland_white_terrier + n02098413: Lhasa + n02099267: flat-coated_retriever + n02099429: curly-coated_retriever + n02099601: golden_retriever + n02099712: Labrador_retriever + n02099849: Chesapeake_Bay_retriever + n02100236: German_short-haired_pointer + n02100583: vizsla + n02100735: English_setter + n02100877: Irish_setter + n02101006: Gordon_setter + n02101388: Brittany_spaniel + n02101556: clumber + n02102040: English_springer + n02102177: Welsh_springer_spaniel + n02102318: cocker_spaniel + n02102480: Sussex_spaniel + n02102973: Irish_water_spaniel + n02104029: kuvasz + n02104365: schipperke + n02105056: groenendael + n02105162: malinois + n02105251: briard + n02105412: kelpie + n02105505: komondor + n02105641: Old_English_sheepdog + n02105855: Shetland_sheepdog + n02106030: collie + n02106166: Border_collie + n02106382: Bouvier_des_Flandres + n02106550: Rottweiler + n02106662: German_shepherd + n02107142: Doberman + n02107312: miniature_pinscher + n02107574: Greater_Swiss_Mountain_dog + n02107683: Bernese_mountain_dog + n02107908: Appenzeller + n02108000: EntleBucher + n02108089: boxer + n02108422: bull_mastiff + n02108551: Tibetan_mastiff + n02108915: French_bulldog + n02109047: Great_Dane + n02109525: Saint_Bernard + n02109961: Eskimo_dog + n02110063: malamute + n02110185: Siberian_husky + n02110341: dalmatian + n02110627: affenpinscher + n02110806: basenji + n02110958: pug + n02111129: Leonberg + n02111277: Newfoundland + n02111500: Great_Pyrenees + n02111889: Samoyed + n02112018: Pomeranian + n02112137: chow + n02112350: keeshond + n02112706: Brabancon_griffon + n02113023: Pembroke + n02113186: Cardigan + n02113624: toy_poodle + n02113712: miniature_poodle + n02113799: standard_poodle + n02113978: Mexican_hairless + n02114367: timber_wolf + n02114548: white_wolf + n02114712: red_wolf + n02114855: coyote + n02115641: dingo + n02115913: dhole + n02116738: African_hunting_dog + n02117135: hyena + n02119022: red_fox + n02119789: kit_fox + n02120079: Arctic_fox + n02120505: grey_fox + n02123045: tabby + n02123159: tiger_cat + n02123394: Persian_cat + n02123597: Siamese_cat + n02124075: Egyptian_cat + n02125311: cougar + n02127052: lynx + n02128385: leopard + n02128757: snow_leopard + n02128925: jaguar + n02129165: lion + n02129604: tiger + n02130308: cheetah + n02132136: brown_bear + n02133161: American_black_bear + n02134084: ice_bear + n02134418: sloth_bear + n02137549: mongoose + n02138441: meerkat + n02165105: tiger_beetle + n02165456: ladybug + n02167151: ground_beetle + n02168699: long-horned_beetle + n02169497: leaf_beetle + n02172182: dung_beetle + n02174001: rhinoceros_beetle + n02177972: weevil + n02190166: fly + n02206856: bee + n02219486: ant + n02226429: grasshopper + n02229544: cricket + n02231487: walking_stick + n02233338: cockroach + n02236044: mantis + n02256656: cicada + n02259212: leafhopper + n02264363: lacewing + n02268443: dragonfly + n02268853: damselfly + n02276258: admiral + n02277742: ringlet + n02279972: monarch + n02280649: cabbage_butterfly + n02281406: sulphur_butterfly + n02281787: lycaenid + n02317335: starfish + n02319095: sea_urchin + n02321529: sea_cucumber + n02325366: wood_rabbit + n02326432: hare + n02328150: Angora + n02342885: hamster + n02346627: porcupine + n02356798: fox_squirrel + n02361337: marmot + n02363005: beaver + n02364673: guinea_pig + n02389026: sorrel + n02391049: zebra + n02395406: hog + n02396427: wild_boar + n02397096: warthog + n02398521: hippopotamus + n02403003: ox + n02408429: water_buffalo + n02410509: bison + n02412080: ram + n02415577: bighorn + n02417914: ibex + n02422106: hartebeest + n02422699: impala + n02423022: gazelle + n02437312: Arabian_camel + n02437616: llama + n02441942: weasel + n02442845: mink + n02443114: polecat + n02443484: black-footed_ferret + n02444819: otter + n02445715: skunk + n02447366: badger + n02454379: armadillo + n02457408: three-toed_sloth + n02480495: orangutan + n02480855: gorilla + n02481823: chimpanzee + n02483362: gibbon + n02483708: siamang + n02484975: guenon + n02486261: patas + n02486410: baboon + n02487347: macaque + n02488291: langur + n02488702: colobus + n02489166: proboscis_monkey + n02490219: marmoset + n02492035: capuchin + n02492660: howler_monkey + n02493509: titi + n02493793: spider_monkey + n02494079: squirrel_monkey + n02497673: Madagascar_cat + n02500267: indri + n02504013: Indian_elephant + n02504458: African_elephant + n02509815: lesser_panda + n02510455: giant_panda + n02514041: barracouta + n02526121: eel + n02536864: coho + n02606052: rock_beauty + n02607072: anemone_fish + n02640242: sturgeon + n02641379: gar + n02643566: lionfish + n02655020: puffer + n02666196: abacus + n02667093: abaya + n02669723: academic_gown + n02672831: accordion + n02676566: acoustic_guitar + n02687172: aircraft_carrier + n02690373: airliner + n02692877: airship + n02699494: altar + n02701002: ambulance + n02704792: amphibian + n02708093: analog_clock + n02727426: apiary + n02730930: apron + n02747177: ashcan + n02749479: assault_rifle + n02769748: backpack + n02776631: bakery + n02777292: balance_beam + n02782093: balloon + n02783161: ballpoint + n02786058: Band_Aid + n02787622: banjo + n02788148: bannister + n02790996: barbell + n02791124: barber_chair + n02791270: barbershop + n02793495: barn + n02794156: barometer + n02795169: barrel + n02797295: barrow + n02799071: baseball + n02802426: basketball + n02804414: bassinet + n02804610: bassoon + n02807133: bathing_cap + n02808304: bath_towel + n02808440: bathtub + n02814533: beach_wagon + n02814860: beacon + n02815834: beaker + n02817516: bearskin + n02823428: beer_bottle + n02823750: beer_glass + n02825657: bell_cote + n02834397: bib + n02835271: bicycle-built-for-two + n02837789: bikini + n02840245: binder + n02841315: binoculars + n02843684: birdhouse + n02859443: boathouse + n02860847: bobsled + n02865351: bolo_tie + n02869837: bonnet + n02870880: bookcase + n02871525: bookshop + n02877765: bottlecap + n02879718: bow + n02883205: bow_tie + n02892201: brass + n02892767: brassiere + n02894605: breakwater + n02895154: breastplate + n02906734: broom + n02909870: bucket + n02910353: buckle + n02916936: bulletproof_vest + n02917067: bullet_train + n02927161: butcher_shop + n02930766: cab + n02939185: caldron + n02948072: candle + n02950826: cannon + n02951358: canoe + n02951585: can_opener + n02963159: cardigan + n02965783: car_mirror + n02966193: carousel + n02966687: carpenter's_kit + n02971356: carton + n02974003: car_wheel + n02977058: cash_machine + n02978881: cassette + n02979186: cassette_player + n02980441: castle + n02981792: catamaran + n02988304: CD_player + n02992211: cello + n02992529: cellular_telephone + n02999410: chain + n03000134: chainlink_fence + n03000247: chain_mail + n03000684: chain_saw + n03014705: chest + n03016953: chiffonier + n03017168: chime + n03018349: china_cabinet + n03026506: Christmas_stocking + n03028079: church + n03032252: cinema + n03041632: cleaver + n03042490: cliff_dwelling + n03045698: cloak + n03047690: clog + n03062245: cocktail_shaker + n03063599: coffee_mug + n03063689: coffeepot + n03065424: coil + n03075370: combination_lock + n03085013: computer_keyboard + n03089624: confectionery + n03095699: container_ship + n03100240: convertible + n03109150: corkscrew + n03110669: cornet + n03124043: cowboy_boot + n03124170: cowboy_hat + n03125729: cradle + n03126707: crane_(machine) + n03127747: crash_helmet + n03127925: crate + n03131574: crib + n03133878: Crock_Pot + n03134739: croquet_ball + n03141823: crutch + n03146219: cuirass + n03160309: dam + n03179701: desk + n03180011: desktop_computer + n03187595: dial_telephone + n03188531: diaper + n03196217: digital_clock + n03197337: digital_watch + n03201208: dining_table + n03207743: dishrag + n03207941: dishwasher + n03208938: disk_brake + n03216828: dock + n03218198: dogsled + n03220513: dome + n03223299: doormat + n03240683: drilling_platform + n03249569: drum + n03250847: drumstick + n03255030: dumbbell + n03259280: Dutch_oven + n03271574: electric_fan + n03272010: electric_guitar + n03272562: electric_locomotive + n03290653: entertainment_center + n03291819: envelope + n03297495: espresso_maker + n03314780: face_powder + n03325584: feather_boa + n03337140: file + n03344393: fireboat + n03345487: fire_engine + n03347037: fire_screen + n03355925: flagpole + n03372029: flute + n03376595: folding_chair + n03379051: football_helmet + n03384352: forklift + n03388043: fountain + n03388183: fountain_pen + n03388549: four-poster + n03393912: freight_car + n03394916: French_horn + n03400231: frying_pan + n03404251: fur_coat + n03417042: garbage_truck + n03424325: gasmask + n03425413: gas_pump + n03443371: goblet + n03444034: go-kart + n03445777: golf_ball + n03445924: golfcart + n03447447: gondola + n03447721: gong + n03450230: gown + n03452741: grand_piano + n03457902: greenhouse + n03459775: grille + n03461385: grocery_store + n03467068: guillotine + n03476684: hair_slide + n03476991: hair_spray + n03478589: half_track + n03481172: hammer + n03482405: hamper + n03483316: hand_blower + n03485407: hand-held_computer + n03485794: handkerchief + n03492542: hard_disc + n03494278: harmonica + n03495258: harp + n03496892: harvester + n03498962: hatchet + n03527444: holster + n03529860: home_theater + n03530642: honeycomb + n03532672: hook + n03534580: hoopskirt + n03535780: horizontal_bar + n03538406: horse_cart + n03544143: hourglass + n03584254: iPod + n03584829: iron + n03590841: jack-o'-lantern + n03594734: jean + n03594945: jeep + n03595614: jersey + n03598930: jigsaw_puzzle + n03599486: jinrikisha + n03602883: joystick + n03617480: kimono + n03623198: knee_pad + n03627232: knot + n03630383: lab_coat + n03633091: ladle + n03637318: lampshade + n03642806: laptop + n03649909: lawn_mower + n03657121: lens_cap + n03658185: letter_opener + n03661043: library + n03662601: lifeboat + n03666591: lighter + n03670208: limousine + n03673027: liner + n03676483: lipstick + n03680355: Loafer + n03690938: lotion + n03691459: loudspeaker + n03692522: loupe + n03697007: lumbermill + n03706229: magnetic_compass + n03709823: mailbag + n03710193: mailbox + n03710637: maillot_(tights) + n03710721: maillot_(tank_suit) + n03717622: manhole_cover + n03720891: maraca + n03721384: marimba + n03724870: mask + n03729826: matchstick + n03733131: maypole + n03733281: maze + n03733805: measuring_cup + n03742115: medicine_chest + n03743016: megalith + n03759954: microphone + n03761084: microwave + n03763968: military_uniform + n03764736: milk_can + n03769881: minibus + n03770439: miniskirt + n03770679: minivan + n03773504: missile + n03775071: mitten + n03775546: mixing_bowl + n03776460: mobile_home + n03777568: Model_T + n03777754: modem + n03781244: monastery + n03782006: monitor + n03785016: moped + n03786901: mortar + n03787032: mortarboard + n03788195: mosque + n03788365: mosquito_net + n03791053: motor_scooter + n03792782: mountain_bike + n03792972: mountain_tent + n03793489: mouse + n03794056: mousetrap + n03796401: moving_van + n03803284: muzzle + n03804744: nail + n03814639: neck_brace + n03814906: necklace + n03825788: nipple + n03832673: notebook + n03837869: obelisk + n03838899: oboe + n03840681: ocarina + n03841143: odometer + n03843555: oil_filter + n03854065: organ + n03857828: oscilloscope + n03866082: overskirt + n03868242: oxcart + n03868863: oxygen_mask + n03871628: packet + n03873416: paddle + n03874293: paddlewheel + n03874599: padlock + n03876231: paintbrush + n03877472: pajama + n03877845: palace + n03884397: panpipe + n03887697: paper_towel + n03888257: parachute + n03888605: parallel_bars + n03891251: park_bench + n03891332: parking_meter + n03895866: passenger_car + n03899768: patio + n03902125: pay-phone + n03903868: pedestal + n03908618: pencil_box + n03908714: pencil_sharpener + n03916031: perfume + n03920288: Petri_dish + n03924679: photocopier + n03929660: pick + n03929855: pickelhaube + n03930313: picket_fence + n03930630: pickup + n03933933: pier + n03935335: piggy_bank + n03937543: pill_bottle + n03938244: pillow + n03942813: ping-pong_ball + n03944341: pinwheel + n03947888: pirate + n03950228: pitcher + n03954731: plane + n03956157: planetarium + n03958227: plastic_bag + n03961711: plate_rack + n03967562: plow + n03970156: plunger + n03976467: Polaroid_camera + n03976657: pole + n03977966: police_van + n03980874: poncho + n03982430: pool_table + n03983396: pop_bottle + n03991062: pot + n03992509: potter's_wheel + n03995372: power_drill + n03998194: prayer_rug + n04004767: printer + n04005630: prison + n04008634: projectile + n04009552: projector + n04019541: puck + n04023962: punching_bag + n04026417: purse + n04033901: quill + n04033995: quilt + n04037443: racer + n04039381: racket + n04040759: radiator + n04041544: radio + n04044716: radio_telescope + n04049303: rain_barrel + n04065272: recreational_vehicle + n04067472: reel + n04069434: reflex_camera + n04070727: refrigerator + n04074963: remote_control + n04081281: restaurant + n04086273: revolver + n04090263: rifle + n04099969: rocking_chair + n04111531: rotisserie + n04116512: rubber_eraser + n04118538: rugby_ball + n04118776: rule + n04120489: running_shoe + n04125021: safe + n04127249: safety_pin + n04131690: saltshaker + n04133789: sandal + n04136333: sarong + n04141076: sax + n04141327: scabbard + n04141975: scale + n04146614: school_bus + n04147183: schooner + n04149813: scoreboard + n04152593: screen + n04153751: screw + n04154565: screwdriver + n04162706: seat_belt + n04179913: sewing_machine + n04192698: shield + n04200800: shoe_shop + n04201297: shoji + n04204238: shopping_basket + n04204347: shopping_cart + n04208210: shovel + n04209133: shower_cap + n04209239: shower_curtain + n04228054: ski + n04229816: ski_mask + n04235860: sleeping_bag + n04238763: slide_rule + n04239074: sliding_door + n04243546: slot + n04251144: snorkel + n04252077: snowmobile + n04252225: snowplow + n04254120: soap_dispenser + n04254680: soccer_ball + n04254777: sock + n04258138: solar_dish + n04259630: sombrero + n04263257: soup_bowl + n04264628: space_bar + n04265275: space_heater + n04266014: space_shuttle + n04270147: spatula + n04273569: speedboat + n04275548: spider_web + n04277352: spindle + n04285008: sports_car + n04286575: spotlight + n04296562: stage + n04310018: steam_locomotive + n04311004: steel_arch_bridge + n04311174: steel_drum + n04317175: stethoscope + n04325704: stole + n04326547: stone_wall + n04328186: stopwatch + n04330267: stove + n04332243: strainer + n04335435: streetcar + n04336792: stretcher + n04344873: studio_couch + n04346328: stupa + n04347754: submarine + n04350905: suit + n04355338: sundial + n04355933: sunglass + n04356056: sunglasses + n04357314: sunscreen + n04366367: suspension_bridge + n04367480: swab + n04370456: sweatshirt + n04371430: swimming_trunks + n04371774: swing + n04372370: switch + n04376876: syringe + n04380533: table_lamp + n04389033: tank + n04392985: tape_player + n04398044: teapot + n04399382: teddy + n04404412: television + n04409515: tennis_ball + n04417672: thatch + n04418357: theater_curtain + n04423845: thimble + n04428191: thresher + n04429376: throne + n04435653: tile_roof + n04442312: toaster + n04443257: tobacco_shop + n04447861: toilet_seat + n04456115: torch + n04458633: totem_pole + n04461696: tow_truck + n04462240: toyshop + n04465501: tractor + n04467665: trailer_truck + n04476259: tray + n04479046: trench_coat + n04482393: tricycle + n04483307: trimaran + n04485082: tripod + n04486054: triumphal_arch + n04487081: trolleybus + n04487394: trombone + n04493381: tub + n04501370: turnstile + n04505470: typewriter_keyboard + n04507155: umbrella + n04509417: unicycle + n04515003: upright + n04517823: vacuum + n04522168: vase + n04523525: vault + n04525038: velvet + n04525305: vending_machine + n04532106: vestment + n04532670: viaduct + n04536866: violin + n04540053: volleyball + n04542943: waffle_iron + n04548280: wall_clock + n04548362: wallet + n04550184: wardrobe + n04552348: warplane + n04553703: washbasin + n04554684: washer + n04557648: water_bottle + n04560804: water_jug + n04562935: water_tower + n04579145: whiskey_jug + n04579432: whistle + n04584207: wig + n04589890: window_screen + n04590129: window_shade + n04591157: Windsor_tie + n04591713: wine_bottle + n04592741: wing + n04596742: wok + n04597913: wooden_spoon + n04599235: wool + n04604644: worm_fence + n04606251: wreck + n04612504: yawl + n04613696: yurt + n06359193: web_site + n06596364: comic_book + n06785654: crossword_puzzle + n06794110: street_sign + n06874185: traffic_light + n07248320: book_jacket + n07565083: menu + n07579787: plate + n07583066: guacamole + n07584110: consomme + n07590611: hot_pot + n07613480: trifle + n07614500: ice_cream + n07615774: ice_lolly + n07684084: French_loaf + n07693725: bagel + n07695742: pretzel + n07697313: cheeseburger + n07697537: hotdog + n07711569: mashed_potato + n07714571: head_cabbage + n07714990: broccoli + n07715103: cauliflower + n07716358: zucchini + n07716906: spaghetti_squash + n07717410: acorn_squash + n07717556: butternut_squash + n07718472: cucumber + n07718747: artichoke + n07720875: bell_pepper + n07730033: cardoon + n07734744: mushroom + n07742313: Granny_Smith + n07745940: strawberry + n07747607: orange + n07749582: lemon + n07753113: fig + n07753275: pineapple + n07753592: banana + n07754684: jackfruit + n07760859: custard_apple + n07768694: pomegranate + n07802026: hay + n07831146: carbonara + n07836838: chocolate_sauce + n07860988: dough + n07871810: meat_loaf + n07873807: pizza + n07875152: potpie + n07880968: burrito + n07892512: red_wine + n07920052: espresso + n07930864: cup + n07932039: eggnog + n09193705: alp + n09229709: bubble + n09246464: cliff + n09256479: coral_reef + n09288635: geyser + n09332890: lakeside + n09399592: promontory + n09421951: sandbar + n09428293: seashore + n09468604: valley + n09472597: volcano + n09835506: ballplayer + n10148035: groom + n10565667: scuba_diver + n11879895: rapeseed + n11939491: daisy + n12057211: yellow_lady's_slipper + n12144580: corn + n12267677: acorn + n12620546: hip + n12768682: buckeye + n12985857: coral_fungus + n12998815: agaric + n13037406: gyromitra + n13040303: stinkhorn + n13044778: earthstar + n13052670: hen-of-the-woods + n13054560: bolete + n13133613: ear + n15075141: toilet_tissue + +# Download script/URL (optional) +download: ultralytics/data/scripts/get_imagenet.sh diff --git a/tracking/ultralytics/cfg/datasets/Objects365.yaml b/tracking/ultralytics/cfg/datasets/Objects365.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f89d8bfed20ad6fd394f389f69e523cff013053 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/Objects365.yaml @@ -0,0 +1,443 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Objects365 dataset https://www.objects365.org/ by Megvii +# Documentation: https://docs.ultralytics.com/datasets/detect/objects365/ +# Example usage: yolo train data=Objects365.yaml +# parent +# ├── ultralytics +# └── datasets +# └── Objects365 ← downloads here (712 GB = 367G data + 345G zips) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/Objects365 # dataset root dir +train: images/train # train images (relative to 'path') 1742289 images +val: images/val # val images (relative to 'path') 80000 images +test: # test images (optional) + +# Classes +names: + 0: Person + 1: Sneakers + 2: Chair + 3: Other Shoes + 4: Hat + 5: Car + 6: Lamp + 7: Glasses + 8: Bottle + 9: Desk + 10: Cup + 11: Street Lights + 12: Cabinet/shelf + 13: Handbag/Satchel + 14: Bracelet + 15: Plate + 16: Picture/Frame + 17: Helmet + 18: Book + 19: Gloves + 20: Storage box + 21: Boat + 22: Leather Shoes + 23: Flower + 24: Bench + 25: Potted Plant + 26: Bowl/Basin + 27: Flag + 28: Pillow + 29: Boots + 30: Vase + 31: Microphone + 32: Necklace + 33: Ring + 34: SUV + 35: Wine Glass + 36: Belt + 37: Monitor/TV + 38: Backpack + 39: Umbrella + 40: Traffic Light + 41: Speaker + 42: Watch + 43: Tie + 44: Trash bin Can + 45: Slippers + 46: Bicycle + 47: Stool + 48: Barrel/bucket + 49: Van + 50: Couch + 51: Sandals + 52: Basket + 53: Drum + 54: Pen/Pencil + 55: Bus + 56: Wild Bird + 57: High Heels + 58: Motorcycle + 59: Guitar + 60: Carpet + 61: Cell Phone + 62: Bread + 63: Camera + 64: Canned + 65: Truck + 66: Traffic cone + 67: Cymbal + 68: Lifesaver + 69: Towel + 70: Stuffed Toy + 71: Candle + 72: Sailboat + 73: Laptop + 74: Awning + 75: Bed + 76: Faucet + 77: Tent + 78: Horse + 79: Mirror + 80: Power outlet + 81: Sink + 82: Apple + 83: Air Conditioner + 84: Knife + 85: Hockey Stick + 86: Paddle + 87: Pickup Truck + 88: Fork + 89: Traffic Sign + 90: Balloon + 91: Tripod + 92: Dog + 93: Spoon + 94: Clock + 95: Pot + 96: Cow + 97: Cake + 98: Dining Table + 99: Sheep + 100: Hanger + 101: Blackboard/Whiteboard + 102: Napkin + 103: Other Fish + 104: Orange/Tangerine + 105: Toiletry + 106: Keyboard + 107: Tomato + 108: Lantern + 109: Machinery Vehicle + 110: Fan + 111: Green Vegetables + 112: Banana + 113: Baseball Glove + 114: Airplane + 115: Mouse + 116: Train + 117: Pumpkin + 118: Soccer + 119: Skiboard + 120: Luggage + 121: Nightstand + 122: Tea pot + 123: Telephone + 124: Trolley + 125: Head Phone + 126: Sports Car + 127: Stop Sign + 128: Dessert + 129: Scooter + 130: Stroller + 131: Crane + 132: Remote + 133: Refrigerator + 134: Oven + 135: Lemon + 136: Duck + 137: Baseball Bat + 138: Surveillance Camera + 139: Cat + 140: Jug + 141: Broccoli + 142: Piano + 143: Pizza + 144: Elephant + 145: Skateboard + 146: Surfboard + 147: Gun + 148: Skating and Skiing shoes + 149: Gas stove + 150: Donut + 151: Bow Tie + 152: Carrot + 153: Toilet + 154: Kite + 155: Strawberry + 156: Other Balls + 157: Shovel + 158: Pepper + 159: Computer Box + 160: Toilet Paper + 161: Cleaning Products + 162: Chopsticks + 163: Microwave + 164: Pigeon + 165: Baseball + 166: Cutting/chopping Board + 167: Coffee Table + 168: Side Table + 169: Scissors + 170: Marker + 171: Pie + 172: Ladder + 173: Snowboard + 174: Cookies + 175: Radiator + 176: Fire Hydrant + 177: Basketball + 178: Zebra + 179: Grape + 180: Giraffe + 181: Potato + 182: Sausage + 183: Tricycle + 184: Violin + 185: Egg + 186: Fire Extinguisher + 187: Candy + 188: Fire Truck + 189: Billiards + 190: Converter + 191: Bathtub + 192: Wheelchair + 193: Golf Club + 194: Briefcase + 195: Cucumber + 196: Cigar/Cigarette + 197: Paint Brush + 198: Pear + 199: Heavy Truck + 200: Hamburger + 201: Extractor + 202: Extension Cord + 203: Tong + 204: Tennis Racket + 205: Folder + 206: American Football + 207: earphone + 208: Mask + 209: Kettle + 210: Tennis + 211: Ship + 212: Swing + 213: Coffee Machine + 214: Slide + 215: Carriage + 216: Onion + 217: Green beans + 218: Projector + 219: Frisbee + 220: Washing Machine/Drying Machine + 221: Chicken + 222: Printer + 223: Watermelon + 224: Saxophone + 225: Tissue + 226: Toothbrush + 227: Ice cream + 228: Hot-air balloon + 229: Cello + 230: French Fries + 231: Scale + 232: Trophy + 233: Cabbage + 234: Hot dog + 235: Blender + 236: Peach + 237: Rice + 238: Wallet/Purse + 239: Volleyball + 240: Deer + 241: Goose + 242: Tape + 243: Tablet + 244: Cosmetics + 245: Trumpet + 246: Pineapple + 247: Golf Ball + 248: Ambulance + 249: Parking meter + 250: Mango + 251: Key + 252: Hurdle + 253: Fishing Rod + 254: Medal + 255: Flute + 256: Brush + 257: Penguin + 258: Megaphone + 259: Corn + 260: Lettuce + 261: Garlic + 262: Swan + 263: Helicopter + 264: Green Onion + 265: Sandwich + 266: Nuts + 267: Speed Limit Sign + 268: Induction Cooker + 269: Broom + 270: Trombone + 271: Plum + 272: Rickshaw + 273: Goldfish + 274: Kiwi fruit + 275: Router/modem + 276: Poker Card + 277: Toaster + 278: Shrimp + 279: Sushi + 280: Cheese + 281: Notepaper + 282: Cherry + 283: Pliers + 284: CD + 285: Pasta + 286: Hammer + 287: Cue + 288: Avocado + 289: Hami melon + 290: Flask + 291: Mushroom + 292: Screwdriver + 293: Soap + 294: Recorder + 295: Bear + 296: Eggplant + 297: Board Eraser + 298: Coconut + 299: Tape Measure/Ruler + 300: Pig + 301: Showerhead + 302: Globe + 303: Chips + 304: Steak + 305: Crosswalk Sign + 306: Stapler + 307: Camel + 308: Formula 1 + 309: Pomegranate + 310: Dishwasher + 311: Crab + 312: Hoverboard + 313: Meatball + 314: Rice Cooker + 315: Tuba + 316: Calculator + 317: Papaya + 318: Antelope + 319: Parrot + 320: Seal + 321: Butterfly + 322: Dumbbell + 323: Donkey + 324: Lion + 325: Urinal + 326: Dolphin + 327: Electric Drill + 328: Hair Dryer + 329: Egg tart + 330: Jellyfish + 331: Treadmill + 332: Lighter + 333: Grapefruit + 334: Game board + 335: Mop + 336: Radish + 337: Baozi + 338: Target + 339: French + 340: Spring Rolls + 341: Monkey + 342: Rabbit + 343: Pencil Case + 344: Yak + 345: Red Cabbage + 346: Binoculars + 347: Asparagus + 348: Barbell + 349: Scallop + 350: Noddles + 351: Comb + 352: Dumpling + 353: Oyster + 354: Table Tennis paddle + 355: Cosmetics Brush/Eyeliner Pencil + 356: Chainsaw + 357: Eraser + 358: Lobster + 359: Durian + 360: Okra + 361: Lipstick + 362: Cosmetics Mirror + 363: Curling + 364: Table Tennis + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + from pathlib import Path + + import numpy as np + from tqdm import tqdm + + from ultralytics.utils.checks import check_requirements + from ultralytics.utils.downloads import download + from ultralytics.utils.ops import xyxy2xywhn + + check_requirements(("pycocotools>=2.0",)) + from pycocotools.coco import COCO + + # Make Directories + dir = Path(yaml["path"]) # dataset root dir + for p in "images", "labels": + (dir / p).mkdir(parents=True, exist_ok=True) + for q in "train", "val": + (dir / p / q).mkdir(parents=True, exist_ok=True) + + # Train, Val Splits + for split, patches in [("train", 50 + 1), ("val", 43 + 1)]: + print(f"Processing {split} in {patches} patches ...") + images, labels = dir / "images" / split, dir / "labels" / split + + # Download + url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/" + if split == "train": + download([f"{url}zhiyuan_objv2_{split}.tar.gz"], dir=dir) # annotations json + download([f"{url}patch{i}.tar.gz" for i in range(patches)], dir=images, curl=True, threads=8) + elif split == "val": + download([f"{url}zhiyuan_objv2_{split}.json"], dir=dir) # annotations json + download([f"{url}images/v1/patch{i}.tar.gz" for i in range(15 + 1)], dir=images, curl=True, threads=8) + download([f"{url}images/v2/patch{i}.tar.gz" for i in range(16, patches)], dir=images, curl=True, threads=8) + + # Move + for f in tqdm(images.rglob("*.jpg"), desc=f"Moving {split} images"): + f.rename(images / f.name) # move to /images/{split} + + # Labels + coco = COCO(dir / f"zhiyuan_objv2_{split}.json") + names = [x["name"] for x in coco.loadCats(coco.getCatIds())] + for cid, cat in enumerate(names): + catIds = coco.getCatIds(catNms=[cat]) + imgIds = coco.getImgIds(catIds=catIds) + for im in tqdm(coco.loadImgs(imgIds), desc=f"Class {cid + 1}/{len(names)} {cat}"): + width, height = im["width"], im["height"] + path = Path(im["file_name"]) # image filename + try: + with open(labels / path.with_suffix(".txt").name, "a", encoding="utf-8") as file: + annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None) + for a in coco.loadAnns(annIds): + x, y, w, h = a["bbox"] # bounding box in xywh (xy top-left corner) + xyxy = np.array([x, y, x + w, y + h])[None] # pixels(1,4) + x, y, w, h = xyxy2xywhn(xyxy, w=width, h=height, clip=True)[0] # normalized and clipped + file.write(f"{cid} {x:.5f} {y:.5f} {w:.5f} {h:.5f}\n") + except Exception as e: + print(e) diff --git a/tracking/ultralytics/cfg/datasets/SKU-110K.yaml b/tracking/ultralytics/cfg/datasets/SKU-110K.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e056054590e4515ee636670cba12486d5ef8dea8 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/SKU-110K.yaml @@ -0,0 +1,58 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail +# Documentation: https://docs.ultralytics.com/datasets/detect/sku-110k/ +# Example usage: yolo train data=SKU-110K.yaml +# parent +# ├── ultralytics +# └── datasets +# └── SKU-110K ← downloads here (13.6 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/SKU-110K # dataset root dir +train: train.txt # train images (relative to 'path') 8219 images +val: val.txt # val images (relative to 'path') 588 images +test: test.txt # test images (optional) 2936 images + +# Classes +names: + 0: object + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import shutil + from pathlib import Path + + import numpy as np + import pandas as pd + from tqdm import tqdm + + from ultralytics.utils.downloads import download + from ultralytics.utils.ops import xyxy2xywh + + # Download + dir = Path(yaml["path"]) # dataset root dir + parent = Path(dir.parent) # download dir + urls = ["http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz"] + download(urls, dir=parent) + + # Rename directories + if dir.exists(): + shutil.rmtree(dir) + (parent / "SKU110K_fixed").rename(dir) # rename dir + (dir / "labels").mkdir(parents=True, exist_ok=True) # create labels dir + + # Convert labels + names = "image", "x1", "y1", "x2", "y2", "class", "image_width", "image_height" # column names + for d in "annotations_train.csv", "annotations_val.csv", "annotations_test.csv": + x = pd.read_csv(dir / "annotations" / d, names=names).values # annotations + images, unique_images = x[:, 0], np.unique(x[:, 0]) + with open((dir / d).with_suffix(".txt").__str__().replace("annotations_", ""), "w", encoding="utf-8") as f: + f.writelines(f"./images/{s}\n" for s in unique_images) + for im in tqdm(unique_images, desc=f"Converting {dir / d}"): + cls = 0 # single-class dataset + with open((dir / "labels" / im).with_suffix(".txt"), "a", encoding="utf-8") as f: + for r in x[images == im]: + w, h = r[6], r[7] # image width, height + xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance + f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label diff --git a/tracking/ultralytics/cfg/datasets/VOC.yaml b/tracking/ultralytics/cfg/datasets/VOC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19bec223f6440378794f0df0ff0a127ff95c06a5 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/VOC.yaml @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford +# Documentation: # Documentation: https://docs.ultralytics.com/datasets/detect/voc/ +# Example usage: yolo train data=VOC.yaml +# parent +# ├── ultralytics +# └── datasets +# └── VOC ← downloads here (2.8 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/VOC +train: # train images (relative to 'path') 16551 images + - images/train2012 + - images/train2007 + - images/val2012 + - images/val2007 +val: # val images (relative to 'path') 4952 images + - images/test2007 +test: # test images (optional) + - images/test2007 + +# Classes +names: + 0: aeroplane + 1: bicycle + 2: bird + 3: boat + 4: bottle + 5: bus + 6: car + 7: cat + 8: chair + 9: cow + 10: diningtable + 11: dog + 12: horse + 13: motorbike + 14: person + 15: pottedplant + 16: sheep + 17: sofa + 18: train + 19: tvmonitor + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import xml.etree.ElementTree as ET + from pathlib import Path + + from tqdm import tqdm + + from ultralytics.utils.downloads import download + + + def convert_label(path, lb_path, year, image_id): + """Converts XML annotations from VOC format to YOLO format by extracting bounding boxes and class IDs.""" + + def convert_box(size, box): + dw, dh = 1.0 / size[0], 1.0 / size[1] + x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2] + return x * dw, y * dh, w * dw, h * dh + + in_file = open(path / f"VOC{year}/Annotations/{image_id}.xml") + out_file = open(lb_path, "w") + tree = ET.parse(in_file) + root = tree.getroot() + size = root.find("size") + w = int(size.find("width").text) + h = int(size.find("height").text) + + names = list(yaml["names"].values()) # names list + for obj in root.iter("object"): + cls = obj.find("name").text + if cls in names and int(obj.find("difficult").text) != 1: + xmlbox = obj.find("bndbox") + bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ("xmin", "xmax", "ymin", "ymax")]) + cls_id = names.index(cls) # class id + out_file.write(" ".join(str(a) for a in (cls_id, *bb)) + "\n") + + + # Download + dir = Path(yaml["path"]) # dataset root dir + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + urls = [ + f"{url}VOCtrainval_06-Nov-2007.zip", # 446MB, 5012 images + f"{url}VOCtest_06-Nov-2007.zip", # 438MB, 4953 images + f"{url}VOCtrainval_11-May-2012.zip", # 1.95GB, 17126 images + ] + download(urls, dir=dir / "images", curl=True, threads=3, exist_ok=True) # download and unzip over existing (required) + + # Convert + path = dir / "images/VOCdevkit" + for year, image_set in ("2012", "train"), ("2012", "val"), ("2007", "train"), ("2007", "val"), ("2007", "test"): + imgs_path = dir / "images" / f"{image_set}{year}" + lbs_path = dir / "labels" / f"{image_set}{year}" + imgs_path.mkdir(exist_ok=True, parents=True) + lbs_path.mkdir(exist_ok=True, parents=True) + + with open(path / f"VOC{year}/ImageSets/Main/{image_set}.txt") as f: + image_ids = f.read().strip().split() + for id in tqdm(image_ids, desc=f"{image_set}{year}"): + f = path / f"VOC{year}/JPEGImages/{id}.jpg" # old img path + lb_path = (lbs_path / f.name).with_suffix(".txt") # new label path + f.rename(imgs_path / f.name) # move image + convert_label(path, lb_path, year, id) # convert labels to YOLO format diff --git a/tracking/ultralytics/cfg/datasets/VisDrone.yaml b/tracking/ultralytics/cfg/datasets/VisDrone.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35f84c4a8de1c1f6ede85cefa3332aeb9af2eefb --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/VisDrone.yaml @@ -0,0 +1,77 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University +# Documentation: https://docs.ultralytics.com/datasets/detect/visdrone/ +# Example usage: yolo train data=VisDrone.yaml +# parent +# ├── ultralytics +# └── datasets +# └── VisDrone ← downloads here (2.3 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/VisDrone # dataset root dir +train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images +val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images +test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images + +# Classes +names: + 0: pedestrian + 1: people + 2: bicycle + 3: car + 4: van + 5: truck + 6: tricycle + 7: awning-tricycle + 8: bus + 9: motor + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import os + from pathlib import Path + + from ultralytics.utils.downloads import download + + + def visdrone2yolo(dir): + """Convert VisDrone annotations to YOLO format, creating label files with normalized bounding box coordinates.""" + from PIL import Image + from tqdm import tqdm + + def convert_box(size, box): + # Convert VisDrone box to YOLO xywh box + dw = 1.0 / size[0] + dh = 1.0 / size[1] + return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh + + (dir / "labels").mkdir(parents=True, exist_ok=True) # make labels directory + pbar = tqdm((dir / "annotations").glob("*.txt"), desc=f"Converting {dir}") + for f in pbar: + img_size = Image.open((dir / "images" / f.name).with_suffix(".jpg")).size + lines = [] + with open(f, encoding="utf-8") as file: # read annotation.txt + for row in [x.split(",") for x in file.read().strip().splitlines()]: + if row[4] == "0": # VisDrone 'ignored regions' class 0 + continue + cls = int(row[5]) - 1 + box = convert_box(img_size, tuple(map(int, row[:4]))) + lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n") + with open(str(f).replace(f"{os.sep}annotations{os.sep}", f"{os.sep}labels{os.sep}"), "w", encoding="utf-8") as fl: + fl.writelines(lines) # write label.txt + + + # Download + dir = Path(yaml["path"]) # dataset root dir + urls = [ + "https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-train.zip", + "https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-val.zip", + "https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-test-dev.zip", + "https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-test-challenge.zip", + ] + download(urls, dir=dir, curl=True, threads=4) + + # Convert + for d in "VisDrone2019-DET-train", "VisDrone2019-DET-val", "VisDrone2019-DET-test-dev": + visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels diff --git a/tracking/ultralytics/cfg/datasets/african-wildlife.yaml b/tracking/ultralytics/cfg/datasets/african-wildlife.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b825f8f068b54a47e4f32cd105ca225f3f9f1f8a --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/african-wildlife.yaml @@ -0,0 +1,25 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# African-wildlife dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/african-wildlife/ +# Example usage: yolo train data=african-wildlife.yaml +# parent +# ├── ultralytics +# └── datasets +# └── african-wildlife ← downloads here (100 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/african-wildlife # dataset root dir +train: train/images # train images (relative to 'path') 1052 images +val: valid/images # val images (relative to 'path') 225 images +test: test/images # test images (relative to 'path') 227 images + +# Classes +names: + 0: buffalo + 1: elephant + 2: rhino + 3: zebra + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/african-wildlife.zip diff --git a/tracking/ultralytics/cfg/datasets/brain-tumor.yaml b/tracking/ultralytics/cfg/datasets/brain-tumor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a448e84afc76832914cde3a2f0d6208b8c78e29 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/brain-tumor.yaml @@ -0,0 +1,23 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Brain-tumor dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/brain-tumor/ +# Example usage: yolo train data=brain-tumor.yaml +# parent +# ├── ultralytics +# └── datasets +# └── brain-tumor ← downloads here (4.05 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/brain-tumor # dataset root dir +train: train/images # train images (relative to 'path') 893 images +val: valid/images # val images (relative to 'path') 223 images +test: # test images (relative to 'path') + +# Classes +names: + 0: negative + 1: positive + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/brain-tumor.zip diff --git a/tracking/ultralytics/cfg/datasets/carparts-seg.yaml b/tracking/ultralytics/cfg/datasets/carparts-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f15f9b06625e2a2575d24da262d317a9394e71a --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/carparts-seg.yaml @@ -0,0 +1,44 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Carparts-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/carparts-seg/ +# Example usage: yolo train data=carparts-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── carparts-seg ← downloads here (132 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/carparts-seg # dataset root dir +train: train/images # train images (relative to 'path') 3516 images +val: valid/images # val images (relative to 'path') 276 images +test: test/images # test images (relative to 'path') 401 images + +# Classes +names: + 0: back_bumper + 1: back_door + 2: back_glass + 3: back_left_door + 4: back_left_light + 5: back_light + 6: back_right_door + 7: back_right_light + 8: front_bumper + 9: front_door + 10: front_glass + 11: front_left_door + 12: front_left_light + 13: front_light + 14: front_right_door + 15: front_right_light + 16: hood + 17: left_mirror + 18: object + 19: right_mirror + 20: tailgate + 21: trunk + 22: wheel + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/carparts-seg.zip diff --git a/tracking/ultralytics/cfg/datasets/coco-pose.yaml b/tracking/ultralytics/cfg/datasets/coco-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbe16f5c03177f50c93dd660800726c99c747458 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco-pose.yaml @@ -0,0 +1,42 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO 2017 Keypoints dataset https://cocodataset.org by Microsoft +# Documentation: https://docs.ultralytics.com/datasets/pose/coco/ +# Example usage: yolo train data=coco-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco-pose ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco-pose # dataset root dir +train: train2017.txt # train images (relative to 'path') 56599 images +val: val2017.txt # val images (relative to 'path') 2346 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://codalab.lisn.upsaclay.fr/competitions/7403 + +# Keypoints +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + +# Classes +names: + 0: person + +# Download script/URL (optional) +download: | + from pathlib import Path + + from ultralytics.utils.downloads import download + + # Download labels + dir = Path(yaml["path"]) # dataset root dir + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + urls = [f"{url}coco2017labels-pose.zip"] + download(urls, dir=dir.parent) + # Download data + urls = [ + "http://images.cocodataset.org/zips/train2017.zip", # 19G, 118k images + "http://images.cocodataset.org/zips/val2017.zip", # 1G, 5k images + "http://images.cocodataset.org/zips/test2017.zip", # 7G, 41k images (optional) + ] + download(urls, dir=dir / "images", threads=3) diff --git a/tracking/ultralytics/cfg/datasets/coco.yaml b/tracking/ultralytics/cfg/datasets/coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c86fbd58b1bc087af2d5514e87b420f3ae2d2a1 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco.yaml @@ -0,0 +1,118 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO 2017 dataset https://cocodataset.org by Microsoft +# Documentation: https://docs.ultralytics.com/datasets/detect/coco/ +# Example usage: yolo train data=coco.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco # dataset root dir +train: train2017.txt # train images (relative to 'path') 118287 images +val: val2017.txt # val images (relative to 'path') 5000 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: | + from pathlib import Path + + from ultralytics.utils.downloads import download + + # Download labels + segments = True # segment or box labels + dir = Path(yaml["path"]) # dataset root dir + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + urls = [url + ("coco2017labels-segments.zip" if segments else "coco2017labels.zip")] # labels + download(urls, dir=dir.parent) + # Download data + urls = [ + "http://images.cocodataset.org/zips/train2017.zip", # 19G, 118k images + "http://images.cocodataset.org/zips/val2017.zip", # 1G, 5k images + "http://images.cocodataset.org/zips/test2017.zip", # 7G, 41k images (optional) + ] + download(urls, dir=dir / "images", threads=3) diff --git a/tracking/ultralytics/cfg/datasets/coco128-seg.yaml b/tracking/ultralytics/cfg/datasets/coco128-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b023c676300db3ca9909839b5cdf7ac709ba5949 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco128-seg.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO128-seg dataset https://www.kaggle.com/datasets/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/coco/ +# Example usage: yolo train data=coco128.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco128-seg ← downloads here (7 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco128-seg # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco128-seg.zip diff --git a/tracking/ultralytics/cfg/datasets/coco128.yaml b/tracking/ultralytics/cfg/datasets/coco128.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12ff0511bcd0df62be6f05762743543e1eb21524 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco128.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO128 dataset https://www.kaggle.com/datasets/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/coco/ +# Example usage: yolo train data=coco128.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco128 ← downloads here (7 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco128 # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco128.zip diff --git a/tracking/ultralytics/cfg/datasets/coco8-pose.yaml b/tracking/ultralytics/cfg/datasets/coco8-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e8af1e344804571664f5be7de288ba5ff7c3822 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco8-pose.yaml @@ -0,0 +1,26 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/coco8-pose/ +# Example usage: yolo train data=coco8-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco8-pose ← downloads here (1 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco8-pose # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) + +# Keypoints +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + +# Classes +names: + 0: person + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-pose.zip diff --git a/tracking/ultralytics/cfg/datasets/coco8-seg.yaml b/tracking/ultralytics/cfg/datasets/coco8-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ea6b31004cbf5fe33914387532d7e46b1967aa8 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco8-seg.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/coco8-seg/ +# Example usage: yolo train data=coco8-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco8-seg ← downloads here (1 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco8-seg # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-seg.zip diff --git a/tracking/ultralytics/cfg/datasets/coco8.yaml b/tracking/ultralytics/cfg/datasets/coco8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8200738b46d0c5414afc6b2c68027bcfe40b739c --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/coco8.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/coco8/ +# Example usage: yolo train data=coco8.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco8 ← downloads here (1 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8.zip diff --git a/tracking/ultralytics/cfg/datasets/crack-seg.yaml b/tracking/ultralytics/cfg/datasets/crack-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11bdd5f575fe7d8d9046636cb595ce707e8601a7 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/crack-seg.yaml @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Crack-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/crack-seg/ +# Example usage: yolo train data=crack-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── crack-seg ← downloads here (91.2 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/crack-seg # dataset root dir +train: train/images # train images (relative to 'path') 3717 images +val: valid/images # val images (relative to 'path') 112 images +test: test/images # test images (relative to 'path') 200 images + +# Classes +names: + 0: crack + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/crack-seg.zip diff --git a/tracking/ultralytics/cfg/datasets/dog-pose.yaml b/tracking/ultralytics/cfg/datasets/dog-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..447e542ce6c124533e25c7fa2a5caed88570d4ec --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/dog-pose.yaml @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Dogs dataset http://vision.stanford.edu/aditya86/ImageNetDogs/ by Stanford +# Documentation: https://docs.ultralytics.com/datasets/pose/dog-pose/ +# Example usage: yolo train data=dog-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dog-pose ← downloads here (337 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/dog-pose # dataset root dir +train: train # train images (relative to 'path') 6773 images +val: val # val images (relative to 'path') 1703 images + +# Keypoints +kpt_shape: [24, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) + +# Classes +names: + 0: dog + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/dog-pose.zip diff --git a/tracking/ultralytics/cfg/datasets/dota8.yaml b/tracking/ultralytics/cfg/datasets/dota8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..486d9e2effbd0e2161ff982d2df2aea16b6d9fa0 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/dota8.yaml @@ -0,0 +1,35 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DOTA8 dataset 8 images from split DOTAv1 dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/obb/dota8/ +# Example usage: yolo train model=yolov8n-obb.pt data=dota8.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota8 ← downloads here (1MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/dota8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/dota8.zip diff --git a/tracking/ultralytics/cfg/datasets/hand-keypoints.yaml b/tracking/ultralytics/cfg/datasets/hand-keypoints.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d2f765c789c828efa2b993249c238e466985eb3 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/hand-keypoints.yaml @@ -0,0 +1,26 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Hand Keypoints dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/hand-keypoints/ +# Example usage: yolo train data=hand-keypoints.yaml +# parent +# ├── ultralytics +# └── datasets +# └── hand-keypoints ← downloads here (369 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/hand-keypoints # dataset root dir +train: train # train images (relative to 'path') 18776 images +val: val # val images (relative to 'path') 7992 images + +# Keypoints +kpt_shape: [21, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: + [0, 1, 2, 4, 3, 10, 11, 12, 13, 14, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19, 20] + +# Classes +names: + 0: hand + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/hand-keypoints.zip diff --git a/tracking/ultralytics/cfg/datasets/lvis.yaml b/tracking/ultralytics/cfg/datasets/lvis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3af62bbc800a8e832939697be9b8d75386bd2cb0 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/lvis.yaml @@ -0,0 +1,1240 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# LVIS dataset http://www.lvisdataset.org by Facebook AI Research. +# Documentation: https://docs.ultralytics.com/datasets/detect/lvis/ +# Example usage: yolo train data=lvis.yaml +# parent +# ├── ultralytics +# └── datasets +# └── lvis ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/lvis # dataset root dir +train: train.txt # train images (relative to 'path') 100170 images +val: val.txt # val images (relative to 'path') 19809 images +minival: minival.txt # minival images (relative to 'path') 5000 images + +names: + 0: aerosol can/spray can + 1: air conditioner + 2: airplane/aeroplane + 3: alarm clock + 4: alcohol/alcoholic beverage + 5: alligator/gator + 6: almond + 7: ambulance + 8: amplifier + 9: anklet/ankle bracelet + 10: antenna/aerial/transmitting aerial + 11: apple + 12: applesauce + 13: apricot + 14: apron + 15: aquarium/fish tank + 16: arctic/arctic type of shoe/galosh/golosh/rubber/rubber type of shoe/gumshoe + 17: armband + 18: armchair + 19: armoire + 20: armor/armour + 21: artichoke + 22: trash can/garbage can/wastebin/dustbin/trash barrel/trash bin + 23: ashtray + 24: asparagus + 25: atomizer/atomiser/spray/sprayer/nebulizer/nebuliser + 26: avocado + 27: award/accolade + 28: awning + 29: ax/axe + 30: baboon + 31: baby buggy/baby carriage/perambulator/pram/stroller + 32: basketball backboard + 33: backpack/knapsack/packsack/rucksack/haversack + 34: handbag/purse/pocketbook + 35: suitcase/baggage/luggage + 36: bagel/beigel + 37: bagpipe + 38: baguet/baguette + 39: bait/lure + 40: ball + 41: ballet skirt/tutu + 42: balloon + 43: bamboo + 44: banana + 45: Band Aid + 46: bandage + 47: bandanna/bandana + 48: banjo + 49: banner/streamer + 50: barbell + 51: barge + 52: barrel/cask + 53: barrette + 54: barrow/garden cart/lawn cart/wheelbarrow + 55: baseball base + 56: baseball + 57: baseball bat + 58: baseball cap/jockey cap/golf cap + 59: baseball glove/baseball mitt + 60: basket/handbasket + 61: basketball + 62: bass horn/sousaphone/tuba + 63: bat/bat animal + 64: bath mat + 65: bath towel + 66: bathrobe + 67: bathtub/bathing tub + 68: batter/batter food + 69: battery + 70: beachball + 71: bead + 72: bean curd/tofu + 73: beanbag + 74: beanie/beany + 75: bear + 76: bed + 77: bedpan + 78: bedspread/bedcover/bed covering/counterpane/spread + 79: cow + 80: beef/beef food/boeuf/boeuf food + 81: beeper/pager + 82: beer bottle + 83: beer can + 84: beetle + 85: bell + 86: bell pepper/capsicum + 87: belt + 88: belt buckle + 89: bench + 90: beret + 91: bib + 92: Bible + 93: bicycle/bike/bike bicycle + 94: visor/vizor + 95: billboard + 96: binder/ring-binder + 97: binoculars/field glasses/opera glasses + 98: bird + 99: birdfeeder + 100: birdbath + 101: birdcage + 102: birdhouse + 103: birthday cake + 104: birthday card + 105: pirate flag + 106: black sheep + 107: blackberry + 108: blackboard/chalkboard + 109: blanket + 110: blazer/sport jacket/sport coat/sports jacket/sports coat + 111: blender/liquidizer/liquidiser + 112: blimp + 113: blinker/flasher + 114: blouse + 115: blueberry + 116: gameboard + 117: boat/ship/ship boat + 118: bob/bobber/bobfloat + 119: bobbin/spool/reel + 120: bobby pin/hairgrip + 121: boiled egg/coddled egg + 122: bolo tie/bolo/bola tie/bola + 123: deadbolt + 124: bolt + 125: bonnet + 126: book + 127: bookcase + 128: booklet/brochure/leaflet/pamphlet + 129: bookmark/bookmarker + 130: boom microphone/microphone boom + 131: boot + 132: bottle + 133: bottle opener + 134: bouquet + 135: bow/bow weapon + 136: bow/bow decorative ribbons + 137: bow-tie/bowtie + 138: bowl + 139: pipe bowl + 140: bowler hat/bowler/derby hat/derby/plug hat + 141: bowling ball + 142: box + 143: boxing glove + 144: suspenders + 145: bracelet/bangle + 146: brass plaque + 147: brassiere/bra/bandeau + 148: bread-bin/breadbox + 149: bread + 150: breechcloth/breechclout/loincloth + 151: bridal gown/wedding gown/wedding dress + 152: briefcase + 153: broccoli + 154: broach + 155: broom + 156: brownie + 157: brussels sprouts + 158: bubble gum + 159: bucket/pail + 160: horse buggy + 161: horned cow + 162: bulldog + 163: bulldozer/dozer + 164: bullet train + 165: bulletin board/notice board + 166: bulletproof vest + 167: bullhorn/megaphone + 168: bun/roll + 169: bunk bed + 170: buoy + 171: burrito + 172: bus/bus vehicle/autobus/charabanc/double-decker/motorbus/motorcoach + 173: business card + 174: butter + 175: butterfly + 176: button + 177: cab/cab taxi/taxi/taxicab + 178: cabana + 179: cabin car/caboose + 180: cabinet + 181: locker/storage locker + 182: cake + 183: calculator + 184: calendar + 185: calf + 186: camcorder + 187: camel + 188: camera + 189: camera lens + 190: camper/camper vehicle/camping bus/motor home + 191: can/tin can + 192: can opener/tin opener + 193: candle/candlestick + 194: candle holder + 195: candy bar + 196: candy cane + 197: walking cane + 198: canister/canister + 199: canoe + 200: cantaloup/cantaloupe + 201: canteen + 202: cap/cap headwear + 203: bottle cap/cap/cap container lid + 204: cape + 205: cappuccino/coffee cappuccino + 206: car/car automobile/auto/auto automobile/automobile + 207: railcar/railcar part of a train/railway car/railway car part of a train/railroad car/railroad car part of a train + 208: elevator car + 209: car battery/automobile battery + 210: identity card + 211: card + 212: cardigan + 213: cargo ship/cargo vessel + 214: carnation + 215: horse carriage + 216: carrot + 217: tote bag + 218: cart + 219: carton + 220: cash register/register/register for cash transactions + 221: casserole + 222: cassette + 223: cast/plaster cast/plaster bandage + 224: cat + 225: cauliflower + 226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper spice + 227: CD player + 228: celery + 229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone + 230: chain mail/ring mail/chain armor/chain armour/ring armor/ring armour + 231: chair + 232: chaise longue/chaise/daybed + 233: chalice + 234: chandelier + 235: chap + 236: checkbook/chequebook + 237: checkerboard + 238: cherry + 239: chessboard + 240: chicken/chicken animal + 241: chickpea/garbanzo + 242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly vegetable/chile/chile vegetable + 243: chime/gong + 244: chinaware + 245: crisp/crisp potato chip/potato chip + 246: poker chip + 247: chocolate bar + 248: chocolate cake + 249: chocolate milk + 250: chocolate mousse + 251: choker/collar/neckband + 252: chopping board/cutting board/chopping block + 253: chopstick + 254: Christmas tree + 255: slide + 256: cider/cyder + 257: cigar box + 258: cigarette + 259: cigarette case/cigarette pack + 260: cistern/water tank + 261: clarinet + 262: clasp + 263: cleansing agent/cleanser/cleaner + 264: cleat/cleat for securing rope + 265: clementine + 266: clip + 267: clipboard + 268: clippers/clippers for plants + 269: cloak + 270: clock/timepiece/timekeeper + 271: clock tower + 272: clothes hamper/laundry basket/clothes basket + 273: clothespin/clothes peg + 274: clutch bag + 275: coaster + 276: coat + 277: coat hanger/clothes hanger/dress hanger + 278: coatrack/hatrack + 279: cock/rooster + 280: cockroach + 281: cocoa/cocoa beverage/hot chocolate/hot chocolate beverage/drinking chocolate + 282: coconut/cocoanut + 283: coffee maker/coffee machine + 284: coffee table/cocktail table + 285: coffeepot + 286: coil + 287: coin + 288: colander/cullender + 289: coleslaw/slaw + 290: coloring material/colouring material + 291: combination lock + 292: pacifier/teething ring + 293: comic book + 294: compass + 295: computer keyboard/keyboard/keyboard computer + 296: condiment + 297: cone/traffic cone + 298: control/controller + 299: convertible/convertible automobile + 300: sofa bed + 301: cooker + 302: cookie/cooky/biscuit/biscuit cookie + 303: cooking utensil + 304: cooler/cooler for food/ice chest + 305: cork/cork bottle plug/bottle cork + 306: corkboard + 307: corkscrew/bottle screw + 308: edible corn/corn/maize + 309: cornbread + 310: cornet/horn/trumpet + 311: cornice/valance/valance board/pelmet + 312: cornmeal + 313: corset/girdle + 314: costume + 315: cougar/puma/catamount/mountain lion/panther + 316: coverall + 317: cowbell + 318: cowboy hat/ten-gallon hat + 319: crab/crab animal + 320: crabmeat + 321: cracker + 322: crape/crepe/French pancake + 323: crate + 324: crayon/wax crayon + 325: cream pitcher + 326: crescent roll/croissant + 327: crib/cot + 328: crock pot/earthenware jar + 329: crossbar + 330: crouton + 331: crow + 332: crowbar/wrecking bar/pry bar + 333: crown + 334: crucifix + 335: cruise ship/cruise liner + 336: police cruiser/patrol car/police car/squad car + 337: crumb + 338: crutch + 339: cub/cub animal + 340: cube/square block + 341: cucumber/cuke + 342: cufflink + 343: cup + 344: trophy cup + 345: cupboard/closet + 346: cupcake + 347: hair curler/hair roller/hair crimper + 348: curling iron + 349: curtain/drapery + 350: cushion + 351: cylinder + 352: cymbal + 353: dagger + 354: dalmatian + 355: dartboard + 356: date/date fruit + 357: deck chair/beach chair + 358: deer/cervid + 359: dental floss/floss + 360: desk + 361: detergent + 362: diaper + 363: diary/journal + 364: die/dice + 365: dinghy/dory/rowboat + 366: dining table + 367: tux/tuxedo + 368: dish + 369: dish antenna + 370: dishrag/dishcloth + 371: dishtowel/tea towel + 372: dishwasher/dishwashing machine + 373: dishwasher detergent/dishwashing detergent/dishwashing liquid/dishsoap + 374: dispenser + 375: diving board + 376: Dixie cup/paper cup + 377: dog + 378: dog collar + 379: doll + 380: dollar/dollar bill/one dollar bill + 381: dollhouse/doll's house + 382: dolphin + 383: domestic ass/donkey + 384: doorknob/doorhandle + 385: doormat/welcome mat + 386: doughnut/donut + 387: dove + 388: dragonfly + 389: drawer + 390: underdrawers/boxers/boxershorts + 391: dress/frock + 392: dress hat/high hat/opera hat/silk hat/top hat + 393: dress suit + 394: dresser + 395: drill + 396: drone + 397: dropper/eye dropper + 398: drum/drum musical instrument + 399: drumstick + 400: duck + 401: duckling + 402: duct tape + 403: duffel bag/duffle bag/duffel/duffle + 404: dumbbell + 405: dumpster + 406: dustpan + 407: eagle + 408: earphone/earpiece/headphone + 409: earplug + 410: earring + 411: easel + 412: eclair + 413: eel + 414: egg/eggs + 415: egg roll/spring roll + 416: egg yolk/yolk/yolk egg + 417: eggbeater/eggwhisk + 418: eggplant/aubergine + 419: electric chair + 420: refrigerator + 421: elephant + 422: elk/moose + 423: envelope + 424: eraser + 425: escargot + 426: eyepatch + 427: falcon + 428: fan + 429: faucet/spigot/tap + 430: fedora + 431: ferret + 432: Ferris wheel + 433: ferry/ferryboat + 434: fig/fig fruit + 435: fighter jet/fighter aircraft/attack aircraft + 436: figurine + 437: file cabinet/filing cabinet + 438: file/file tool + 439: fire alarm/smoke alarm + 440: fire engine/fire truck + 441: fire extinguisher/extinguisher + 442: fire hose + 443: fireplace + 444: fireplug/fire hydrant/hydrant + 445: first-aid kit + 446: fish + 447: fish/fish food + 448: fishbowl/goldfish bowl + 449: fishing rod/fishing pole + 450: flag + 451: flagpole/flagstaff + 452: flamingo + 453: flannel + 454: flap + 455: flash/flashbulb + 456: flashlight/torch + 457: fleece + 458: flip-flop/flip-flop sandal + 459: flipper/flipper footwear/fin/fin footwear + 460: flower arrangement/floral arrangement + 461: flute glass/champagne flute + 462: foal + 463: folding chair + 464: food processor + 465: football/football American + 466: football helmet + 467: footstool/footrest + 468: fork + 469: forklift + 470: freight car + 471: French toast + 472: freshener/air freshener + 473: frisbee + 474: frog/toad/toad frog + 475: fruit juice + 476: frying pan/frypan/skillet + 477: fudge + 478: funnel + 479: futon + 480: gag/muzzle + 481: garbage + 482: garbage truck + 483: garden hose + 484: gargle/mouthwash + 485: gargoyle + 486: garlic/ail + 487: gasmask/respirator/gas helmet + 488: gazelle + 489: gelatin/jelly + 490: gemstone + 491: generator + 492: giant panda/panda/panda bear + 493: gift wrap + 494: ginger/gingerroot + 495: giraffe + 496: cincture/sash/waistband/waistcloth + 497: glass/glass drink container/drinking glass + 498: globe + 499: glove + 500: goat + 501: goggles + 502: goldfish + 503: golf club/golf-club + 504: golfcart + 505: gondola/gondola boat + 506: goose + 507: gorilla + 508: gourd + 509: grape + 510: grater + 511: gravestone/headstone/tombstone + 512: gravy boat/gravy holder + 513: green bean + 514: green onion/spring onion/scallion + 515: griddle + 516: grill/grille/grillwork/radiator grille + 517: grits/hominy grits + 518: grizzly/grizzly bear + 519: grocery bag + 520: guitar + 521: gull/seagull + 522: gun + 523: hairbrush + 524: hairnet + 525: hairpin + 526: halter top + 527: ham/jambon/gammon + 528: hamburger/beefburger/burger + 529: hammer + 530: hammock + 531: hamper + 532: hamster + 533: hair dryer + 534: hand glass/hand mirror + 535: hand towel/face towel + 536: handcart/pushcart/hand truck + 537: handcuff + 538: handkerchief + 539: handle/grip/handgrip + 540: handsaw/carpenter's saw + 541: hardback book/hardcover book + 542: harmonium/organ/organ musical instrument/reed organ/reed organ musical instrument + 543: hat + 544: hatbox + 545: veil + 546: headband + 547: headboard + 548: headlight/headlamp + 549: headscarf + 550: headset + 551: headstall/headstall for horses/headpiece/headpiece for horses + 552: heart + 553: heater/warmer + 554: helicopter + 555: helmet + 556: heron + 557: highchair/feeding chair + 558: hinge + 559: hippopotamus + 560: hockey stick + 561: hog/pig + 562: home plate/home plate baseball/home base/home base baseball + 563: honey + 564: fume hood/exhaust hood + 565: hook + 566: hookah/narghile/nargileh/sheesha/shisha/water pipe + 567: hornet + 568: horse + 569: hose/hosepipe + 570: hot-air balloon + 571: hotplate + 572: hot sauce + 573: hourglass + 574: houseboat + 575: hummingbird + 576: hummus/humus/hommos/hoummos/humous + 577: polar bear + 578: icecream + 579: popsicle + 580: ice maker + 581: ice pack/ice bag + 582: ice skate + 583: igniter/ignitor/lighter + 584: inhaler/inhalator + 585: iPod + 586: iron/iron for clothing/smoothing iron/smoothing iron for clothing + 587: ironing board + 588: jacket + 589: jam + 590: jar + 591: jean/blue jean/denim + 592: jeep/landrover + 593: jelly bean/jelly egg + 594: jersey/T-shirt/tee shirt + 595: jet plane/jet-propelled plane + 596: jewel/gem/precious stone + 597: jewelry/jewellery + 598: joystick + 599: jumpsuit + 600: kayak + 601: keg + 602: kennel/doghouse + 603: kettle/boiler + 604: key + 605: keycard + 606: kilt + 607: kimono + 608: kitchen sink + 609: kitchen table + 610: kite + 611: kitten/kitty + 612: kiwi fruit + 613: knee pad + 614: knife + 615: knitting needle + 616: knob + 617: knocker/knocker on a door/doorknocker + 618: koala/koala bear + 619: lab coat/laboratory coat + 620: ladder + 621: ladle + 622: ladybug/ladybeetle/ladybird beetle + 623: lamb/lamb animal + 624: lamb-chop/lambchop + 625: lamp + 626: lamppost + 627: lampshade + 628: lantern + 629: lanyard/laniard + 630: laptop computer/notebook computer + 631: lasagna/lasagne + 632: latch + 633: lawn mower + 634: leather + 635: legging/legging clothing/leging/leging clothing/leg covering + 636: Lego/Lego set + 637: legume + 638: lemon + 639: lemonade + 640: lettuce + 641: license plate/numberplate + 642: life buoy/lifesaver/life belt/life ring + 643: life jacket/life vest + 644: lightbulb + 645: lightning rod/lightning conductor + 646: lime + 647: limousine + 648: lion + 649: lip balm + 650: liquor/spirits/hard liquor/liqueur/cordial + 651: lizard + 652: log + 653: lollipop + 654: speaker/speaker stereo equipment + 655: loveseat + 656: machine gun + 657: magazine + 658: magnet + 659: mail slot + 660: mailbox/mailbox at home/letter box/letter box at home + 661: mallard + 662: mallet + 663: mammoth + 664: manatee + 665: mandarin orange + 666: manager/through + 667: manhole + 668: map + 669: marker + 670: martini + 671: mascot + 672: mashed potato + 673: masher + 674: mask/facemask + 675: mast + 676: mat/mat gym equipment/gym mat + 677: matchbox + 678: mattress + 679: measuring cup + 680: measuring stick/ruler/ruler measuring stick/measuring rod + 681: meatball + 682: medicine + 683: melon + 684: microphone + 685: microscope + 686: microwave oven + 687: milestone/milepost + 688: milk + 689: milk can + 690: milkshake + 691: minivan + 692: mint candy + 693: mirror + 694: mitten + 695: mixer/mixer kitchen tool/stand mixer + 696: money + 697: monitor/monitor computer equipment + 698: monkey + 699: motor + 700: motor scooter/scooter + 701: motor vehicle/automotive vehicle + 702: motorcycle + 703: mound/mound baseball/pitcher's mound + 704: mouse/mouse computer equipment/computer mouse + 705: mousepad + 706: muffin + 707: mug + 708: mushroom + 709: music stool/piano stool + 710: musical instrument/instrument/instrument musical + 711: nailfile + 712: napkin/table napkin/serviette + 713: neckerchief + 714: necklace + 715: necktie/tie/tie necktie + 716: needle + 717: nest + 718: newspaper/paper/paper newspaper + 719: newsstand + 720: nightshirt/nightwear/sleepwear/nightclothes + 721: nosebag/nosebag for animals/feedbag + 722: noseband/noseband for animals/nosepiece/nosepiece for animals + 723: notebook + 724: notepad + 725: nut + 726: nutcracker + 727: oar + 728: octopus/octopus food + 729: octopus/octopus animal + 730: oil lamp/kerosene lamp/kerosine lamp + 731: olive oil + 732: omelet/omelette + 733: onion + 734: orange/orange fruit + 735: orange juice + 736: ostrich + 737: ottoman/pouf/pouffe/hassock + 738: oven + 739: overalls/overalls clothing + 740: owl + 741: packet + 742: inkpad/inking pad/stamp pad + 743: pad + 744: paddle/boat paddle + 745: padlock + 746: paintbrush + 747: painting + 748: pajamas/pyjamas + 749: palette/pallet + 750: pan/pan for cooking/cooking pan + 751: pan/pan metal container + 752: pancake + 753: pantyhose + 754: papaya + 755: paper plate + 756: paper towel + 757: paperback book/paper-back book/softback book/soft-cover book + 758: paperweight + 759: parachute + 760: parakeet/parrakeet/parroket/paraquet/paroquet/parroquet + 761: parasail/parasail sports + 762: parasol/sunshade + 763: parchment + 764: parka/anorak + 765: parking meter + 766: parrot + 767: passenger car/passenger car part of a train/coach/coach part of a train + 768: passenger ship + 769: passport + 770: pastry + 771: patty/patty food + 772: pea/pea food + 773: peach + 774: peanut butter + 775: pear + 776: peeler/peeler tool for fruit and vegetables + 777: wooden leg/pegleg + 778: pegboard + 779: pelican + 780: pen + 781: pencil + 782: pencil box/pencil case + 783: pencil sharpener + 784: pendulum + 785: penguin + 786: pennant + 787: penny/penny coin + 788: pepper/peppercorn + 789: pepper mill/pepper grinder + 790: perfume + 791: persimmon + 792: person/baby/child/boy/girl/man/woman/human + 793: pet + 794: pew/pew church bench/church bench + 795: phonebook/telephone book/telephone directory + 796: phonograph record/phonograph recording/record/record phonograph recording + 797: piano + 798: pickle + 799: pickup truck + 800: pie + 801: pigeon + 802: piggy bank/penny bank + 803: pillow + 804: pin/pin non jewelry + 805: pineapple + 806: pinecone + 807: ping-pong ball + 808: pinwheel + 809: tobacco pipe + 810: pipe/piping + 811: pistol/handgun + 812: pita/pita bread/pocket bread + 813: pitcher/pitcher vessel for liquid/ewer + 814: pitchfork + 815: pizza + 816: place mat + 817: plate + 818: platter + 819: playpen + 820: pliers/plyers + 821: plow/plow farm equipment/plough/plough farm equipment + 822: plume + 823: pocket watch + 824: pocketknife + 825: poker/poker fire stirring tool/stove poker/fire hook + 826: pole/post + 827: polo shirt/sport shirt + 828: poncho + 829: pony + 830: pool table/billiard table/snooker table + 831: pop/pop soda/soda/soda pop/tonic/soft drink + 832: postbox/postbox public/mailbox/mailbox public + 833: postcard/postal card/mailing-card + 834: poster/placard + 835: pot + 836: flowerpot + 837: potato + 838: potholder + 839: pottery/clayware + 840: pouch + 841: power shovel/excavator/digger + 842: prawn/shrimp + 843: pretzel + 844: printer/printing machine + 845: projectile/projectile weapon/missile + 846: projector + 847: propeller/propellor + 848: prune + 849: pudding + 850: puffer/puffer fish/pufferfish/blowfish/globefish + 851: puffin + 852: pug-dog + 853: pumpkin + 854: puncher + 855: puppet/marionette + 856: puppy + 857: quesadilla + 858: quiche + 859: quilt/comforter + 860: rabbit + 861: race car/racing car + 862: racket/racquet + 863: radar + 864: radiator + 865: radio receiver/radio set/radio/tuner/tuner radio + 866: radish/daikon + 867: raft + 868: rag doll + 869: raincoat/waterproof jacket + 870: ram/ram animal + 871: raspberry + 872: rat + 873: razorblade + 874: reamer/reamer juicer/juicer/juice reamer + 875: rearview mirror + 876: receipt + 877: recliner/reclining chair/lounger/lounger chair + 878: record player/phonograph/phonograph record player/turntable + 879: reflector + 880: remote control + 881: rhinoceros + 882: rib/rib food + 883: rifle + 884: ring + 885: river boat + 886: road map + 887: robe + 888: rocking chair + 889: rodent + 890: roller skate + 891: Rollerblade + 892: rolling pin + 893: root beer + 894: router/router computer equipment + 895: rubber band/elastic band + 896: runner/runner carpet + 897: plastic bag/paper bag + 898: saddle/saddle on an animal + 899: saddle blanket/saddlecloth/horse blanket + 900: saddlebag + 901: safety pin + 902: sail + 903: salad + 904: salad plate/salad bowl + 905: salami + 906: salmon/salmon fish + 907: salmon/salmon food + 908: salsa + 909: saltshaker + 910: sandal/sandal type of shoe + 911: sandwich + 912: satchel + 913: saucepan + 914: saucer + 915: sausage + 916: sawhorse/sawbuck + 917: saxophone + 918: scale/scale measuring instrument + 919: scarecrow/strawman + 920: scarf + 921: school bus + 922: scissors + 923: scoreboard + 924: scraper + 925: screwdriver + 926: scrubbing brush + 927: sculpture + 928: seabird/seafowl + 929: seahorse + 930: seaplane/hydroplane + 931: seashell + 932: sewing machine + 933: shaker + 934: shampoo + 935: shark + 936: sharpener + 937: Sharpie + 938: shaver/shaver electric/electric shaver/electric razor + 939: shaving cream/shaving soap + 940: shawl + 941: shears + 942: sheep + 943: shepherd dog/sheepdog + 944: sherbert/sherbet + 945: shield + 946: shirt + 947: shoe/sneaker/sneaker type of shoe/tennis shoe + 948: shopping bag + 949: shopping cart + 950: short pants/shorts/shorts clothing/trunks/trunks clothing + 951: shot glass + 952: shoulder bag + 953: shovel + 954: shower head + 955: shower cap + 956: shower curtain + 957: shredder/shredder for paper + 958: signboard + 959: silo + 960: sink + 961: skateboard + 962: skewer + 963: ski + 964: ski boot + 965: ski parka/ski jacket + 966: ski pole + 967: skirt + 968: skullcap + 969: sled/sledge/sleigh + 970: sleeping bag + 971: sling/sling bandage/triangular bandage + 972: slipper/slipper footwear/carpet slipper/carpet slipper footwear + 973: smoothie + 974: snake/serpent + 975: snowboard + 976: snowman + 977: snowmobile + 978: soap + 979: soccer ball + 980: sock + 981: sofa/couch/lounge + 982: softball + 983: solar array/solar battery/solar panel + 984: sombrero + 985: soup + 986: soup bowl + 987: soupspoon + 988: sour cream/soured cream + 989: soya milk/soybean milk/soymilk + 990: space shuttle + 991: sparkler/sparkler fireworks + 992: spatula + 993: spear/lance + 994: spectacles/specs/eyeglasses/glasses + 995: spice rack + 996: spider + 997: crawfish/crayfish + 998: sponge + 999: spoon + 1000: sportswear/athletic wear/activewear + 1001: spotlight + 1002: squid/squid food/calamari/calamary + 1003: squirrel + 1004: stagecoach + 1005: stapler/stapler stapling machine + 1006: starfish/sea star + 1007: statue/statue sculpture + 1008: steak/steak food + 1009: steak knife + 1010: steering wheel + 1011: stepladder + 1012: step stool + 1013: stereo/stereo sound system + 1014: stew + 1015: stirrer + 1016: stirrup + 1017: stool + 1018: stop sign + 1019: brake light + 1020: stove/kitchen stove/range/range kitchen appliance/kitchen range/cooking stove + 1021: strainer + 1022: strap + 1023: straw/straw for drinking/drinking straw + 1024: strawberry + 1025: street sign + 1026: streetlight/street lamp + 1027: string cheese + 1028: stylus + 1029: subwoofer + 1030: sugar bowl + 1031: sugarcane/sugarcane plant + 1032: suit/suit clothing + 1033: sunflower + 1034: sunglasses + 1035: sunhat + 1036: surfboard + 1037: sushi + 1038: mop + 1039: sweat pants + 1040: sweatband + 1041: sweater + 1042: sweatshirt + 1043: sweet potato + 1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing trunks + 1045: sword + 1046: syringe + 1047: Tabasco sauce + 1048: table-tennis table/ping-pong table + 1049: table + 1050: table lamp + 1051: tablecloth + 1052: tachometer + 1053: taco + 1054: tag + 1055: taillight/rear light + 1056: tambourine + 1057: army tank/armored combat vehicle/armoured combat vehicle + 1058: tank/tank storage vessel/storage tank + 1059: tank top/tank top clothing + 1060: tape/tape sticky cloth or paper + 1061: tape measure/measuring tape + 1062: tapestry + 1063: tarp + 1064: tartan/plaid + 1065: tassel + 1066: tea bag + 1067: teacup + 1068: teakettle + 1069: teapot + 1070: teddy bear + 1071: telephone/phone/telephone set + 1072: telephone booth/phone booth/call box/telephone box/telephone kiosk + 1073: telephone pole/telegraph pole/telegraph post + 1074: telephoto lens/zoom lens + 1075: television camera/tv camera + 1076: television set/tv/tv set + 1077: tennis ball + 1078: tennis racket + 1079: tequila + 1080: thermometer + 1081: thermos bottle + 1082: thermostat + 1083: thimble + 1084: thread/yarn + 1085: thumbtack/drawing pin/pushpin + 1086: tiara + 1087: tiger + 1088: tights/tights clothing/leotards + 1089: timer/stopwatch + 1090: tinfoil + 1091: tinsel + 1092: tissue paper + 1093: toast/toast food + 1094: toaster + 1095: toaster oven + 1096: toilet + 1097: toilet tissue/toilet paper/bathroom tissue + 1098: tomato + 1099: tongs + 1100: toolbox + 1101: toothbrush + 1102: toothpaste + 1103: toothpick + 1104: cover + 1105: tortilla + 1106: tow truck + 1107: towel + 1108: towel rack/towel rail/towel bar + 1109: toy + 1110: tractor/tractor farm equipment + 1111: traffic light + 1112: dirt bike + 1113: trailer truck/tractor trailer/trucking rig/articulated lorry/semi truck + 1114: train/train railroad vehicle/railroad train + 1115: trampoline + 1116: tray + 1117: trench coat + 1118: triangle/triangle musical instrument + 1119: tricycle + 1120: tripod + 1121: trousers/pants/pants clothing + 1122: truck + 1123: truffle/truffle chocolate/chocolate truffle + 1124: trunk + 1125: vat + 1126: turban + 1127: turkey/turkey food + 1128: turnip + 1129: turtle + 1130: turtleneck/turtleneck clothing/polo-neck + 1131: typewriter + 1132: umbrella + 1133: underwear/underclothes/underclothing/underpants + 1134: unicycle + 1135: urinal + 1136: urn + 1137: vacuum cleaner + 1138: vase + 1139: vending machine + 1140: vent/blowhole/air vent + 1141: vest/waistcoat + 1142: videotape + 1143: vinegar + 1144: violin/fiddle + 1145: vodka + 1146: volleyball + 1147: vulture + 1148: waffle + 1149: waffle iron + 1150: wagon + 1151: wagon wheel + 1152: walking stick + 1153: wall clock + 1154: wall socket/wall plug/electric outlet/electrical outlet/outlet/electric receptacle + 1155: wallet/billfold + 1156: walrus + 1157: wardrobe + 1158: washbasin/basin/basin for washing/washbowl/washstand/handbasin + 1159: automatic washer/washing machine + 1160: watch/wristwatch + 1161: water bottle + 1162: water cooler + 1163: water faucet/water tap/tap/tap water faucet + 1164: water heater/hot-water heater + 1165: water jug + 1166: water gun/squirt gun + 1167: water scooter/sea scooter/jet ski + 1168: water ski + 1169: water tower + 1170: watering can + 1171: watermelon + 1172: weathervane/vane/vane weathervane/wind vane + 1173: webcam + 1174: wedding cake/bridecake + 1175: wedding ring/wedding band + 1176: wet suit + 1177: wheel + 1178: wheelchair + 1179: whipped cream + 1180: whistle + 1181: wig + 1182: wind chime + 1183: windmill + 1184: window box/window box for plants + 1185: windshield wiper/windscreen wiper/wiper/wiper for windshield or screen + 1186: windsock/air sock/air-sleeve/wind sleeve/wind cone + 1187: wine bottle + 1188: wine bucket/wine cooler + 1189: wineglass + 1190: blinder/blinder for horses + 1191: wok + 1192: wolf + 1193: wooden spoon + 1194: wreath + 1195: wrench/spanner + 1196: wristband + 1197: wristlet/wrist band + 1198: yacht + 1199: yogurt/yoghurt/yoghourt + 1200: yoke/yoke animal equipment + 1201: zebra + 1202: zucchini/courgette + +# Download script/URL (optional) +download: | + from pathlib import Path + + from ultralytics.utils.downloads import download + + # Download labels + dir = Path(yaml["path"]) # dataset root dir + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + urls = [f"{url}lvis-labels-segments.zip"] + download(urls, dir=dir.parent) + + # Download data + urls = [ + "http://images.cocodataset.org/zips/train2017.zip", # 19G, 118k images + "http://images.cocodataset.org/zips/val2017.zip", # 1G, 5k images + "http://images.cocodataset.org/zips/test2017.zip", # 7G, 41k images (optional) + ] + download(urls, dir=dir / "images", threads=3) diff --git a/tracking/ultralytics/cfg/datasets/medical-pills.yaml b/tracking/ultralytics/cfg/datasets/medical-pills.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25507c8b9bef023c866aa6d6d41f8a3a6bf0958a --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/medical-pills.yaml @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Medical-pills dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/medical-pills/ +# Example usage: yolo train data=medical-pills.yaml +# parent +# ├── ultralytics +# └── datasets +# └── medical-pills ← downloads here (8.19 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/medical-pills # dataset root dir +train: train/images # train images (relative to 'path') 92 images +val: valid/images # val images (relative to 'path') 23 images +test: # test images (relative to 'path') + +# Classes +names: + 0: pill + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/medical-pills.zip diff --git a/tracking/ultralytics/cfg/datasets/open-images-v7.yaml b/tracking/ultralytics/cfg/datasets/open-images-v7.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e3439f56d31fe06024c964436fccf2525091cef --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/open-images-v7.yaml @@ -0,0 +1,666 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Open Images v7 dataset https://storage.googleapis.com/openimages/web/index.html by Google +# Documentation: https://docs.ultralytics.com/datasets/detect/open-images-v7/ +# Example usage: yolo train data=open-images-v7.yaml +# parent +# ├── ultralytics +# └── datasets +# └── open-images-v7 ← downloads here (561 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/open-images-v7 # dataset root dir +train: images/train # train images (relative to 'path') 1743042 images +val: images/val # val images (relative to 'path') 41620 images +test: # test images (optional) + +# Classes +names: + 0: Accordion + 1: Adhesive tape + 2: Aircraft + 3: Airplane + 4: Alarm clock + 5: Alpaca + 6: Ambulance + 7: Animal + 8: Ant + 9: Antelope + 10: Apple + 11: Armadillo + 12: Artichoke + 13: Auto part + 14: Axe + 15: Backpack + 16: Bagel + 17: Baked goods + 18: Balance beam + 19: Ball + 20: Balloon + 21: Banana + 22: Band-aid + 23: Banjo + 24: Barge + 25: Barrel + 26: Baseball bat + 27: Baseball glove + 28: Bat (Animal) + 29: Bathroom accessory + 30: Bathroom cabinet + 31: Bathtub + 32: Beaker + 33: Bear + 34: Bed + 35: Bee + 36: Beehive + 37: Beer + 38: Beetle + 39: Bell pepper + 40: Belt + 41: Bench + 42: Bicycle + 43: Bicycle helmet + 44: Bicycle wheel + 45: Bidet + 46: Billboard + 47: Billiard table + 48: Binoculars + 49: Bird + 50: Blender + 51: Blue jay + 52: Boat + 53: Bomb + 54: Book + 55: Bookcase + 56: Boot + 57: Bottle + 58: Bottle opener + 59: Bow and arrow + 60: Bowl + 61: Bowling equipment + 62: Box + 63: Boy + 64: Brassiere + 65: Bread + 66: Briefcase + 67: Broccoli + 68: Bronze sculpture + 69: Brown bear + 70: Building + 71: Bull + 72: Burrito + 73: Bus + 74: Bust + 75: Butterfly + 76: Cabbage + 77: Cabinetry + 78: Cake + 79: Cake stand + 80: Calculator + 81: Camel + 82: Camera + 83: Can opener + 84: Canary + 85: Candle + 86: Candy + 87: Cannon + 88: Canoe + 89: Cantaloupe + 90: Car + 91: Carnivore + 92: Carrot + 93: Cart + 94: Cassette deck + 95: Castle + 96: Cat + 97: Cat furniture + 98: Caterpillar + 99: Cattle + 100: Ceiling fan + 101: Cello + 102: Centipede + 103: Chainsaw + 104: Chair + 105: Cheese + 106: Cheetah + 107: Chest of drawers + 108: Chicken + 109: Chime + 110: Chisel + 111: Chopsticks + 112: Christmas tree + 113: Clock + 114: Closet + 115: Clothing + 116: Coat + 117: Cocktail + 118: Cocktail shaker + 119: Coconut + 120: Coffee + 121: Coffee cup + 122: Coffee table + 123: Coffeemaker + 124: Coin + 125: Common fig + 126: Common sunflower + 127: Computer keyboard + 128: Computer monitor + 129: Computer mouse + 130: Container + 131: Convenience store + 132: Cookie + 133: Cooking spray + 134: Corded phone + 135: Cosmetics + 136: Couch + 137: Countertop + 138: Cowboy hat + 139: Crab + 140: Cream + 141: Cricket ball + 142: Crocodile + 143: Croissant + 144: Crown + 145: Crutch + 146: Cucumber + 147: Cupboard + 148: Curtain + 149: Cutting board + 150: Dagger + 151: Dairy Product + 152: Deer + 153: Desk + 154: Dessert + 155: Diaper + 156: Dice + 157: Digital clock + 158: Dinosaur + 159: Dishwasher + 160: Dog + 161: Dog bed + 162: Doll + 163: Dolphin + 164: Door + 165: Door handle + 166: Doughnut + 167: Dragonfly + 168: Drawer + 169: Dress + 170: Drill (Tool) + 171: Drink + 172: Drinking straw + 173: Drum + 174: Duck + 175: Dumbbell + 176: Eagle + 177: Earrings + 178: Egg (Food) + 179: Elephant + 180: Envelope + 181: Eraser + 182: Face powder + 183: Facial tissue holder + 184: Falcon + 185: Fashion accessory + 186: Fast food + 187: Fax + 188: Fedora + 189: Filing cabinet + 190: Fire hydrant + 191: Fireplace + 192: Fish + 193: Flag + 194: Flashlight + 195: Flower + 196: Flowerpot + 197: Flute + 198: Flying disc + 199: Food + 200: Food processor + 201: Football + 202: Football helmet + 203: Footwear + 204: Fork + 205: Fountain + 206: Fox + 207: French fries + 208: French horn + 209: Frog + 210: Fruit + 211: Frying pan + 212: Furniture + 213: Garden Asparagus + 214: Gas stove + 215: Giraffe + 216: Girl + 217: Glasses + 218: Glove + 219: Goat + 220: Goggles + 221: Goldfish + 222: Golf ball + 223: Golf cart + 224: Gondola + 225: Goose + 226: Grape + 227: Grapefruit + 228: Grinder + 229: Guacamole + 230: Guitar + 231: Hair dryer + 232: Hair spray + 233: Hamburger + 234: Hammer + 235: Hamster + 236: Hand dryer + 237: Handbag + 238: Handgun + 239: Harbor seal + 240: Harmonica + 241: Harp + 242: Harpsichord + 243: Hat + 244: Headphones + 245: Heater + 246: Hedgehog + 247: Helicopter + 248: Helmet + 249: High heels + 250: Hiking equipment + 251: Hippopotamus + 252: Home appliance + 253: Honeycomb + 254: Horizontal bar + 255: Horse + 256: Hot dog + 257: House + 258: Houseplant + 259: Human arm + 260: Human beard + 261: Human body + 262: Human ear + 263: Human eye + 264: Human face + 265: Human foot + 266: Human hair + 267: Human hand + 268: Human head + 269: Human leg + 270: Human mouth + 271: Human nose + 272: Humidifier + 273: Ice cream + 274: Indoor rower + 275: Infant bed + 276: Insect + 277: Invertebrate + 278: Ipod + 279: Isopod + 280: Jacket + 281: Jacuzzi + 282: Jaguar (Animal) + 283: Jeans + 284: Jellyfish + 285: Jet ski + 286: Jug + 287: Juice + 288: Kangaroo + 289: Kettle + 290: Kitchen & dining room table + 291: Kitchen appliance + 292: Kitchen knife + 293: Kitchen utensil + 294: Kitchenware + 295: Kite + 296: Knife + 297: Koala + 298: Ladder + 299: Ladle + 300: Ladybug + 301: Lamp + 302: Land vehicle + 303: Lantern + 304: Laptop + 305: Lavender (Plant) + 306: Lemon + 307: Leopard + 308: Light bulb + 309: Light switch + 310: Lighthouse + 311: Lily + 312: Limousine + 313: Lion + 314: Lipstick + 315: Lizard + 316: Lobster + 317: Loveseat + 318: Luggage and bags + 319: Lynx + 320: Magpie + 321: Mammal + 322: Man + 323: Mango + 324: Maple + 325: Maracas + 326: Marine invertebrates + 327: Marine mammal + 328: Measuring cup + 329: Mechanical fan + 330: Medical equipment + 331: Microphone + 332: Microwave oven + 333: Milk + 334: Miniskirt + 335: Mirror + 336: Missile + 337: Mixer + 338: Mixing bowl + 339: Mobile phone + 340: Monkey + 341: Moths and butterflies + 342: Motorcycle + 343: Mouse + 344: Muffin + 345: Mug + 346: Mule + 347: Mushroom + 348: Musical instrument + 349: Musical keyboard + 350: Nail (Construction) + 351: Necklace + 352: Nightstand + 353: Oboe + 354: Office building + 355: Office supplies + 356: Orange + 357: Organ (Musical Instrument) + 358: Ostrich + 359: Otter + 360: Oven + 361: Owl + 362: Oyster + 363: Paddle + 364: Palm tree + 365: Pancake + 366: Panda + 367: Paper cutter + 368: Paper towel + 369: Parachute + 370: Parking meter + 371: Parrot + 372: Pasta + 373: Pastry + 374: Peach + 375: Pear + 376: Pen + 377: Pencil case + 378: Pencil sharpener + 379: Penguin + 380: Perfume + 381: Person + 382: Personal care + 383: Personal flotation device + 384: Piano + 385: Picnic basket + 386: Picture frame + 387: Pig + 388: Pillow + 389: Pineapple + 390: Pitcher (Container) + 391: Pizza + 392: Pizza cutter + 393: Plant + 394: Plastic bag + 395: Plate + 396: Platter + 397: Plumbing fixture + 398: Polar bear + 399: Pomegranate + 400: Popcorn + 401: Porch + 402: Porcupine + 403: Poster + 404: Potato + 405: Power plugs and sockets + 406: Pressure cooker + 407: Pretzel + 408: Printer + 409: Pumpkin + 410: Punching bag + 411: Rabbit + 412: Raccoon + 413: Racket + 414: Radish + 415: Ratchet (Device) + 416: Raven + 417: Rays and skates + 418: Red panda + 419: Refrigerator + 420: Remote control + 421: Reptile + 422: Rhinoceros + 423: Rifle + 424: Ring binder + 425: Rocket + 426: Roller skates + 427: Rose + 428: Rugby ball + 429: Ruler + 430: Salad + 431: Salt and pepper shakers + 432: Sandal + 433: Sandwich + 434: Saucer + 435: Saxophone + 436: Scale + 437: Scarf + 438: Scissors + 439: Scoreboard + 440: Scorpion + 441: Screwdriver + 442: Sculpture + 443: Sea lion + 444: Sea turtle + 445: Seafood + 446: Seahorse + 447: Seat belt + 448: Segway + 449: Serving tray + 450: Sewing machine + 451: Shark + 452: Sheep + 453: Shelf + 454: Shellfish + 455: Shirt + 456: Shorts + 457: Shotgun + 458: Shower + 459: Shrimp + 460: Sink + 461: Skateboard + 462: Ski + 463: Skirt + 464: Skull + 465: Skunk + 466: Skyscraper + 467: Slow cooker + 468: Snack + 469: Snail + 470: Snake + 471: Snowboard + 472: Snowman + 473: Snowmobile + 474: Snowplow + 475: Soap dispenser + 476: Sock + 477: Sofa bed + 478: Sombrero + 479: Sparrow + 480: Spatula + 481: Spice rack + 482: Spider + 483: Spoon + 484: Sports equipment + 485: Sports uniform + 486: Squash (Plant) + 487: Squid + 488: Squirrel + 489: Stairs + 490: Stapler + 491: Starfish + 492: Stationary bicycle + 493: Stethoscope + 494: Stool + 495: Stop sign + 496: Strawberry + 497: Street light + 498: Stretcher + 499: Studio couch + 500: Submarine + 501: Submarine sandwich + 502: Suit + 503: Suitcase + 504: Sun hat + 505: Sunglasses + 506: Surfboard + 507: Sushi + 508: Swan + 509: Swim cap + 510: Swimming pool + 511: Swimwear + 512: Sword + 513: Syringe + 514: Table + 515: Table tennis racket + 516: Tablet computer + 517: Tableware + 518: Taco + 519: Tank + 520: Tap + 521: Tart + 522: Taxi + 523: Tea + 524: Teapot + 525: Teddy bear + 526: Telephone + 527: Television + 528: Tennis ball + 529: Tennis racket + 530: Tent + 531: Tiara + 532: Tick + 533: Tie + 534: Tiger + 535: Tin can + 536: Tire + 537: Toaster + 538: Toilet + 539: Toilet paper + 540: Tomato + 541: Tool + 542: Toothbrush + 543: Torch + 544: Tortoise + 545: Towel + 546: Tower + 547: Toy + 548: Traffic light + 549: Traffic sign + 550: Train + 551: Training bench + 552: Treadmill + 553: Tree + 554: Tree house + 555: Tripod + 556: Trombone + 557: Trousers + 558: Truck + 559: Trumpet + 560: Turkey + 561: Turtle + 562: Umbrella + 563: Unicycle + 564: Van + 565: Vase + 566: Vegetable + 567: Vehicle + 568: Vehicle registration plate + 569: Violin + 570: Volleyball (Ball) + 571: Waffle + 572: Waffle iron + 573: Wall clock + 574: Wardrobe + 575: Washing machine + 576: Waste container + 577: Watch + 578: Watercraft + 579: Watermelon + 580: Weapon + 581: Whale + 582: Wheel + 583: Wheelchair + 584: Whisk + 585: Whiteboard + 586: Willow + 587: Window + 588: Window blind + 589: Wine + 590: Wine glass + 591: Wine rack + 592: Winter melon + 593: Wok + 594: Woman + 595: Wood-burning stove + 596: Woodpecker + 597: Worm + 598: Wrench + 599: Zebra + 600: Zucchini + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import warnings + + from ultralytics.utils import LOGGER, SETTINGS, Path, get_ubuntu_version, is_ubuntu + from ultralytics.utils.checks import check_requirements, check_version + + check_requirements("fiftyone") + if is_ubuntu() and check_version(get_ubuntu_version(), ">=22.04"): + # Ubuntu>=22.04 patch https://github.com/voxel51/fiftyone/issues/2961#issuecomment-1666519347 + check_requirements("fiftyone-db-ubuntu2204") + + import fiftyone as fo + import fiftyone.zoo as foz + + name = "open-images-v7" + fo.config.dataset_zoo_dir = Path(SETTINGS["datasets_dir"]) / "fiftyone" / name + fraction = 1.0 # fraction of full dataset to use + LOGGER.warning("WARNING ⚠️ Open Images V7 dataset requires at least **561 GB of free space. Starting download...") + for split in "train", "validation": # 1743042 train, 41620 val images + train = split == "train" + + # Load Open Images dataset + dataset = foz.load_zoo_dataset( + name, + split=split, + label_types=["detections"], + max_samples=round((1743042 if train else 41620) * fraction), + ) + + # Define classes + if train: + classes = dataset.default_classes # all classes + # classes = dataset.distinct('ground_truth.detections.label') # only observed classes + + # Export to YOLO format + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="fiftyone.utils.yolo") + dataset.export( + export_dir=str(Path(SETTINGS["datasets_dir"]) / name), + dataset_type=fo.types.YOLOv5Dataset, + label_field="ground_truth", + split="val" if split == "validation" else split, + classes=classes, + overwrite=train, + ) diff --git a/tracking/ultralytics/cfg/datasets/package-seg.yaml b/tracking/ultralytics/cfg/datasets/package-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..433ca04c7fe58a6f7da6130acda50e335901fbb1 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/package-seg.yaml @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Package-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/package-seg/ +# Example usage: yolo train data=package-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── package-seg ← downloads here (102 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/package-seg # dataset root dir +train: train/images # train images (relative to 'path') 1920 images +val: valid/images # val images (relative to 'path') 89 images +test: test/images # test images (relative to 'path') 188 images + +# Classes +names: + 0: package + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/package-seg.zip diff --git a/tracking/ultralytics/cfg/datasets/signature.yaml b/tracking/ultralytics/cfg/datasets/signature.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c9d5c338e95b566cd8e6f56294e32dc6c5ab323 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/signature.yaml @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Signature dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/signature/ +# Example usage: yolo train data=signature.yaml +# parent +# ├── ultralytics +# └── datasets +# └── signature ← downloads here (11.2 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/signature # dataset root dir +train: train/images # train images (relative to 'path') 143 images +val: valid/images # val images (relative to 'path') 35 images + +# Classes +names: + 0: signature + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/signature.zip diff --git a/tracking/ultralytics/cfg/datasets/tiger-pose.yaml b/tracking/ultralytics/cfg/datasets/tiger-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b3f7b71761e475cff2aa3c97c8301c848e0844f --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/tiger-pose.yaml @@ -0,0 +1,25 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Tiger Pose dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/tiger-pose/ +# Example usage: yolo train data=tiger-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── tiger-pose ← downloads here (75.3 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/tiger-pose # dataset root dir +train: train # train images (relative to 'path') 210 images +val: val # val images (relative to 'path') 53 images + +# Keypoints +kpt_shape: [12, 2] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + +# Classes +names: + 0: tiger + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/tiger-pose.zip diff --git a/tracking/ultralytics/cfg/datasets/xView.yaml b/tracking/ultralytics/cfg/datasets/xView.yaml new file mode 100644 index 0000000000000000000000000000000000000000..265a74341d9edbb8e05da983c471492e741cd783 --- /dev/null +++ b/tracking/ultralytics/cfg/datasets/xView.yaml @@ -0,0 +1,155 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA) +# -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! -------- +# Documentation: https://docs.ultralytics.com/datasets/detect/xview/ +# Example usage: yolo train data=xView.yaml +# parent +# ├── ultralytics +# └── datasets +# └── xView ← downloads here (20.7 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/xView # dataset root dir +train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images +val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images + +# Classes +names: + 0: Fixed-wing Aircraft + 1: Small Aircraft + 2: Cargo Plane + 3: Helicopter + 4: Passenger Vehicle + 5: Small Car + 6: Bus + 7: Pickup Truck + 8: Utility Truck + 9: Truck + 10: Cargo Truck + 11: Truck w/Box + 12: Truck Tractor + 13: Trailer + 14: Truck w/Flatbed + 15: Truck w/Liquid + 16: Crane Truck + 17: Railway Vehicle + 18: Passenger Car + 19: Cargo Car + 20: Flat Car + 21: Tank car + 22: Locomotive + 23: Maritime Vessel + 24: Motorboat + 25: Sailboat + 26: Tugboat + 27: Barge + 28: Fishing Vessel + 29: Ferry + 30: Yacht + 31: Container Ship + 32: Oil Tanker + 33: Engineering Vehicle + 34: Tower crane + 35: Container Crane + 36: Reach Stacker + 37: Straddle Carrier + 38: Mobile Crane + 39: Dump Truck + 40: Haul Truck + 41: Scraper/Tractor + 42: Front loader/Bulldozer + 43: Excavator + 44: Cement Mixer + 45: Ground Grader + 46: Hut/Tent + 47: Shed + 48: Building + 49: Aircraft Hangar + 50: Damaged Building + 51: Facility + 52: Construction Site + 53: Vehicle Lot + 54: Helipad + 55: Storage Tank + 56: Shipping container lot + 57: Shipping Container + 58: Pylon + 59: Tower + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import json + import os + from pathlib import Path + + import numpy as np + from PIL import Image + from tqdm import tqdm + + from ultralytics.data.utils import autosplit + from ultralytics.utils.ops import xyxy2xywhn + + + def convert_labels(fname=Path("xView/xView_train.geojson")): + """Converts xView geoJSON labels to YOLO format, mapping classes to indices 0-59 and saving as text files.""" + path = fname.parent + with open(fname, encoding="utf-8") as f: + print(f"Loading {fname}...") + data = json.load(f) + + # Make dirs + labels = Path(path / "labels" / "train") + os.system(f"rm -rf {labels}") + labels.mkdir(parents=True, exist_ok=True) + + # xView classes 11-94 to 0-59 + xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11, + 12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1, + 29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46, + 47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59] + + shapes = {} + for feature in tqdm(data["features"], desc=f"Converting {fname}"): + p = feature["properties"] + if p["bounds_imcoords"]: + id = p["image_id"] + file = path / "train_images" / id + if file.exists(): # 1395.tif missing + try: + box = np.array([int(num) for num in p["bounds_imcoords"].split(",")]) + assert box.shape[0] == 4, f"incorrect box shape {box.shape[0]}" + cls = p["type_id"] + cls = xview_class2index[int(cls)] # xView class to 0-60 + assert 59 >= cls >= 0, f"incorrect class index {cls}" + + # Write YOLO label + if id not in shapes: + shapes[id] = Image.open(file).size + box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True) + with open((labels / id).with_suffix(".txt"), "a", encoding="utf-8") as f: + f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt + except Exception as e: + print(f"WARNING: skipping one label for {file}: {e}") + + + # Download manually from https://challenge.xviewdataset.org + dir = Path(yaml["path"]) # dataset root dir + # urls = [ + # "https://d307kc0mrhucc3.cloudfront.net/train_labels.zip", # train labels + # "https://d307kc0mrhucc3.cloudfront.net/train_images.zip", # 15G, 847 train images + # "https://d307kc0mrhucc3.cloudfront.net/val_images.zip", # 5G, 282 val images (no labels) + # ] + # download(urls, dir=dir) + + # Convert labels + convert_labels(dir / "xView_train.geojson") + + # Move images + images = Path(dir / "images") + images.mkdir(parents=True, exist_ok=True) + Path(dir / "train_images").rename(dir / "images" / "train") + Path(dir / "val_images").rename(dir / "images" / "val") + + # Split + autosplit(dir / "images" / "train") diff --git a/tracking/ultralytics/cfg/default.yaml b/tracking/ultralytics/cfg/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b6ba0fd252a8d54f41de0317be275b8b78c450d --- /dev/null +++ b/tracking/ultralytics/cfg/default.yaml @@ -0,0 +1,128 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Global configuration YAML with settings and hyperparameters for YOLO training, validation, prediction and export +# For documentation see https://docs.ultralytics.com/usage/cfg/ + +task: detect # (str) YOLO task, i.e. detect, segment, classify, pose, obb +mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark + +# Train settings ------------------------------------------------------------------------------------------------------- +model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml +data: # (str, optional) path to data file, i.e. coco8.yaml +epochs: 100 # (int) number of epochs to train for +time: # (float, optional) number of hours to train for, overrides epochs if supplied +patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training +batch: 16 # (int) number of images per batch (-1 for AutoBatch) +imgsz: 640 # (int | list) input images size as int for train and val modes, or list[h,w] for predict and export modes +save: True # (bool) save train checkpoints and predict results +save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1) +cache: False # (bool) True/ram, disk or False. Use cache for data loading +device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu +workers: 8 # (int) number of worker threads for data loading (per RANK if DDP) +project: # (str, optional) project name +name: # (str, optional) experiment name, results saved to 'project/name' directory +exist_ok: False # (bool) whether to overwrite existing experiment +pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str) +optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] +verbose: True # (bool) whether to print verbose output +seed: 0 # (int) random seed for reproducibility +deterministic: True # (bool) whether to enable deterministic mode +single_cls: False # (bool) train multi-class data as single-class +rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val' +cos_lr: False # (bool) use cosine learning rate scheduler +close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable) +resume: False # (bool) resume training from last checkpoint +amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check +fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set) +profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers +freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training +multi_scale: False # (bool) Whether to use multiscale during training +# Segmentation +overlap_mask: True # (bool) merge object masks into a single image mask during training (segment train only) +mask_ratio: 4 # (int) mask downsample ratio (segment train only) +# Classification +dropout: 0.0 # (float) use dropout regularization (classify train only) + +# Val/Test settings ---------------------------------------------------------------------------------------------------- +val: True # (bool) validate/test during training +split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train' +save_json: False # (bool) save results to JSON file +save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions) +conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val) +iou: 0.7 # (float) intersection over union (IoU) threshold for NMS +max_det: 300 # (int) maximum number of detections per image +half: False # (bool) use half precision (FP16) +dnn: False # (bool) use OpenCV DNN for ONNX inference +plots: True # (bool) save plots and images during train/val + +# Predict settings ----------------------------------------------------------------------------------------------------- +source: # (str, optional) source directory for images or videos +vid_stride: 1 # (int) video frame-rate stride +stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False) +visualize: False # (bool) visualize model features +augment: False # (bool) apply image augmentation to prediction sources +agnostic_nms: False # (bool) class-agnostic NMS +classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3] +retina_masks: False # (bool) use high-resolution segmentation masks +embed: # (list[int], optional) return feature vectors/embeddings from given layers + +# Visualize settings --------------------------------------------------------------------------------------------------- +show: False # (bool) show predicted images and videos if environment allows +save_frames: False # (bool) save predicted individual video frames +save_txt: False # (bool) save results as .txt file +save_conf: False # (bool) save results with confidence scores +save_crop: False # (bool) save cropped images with results +show_labels: True # (bool) show prediction labels, i.e. 'person' +show_conf: True # (bool) show prediction confidence, i.e. '0.99' +show_boxes: True # (bool) show prediction boxes +line_width: # (int, optional) line width of the bounding boxes. Scaled to image size if None. + +# Export settings ------------------------------------------------------------------------------------------------------ +format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats +keras: False # (bool) use Kera=s +optimize: False # (bool) TorchScript: optimize for mobile +int8: False # (bool) CoreML/TF INT8 quantization +dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes +simplify: True # (bool) ONNX: simplify model using `onnxslim` +opset: # (int, optional) ONNX: opset version +workspace: None # (float, optional) TensorRT: workspace size (GiB), `None` will let TensorRT auto-allocate memory +nms: False # (bool) CoreML: add NMS + +# Hyperparameters ------------------------------------------------------------------------------------------------------ +lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3) +lrf: 0.01 # (float) final learning rate (lr0 * lrf) +momentum: 0.937 # (float) SGD momentum/Adam beta1 +weight_decay: 0.0005 # (float) optimizer weight decay 5e-4 +warmup_epochs: 3.0 # (float) warmup epochs (fractions ok) +warmup_momentum: 0.8 # (float) warmup initial momentum +warmup_bias_lr: 0.1 # (float) warmup initial bias lr +box: 7.5 # (float) box loss gain +cls: 0.5 # (float) cls loss gain (scale with pixels) +dfl: 1.5 # (float) dfl loss gain +pose: 12.0 # (float) pose loss gain +kobj: 1.0 # (float) keypoint obj loss gain +nbs: 64 # (int) nominal batch size +hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction) +degrees: 0.0 # (float) image rotation (+/- deg) +translate: 0.1 # (float) image translation (+/- fraction) +scale: 0.5 # (float) image scale (+/- gain) +shear: 0.0 # (float) image shear (+/- deg) +perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # (float) image flip up-down (probability) +fliplr: 0.5 # (float) image flip left-right (probability) +bgr: 0.0 # (float) image channel BGR (probability) +mosaic: 1.0 # (float) image mosaic (probability) +mixup: 0.0 # (float) image mixup (probability) +copy_paste: 0.0 # (float) segment copy-paste (probability) +copy_paste_mode: "flip" # (str) the method to do copy_paste augmentation (flip, mixup) +auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix) +erasing: 0.4 # (float) probability of random erasing during classification training (0-0.9), 0 means no erasing, must be less than 1.0. +crop_fraction: 1.0 # (float) image crop fraction for classification (0.1-1), 1.0 means no crop, must be greater than 0. + +# Custom config.yaml --------------------------------------------------------------------------------------------------- +cfg: # (str, optional) for overriding defaults.yaml + +# Tracker settings ------------------------------------------------------------------------------------------------------ +tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml] diff --git a/tracking/ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml b/tracking/ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c427f855b006bd91314e0a5dfe1a9ec9357cb6de --- /dev/null +++ b/tracking/ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml @@ -0,0 +1,17 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-cls image classification model with ResNet18 backbone +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes + +# ResNet18 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, TorchVision, [512, resnet18, DEFAULT, True, 2]] # truncate two layers from the end + +# YOLO11n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/tracking/ultralytics/cfg/models/11/yolo11-cls.yaml b/tracking/ultralytics/cfg/models/11/yolo11-cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..753e27b6ddf2e1e93a22c6bf1bf2e8c1b062cfb9 --- /dev/null +++ b/tracking/ultralytics/cfg/models/11/yolo11-cls.yaml @@ -0,0 +1,33 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-cls image classification model +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 86 layers, 1633584 parameters, 1633584 gradients, 0.5 GFLOPs + s: [0.50, 0.50, 1024] # summary: 86 layers, 5545488 parameters, 5545488 gradients, 1.6 GFLOPs + m: [0.50, 1.00, 512] # summary: 106 layers, 10455696 parameters, 10455696 gradients, 5.0 GFLOPs + l: [1.00, 1.00, 512] # summary: 176 layers, 12937104 parameters, 12937104 gradients, 6.2 GFLOPs + x: [1.00, 1.50, 512] # summary: 176 layers, 28458544 parameters, 28458544 gradients, 13.7 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 2, C2PSA, [1024]] # 9 + +# YOLO11n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/tracking/ultralytics/cfg/models/11/yolo11-obb.yaml b/tracking/ultralytics/cfg/models/11/yolo11-obb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ca7c60dde6c7cf98cdbf6cb3a5425f2d0d0da18 --- /dev/null +++ b/tracking/ultralytics/cfg/models/11/yolo11-obb.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/obb + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-obb.yaml' will call yolo11-obb.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 196 layers, 2695747 parameters, 2695731 gradients, 6.9 GFLOPs + s: [0.50, 0.50, 1024] # summary: 196 layers, 9744931 parameters, 9744915 gradients, 22.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 246 layers, 20963523 parameters, 20963507 gradients, 72.2 GFLOPs + l: [1.00, 1.00, 512] # summary: 372 layers, 26220995 parameters, 26220979 gradients, 91.3 GFLOPs + x: [1.00, 1.50, 512] # summary: 372 layers, 58875331 parameters, 58875315 gradients, 204.3 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, OBB, [nc, 1]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/11/yolo11-pose.yaml b/tracking/ultralytics/cfg/models/11/yolo11-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32766e7fe71da4461f018b98430bb99b0845b1a8 --- /dev/null +++ b/tracking/ultralytics/cfg/models/11/yolo11-pose.yaml @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-pose keypoints/pose estimation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 80 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolo11n-pose.yaml' will call yolo11.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 196 layers, 2908507 parameters, 2908491 gradients, 7.7 GFLOPs + s: [0.50, 0.50, 1024] # summary: 196 layers, 9948811 parameters, 9948795 gradients, 23.5 GFLOPs + m: [0.50, 1.00, 512] # summary: 246 layers, 20973273 parameters, 20973257 gradients, 72.3 GFLOPs + l: [1.00, 1.00, 512] # summary: 372 layers, 26230745 parameters, 26230729 gradients, 91.4 GFLOPs + x: [1.00, 1.50, 512] # summary: 372 layers, 58889881 parameters, 58889865 gradients, 204.3 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, Pose, [nc, kpt_shape]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/11/yolo11-seg.yaml b/tracking/ultralytics/cfg/models/11/yolo11-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1186666c6ca899290be6470962f3cfb67f55f419 --- /dev/null +++ b/tracking/ultralytics/cfg/models/11/yolo11-seg.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-seg.yaml' will call yolo11-seg.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 203 layers, 2876848 parameters, 2876832 gradients, 10.5 GFLOPs + s: [0.50, 0.50, 1024] # summary: 203 layers, 10113248 parameters, 10113232 gradients, 35.8 GFLOPs + m: [0.50, 1.00, 512] # summary: 253 layers, 22420896 parameters, 22420880 gradients, 123.9 GFLOPs + l: [1.00, 1.00, 512] # summary: 379 layers, 27678368 parameters, 27678352 gradients, 143.0 GFLOPs + x: [1.00, 1.50, 512] # summary: 379 layers, 62142656 parameters, 62142640 gradients, 320.2 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/11/yolo11.yaml b/tracking/ultralytics/cfg/models/11/yolo11.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c90c444859b460a870d97235e0bbc74fc45e599d --- /dev/null +++ b/tracking/ultralytics/cfg/models/11/yolo11.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 181 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs + s: [0.50, 0.50, 1024] # summary: 181 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 231 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs + l: [1.00, 1.00, 512] # summary: 357 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs + x: [1.00, 1.50, 512] # summary: 357 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/12/yolo12-cls.yaml b/tracking/ultralytics/cfg/models/12/yolo12-cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b0e15ceeba58ccb10378a205e90cd28c0fc4e79b --- /dev/null +++ b/tracking/ultralytics/cfg/models/12/yolo12-cls.yaml @@ -0,0 +1,32 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLO12-cls image classification model +# Model docs: https://docs.ultralytics.com/models/yolo12 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo12n-cls.yaml' will call yolo12-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 152 layers, 1,820,976 parameters, 1,820,976 gradients, 3.7 GFLOPs + s: [0.50, 0.50, 1024] # summary: 152 layers, 6,206,992 parameters, 6,206,992 gradients, 13.6 GFLOPs + m: [0.50, 1.00, 512] # summary: 172 layers, 12,083,088 parameters, 12,083,088 gradients, 44.2 GFLOPs + l: [1.00, 1.00, 512] # summary: 312 layers, 15,558,640 parameters, 15,558,640 gradients, 56.9 GFLOPs + x: [1.00, 1.50, 512] # summary: 312 layers, 34,172,592 parameters, 34,172,592 gradients, 126.5 GFLOPs + +# YOLO12n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] # 8 + +# YOLO12n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/tracking/ultralytics/cfg/models/12/yolo12-obb.yaml b/tracking/ultralytics/cfg/models/12/yolo12-obb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e5c36f0c0e095e2f0f23902821773b20fdf81d73 --- /dev/null +++ b/tracking/ultralytics/cfg/models/12/yolo12-obb.yaml @@ -0,0 +1,48 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLO12-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo12 +# Task docs: https://docs.ultralytics.com/tasks/obb + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo12n-obb.yaml' will call yolo12-obb.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 287 layers, 2,673,955 parameters, 2,673,939 gradients, 6.9 GFLOPs + s: [0.50, 0.50, 1024] # summary: 287 layers, 9,570,275 parameters, 9,570,259 gradients, 22.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 307 layers, 21,048,003 parameters, 21,047,987 gradients, 71.8 GFLOPs + l: [1.00, 1.00, 512] # summary: 503 layers, 27,299,619 parameters, 27,299,603 gradients, 93.4 GFLOPs + x: [1.00, 1.50, 512] # summary: 503 layers, 61,119,939 parameters, 61,119,923 gradients, 208.6 GFLOPs + +# YOLO12n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] # 8 + +# YOLO12n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, A2C2f, [512, False, -1]] # 11 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, A2C2f, [256, False, -1]] # 14 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P4 + - [-1, 2, A2C2f, [512, False, -1]] # 17 + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 8], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large) + + - [[14, 17, 20], 1, OBB, [nc, 1]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/12/yolo12-pose.yaml b/tracking/ultralytics/cfg/models/12/yolo12-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..104a1865e5f57bec3f91505e71412555e7c49694 --- /dev/null +++ b/tracking/ultralytics/cfg/models/12/yolo12-pose.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLO12-pose keypoints/pose estimation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo12 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 80 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolo12n-pose.yaml' will call yolo12-pose.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 287 layers, 2,886,715 parameters, 2,886,699 gradients, 7.8 GFLOPs + s: [0.50, 0.50, 1024] # summary: 287 layers, 9,774,155 parameters, 9,774,139 gradients, 23.5 GFLOPs + m: [0.50, 1.00, 512] # summary: 307 layers, 21,057,753 parameters, 21,057,737 gradients, 71.8 GFLOPs + l: [1.00, 1.00, 512] # summary: 503 layers, 27,309,369 parameters, 27,309,353 gradients, 93.5 GFLOPs + x: [1.00, 1.50, 512] # summary: 503 layers, 61,134,489 parameters, 61,134,473 gradients, 208.7 GFLOPs + +# YOLO12n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] # 8 + +# YOLO12n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, A2C2f, [512, False, -1]] # 11 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, A2C2f, [256, False, -1]] # 14 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P4 + - [-1, 2, A2C2f, [512, False, -1]] # 17 + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 8], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large) + + - [[14, 17, 20], 1, Pose, [nc, kpt_shape]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/12/yolo12-seg.yaml b/tracking/ultralytics/cfg/models/12/yolo12-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d03a3e2e503a50816aab9c91658fae0fbe9df5c --- /dev/null +++ b/tracking/ultralytics/cfg/models/12/yolo12-seg.yaml @@ -0,0 +1,48 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLO12-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo12 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo12n-seg.yaml' will call yolo12-seg.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 294 layers, 2,855,056 parameters, 2,855,040 gradients, 10.6 GFLOPs + s: [0.50, 0.50, 1024] # summary: 294 layers, 9,938,592 parameters, 9,938,576 gradients, 35.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 314 layers, 22,505,376 parameters, 22,505,360 gradients, 123.5 GFLOPs + l: [1.00, 1.00, 512] # summary: 510 layers, 28,756,992 parameters, 28,756,976 gradients, 145.1 GFLOPs + x: [1.00, 1.50, 512] # summary: 510 layers, 64,387,264 parameters, 64,387,248 gradients, 324.6 GFLOPs + +# YOLO12n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] # 8 + +# YOLO12n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, A2C2f, [512, False, -1]] # 11 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, A2C2f, [256, False, -1]] # 14 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P4 + - [-1, 2, A2C2f, [512, False, -1]] # 17 + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 8], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large) + + - [[14, 17, 20], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/12/yolo12.yaml b/tracking/ultralytics/cfg/models/12/yolo12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..737c03339b7ef54abb1db151f7d7fa4c60c3a923 --- /dev/null +++ b/tracking/ultralytics/cfg/models/12/yolo12.yaml @@ -0,0 +1,48 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLO12 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo12 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo12n.yaml' will call yolo12.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 272 layers, 2,602,288 parameters, 2,602,272 gradients, 6.7 GFLOPs + s: [0.50, 0.50, 1024] # summary: 272 layers, 9,284,096 parameters, 9,284,080 gradients, 21.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 292 layers, 20,199,168 parameters, 20,199,152 gradients, 68.1 GFLOPs + l: [1.00, 1.00, 512] # summary: 488 layers, 26,450,784 parameters, 26,450,768 gradients, 89.7 GFLOPs + x: [1.00, 1.50, 512] # summary: 488 layers, 59,210,784 parameters, 59,210,768 gradients, 200.3 GFLOPs + +# YOLO12n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] # 8 + +# YOLO12n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, A2C2f, [512, False, -1]] # 11 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, A2C2f, [256, False, -1]] # 14 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P4 + - [-1, 2, A2C2f, [512, False, -1]] # 17 + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 8], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large) + + - [[14, 17, 20], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/README.md b/tracking/ultralytics/cfg/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..68a9238384ec4c8f08c1db1fd6bad95e824d96d4 --- /dev/null +++ b/tracking/ultralytics/cfg/models/README.md @@ -0,0 +1,48 @@ +## Models + +Welcome to the [Ultralytics](https://www.ultralytics.com/) Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks. + +These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs. + +To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full details at the Ultralytics [Docs](https://docs.ultralytics.com/models/), and if you need help or have any questions, feel free to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now! + +### Usage + +Model `*.yaml` files may be used directly in the [Command Line Interface (CLI)](https://docs.ultralytics.com/usage/cli/) with a `yolo` command: + +```bash +# Train a YOLO11n model using the coco8 dataset for 100 epochs +yolo task=detect mode=train model=yolo11n.yaml data=coco8.yaml epochs=100 +``` + +They may also be used directly in a Python environment, and accept the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: + +```python +from ultralytics import YOLO + +# Initialize a YOLO11n model from a YAML configuration file +model = YOLO("model.yaml") + +# If a pre-trained model is available, use it instead +# model = YOLO("model.pt") + +# Display model information +model.info() + +# Train the model using the COCO8 dataset for 100 epochs +model.train(data="coco8.yaml", epochs=100) +``` + +## Pre-trained Model Architectures + +Ultralytics supports many model architectures. Visit [Ultralytics Models](https://docs.ultralytics.com/models/) to view detailed information and usage. Any of these models can be used by loading their configurations or pretrained checkpoints if available. + +## Contribute New Models + +Have you trained a new YOLO variant or achieved state-of-the-art performance with specific tuning? We'd love to showcase your work in our Models section! Contributions from the community in the form of new models, architectures, or optimizations are highly valued and can significantly enrich our repository. + +By contributing to this section, you're helping us offer a wider array of model choices and configurations to the community. It's a fantastic way to share your knowledge and expertise while making the Ultralytics YOLO ecosystem even more versatile. + +To get started, please consult our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) for step-by-step instructions on how to submit a Pull Request (PR) 🛠️. Your contributions are eagerly awaited! + +Let's join hands to extend the range and capabilities of the Ultralytics YOLO models 🙏! diff --git a/tracking/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8d6b4f410be5cfdea18c1d3dae48501e443fec2 --- /dev/null +++ b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml @@ -0,0 +1,53 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-l hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, HGStem, [32, 48]] # 0-P2/4 + - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 + + - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 + - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 + + - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 + - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut + - [-1, 6, HGBlock, [192, 1024, 5, True, True]] + - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 + + - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 + - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 + - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 + - [[-1, 17], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 + - [[-1, 12], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 + + - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b13e94512bd5f995937a995002d66c93bff7803f --- /dev/null +++ b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-ResNet101 hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 23]] # 3 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 5 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 7 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 11 + - [-1, 1, Conv, [256, 1, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (16), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 17, downsample_convs.0 + - [[-1, 12], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (19), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.1 + - [[-1, 7], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (22), pan_blocks.1 + + - [[16, 19, 22], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8172ad4ed4c4fe9263b87d2595a61625d0644ad2 --- /dev/null +++ b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-ResNet50 hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 6]] # 3 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 5 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 7 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 11 + - [-1, 1, Conv, [256, 1, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (16), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 17, downsample_convs.0 + - [[-1, 12], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (19), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.1 + - [[-1, 7], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (22), pan_blocks.1 + + - [[16, 19, 22], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9c4a19c8ab919d06a68ff655d6ee788ac5b23a0 --- /dev/null +++ b/tracking/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml @@ -0,0 +1,57 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-x hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + x: [1.00, 1.00, 2048] + +backbone: + # [from, repeats, module, args] + - [-1, 1, HGStem, [32, 64]] # 0-P2/4 + - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 + + - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 + - [-1, 6, HGBlock, [128, 512, 3]] + - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 + + - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 + - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 + + - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 + - [-1, 6, HGBlock, [512, 2048, 5, True, False]] + - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 + +head: + - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 + - [-1, 1, AIFI, [2048, 8]] + - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 + - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 + + - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 + - [[-1, 21], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 + + - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 + - [[-1, 16], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 + + - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v10/yolov10b.yaml b/tracking/ultralytics/cfg/models/v10/yolov10b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..750379128cc77d22329cbe9315ce5a644b722baa --- /dev/null +++ b/tracking/ultralytics/cfg/models/v10/yolov10b.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10b object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + b: [0.67, 1.00, 512] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v10/yolov10l.yaml b/tracking/ultralytics/cfg/models/v10/yolov10l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1dedd752e2372c186ffef0e10e8272227af1ce27 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v10/yolov10l.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10l object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 512] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v10/yolov10m.yaml b/tracking/ultralytics/cfg/models/v10/yolov10m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ba4020b3309322ca22dd84d107f3dd75802d5fa --- /dev/null +++ b/tracking/ultralytics/cfg/models/v10/yolov10m.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10m object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + m: [0.67, 0.75, 768] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v10/yolov10n.yaml b/tracking/ultralytics/cfg/models/v10/yolov10n.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9aa7018950c2897c6efb3f38ad780016db8817b --- /dev/null +++ b/tracking/ultralytics/cfg/models/v10/yolov10n.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10n object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v10/yolov10s.yaml b/tracking/ultralytics/cfg/models/v10/yolov10s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dbb678b277d72bd2aabca2473974193495157559 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v10/yolov10s.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10s object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + s: [0.33, 0.50, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v10/yolov10x.yaml b/tracking/ultralytics/cfg/models/v10/yolov10x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57482133863ee7eee26b66f83d2d6567c9fa9baf --- /dev/null +++ b/tracking/ultralytics/cfg/models/v10/yolov10x.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10x object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + x: [1.00, 1.25, 512] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2fCIB, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v3/yolov3-spp.yaml b/tracking/ultralytics/cfg/models/v3/yolov3-spp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6aef25ab748bc84ee6ad054057029de97e7e1b14 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v3/yolov3-spp.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv3-SPP object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov3 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# darknet53 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [32, 3, 1]] # 0 + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Bottleneck, [64]] + - [-1, 1, Conv, [128, 3, 2]] # 3-P2/4 + - [-1, 2, Bottleneck, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 5-P3/8 + - [-1, 8, Bottleneck, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 7-P4/16 + - [-1, 8, Bottleneck, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P5/32 + - [-1, 4, Bottleneck, [1024]] # 10 + +# YOLOv3-SPP head +head: + - [-1, 1, Bottleneck, [1024, False]] + - [-1, 1, SPP, [512, [5, 9, 13]]] + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] # 15 (P5/32-large) + + - [-2, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 22 (P4/16-medium) + + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Bottleneck, [256, False]] + - [-1, 2, Bottleneck, [256, False]] # 27 (P3/8-small) + + - [[27, 22, 15], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v3/yolov3-tiny.yaml b/tracking/ultralytics/cfg/models/v3/yolov3-tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91a0bb03f7d8f8436cd4a09cb8e46d8e38484c5c --- /dev/null +++ b/tracking/ultralytics/cfg/models/v3/yolov3-tiny.yaml @@ -0,0 +1,40 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv3-tiiny object detection model with P4/16 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov3 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# YOLOv3-tiny backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [16, 3, 1]] # 0 + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 1-P1/2 + - [-1, 1, Conv, [32, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 3-P2/4 + - [-1, 1, Conv, [64, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 5-P3/8 + - [-1, 1, Conv, [128, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 7-P4/16 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 9-P5/32 + - [-1, 1, Conv, [512, 3, 1]] + - [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]] # 11 + - [-1, 1, nn.MaxPool2d, [2, 1, 0]] # 12 + +# YOLOv3-tiny head +head: + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 15 (P5/32-large) + + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Conv, [256, 3, 1]] # 19 (P4/16-medium) + + - [[19, 15], 1, Detect, [nc]] # Detect(P4, P5) diff --git a/tracking/ultralytics/cfg/models/v3/yolov3.yaml b/tracking/ultralytics/cfg/models/v3/yolov3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..95c99de52be649df826589f73da346d4a75dd05f --- /dev/null +++ b/tracking/ultralytics/cfg/models/v3/yolov3.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv3 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov3 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# darknet53 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [32, 3, 1]] # 0 + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Bottleneck, [64]] + - [-1, 1, Conv, [128, 3, 2]] # 3-P2/4 + - [-1, 2, Bottleneck, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 5-P3/8 + - [-1, 8, Bottleneck, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 7-P4/16 + - [-1, 8, Bottleneck, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P5/32 + - [-1, 4, Bottleneck, [1024]] # 10 + +# YOLOv3 head +head: + - [-1, 1, Bottleneck, [1024, False]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] # 15 (P5/32-large) + + - [-2, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 22 (P4/16-medium) + + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Bottleneck, [256, False]] + - [-1, 2, Bottleneck, [256, False]] # 27 (P3/8-small) + + - [[27, 22, 15], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v5/yolov5-p6.yaml b/tracking/ultralytics/cfg/models/v5/yolov5-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..376d1aba90c24c0df93c14af4a45f5b98b6dbf02 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v5/yolov5-p6.yaml @@ -0,0 +1,62 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv5 object detection model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov5 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.33, 1.25, 1024] + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 9, C3, [512]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C3, [768]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C3, [1024]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv5 v6.0 head +head: + - [-1, 1, Conv, [768, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C3, [768, False]] # 15 + + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3, [512, False]] # 19 + + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3, [256, False]] # 23 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 20], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3, [512, False]] # 26 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 16], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3, [768, False]] # 29 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P6 + - [-1, 3, C3, [1024, False]] # 32 (P6/64-xlarge) + + - [[23, 26, 29, 32], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/tracking/ultralytics/cfg/models/v5/yolov5.yaml b/tracking/ultralytics/cfg/models/v5/yolov5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76a4749ae4f102cdee4db2a14dae4912476736c4 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v5/yolov5.yaml @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv5 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov5 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.33, 1.25, 1024] + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 9, C3, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3, [1024]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv5 v6.0 head +head: + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3, [512, False]] # 13 + + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3, [256, False]] # 17 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3, [512, False]] # 20 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3, [1024, False]] # 23 (P5/32-large) + + - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v6/yolov6.yaml b/tracking/ultralytics/cfg/models/v6/yolov6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a45224e570b5ffcc85d3711ef2bba2ee7b73013 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v6/yolov6.yaml @@ -0,0 +1,56 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Meituan YOLOv6 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov6 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +activation: torch.nn.ReLU() # (optional) model default activation function +scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv6-3.0s backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 6, Conv, [128, 3, 1]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 12, Conv, [256, 3, 1]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 18, Conv, [512, 3, 1]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 6, Conv, [1024, 3, 1]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv6-3.0s head +head: + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 9, Conv, [256, 3, 1]] # 14 + + - [-1, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Conv, [128, 3, 1]] + - [-1, 9, Conv, [128, 3, 1]] # 19 + + - [-1, 1, Conv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P4 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 9, Conv, [256, 3, 1]] # 23 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 1, Conv, [512, 3, 1]] + - [-1, 9, Conv, [512, 3, 1]] # 27 + + - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44cc00ebf2243977441dc3e492eb860f8c442267 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml @@ -0,0 +1,28 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-cls image classification model with ResNet101 backbone +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0-P1/2 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1-P2/4 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2-P3/8 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 23]] # 3-P4/16 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4-P5/32 + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d05e0753fcd22195eaecd4e5dde22fb58750a14 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml @@ -0,0 +1,28 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-cls image classification model with ResNet50 backbone +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0-P1/2 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1-P2/4 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2-P3/8 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 6]] # 3-P4/16 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4-P5/32 + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-cls.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e346e5e1b76164894a01f0640d30145d4161ce53 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-cls.yaml @@ -0,0 +1,32 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-cls image classification model with YOLO backbone +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08209a1ea130a5d97908b95dd63a01914837dc0d --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml @@ -0,0 +1,58 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P2/4 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost-p2 summary: 290 layers, 2033944 parameters, 2033928 gradients, 13.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost-p2 summary: 290 layers, 5562080 parameters, 5562064 gradients, 25.1 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost-p2 summary: 434 layers, 9031728 parameters, 9031712 gradients, 42.8 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost-p2 summary: 578 layers, 12214448 parameters, 12214432 gradients, 69.1 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost-p2 summary: 578 layers, 18664776 parameters, 18664760 gradients, 103.3 GFLOPs + +# YOLOv8.0-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0-ghost-p2 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 15 (P3/8-small) + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C3Ghost, [128]] # 18 (P2/4-xsmall) + + - [-1, 1, GhostConv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C3Ghost, [256]] # 21 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 24 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [1024]] # 27 (P5/32-large) + + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c11bbbe6a98f533456a6eba760a5ccf0fb08b7ae --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml @@ -0,0 +1,60 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost-p6 summary: 312 layers, 2901100 parameters, 2901084 gradients, 5.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost-p6 summary: 312 layers, 9520008 parameters, 9519992 gradients, 16.4 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost-p6 summary: 468 layers, 18002904 parameters, 18002888 gradients, 34.4 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost-p6 summary: 624 layers, 21227584 parameters, 21227568 gradients, 55.3 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost-p6 summary: 624 layers, 33057852 parameters, 33057836 gradients, 85.7 GFLOPs + +# YOLOv8.0-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [768, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0-ghost-p6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C3Ghost, [768]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 20 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 23 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [768]] # 26 (P5/32-large) + + - [-1, 1, GhostConv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C3Ghost, [1024]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-ghost.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-ghost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..371b766eee8590ac5d543341240e7d7e3c945a2b --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-ghost.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost summary: 237 layers, 1865316 parameters, 1865300 gradients, 5.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost summary: 237 layers, 5960072 parameters, 5960056 gradients, 16.4 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost summary: 357 layers, 10336312 parameters, 10336296 gradients, 32.7 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost summary: 477 layers, 14277872 parameters, 14277856 gradients, 53.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost summary: 477 layers, 22229308 parameters, 22229292 gradients, 83.3 GFLOPs + +# YOLOv8.0n-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 15 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 18 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-obb.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-obb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0b6cef3479d64988418127efde89406dd5045559 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-obb.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/obb + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-obb summary: 144 layers, 3228867 parameters, 3228851 gradients, 9.1 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-obb summary: 144 layers, 11452739 parameters, 11452723 gradients, 29.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-obb summary: 184 layers, 26463235 parameters, 26463219 gradients, 81.5 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-obb summary: 224 layers, 44540355 parameters, 44540339 gradients, 169.4 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-obb summary: 224 layers, 69555651 parameters, 69555635 gradients, 264.3 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, OBB, [nc, 1]] # OBB(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-p2.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-p2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..676bc8348c4b264e03140f9d4ec2da7a65b02a8b --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-p2.yaml @@ -0,0 +1,57 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P2/4 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0-p2 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) + + - [-1, 1, Conv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C2f, [256]] # 21 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 24 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 27 (P5/32-large) + + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-p6.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fb243fbe36433ffbec4c4ea286ec23956f7ead0 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-p6.yaml @@ -0,0 +1,59 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-p6 summary: 170 layers, 4984352 parameters, 4984336 gradients, 8.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-p6 summary: 170 layers, 17911792 parameters, 17911776 gradients, 28.7 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-p6 summary: 222 layers, 44887488 parameters, 44887472 gradients, 83.5 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-p6 summary: 274 layers, 62384016 parameters, 62384000 gradients, 167.9 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-p6 summary: 274 layers, 97423072 parameters, 97423056 gradients, 261.8 GFLOPs + +# YOLOv8.0x6 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [768, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0x6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..447a21aab0703b0f3a6520ba85b112ff6bd4f809 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml @@ -0,0 +1,60 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-pose keypoints/pose estimation model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0x6 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [768, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0x6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-pose.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c22bc435b578d58b1a432e1ee0f203562e73f076 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-pose.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-pose keypoints/pose estimation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2abded63b5d5688af2f0ad727c787ec366b56e8b --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-RTDETR hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-rtdetr summary: 235 layers, 9643868 parameters, 9643868 gradients, 17.1 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-rtdetr summary: 235 layers, 16518572 parameters, 16518572 gradients, 32.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-rtdetr summary: 275 layers, 29645180 parameters, 29645180 gradients, 75.8 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-rtdetr summary: 315 layers, 45644364 parameters, 45644364 gradients, 152.3 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-rtdetr summary: 315 layers, 67113884 parameters, 67113884 gradients, 230.8 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c7ba9bf4dddaa23adc5521ff51ab149ae09495a --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml @@ -0,0 +1,59 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-seg instance segmentation model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-seg-p6.yaml' will call yolov8-seg-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0x6 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [768, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0x6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Segment, [nc, 32, 256]] # Pose(P3, P4, P5, P6) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-seg.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52b1c7e9aedc4fbc04d5c4f48f4ad214ad6a51e2 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-seg.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-world.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-world.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c7a023915bd50a36024da47320673085fd97b2c --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-world.yaml @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-World hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo-world +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-world summary: 161 layers, 4204111 parameters, 4204095 gradients, 39.6 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-world summary: 161 layers, 13383496 parameters, 13383480 gradients, 71.5 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-world summary: 201 layers, 29065310 parameters, 29065294 gradients, 131.4 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-world summary: 241 layers, 47553970 parameters, 47553954 gradients, 225.6 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-world summary: 241 layers, 73690217 parameters, 73690201 gradients, 330.8 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + + - [[15, 12, 9], 1, ImagePoolingAttn, [256]] # 16 (P3/8-small) + + - [15, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fAttn, [1024, 512, 16]] # 22 (P5/32-large) + + - [[15, 19, 22], 1, WorldDetect, [nc, 512, False]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8-worldv2.yaml b/tracking/ultralytics/cfg/models/v8/yolov8-worldv2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8396009823acc084b9a69a963eb32dd02da92b3a --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8-worldv2.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-Worldv2 hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo-world +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-worldv2 summary: 148 layers, 3695183 parameters, 3695167 gradients, 19.5 GFLOPS + s: [0.33, 0.50, 1024] # YOLOv8s-worldv2 summary: 148 layers, 12759880 parameters, 12759864 gradients, 51.0 GFLOPS + m: [0.67, 0.75, 768] # YOLOv8m-worldv2 summary: 188 layers, 28376158 parameters, 28376142 gradients, 110.5 GFLOPS + l: [1.00, 1.00, 512] # YOLOv8l-worldv2 summary: 228 layers, 46832050 parameters, 46832034 gradients, 204.5 GFLOPS + x: [1.00, 1.25, 512] # YOLOv8x-worldv2 summary: 228 layers, 72886377 parameters, 72886361 gradients, 309.3 GFLOPS + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + + - [15, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v8/yolov8.yaml b/tracking/ultralytics/cfg/models/v8/yolov8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..189b89c10b5369b1d8a6b88b731b39ac7040ae56 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v8/yolov8.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 129 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPS + s: [0.33, 0.50, 1024] # YOLOv8s summary: 129 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPS + m: [0.67, 0.75, 768] # YOLOv8m summary: 169 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPS + l: [1.00, 1.00, 512] # YOLOv8l summary: 209 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPS + x: [1.00, 1.25, 512] # YOLOv8x summary: 209 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPS + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9c-seg.yaml b/tracking/ultralytics/cfg/models/v9/yolov9c-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..94940fe4a3083da9b04518d48425ec7f1d177dd1 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9c-seg.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9c-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/segment +# 380 layers, 27897120 parameters, 159.4 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2 + - [-1, 1, ADown, [256]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4 + - [-1, 1, ADown, [512]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6 + - [-1, 1, ADown, [512]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8 + - [-1, 1, SPPELAN, [512, 256]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9c.yaml b/tracking/ultralytics/cfg/models/v9/yolov9c.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2808d8160a0a316f199a11de53d1887697150b5c --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9c.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9c object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 358 layers, 25590912 parameters, 104.0 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2 + - [-1, 1, ADown, [256]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4 + - [-1, 1, ADown, [512]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6 + - [-1, 1, ADown, [512]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8 + - [-1, 1, SPPELAN, [512, 256]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9e-seg.yaml b/tracking/ultralytics/cfg/models/v9/yolov9e-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..056ec842e0335e9f5c058ddea819c11a4d038285 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9e-seg.yaml @@ -0,0 +1,64 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9e-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/segment +# 743 layers, 60512800 parameters, 248.4 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, nn.Identity, []] + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3 + - [-1, 1, ADown, [256]] # 4-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5 + - [-1, 1, ADown, [512]] # 6-P4/16 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7 + - [-1, 1, ADown, [1024]] # 8-P5/32 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9 + + - [1, 1, CBLinear, [[64]]] # 10 + - [3, 1, CBLinear, [[64, 128]]] # 11 + - [5, 1, CBLinear, [[64, 128, 256]]] # 12 + - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13 + - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14 + + - [0, 1, Conv, [64, 3, 2]] # 15-P1/2 + - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16 + - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4 + - [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19 + - [-1, 1, ADown, [256]] # 20-P3/8 + - [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22 + - [-1, 1, ADown, [512]] # 23-P4/16 + - [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25 + - [-1, 1, ADown, [1024]] # 26-P5/32 + - [[14, -1], 1, CBFuse, [[4]]] # 27 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28 + - [-1, 1, SPPELAN, [512, 256]] # 29 + +# GELAN head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 25], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 22], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 32], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 29], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large) + + - [[35, 38, 41], 1, Segment, [nc, 32, 256]] # Segment (P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9e.yaml b/tracking/ultralytics/cfg/models/v9/yolov9e.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29a61e5cbb2f0d85c86331e5d08241b46e06f8e8 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9e.yaml @@ -0,0 +1,64 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9e object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 721 layers, 58206592 parameters, 193.0 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, nn.Identity, []] + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3 + - [-1, 1, ADown, [256]] # 4-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5 + - [-1, 1, ADown, [512]] # 6-P4/16 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7 + - [-1, 1, ADown, [1024]] # 8-P5/32 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9 + + - [1, 1, CBLinear, [[64]]] # 10 + - [3, 1, CBLinear, [[64, 128]]] # 11 + - [5, 1, CBLinear, [[64, 128, 256]]] # 12 + - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13 + - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14 + + - [0, 1, Conv, [64, 3, 2]] # 15-P1/2 + - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16 + - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4 + - [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19 + - [-1, 1, ADown, [256]] # 20-P3/8 + - [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22 + - [-1, 1, ADown, [512]] # 23-P4/16 + - [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25 + - [-1, 1, ADown, [1024]] # 26-P5/32 + - [[14, -1], 1, CBFuse, [[4]]] # 27 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28 + - [-1, 1, SPPELAN, [512, 256]] # 29 + +# GELAN head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 25], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 22], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 32], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 29], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large) + + - [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9m.yaml b/tracking/ultralytics/cfg/models/v9/yolov9m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..683f90dac8b505caf2bd808d4c7c3f0424e74379 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9m.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9m object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 348 layers, 20216160 parameters, 77.9 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [32, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [64, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 1]] # 2 + - [-1, 1, AConv, [240]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]] # 4 + - [-1, 1, AConv, [360]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]] # 6 + - [-1, 1, AConv, [480]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]] # 8 + - [-1, 1, SPPELAN, [480, 240]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]] # 15 + + - [-1, 1, AConv, [180]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]] # 18 (P4/16-medium) + + - [-1, 1, AConv, [240]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9s.yaml b/tracking/ultralytics/cfg/models/v9/yolov9s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6758a2952282e992bb9d3ed979f73ae95bd3b7d9 --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9s.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9s object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 544 layers, 7318368 parameters, 27.6 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [32, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [64, 3, 2]] # 1-P2/4 + - [-1, 1, ELAN1, [64, 64, 32]] # 2 + - [-1, 1, AConv, [128]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 4 + - [-1, 1, AConv, [192]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]] # 6 + - [-1, 1, AConv, [256]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]] # 8 + - [-1, 1, SPPELAN, [256, 128]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 15 + + - [-1, 1, AConv, [96]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]] # 18 (P4/16-medium) + + - [-1, 1, AConv, [128]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4 P5) diff --git a/tracking/ultralytics/cfg/models/v9/yolov9t.yaml b/tracking/ultralytics/cfg/models/v9/yolov9t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..950d28f499de8320a296c00cf391666993b0253b --- /dev/null +++ b/tracking/ultralytics/cfg/models/v9/yolov9t.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9t object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 544 layers, 2128720 parameters, 8.5 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [16, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [32, 3, 2]] # 1-P2/4 + - [-1, 1, ELAN1, [32, 32, 16]] # 2 + - [-1, 1, AConv, [64]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]] # 4 + - [-1, 1, AConv, [96]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 6 + - [-1, 1, AConv, [128]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 8 + - [-1, 1, SPPELAN, [128, 64]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]] # 15 + + - [-1, 1, AConv, [48]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 18 (P4/16-medium) + + - [-1, 1, AConv, [64]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/tracking/ultralytics/cfg/solutions/default.yaml b/tracking/ultralytics/cfg/solutions/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4afb49b324894c7caa6f6a1382cdc8acae10db9 --- /dev/null +++ b/tracking/ultralytics/cfg/solutions/default.yaml @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Global configuration YAML with settings and arguments for Ultralytics Solutions +# For documentation see https://docs.ultralytics.com/solutions/ + +# Object counting settings -------------------------------------------------------------------------------------------- +region: # list[tuple[int, int]] object counting, queue or speed estimation region points. +show_in: True # (bool) flag to display objects moving *into* the defined region +show_out: True # (bool) flag to display objects moving *out of* the defined region + +# Heatmaps settings ---------------------------------------------------------------------------------------------------- +colormap: # (int | str) colormap for heatmap, Only OPENCV supported colormaps can be used. + +# Workouts monitoring settings ----------------------------------------------------------------------------------------- +up_angle: 145.0 # (float) Workouts up_angle for counts, 145.0 is default value. +down_angle: 90 # (float) Workouts down_angle for counts, 90 is default value. Y +kpts: [6, 8, 10] # (list[int]) keypoints for workouts monitoring, i.e. for push-ups kpts have values of [6, 8, 10]. + +# Analytics settings --------------------------------------------------------------------------------------------------- +analytics_type: "line" # (str) analytics type i.e "line", "pie", "bar" or "area" charts. +json_file: # (str) parking system regions file path. + +# Security alarm system settings --------------------------------------------------------------------------------------- +records: 5 # (int) Total detections count to send an email about security diff --git a/tracking/ultralytics/cfg/trackers/botsort.yaml b/tracking/ultralytics/cfg/trackers/botsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aedcee4860fe00e969aca4c51ff0975d931ac283 --- /dev/null +++ b/tracking/ultralytics/cfg/trackers/botsort.yaml @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Default Ultralytics settings for BoT-SORT tracker when using mode="track" +# For documentation and examples see https://docs.ultralytics.com/modes/track/ +# For BoT-SORT source code see https://github.com/NirAharon/BoT-SORT + +tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.25 # threshold for the first association +track_low_thresh: 0.1 # threshold for the second association +new_track_thresh: 0.25 # threshold for init new track if the detection does not match any tracks +track_buffer: 30 # buffer to calculate the time when to remove tracks +match_thresh: 0.8 # threshold for matching tracks +fuse_score: True # Whether to fuse confidence scores with the iou distances before matching +# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) + +# BoT-SORT settings +gmc_method: sparseOptFlow # method of global motion compensation +# ReID model related thresh (not supported yet) +proximity_thresh: 0.5 +appearance_thresh: 0.25 +with_reid: False diff --git a/tracking/ultralytics/cfg/trackers/bytetrack.yaml b/tracking/ultralytics/cfg/trackers/bytetrack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62071a3022da1a38fc8aaf74a538d3829a489502 --- /dev/null +++ b/tracking/ultralytics/cfg/trackers/bytetrack.yaml @@ -0,0 +1,14 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Default Ultralytics settings for ByteTrack tracker when using mode="track" +# For documentation and examples see https://docs.ultralytics.com/modes/track/ +# For ByteTrack source code see https://github.com/ifzhang/ByteTrack + +tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.25 # threshold for the first association +track_low_thresh: 0.1 # threshold for the second association +new_track_thresh: 0.25 # threshold for init new track if the detection does not match any tracks +track_buffer: 30 # buffer to calculate the time when to remove tracks +match_thresh: 0.8 # threshold for matching tracks +fuse_score: True # Whether to fuse confidence scores with the iou distances before matching +# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) diff --git a/tracking/ultralytics/data/__init__.py b/tracking/ultralytics/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d258d5df5f417903d2a42c9dce4174610a9804 --- /dev/null +++ b/tracking/ultralytics/data/__init__.py @@ -0,0 +1,26 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .base import BaseDataset +from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source +from .dataset import ( + ClassificationDataset, + GroundingDataset, + SemanticDataset, + YOLOConcatDataset, + YOLODataset, + YOLOMultiModalDataset, +) + +__all__ = ( + "BaseDataset", + "ClassificationDataset", + "SemanticDataset", + "YOLODataset", + "YOLOMultiModalDataset", + "YOLOConcatDataset", + "GroundingDataset", + "build_yolo_dataset", + "build_grounding", + "build_dataloader", + "load_inference_source", +) diff --git a/tracking/ultralytics/data/annotator.py b/tracking/ultralytics/data/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..cb72ab9f3a75e3dfad3735e38e55493ee8d87a81 --- /dev/null +++ b/tracking/ultralytics/data/annotator.py @@ -0,0 +1,65 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics import SAM, YOLO + + +def auto_annotate( + data, + det_model="yolo11x.pt", + sam_model="sam_b.pt", + device="", + conf=0.25, + iou=0.45, + imgsz=640, + max_det=300, + classes=None, + output_dir=None, +): + """ + Automatically annotate images using a YOLO object detection model and a SAM segmentation model. + + This function processes images in a specified directory, detects objects using a YOLO model, and then generates + segmentation masks using a SAM model. The resulting annotations are saved as text files. + + Args: + data (str | Path): Path to a folder containing images to be annotated. + det_model (str): Path or name of the pre-trained YOLO detection model. + sam_model (str): Path or name of the pre-trained SAM segmentation model. + device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). + conf (float): Confidence threshold for detection model. + iou (float): IoU threshold for filtering overlapping boxes in detection results. + imgsz (int): Input image resize dimension. + max_det (int): Maximum number of detections per image. + classes (List[int] | None): Filter predictions to specified class IDs, returning only relevant detections. + output_dir (str | Path | None): Directory to save the annotated results. If None, a default directory is created. + + Examples: + >>> from ultralytics.data.annotator import auto_annotate + >>> auto_annotate(data="ultralytics/assets", det_model="yolo11n.pt", sam_model="mobile_sam.pt") + """ + det_model = YOLO(det_model) + sam_model = SAM(sam_model) + + data = Path(data) + if not output_dir: + output_dir = data.parent / f"{data.stem}_auto_annotate_labels" + Path(output_dir).mkdir(exist_ok=True, parents=True) + + det_results = det_model( + data, stream=True, device=device, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det, classes=classes + ) + + for result in det_results: + class_ids = result.boxes.cls.int().tolist() # noqa + if class_ids: + boxes = result.boxes.xyxy # Boxes object for bbox outputs + sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) + segments = sam_results[0].masks.xyn + + with open(f"{Path(output_dir) / Path(result.path).stem}.txt", "w", encoding="utf-8") as f: + for i, s in enumerate(segments): + if s.any(): + segment = map(str, s.reshape(-1).tolist()) + f.write(f"{class_ids[i]} " + " ".join(segment) + "\n") diff --git a/tracking/ultralytics/data/augment.py b/tracking/ultralytics/data/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..0faf2af5127ae3fb8177d4a2f999ed2ffdadf07f --- /dev/null +++ b/tracking/ultralytics/data/augment.py @@ -0,0 +1,2746 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +import random +from copy import deepcopy +from typing import Tuple, Union + +import cv2 +import numpy as np +import torch +from PIL import Image + +from ultralytics.data.utils import polygons2masks, polygons2masks_overlap +from ultralytics.utils import LOGGER, colorstr +from ultralytics.utils.checks import check_version +from ultralytics.utils.instance import Instances +from ultralytics.utils.metrics import bbox_ioa +from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr +from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13 + +DEFAULT_MEAN = (0.0, 0.0, 0.0) +DEFAULT_STD = (1.0, 1.0, 1.0) +DEFAULT_CROP_FRACTION = 1.0 + + +class BaseTransform: + """ + Base class for image transformations in the Ultralytics library. + + This class serves as a foundation for implementing various image processing operations, designed to be + compatible with both classification and semantic segmentation tasks. + + Methods: + apply_image: Applies image transformations to labels. + apply_instances: Applies transformations to object instances in labels. + apply_semantic: Applies semantic segmentation to an image. + __call__: Applies all label transformations to an image, instances, and semantic masks. + + Examples: + >>> transform = BaseTransform() + >>> labels = {"image": np.array(...), "instances": [...], "semantic": np.array(...)} + >>> transformed_labels = transform(labels) + """ + + def __init__(self) -> None: + """ + Initializes the BaseTransform object. + + This constructor sets up the base transformation object, which can be extended for specific image + processing tasks. It is designed to be compatible with both classification and semantic segmentation. + + Examples: + >>> transform = BaseTransform() + """ + pass + + def apply_image(self, labels): + """ + Applies image transformations to labels. + + This method is intended to be overridden by subclasses to implement specific image transformation + logic. In its base form, it returns the input labels unchanged. + + Args: + labels (Any): The input labels to be transformed. The exact type and structure of labels may + vary depending on the specific implementation. + + Returns: + (Any): The transformed labels. In the base implementation, this is identical to the input. + + Examples: + >>> transform = BaseTransform() + >>> original_labels = [1, 2, 3] + >>> transformed_labels = transform.apply_image(original_labels) + >>> print(transformed_labels) + [1, 2, 3] + """ + pass + + def apply_instances(self, labels): + """ + Applies transformations to object instances in labels. + + This method is responsible for applying various transformations to object instances within the given + labels. It is designed to be overridden by subclasses to implement specific instance transformation + logic. + + Args: + labels (dict): A dictionary containing label information, including object instances. + + Returns: + (dict): The modified labels dictionary with transformed object instances. + + Examples: + >>> transform = BaseTransform() + >>> labels = {"instances": Instances(xyxy=torch.rand(5, 4), cls=torch.randint(0, 80, (5,)))} + >>> transformed_labels = transform.apply_instances(labels) + """ + pass + + def apply_semantic(self, labels): + """ + Applies semantic segmentation transformations to an image. + + This method is intended to be overridden by subclasses to implement specific semantic segmentation + transformations. In its base form, it does not perform any operations. + + Args: + labels (Any): The input labels or semantic segmentation mask to be transformed. + + Returns: + (Any): The transformed semantic segmentation mask or labels. + + Examples: + >>> transform = BaseTransform() + >>> semantic_mask = np.zeros((100, 100), dtype=np.uint8) + >>> transformed_mask = transform.apply_semantic(semantic_mask) + """ + pass + + def __call__(self, labels): + """ + Applies all label transformations to an image, instances, and semantic masks. + + This method orchestrates the application of various transformations defined in the BaseTransform class + to the input labels. It sequentially calls the apply_image and apply_instances methods to process the + image and object instances, respectively. + + Args: + labels (dict): A dictionary containing image data and annotations. Expected keys include 'img' for + the image data, and 'instances' for object instances. + + Returns: + (dict): The input labels dictionary with transformed image and instances. + + Examples: + >>> transform = BaseTransform() + >>> labels = {"img": np.random.rand(640, 640, 3), "instances": []} + >>> transformed_labels = transform(labels) + """ + self.apply_image(labels) + self.apply_instances(labels) + self.apply_semantic(labels) + + +class Compose: + """ + A class for composing multiple image transformations. + + Attributes: + transforms (List[Callable]): A list of transformation functions to be applied sequentially. + + Methods: + __call__: Applies a series of transformations to input data. + append: Appends a new transform to the existing list of transforms. + insert: Inserts a new transform at a specified index in the list of transforms. + __getitem__: Retrieves a specific transform or a set of transforms using indexing. + __setitem__: Sets a specific transform or a set of transforms using indexing. + tolist: Converts the list of transforms to a standard Python list. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(30)] + >>> compose = Compose(transforms) + >>> transformed_data = compose(data) + >>> compose.append(CenterCrop((224, 224))) + >>> compose.insert(0, RandomFlip()) + """ + + def __init__(self, transforms): + """ + Initializes the Compose object with a list of transforms. + + Args: + transforms (List[Callable]): A list of callable transform objects to be applied sequentially. + + Examples: + >>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip + >>> transforms = [RandomHSV(), RandomFlip()] + >>> compose = Compose(transforms) + """ + self.transforms = transforms if isinstance(transforms, list) else [transforms] + + def __call__(self, data): + """ + Applies a series of transformations to input data. This method sequentially applies each transformation in the + Compose object's list of transforms to the input data. + + Args: + data (Any): The input data to be transformed. This can be of any type, depending on the + transformations in the list. + + Returns: + (Any): The transformed data after applying all transformations in sequence. + + Examples: + >>> transforms = [Transform1(), Transform2(), Transform3()] + >>> compose = Compose(transforms) + >>> transformed_data = compose(input_data) + """ + for t in self.transforms: + data = t(data) + return data + + def append(self, transform): + """ + Appends a new transform to the existing list of transforms. + + Args: + transform (BaseTransform): The transformation to be added to the composition. + + Examples: + >>> compose = Compose([RandomFlip(), RandomPerspective()]) + >>> compose.append(RandomHSV()) + """ + self.transforms.append(transform) + + def insert(self, index, transform): + """ + Inserts a new transform at a specified index in the existing list of transforms. + + Args: + index (int): The index at which to insert the new transform. + transform (BaseTransform): The transform object to be inserted. + + Examples: + >>> compose = Compose([Transform1(), Transform2()]) + >>> compose.insert(1, Transform3()) + >>> len(compose.transforms) + 3 + """ + self.transforms.insert(index, transform) + + def __getitem__(self, index: Union[list, int]) -> "Compose": + """ + Retrieves a specific transform or a set of transforms using indexing. + + Args: + index (int | List[int]): Index or list of indices of the transforms to retrieve. + + Returns: + (Compose): A new Compose object containing the selected transform(s). + + Raises: + AssertionError: If the index is not of type int or list. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(10), RandomHSV(0.5, 0.5, 0.5)] + >>> compose = Compose(transforms) + >>> single_transform = compose[1] # Returns a Compose object with only RandomPerspective + >>> multiple_transforms = compose[0:2] # Returns a Compose object with RandomFlip and RandomPerspective + """ + assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" + index = [index] if isinstance(index, int) else index + return Compose([self.transforms[i] for i in index]) + + def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None: + """ + Sets one or more transforms in the composition using indexing. + + Args: + index (int | List[int]): Index or list of indices to set transforms at. + value (Any | List[Any]): Transform or list of transforms to set at the specified index(es). + + Raises: + AssertionError: If index type is invalid, value type doesn't match index type, or index is out of range. + + Examples: + >>> compose = Compose([Transform1(), Transform2(), Transform3()]) + >>> compose[1] = NewTransform() # Replace second transform + >>> compose[0:2] = [NewTransform1(), NewTransform2()] # Replace first two transforms + """ + assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" + if isinstance(index, list): + assert isinstance(value, list), ( + f"The indices should be the same type as values, but got {type(index)} and {type(value)}" + ) + if isinstance(index, int): + index, value = [index], [value] + for i, v in zip(index, value): + assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}." + self.transforms[i] = v + + def tolist(self): + """ + Converts the list of transforms to a standard Python list. + + Returns: + (list): A list containing all the transform objects in the Compose instance. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(10), CenterCrop()] + >>> compose = Compose(transforms) + >>> transform_list = compose.tolist() + >>> print(len(transform_list)) + 3 + """ + return self.transforms + + def __repr__(self): + """ + Returns a string representation of the Compose object. + + Returns: + (str): A string representation of the Compose object, including the list of transforms. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(degrees=10, translate=0.1, scale=0.1)] + >>> compose = Compose(transforms) + >>> print(compose) + Compose([ + RandomFlip(), + RandomPerspective(degrees=10, translate=0.1, scale=0.1) + ]) + """ + return f"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})" + + +class BaseMixTransform: + """ + Base class for mix transformations like MixUp and Mosaic. + + This class provides a foundation for implementing mix transformations on datasets. It handles the + probability-based application of transforms and manages the mixing of multiple images and labels. + + Attributes: + dataset (Any): The dataset object containing images and labels. + pre_transform (Callable | None): Optional transform to apply before mixing. + p (float): Probability of applying the mix transformation. + + Methods: + __call__: Applies the mix transformation to the input labels. + _mix_transform: Abstract method to be implemented by subclasses for specific mix operations. + get_indexes: Abstract method to get indexes of images to be mixed. + _update_label_text: Updates label text for mixed images. + + Examples: + >>> class CustomMixTransform(BaseMixTransform): + ... def _mix_transform(self, labels): + ... # Implement custom mix logic here + ... return labels + ... + ... def get_indexes(self): + ... return [random.randint(0, len(self.dataset) - 1) for _ in range(3)] + >>> dataset = YourDataset() + >>> transform = CustomMixTransform(dataset, p=0.5) + >>> mixed_labels = transform(original_labels) + """ + + def __init__(self, dataset, pre_transform=None, p=0.0) -> None: + """ + Initializes the BaseMixTransform object for mix transformations like MixUp and Mosaic. + + This class serves as a base for implementing mix transformations in image processing pipelines. + + Args: + dataset (Any): The dataset object containing images and labels for mixing. + pre_transform (Callable | None): Optional transform to apply before mixing. + p (float): Probability of applying the mix transformation. Should be in the range [0.0, 1.0]. + + Examples: + >>> dataset = YOLODataset("path/to/data") + >>> pre_transform = Compose([RandomFlip(), RandomPerspective()]) + >>> mix_transform = BaseMixTransform(dataset, pre_transform, p=0.5) + """ + self.dataset = dataset + self.pre_transform = pre_transform + self.p = p + + def __call__(self, labels): + """ + Applies pre-processing transforms and mixup/mosaic transforms to labels data. + + This method determines whether to apply the mix transform based on a probability factor. If applied, it + selects additional images, applies pre-transforms if specified, and then performs the mix transform. + + Args: + labels (dict): A dictionary containing label data for an image. + + Returns: + (dict): The transformed labels dictionary, which may include mixed data from other images. + + Examples: + >>> transform = BaseMixTransform(dataset, pre_transform=None, p=0.5) + >>> result = transform({"image": img, "bboxes": boxes, "cls": classes}) + """ + if random.uniform(0, 1) > self.p: + return labels + + # Get index of one or three other images + indexes = self.get_indexes() + if isinstance(indexes, int): + indexes = [indexes] + + # Get images information will be used for Mosaic or MixUp + mix_labels = [self.dataset.get_image_and_label(i) for i in indexes] + + if self.pre_transform is not None: + for i, data in enumerate(mix_labels): + mix_labels[i] = self.pre_transform(data) + labels["mix_labels"] = mix_labels + + # Update cls and texts + labels = self._update_label_text(labels) + # Mosaic or MixUp + labels = self._mix_transform(labels) + labels.pop("mix_labels", None) + return labels + + def _mix_transform(self, labels): + """ + Applies MixUp or Mosaic augmentation to the label dictionary. + + This method should be implemented by subclasses to perform specific mix transformations like MixUp or + Mosaic. It modifies the input label dictionary in-place with the augmented data. + + Args: + labels (dict): A dictionary containing image and label data. Expected to have a 'mix_labels' key + with a list of additional image and label data for mixing. + + Returns: + (dict): The modified labels dictionary with augmented data after applying the mix transform. + + Examples: + >>> transform = BaseMixTransform(dataset) + >>> labels = {"image": img, "bboxes": boxes, "mix_labels": [{"image": img2, "bboxes": boxes2}]} + >>> augmented_labels = transform._mix_transform(labels) + """ + raise NotImplementedError + + def get_indexes(self): + """ + Gets a list of shuffled indexes for mosaic augmentation. + + Returns: + (List[int]): A list of shuffled indexes from the dataset. + + Examples: + >>> transform = BaseMixTransform(dataset) + >>> indexes = transform.get_indexes() + >>> print(indexes) # [3, 18, 7, 2] + """ + raise NotImplementedError + + @staticmethod + def _update_label_text(labels): + """ + Updates label text and class IDs for mixed labels in image augmentation. + + This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels, + creating a unified set of text labels and updating class IDs accordingly. + + Args: + labels (dict): A dictionary containing label information, including 'texts' and 'cls' fields, + and optionally a 'mix_labels' field with additional label dictionaries. + + Returns: + (dict): The updated labels dictionary with unified text labels and updated class IDs. + + Examples: + >>> labels = { + ... "texts": [["cat"], ["dog"]], + ... "cls": torch.tensor([[0], [1]]), + ... "mix_labels": [{"texts": [["bird"], ["fish"]], "cls": torch.tensor([[0], [1]])}], + ... } + >>> updated_labels = self._update_label_text(labels) + >>> print(updated_labels["texts"]) + [['cat'], ['dog'], ['bird'], ['fish']] + >>> print(updated_labels["cls"]) + tensor([[0], + [1]]) + >>> print(updated_labels["mix_labels"][0]["cls"]) + tensor([[2], + [3]]) + """ + if "texts" not in labels: + return labels + + mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], []) + mix_texts = list({tuple(x) for x in mix_texts}) + text2id = {text: i for i, text in enumerate(mix_texts)} + + for label in [labels] + labels["mix_labels"]: + for i, cls in enumerate(label["cls"].squeeze(-1).tolist()): + text = label["texts"][int(cls)] + label["cls"][i] = text2id[tuple(text)] + label["texts"] = mix_texts + return labels + + +class Mosaic(BaseMixTransform): + """ + Mosaic augmentation for image datasets. + + This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. + The augmentation is applied to a dataset with a given probability. + + Attributes: + dataset: The dataset on which the mosaic augmentation is applied. + imgsz (int): Image size (height and width) after mosaic pipeline of a single image. + p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1. + n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3). + border (Tuple[int, int]): Border size for width and height. + + Methods: + get_indexes: Returns a list of random indexes from the dataset. + _mix_transform: Applies mixup transformation to the input image and labels. + _mosaic3: Creates a 1x3 image mosaic. + _mosaic4: Creates a 2x2 image mosaic. + _mosaic9: Creates a 3x3 image mosaic. + _update_labels: Updates labels with padding. + _cat_labels: Concatenates labels and clips mosaic border instances. + + Examples: + >>> from ultralytics.data.augment import Mosaic + >>> dataset = YourDataset(...) # Your image dataset + >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4) + >>> augmented_labels = mosaic_aug(original_labels) + """ + + def __init__(self, dataset, imgsz=640, p=1.0, n=4): + """ + Initializes the Mosaic augmentation object. + + This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. + The augmentation is applied to a dataset with a given probability. + + Args: + dataset (Any): The dataset on which the mosaic augmentation is applied. + imgsz (int): Image size (height and width) after mosaic pipeline of a single image. + p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1. + n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3). + + Examples: + >>> from ultralytics.data.augment import Mosaic + >>> dataset = YourDataset(...) + >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4) + """ + assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." + assert n in {4, 9}, "grid must be equal to 4 or 9." + super().__init__(dataset=dataset, p=p) + self.imgsz = imgsz + self.border = (-imgsz // 2, -imgsz // 2) # width, height + self.n = n + + def get_indexes(self, buffer=True): + """ + Returns a list of random indexes from the dataset for mosaic augmentation. + + This method selects random image indexes either from a buffer or from the entire dataset, depending on + the 'buffer' parameter. It is used to choose images for creating mosaic augmentations. + + Args: + buffer (bool): If True, selects images from the dataset buffer. If False, selects from the entire + dataset. + + Returns: + (List[int]): A list of random image indexes. The length of the list is n-1, where n is the number + of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9). + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4) + >>> indexes = mosaic.get_indexes() + >>> print(len(indexes)) # Output: 3 + """ + if buffer: # select images from buffer + return random.choices(list(self.dataset.buffer), k=self.n - 1) + else: # select any images + return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)] + + def _mix_transform(self, labels): + """ + Applies mosaic augmentation to the input image and labels. + + This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute. + It ensures that rectangular annotations are not present and that there are other images available for + mosaic augmentation. + + Args: + labels (dict): A dictionary containing image data and annotations. Expected keys include: + - 'rect_shape': Should be None as rect and mosaic are mutually exclusive. + - 'mix_labels': A list of dictionaries containing data for other images to be used in the mosaic. + + Returns: + (dict): A dictionary containing the mosaic-augmented image and updated annotations. + + Raises: + AssertionError: If 'rect_shape' is not None or if 'mix_labels' is empty. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4) + >>> augmented_data = mosaic._mix_transform(labels) + """ + assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive." + assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment." + return ( + self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels) + ) # This code is modified for mosaic3 method. + + def _mosaic3(self, labels): + """ + Creates a 1x3 image mosaic by combining three images. + + This method arranges three images in a horizontal layout, with the main image in the center and two + additional images on either side. It's part of the Mosaic augmentation technique used in object detection. + + Args: + labels (dict): A dictionary containing image and label information for the main (center) image. + Must include 'img' key with the image array, and 'mix_labels' key with a list of two + dictionaries containing information for the side images. + + Returns: + (dict): A dictionary with the mosaic image and updated labels. Keys include: + - 'img' (np.ndarray): The mosaic image array with shape (H, W, C). + - Other keys from the input labels, updated to reflect the new image dimensions. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=3) + >>> labels = { + ... "img": np.random.rand(480, 640, 3), + ... "mix_labels": [{"img": np.random.rand(480, 640, 3)} for _ in range(2)], + ... } + >>> result = mosaic._mosaic3(labels) + >>> print(result["img"].shape) + (640, 640, 3) + """ + mosaic_labels = [] + s = self.imgsz + for i in range(3): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img3 + if i == 0: # center + img3 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 3 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 2: # left + c = s - w, s + h0 - h, s, s + h0 + + padw, padh = c[:2] + x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coordinates + + img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax] + # hp, wp = h, w # height, width previous for next iteration + + # Labels assuming imgsz*2 mosaic size + labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + + final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] + return final_labels + + def _mosaic4(self, labels): + """ + Creates a 2x2 image mosaic from four input images. + + This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also + updates the corresponding labels for each image in the mosaic. + + Args: + labels (dict): A dictionary containing image data and labels for the base image (index 0) and three + additional images (indices 1-3) in the 'mix_labels' key. + + Returns: + (dict): A dictionary containing the mosaic image and updated labels. The 'img' key contains the mosaic + image as a numpy array, and other keys contain the combined and adjusted labels for all four images. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4) + >>> labels = { + ... "img": np.random.rand(480, 640, 3), + ... "mix_labels": [{"img": np.random.rand(480, 640, 3)} for _ in range(3)], + ... } + >>> result = mosaic._mosaic4(labels) + >>> assert result["img"].shape == (1280, 1280, 3) + """ + mosaic_labels = [] + s = self.imgsz + yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y + for i in range(4): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img4 + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + padw = x1a - x1b + padh = y1a - y1b + + labels_patch = self._update_labels(labels_patch, padw, padh) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + final_labels["img"] = img4 + return final_labels + + def _mosaic9(self, labels): + """ + Creates a 3x3 image mosaic from the input image and eight additional images. + + This method combines nine images into a single mosaic image. The input image is placed at the center, + and eight additional images from the dataset are placed around it in a 3x3 grid pattern. + + Args: + labels (dict): A dictionary containing the input image and its associated labels. It should have + the following keys: + - 'img' (numpy.ndarray): The input image. + - 'resized_shape' (Tuple[int, int]): The shape of the resized image (height, width). + - 'mix_labels' (List[Dict]): A list of dictionaries containing information for the additional + eight images, each with the same structure as the input labels. + + Returns: + (dict): A dictionary containing the mosaic image and updated labels. It includes the following keys: + - 'img' (numpy.ndarray): The final mosaic image. + - Other keys from the input labels, updated to reflect the new mosaic arrangement. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=9) + >>> input_labels = dataset[0] + >>> mosaic_result = mosaic._mosaic9(input_labels) + >>> mosaic_image = mosaic_result["img"] + """ + mosaic_labels = [] + s = self.imgsz + hp, wp = -1, -1 # height, width previous + for i in range(9): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img9 + if i == 0: # center + img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # top + c = s, s - h, s + w, s + elif i == 2: # top right + c = s + wp, s - h, s + wp + w, s + elif i == 3: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 4: # bottom right + c = s + w0, s + hp, s + w0 + w, s + hp + h + elif i == 5: # bottom + c = s + w0 - w, s + h0, s + w0, s + h0 + h + elif i == 6: # bottom left + c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h + elif i == 7: # left + c = s - w, s + h0 - h, s, s + h0 + elif i == 8: # top left + c = s - w, s + h0 - hp - h, s, s + h0 - hp + + padw, padh = c[:2] + x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coordinates + + # Image + img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax] + hp, wp = h, w # height, width previous for next iteration + + # Labels assuming imgsz*2 mosaic size + labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + + final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] + return final_labels + + @staticmethod + def _update_labels(labels, padw, padh): + """ + Updates label coordinates with padding values. + + This method adjusts the bounding box coordinates of object instances in the labels by adding padding + values. It also denormalizes the coordinates if they were previously normalized. + + Args: + labels (dict): A dictionary containing image and instance information. + padw (int): Padding width to be added to the x-coordinates. + padh (int): Padding height to be added to the y-coordinates. + + Returns: + (dict): Updated labels dictionary with adjusted instance coordinates. + + Examples: + >>> labels = {"img": np.zeros((100, 100, 3)), "instances": Instances(...)} + >>> padw, padh = 50, 50 + >>> updated_labels = Mosaic._update_labels(labels, padw, padh) + """ + nh, nw = labels["img"].shape[:2] + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(nw, nh) + labels["instances"].add_padding(padw, padh) + return labels + + def _cat_labels(self, mosaic_labels): + """ + Concatenates and processes labels for mosaic augmentation. + + This method combines labels from multiple images used in mosaic augmentation, clips instances to the + mosaic border, and removes zero-area boxes. + + Args: + mosaic_labels (List[Dict]): A list of label dictionaries for each image in the mosaic. + + Returns: + (dict): A dictionary containing concatenated and processed labels for the mosaic image, including: + - im_file (str): File path of the first image in the mosaic. + - ori_shape (Tuple[int, int]): Original shape of the first image. + - resized_shape (Tuple[int, int]): Shape of the mosaic image (imgsz * 2, imgsz * 2). + - cls (np.ndarray): Concatenated class labels. + - instances (Instances): Concatenated instance annotations. + - mosaic_border (Tuple[int, int]): Mosaic border size. + - texts (List[str], optional): Text labels if present in the original labels. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640) + >>> mosaic_labels = [{"cls": np.array([0, 1]), "instances": Instances(...)} for _ in range(4)] + >>> result = mosaic._cat_labels(mosaic_labels) + >>> print(result.keys()) + dict_keys(['im_file', 'ori_shape', 'resized_shape', 'cls', 'instances', 'mosaic_border']) + """ + if len(mosaic_labels) == 0: + return {} + cls = [] + instances = [] + imgsz = self.imgsz * 2 # mosaic imgsz + for labels in mosaic_labels: + cls.append(labels["cls"]) + instances.append(labels["instances"]) + # Final labels + final_labels = { + "im_file": mosaic_labels[0]["im_file"], + "ori_shape": mosaic_labels[0]["ori_shape"], + "resized_shape": (imgsz, imgsz), + "cls": np.concatenate(cls, 0), + "instances": Instances.concatenate(instances, axis=0), + "mosaic_border": self.border, + } + final_labels["instances"].clip(imgsz, imgsz) + good = final_labels["instances"].remove_zero_area_boxes() + final_labels["cls"] = final_labels["cls"][good] + if "texts" in mosaic_labels[0]: + final_labels["texts"] = mosaic_labels[0]["texts"] + return final_labels + + +class MixUp(BaseMixTransform): + """ + Applies MixUp augmentation to image datasets. + + This class implements the MixUp augmentation technique as described in the paper [mixup: Beyond Empirical Risk + Minimization](https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight. + + Attributes: + dataset (Any): The dataset to which MixUp augmentation will be applied. + pre_transform (Callable | None): Optional transform to apply before MixUp. + p (float): Probability of applying MixUp augmentation. + + Methods: + get_indexes: Returns a random index from the dataset. + _mix_transform: Applies MixUp augmentation to the input labels. + + Examples: + >>> from ultralytics.data.augment import MixUp + >>> dataset = YourDataset(...) # Your image dataset + >>> mixup = MixUp(dataset, p=0.5) + >>> augmented_labels = mixup(original_labels) + """ + + def __init__(self, dataset, pre_transform=None, p=0.0) -> None: + """ + Initializes the MixUp augmentation object. + + MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel + values and labels. This implementation is designed for use with the Ultralytics YOLO framework. + + Args: + dataset (Any): The dataset to which MixUp augmentation will be applied. + pre_transform (Callable | None): Optional transform to apply to images before MixUp. + p (float): Probability of applying MixUp augmentation to an image. Must be in the range [0, 1]. + + Examples: + >>> from ultralytics.data.dataset import YOLODataset + >>> dataset = YOLODataset("path/to/data.yaml") + >>> mixup = MixUp(dataset, pre_transform=None, p=0.5) + """ + super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) + + def get_indexes(self): + """ + Get a random index from the dataset. + + This method returns a single random index from the dataset, which is used to select an image for MixUp + augmentation. + + Returns: + (int): A random integer index within the range of the dataset length. + + Examples: + >>> mixup = MixUp(dataset) + >>> index = mixup.get_indexes() + >>> print(index) + 42 + """ + return random.randint(0, len(self.dataset) - 1) + + def _mix_transform(self, labels): + """ + Applies MixUp augmentation to the input labels. + + This method implements the MixUp augmentation technique as described in the paper + "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412). + + Args: + labels (dict): A dictionary containing the original image and label information. + + Returns: + (dict): A dictionary containing the mixed-up image and combined label information. + + Examples: + >>> mixer = MixUp(dataset) + >>> mixed_labels = mixer._mix_transform(labels) + """ + r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 + labels2 = labels["mix_labels"][0] + labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) + labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) + labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) + return labels + + +class RandomPerspective: + """ + Implements random perspective and affine transformations on images and corresponding annotations. + + This class applies random rotations, translations, scaling, shearing, and perspective transformations + to images and their associated bounding boxes, segments, and keypoints. It can be used as part of an + augmentation pipeline for object detection and instance segmentation tasks. + + Attributes: + degrees (float): Maximum absolute degree range for random rotations. + translate (float): Maximum translation as a fraction of the image size. + scale (float): Scaling factor range, e.g., scale=0.1 means 0.9-1.1. + shear (float): Maximum shear angle in degrees. + perspective (float): Perspective distortion factor. + border (Tuple[int, int]): Mosaic border size as (x, y). + pre_transform (Callable | None): Optional transform to apply before the random perspective. + + Methods: + affine_transform: Applies affine transformations to the input image. + apply_bboxes: Transforms bounding boxes using the affine matrix. + apply_segments: Transforms segments and generates new bounding boxes. + apply_keypoints: Transforms keypoints using the affine matrix. + __call__: Applies the random perspective transformation to images and annotations. + box_candidates: Filters transformed bounding boxes based on size and aspect ratio. + + Examples: + >>> transform = RandomPerspective(degrees=10, translate=0.1, scale=0.1, shear=10) + >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> labels = {"img": image, "cls": np.array([0, 1]), "instances": Instances(...)} + >>> result = transform(labels) + >>> transformed_image = result["img"] + >>> transformed_instances = result["instances"] + """ + + def __init__( + self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None + ): + """ + Initializes RandomPerspective object with transformation parameters. + + This class implements random perspective and affine transformations on images and corresponding bounding boxes, + segments, and keypoints. Transformations include rotation, translation, scaling, and shearing. + + Args: + degrees (float): Degree range for random rotations. + translate (float): Fraction of total width and height for random translation. + scale (float): Scaling factor interval, e.g., a scale factor of 0.5 allows a resize between 50%-150%. + shear (float): Shear intensity (angle in degrees). + perspective (float): Perspective distortion factor. + border (Tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right). + pre_transform (Callable | None): Function/transform to apply to the image before starting the random + transformation. + + Examples: + >>> transform = RandomPerspective(degrees=10.0, translate=0.1, scale=0.5, shear=5.0) + >>> result = transform(labels) # Apply random perspective to labels + """ + self.degrees = degrees + self.translate = translate + self.scale = scale + self.shear = shear + self.perspective = perspective + self.border = border # mosaic border + self.pre_transform = pre_transform + + def affine_transform(self, img, border): + """ + Applies a sequence of affine transformations centered around the image center. + + This function performs a series of geometric transformations on the input image, including + translation, perspective change, rotation, scaling, and shearing. The transformations are + applied in a specific order to maintain consistency. + + Args: + img (np.ndarray): Input image to be transformed. + border (Tuple[int, int]): Border dimensions for the transformed image. + + Returns: + (Tuple[np.ndarray, np.ndarray, float]): A tuple containing: + - np.ndarray: Transformed image. + - np.ndarray: 3x3 transformation matrix. + - float: Scale factor applied during the transformation. + + Examples: + >>> import numpy as np + >>> img = np.random.rand(100, 100, 3) + >>> border = (10, 10) + >>> transformed_img, matrix, scale = affine_transform(img, border) + """ + # Center + C = np.eye(3, dtype=np.float32) + + C[0, 2] = -img.shape[1] / 2 # x translation (pixels) + C[1, 2] = -img.shape[0] / 2 # y translation (pixels) + + # Perspective + P = np.eye(3, dtype=np.float32) + P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y) + P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x) + + # Rotation and Scale + R = np.eye(3, dtype=np.float32) + a = random.uniform(-self.degrees, self.degrees) + # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations + s = random.uniform(1 - self.scale, 1 + self.scale) + # s = 2 ** random.uniform(-scale, scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3, dtype=np.float32) + S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg) + + # Translation + T = np.eye(3, dtype=np.float32) + T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels) + T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels) + + # Combined rotation matrix + M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT + # Affine image + if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed + if self.perspective: + img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114)) + else: # affine + img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114)) + return img, M, s + + def apply_bboxes(self, bboxes, M): + """ + Apply affine transformation to bounding boxes. + + This function applies an affine transformation to a set of bounding boxes using the provided + transformation matrix. + + Args: + bboxes (torch.Tensor): Bounding boxes in xyxy format with shape (N, 4), where N is the number + of bounding boxes. + M (torch.Tensor): Affine transformation matrix with shape (3, 3). + + Returns: + (torch.Tensor): Transformed bounding boxes in xyxy format with shape (N, 4). + + Examples: + >>> bboxes = torch.tensor([[10, 10, 20, 20], [30, 30, 40, 40]]) + >>> M = torch.eye(3) + >>> transformed_bboxes = apply_bboxes(bboxes, M) + """ + n = len(bboxes) + if n == 0: + return bboxes + + xy = np.ones((n * 4, 3), dtype=bboxes.dtype) + xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine + + # Create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T + + def apply_segments(self, segments, M): + """ + Apply affine transformations to segments and generate new bounding boxes. + + This function applies affine transformations to input segments and generates new bounding boxes based on + the transformed segments. It clips the transformed segments to fit within the new bounding boxes. + + Args: + segments (np.ndarray): Input segments with shape (N, M, 2), where N is the number of segments and M is the + number of points in each segment. + M (np.ndarray): Affine transformation matrix with shape (3, 3). + + Returns: + (Tuple[np.ndarray, np.ndarray]): A tuple containing: + - New bounding boxes with shape (N, 4) in xyxy format. + - Transformed and clipped segments with shape (N, M, 2). + + Examples: + >>> segments = np.random.rand(10, 500, 2) # 10 segments with 500 points each + >>> M = np.eye(3) # Identity transformation matrix + >>> new_bboxes, new_segments = apply_segments(segments, M) + """ + n, num = segments.shape[:2] + if n == 0: + return [], segments + + xy = np.ones((n * num, 3), dtype=segments.dtype) + segments = segments.reshape(-1, 2) + xy[:, :2] = segments + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] + segments = xy.reshape(n, -1, 2) + bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0) + segments[..., 0] = segments[..., 0].clip(bboxes[:, 0:1], bboxes[:, 2:3]) + segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4]) + return bboxes, segments + + def apply_keypoints(self, keypoints, M): + """ + Applies affine transformation to keypoints. + + This method transforms the input keypoints using the provided affine transformation matrix. It handles + perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image + boundaries after transformation. + + Args: + keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances, + 17 is the number of keypoints per instance, and 3 represents (x, y, visibility). + M (np.ndarray): 3x3 affine transformation matrix. + + Returns: + (np.ndarray): Transformed keypoints array with the same shape as input (N, 17, 3). + + Examples: + >>> random_perspective = RandomPerspective() + >>> keypoints = np.random.rand(5, 17, 3) # 5 instances, 17 keypoints each + >>> M = np.eye(3) # Identity transformation + >>> transformed_keypoints = random_perspective.apply_keypoints(keypoints, M) + """ + n, nkpt = keypoints.shape[:2] + if n == 0: + return keypoints + xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype) + visible = keypoints[..., 2].reshape(n * nkpt, 1) + xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2) + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine + out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1]) + visible[out_mask] = 0 + return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3) + + def __call__(self, labels): + """ + Applies random perspective and affine transformations to an image and its associated labels. + + This method performs a series of transformations including rotation, translation, scaling, shearing, + and perspective distortion on the input image and adjusts the corresponding bounding boxes, segments, + and keypoints accordingly. + + Args: + labels (dict): A dictionary containing image data and annotations. + Must include: + 'img' (np.ndarray): The input image. + 'cls' (np.ndarray): Class labels. + 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints. + May include: + 'mosaic_border' (Tuple[int, int]): Border size for mosaic augmentation. + + Returns: + (dict): Transformed labels dictionary containing: + - 'img' (np.ndarray): The transformed image. + - 'cls' (np.ndarray): Updated class labels. + - 'instances' (Instances): Updated object instances. + - 'resized_shape' (Tuple[int, int]): New image shape after transformation. + + Examples: + >>> transform = RandomPerspective() + >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> labels = { + ... "img": image, + ... "cls": np.array([0, 1, 2]), + ... "instances": Instances(bboxes=np.array([[10, 10, 50, 50], [100, 100, 150, 150]])), + ... } + >>> result = transform(labels) + >>> assert result["img"].shape[:2] == result["resized_shape"] + """ + if self.pre_transform and "mosaic_border" not in labels: + labels = self.pre_transform(labels) + labels.pop("ratio_pad", None) # do not need ratio pad + + img = labels["img"] + cls = labels["cls"] + instances = labels.pop("instances") + # Make sure the coord formats are right + instances.convert_bbox(format="xyxy") + instances.denormalize(*img.shape[:2][::-1]) + + border = labels.pop("mosaic_border", self.border) + self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h + # M is affine matrix + # Scale for func:`box_candidates` + img, M, scale = self.affine_transform(img, border) + + bboxes = self.apply_bboxes(instances.bboxes, M) + + segments = instances.segments + keypoints = instances.keypoints + # Update bboxes if there are segments. + if len(segments): + bboxes, segments = self.apply_segments(segments, M) + + if keypoints is not None: + keypoints = self.apply_keypoints(keypoints, M) + new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False) + # Clip + new_instances.clip(*self.size) + + # Filter instances + instances.scale(scale_w=scale, scale_h=scale, bbox_only=True) + # Make the bboxes have the same scale with new_bboxes + i = self.box_candidates( + box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10 + ) + labels["instances"] = new_instances[i] + labels["cls"] = cls[i] + labels["img"] = img + labels["resized_shape"] = img.shape[:2] + return labels + + @staticmethod + def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): + """ + Compute candidate boxes for further processing based on size and aspect ratio criteria. + + This method compares boxes before and after augmentation to determine if they meet specified + thresholds for width, height, aspect ratio, and area. It's used to filter out boxes that have + been overly distorted or reduced by the augmentation process. + + Args: + box1 (numpy.ndarray): Original boxes before augmentation, shape (4, N) where n is the + number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates. + box2 (numpy.ndarray): Augmented boxes after transformation, shape (4, N). Format is + [x1, y1, x2, y2] in absolute coordinates. + wh_thr (float): Width and height threshold in pixels. Boxes smaller than this in either + dimension are rejected. + ar_thr (float): Aspect ratio threshold. Boxes with an aspect ratio greater than this + value are rejected. + area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than + this value are rejected. + eps (float): Small epsilon value to prevent division by zero. + + Returns: + (numpy.ndarray): Boolean array of shape (n) indicating which boxes are candidates. + True values correspond to boxes that meet all criteria. + + Examples: + >>> random_perspective = RandomPerspective() + >>> box1 = np.array([[0, 0, 100, 100], [0, 0, 50, 50]]).T + >>> box2 = np.array([[10, 10, 90, 90], [5, 5, 45, 45]]).T + >>> candidates = random_perspective.box_candidates(box1, box2) + >>> print(candidates) + [True True] + """ + w1, h1 = box1[2] - box1[0], box1[3] - box1[1] + w2, h2 = box2[2] - box2[0], box2[3] - box2[1] + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates + + +class RandomHSV: + """ + Randomly adjusts the Hue, Saturation, and Value (HSV) channels of an image. + + This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain. + + Attributes: + hgain (float): Maximum variation for hue. Range is typically [0, 1]. + sgain (float): Maximum variation for saturation. Range is typically [0, 1]. + vgain (float): Maximum variation for value. Range is typically [0, 1]. + + Methods: + __call__: Applies random HSV augmentation to an image. + + Examples: + >>> import numpy as np + >>> from ultralytics.data.augment import RandomHSV + >>> augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5) + >>> image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + >>> labels = {"img": image} + >>> augmenter(labels) + >>> augmented_image = augmented_labels["img"] + """ + + def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None: + """ + Initializes the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation. + + This class applies random adjustments to the HSV channels of an image within specified limits. + + Args: + hgain (float): Maximum variation for hue. Should be in the range [0, 1]. + sgain (float): Maximum variation for saturation. Should be in the range [0, 1]. + vgain (float): Maximum variation for value. Should be in the range [0, 1]. + + Examples: + >>> hsv_aug = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5) + >>> hsv_aug(image) + """ + self.hgain = hgain + self.sgain = sgain + self.vgain = vgain + + def __call__(self, labels): + """ + Applies random HSV augmentation to an image within predefined limits. + + This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels. + The adjustments are made within the limits set by hgain, sgain, and vgain during initialization. + + Args: + labels (dict): A dictionary containing image data and metadata. Must include an 'img' key with + the image as a numpy array. + + Returns: + (None): The function modifies the input 'labels' dictionary in-place, updating the 'img' key + with the HSV-augmented image. + + Examples: + >>> hsv_augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5) + >>> labels = {"img": np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)} + >>> hsv_augmenter(labels) + >>> augmented_img = labels["img"] + """ + if self.hgain or self.sgain or self.vgain: + img = labels["img"] + dtype = img.dtype # uint8 + + r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] # random gains + x = np.arange(0, 256, dtype=r.dtype) + # lut_hue = ((x * (r[0] + 1)) % 180).astype(dtype) # original hue implementation from ultralytics<=8.3.78 + lut_hue = ((x + r[0] * 180) % 180).astype(dtype) + lut_sat = np.clip(x * (r[1] + 1), 0, 255).astype(dtype) + lut_val = np.clip(x * (r[2] + 1), 0, 255).astype(dtype) + lut_sat[0] = 0 # prevent pure white changing color, introduced in 8.3.79 + + hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) + cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed + return labels + + +class RandomFlip: + """ + Applies a random horizontal or vertical flip to an image with a given probability. + + This class performs random image flipping and updates corresponding instance annotations such as + bounding boxes and keypoints. + + Attributes: + p (float): Probability of applying the flip. Must be between 0 and 1. + direction (str): Direction of flip, either 'horizontal' or 'vertical'. + flip_idx (array-like): Index mapping for flipping keypoints, if applicable. + + Methods: + __call__: Applies the random flip transformation to an image and its annotations. + + Examples: + >>> transform = RandomFlip(p=0.5, direction="horizontal") + >>> result = transform({"img": image, "instances": instances}) + >>> flipped_image = result["img"] + >>> flipped_instances = result["instances"] + """ + + def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None: + """ + Initializes the RandomFlip class with probability and direction. + + This class applies a random horizontal or vertical flip to an image with a given probability. + It also updates any instances (bounding boxes, keypoints, etc.) accordingly. + + Args: + p (float): The probability of applying the flip. Must be between 0 and 1. + direction (str): The direction to apply the flip. Must be 'horizontal' or 'vertical'. + flip_idx (List[int] | None): Index mapping for flipping keypoints, if any. + + Raises: + AssertionError: If direction is not 'horizontal' or 'vertical', or if p is not between 0 and 1. + + Examples: + >>> flip = RandomFlip(p=0.5, direction="horizontal") + >>> flip_with_idx = RandomFlip(p=0.7, direction="vertical", flip_idx=[1, 0, 3, 2, 5, 4]) + """ + assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}" + assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." + + self.p = p + self.direction = direction + self.flip_idx = flip_idx + + def __call__(self, labels): + """ + Applies random flip to an image and updates any instances like bounding boxes or keypoints accordingly. + + This method randomly flips the input image either horizontally or vertically based on the initialized + probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to + match the flipped image. + + Args: + labels (dict): A dictionary containing the following keys: + 'img' (numpy.ndarray): The image to be flipped. + 'instances' (ultralytics.utils.instance.Instances): An object containing bounding boxes and + optionally keypoints. + + Returns: + (dict): The same dictionary with the flipped image and updated instances: + 'img' (numpy.ndarray): The flipped image. + 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image. + + Examples: + >>> labels = {"img": np.random.rand(640, 640, 3), "instances": Instances(...)} + >>> random_flip = RandomFlip(p=0.5, direction="horizontal") + >>> flipped_labels = random_flip(labels) + """ + img = labels["img"] + instances = labels.pop("instances") + instances.convert_bbox(format="xywh") + h, w = img.shape[:2] + h = 1 if instances.normalized else h + w = 1 if instances.normalized else w + + # Flip up-down + if self.direction == "vertical" and random.random() < self.p: + img = np.flipud(img) + instances.flipud(h) + if self.direction == "horizontal" and random.random() < self.p: + img = np.fliplr(img) + instances.fliplr(w) + # For keypoints + if self.flip_idx is not None and instances.keypoints is not None: + instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :]) + labels["img"] = np.ascontiguousarray(img) + labels["instances"] = instances + return labels + + +class LetterBox: + """ + Resize image and padding for detection, instance segmentation, pose. + + This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates + corresponding labels and bounding boxes. + + Attributes: + new_shape (tuple): Target shape (height, width) for resizing. + auto (bool): Whether to use minimum rectangle. + scale_fill (bool): Whether to stretch the image to new_shape. + scaleup (bool): Whether to allow scaling up. If False, only scale down. + stride (int): Stride for rounding padding. + center (bool): Whether to center the image or align to top-left. + + Methods: + __call__: Resize and pad image, update labels and bounding boxes. + + Examples: + >>> transform = LetterBox(new_shape=(640, 640)) + >>> result = transform(labels) + >>> resized_img = result["img"] + >>> updated_instances = result["instances"] + """ + + def __init__(self, new_shape=(640, 640), auto=False, scale_fill=False, scaleup=True, center=True, stride=32): + """ + Initialize LetterBox object for resizing and padding images. + + This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation + tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing. + + Args: + new_shape (Tuple[int, int]): Target size (height, width) for the resized image. + auto (bool): If True, use minimum rectangle to resize. If False, use new_shape directly. + scale_fill (bool): If True, stretch the image to new_shape without padding. + scaleup (bool): If True, allow scaling up. If False, only scale down. + center (bool): If True, center the placed image. If False, place image in top-left corner. + stride (int): Stride of the model (e.g., 32 for YOLOv5). + + Attributes: + new_shape (Tuple[int, int]): Target size for the resized image. + auto (bool): Flag for using minimum rectangle resizing. + scale_fill (bool): Flag for stretching image without padding. + scaleup (bool): Flag for allowing upscaling. + stride (int): Stride value for ensuring image size is divisible by stride. + + Examples: + >>> letterbox = LetterBox(new_shape=(640, 640), auto=False, scale_fill=False, scaleup=True, stride=32) + >>> resized_img = letterbox(original_img) + """ + self.new_shape = new_shape + self.auto = auto + self.scale_fill = scale_fill + self.scaleup = scaleup + self.stride = stride + self.center = center # Put the image in the middle or top-left + + def __call__(self, labels=None, image=None): + """ + Resizes and pads an image for object detection, instance segmentation, or pose estimation tasks. + + This method applies letterboxing to the input image, which involves resizing the image while maintaining its + aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly. + + Args: + labels (Dict | None): A dictionary containing image data and associated labels, or empty dict if None. + image (np.ndarray | None): The input image as a numpy array. If None, the image is taken from 'labels'. + + Returns: + (Dict | Tuple): If 'labels' is provided, returns an updated dictionary with the resized and padded image, + updated labels, and additional metadata. If 'labels' is empty, returns a tuple containing the resized + and padded image, and a tuple of (ratio, (left_pad, top_pad)). + + Examples: + >>> letterbox = LetterBox(new_shape=(640, 640)) + >>> result = letterbox(labels={"img": np.zeros((480, 640, 3)), "instances": Instances(...)}) + >>> resized_img = result["img"] + >>> updated_instances = result["instances"] + """ + if labels is None: + labels = {} + img = labels.get("img") if image is None else image + shape = img.shape[:2] # current shape [height, width] + new_shape = labels.pop("rect_shape", self.new_shape) + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not self.scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if self.auto: # minimum rectangle + dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding + elif self.scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + if self.center: + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + if labels.get("ratio_pad"): + labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation + + if len(labels): + labels = self._update_labels(labels, ratio, left, top) + labels["img"] = img + labels["resized_shape"] = new_shape + return labels + else: + return img + + @staticmethod + def _update_labels(labels, ratio, padw, padh): + """ + Updates labels after applying letterboxing to an image. + + This method modifies the bounding box coordinates of instances in the labels + to account for resizing and padding applied during letterboxing. + + Args: + labels (dict): A dictionary containing image labels and instances. + ratio (Tuple[float, float]): Scaling ratios (width, height) applied to the image. + padw (float): Padding width added to the image. + padh (float): Padding height added to the image. + + Returns: + (dict): Updated labels dictionary with modified instance coordinates. + + Examples: + >>> letterbox = LetterBox(new_shape=(640, 640)) + >>> labels = {"instances": Instances(...)} + >>> ratio = (0.5, 0.5) + >>> padw, padh = 10, 20 + >>> updated_labels = letterbox._update_labels(labels, ratio, padw, padh) + """ + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) + labels["instances"].scale(*ratio) + labels["instances"].add_padding(padw, padh) + return labels + + +class CopyPaste(BaseMixTransform): + """ + CopyPaste class for applying Copy-Paste augmentation to image datasets. + + This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong + Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from + different images to create new training samples. + + Attributes: + dataset (Any): The dataset to which Copy-Paste augmentation will be applied. + pre_transform (Callable | None): Optional transform to apply before Copy-Paste. + p (float): Probability of applying Copy-Paste augmentation. + + Methods: + get_indexes: Returns a random index from the dataset. + _mix_transform: Applies Copy-Paste augmentation to the input labels. + __call__: Applies the Copy-Paste transformation to images and annotations. + + Examples: + >>> from ultralytics.data.augment import CopyPaste + >>> dataset = YourDataset(...) # Your image dataset + >>> copypaste = CopyPaste(dataset, p=0.5) + >>> augmented_labels = copypaste(original_labels) + """ + + def __init__(self, dataset=None, pre_transform=None, p=0.5, mode="flip") -> None: + """Initializes CopyPaste object with dataset, pre_transform, and probability of applying MixUp.""" + super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) + assert mode in {"flip", "mixup"}, f"Expected `mode` to be `flip` or `mixup`, but got {mode}." + self.mode = mode + + def get_indexes(self): + """Returns a list of random indexes from the dataset for CopyPaste augmentation.""" + return random.randint(0, len(self.dataset) - 1) + + def _mix_transform(self, labels): + """Applies Copy-Paste augmentation to combine objects from another image into the current image.""" + labels2 = labels["mix_labels"][0] + return self._transform(labels, labels2) + + def __call__(self, labels): + """Applies Copy-Paste augmentation to an image and its labels.""" + if len(labels["instances"].segments) == 0 or self.p == 0: + return labels + if self.mode == "flip": + return self._transform(labels) + + # Get index of one or three other images + indexes = self.get_indexes() + if isinstance(indexes, int): + indexes = [indexes] + + # Get images information will be used for Mosaic or MixUp + mix_labels = [self.dataset.get_image_and_label(i) for i in indexes] + + if self.pre_transform is not None: + for i, data in enumerate(mix_labels): + mix_labels[i] = self.pre_transform(data) + labels["mix_labels"] = mix_labels + + # Update cls and texts + labels = self._update_label_text(labels) + # Mosaic or MixUp + labels = self._mix_transform(labels) + labels.pop("mix_labels", None) + return labels + + def _transform(self, labels1, labels2={}): + """Applies Copy-Paste augmentation to combine objects from another image into the current image.""" + im = labels1["img"] + cls = labels1["cls"] + h, w = im.shape[:2] + instances = labels1.pop("instances") + instances.convert_bbox(format="xyxy") + instances.denormalize(w, h) + + im_new = np.zeros(im.shape, np.uint8) + instances2 = labels2.pop("instances", None) + if instances2 is None: + instances2 = deepcopy(instances) + instances2.fliplr(w) + ioa = bbox_ioa(instances2.bboxes, instances.bboxes) # intersection over area, (N, M) + indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, ) + n = len(indexes) + sorted_idx = np.argsort(ioa.max(1)[indexes]) + indexes = indexes[sorted_idx] + for j in indexes[: round(self.p * n)]: + cls = np.concatenate((cls, labels2.get("cls", cls)[[j]]), axis=0) + instances = Instances.concatenate((instances, instances2[[j]]), axis=0) + cv2.drawContours(im_new, instances2.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED) + + result = labels2.get("img", cv2.flip(im, 1)) # augment segments + i = im_new.astype(bool) + im[i] = result[i] + + labels1["img"] = im + labels1["cls"] = cls + labels1["instances"] = instances + return labels1 + + +class Albumentations: + """ + Albumentations transformations for image augmentation. + + This class applies various image transformations using the Albumentations library. It includes operations such as + Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes + in brightness and contrast, RandomGamma, and image quality reduction through compression. + + Attributes: + p (float): Probability of applying the transformations. + transform (albumentations.Compose): Composed Albumentations transforms. + contains_spatial (bool): Indicates if the transforms include spatial operations. + + Methods: + __call__: Applies the Albumentations transformations to the input labels. + + Examples: + >>> transform = Albumentations(p=0.5) + >>> augmented_labels = transform(labels) + + Notes: + - The Albumentations package must be installed to use this class. + - If the package is not installed or an error occurs during initialization, the transform will be set to None. + - Spatial transforms are handled differently and require special processing for bounding boxes. + """ + + def __init__(self, p=1.0): + """ + Initialize the Albumentations transform object for YOLO bbox formatted parameters. + + This class applies various image augmentations using the Albumentations library, including Blur, Median Blur, + conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and + contrast, RandomGamma, and image quality reduction through compression. + + Args: + p (float): Probability of applying the augmentations. Must be between 0 and 1. + + Attributes: + p (float): Probability of applying the augmentations. + transform (albumentations.Compose): Composed Albumentations transforms. + contains_spatial (bool): Indicates if the transforms include spatial transformations. + + Raises: + ImportError: If the Albumentations package is not installed. + Exception: For any other errors during initialization. + + Examples: + >>> transform = Albumentations(p=0.5) + >>> augmented = transform(image=image, bboxes=bboxes, class_labels=classes) + >>> augmented_image = augmented["image"] + >>> augmented_bboxes = augmented["bboxes"] + + Notes: + - Requires Albumentations version 1.0.3 or higher. + - Spatial transforms are handled differently to ensure bbox compatibility. + - Some transforms are applied with very low probability (0.01) by default. + """ + self.p = p + self.transform = None + prefix = colorstr("albumentations: ") + + try: + import albumentations as A + + check_version(A.__version__, "1.0.3", hard=True) # version requirement + + # List of possible spatial transforms + spatial_transforms = { + "Affine", + "BBoxSafeRandomCrop", + "CenterCrop", + "CoarseDropout", + "Crop", + "CropAndPad", + "CropNonEmptyMaskIfExists", + "D4", + "ElasticTransform", + "Flip", + "GridDistortion", + "GridDropout", + "HorizontalFlip", + "Lambda", + "LongestMaxSize", + "MaskDropout", + "MixUp", + "Morphological", + "NoOp", + "OpticalDistortion", + "PadIfNeeded", + "Perspective", + "PiecewiseAffine", + "PixelDropout", + "RandomCrop", + "RandomCropFromBorders", + "RandomGridShuffle", + "RandomResizedCrop", + "RandomRotate90", + "RandomScale", + "RandomSizedBBoxSafeCrop", + "RandomSizedCrop", + "Resize", + "Rotate", + "SafeRotate", + "ShiftScaleRotate", + "SmallestMaxSize", + "Transpose", + "VerticalFlip", + "XYMasking", + } # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms + + # Transforms + T = [ + A.Blur(p=0.01), + A.MedianBlur(p=0.01), + A.ToGray(p=0.01), + A.CLAHE(p=0.01), + A.RandomBrightnessContrast(p=0.0), + A.RandomGamma(p=0.0), + A.ImageCompression(quality_range=(75, 100), p=0.0), + ] + + # Compose transforms + self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T) + self.transform = ( + A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) + if self.contains_spatial + else A.Compose(T) + ) + if hasattr(self.transform, "set_random_seed"): + # Required for deterministic transforms in albumentations>=1.4.21 + self.transform.set_random_seed(torch.initial_seed()) + LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) + except ImportError: # package not installed, skip + pass + except Exception as e: + LOGGER.info(f"{prefix}{e}") + + def __call__(self, labels): + """ + Applies Albumentations transformations to input labels. + + This method applies a series of image augmentations using the Albumentations library. It can perform both + spatial and non-spatial transformations on the input image and its corresponding labels. + + Args: + labels (dict): A dictionary containing image data and annotations. Expected keys are: + - 'img': numpy.ndarray representing the image + - 'cls': numpy.ndarray of class labels + - 'instances': object containing bounding boxes and other instance information + + Returns: + (dict): The input dictionary with augmented image and updated annotations. + + Examples: + >>> transform = Albumentations(p=0.5) + >>> labels = { + ... "img": np.random.rand(640, 640, 3), + ... "cls": np.array([0, 1]), + ... "instances": Instances(bboxes=np.array([[0, 0, 1, 1], [0.5, 0.5, 0.8, 0.8]])), + ... } + >>> augmented = transform(labels) + >>> assert augmented["img"].shape == (640, 640, 3) + + Notes: + - The method applies transformations with probability self.p. + - Spatial transforms update bounding boxes, while non-spatial transforms only modify the image. + - Requires the Albumentations library to be installed. + """ + if self.transform is None or random.random() > self.p: + return labels + + if self.contains_spatial: + cls = labels["cls"] + if len(cls): + im = labels["img"] + labels["instances"].convert_bbox("xywh") + labels["instances"].normalize(*im.shape[:2][::-1]) + bboxes = labels["instances"].bboxes + # TODO: add supports of segments and keypoints + new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed + if len(new["class_labels"]) > 0: # skip update if no bbox in new im + labels["img"] = new["image"] + labels["cls"] = np.array(new["class_labels"]) + bboxes = np.array(new["bboxes"], dtype=np.float32) + labels["instances"].update(bboxes=bboxes) + else: + labels["img"] = self.transform(image=labels["img"])["image"] # transformed + + return labels + + +class Format: + """ + A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks. + + This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader. + + Attributes: + bbox_format (str): Format for bounding boxes. Options are 'xywh' or 'xyxy'. + normalize (bool): Whether to normalize bounding boxes. + return_mask (bool): Whether to return instance masks for segmentation. + return_keypoint (bool): Whether to return keypoints for pose estimation. + return_obb (bool): Whether to return oriented bounding boxes. + mask_ratio (int): Downsample ratio for masks. + mask_overlap (bool): Whether to overlap masks. + batch_idx (bool): Whether to keep batch indexes. + bgr (float): The probability to return BGR images. + + Methods: + __call__: Formats labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints. + _format_img: Converts image from Numpy array to PyTorch tensor. + _format_segments: Converts polygon points to bitmap masks. + + Examples: + >>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True) + >>> formatted_labels = formatter(labels) + >>> img = formatted_labels["img"] + >>> bboxes = formatted_labels["bboxes"] + >>> masks = formatted_labels["masks"] + """ + + def __init__( + self, + bbox_format="xywh", + normalize=True, + return_mask=False, + return_keypoint=False, + return_obb=False, + mask_ratio=4, + mask_overlap=True, + batch_idx=True, + bgr=0.0, + ): + """ + Initializes the Format class with given parameters for image and instance annotation formatting. + + This class standardizes image and instance annotations for object detection, instance segmentation, and pose + estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`. + + Args: + bbox_format (str): Format for bounding boxes. Options are 'xywh', 'xyxy', etc. + normalize (bool): Whether to normalize bounding boxes to [0,1]. + return_mask (bool): If True, returns instance masks for segmentation tasks. + return_keypoint (bool): If True, returns keypoints for pose estimation tasks. + return_obb (bool): If True, returns oriented bounding boxes. + mask_ratio (int): Downsample ratio for masks. + mask_overlap (bool): If True, allows mask overlap. + batch_idx (bool): If True, keeps batch indexes. + bgr (float): Probability of returning BGR images instead of RGB. + + Attributes: + bbox_format (str): Format for bounding boxes. + normalize (bool): Whether bounding boxes are normalized. + return_mask (bool): Whether to return instance masks. + return_keypoint (bool): Whether to return keypoints. + return_obb (bool): Whether to return oriented bounding boxes. + mask_ratio (int): Downsample ratio for masks. + mask_overlap (bool): Whether masks can overlap. + batch_idx (bool): Whether to keep batch indexes. + bgr (float): The probability to return BGR images. + + Examples: + >>> format = Format(bbox_format="xyxy", return_mask=True, return_keypoint=False) + >>> print(format.bbox_format) + xyxy + """ + self.bbox_format = bbox_format + self.normalize = normalize + self.return_mask = return_mask # set False when training detection only + self.return_keypoint = return_keypoint + self.return_obb = return_obb + self.mask_ratio = mask_ratio + self.mask_overlap = mask_overlap + self.batch_idx = batch_idx # keep the batch indexes + self.bgr = bgr + + def __call__(self, labels): + """ + Formats image annotations for object detection, instance segmentation, and pose estimation tasks. + + This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch + DataLoader. It processes the input labels dictionary, converting annotations to the specified format and + applying normalization if required. + + Args: + labels (dict): A dictionary containing image and annotation data with the following keys: + - 'img': The input image as a numpy array. + - 'cls': Class labels for instances. + - 'instances': An Instances object containing bounding boxes, segments, and keypoints. + + Returns: + (dict): A dictionary with formatted data, including: + - 'img': Formatted image tensor. + - 'cls': Class label's tensor. + - 'bboxes': Bounding boxes tensor in the specified format. + - 'masks': Instance masks tensor (if return_mask is True). + - 'keypoints': Keypoints tensor (if return_keypoint is True). + - 'batch_idx': Batch index tensor (if batch_idx is True). + + Examples: + >>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True) + >>> labels = {"img": np.random.rand(640, 640, 3), "cls": np.array([0, 1]), "instances": Instances(...)} + >>> formatted_labels = formatter(labels) + >>> print(formatted_labels.keys()) + """ + img = labels.pop("img") + h, w = img.shape[:2] + cls = labels.pop("cls") + instances = labels.pop("instances") + instances.convert_bbox(format=self.bbox_format) + instances.denormalize(w, h) + nl = len(instances) + + if self.return_mask: + if nl: + masks, instances, cls = self._format_segments(instances, cls, w, h) + masks = torch.from_numpy(masks) + else: + masks = torch.zeros( + 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio + ) + labels["masks"] = masks + labels["img"] = self._format_img(img) + labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) + labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) + if self.return_keypoint: + labels["keypoints"] = torch.from_numpy(instances.keypoints) + if self.normalize: + labels["keypoints"][..., 0] /= w + labels["keypoints"][..., 1] /= h + if self.return_obb: + labels["bboxes"] = ( + xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5)) + ) + # NOTE: need to normalize obb in xywhr format for width-height consistency + if self.normalize: + labels["bboxes"][:, [0, 2]] /= w + labels["bboxes"][:, [1, 3]] /= h + # Then we can use collate_fn + if self.batch_idx: + labels["batch_idx"] = torch.zeros(nl) + return labels + + def _format_img(self, img): + """ + Formats an image for YOLO from a Numpy array to a PyTorch tensor. + + This function performs the following operations: + 1. Ensures the image has 3 dimensions (adds a channel dimension if needed). + 2. Transposes the image from HWC to CHW format. + 3. Optionally flips the color channels from RGB to BGR. + 4. Converts the image to a contiguous array. + 5. Converts the Numpy array to a PyTorch tensor. + + Args: + img (np.ndarray): Input image as a Numpy array with shape (H, W, C) or (H, W). + + Returns: + (torch.Tensor): Formatted image as a PyTorch tensor with shape (C, H, W). + + Examples: + >>> import numpy as np + >>> img = np.random.rand(100, 100, 3) + >>> formatted_img = self._format_img(img) + >>> print(formatted_img.shape) + torch.Size([3, 100, 100]) + """ + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = img.transpose(2, 0, 1) + img = np.ascontiguousarray(img[::-1] if random.uniform(0, 1) > self.bgr else img) + img = torch.from_numpy(img) + return img + + def _format_segments(self, instances, cls, w, h): + """ + Converts polygon segments to bitmap masks. + + Args: + instances (Instances): Object containing segment information. + cls (numpy.ndarray): Class labels for each instance. + w (int): Width of the image. + h (int): Height of the image. + + Returns: + masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True. + instances (Instances): Updated instances object with sorted segments if mask_overlap is True. + cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True. + + Notes: + - If self.mask_overlap is True, masks are overlapped and sorted by area. + - If self.mask_overlap is False, each mask is represented separately. + - Masks are downsampled according to self.mask_ratio. + """ + segments = instances.segments + if self.mask_overlap: + masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio) + masks = masks[None] # (640, 640) -> (1, 640, 640) + instances = instances[sorted_idx] + cls = cls[sorted_idx] + else: + masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio) + + return masks, instances, cls + + +class RandomLoadText: + """ + Randomly samples positive and negative texts and updates class indices accordingly. + + This class is responsible for sampling texts from a given set of class texts, including both positive + (present in the image) and negative (not present in the image) samples. It updates the class indices + to reflect the sampled texts and can optionally pad the text list to a fixed length. + + Attributes: + prompt_format (str): Format string for text prompts. + neg_samples (Tuple[int, int]): Range for randomly sampling negative texts. + max_samples (int): Maximum number of different text samples in one image. + padding (bool): Whether to pad texts to max_samples. + padding_value (str): The text used for padding when padding is True. + + Methods: + __call__: Processes the input labels and returns updated classes and texts. + + Examples: + >>> loader = RandomLoadText(prompt_format="Object: {}", neg_samples=(5, 10), max_samples=20) + >>> labels = {"cls": [0, 1, 2], "texts": [["cat"], ["dog"], ["bird"]], "instances": [...]} + >>> updated_labels = loader(labels) + >>> print(updated_labels["texts"]) + ['Object: cat', 'Object: dog', 'Object: bird', 'Object: elephant', 'Object: car'] + """ + + def __init__( + self, + prompt_format: str = "{}", + neg_samples: Tuple[int, int] = (80, 80), + max_samples: int = 80, + padding: bool = False, + padding_value: str = "", + ) -> None: + """ + Initializes the RandomLoadText class for randomly sampling positive and negative texts. + + This class is designed to randomly sample positive texts and negative texts, and update the class + indices accordingly to the number of samples. It can be used for text-based object detection tasks. + + Args: + prompt_format (str): Format string for the prompt. Default is '{}'. The format string should + contain a single pair of curly braces {} where the text will be inserted. + neg_samples (Tuple[int, int]): A range to randomly sample negative texts. The first integer + specifies the minimum number of negative samples, and the second integer specifies the + maximum. Default is (80, 80). + max_samples (int): The maximum number of different text samples in one image. Default is 80. + padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always + be equal to max_samples. Default is False. + padding_value (str): The padding text to use when padding is True. Default is an empty string. + + Attributes: + prompt_format (str): The format string for the prompt. + neg_samples (Tuple[int, int]): The range for sampling negative texts. + max_samples (int): The maximum number of text samples. + padding (bool): Whether padding is enabled. + padding_value (str): The value used for padding. + + Examples: + >>> random_load_text = RandomLoadText(prompt_format="Object: {}", neg_samples=(50, 100), max_samples=120) + >>> random_load_text.prompt_format + 'Object: {}' + >>> random_load_text.neg_samples + (50, 100) + >>> random_load_text.max_samples + 120 + """ + self.prompt_format = prompt_format + self.neg_samples = neg_samples + self.max_samples = max_samples + self.padding = padding + self.padding_value = padding_value + + def __call__(self, labels: dict) -> dict: + """ + Randomly samples positive and negative texts and updates class indices accordingly. + + This method samples positive texts based on the existing class labels in the image, and randomly + selects negative texts from the remaining classes. It then updates the class indices to match the + new sampled text order. + + Args: + labels (dict): A dictionary containing image labels and metadata. Must include 'texts' and 'cls' keys. + + Returns: + (dict): Updated labels dictionary with new 'cls' and 'texts' entries. + + Examples: + >>> loader = RandomLoadText(prompt_format="A photo of {}", neg_samples=(5, 10), max_samples=20) + >>> labels = {"cls": np.array([[0], [1], [2]]), "texts": [["dog"], ["cat"], ["bird"]]} + >>> updated_labels = loader(labels) + """ + assert "texts" in labels, "No texts found in labels." + class_texts = labels["texts"] + num_classes = len(class_texts) + cls = np.asarray(labels.pop("cls"), dtype=int) + pos_labels = np.unique(cls).tolist() + + if len(pos_labels) > self.max_samples: + pos_labels = random.sample(pos_labels, k=self.max_samples) + + neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples)) + neg_labels = [i for i in range(num_classes) if i not in pos_labels] + neg_labels = random.sample(neg_labels, k=neg_samples) + + sampled_labels = pos_labels + neg_labels + random.shuffle(sampled_labels) + + label2ids = {label: i for i, label in enumerate(sampled_labels)} + valid_idx = np.zeros(len(labels["instances"]), dtype=bool) + new_cls = [] + for i, label in enumerate(cls.squeeze(-1).tolist()): + if label not in label2ids: + continue + valid_idx[i] = True + new_cls.append([label2ids[label]]) + labels["instances"] = labels["instances"][valid_idx] + labels["cls"] = np.array(new_cls) + + # Randomly select one prompt when there's more than one prompts + texts = [] + for label in sampled_labels: + prompts = class_texts[label] + assert len(prompts) > 0 + prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))]) + texts.append(prompt) + + if self.padding: + valid_labels = len(pos_labels) + len(neg_labels) + num_padding = self.max_samples - valid_labels + if num_padding > 0: + texts += [self.padding_value] * num_padding + + labels["texts"] = texts + return labels + + +def v8_transforms(dataset, imgsz, hyp, stretch=False): + """ + Applies a series of image transformations for training. + + This function creates a composition of image augmentation techniques to prepare images for YOLO training. + It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments. + + Args: + dataset (Dataset): The dataset object containing image data and annotations. + imgsz (int): The target image size for resizing. + hyp (Namespace): A dictionary of hyperparameters controlling various aspects of the transformations. + stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing. + + Returns: + (Compose): A composition of image transformations to be applied to the dataset. + + Examples: + >>> from ultralytics.data.dataset import YOLODataset + >>> from ultralytics.utils import IterableSimpleNamespace + >>> dataset = YOLODataset(img_path="path/to/images", imgsz=640) + >>> hyp = IterableSimpleNamespace(mosaic=1.0, copy_paste=0.5, degrees=10.0, translate=0.2, scale=0.9) + >>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp) + >>> augmented_data = transforms(dataset[0]) + """ + mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic) + affine = RandomPerspective( + degrees=hyp.degrees, + translate=hyp.translate, + scale=hyp.scale, + shear=hyp.shear, + perspective=hyp.perspective, + pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), + ) + + pre_transform = Compose([mosaic, affine]) + if hyp.copy_paste_mode == "flip": + pre_transform.insert(1, CopyPaste(p=hyp.copy_paste, mode=hyp.copy_paste_mode)) + else: + pre_transform.append( + CopyPaste( + dataset, + pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), affine]), + p=hyp.copy_paste, + mode=hyp.copy_paste_mode, + ) + ) + flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation + if dataset.use_keypoints: + kpt_shape = dataset.data.get("kpt_shape", None) + if len(flip_idx) == 0 and hyp.fliplr > 0.0: + hyp.fliplr = 0.0 + LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") + elif flip_idx and (len(flip_idx) != kpt_shape[0]): + raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") + + return Compose( + [ + pre_transform, + MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), + Albumentations(p=1.0), + RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), + RandomFlip(direction="vertical", p=hyp.flipud), + RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx), + ] + ) # transforms + + +# Classification augmentations ----------------------------------------------------------------------------------------- +def classify_transforms( + size=224, + mean=DEFAULT_MEAN, + std=DEFAULT_STD, + interpolation="BILINEAR", + crop_fraction: float = DEFAULT_CROP_FRACTION, +): + """ + Creates a composition of image transforms for classification tasks. + + This function generates a sequence of torchvision transforms suitable for preprocessing images + for classification models during evaluation or inference. The transforms include resizing, + center cropping, conversion to tensor, and normalization. + + Args: + size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a + tuple, it defines (height, width). + mean (tuple): Mean values for each RGB channel used in normalization. + std (tuple): Standard deviation values for each RGB channel used in normalization. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. + crop_fraction (float): Fraction of the image to be cropped. + + Returns: + (torchvision.transforms.Compose): A composition of torchvision transforms. + + Examples: + >>> transforms = classify_transforms(size=224) + >>> img = Image.open("path/to/image.jpg") + >>> transformed_img = transforms(img) + """ + import torchvision.transforms as T # scope for faster 'import ultralytics' + + if isinstance(size, (tuple, list)): + assert len(size) == 2, f"'size' tuples must be length 2, not length {len(size)}" + scale_size = tuple(math.floor(x / crop_fraction) for x in size) + else: + scale_size = math.floor(size / crop_fraction) + scale_size = (scale_size, scale_size) + + # Aspect ratio is preserved, crops center within image, no borders are added, image is lost + if scale_size[0] == scale_size[1]: + # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg) + tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))] + else: + # Resize the shortest edge to matching target dim for non-square target + tfl = [T.Resize(scale_size)] + tfl.extend( + [ + T.CenterCrop(size), + T.ToTensor(), + T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), + ] + ) + return T.Compose(tfl) + + +# Classification training augmentations -------------------------------------------------------------------------------- +def classify_augmentations( + size=224, + mean=DEFAULT_MEAN, + std=DEFAULT_STD, + scale=None, + ratio=None, + hflip=0.5, + vflip=0.0, + auto_augment=None, + hsv_h=0.015, # image HSV-Hue augmentation (fraction) + hsv_s=0.4, # image HSV-Saturation augmentation (fraction) + hsv_v=0.4, # image HSV-Value augmentation (fraction) + force_color_jitter=False, + erasing=0.0, + interpolation="BILINEAR", +): + """ + Creates a composition of image augmentation transforms for classification tasks. + + This function generates a set of image transformations suitable for training classification models. It includes + options for resizing, flipping, color jittering, auto augmentation, and random erasing. + + Args: + size (int): Target size for the image after transformations. + mean (tuple): Mean values for normalization, one per channel. + std (tuple): Standard deviation values for normalization, one per channel. + scale (tuple | None): Range of size of the origin size cropped. + ratio (tuple | None): Range of aspect ratio of the origin aspect ratio cropped. + hflip (float): Probability of horizontal flip. + vflip (float): Probability of vertical flip. + auto_augment (str | None): Auto augmentation policy. Can be 'randaugment', 'augmix', 'autoaugment' or None. + hsv_h (float): Image HSV-Hue augmentation factor. + hsv_s (float): Image HSV-Saturation augmentation factor. + hsv_v (float): Image HSV-Value augmentation factor. + force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled. + erasing (float): Probability of random erasing. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. + + Returns: + (torchvision.transforms.Compose): A composition of image augmentation transforms. + + Examples: + >>> transforms = classify_augmentations(size=224, auto_augment="randaugment") + >>> augmented_image = transforms(original_image) + """ + # Transforms to apply if Albumentations not installed + import torchvision.transforms as T # scope for faster 'import ultralytics' + + if not isinstance(size, int): + raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range + interpolation = getattr(T.InterpolationMode, interpolation) + primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] + if hflip > 0.0: + primary_tfl.append(T.RandomHorizontalFlip(p=hflip)) + if vflip > 0.0: + primary_tfl.append(T.RandomVerticalFlip(p=vflip)) + + secondary_tfl = [] + disable_color_jitter = False + if auto_augment: + assert isinstance(auto_augment, str), f"Provided argument should be string, but got type {type(auto_augment)}" + # color jitter is typically disabled if AA/RA on, + # this allows override without breaking old hparm cfgs + disable_color_jitter = not force_color_jitter + + if auto_augment == "randaugment": + if TORCHVISION_0_11: + secondary_tfl.append(T.RandAugment(interpolation=interpolation)) + else: + LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.') + + elif auto_augment == "augmix": + if TORCHVISION_0_13: + secondary_tfl.append(T.AugMix(interpolation=interpolation)) + else: + LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.') + + elif auto_augment == "autoaugment": + if TORCHVISION_0_10: + secondary_tfl.append(T.AutoAugment(interpolation=interpolation)) + else: + LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.') + + else: + raise ValueError( + f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' + f'"augmix", "autoaugment" or None' + ) + + if not disable_color_jitter: + secondary_tfl.append(T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)) + + final_tfl = [ + T.ToTensor(), + T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), + T.RandomErasing(p=erasing, inplace=True), + ] + + return T.Compose(primary_tfl + secondary_tfl + final_tfl) + + +# NOTE: keep this class for backward compatibility +class ClassifyLetterBox: + """ + A class for resizing and padding images for classification tasks. + + This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). + It resizes and pads images to a specified size while maintaining the original aspect ratio. + + Attributes: + h (int): Target height of the image. + w (int): Target width of the image. + auto (bool): If True, automatically calculates the short side using stride. + stride (int): The stride value, used when 'auto' is True. + + Methods: + __call__: Applies the letterbox transformation to an input image. + + Examples: + >>> transform = ClassifyLetterBox(size=(640, 640), auto=False, stride=32) + >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> result = transform(img) + >>> print(result.shape) + (640, 640, 3) + """ + + def __init__(self, size=(640, 640), auto=False, stride=32): + """ + Initializes the ClassifyLetterBox object for image preprocessing. + + This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and + pads images to a specified size while maintaining the original aspect ratio. + + Args: + size (int | Tuple[int, int]): Target size for the letterboxed image. If an int, a square image of + (size, size) is created. If a tuple, it should be (height, width). + auto (bool): If True, automatically calculates the short side based on stride. Default is False. + stride (int): The stride value, used when 'auto' is True. Default is 32. + + Attributes: + h (int): Target height of the letterboxed image. + w (int): Target width of the letterboxed image. + auto (bool): Flag indicating whether to automatically calculate short side. + stride (int): Stride value for automatic short side calculation. + + Examples: + >>> transform = ClassifyLetterBox(size=224) + >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> result = transform(img) + >>> print(result.shape) + (224, 224, 3) + """ + super().__init__() + self.h, self.w = (size, size) if isinstance(size, int) else size + self.auto = auto # pass max size integer, automatically solve for short side using stride + self.stride = stride # used with auto + + def __call__(self, im): + """ + Resizes and pads an image using the letterbox method. + + This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio, + then pads the resized image to match the target size. + + Args: + im (numpy.ndarray): Input image as a numpy array with shape (H, W, C). + + Returns: + (numpy.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are + the target height and width respectively. + + Examples: + >>> letterbox = ClassifyLetterBox(size=(640, 640)) + >>> image = np.random.randint(0, 255, (720, 1280, 3), dtype=np.uint8) + >>> resized_image = letterbox(image) + >>> print(resized_image.shape) + (640, 640, 3) + """ + imh, imw = im.shape[:2] + r = min(self.h / imh, self.w / imw) # ratio of new/old dimensions + h, w = round(imh * r), round(imw * r) # resized image dimensions + + # Calculate padding dimensions + hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else (self.h, self.w) + top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1) + + # Create padded image + im_out = np.full((hs, ws, 3), 114, dtype=im.dtype) + im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + return im_out + + +# NOTE: keep this class for backward compatibility +class CenterCrop: + """ + Applies center cropping to images for classification tasks. + + This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect + ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]). + + Attributes: + h (int): Target height of the cropped image. + w (int): Target width of the cropped image. + + Methods: + __call__: Applies the center crop transformation to an input image. + + Examples: + >>> transform = CenterCrop(640) + >>> image = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8) + >>> cropped_image = transform(image) + >>> print(cropped_image.shape) + (640, 640, 3) + """ + + def __init__(self, size=640): + """ + Initializes the CenterCrop object for image preprocessing. + + This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]). + It performs a center crop on input images to a specified size. + + Args: + size (int | Tuple[int, int]): The desired output size of the crop. If size is an int, a square crop + (size, size) is made. If size is a sequence like (h, w), it is used as the output size. + + Returns: + (None): This method initializes the object and does not return anything. + + Examples: + >>> transform = CenterCrop(224) + >>> img = np.random.rand(300, 300, 3) + >>> cropped_img = transform(img) + >>> print(cropped_img.shape) + (224, 224, 3) + """ + super().__init__() + self.h, self.w = (size, size) if isinstance(size, int) else size + + def __call__(self, im): + """ + Applies center cropping to an input image. + + This method resizes and crops the center of the image using a letterbox method. It maintains the aspect + ratio of the original image while fitting it into the specified dimensions. + + Args: + im (numpy.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a + PIL Image object. + + Returns: + (numpy.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C). + + Examples: + >>> transform = CenterCrop(size=224) + >>> image = np.random.randint(0, 255, (640, 480, 3), dtype=np.uint8) + >>> cropped_image = transform(image) + >>> assert cropped_image.shape == (224, 224, 3) + """ + if isinstance(im, Image.Image): # convert from PIL to numpy array if required + im = np.asarray(im) + imh, imw = im.shape[:2] + m = min(imh, imw) # min dimension + top, left = (imh - m) // 2, (imw - m) // 2 + return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) + + +# NOTE: keep this class for backward compatibility +class ToTensor: + """ + Converts an image from a numpy array to a PyTorch tensor. + + This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). + + Attributes: + half (bool): If True, converts the image to half precision (float16). + + Methods: + __call__: Applies the tensor conversion to an input image. + + Examples: + >>> transform = ToTensor(half=True) + >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> tensor_img = transform(img) + >>> print(tensor_img.shape, tensor_img.dtype) + torch.Size([3, 640, 640]) torch.float16 + + Notes: + The input image is expected to be in BGR format with shape (H, W, C). + The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1]. + """ + + def __init__(self, half=False): + """ + Initializes the ToTensor object for converting images to PyTorch tensors. + + This class is designed to be used as part of a transformation pipeline for image preprocessing in the + Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option + for half-precision (float16) conversion. + + Args: + half (bool): If True, converts the tensor to half precision (float16). Default is False. + + Examples: + >>> transform = ToTensor(half=True) + >>> img = np.random.rand(640, 640, 3) + >>> tensor_img = transform(img) + >>> print(tensor_img.dtype) + torch.float16 + """ + super().__init__() + self.half = half + + def __call__(self, im): + """ + Transforms an image from a numpy array to a PyTorch tensor. + + This method converts the input image from a numpy array to a PyTorch tensor, applying optional + half-precision conversion and normalization. The image is transposed from HWC to CHW format and + the color channels are reversed from BGR to RGB. + + Args: + im (numpy.ndarray): Input image as a numpy array with shape (H, W, C) in BGR order. + + Returns: + (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized + to [0, 1] with shape (C, H, W) in RGB order. + + Examples: + >>> transform = ToTensor(half=True) + >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> tensor_img = transform(img) + >>> print(tensor_img.shape, tensor_img.dtype) + torch.Size([3, 640, 640]) torch.float16 + """ + im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous + im = torch.from_numpy(im) # to torch + im = im.half() if self.half else im.float() # uint8 to fp16/32 + im /= 255.0 # 0-255 to 0.0-1.0 + return im diff --git a/tracking/ultralytics/data/base.py b/tracking/ultralytics/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b059e1f194cde31ca44ead83ff9c37417f292afe --- /dev/null +++ b/tracking/ultralytics/data/base.py @@ -0,0 +1,432 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import glob +import math +import os +import random +from copy import deepcopy +from multiprocessing.pool import ThreadPool +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import psutil +from torch.utils.data import Dataset + +from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS +from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM + + +class BaseDataset(Dataset): + """ + Base dataset class for loading and processing image data. + + This class provides core functionality for loading images, caching, and preparing data for training and inference + in object detection tasks. + + Attributes: + img_path (str): Path to the folder containing images. + imgsz (int): Target image size for resizing. + augment (bool): Whether to apply data augmentation. + single_cls (bool): Whether to treat all objects as a single class. + prefix (str): Prefix to print in log messages. + fraction (float): Fraction of dataset to utilize. + im_files (List[str]): List of image file paths. + labels (List[Dict]): List of label data dictionaries. + ni (int): Number of images in the dataset. + rect (bool): Whether to use rectangular training. + batch_size (int): Size of batches. + stride (int): Stride used in the model. + pad (float): Padding value. + buffer (list): Buffer for mosaic images. + max_buffer_length (int): Maximum buffer size. + ims (list): List of loaded images. + im_hw0 (list): List of original image dimensions (h, w). + im_hw (list): List of resized image dimensions (h, w). + npy_files (List[Path]): List of numpy file paths. + cache (str): Cache images to RAM or disk during training. + transforms (callable): Image transformation function. + + Methods: + get_img_files: Read image files from the specified path. + update_labels: Update labels to include only specified classes. + load_image: Load an image from the dataset. + cache_images: Cache images to memory or disk. + cache_images_to_disk: Save an image as an *.npy file for faster loading. + check_cache_disk: Check image caching requirements vs available disk space. + check_cache_ram: Check image caching requirements vs available memory. + set_rectangle: Set the shape of bounding boxes as rectangles. + get_image_and_label: Get and return label information from the dataset. + update_labels_info: Custom label format method to be implemented by subclasses. + build_transforms: Build transformation pipeline to be implemented by subclasses. + get_labels: Get labels method to be implemented by subclasses. + """ + + def __init__( + self, + img_path, + imgsz=640, + cache=False, + augment=True, + hyp=DEFAULT_CFG, + prefix="", + rect=False, + batch_size=16, + stride=32, + pad=0.5, + single_cls=False, + classes=None, + fraction=1.0, + ): + """ + Initialize BaseDataset with given configuration and options. + + Args: + img_path (str): Path to the folder containing images. + imgsz (int, optional): Image size for resizing. + cache (bool | str, optional): Cache images to RAM or disk during training. + augment (bool, optional): If True, data augmentation is applied. + hyp (dict, optional): Hyperparameters to apply data augmentation. + prefix (str, optional): Prefix to print in log messages. + rect (bool, optional): If True, rectangular training is used. + batch_size (int, optional): Size of batches. + stride (int, optional): Stride used in the model. + pad (float, optional): Padding value. + single_cls (bool, optional): If True, single class training is used. + classes (list, optional): List of included classes. + fraction (float, optional): Fraction of dataset to utilize. + """ + super().__init__() + self.img_path = img_path + self.imgsz = imgsz + self.augment = augment + self.single_cls = single_cls + self.prefix = prefix + self.fraction = fraction + self.im_files = self.get_img_files(self.img_path) + self.labels = self.get_labels() + self.update_labels(include_class=classes) # single_cls and include_class + self.ni = len(self.labels) # number of images + self.rect = rect + self.batch_size = batch_size + self.stride = stride + self.pad = pad + if self.rect: + assert self.batch_size is not None + self.set_rectangle() + + # Buffer thread for mosaic images + self.buffer = [] # buffer size = batch size + self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 + + # Cache images (options are cache = True, False, None, "ram", "disk") + self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni + self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] + self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None + if self.cache == "ram" and self.check_cache_ram(): + if hyp.deterministic: + LOGGER.warning( + "WARNING ⚠️ cache='ram' may produce non-deterministic training results. " + "Consider cache='disk' as a deterministic alternative if your disk space allows." + ) + self.cache_images() + elif self.cache == "disk" and self.check_cache_disk(): + self.cache_images() + + # Transforms + self.transforms = self.build_transforms(hyp=hyp) + + def get_img_files(self, img_path): + """ + Read image files from the specified path. + + Args: + img_path (str | List[str]): Path or list of paths to image directories or files. + + Returns: + (List[str]): List of image file paths. + + Raises: + FileNotFoundError: If no images are found or the path doesn't exist. + """ + try: + f = [] # image files + for p in img_path if isinstance(img_path, list) else [img_path]: + p = Path(p) # os-agnostic + if p.is_dir(): # dir + f += glob.glob(str(p / "**" / "*.*"), recursive=True) + # F = list(p.rglob('*.*')) # pathlib + elif p.is_file(): # file + with open(p, encoding="utf-8") as t: + t = t.read().strip().splitlines() + parent = str(p.parent) + os.sep + f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path + # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) + else: + raise FileNotFoundError(f"{self.prefix}{p} does not exist") + im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) + # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib + assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}" + except Exception as e: + raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e + if self.fraction < 1: + im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset + return im_files + + def update_labels(self, include_class: Optional[list]): + """ + Update labels to include only specified classes. + + Args: + include_class (list, optional): List of classes to include. If None, all classes are included. + """ + include_class_array = np.array(include_class).reshape(1, -1) + for i in range(len(self.labels)): + if include_class is not None: + cls = self.labels[i]["cls"] + bboxes = self.labels[i]["bboxes"] + segments = self.labels[i]["segments"] + keypoints = self.labels[i]["keypoints"] + j = (cls == include_class_array).any(1) + self.labels[i]["cls"] = cls[j] + self.labels[i]["bboxes"] = bboxes[j] + if segments: + self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx] + if keypoints is not None: + self.labels[i]["keypoints"] = keypoints[j] + if self.single_cls: + self.labels[i]["cls"][:, 0] = 0 + + def load_image(self, i, rect_mode=True): + """ + Load an image from dataset index 'i'. + + Args: + i (int): Index of the image to load. + rect_mode (bool, optional): Whether to use rectangular resizing. + + Returns: + (np.ndarray): Loaded image. + (tuple): Original image dimensions (h, w). + (tuple): Resized image dimensions (h, w). + + Raises: + FileNotFoundError: If the image file is not found. + """ + im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] + if im is None: # not cached in RAM + if fn.exists(): # load npy + try: + im = np.load(fn) + except Exception as e: + LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") + Path(fn).unlink(missing_ok=True) + im = cv2.imread(f) # BGR + else: # read image + im = cv2.imread(f) # BGR + if im is None: + raise FileNotFoundError(f"Image Not Found {f}") + + h0, w0 = im.shape[:2] # orig hw + if rect_mode: # resize long side to imgsz while maintaining aspect ratio + r = self.imgsz / max(h0, w0) # ratio + if r != 1: # if sizes are not equal + w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz + im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR) + + # Add to buffer if training with augmentations + if self.augment: + self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized + self.buffer.append(i) + if 1 < len(self.buffer) >= self.max_buffer_length: # prevent empty buffer + j = self.buffer.pop(0) + if self.cache != "ram": + self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None + + return im, (h0, w0), im.shape[:2] + + return self.ims[i], self.im_hw0[i], self.im_hw[i] + + def cache_images(self): + """Cache images to memory or disk for faster training.""" + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM") + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap(fcn, range(self.ni)) + pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0) + for i, x in pbar: + if self.cache == "disk": + b += self.npy_files[i].stat().st_size + else: # 'ram' + self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) + b += self.ims[i].nbytes + pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})" + pbar.close() + + def cache_images_to_disk(self, i): + """Save an image as an *.npy file for faster loading.""" + f = self.npy_files[i] + if not f.exists(): + np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False) + + def check_cache_disk(self, safety_margin=0.5): + """ + Check if there's enough disk space for caching images. + + Args: + safety_margin (float, optional): Safety margin factor for disk space calculation. + + Returns: + (bool): True if there's enough disk space, False otherwise. + """ + import shutil + + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.ni, 30) # extrapolate from 30 random images + for _ in range(n): + im_file = random.choice(self.im_files) + im = cv2.imread(im_file) + if im is None: + continue + b += im.nbytes + if not os.access(Path(im_file).parent, os.W_OK): + self.cache = None + LOGGER.info(f"{self.prefix}Skipping caching images to disk, directory not writeable ⚠️") + return False + disk_required = b * self.ni / n * (1 + safety_margin) # bytes required to cache dataset to disk + total, used, free = shutil.disk_usage(Path(self.im_files[0]).parent) + if disk_required > free: + self.cache = None + LOGGER.info( + f"{self.prefix}{disk_required / gb:.1f}GB disk space required, " + f"with {int(safety_margin * 100)}% safety margin but only " + f"{free / gb:.1f}/{total / gb:.1f}GB free, not caching images to disk ⚠️" + ) + return False + return True + + def check_cache_ram(self, safety_margin=0.5): + """ + Check if there's enough RAM for caching images. + + Args: + safety_margin (float, optional): Safety margin factor for RAM calculation. + + Returns: + (bool): True if there's enough RAM, False otherwise. + """ + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.ni, 30) # extrapolate from 30 random images + for _ in range(n): + im = cv2.imread(random.choice(self.im_files)) # sample image + if im is None: + continue + ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio + b += im.nbytes * ratio**2 + mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM + mem = psutil.virtual_memory() + if mem_required > mem.available: + self.cache = None + LOGGER.info( + f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images " + f"with {int(safety_margin * 100)}% safety margin but only " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images ⚠️" + ) + return False + return True + + def set_rectangle(self): + """Set the shape of bounding boxes for YOLO detections as rectangles.""" + bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index + nb = bi[-1] + 1 # number of batches + + s = np.array([x.pop("shape") for x in self.labels]) # hw + ar = s[:, 0] / s[:, 1] # aspect ratio + irect = ar.argsort() + self.im_files = [self.im_files[i] for i in irect] + self.labels = [self.labels[i] for i in irect] + ar = ar[irect] + + # Set training image shapes + shapes = [[1, 1]] * nb + for i in range(nb): + ari = ar[bi == i] + mini, maxi = ari.min(), ari.max() + if maxi < 1: + shapes[i] = [maxi, 1] + elif mini > 1: + shapes[i] = [1, 1 / mini] + + self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride + self.batch = bi # batch index of image + + def __getitem__(self, index): + """Return transformed label information for given index.""" + return self.transforms(self.get_image_and_label(index)) + + def get_image_and_label(self, index): + """ + Get and return label information from the dataset. + + Args: + index (int): Index of the image to retrieve. + + Returns: + (dict): Label dictionary with image and metadata. + """ + label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 + label.pop("shape", None) # shape is for rect, remove it + label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) + label["ratio_pad"] = ( + label["resized_shape"][0] / label["ori_shape"][0], + label["resized_shape"][1] / label["ori_shape"][1], + ) # for evaluation + if self.rect: + label["rect_shape"] = self.batch_shapes[self.batch[index]] + return self.update_labels_info(label) + + def __len__(self): + """Return the length of the labels list for the dataset.""" + return len(self.labels) + + def update_labels_info(self, label): + """Custom your label format here.""" + return label + + def build_transforms(self, hyp=None): + """ + Users can customize augmentations here. + + Examples: + >>> if self.augment: + ... # Training transforms + ... return Compose([]) + >>> else: + ... # Val transforms + ... return Compose([]) + """ + raise NotImplementedError + + def get_labels(self): + """ + Users can customize their own format here. + + Note: + Ensure output is a dictionary with the following keys: + ```python + dict( + im_file=im_file, + shape=shape, # format: (height, width) + cls=cls, + bboxes=bboxes, # xywh + segments=segments, # xy + keypoints=keypoints, # xy + normalized=True, # or False + bbox_format="xyxy", # or xywh, ltwh + ) + ``` + """ + raise NotImplementedError diff --git a/tracking/ultralytics/data/build.py b/tracking/ultralytics/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..3051dbb6ad934da8e98413f2309f146b70049333 --- /dev/null +++ b/tracking/ultralytics/data/build.py @@ -0,0 +1,258 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +import random +from pathlib import Path + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import dataloader, distributed + +from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset +from ultralytics.data.loaders import ( + LOADERS, + LoadImagesAndVideos, + LoadPilAndNumpy, + LoadScreenshots, + LoadStreams, + LoadTensor, + SourceTypes, + autocast_list, +) +from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS +from ultralytics.utils import RANK, colorstr +from ultralytics.utils.checks import check_file + + +class InfiniteDataLoader(dataloader.DataLoader): + """ + Dataloader that reuses workers. + + This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency + for training loops that need to iterate through the dataset multiple times. + + Attributes: + batch_sampler (_RepeatSampler): A sampler that repeats indefinitely. + iterator (Iterator): The iterator from the parent DataLoader. + + Methods: + __len__: Returns the length of the batch sampler's sampler. + __iter__: Creates a sampler that repeats indefinitely. + __del__: Ensures workers are properly terminated. + reset: Resets the iterator, useful when modifying dataset settings during training. + """ + + def __init__(self, *args, **kwargs): + """Initialize the InfiniteDataLoader with the same arguments as DataLoader.""" + super().__init__(*args, **kwargs) + object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + """Return the length of the batch sampler's sampler.""" + return len(self.batch_sampler.sampler) + + def __iter__(self): + """Create an iterator that yields indefinitely from the underlying iterator.""" + for _ in range(len(self)): + yield next(self.iterator) + + def __del__(self): + """Ensure that workers are properly terminated when the dataloader is deleted.""" + try: + if not hasattr(self.iterator, "_workers"): + return + for w in self.iterator._workers: # force terminate + if w.is_alive(): + w.terminate() + self.iterator._shutdown_workers() # cleanup + except Exception: + pass + + def reset(self): + """Reset the iterator to allow modifications to the dataset during training.""" + self.iterator = self._get_iterator() + + +class _RepeatSampler: + """ + Sampler that repeats forever. + + This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration + over a dataset. + + Attributes: + sampler (Dataset.sampler): The sampler to repeat. + """ + + def __init__(self, sampler): + """Initialize the _RepeatSampler with a sampler to repeat indefinitely.""" + self.sampler = sampler + + def __iter__(self): + """Iterate over the sampler indefinitely, yielding its contents.""" + while True: + yield from iter(self.sampler) + + +def seed_worker(worker_id): # noqa + """Set dataloader worker seed for reproducibility across worker processes.""" + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): + """Build and return a YOLO dataset based on configuration parameters.""" + dataset = YOLOMultiModalDataset if multi_modal else YOLODataset + return dataset( + img_path=img_path, + imgsz=cfg.imgsz, + batch_size=batch, + augment=mode == "train", # augmentation + hyp=cfg, # TODO: probably add a get_hyps_from_cfg function + rect=cfg.rect or rect, # rectangular batches + cache=cfg.cache or None, + single_cls=cfg.single_cls or False, + stride=int(stride), + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + task=cfg.task, + classes=cfg.classes, + data=data, + fraction=cfg.fraction if mode == "train" else 1.0, + ) + + +def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): + """Build and return a GroundingDataset based on configuration parameters.""" + return GroundingDataset( + img_path=img_path, + json_file=json_file, + imgsz=cfg.imgsz, + batch_size=batch, + augment=mode == "train", # augmentation + hyp=cfg, # TODO: probably add a get_hyps_from_cfg function + rect=cfg.rect or rect, # rectangular batches + cache=cfg.cache or None, + single_cls=cfg.single_cls or False, + stride=int(stride), + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + task=cfg.task, + classes=cfg.classes, + fraction=cfg.fraction if mode == "train" else 1.0, + ) + + +def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): + """ + Create and return an InfiniteDataLoader or DataLoader for training or validation. + + Args: + dataset (Dataset): Dataset to load data from. + batch (int): Batch size for the dataloader. + workers (int): Number of worker threads for loading data. + shuffle (bool): Whether to shuffle the dataset. + rank (int): Process rank in distributed training. -1 for single-GPU training. + + Returns: + (InfiniteDataLoader): A dataloader that can be used for training or validation. + """ + batch = min(batch, len(dataset)) + nd = torch.cuda.device_count() # number of CUDA devices + nw = min(os.cpu_count() // max(nd, 1), workers) # number of workers + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + generator = torch.Generator() + generator.manual_seed(6148914691236517205 + RANK) + return InfiniteDataLoader( + dataset=dataset, + batch_size=batch, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + worker_init_fn=seed_worker, + generator=generator, + ) + + +def check_source(source): + """ + Check the type of input source and return corresponding flag values. + + Args: + source (str | int | Path | List | Tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check. + + Returns: + (tuple): A tuple containing: + - source: The processed source. + - webcam (bool): Whether the source is a webcam. + - screenshot (bool): Whether the source is a screenshot. + - from_img (bool): Whether the source is an image or list of images. + - in_memory (bool): Whether the source is an in-memory object. + - tensor (bool): Whether the source is a torch.Tensor. + + Raises: + TypeError: If the source type is unsupported. + """ + webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False + if isinstance(source, (str, int, Path)): # int for local usb camera + source = str(source) + is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS) + is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) + webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) + screenshot = source.lower() == "screen" + if is_url and is_file: + source = check_file(source) # download + elif isinstance(source, LOADERS): + in_memory = True + elif isinstance(source, (list, tuple)): + source = autocast_list(source) # convert all list elements to PIL or np arrays + from_img = True + elif isinstance(source, (Image.Image, np.ndarray)): + from_img = True + elif isinstance(source, torch.Tensor): + tensor = True + else: + raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") + + return source, webcam, screenshot, from_img, in_memory, tensor + + +def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): + """ + Load an inference source for object detection and apply necessary transformations. + + Args: + source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference. + batch (int, optional): Batch size for dataloaders. + vid_stride (int, optional): The frame interval for video sources. + buffer (bool, optional): Whether stream frames will be buffered. + + Returns: + (Dataset): A dataset object for the specified input source with attached source_type attribute. + """ + source, stream, screenshot, from_img, in_memory, tensor = check_source(source) + source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor) + + # Dataloader + if tensor: + dataset = LoadTensor(source) + elif in_memory: + dataset = source + elif stream: + dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer) + elif screenshot: + dataset = LoadScreenshots(source) + elif from_img: + dataset = LoadPilAndNumpy(source) + else: + dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride) + + # Attach source types to the dataset + setattr(dataset, "source_type", source_type) + + return dataset diff --git a/tracking/ultralytics/data/converter.py b/tracking/ultralytics/data/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..741ac8e7fdbf533db5f7ebf21f13335f82d2f7b0 --- /dev/null +++ b/tracking/ultralytics/data/converter.py @@ -0,0 +1,703 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json +import random +import shutil +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image + +from ultralytics.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM +from ultralytics.utils.downloads import download +from ultralytics.utils.files import increment_path + + +def coco91_to_coco80_class(): + """ + Converts 91-index COCO class IDs to 80-index COCO class IDs. + + Returns: + (list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the + corresponding 91-index class ID. + """ + return [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + None, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + None, + 24, + 25, + None, + None, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + None, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + None, + 60, + None, + None, + 61, + None, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + None, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + None, + ] + + +def coco80_to_coco91_class(): + r""" + Converts 80-index (val2014) to 91-index (paper). + For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/. + + Examples: + >>> import numpy as np + >>> a = np.loadtxt("data/coco.names", dtype="str", delimiter="\n") + >>> b = np.loadtxt("data/coco_paper.names", dtype="str", delimiter="\n") + + Convert the darknet to COCO format + >>> x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] + + Convert the COCO to darknet format + >>> x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] + """ + return [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 27, + 28, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 67, + 70, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + ] + + +def convert_coco( + labels_dir="../coco/annotations/", + save_dir="coco_converted/", + use_segments=False, + use_keypoints=False, + cls91to80=True, + lvis=False, +): + """ + Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models. + + Args: + labels_dir (str, optional): Path to directory containing COCO dataset annotation files. + save_dir (str, optional): Path to directory to save results to. + use_segments (bool, optional): Whether to include segmentation masks in the output. + use_keypoints (bool, optional): Whether to include keypoint annotations in the output. + cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs. + lvis (bool, optional): Whether to convert data in lvis dataset way. + + Examples: + >>> from ultralytics.data.converter import convert_coco + + Convert COCO annotations to YOLO format + >>> convert_coco("../datasets/coco/annotations/", use_segments=True, use_keypoints=False, cls91to80=False) + + Convert LVIS annotations to YOLO format + >>> convert_coco( + >>> "../datasets/lvis/annotations/", + ... use_segments=True, + ... use_keypoints=False, + ... cls91to80=False, + ... lvis=True + ... ) + + Output: + Generates output files in the specified output directory. + """ + # Create dataset directory + save_dir = increment_path(save_dir) # increment if save directory already exists + for p in save_dir / "labels", save_dir / "images": + p.mkdir(parents=True, exist_ok=True) # make dir + + # Convert classes + coco80 = coco91_to_coco80_class() + + # Import json + for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): + lname = "" if lvis else json_file.stem.replace("instances_", "") + fn = Path(save_dir) / "labels" / lname # folder name + fn.mkdir(parents=True, exist_ok=True) + if lvis: + # NOTE: create folders for both train and val in advance, + # since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split. + (fn / "train2017").mkdir(parents=True, exist_ok=True) + (fn / "val2017").mkdir(parents=True, exist_ok=True) + with open(json_file, encoding="utf-8") as f: + data = json.load(f) + + # Create image dict + images = {f"{x['id']:d}": x for x in data["images"]} + # Create image-annotations dict + annotations = defaultdict(list) + for ann in data["annotations"]: + annotations[ann["image_id"]].append(ann) + + image_txt = [] + # Write labels file + for img_id, anns in TQDM(annotations.items(), desc=f"Annotations {json_file}"): + img = images[f"{img_id:d}"] + h, w = img["height"], img["width"] + f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"] + if lvis: + image_txt.append(str(Path("./images") / f)) + + bboxes = [] + segments = [] + keypoints = [] + for ann in anns: + if ann.get("iscrowd", False): + continue + # The COCO box format is [top left x, top left y, width, height] + box = np.array(ann["bbox"], dtype=np.float64) + box[:2] += box[2:] / 2 # xy top-left corner to center + box[[0, 2]] /= w # normalize x + box[[1, 3]] /= h # normalize y + if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0 + continue + + cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class + box = [cls] + box.tolist() + if box not in bboxes: + bboxes.append(box) + if use_segments and ann.get("segmentation") is not None: + if len(ann["segmentation"]) == 0: + segments.append([]) + continue + elif len(ann["segmentation"]) > 1: + s = merge_multi_segment(ann["segmentation"]) + s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist() + else: + s = [j for i in ann["segmentation"] for j in i] # all segments concatenated + s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist() + s = [cls] + s + segments.append(s) + if use_keypoints and ann.get("keypoints") is not None: + keypoints.append( + box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist() + ) + + # Write + with open((fn / f).with_suffix(".txt"), "a", encoding="utf-8") as file: + for i in range(len(bboxes)): + if use_keypoints: + line = (*(keypoints[i]),) # cls, box, keypoints + else: + line = ( + *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]), + ) # cls, box or segments + file.write(("%g " * len(line)).rstrip() % line + "\n") + + if lvis: + filename = Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt") + with open(filename, "a", encoding="utf-8") as f: + f.writelines(f"{line}\n" for line in image_txt) + + LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}") + + +def convert_segment_masks_to_yolo_seg(masks_dir, output_dir, classes): + """ + Converts a dataset of segmentation mask images to the YOLO segmentation format. + + This function takes the directory containing the binary format mask images and converts them into YOLO segmentation format. + The converted masks are saved in the specified output directory. + + Args: + masks_dir (str): The path to the directory where all mask images (png, jpg) are stored. + output_dir (str): The path to the directory where the converted YOLO segmentation masks will be stored. + classes (int): Total classes in the dataset i.e. for COCO classes=80 + + Examples: + >>> from ultralytics.data.converter import convert_segment_masks_to_yolo_seg + + The classes here is the total classes in the dataset, for COCO dataset we have 80 classes + >>> convert_segment_masks_to_yolo_seg("path/to/masks_directory", "path/to/output/directory", classes=80) + + Notes: + The expected directory structure for the masks is: + + - masks + ├─ mask_image_01.png or mask_image_01.jpg + ├─ mask_image_02.png or mask_image_02.jpg + ├─ mask_image_03.png or mask_image_03.jpg + └─ mask_image_04.png or mask_image_04.jpg + + After execution, the labels will be organized in the following structure: + + - output_dir + ├─ mask_yolo_01.txt + ├─ mask_yolo_02.txt + ├─ mask_yolo_03.txt + └─ mask_yolo_04.txt + """ + pixel_to_class_mapping = {i + 1: i for i in range(classes)} + for mask_path in Path(masks_dir).iterdir(): + if mask_path.suffix in {".png", ".jpg"}: + mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale + img_height, img_width = mask.shape # Get image dimensions + LOGGER.info(f"Processing {mask_path} imgsz = {img_height} x {img_width}") + + unique_values = np.unique(mask) # Get unique pixel values representing different classes + yolo_format_data = [] + + for value in unique_values: + if value == 0: + continue # Skip background + class_index = pixel_to_class_mapping.get(value, -1) + if class_index == -1: + LOGGER.warning(f"Unknown class for pixel value {value} in file {mask_path}, skipping.") + continue + + # Create a binary mask for the current class and find contours + contours, _ = cv2.findContours( + (mask == value).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) # Find contours + + for contour in contours: + if len(contour) >= 3: # YOLO requires at least 3 points for a valid segmentation + contour = contour.squeeze() # Remove single-dimensional entries + yolo_format = [class_index] + for point in contour: + # Normalize the coordinates + yolo_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places + yolo_format.append(round(point[1] / img_height, 6)) + yolo_format_data.append(yolo_format) + # Save Ultralytics YOLO format data to file + output_path = Path(output_dir) / f"{mask_path.stem}.txt" + with open(output_path, "w", encoding="utf-8") as file: + for item in yolo_format_data: + line = " ".join(map(str, item)) + file.write(line + "\n") + LOGGER.info(f"Processed and stored at {output_path} imgsz = {img_height} x {img_width}") + + +def convert_dota_to_yolo_obb(dota_root_path: str): + """ + Converts DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format. + + The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the + associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory. + + Args: + dota_root_path (str): The root directory path of the DOTA dataset. + + Examples: + >>> from ultralytics.data.converter import convert_dota_to_yolo_obb + >>> convert_dota_to_yolo_obb("path/to/DOTA") + + Notes: + The directory structure assumed for the DOTA dataset: + + - DOTA + ├─ images + │ ├─ train + │ └─ val + └─ labels + ├─ train_original + └─ val_original + + After execution, the function will organize the labels into: + + - DOTA + └─ labels + ├─ train + └─ val + """ + dota_root_path = Path(dota_root_path) + + # Class names to indices mapping + class_mapping = { + "plane": 0, + "ship": 1, + "storage-tank": 2, + "baseball-diamond": 3, + "tennis-court": 4, + "basketball-court": 5, + "ground-track-field": 6, + "harbor": 7, + "bridge": 8, + "large-vehicle": 9, + "small-vehicle": 10, + "helicopter": 11, + "roundabout": 12, + "soccer-ball-field": 13, + "swimming-pool": 14, + "container-crane": 15, + "airport": 16, + "helipad": 17, + } + + def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir): + """Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory.""" + orig_label_path = orig_label_dir / f"{image_name}.txt" + save_path = save_dir / f"{image_name}.txt" + + with orig_label_path.open("r") as f, save_path.open("w") as g: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + if len(parts) < 9: + continue + class_name = parts[8] + class_idx = class_mapping[class_name] + coords = [float(p) for p in parts[:8]] + normalized_coords = [ + coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8) + ] + formatted_coords = [f"{coord:.6g}" for coord in normalized_coords] + g.write(f"{class_idx} {' '.join(formatted_coords)}\n") + + for phase in ["train", "val"]: + image_dir = dota_root_path / "images" / phase + orig_label_dir = dota_root_path / "labels" / f"{phase}_original" + save_dir = dota_root_path / "labels" / phase + + save_dir.mkdir(parents=True, exist_ok=True) + + image_paths = list(image_dir.iterdir()) + for image_path in TQDM(image_paths, desc=f"Processing {phase} images"): + if image_path.suffix != ".png": + continue + image_name_without_ext = image_path.stem + img = cv2.imread(str(image_path)) + h, w = img.shape[:2] + convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir) + + +def min_index(arr1, arr2): + """ + Find a pair of indexes with the shortest distance between two arrays of 2D points. + + Args: + arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points. + arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points. + + Returns: + (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively. + """ + dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1) + return np.unravel_index(np.argmin(dis, axis=None), dis.shape) + + +def merge_multi_segment(segments): + """ + Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment. + This function connects these coordinates with a thin line to merge all segments into one. + + Args: + segments (List[List]): Original segmentations in COCO's JSON file. + Each element is a list of coordinates, like [segmentation1, segmentation2,...]. + + Returns: + s (List[np.ndarray]): A list of connected segments represented as NumPy arrays. + """ + s = [] + segments = [np.array(i).reshape(-1, 2) for i in segments] + idx_list = [[] for _ in range(len(segments))] + + # Record the indexes with min distance between each segment + for i in range(1, len(segments)): + idx1, idx2 = min_index(segments[i - 1], segments[i]) + idx_list[i - 1].append(idx1) + idx_list[i].append(idx2) + + # Use two round to connect all the segments + for k in range(2): + # Forward connection + if k == 0: + for i, idx in enumerate(idx_list): + # Middle segments have two indexes, reverse the index of middle segments + if len(idx) == 2 and idx[0] > idx[1]: + idx = idx[::-1] + segments[i] = segments[i][::-1, :] + + segments[i] = np.roll(segments[i], -idx[0], axis=0) + segments[i] = np.concatenate([segments[i], segments[i][:1]]) + # Deal with the first segment and the last one + if i in {0, len(idx_list) - 1}: + s.append(segments[i]) + else: + idx = [0, idx[1] - idx[0]] + s.append(segments[i][idx[0] : idx[1] + 1]) + + else: + for i in range(len(idx_list) - 1, -1, -1): + if i not in {0, len(idx_list) - 1}: + idx = idx_list[i] + nidx = abs(idx[1] - idx[0]) + s.append(segments[i][nidx:]) + return s + + +def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt", device=None): + """ + Converts existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) + in YOLO format. Generates segmentation data using SAM auto-annotator as needed. + + Args: + im_dir (str | Path): Path to image directory to convert. + save_dir (str | Path): Path to save the generated labels, labels will be saved + into `labels-segment` in the same directory level of `im_dir` if save_dir is None. + sam_model (str): Segmentation model to use for intermediate segmentation data. + device (int | str): The specific device to run SAM models. + + Notes: + The input directory structure assumed for dataset: + + - im_dir + ├─ 001.jpg + ├─ ... + └─ NNN.jpg + - labels + ├─ 001.txt + ├─ ... + └─ NNN.txt + """ + from ultralytics import SAM + from ultralytics.data import YOLODataset + from ultralytics.utils.ops import xywh2xyxy + + # NOTE: add placeholder to pass class index check + dataset = YOLODataset(im_dir, data=dict(names=list(range(1000)))) + if len(dataset.labels[0]["segments"]) > 0: # if it's segment data + LOGGER.info("Segmentation labels detected, no need to generate new ones!") + return + + LOGGER.info("Detection labels detected, generating segment labels by SAM model!") + sam_model = SAM(sam_model) + for label in TQDM(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"): + h, w = label["shape"] + boxes = label["bboxes"] + if len(boxes) == 0: # skip empty labels + continue + boxes[:, [0, 2]] *= w + boxes[:, [1, 3]] *= h + im = cv2.imread(label["im_file"]) + sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False, device=device) + label["segments"] = sam_results[0].masks.xyn + + save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment" + save_dir.mkdir(parents=True, exist_ok=True) + for label in dataset.labels: + texts = [] + lb_name = Path(label["im_file"]).with_suffix(".txt").name + txt_file = save_dir / lb_name + cls = label["cls"] + for i, s in enumerate(label["segments"]): + if len(s) == 0: + continue + line = (int(cls[i]), *s.reshape(-1)) + texts.append(("%g " * len(line)).rstrip() % line) + with open(txt_file, "a", encoding="utf-8") as f: + f.writelines(text + "\n" for text in texts) + LOGGER.info(f"Generated segment labels saved in {save_dir}") + + +def create_synthetic_coco_dataset(): + """ + Creates a synthetic COCO dataset with random images based on filenames from label lists. + + This function downloads COCO labels, reads image filenames from label list files, + creates synthetic images for train2017 and val2017 subsets, and organizes + them in the COCO dataset structure. It uses multithreading to generate images efficiently. + + Examples: + >>> from ultralytics.data.converter import create_synthetic_coco_dataset + >>> create_synthetic_coco_dataset() + + Notes: + - Requires internet connection to download label files. + - Generates random RGB images of varying sizes (480x480 to 640x640 pixels). + - Existing test2017 directory is removed as it's not needed. + - Reads image filenames from train2017.txt and val2017.txt files. + """ + + def create_synthetic_image(image_file): + """Generates synthetic images with random sizes and colors for dataset augmentation or testing purposes.""" + if not image_file.exists(): + size = (random.randint(480, 640), random.randint(480, 640)) + Image.new( + "RGB", + size=size, + color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), + ).save(image_file) + + # Download labels + dir = DATASETS_DIR / "coco" + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + label_zip = "coco2017labels-segments.zip" + download([url + label_zip], dir=dir.parent) + + # Create synthetic images + shutil.rmtree(dir / "labels" / "test2017", ignore_errors=True) # Remove test2017 directory as not needed + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + for subset in ["train2017", "val2017"]: + subset_dir = dir / "images" / subset + subset_dir.mkdir(parents=True, exist_ok=True) + + # Read image filenames from label list file + label_list_file = dir / f"{subset}.txt" + if label_list_file.exists(): + with open(label_list_file, encoding="utf-8") as f: + image_files = [dir / line.strip() for line in f] + + # Submit all tasks + futures = [executor.submit(create_synthetic_image, image_file) for image_file in image_files] + for _ in TQDM(as_completed(futures), total=len(futures), desc=f"Generating images for {subset}"): + pass # The actual work is done in the background + else: + print(f"Warning: Labels file {label_list_file} does not exist. Skipping image creation for {subset}.") + + print("Synthetic COCO dataset created successfully.") diff --git a/tracking/ultralytics/data/dataset.py b/tracking/ultralytics/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..231a70236115c1bb5d2ec38fdf2a1592e91aa8a5 --- /dev/null +++ b/tracking/ultralytics/data/dataset.py @@ -0,0 +1,676 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json +from collections import defaultdict +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import ConcatDataset + +from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr +from ultralytics.utils.ops import resample_segments +from ultralytics.utils.torch_utils import TORCHVISION_0_18 + +from .augment import ( + Compose, + Format, + Instances, + LetterBox, + RandomLoadText, + classify_augmentations, + classify_transforms, + v8_transforms, +) +from .base import BaseDataset +from .utils import ( + HELP_URL, + LOGGER, + get_hash, + img2label_paths, + load_dataset_cache_file, + save_dataset_cache_file, + verify_image, + verify_image_label, +) + +# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 +DATASET_CACHE_VERSION = "1.0.3" + + +class YOLODataset(BaseDataset): + """ + Dataset class for loading object detection and/or segmentation labels in YOLO format. + + This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box + (OBB) tasks using the YOLO format. + + Attributes: + use_segments (bool): Indicates if segmentation masks should be used. + use_keypoints (bool): Indicates if keypoints should be used for pose estimation. + use_obb (bool): Indicates if oriented bounding boxes should be used. + data (dict): Dataset configuration dictionary. + + Methods: + cache_labels: Cache dataset labels, check images and read shapes. + get_labels: Returns dictionary of labels for YOLO training. + build_transforms: Builds and appends transforms to the list. + close_mosaic: Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations. + update_labels_info: Updates label format for different tasks. + collate_fn: Collates data samples into batches. + + Examples: + >>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect") + >>> dataset.get_labels() + """ + + def __init__(self, *args, data=None, task="detect", **kwargs): + """ + Initialize the YOLODataset. + + Args: + data (dict, optional): Dataset configuration dictionary. + task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'. + *args (Any): Additional positional arguments for the parent class. + **kwargs (Any): Additional keyword arguments for the parent class. + """ + self.use_segments = task == "segment" + self.use_keypoints = task == "pose" + self.use_obb = task == "obb" + self.data = data + assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." + super().__init__(*args, **kwargs) + + def cache_labels(self, path=Path("./labels.cache")): + """ + Cache dataset labels, check images and read shapes. + + Args: + path (Path): Path where to save the cache file. + + Returns: + (dict): Dictionary containing cached labels and related information. + """ + x = {"labels": []} + nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages + desc = f"{self.prefix}Scanning {path.parent / path.stem}..." + total = len(self.im_files) + nkpt, ndim = self.data.get("kpt_shape", (0, 0)) + if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}): + raise ValueError( + "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " + "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" + ) + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap( + func=verify_image_label, + iterable=zip( + self.im_files, + self.label_files, + repeat(self.prefix), + repeat(self.use_keypoints), + repeat(len(self.data["names"])), + repeat(nkpt), + repeat(ndim), + repeat(self.single_cls), + ), + ) + pbar = TQDM(results, desc=desc, total=total) + for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: + nm += nm_f + nf += nf_f + ne += ne_f + nc += nc_f + if im_file: + x["labels"].append( + { + "im_file": im_file, + "shape": shape, + "cls": lb[:, 0:1], # n, 1 + "bboxes": lb[:, 1:], # n, 4 + "segments": segments, + "keypoints": keypoint, + "normalized": True, + "bbox_format": "xywh", + } + ) + if msg: + msgs.append(msg) + pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" + pbar.close() + + if msgs: + LOGGER.info("\n".join(msgs)) + if nf == 0: + LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") + x["hash"] = get_hash(self.label_files + self.im_files) + x["results"] = nf, nm, ne, nc, len(self.im_files) + x["msgs"] = msgs # warnings + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) + return x + + def get_labels(self): + """ + Returns dictionary of labels for YOLO training. + + This method loads labels from disk or cache, verifies their integrity, and prepares them for training. + + Returns: + (List[dict]): List of label dictionaries, each containing information about an image and its annotations. + """ + self.label_files = img2label_paths(self.im_files) + cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") + try: + cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash + except (FileNotFoundError, AssertionError, AttributeError): + cache, exists = self.cache_labels(cache_path), False # run cache ops + + # Display cache + nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total + if exists and LOCAL_RANK in {-1, 0}: + d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" + TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings + + # Read cache + [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items + labels = cache["labels"] + if not labels: + LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") + self.im_files = [lb["im_file"] for lb in labels] # update im_files + + # Check if the dataset is all boxes or all segments + lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) + len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) + if len_segments and len_boxes != len_segments: + LOGGER.warning( + f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " + f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " + "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." + ) + for lb in labels: + lb["segments"] = [] + if len_cls == 0: + LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") + return labels + + def build_transforms(self, hyp=None): + """ + Builds and appends transforms to the list. + + Args: + hyp (dict, optional): Hyperparameters for transforms. + + Returns: + (Compose): Composed transforms. + """ + if self.augment: + hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 + hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 + transforms = v8_transforms(self, self.imgsz, hyp) + else: + transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) + transforms.append( + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + return_obb=self.use_obb, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + bgr=hyp.bgr if self.augment else 0.0, # only affect training. + ) + ) + return transforms + + def close_mosaic(self, hyp): + """ + Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations. + + Args: + hyp (dict): Hyperparameters for transforms. + """ + hyp.mosaic = 0.0 # set mosaic ratio=0.0 + hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic + hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic + self.transforms = self.build_transforms(hyp) + + def update_labels_info(self, label): + """ + Custom your label format here. + + Args: + label (dict): Label dictionary containing bboxes, segments, keypoints, etc. + + Returns: + (dict): Updated label dictionary with instances. + + Note: + cls is not with bboxes now, classification and semantic segmentation need an independent cls label + Can also support classification and semantic segmentation by adding or removing dict keys there. + """ + bboxes = label.pop("bboxes") + segments = label.pop("segments", []) + keypoints = label.pop("keypoints", None) + bbox_format = label.pop("bbox_format") + normalized = label.pop("normalized") + + # NOTE: do NOT resample oriented boxes + segment_resamples = 100 if self.use_obb else 1000 + if len(segments) > 0: + # make sure segments interpolate correctly if original length is greater than segment_resamples + max_len = max(len(s) for s in segments) + segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples + # list[np.array(segment_resamples, 2)] * num_samples + segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) + else: + segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) + label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) + return label + + @staticmethod + def collate_fn(batch): + """ + Collates data samples into batches. + + Args: + batch (List[dict]): List of dictionaries containing sample data. + + Returns: + (dict): Collated batch with stacked tensors. + """ + new_batch = {} + keys = batch[0].keys() + values = list(zip(*[list(b.values()) for b in batch])) + for i, k in enumerate(keys): + value = values[i] + if k == "img": + value = torch.stack(value, 0) + if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}: + value = torch.cat(value, 0) + new_batch[k] = value + new_batch["batch_idx"] = list(new_batch["batch_idx"]) + for i in range(len(new_batch["batch_idx"])): + new_batch["batch_idx"][i] += i # add target image index for build_targets() + new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) + return new_batch + + +class YOLOMultiModalDataset(YOLODataset): + """ + Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support. + + This class extends YOLODataset to add text information for multi-modal model training, enabling models to + process both image and text data. + + Methods: + update_labels_info: Adds text information for multi-modal model training. + build_transforms: Enhances data transformations with text augmentation. + + Examples: + >>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect") + >>> batch = next(iter(dataset)) + >>> print(batch.keys()) # Should include 'texts' + """ + + def __init__(self, *args, data=None, task="detect", **kwargs): + """ + Initialize a YOLOMultiModalDataset. + + Args: + data (dict, optional): Dataset configuration dictionary. + task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'. + *args (Any): Additional positional arguments for the parent class. + **kwargs (Any): Additional keyword arguments for the parent class. + """ + super().__init__(*args, data=data, task=task, **kwargs) + + def update_labels_info(self, label): + """ + Add texts information for multi-modal model training. + + Args: + label (dict): Label dictionary containing bboxes, segments, keypoints, etc. + + Returns: + (dict): Updated label dictionary with instances and texts. + """ + labels = super().update_labels_info(label) + # NOTE: some categories are concatenated with its synonyms by `/`. + labels["texts"] = [v.split("/") for _, v in self.data["names"].items()] + return labels + + def build_transforms(self, hyp=None): + """ + Enhances data transformations with optional text augmentation for multi-modal training. + + Args: + hyp (dict, optional): Hyperparameters for transforms. + + Returns: + (Compose): Composed transforms including text augmentation if applicable. + """ + transforms = super().build_transforms(hyp) + if self.augment: + # NOTE: hard-coded the args for now. + transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True)) + return transforms + + +class GroundingDataset(YOLODataset): + """ + Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format. + + This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than + the standard YOLO format text files. + + Attributes: + json_file (str): Path to the JSON file containing annotations. + + Methods: + get_img_files: Returns empty list as image files are read in get_labels. + get_labels: Loads annotations from a JSON file and prepares them for training. + build_transforms: Configures augmentations for training with optional text loading. + + Examples: + >>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect") + >>> len(dataset) # Number of valid images with annotations + """ + + def __init__(self, *args, task="detect", json_file, **kwargs): + """ + Initialize a GroundingDataset for object detection. + + Args: + json_file (str): Path to the JSON file containing annotations. + task (str): Must be 'detect' for GroundingDataset. + *args (Any): Additional positional arguments for the parent class. + **kwargs (Any): Additional keyword arguments for the parent class. + """ + assert task == "detect", "`GroundingDataset` only support `detect` task for now!" + self.json_file = json_file + super().__init__(*args, task=task, data={}, **kwargs) + + def get_img_files(self, img_path): + """ + The image files would be read in `get_labels` function, return empty list here. + + Args: + img_path (str): Path to the directory containing images. + + Returns: + (list): Empty list as image files are read in get_labels. + """ + return [] + + def get_labels(self): + """ + Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image. + + Returns: + (List[dict]): List of label dictionaries, each containing information about an image and its annotations. + """ + labels = [] + LOGGER.info("Loading annotation file...") + with open(self.json_file) as f: + annotations = json.load(f) + images = {f"{x['id']:d}": x for x in annotations["images"]} + img_to_anns = defaultdict(list) + for ann in annotations["annotations"]: + img_to_anns[ann["image_id"]].append(ann) + for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"): + img = images[f"{img_id:d}"] + h, w, f = img["height"], img["width"], img["file_name"] + im_file = Path(self.img_path) / f + if not im_file.exists(): + continue + self.im_files.append(str(im_file)) + bboxes = [] + cat2id = {} + texts = [] + for ann in anns: + if ann["iscrowd"]: + continue + box = np.array(ann["bbox"], dtype=np.float32) + box[:2] += box[2:] / 2 + box[[0, 2]] /= float(w) + box[[1, 3]] /= float(h) + if box[2] <= 0 or box[3] <= 0: + continue + + caption = img["caption"] + cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]) + if cat_name not in cat2id: + cat2id[cat_name] = len(cat2id) + texts.append([cat_name]) + cls = cat2id[cat_name] # class + box = [cls] + box.tolist() + if box not in bboxes: + bboxes.append(box) + lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32) + labels.append( + { + "im_file": im_file, + "shape": (h, w), + "cls": lb[:, 0:1], # n, 1 + "bboxes": lb[:, 1:], # n, 4 + "normalized": True, + "bbox_format": "xywh", + "texts": texts, + } + ) + return labels + + def build_transforms(self, hyp=None): + """ + Configures augmentations for training with optional text loading. + + Args: + hyp (dict, optional): Hyperparameters for transforms. + + Returns: + (Compose): Composed transforms including text augmentation if applicable. + """ + transforms = super().build_transforms(hyp) + if self.augment: + # NOTE: hard-coded the args for now. + transforms.insert(-1, RandomLoadText(max_samples=80, padding=True)) + return transforms + + +class YOLOConcatDataset(ConcatDataset): + """ + Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same + collation function. + + Methods: + collate_fn: Static method that collates data samples into batches using YOLODataset's collation function. + + Examples: + >>> dataset1 = YOLODataset(...) + >>> dataset2 = YOLODataset(...) + >>> combined_dataset = YOLOConcatDataset([dataset1, dataset2]) + """ + + @staticmethod + def collate_fn(batch): + """ + Collates data samples into batches. + + Args: + batch (List[dict]): List of dictionaries containing sample data. + + Returns: + (dict): Collated batch with stacked tensors. + """ + return YOLODataset.collate_fn(batch) + + +# TODO: support semantic segmentation +class SemanticDataset(BaseDataset): + """Semantic Segmentation Dataset.""" + + def __init__(self): + """Initialize a SemanticDataset object.""" + super().__init__() + + +class ClassificationDataset: + """ + Extends torchvision ImageFolder to support YOLO classification tasks. + + This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently + handle large datasets for training deep learning models, with optional image transformations and caching mechanisms + to speed up training. + + Attributes: + cache_ram (bool): Indicates if caching in RAM is enabled. + cache_disk (bool): Indicates if caching on disk is enabled. + samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache + file (if caching on disk), and optionally the loaded image array (if caching in RAM). + torch_transforms (callable): PyTorch transforms to be applied to the images. + root (str): Root directory of the dataset. + prefix (str): Prefix for logging and cache filenames. + + Methods: + __getitem__: Returns subset of data and targets corresponding to given indices. + __len__: Returns the total number of samples in the dataset. + verify_images: Verifies all images in dataset. + """ + + def __init__(self, root, args, augment=False, prefix=""): + """ + Initialize YOLO object with root, image size, augmentations, and cache settings. + + Args: + root (str): Path to the dataset directory where images are stored in a class-specific folder structure. + args (Namespace): Configuration containing dataset-related settings such as image size, augmentation + parameters, and cache settings. + augment (bool, optional): Whether to apply augmentations to the dataset. + prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification. + """ + import torchvision # scope for faster 'import ultralytics' + + # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import + if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18 + self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True) + else: + self.base = torchvision.datasets.ImageFolder(root=root) + self.samples = self.base.samples + self.root = self.base.root + + # Initialize attributes + if augment and args.fraction < 1.0: # reduce training fraction + self.samples = self.samples[: round(len(self.samples) * args.fraction)] + self.prefix = colorstr(f"{prefix}: ") if prefix else "" + self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM + if self.cache_ram: + LOGGER.warning( + "WARNING ⚠️ Classification `cache_ram` training has known memory leak in " + "https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`." + ) + self.cache_ram = False + self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files + self.samples = self.verify_images() # filter out bad images + self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im + scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) + self.torch_transforms = ( + classify_augmentations( + size=args.imgsz, + scale=scale, + hflip=args.fliplr, + vflip=args.flipud, + erasing=args.erasing, + auto_augment=args.auto_augment, + hsv_h=args.hsv_h, + hsv_s=args.hsv_s, + hsv_v=args.hsv_v, + ) + if augment + else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) + ) + + def __getitem__(self, i): + """ + Returns subset of data and targets corresponding to given indices. + + Args: + i (int): Index of the sample to retrieve. + + Returns: + (dict): Dictionary containing the image and its class index. + """ + f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image + if self.cache_ram: + if im is None: # Warning: two separate if statements required here, do not combine this with previous line + im = self.samples[i][3] = cv2.imread(f) + elif self.cache_disk: + if not fn.exists(): # load npy + np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + # Convert NumPy array to PIL image + im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) + sample = self.torch_transforms(im) + return {"img": sample, "cls": j} + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return len(self.samples) + + def verify_images(self): + """ + Verify all images in dataset. + + Returns: + (list): List of valid samples after verification. + """ + desc = f"{self.prefix}Scanning {self.root}..." + path = Path(self.root).with_suffix(".cache") # *.cache file path + + try: + cache = load_dataset_cache_file(path) # attempt to load a *.cache file + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash + nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total + if LOCAL_RANK in {-1, 0}: + d = f"{desc} {nf} images, {nc} corrupt" + TQDM(None, desc=d, total=n, initial=n) + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings + return samples + + except (FileNotFoundError, AssertionError, AttributeError): + # Run scan if *.cache retrieval failed + nf, nc, msgs, samples, x = 0, 0, [], [], {} + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix))) + pbar = TQDM(results, desc=desc, total=len(self.samples)) + for sample, nf_f, nc_f, msg in pbar: + if nf_f: + samples.append(sample) + if msg: + msgs.append(msg) + nf += nf_f + nc += nc_f + pbar.desc = f"{desc} {nf} images, {nc} corrupt" + pbar.close() + if msgs: + LOGGER.info("\n".join(msgs)) + x["hash"] = get_hash([x[0] for x in self.samples]) + x["results"] = nf, nc, len(samples), samples + x["msgs"] = msgs # warnings + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) + return samples diff --git a/tracking/ultralytics/data/loaders.py b/tracking/ultralytics/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..1e25731440480075dfdd307ef9153e14243e82f7 --- /dev/null +++ b/tracking/ultralytics/data/loaders.py @@ -0,0 +1,659 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import glob +import math +import os +import time +from dataclasses import dataclass +from pathlib import Path +from threading import Thread +from urllib.parse import urlparse + +import cv2 +import numpy as np +import requests +import torch +from PIL import Image + +from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS +from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.patches import imread + + +@dataclass +class SourceTypes: + """ + Class to represent various types of input sources for predictions. + + This class uses dataclass to define boolean flags for different types of input sources that can be used for + making predictions with YOLO models. + + Attributes: + stream (bool): Flag indicating if the input source is a video stream. + screenshot (bool): Flag indicating if the input source is a screenshot. + from_img (bool): Flag indicating if the input source is an image file. + tensor (bool): Flag indicating if the input source is a tensor. + + Examples: + >>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False) + >>> print(source_types.stream) + True + >>> print(source_types.from_img) + False + """ + + stream: bool = False + screenshot: bool = False + from_img: bool = False + tensor: bool = False + + +class LoadStreams: + """ + Stream Loader for various types of video streams. + + Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video + streams simultaneously, making it suitable for real-time video analysis tasks. + + Attributes: + sources (List[str]): The source input paths or URLs for the video streams. + vid_stride (int): Video frame-rate stride. + buffer (bool): Whether to buffer input streams. + running (bool): Flag to indicate if the streaming thread is running. + mode (str): Set to 'stream' indicating real-time capture. + imgs (List[List[np.ndarray]]): List of image frames for each stream. + fps (List[float]): List of FPS for each stream. + frames (List[int]): List of total frames for each stream. + threads (List[Thread]): List of threads for each stream. + shape (List[Tuple[int, int, int]]): List of shapes for each stream. + caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream. + bs (int): Batch size for processing. + + Methods: + update: Read stream frames in daemon thread. + close: Close stream loader and release resources. + __iter__: Returns an iterator object for the class. + __next__: Returns source paths, transformed, and original images for processing. + __len__: Return the length of the sources object. + + Examples: + >>> stream_loader = LoadStreams("rtsp://example.com/stream1.mp4") + >>> for sources, imgs, _ in stream_loader: + ... # Process the images + ... pass + >>> stream_loader.close() + + Notes: + - The class uses threading to efficiently load frames from multiple streams simultaneously. + - It automatically handles YouTube links, converting them to the best available stream URL. + - The class implements a buffer system to manage frame storage and retrieval. + """ + + def __init__(self, sources="file.streams", vid_stride=1, buffer=False): + """Initialize stream loader for multiple video sources, supporting various stream types.""" + torch.backends.cudnn.benchmark = True # faster for fixed-size inference + self.buffer = buffer # buffer input streams + self.running = True # running flag for Thread + self.mode = "stream" + self.vid_stride = vid_stride # video frame-rate stride + + sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] + n = len(sources) + self.bs = n + self.fps = [0] * n # frames per second + self.frames = [0] * n + self.threads = [None] * n + self.caps = [None] * n # video capture objects + self.imgs = [[] for _ in range(n)] # images + self.shape = [[] for _ in range(n)] # image shapes + self.sources = [ops.clean_str(x).replace(os.sep, "_") for x in sources] # clean source names for later + for i, s in enumerate(sources): # index, source + # Start thread to read frames from video stream + st = f"{i + 1}/{n}: {s}... " + if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video + # YouTube format i.e. 'https://www.youtube.com/watch?v=Jsn8D3aC840' or 'https://youtu.be/Jsn8D3aC840' + s = get_best_youtube_url(s) + s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam + if s == 0 and (IS_COLAB or IS_KAGGLE): + raise NotImplementedError( + "'source=0' webcam not supported in Colab and Kaggle notebooks. " + "Try running 'source=0' in a local environment." + ) + self.caps[i] = cv2.VideoCapture(s) # store video capture object + if not self.caps[i].isOpened(): + raise ConnectionError(f"{st}Failed to open {s}") + w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan + self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float( + "inf" + ) # infinite stream fallback + self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback + + success, im = self.caps[i].read() # guarantee first frame + if not success or im is None: + raise ConnectionError(f"{st}Failed to read images from {s}") + self.imgs[i].append(im) + self.shape[i] = im.shape + self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True) + LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)") + self.threads[i].start() + LOGGER.info("") # newline + + def update(self, i, cap, stream): + """Read stream frames in daemon thread and update image buffer.""" + n, f = 0, self.frames[i] # frame number, frame array + while self.running and cap.isOpened() and n < (f - 1): + if len(self.imgs[i]) < 30: # keep a <=30-image buffer + n += 1 + cap.grab() # .read() = .grab() followed by .retrieve() + if n % self.vid_stride == 0: + success, im = cap.retrieve() + if not success: + im = np.zeros(self.shape[i], dtype=np.uint8) + LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.") + cap.open(stream) # re-open stream if signal was lost + if self.buffer: + self.imgs[i].append(im) + else: + self.imgs[i] = [im] + else: + time.sleep(0.01) # wait until the buffer is empty + + def close(self): + """Terminates stream loader, stops threads, and releases video capture resources.""" + self.running = False # stop flag for Thread + for thread in self.threads: + if thread.is_alive(): + thread.join(timeout=5) # Add timeout + for cap in self.caps: # Iterate through the stored VideoCapture objects + try: + cap.release() # release video capture + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}") + cv2.destroyAllWindows() + + def __iter__(self): + """Iterates through YOLO image feed and re-opens unresponsive streams.""" + self.count = -1 + return self + + def __next__(self): + """Returns the next batch of frames from multiple video streams for processing.""" + self.count += 1 + + images = [] + for i, x in enumerate(self.imgs): + # Wait until a frame is available in each buffer + while not x: + if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit + self.close() + raise StopIteration + time.sleep(1 / min(self.fps)) + x = self.imgs[i] + if not x: + LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}") + + # Get and remove the first frame from imgs buffer + if self.buffer: + images.append(x.pop(0)) + + # Get the last frame, and clear the rest from the imgs buffer + else: + images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) + x.clear() + + return self.sources, images, [""] * self.bs + + def __len__(self): + """Return the number of video streams in the LoadStreams object.""" + return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years + + +class LoadScreenshots: + """ + Ultralytics screenshot dataloader for capturing and processing screen images. + + This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with + `yolo predict source=screen`. + + Attributes: + source (str): The source input indicating which screen to capture. + screen (int): The screen number to capture. + left (int): The left coordinate for screen capture area. + top (int): The top coordinate for screen capture area. + width (int): The width of the screen capture area. + height (int): The height of the screen capture area. + mode (str): Set to 'stream' indicating real-time capture. + frame (int): Counter for captured frames. + sct (mss.mss): Screen capture object from `mss` library. + bs (int): Batch size, set to 1. + fps (int): Frames per second, set to 30. + monitor (Dict[str, int]): Monitor configuration details. + + Methods: + __iter__: Returns an iterator object. + __next__: Captures the next screenshot and returns it. + + Examples: + >>> loader = LoadScreenshots("0 100 100 640 480") # screen 0, top-left (100,100), 640x480 + >>> for source, im, im0s, vid_cap, s in loader: + ... print(f"Captured frame: {im.shape}") + """ + + def __init__(self, source): + """Initialize screenshot capture with specified screen and region parameters.""" + check_requirements("mss") + import mss # noqa + + source, *params = source.split() + self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0 + if len(params) == 1: + self.screen = int(params[0]) + elif len(params) == 4: + left, top, width, height = (int(x) for x in params) + elif len(params) == 5: + self.screen, left, top, width, height = (int(x) for x in params) + self.mode = "stream" + self.frame = 0 + self.sct = mss.mss() + self.bs = 1 + self.fps = 30 + + # Parse monitor shape + monitor = self.sct.monitors[self.screen] + self.top = monitor["top"] if top is None else (monitor["top"] + top) + self.left = monitor["left"] if left is None else (monitor["left"] + left) + self.width = width or monitor["width"] + self.height = height or monitor["height"] + self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} + + def __iter__(self): + """Yields the next screenshot image from the specified screen or region for processing.""" + return self + + def __next__(self): + """Captures and returns the next screenshot as a numpy array using the mss library.""" + im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR + s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " + + self.frame += 1 + return [str(self.screen)], [im0], [s] # screen, img, string + + +class LoadImagesAndVideos: + """ + A class for loading and processing images and videos for YOLO object detection. + + This class manages the loading and pre-processing of image and video data from various sources, including + single image files, video files, and lists of image and video paths. + + Attributes: + files (List[str]): List of image and video file paths. + nf (int): Total number of files (images and videos). + video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False). + mode (str): Current mode, 'image' or 'video'. + vid_stride (int): Stride for video frame-rate. + bs (int): Batch size. + cap (cv2.VideoCapture): Video capture object for OpenCV. + frame (int): Frame counter for video. + frames (int): Total number of frames in the video. + count (int): Counter for iteration, initialized at 0 during __iter__(). + ni (int): Number of images. + + Methods: + __init__: Initialize the LoadImagesAndVideos object. + __iter__: Returns an iterator object for VideoStream or ImageFolder. + __next__: Returns the next batch of images or video frames along with their paths and metadata. + _new_video: Creates a new video capture object for the given path. + __len__: Returns the number of batches in the object. + + Examples: + >>> loader = LoadImagesAndVideos("path/to/data", batch=32, vid_stride=1) + >>> for paths, imgs, info in loader: + ... # Process batch of images or video frames + ... pass + + Notes: + - Supports various image formats including HEIC. + - Handles both local files and directories. + - Can read from a text file containing paths to images and videos. + """ + + def __init__(self, path, batch=1, vid_stride=1): + """Initialize dataloader for images and videos, supporting various input formats.""" + parent = None + if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line + parent = Path(path).parent + path = Path(path).read_text().splitlines() # list of sources + files = [] + for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: + a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912 + if "*" in a: + files.extend(sorted(glob.glob(a, recursive=True))) # glob + elif os.path.isdir(a): + files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir + elif os.path.isfile(a): + files.append(a) # files (absolute or relative to CWD) + elif parent and (parent / p).is_file(): + files.append(str((parent / p).absolute())) # files (relative to *.txt file parent) + else: + raise FileNotFoundError(f"{p} does not exist") + + # Define files as images or videos + images, videos = [], [] + for f in files: + suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase + if suffix in IMG_FORMATS: + images.append(f) + elif suffix in VID_FORMATS: + videos.append(f) + ni, nv = len(images), len(videos) + + self.files = images + videos + self.nf = ni + nv # number of files + self.ni = ni # number of images + self.video_flag = [False] * ni + [True] * nv + self.mode = "video" if ni == 0 else "image" # default to video if no images + self.vid_stride = vid_stride # video frame-rate stride + self.bs = batch + if any(videos): + self._new_video(videos[0]) # new video + else: + self.cap = None + if self.nf == 0: + raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}") + + def __iter__(self): + """Iterates through image/video files, yielding source paths, images, and metadata.""" + self.count = 0 + return self + + def __next__(self): + """Returns the next batch of images or video frames with their paths and metadata.""" + paths, imgs, info = [], [], [] + while len(imgs) < self.bs: + if self.count >= self.nf: # end of file list + if imgs: + return paths, imgs, info # return last partial batch + else: + raise StopIteration + + path = self.files[self.count] + if self.video_flag[self.count]: + self.mode = "video" + if not self.cap or not self.cap.isOpened(): + self._new_video(path) + + success = False + for _ in range(self.vid_stride): + success = self.cap.grab() + if not success: + break # end of video or failure + + if success: + success, im0 = self.cap.retrieve() + if success: + self.frame += 1 + paths.append(path) + imgs.append(im0) + info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") + if self.frame == self.frames: # end of video + self.count += 1 + self.cap.release() + else: + # Move to the next file if the current video ended or failed to open + self.count += 1 + if self.cap: + self.cap.release() + if self.count < self.nf: + self._new_video(self.files[self.count]) + else: + # Handle image files (including HEIC) + self.mode = "image" + if path.split(".")[-1].lower() == "heic": + # Load HEIC image using Pillow with pillow-heif + check_requirements("pillow-heif") + + from pillow_heif import register_heif_opener + + register_heif_opener() # Register HEIF opener with Pillow + with Image.open(path) as img: + im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray + else: + im0 = imread(path) # BGR + if im0 is None: + LOGGER.warning(f"WARNING ⚠️ Image Read Error {path}") + else: + paths.append(path) + imgs.append(im0) + info.append(f"image {self.count + 1}/{self.nf} {path}: ") + self.count += 1 # move to the next file + if self.count >= self.ni: # end of image list + break + + return paths, imgs, info + + def _new_video(self, path): + """Creates a new video capture object for the given path and initializes video-related attributes.""" + self.frame = 0 + self.cap = cv2.VideoCapture(path) + self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + if not self.cap.isOpened(): + raise FileNotFoundError(f"Failed to open video {path}") + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) + + def __len__(self): + """Returns the number of files (images and videos) in the dataset.""" + return math.ceil(self.nf / self.bs) # number of batches + + +class LoadPilAndNumpy: + """ + Load images from PIL and Numpy arrays for batch processing. + + This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic + validation and format conversion to ensure that the images are in the required format for downstream processing. + + Attributes: + paths (List[str]): List of image paths or autogenerated filenames. + im0 (List[np.ndarray]): List of images stored as Numpy arrays. + mode (str): Type of data being processed, set to 'image'. + bs (int): Batch size, equivalent to the length of `im0`. + + Methods: + _single_check: Validate and format a single image to a Numpy array. + + Examples: + >>> from PIL import Image + >>> import numpy as np + >>> pil_img = Image.new("RGB", (100, 100)) + >>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + >>> loader = LoadPilAndNumpy([pil_img, np_img]) + >>> paths, images, _ = next(iter(loader)) + >>> print(f"Loaded {len(images)} images") + Loaded 2 images + """ + + def __init__(self, im0): + """Initializes a loader for PIL and Numpy images, converting inputs to a standardized format.""" + if not isinstance(im0, list): + im0 = [im0] + # use `image{i}.jpg` when Image.filename returns an empty path. + self.paths = [getattr(im, "filename", "") or f"image{i}.jpg" for i, im in enumerate(im0)] + self.im0 = [self._single_check(im) for im in im0] + self.mode = "image" + self.bs = len(self.im0) + + @staticmethod + def _single_check(im): + """Validate and format an image to numpy array, ensuring RGB order and contiguous memory.""" + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): + if im.mode != "RGB": + im = im.convert("RGB") + im = np.asarray(im)[:, :, ::-1] + im = np.ascontiguousarray(im) # contiguous + return im + + def __len__(self): + """Returns the length of the 'im0' attribute, representing the number of loaded images.""" + return len(self.im0) + + def __next__(self): + """Returns the next batch of images, paths, and metadata for processing.""" + if self.count == 1: # loop only once as it's batch inference + raise StopIteration + self.count += 1 + return self.paths, self.im0, [""] * self.bs + + def __iter__(self): + """Iterates through PIL/numpy images, yielding paths, raw images, and metadata for processing.""" + self.count = 0 + return self + + +class LoadTensor: + """ + A class for loading and processing tensor data for object detection tasks. + + This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for + further processing in object detection pipelines. + + Attributes: + im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W). + bs (int): Batch size, inferred from the shape of `im0`. + mode (str): Current processing mode, set to 'image'. + paths (List[str]): List of image paths or auto-generated filenames. + + Methods: + _single_check: Validates and formats an input tensor. + + Examples: + >>> import torch + >>> tensor = torch.rand(1, 3, 640, 640) + >>> loader = LoadTensor(tensor) + >>> paths, images, info = next(iter(loader)) + >>> print(f"Processed {len(images)} images") + """ + + def __init__(self, im0) -> None: + """Initialize LoadTensor object for processing torch.Tensor image data.""" + self.im0 = self._single_check(im0) + self.bs = self.im0.shape[0] + self.mode = "image" + self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] + + @staticmethod + def _single_check(im, stride=32): + """Validates and formats a single image tensor, ensuring correct shape and normalization.""" + s = ( + f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) " + f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible." + ) + if len(im.shape) != 4: + if len(im.shape) != 3: + raise ValueError(s) + LOGGER.warning(s) + im = im.unsqueeze(0) + if im.shape[2] % stride or im.shape[3] % stride: + raise ValueError(s) + if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07 + LOGGER.warning( + f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. " + f"Dividing input by 255." + ) + im = im.float() / 255.0 + + return im + + def __iter__(self): + """Yields an iterator object for iterating through tensor image data.""" + self.count = 0 + return self + + def __next__(self): + """Yields the next batch of tensor images and metadata for processing.""" + if self.count == 1: + raise StopIteration + self.count += 1 + return self.paths, self.im0, [""] * self.bs + + def __len__(self): + """Returns the batch size of the tensor input.""" + return self.bs + + +def autocast_list(source): + """Merges a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction.""" + files = [] + for im in source: + if isinstance(im, (str, Path)): # filename or uri + files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im)) + elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image + files.append(im) + else: + raise TypeError( + f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n" + f"See https://docs.ultralytics.com/modes/predict for supported source types." + ) + + return files + + +def get_best_youtube_url(url, method="pytube"): + """ + Retrieves the URL of the best quality MP4 video stream from a given YouTube video. + + Args: + url (str): The URL of the YouTube video. + method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp". + Defaults to "pytube". + + Returns: + (str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found. + + Examples: + >>> url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" + >>> best_url = get_best_youtube_url(url) + >>> print(best_url) + https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=... + + Notes: + - Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp. + - The function prioritizes streams with at least 1080p resolution when available. + - For the "yt-dlp" method, it looks for formats with video codec, no audio, and *.mp4 extension. + """ + if method == "pytube": + # Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954 + check_requirements("pytubefix>=6.5.2") + from pytubefix import YouTube + + streams = YouTube(url).streams.filter(file_extension="mp4", only_video=True) + streams = sorted(streams, key=lambda s: s.resolution, reverse=True) # sort streams by resolution + for stream in streams: + if stream.resolution and int(stream.resolution[:-1]) >= 1080: # check if resolution is at least 1080p + return stream.url + + elif method == "pafy": + check_requirements(("pafy", "youtube_dl==2020.12.2")) + import pafy # noqa + + return pafy.new(url).getbestvideo(preftype="mp4").url + + elif method == "yt-dlp": + check_requirements("yt-dlp") + import yt_dlp + + with yt_dlp.YoutubeDL({"quiet": True}) as ydl: + info_dict = ydl.extract_info(url, download=False) # extract info + for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last + # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size + good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 + if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": + return f.get("url") + + +# Define constants +LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots) diff --git a/tracking/ultralytics/data/scripts/download_weights.sh b/tracking/ultralytics/data/scripts/download_weights.sh new file mode 100644 index 0000000000000000000000000000000000000000..59ea9b926e47fb5cb01bfba9792903ff6c09e661 --- /dev/null +++ b/tracking/ultralytics/data/scripts/download_weights.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Download latest models from https://github.com/ultralytics/assets/releases +# Example usage: bash ultralytics/data/scripts/download_weights.sh +# parent +# └── weights +# ├── yolov8n.pt ← downloads here +# ├── yolov8s.pt +# └── ... + +python << EOF +from ultralytics.utils.downloads import attempt_download_asset + +assets = [f"yolov8{size}{suffix}.pt" for size in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose")] +for x in assets: + attempt_download_asset(f"weights/{x}") +EOF diff --git a/tracking/ultralytics/data/scripts/get_coco.sh b/tracking/ultralytics/data/scripts/get_coco.sh new file mode 100644 index 0000000000000000000000000000000000000000..bd667cbbfbd4de6728ed11a8f1bbde22640dc03c --- /dev/null +++ b/tracking/ultralytics/data/scripts/get_coco.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Download COCO 2017 dataset https://cocodataset.org +# Example usage: bash data/scripts/get_coco.sh +# parent +# ├── ultralytics +# └── datasets +# └── coco ← downloads here + +# Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments +if [ "$#" -gt 0 ]; then + for opt in "$@"; do + case "${opt}" in + --train) train=true ;; + --val) val=true ;; + --test) test=true ;; + --segments) segments=true ;; + --sama) sama=true ;; + esac + done +else + train=true + val=true + test=false + segments=false + sama=false +fi + +# Download/unzip labels +d='../datasets' # unzip directory +url=https://github.com/ultralytics/assets/releases/download/v0.0.0/ +if [ "$segments" == "true" ]; then + f='coco2017labels-segments.zip' # 169 MB +elif [ "$sama" == "true" ]; then + f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/ +else + f='coco2017labels.zip' # 46 MB +fi +echo 'Downloading' $url$f ' ...' +curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & + +# Download/unzip images +d='../datasets/coco/images' # unzip directory +url=http://images.cocodataset.org/zips/ +if [ "$train" == "true" ]; then + f='train2017.zip' # 19G, 118k images + echo 'Downloading' $url$f '...' + curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & +fi +if [ "$val" == "true" ]; then + f='val2017.zip' # 1G, 5k images + echo 'Downloading' $url$f '...' + curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & +fi +if [ "$test" == "true" ]; then + f='test2017.zip' # 7G, 41k images (optional) + echo 'Downloading' $url$f '...' + curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & +fi +wait # finish background tasks diff --git a/tracking/ultralytics/data/scripts/get_coco128.sh b/tracking/ultralytics/data/scripts/get_coco128.sh new file mode 100644 index 0000000000000000000000000000000000000000..8260f018643af86d3944c9dbea727d8a32229eab --- /dev/null +++ b/tracking/ultralytics/data/scripts/get_coco128.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) +# Example usage: bash data/scripts/get_coco128.sh +# parent +# ├── ultralytics +# └── datasets +# └── coco128 ← downloads here + +# Download/unzip images and labels +d='../datasets' # unzip directory +url=https://github.com/ultralytics/assets/releases/download/v0.0.0/ +f='coco128.zip' # or 'coco128-segments.zip', 68 MB +echo 'Downloading' $url$f ' ...' +curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f & + +wait # finish background tasks diff --git a/tracking/ultralytics/data/scripts/get_imagenet.sh b/tracking/ultralytics/data/scripts/get_imagenet.sh new file mode 100644 index 0000000000000000000000000000000000000000..091c4d5b04614965625ab511b2417668fde95450 --- /dev/null +++ b/tracking/ultralytics/data/scripts/get_imagenet.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Download ILSVRC2012 ImageNet dataset https://image-net.org +# Example usage: bash data/scripts/get_imagenet.sh +# parent +# ├── ultralytics +# └── datasets +# └── imagenet ← downloads here + +# Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val +if [ "$#" -gt 0 ]; then + for opt in "$@"; do + case "${opt}" in + --train) train=true ;; + --val) val=true ;; + esac + done +else + train=true + val=true +fi + +# Make dir +d='../datasets/imagenet' # unzip directory +mkdir -p $d && cd $d + +# Download/unzip train +if [ "$train" == "true" ]; then + wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images + mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train + tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar + find . -name "*.tar" | while read NAME; do + mkdir -p "${NAME%.tar}" + tar -xf "${NAME}" -C "${NAME%.tar}" + rm -f "${NAME}" + done + cd .. +fi + +# Download/unzip val +if [ "$val" == "true" ]; then + wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images + mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar + wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs +fi + +# Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail) +# rm train/n04266014/n04266014_10835.JPEG + +# TFRecords (optional) +# wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt diff --git a/tracking/ultralytics/data/split_dota.py b/tracking/ultralytics/data/split_dota.py new file mode 100644 index 0000000000000000000000000000000000000000..8e61343773003b63702eba4796063faf0b3f57a6 --- /dev/null +++ b/tracking/ultralytics/data/split_dota.py @@ -0,0 +1,325 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import itertools +from glob import glob +from math import ceil +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image + +from ultralytics.data.utils import exif_size, img2label_paths +from ultralytics.utils import TQDM +from ultralytics.utils.checks import check_requirements + + +def bbox_iof(polygon1, bbox2, eps=1e-6): + """ + Calculate Intersection over Foreground (IoF) between polygons and bounding boxes. + + Args: + polygon1 (np.ndarray): Polygon coordinates with shape (n, 8). + bbox2 (np.ndarray): Bounding boxes with shape (n, 4). + eps (float, optional): Small value to prevent division by zero. + + Returns: + (np.ndarray): IoF scores with shape (n, 1) or (n, m) if bbox2 is (m, 4). + + Notes: + Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4]. + Bounding box format: [x_min, y_min, x_max, y_max]. + """ + check_requirements("shapely") + from shapely.geometry import Polygon + + polygon1 = polygon1.reshape(-1, 4, 2) + lt_point = np.min(polygon1, axis=-2) # left-top + rb_point = np.max(polygon1, axis=-2) # right-bottom + bbox1 = np.concatenate([lt_point, rb_point], axis=-1) + + lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2]) + rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:]) + wh = np.clip(rb - lt, 0, np.inf) + h_overlaps = wh[..., 0] * wh[..., 1] + + left, top, right, bottom = (bbox2[..., i] for i in range(4)) + polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2) + + sg_polys1 = [Polygon(p) for p in polygon1] + sg_polys2 = [Polygon(p) for p in polygon2] + overlaps = np.zeros(h_overlaps.shape) + for p in zip(*np.nonzero(h_overlaps)): + overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area + unions = np.array([p.area for p in sg_polys1], dtype=np.float32) + unions = unions[..., None] + + unions = np.clip(unions, eps, np.inf) + outputs = overlaps / unions + if outputs.ndim == 1: + outputs = outputs[..., None] + return outputs + + +def load_yolo_dota(data_root, split="train"): + """ + Load DOTA dataset. + + Args: + data_root (str): Data root directory. + split (str): The split data set, could be `train` or `val`. + + Returns: + (List[Dict]): List of annotation dictionaries containing image information. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + """ + assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}." + im_dir = Path(data_root) / "images" / split + assert im_dir.exists(), f"Can't find {im_dir}, please check your data root." + im_files = glob(str(Path(data_root) / "images" / split / "*")) + lb_files = img2label_paths(im_files) + annos = [] + for im_file, lb_file in zip(im_files, lb_files): + w, h = exif_size(Image.open(im_file)) + with open(lb_file, encoding="utf-8") as f: + lb = [x.split() for x in f.read().strip().splitlines() if len(x)] + lb = np.array(lb, dtype=np.float32) + annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file)) + return annos + + +def get_windows(im_size, crop_sizes=(1024,), gaps=(200,), im_rate_thr=0.6, eps=0.01): + """ + Get the coordinates of windows. + + Args: + im_size (tuple): Original image size, (h, w). + crop_sizes (List[int]): Crop size of windows. + gaps (List[int]): Gap between crops. + im_rate_thr (float): Threshold of windows areas divided by image areas. + eps (float): Epsilon value for math operations. + + Returns: + (np.ndarray): Array of window coordinates with shape (n, 4) where each row is [x_start, y_start, x_stop, y_stop]. + """ + h, w = im_size + windows = [] + for crop_size, gap in zip(crop_sizes, gaps): + assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]" + step = crop_size - gap + + xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1) + xs = [step * i for i in range(xn)] + if len(xs) > 1 and xs[-1] + crop_size > w: + xs[-1] = w - crop_size + + yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1) + ys = [step * i for i in range(yn)] + if len(ys) > 1 and ys[-1] + crop_size > h: + ys[-1] = h - crop_size + + start = np.array(list(itertools.product(xs, ys)), dtype=np.int64) + stop = start + crop_size + windows.append(np.concatenate([start, stop], axis=1)) + windows = np.concatenate(windows, axis=0) + + im_in_wins = windows.copy() + im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w) + im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h) + im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1]) + win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1]) + im_rates = im_areas / win_areas + if not (im_rates > im_rate_thr).any(): + max_rate = im_rates.max() + im_rates[abs(im_rates - max_rate) < eps] = 1 + return windows[im_rates > im_rate_thr] + + +def get_window_obj(anno, windows, iof_thr=0.7): + """Get objects for each window.""" + h, w = anno["ori_size"] + label = anno["label"] + if len(label): + label[:, 1::2] *= w + label[:, 2::2] *= h + iofs = bbox_iof(label[:, 1:], windows) + # Unnormalized and misaligned coordinates + return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns + else: + return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns + + +def crop_and_save(anno, windows, window_objs, im_dir, lb_dir, allow_background_images=True): + """ + Crop images and save new labels. + + Args: + anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys. + windows (np.ndarray): Array of windows coordinates with shape (n, 4). + window_objs (list): A list of labels inside each window. + im_dir (str): The output directory path of images. + lb_dir (str): The output directory path of labels. + allow_background_images (bool): Whether to include background images without labels. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + """ + im = cv2.imread(anno["filepath"]) + name = Path(anno["filepath"]).stem + for i, window in enumerate(windows): + x_start, y_start, x_stop, y_stop = window.tolist() + new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}" + patch_im = im[y_start:y_stop, x_start:x_stop] + ph, pw = patch_im.shape[:2] + + label = window_objs[i] + if len(label) or allow_background_images: + cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im) + if len(label): + label[:, 1::2] -= x_start + label[:, 2::2] -= y_start + label[:, 1::2] /= pw + label[:, 2::2] /= ph + + with open(Path(lb_dir) / f"{new_name}.txt", "w", encoding="utf-8") as f: + for lb in label: + formatted_coords = [f"{coord:.6g}" for coord in lb[1:]] + f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n") + + +def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024,), gaps=(200,)): + """ + Split both images and labels. + + Args: + data_root (str): Root directory of the dataset. + save_dir (str): Directory to save the split dataset. + split (str): The split data set, could be `train` or `val`. + crop_sizes (tuple): Tuple of crop sizes. + gaps (tuple): Tuple of gaps between crops. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - split + - labels + - split + and the output directory structure is: + - save_dir + - images + - split + - labels + - split + """ + im_dir = Path(save_dir) / "images" / split + im_dir.mkdir(parents=True, exist_ok=True) + lb_dir = Path(save_dir) / "labels" / split + lb_dir.mkdir(parents=True, exist_ok=True) + + annos = load_yolo_dota(data_root, split=split) + for anno in TQDM(annos, total=len(annos), desc=split): + windows = get_windows(anno["ori_size"], crop_sizes, gaps) + window_objs = get_window_obj(anno, windows) + crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir)) + + +def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)): + """ + Split train and val set of DOTA. + + Args: + data_root (str): Root directory of the dataset. + save_dir (str): Directory to save the split dataset. + crop_size (int): Base crop size. + gap (int): Base gap between crops. + rates (tuple): Scaling rates for crop_size and gap. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + and the output directory structure is: + - save_dir + - images + - train + - val + - labels + - train + - val + """ + crop_sizes, gaps = [], [] + for r in rates: + crop_sizes.append(int(crop_size / r)) + gaps.append(int(gap / r)) + for split in ["train", "val"]: + split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps) + + +def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)): + """ + Split test set of DOTA, labels are not included within this set. + + Args: + data_root (str): Root directory of the dataset. + save_dir (str): Directory to save the split dataset. + crop_size (int): Base crop size. + gap (int): Base gap between crops. + rates (tuple): Scaling rates for crop_size and gap. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - test + and the output directory structure is: + - save_dir + - images + - test + """ + crop_sizes, gaps = [], [] + for r in rates: + crop_sizes.append(int(crop_size / r)) + gaps.append(int(gap / r)) + save_dir = Path(save_dir) / "images" / "test" + save_dir.mkdir(parents=True, exist_ok=True) + + im_dir = Path(data_root) / "images" / "test" + assert im_dir.exists(), f"Can't find {im_dir}, please check your data root." + im_files = glob(str(im_dir / "*")) + for im_file in TQDM(im_files, total=len(im_files), desc="test"): + w, h = exif_size(Image.open(im_file)) + windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps) + im = cv2.imread(im_file) + name = Path(im_file).stem + for window in windows: + x_start, y_start, x_stop, y_stop = window.tolist() + new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}" + patch_im = im[y_start:y_stop, x_start:x_stop] + cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im) + + +if __name__ == "__main__": + split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split") + split_test(data_root="DOTAv2", save_dir="DOTAv2-split") diff --git a/tracking/ultralytics/data/utils.py b/tracking/ultralytics/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9e1fe66ea708902d1d4ccb0d67c90280a74493 --- /dev/null +++ b/tracking/ultralytics/data/utils.py @@ -0,0 +1,711 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import hashlib +import json +import os +import random +import subprocess +import time +import zipfile +from multiprocessing.pool import ThreadPool +from pathlib import Path +from tarfile import is_tarfile + +import cv2 +import numpy as np +from PIL import Image, ImageOps + +from ultralytics.nn.autobackend import check_class_names +from ultralytics.utils import ( + DATASETS_DIR, + LOGGER, + NUM_THREADS, + ROOT, + SETTINGS_FILE, + TQDM, + clean_url, + colorstr, + emojis, + is_dir_writeable, + yaml_load, + yaml_save, +) +from ultralytics.utils.checks import check_file, check_font, is_ascii +from ultralytics.utils.downloads import download, safe_download, unzip_file +from ultralytics.utils.ops import segments2boxes + +HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance." +IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes +VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes +PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders +FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" + + +def img2label_paths(img_paths): + """Define label paths as a function of image paths.""" + sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings + return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] + + +def get_hash(paths): + """Returns a single hash value of a list of paths (files or dirs).""" + size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes + h = hashlib.sha256(str(size).encode()) # hash sizes + h.update("".join(paths).encode()) # hash paths + return h.hexdigest() # return hash + + +def exif_size(img: Image.Image): + """Returns exif-corrected PIL size.""" + s = img.size # (width, height) + if img.format == "JPEG": # only support JPEG images + try: + if exif := img.getexif(): + rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274 + if rotation in {6, 8}: # rotation 270 or 90 + s = s[1], s[0] + except Exception: + pass + return s + + +def verify_image(args): + """Verify one image.""" + (im_file, cls), prefix = args + # Number (found, corrupt), message + nf, nc, msg = 0, 0, "" + try: + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + shape = (shape[1], shape[0]) # hw + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}" + if im.format.lower() in {"jpg", "jpeg"}: + with open(im_file, "rb") as f: + f.seek(-2, 2) + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" + nf = 1 + except Exception as e: + nc = 1 + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" + return (im_file, cls), nf, nc, msg + + +def verify_image_label(args): + """Verify one image-label pair.""" + im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args + # Number (missing, found, empty, corrupt), message, segments, keypoints + nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None + try: + # Verify images + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + shape = (shape[1], shape[0]) # hw + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}" + if im.format.lower() in {"jpg", "jpeg"}: + with open(im_file, "rb") as f: + f.seek(-2, 2) + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" + + # Verify labels + if os.path.isfile(lb_file): + nf = 1 # label found + with open(lb_file, encoding="utf-8") as f: + lb = [x.split() for x in f.read().strip().splitlines() if len(x)] + if any(len(x) > 6 for x in lb) and (not keypoint): # is segment + classes = np.array([x[0] for x in lb], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) + lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) + lb = np.array(lb, dtype=np.float32) + if nl := len(lb): + if keypoint: + assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" + points = lb[:, 5:].reshape(-1, ndim)[:, :2] + else: + assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" + points = lb[:, 1:] + assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" + assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" + + # All labels + if single_cls: + lb[:, 0] = 0 + max_cls = lb[:, 0].max() # max label count + assert max_cls < num_cls, ( + f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " + f"Possible class labels are 0-{num_cls - 1}" + ) + _, i = np.unique(lb, axis=0, return_index=True) + if len(i) < nl: # duplicate row check + lb = lb[i] # remove duplicates + if segments: + segments = [segments[x] for x in i] + msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" + else: + ne = 1 # label empty + lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32) + else: + nm = 1 # label missing + lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32) + if keypoint: + keypoints = lb[:, 5:].reshape(-1, nkpt, ndim) + if ndim == 2: + kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32) + keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3) + lb = lb[:, :5] + return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg + except Exception as e: + nc = 1 + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" + return [None, None, None, None, None, nm, nf, ne, nc, msg] + + +def visualize_image_annotations(image_path, txt_path, label_map): + """ + Visualizes YOLO annotations (bounding boxes and class labels) on an image. + + This function reads an image and its corresponding annotation file in YOLO format, then + draws bounding boxes around detected objects and labels them with their respective class names. + The bounding box colors are assigned based on the class ID, and the text color is dynamically + adjusted for readability, depending on the background color's luminance. + + Args: + image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL. + txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object. + label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings). + + Examples: + >>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details + >>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map) + """ + import matplotlib.pyplot as plt + + from ultralytics.utils.plotting import colors + + img = np.array(Image.open(image_path)) + img_height, img_width = img.shape[:2] + annotations = [] + with open(txt_path, encoding="utf-8") as file: + for line in file: + class_id, x_center, y_center, width, height = map(float, line.split()) + x = (x_center - width / 2) * img_width + y = (y_center - height / 2) * img_height + w = width * img_width + h = height * img_height + annotations.append((x, y, w, h, int(class_id))) + fig, ax = plt.subplots(1) # Plot the image and annotations + for x, y, w, h, label in annotations: + color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color + rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle + ax.add_patch(rect) + luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance + ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color) + ax.imshow(img) + plt.show() + + +def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): + """ + Convert a list of polygons to a binary mask of the specified image size. + + Args: + imgsz (tuple): The size of the image as (height, width). + polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where + N is the number of polygons, and M is the number of points such that M % 2 = 0. + color (int, optional): The color value to fill in the polygons on the mask. + downsample_ratio (int, optional): Factor by which to downsample the mask. + + Returns: + (np.ndarray): A binary mask of the specified image size with the polygons filled in. + """ + mask = np.zeros(imgsz, dtype=np.uint8) + polygons = np.asarray(polygons, dtype=np.int32) + polygons = polygons.reshape((polygons.shape[0], -1, 2)) + cv2.fillPoly(mask, polygons, color=color) + nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) + # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1 + return cv2.resize(mask, (nw, nh)) + + +def polygons2masks(imgsz, polygons, color, downsample_ratio=1): + """ + Convert a list of polygons to a set of binary masks of the specified image size. + + Args: + imgsz (tuple): The size of the image as (height, width). + polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where + N is the number of polygons, and M is the number of points such that M % 2 = 0. + color (int): The color value to fill in the polygons on the masks. + downsample_ratio (int, optional): Factor by which to downsample each mask. + + Returns: + (np.ndarray): A set of binary masks of the specified image size with the polygons filled in. + """ + return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons]) + + +def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): + """Return a (640, 640) overlap mask.""" + masks = np.zeros( + (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), + dtype=np.int32 if len(segments) > 255 else np.uint8, + ) + areas = [] + ms = [] + for si in range(len(segments)): + mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1) + ms.append(mask.astype(masks.dtype)) + areas.append(mask.sum()) + areas = np.asarray(areas) + index = np.argsort(-areas) + ms = np.array(ms)[index] + for i in range(len(segments)): + mask = ms[i] * (i + 1) + masks = masks + mask + masks = np.clip(masks, a_min=0, a_max=i + 1) + return masks, index + + +def find_dataset_yaml(path: Path) -> Path: + """ + Find and return the YAML file associated with a Detect, Segment or Pose dataset. + + This function searches for a YAML file at the root level of the provided directory first, and if not found, it + performs a recursive search. It prefers YAML files that have the same stem as the provided path. + + Args: + path (Path): The directory path to search for the YAML file. + + Returns: + (Path): The path of the found YAML file. + """ + files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive + assert files, f"No YAML file found in '{path.resolve()}'" + if len(files) > 1: + files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match + assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}" + return files[0] + + +def check_det_dataset(dataset, autodownload=True): + """ + Download, verify, and/or unzip a dataset if not found locally. + + This function checks the availability of a specified dataset, and if not found, it has the option to download and + unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also + resolves paths related to the dataset. + + Args: + dataset (str): Path to the dataset or dataset descriptor (like a YAML file). + autodownload (bool, optional): Whether to automatically download the dataset if not found. + + Returns: + (dict): Parsed dataset information and paths. + """ + file = check_file(dataset) + + # Download (optional) + extract_dir = "" + if zipfile.is_zipfile(file) or is_tarfile(file): + new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) + file = find_dataset_yaml(DATASETS_DIR / new_dir) + extract_dir, autodownload = file.parent, False + + # Read YAML + data = yaml_load(file, append_filename=True) # dictionary + + # Checks + for k in "train", "val": + if k not in data: + if k != "val" or "validation" not in data: + raise SyntaxError( + emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") + ) + LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") + data["val"] = data.pop("validation") # replace 'validation' key with 'val' key + if "names" not in data and "nc" not in data: + raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) + if "names" in data and "nc" in data and len(data["names"]) != data["nc"]: + raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) + if "names" not in data: + data["names"] = [f"class_{i}" for i in range(data["nc"])] + else: + data["nc"] = len(data["names"]) + + data["names"] = check_class_names(data["names"]) + + # Resolve paths + path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root + if not path.is_absolute(): + path = (DATASETS_DIR / path).resolve() + + # Set paths + data["path"] = path # download scripts + for k in "train", "val", "test", "minival": + if data.get(k): # prepend path + if isinstance(data[k], str): + x = (path / data[k]).resolve() + if not x.exists() and data[k].startswith("../"): + x = (path / data[k][3:]).resolve() + data[k] = str(x) + else: + data[k] = [str((path / x).resolve()) for x in data[k]] + + # Parse YAML + val, s = (data.get(x) for x in ("val", "download")) + if val: + val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path + if not all(x.exists() for x in val): + name = clean_url(dataset) # dataset name with URL auth stripped + m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'" + if s and autodownload: + LOGGER.warning(m) + else: + m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'" + raise FileNotFoundError(m) + t = time.time() + r = None # success + if s.startswith("http") and s.endswith(".zip"): # URL + safe_download(url=s, dir=DATASETS_DIR, delete=True) + elif s.startswith("bash "): # bash script + LOGGER.info(f"Running {s} ...") + r = os.system(s) + else: # python script + exec(s, {"yaml": data}) + dt = f"({round(time.time() - t, 1)}s)" + s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌" + LOGGER.info(f"Dataset download {s}\n") + check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts + + return data # dictionary + + +def check_cls_dataset(dataset, split=""): + """ + Checks a classification dataset such as Imagenet. + + This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. + If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally. + + Args: + dataset (str | Path): The name of the dataset. + split (str, optional): The split of the dataset. Either 'val', 'test', or ''. + + Returns: + (dict): A dictionary containing the following keys: + - 'train' (Path): The directory path containing the training set of the dataset. + - 'val' (Path): The directory path containing the validation set of the dataset. + - 'test' (Path): The directory path containing the test set of the dataset. + - 'nc' (int): The number of classes in the dataset. + - 'names' (dict): A dictionary of class names in the dataset. + """ + # Download (optional if dataset=https://file.zip is passed directly) + if str(dataset).startswith(("http:/", "https:/")): + dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) + elif Path(dataset).suffix in {".zip", ".tar", ".gz"}: + file = check_file(dataset) + dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) + + dataset = Path(dataset) + data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() + if not data_dir.is_dir(): + LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") + t = time.time() + if str(dataset) == "imagenet": + subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) + else: + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip" + download(url, dir=data_dir.parent) + s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" + LOGGER.info(s) + train_set = data_dir / "train" + val_set = ( + data_dir / "val" + if (data_dir / "val").exists() + else data_dir / "validation" + if (data_dir / "validation").exists() + else None + ) # data/test or data/val + test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test + if split == "val" and not val_set: + LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") + val_set = test_set + elif split == "test" and not test_set: + LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") + test_set = val_set + + nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes + names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list + names = dict(enumerate(sorted(names))) + + # Print to console + for k, v in {"train": train_set, "val": val_set, "test": test_set}.items(): + prefix = f"{colorstr(f'{k}:')} {v}..." + if v is None: + LOGGER.info(prefix) + else: + files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS] + nf = len(files) # number of files + nd = len({file.parent for file in files}) # number of directories + if nf == 0: + if k == "train": + raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) + else: + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") + elif nd != nc: + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") + else: + LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") + + return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} + + +class HUBDatasetStats: + """ + A class for generating HUB dataset JSON and `-hub` dataset directory. + + Args: + path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'. + task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'. + autodownload (bool): Attempt to download dataset if not found locally. Default is False. + + Note: + Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets + i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip. + + Examples: + >>> from ultralytics.data.utils import HUBDatasetStats + >>> stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset + >>> stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset + >>> stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset + >>> stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset + >>> stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset + >>> stats.get_json(save=True) + >>> stats.process_images() + """ + + def __init__(self, path="coco8.yaml", task="detect", autodownload=False): + """Initialize class.""" + path = Path(path).resolve() + LOGGER.info(f"Starting HUB dataset checks for {path}....") + + self.task = task # detect, segment, pose, classify, obb + if self.task == "classify": + unzip_dir = unzip_file(path) + data = check_cls_dataset(unzip_dir) + data["path"] = unzip_dir + else: # detect, segment, pose, obb + _, data_dir, yaml_path = self._unzip(Path(path)) + try: + # Load YAML with checks + data = yaml_load(yaml_path) + data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets + yaml_save(yaml_path, data) + data = check_det_dataset(yaml_path, autodownload) # dict + data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute) + except Exception as e: + raise Exception("error/HUB/dataset_stats/init") from e + + self.hub_dir = Path(f"{data['path']}-hub") + self.im_dir = self.hub_dir / "images" + self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary + self.data = data + + @staticmethod + def _unzip(path): + """Unzip data.zip.""" + if not str(path).endswith(".zip"): # path is data.yaml + return False, None, path + unzip_dir = unzip_file(path, path=path.parent) + assert unzip_dir.is_dir(), ( + f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/" + ) + return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path + + def _hub_ops(self, f): + """Saves a compressed image for HUB previews.""" + compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub + + def get_json(self, save=False, verbose=False): + """Return dataset JSON for Ultralytics HUB.""" + + def _round(labels): + """Update labels to integer class and 4 decimal place floats.""" + if self.task == "detect": + coordinates = labels["bboxes"] + elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy + coordinates = [x.flatten() for x in labels["segments"]] + elif self.task == "pose": + n, nk, nd = labels["keypoints"].shape + coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1) + else: + raise ValueError(f"Undefined dataset task={self.task}.") + zipped = zip(labels["cls"], coordinates) + return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped] + + for split in "train", "val", "test": + self.stats[split] = None # predefine + path = self.data.get(split) + + # Check split + if path is None: # no split + continue + files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split + if not files: # no images + continue + + # Get dataset statistics + if self.task == "classify": + from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics' + + dataset = ImageFolder(self.data[split]) + + x = np.zeros(len(dataset.classes)).astype(int) + for im in dataset.imgs: + x[im[1]] += 1 + + self.stats[split] = { + "instance_stats": {"total": len(dataset), "per_class": x.tolist()}, + "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()}, + "labels": [{Path(k).name: v} for k, v in dataset.imgs], + } + else: + from ultralytics.data import YOLODataset + + dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task) + x = np.array( + [ + np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"]) + for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics") + ] + ) # shape(128x80) + self.stats[split] = { + "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()}, + "image_stats": { + "total": len(dataset), + "unlabelled": int(np.all(x == 0, 1).sum()), + "per_class": (x > 0).sum(0).tolist(), + }, + "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)], + } + + # Save, print and return + if save: + self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/ + stats_path = self.hub_dir / "stats.json" + LOGGER.info(f"Saving {stats_path.resolve()}...") + with open(stats_path, "w", encoding="utf-8") as f: + json.dump(self.stats, f) # save stats.json + if verbose: + LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) + return self.stats + + def process_images(self): + """Compress images for Ultralytics HUB.""" + from ultralytics.data import YOLODataset # ClassificationDataset + + self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/ + for split in "train", "val", "test": + if self.data.get(split) is None: + continue + dataset = YOLODataset(img_path=self.data[split], data=self.data) + with ThreadPool(NUM_THREADS) as pool: + for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"): + pass + LOGGER.info(f"Done. All images saved to {self.im_dir}") + return self.im_dir + + +def compress_one_image(f, f_new=None, max_dim=1920, quality=50): + """ + Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python + Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be + resized. + + Args: + f (str): The path to the input image file. + f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten. + max_dim (int, optional): The maximum dimension (width or height) of the output image. + quality (int, optional): The image compression quality as a percentage. + + Examples: + >>> from pathlib import Path + >>> from ultralytics.data.utils import compress_one_image + >>> for f in Path("path/to/dataset").rglob("*.jpg"): + >>> compress_one_image(f) + """ + try: # use PIL + im = Image.open(f) + r = max_dim / max(im.height, im.width) # ratio + if r < 1.0: # image too large + im = im.resize((int(im.width * r), int(im.height * r))) + im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save + except Exception as e: # use OpenCV + LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") + im = cv2.imread(f) + im_height, im_width = im.shape[:2] + r = max_dim / max(im_height, im_width) # ratio + if r < 1.0: # image too large + im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA) + cv2.imwrite(str(f_new or f), im) + + +def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False): + """ + Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. + + Args: + path (Path, optional): Path to images directory. + weights (list | tuple, optional): Train, validation, and test split fractions. + annotated_only (bool, optional): If True, only images with an associated txt file are used. + + Examples: + >>> from ultralytics.data.utils import autosplit + >>> autosplit() + """ + path = Path(path) # images dir + files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only + n = len(files) # number of files + random.seed(0) # for reproducibility + indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split + + txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files + for x in txt: + if (path.parent / x).exists(): + (path.parent / x).unlink() # remove existing + + LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) + for i, img in TQDM(zip(indices, files), total=n): + if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label + with open(path.parent / txt[i], "a", encoding="utf-8") as f: + f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file + + +def load_dataset_cache_file(path): + """Load an Ultralytics *.cache dictionary from path.""" + import gc + + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 + cache = np.load(str(path), allow_pickle=True).item() # load dict + gc.enable() + return cache + + +def save_dataset_cache_file(prefix, path, x, version): + """Save an Ultralytics dataset *.cache dictionary x to path.""" + x["version"] = version # add cache version + if is_dir_writeable(path.parent): + if path.exists(): + path.unlink() # remove *.cache file if exists + with open(str(path), "wb") as file: # context manager here fixes windows async np.save bug + np.save(file, x) + LOGGER.info(f"{prefix}New cache created: {path}") + else: + LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") diff --git a/tracking/ultralytics/engine/__init__.py b/tracking/ultralytics/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/tracking/ultralytics/engine/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/tracking/ultralytics/engine/exporter.py b/tracking/ultralytics/engine/exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..472fb8e3dca41601c5c0a9727811495a7a998e9b --- /dev/null +++ b/tracking/ultralytics/engine/exporter.py @@ -0,0 +1,1649 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit. + +Format | `format=argument` | Model +--- | --- | --- +PyTorch | - | yolo11n.pt +TorchScript | `torchscript` | yolo11n.torchscript +ONNX | `onnx` | yolo11n.onnx +OpenVINO | `openvino` | yolo11n_openvino_model/ +TensorRT | `engine` | yolo11n.engine +CoreML | `coreml` | yolo11n.mlpackage +TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/ +TensorFlow GraphDef | `pb` | yolo11n.pb +TensorFlow Lite | `tflite` | yolo11n.tflite +TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite +TensorFlow.js | `tfjs` | yolo11n_web_model/ +PaddlePaddle | `paddle` | yolo11n_paddle_model/ +MNN | `mnn` | yolo11n.mnn +NCNN | `ncnn` | yolo11n_ncnn_model/ +IMX | `imx` | yolo11n_imx_model/ +RKNN | `rknn` | yolo11n_rknn_model/ + +Requirements: + $ pip install "ultralytics[export]" + +Python: + from ultralytics import YOLO + model = YOLO('yolo11n.pt') + results = model.export(format='onnx') + +CLI: + $ yolo mode=export model=yolo11n.pt format=onnx + +Inference: + $ yolo predict model=yolo11n.pt # PyTorch + yolo11n.torchscript # TorchScript + yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolo11n_openvino_model # OpenVINO + yolo11n.engine # TensorRT + yolo11n.mlpackage # CoreML (macOS-only) + yolo11n_saved_model # TensorFlow SavedModel + yolo11n.pb # TensorFlow GraphDef + yolo11n.tflite # TensorFlow Lite + yolo11n_edgetpu.tflite # TensorFlow Edge TPU + yolo11n_paddle_model # PaddlePaddle + yolo11n.mnn # MNN + yolo11n_ncnn_model # NCNN + yolo11n_imx_model # IMX + +TensorFlow.js: + $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example + $ npm install + $ ln -s ../../yolo11n_web_model public/yolo11n_web_model + $ npm start +""" + +import gc +import json +import os +import shutil +import subprocess +import time +import warnings +from contextlib import contextmanager +from copy import deepcopy +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.cfg import TASK2DATA, get_cfg +from ultralytics.data import build_dataloader +from ultralytics.data.dataset import YOLODataset +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.autobackend import check_class_names, default_class_names +from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, WorldModel +from ultralytics.utils import ( + ARM64, + DEFAULT_CFG, + IS_COLAB, + IS_JETSON, + LINUX, + LOGGER, + MACOS, + PYTHON_VERSION, + RKNN_CHIPS, + ROOT, + WINDOWS, + __version__, + callbacks, + colorstr, + get_default_args, + yaml_save, +) +from ultralytics.utils.checks import ( + check_imgsz, + check_is_path_safe, + check_requirements, + check_version, + is_sudo_available, +) +from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download +from ultralytics.utils.files import file_size, spaces_in_path +from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy +from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device + + +def export_formats(): + """Return a dictionary of Ultralytics YOLO export formats.""" + x = [ + ["PyTorch", "-", ".pt", True, True, []], + ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]], + ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]], + ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]], + ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]], + ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]], + ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]], + ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]], + ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]], + ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []], + ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]], + ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]], + ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]], + ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]], + ["IMX", "imx", "_imx_model", True, True, ["int8"]], + ["RKNN", "rknn", "_rknn_model", False, False, ["batch", "name"]], + ] + return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x))) + + +def validate_args(format, passed_args, valid_args): + """ + Validate arguments based on the export format. + + Args: + format (str): The export format. + passed_args (Namespace): The arguments used during export. + valid_args (list): List of valid arguments for the format. + + Raises: + AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings. + """ + export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"] + + assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed." + custom = {"batch": 1, "data": None, "device": None} # exporter defaults + default_args = get_cfg(DEFAULT_CFG, custom) + for arg in export_args: + not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None) + if not_default: + assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'" + + +def gd_outputs(gd): + """Return TensorFlow GraphDef model output node names.""" + name_list, input_list = [], [] + for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef + name_list.append(node.name) + input_list.extend(node.input) + return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp")) + + +def try_export(inner_func): + """YOLO export decorator, i.e. @try_export.""" + inner_args = get_default_args(inner_func) + + def outer_func(*args, **kwargs): + """Export a model.""" + prefix = inner_args["prefix"] + dt = 0.0 + try: + with Profile() as dt: + f, model = inner_func(*args, **kwargs) + LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)") + return f, model + except Exception as e: + LOGGER.error(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}") + raise e + + return outer_func + + +@contextmanager +def arange_patch(args): + """ + Workaround for ONNX torch.arange incompatibility with FP16. + + https://github.com/pytorch/pytorch/issues/148041. + """ + if args.dynamic and args.half and args.format == "onnx": + func = torch.arange + + def arange(*args, dtype=None, **kwargs): + """Return a 1-D tensor of size with values from the interval and common difference.""" + return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype + + torch.arange = arange # patch + yield + torch.arange = func # unpatch + else: + yield + + +class Exporter: + """ + A class for exporting a model. + + Attributes: + args (SimpleNamespace): Configuration for the exporter. + callbacks (list, optional): List of callback functions. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Exporter class. + + Args: + cfg (str, optional): Path to a configuration file. + overrides (dict, optional): Configuration overrides. + _callbacks (dict, optional): Dictionary of callback functions. + """ + self.args = get_cfg(cfg, overrides) + if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback + + self.callbacks = _callbacks or callbacks.get_default_callbacks() + callbacks.add_integration_callbacks(self) + + def __call__(self, model=None) -> str: + """Return list of exported files/dirs after running callbacks.""" + self.run_callbacks("on_export_start") + t = time.time() + fmt = self.args.format.lower() # to lowercase + if fmt in {"tensorrt", "trt"}: # 'engine' aliases + fmt = "engine" + if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases + fmt = "coreml" + fmts_dict = export_formats() + fmts = tuple(fmts_dict["Argument"][1:]) # available export formats + if fmt not in fmts: + import difflib + + # Get the closest match if format is invalid + matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match + if not matches: + raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") + LOGGER.warning(f"WARNING ⚠️ Invalid export format='{fmt}', updating to format='{matches[0]}'") + fmt = matches[0] + flags = [x == fmt for x in fmts] + if sum(flags) != 1: + raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") + (jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn, imx, rknn) = ( + flags # export booleans + ) + + is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs)) + + # Device + dla = None + if fmt == "engine" and self.args.device is None: + LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0") + self.args.device = "0" + if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first + dla = self.args.device.split(":")[-1] + self.args.device = "0" # update device to "0" + assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}." + self.device = select_device("cpu" if self.args.device is None else self.args.device) + + # Argument compatibility checks + fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1] + validate_args(fmt, self.args, fmt_keys) + if imx and not self.args.int8: + LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.") + self.args.int8 = True + if not hasattr(model, "names"): + model.names = default_class_names() + model.names = check_class_names(model.names) + if self.args.half and self.args.int8: + LOGGER.warning("WARNING ⚠️ half=True and int8=True are mutually exclusive, setting half=False.") + self.args.half = False + if self.args.half and onnx and self.device.type == "cpu": + LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0") + self.args.half = False + self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size + if self.args.int8 and engine: + self.args.dynamic = True # enforce dynamic to export TensorRT INT8 + if self.args.optimize: + assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False" + assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" + if rknn: + if not self.args.name: + LOGGER.warning( + "WARNING ⚠️ Rockchip RKNN export requires a missing 'name' arg for processor type. " + "Using default name='rk3588'." + ) + self.args.name = "rk3588" + self.args.name = self.args.name.lower() + assert self.args.name in RKNN_CHIPS, ( + f"Invalid processor name '{self.args.name}' for Rockchip RKNN export. Valid names are {RKNN_CHIPS}." + ) + if self.args.int8 and tflite: + assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models." + if self.args.nms: + assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models." + assert not (tflite and ARM64 and LINUX), "TFLite export with NMS unsupported on ARM64 Linux" + if getattr(model, "end2end", False): + LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.") + self.args.nms = False + self.args.conf = self.args.conf or 0.25 # set conf default value for nms export + if edgetpu: + if not LINUX or ARM64: + raise SystemError( + "Edge TPU export only supported on non-aarch64 Linux. See https://coral.ai/docs/edgetpu/compiler" + ) + elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420 + LOGGER.warning("WARNING ⚠️ Edge TPU export requires batch size 1, setting batch=1.") + self.args.batch = 1 + if isinstance(model, WorldModel): + LOGGER.warning( + "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n" + "WARNING ⚠️ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to " + "(torchscript, onnx, openvino, engine, coreml) formats. " + "See https://docs.ultralytics.com/models/yolo-world for details." + ) + model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445 + if self.args.int8 and not self.args.data: + self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data + LOGGER.warning( + "WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. " + f"Using default 'data={self.args.data}'." + ) + if tfjs and (ARM64 and LINUX): + raise SystemError("TF.js exports are not currently supported on ARM64 Linux") + + # Input + im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) + file = Path( + getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "") + ) + if file.suffix in {".yaml", ".yml"}: + file = Path(file.name) + + # Update model + model = deepcopy(model).to(self.device) + for p in model.parameters(): + p.requires_grad = False + model.eval() + model.float() + model = model.fuse() + + if imx: + from ultralytics.utils.torch_utils import FXModel + + model = FXModel(model) + for m in model.modules(): + if isinstance(m, Classify): + m.export = True + if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB + m.dynamic = self.args.dynamic + m.export = True + m.format = self.args.format + m.max_det = self.args.max_det + elif isinstance(m, C2f) and not is_tf_format: + # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph + m.forward = m.forward_split + if isinstance(m, Detect) and imx: + from ultralytics.utils.tal import make_anchors + + m.anchors, m.strides = ( + x.transpose(0, 1) + for x in make_anchors( + torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5 + ) + ) + + y = None + for _ in range(2): # dry runs + y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im) + if self.args.half and onnx and self.device.type != "cpu": + im, model = im.half(), model.half() # to FP16 + + # Filter warnings + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning + warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning + warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning + + # Assign + self.im = im + self.model = model + self.file = file + self.output_shape = ( + tuple(y.shape) + if isinstance(y, torch.Tensor) + else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y) + ) + self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO") + data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else "" + description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}" + self.metadata = { + "description": description, + "author": "Ultralytics", + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + "stride": int(max(model.stride)), + "task": model.task, + "batch": self.args.batch, + "imgsz": self.imgsz, + "names": model.names, + "args": {k: v for k, v in self.args if k in fmt_keys}, + } # model metadata + if dla is not None: + self.metadata["dla"] = dla # make sure `AutoBackend` uses correct dla device if it has one + if model.task == "pose": + self.metadata["kpt_shape"] = model.model[-1].kpt_shape + + LOGGER.info( + f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and " + f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)" + ) + + # Exports + f = [""] * len(fmts) # exported filenames + if jit or ncnn: # TorchScript + f[0], _ = self.export_torchscript() + if engine: # TensorRT required before ONNX + f[1], _ = self.export_engine(dla=dla) + if onnx: # ONNX + f[2], _ = self.export_onnx() + if xml: # OpenVINO + f[3], _ = self.export_openvino() + if coreml: # CoreML + f[4], _ = self.export_coreml() + if is_tf_format: # TensorFlow formats + self.args.int8 |= edgetpu + f[5], keras_model = self.export_saved_model() + if pb or tfjs: # pb prerequisite to tfjs + f[6], _ = self.export_pb(keras_model=keras_model) + if tflite: + f[7], _ = self.export_tflite() + if edgetpu: + f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite") + if tfjs: + f[9], _ = self.export_tfjs() + if paddle: # PaddlePaddle + f[10], _ = self.export_paddle() + if mnn: # MNN + f[11], _ = self.export_mnn() + if ncnn: # NCNN + f[12], _ = self.export_ncnn() + if imx: + f[13], _ = self.export_imx() + if rknn: + f[14], _ = self.export_rknn() + + # Finish + f = [str(x) for x in f if x] # filter out '' and None + if any(f): + f = str(Path(f[-1])) + square = self.imgsz[0] == self.imgsz[1] + s = ( + "" + if square + else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " + f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." + ) + imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "") + predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else "" + q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization + LOGGER.info( + f"\nExport complete ({time.time() - t:.1f}s)" + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}" + f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}" + f"\nVisualize: https://netron.app" + ) + + self.run_callbacks("on_export_end") + return f # return list of exported files/dirs + + def get_int8_calibration_dataloader(self, prefix=""): + """Build and return a dataloader for calibration of INT8 models.""" + LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") + data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data) + # TensorRT INT8 calibration should use 2x batch size + batch = self.args.batch * (2 if self.args.format == "engine" else 1) + dataset = YOLODataset( + data[self.args.split or "val"], + data=data, + task=self.model.task, + imgsz=self.imgsz[0], + augment=False, + batch_size=batch, + ) + n = len(dataset) + if n < self.args.batch: + raise ValueError( + f"The calibration dataset ({n} images) must have at least as many images as the batch size " + f"('batch={self.args.batch}')." + ) + elif n < 300: + LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.") + return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading + + @try_export + def export_torchscript(self, prefix=colorstr("TorchScript:")): + """YOLO TorchScript model export.""" + LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...") + f = self.file.with_suffix(".torchscript") + + ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False) + extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap() + if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html + LOGGER.info(f"{prefix} optimizing for mobile...") + from torch.utils.mobile_optimizer import optimize_for_mobile + + optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) + else: + ts.save(str(f), _extra_files=extra_files) + return f, None + + @try_export + def export_onnx(self, prefix=colorstr("ONNX:")): + """YOLO ONNX export.""" + requirements = ["onnx>=1.12.0"] + if self.args.simplify: + requirements += ["onnxslim", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")] + check_requirements(requirements) + import onnx # noqa + + opset_version = self.args.opset or get_latest_opset() + LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...") + f = str(self.file.with_suffix(".onnx")) + output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"] + dynamic = self.args.dynamic + if dynamic: + dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640) + if isinstance(self.model, SegmentationModel): + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400) + dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160) + elif isinstance(self.model, DetectionModel): + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400) + if self.args.nms: # only batch size is dynamic with NMS + dynamic["output0"].pop(2) + if self.args.nms and self.model.task == "obb": + self.args.opset = opset_version # for NMSModel + # OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865 + try: + torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version) + except RuntimeError: # it will fail if it's already registered + pass + check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB + + with arange_patch(self.args): + torch.onnx.export( + NMSModel(self.model, self.args) if self.args.nms else self.model, + self.im, + f, + verbose=False, + opset_version=opset_version, + do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False + input_names=["images"], + output_names=output_names, + dynamic_axes=dynamic or None, + ) + + # Checks + model_onnx = onnx.load(f) # load onnx model + + # Simplify + if self.args.simplify: + try: + import onnxslim + + LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...") + model_onnx = onnxslim.slim(model_onnx) + + except Exception as e: + LOGGER.warning(f"{prefix} simplifier failure: {e}") + + # Metadata + for k, v in self.metadata.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + + onnx.save(model_onnx, f) + return f, model_onnx + + @try_export + def export_openvino(self, prefix=colorstr("OpenVINO:")): + """YOLO OpenVINO export.""" + check_requirements("openvino>=2024.0.0,!=2025.0.0") + import openvino as ov + + LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...") + assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed" + ov_model = ov.convert_model( + NMSModel(self.model, self.args) if self.args.nms else self.model, + input=None if self.args.dynamic else [self.im.shape], + example_input=self.im, + ) + + def serialize(ov_model, file): + """Set RT info, serialize, and save metadata YAML.""" + ov_model.set_rt_info("YOLO", ["model_info", "model_type"]) + ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"]) + ov_model.set_rt_info(114, ["model_info", "pad_value"]) + ov_model.set_rt_info([255.0], ["model_info", "scale_values"]) + ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"]) + ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"]) + if self.model.task != "classify": + ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"]) + + ov.save_model(ov_model, file, compress_to_fp16=self.args.half) + yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml + + if self.args.int8: + fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}") + fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name) + check_requirements("nncf>=2.14.0") + import nncf + + def transform_fn(data_item) -> np.ndarray: + """Quantization transform function.""" + data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item + assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing" + im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0 + return np.expand_dims(im, 0) if im.ndim == 3 else im + + # Generate calibration data for integer quantization + ignored_scope = None + if isinstance(self.model.model[-1], Detect): + # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2]) + ignored_scope = nncf.IgnoredScope( # ignore operations + patterns=[ + f".*{head_module_name}/.*/Add", + f".*{head_module_name}/.*/Sub*", + f".*{head_module_name}/.*/Mul*", + f".*{head_module_name}/.*/Div*", + f".*{head_module_name}\\.dfl.*", + ], + types=["Sigmoid"], + ) + + quantized_ov_model = nncf.quantize( + model=ov_model, + calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn), + preset=nncf.QuantizationPreset.MIXED, + ignored_scope=ignored_scope, + ) + serialize(quantized_ov_model, fq_ov) + return fq, None + + f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}") + f_ov = str(Path(f) / self.file.with_suffix(".xml").name) + + serialize(ov_model, f_ov) + return f, None + + @try_export + def export_paddle(self, prefix=colorstr("PaddlePaddle:")): + """YOLO Paddle export.""" + check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle", "x2paddle")) + import x2paddle # noqa + from x2paddle.convert import pytorch2paddle # noqa + + LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...") + f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}") + + pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml + return f, None + + @try_export + def export_mnn(self, prefix=colorstr("MNN:")): + """YOLOv8 MNN export using MNN https://github.com/alibaba/MNN.""" + f_onnx, _ = self.export_onnx() # get onnx model first + + check_requirements("MNN>=2.9.6") + import MNN # noqa + from MNN.tools import mnnconvert + + # Setup and checks + LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...") + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = str(self.file.with_suffix(".mnn")) # MNN model file + args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)] + if self.args.int8: + args.extend(("--weightQuantBits", "8")) + if self.args.half: + args.append("--fp16") + mnnconvert.convert(args) + # remove scratch file for model convert optimize + convert_scratch = Path(self.file.parent / ".__convert_external_data.bin") + if convert_scratch.exists(): + convert_scratch.unlink() + return f, None + + @try_export + def export_ncnn(self, prefix=colorstr("NCNN:")): + """YOLO NCNN export using PNNX https://github.com/pnnx/pnnx.""" + check_requirements("ncnn") + import ncnn # noqa + + LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__}...") + f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}")) + f_ts = self.file.with_suffix(".torchscript") + + name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename + pnnx = name if name.is_file() else (ROOT / name) + if not pnnx.is_file(): + LOGGER.warning( + f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from " + "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory " + f"or in {ROOT}. See PNNX repo for full installation instructions." + ) + system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux" + try: + release, assets = get_github_assets(repo="pnnx/pnnx") + asset = [x for x in assets if f"{system}.zip" in x][0] + assert isinstance(asset, str), "Unable to retrieve PNNX repo assets" # i.e. pnnx-20240410-macos.zip + LOGGER.info(f"{prefix} successfully found latest PNNX asset file {asset}") + except Exception as e: + release = "20240410" + asset = f"pnnx-{release}-{system}.zip" + LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {asset}") + unzip_dir = safe_download(f"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}", delete=True) + if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability + shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT + pnnx.chmod(0o777) # set read, write, and execute permissions for everyone + shutil.rmtree(unzip_dir) # delete unzip dir + + ncnn_args = [ + f"ncnnparam={f / 'model.ncnn.param'}", + f"ncnnbin={f / 'model.ncnn.bin'}", + f"ncnnpy={f / 'model_ncnn.py'}", + ] + + pnnx_args = [ + f"pnnxparam={f / 'model.pnnx.param'}", + f"pnnxbin={f / 'model.pnnx.bin'}", + f"pnnxpy={f / 'model_pnnx.py'}", + f"pnnxonnx={f / 'model.pnnx.onnx'}", + ] + + cmd = [ + str(pnnx), + str(f_ts), + *ncnn_args, + *pnnx_args, + f"fp16={int(self.args.half)}", + f"device={self.device.type}", + f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', + ] + f.mkdir(exist_ok=True) # make ncnn_model directory + LOGGER.info(f"{prefix} running '{' '.join(cmd)}'") + subprocess.run(cmd, check=True) + + # Remove debug files + pnnx_files = [x.split("=")[-1] for x in pnnx_args] + for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files): + Path(f_debug).unlink(missing_ok=True) + + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml + return str(f), None + + @try_export + def export_coreml(self, prefix=colorstr("CoreML:")): + """YOLO CoreML export.""" + mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested + check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=8.0") + import coremltools as ct # noqa + + LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...") + assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux." + assert self.args.batch == 1, "CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'." + f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage") + if f.is_dir(): + shutil.rmtree(f) + + bias = [0.0, 0.0, 0.0] + scale = 1 / 255 + classifier_config = None + if self.model.task == "classify": + classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None + model = self.model + elif self.model.task == "detect": + model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model + else: + if self.args.nms: + LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolo11n.pt'.") + # TODO CoreML Segment and Pose model pipelining + model = self.model + + ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model + ct_model = ct.convert( + ts, + inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], + classifier_config=classifier_config, + convert_to="neuralnetwork" if mlmodel else "mlprogram", + ) + bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None) + if bits < 32: + if "kmeans" in mode: + check_requirements("scikit-learn") # scikit-learn package required for k-means quantization + if mlmodel: + ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) + elif bits == 8: # mlprogram already quantized to FP16 + import coremltools.optimize.coreml as cto + + op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512) + config = cto.OptimizationConfig(global_config=op_config) + ct_model = cto.palettize_weights(ct_model, config=config) + if self.args.nms and self.model.task == "detect": + if mlmodel: + # coremltools<=6.2 NMS export requires Python<3.11 + check_version(PYTHON_VERSION, "<3.11", name="Python ", hard=True) + weights_dir = None + else: + ct_model.save(str(f)) # save otherwise weights_dir does not exist + weights_dir = str(f / "Data/com.apple.CoreML/weights") + ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir) + + m = self.metadata # metadata dict + ct_model.short_description = m.pop("description") + ct_model.author = m.pop("author") + ct_model.license = m.pop("license") + ct_model.version = m.pop("version") + ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()}) + try: + ct_model.save(str(f)) # save *.mlpackage + except Exception as e: + LOGGER.warning( + f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. " + f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928." + ) + f = f.with_suffix(".mlmodel") + ct_model.save(str(f)) + return f, ct_model + + @try_export + def export_engine(self, dla=None, prefix=colorstr("TensorRT:")): + """YOLO TensorRT export https://developer.nvidia.com/tensorrt.""" + assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" + f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016 + + try: + import tensorrt as trt # noqa + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + + # Setup and checks + LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") + is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10 + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = self.file.with_suffix(".engine") # TensorRT engine file + logger = trt.Logger(trt.Logger.INFO) + if self.args.verbose: + logger.min_severity = trt.Logger.Severity.VERBOSE + + # Engine builder + builder = trt.Builder(logger) + config = builder.create_builder_config() + workspace = int((self.args.workspace or 0) * (1 << 30)) + if is_trt10 and workspace > 0: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) + elif workspace > 0: # TensorRT versions 7, 8 + config.max_workspace_size = workspace + flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(flag) + half = builder.platform_has_fast_fp16 and self.args.half + int8 = builder.platform_has_fast_int8 and self.args.int8 + + # Optionally switch to DLA if enabled + if dla is not None: + if not IS_JETSON: + raise ValueError("DLA is only available on NVIDIA Jetson devices") + LOGGER.info(f"{prefix} enabling DLA on core {dla}...") + if not self.args.half and not self.args.int8: + raise ValueError( + "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again." + ) + config.default_device_type = trt.DeviceType.DLA + config.DLA_core = int(dla) + config.set_flag(trt.BuilderFlag.GPU_FALLBACK) + + # Read ONNX file + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(f_onnx): + raise RuntimeError(f"failed to load ONNX file: {f_onnx}") + + # Network inputs + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + for inp in inputs: + LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') + for out in outputs: + LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') + + if self.args.dynamic: + shape = self.im.shape + if shape[0] <= 1: + LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'") + profile = builder.create_optimization_profile() + min_shape = (1, shape[1], 32, 32) # minimum input shape + max_shape = (*shape[:2], *(int(max(1, self.args.workspace or 1) * d) for d in shape[2:])) # max input shape + for inp in inputs: + profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape) + config.add_optimization_profile(profile) + + LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {f}") + if int8: + config.set_flag(trt.BuilderFlag.INT8) + config.set_calibration_profile(profile) + config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED + + class EngineCalibrator(trt.IInt8Calibrator): + def __init__( + self, + dataset, # ultralytics.data.build.InfiniteDataLoader + batch: int, + cache: str = "", + ) -> None: + trt.IInt8Calibrator.__init__(self) + self.dataset = dataset + self.data_iter = iter(dataset) + self.algo = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 + self.batch = batch + self.cache = Path(cache) + + def get_algorithm(self) -> trt.CalibrationAlgoType: + """Get the calibration algorithm to use.""" + return self.algo + + def get_batch_size(self) -> int: + """Get the batch size to use for calibration.""" + return self.batch or 1 + + def get_batch(self, names) -> list: + """Get the next batch to use for calibration, as a list of device memory pointers.""" + try: + im0s = next(self.data_iter)["img"] / 255.0 + im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s + return [int(im0s.data_ptr())] + except StopIteration: + # Return [] or None, signal to TensorRT there is no calibration data remaining + return None + + def read_calibration_cache(self) -> bytes: + """Use existing cache instead of calibrating again, otherwise, implicitly return None.""" + if self.cache.exists() and self.cache.suffix == ".cache": + return self.cache.read_bytes() + + def write_calibration_cache(self, cache) -> None: + """Write calibration cache to disk.""" + _ = self.cache.write_bytes(cache) + + # Load dataset w/ builder (for batching) and calibrate + config.int8_calibrator = EngineCalibrator( + dataset=self.get_int8_calibration_dataloader(prefix), + batch=2 * self.args.batch, # TensorRT INT8 calibration should use 2x batch size + cache=str(self.file.with_suffix(".cache")), + ) + + elif half: + config.set_flag(trt.BuilderFlag.FP16) + + # Free CUDA memory + del self.model + gc.collect() + torch.cuda.empty_cache() + + # Write file + build = builder.build_serialized_network if is_trt10 else builder.build_engine + with build(network, config) as engine, open(f, "wb") as t: + # Metadata + meta = json.dumps(self.metadata) + t.write(len(meta).to_bytes(4, byteorder="little", signed=True)) + t.write(meta.encode()) + # Model + t.write(engine if is_trt10 else engine.serialize()) + + return f, None + + @try_export + def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")): + """YOLO TensorFlow SavedModel export.""" + cuda = torch.cuda.is_available() + try: + import tensorflow as tf # noqa + except ImportError: + check_requirements("tensorflow>=2.0.0") + import tensorflow as tf # noqa + check_requirements( + ( + "keras", # required by 'onnx2tf' package + "tf_keras", # required by 'onnx2tf' package + "sng4onnx>=1.0.1", # required by 'onnx2tf' package + "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package + "ai-edge-litert>=1.2.0", # required by 'onnx2tf' package + "onnx>=1.12.0", + "onnx2tf>=1.26.3", + "onnxslim>=0.1.31", + "tflite_support<=0.4.3" if IS_JETSON else "tflite_support", # fix ImportError 'GLIBCXX_3.4.29' + "flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package + "onnxruntime-gpu" if cuda else "onnxruntime", + "protobuf>=5", # tflite_support pins <=4 but >=5 works + ), + cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA + ) + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + check_version( + tf.__version__, + ">=2.0.0", + name="tensorflow", + verbose=True, + msg="https://github.com/ultralytics/ultralytics/issues/5161", + ) + import onnx2tf + + f = Path(str(self.file).replace(self.file.suffix, "_saved_model")) + if f.is_dir(): + shutil.rmtree(f) # delete output folder + + # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545 + onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy") + if not onnx2tf_file.exists(): + attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True) + + # Export to ONNX + self.args.simplify = True + f_onnx, _ = self.export_onnx() + + # Export to TF + np_data = None + if self.args.int8: + tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file + if self.args.data: + f.mkdir() + images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)] + images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute( + 0, 2, 3, 1 + ) + np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC + np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]] + + LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...") + keras_model = onnx2tf.convert( + input_onnx_file_path=f_onnx, + output_folder_path=str(f), + not_use_onnxsim=True, + verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873 + output_integer_quantized_tflite=self.args.int8, + quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate) + custom_input_op_name_np_data_path=np_data, + disable_group_convolution=True, # for end-to-end model compatibility + enable_batchmatmul_unfold=True, # for end-to-end model compatibility + ) + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml + + # Remove/rename TFLite models + if self.args.int8: + tmp_file.unlink(missing_ok=True) + for file in f.rglob("*_dynamic_range_quant.tflite"): + file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix)) + for file in f.rglob("*_integer_quant_with_int16_act.tflite"): + file.unlink() # delete extra fp16 activation TFLite files + + # Add TFLite metadata + for file in f.rglob("*.tflite"): + f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file) + + return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None) + + @try_export + def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")): + """YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow.""" + import tensorflow as tf # noqa + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + f = self.file.with_suffix(".pb") + + m = tf.function(lambda x: keras_model(x)) # full model + m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) + frozen_func = convert_variables_to_constants_v2(m) + frozen_func.graph.as_graph_def() + tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) + return f, None + + @try_export + def export_tflite(self, prefix=colorstr("TensorFlow Lite:")): + """YOLO TensorFlow Lite export.""" + # BUG https://github.com/ultralytics/ultralytics/issues/13436 + import tensorflow as tf # noqa + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model")) + if self.args.int8: + f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out + elif self.args.half: + f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out + else: + f = saved_model / f"{self.file.stem}_float32.tflite" + return str(f), None + + @try_export + def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")): + """YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/.""" + LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185") + + cmd = "edgetpu_compiler --version" + help_url = "https://coral.ai/docs/edgetpu/compiler/" + assert LINUX, f"export only supported on Linux. See {help_url}" + if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0: + LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}") + for c in ( + "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -", + 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' + "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list", + "sudo apt-get update", + "sudo apt-get install edgetpu-compiler", + ): + subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True) + ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] + + LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...") + f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model + + cmd = ( + "edgetpu_compiler " + f'--out_dir "{Path(f).parent}" ' + "--show_operations " + "--search_delegate " + "--delegate_search_step 30 " + "--timeout_sec 180 " + f'"{tflite_model}"' + ) + LOGGER.info(f"{prefix} running '{cmd}'") + subprocess.run(cmd, shell=True) + self._add_tflite_metadata(f) + return f, None + + @try_export + def export_tfjs(self, prefix=colorstr("TensorFlow.js:")): + """YOLO TensorFlow.js export.""" + check_requirements("tensorflowjs") + import tensorflow as tf + import tensorflowjs as tfjs # noqa + + LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...") + f = str(self.file).replace(self.file.suffix, "_web_model") # js dir + f_pb = str(self.file.with_suffix(".pb")) # *.pb path + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(f_pb, "rb") as file: + gd.ParseFromString(file.read()) + outputs = ",".join(gd_outputs(gd)) + LOGGER.info(f"\n{prefix} output node names: {outputs}") + + quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else "" + with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path + cmd = ( + "tensorflowjs_converter " + f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"' + ) + LOGGER.info(f"{prefix} running '{cmd}'") + subprocess.run(cmd, shell=True) + + if " " in f: + LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.") + + # Add metadata + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml + return f, None + + @try_export + def export_rknn(self, prefix=colorstr("RKNN:")): + """YOLO RKNN model export.""" + LOGGER.info(f"\n{prefix} starting export with rknn-toolkit2...") + + check_requirements("rknn-toolkit2") + if IS_COLAB: + # Prevent 'exit' from closing the notebook https://github.com/airockchip/rknn-toolkit2/issues/259 + import builtins + + builtins.exit = lambda: None + + from rknn.api import RKNN + + f, _ = self.export_onnx() + export_path = Path(f"{Path(f).stem}_rknn_model") + export_path.mkdir(exist_ok=True) + + rknn = RKNN(verbose=False) + rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], target_platform=self.args.name) + rknn.load_onnx(model=f) + rknn.build(do_quantization=False) # TODO: Add quantization support + f = f.replace(".onnx", f"-{self.args.name}.rknn") + rknn.export_rknn(f"{export_path / f}") + yaml_save(export_path / "metadata.yaml", self.metadata) + return export_path, None + + @try_export + def export_imx(self, prefix=colorstr("IMX:")): + """YOLO IMX export.""" + gptq = False + assert LINUX, ( + "export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter" + ) + if getattr(self.model, "end2end", False): + raise ValueError("IMX export is not supported for end2end models.") + if "C2f" not in self.model.__str__(): + raise ValueError("IMX export is only supported for YOLOv8n detection models") + check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0")) + check_requirements("imx500-converter[pt]==3.14.3") # Separate requirements for imx500-converter + + import model_compression_toolkit as mct + import onnx + from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms + + LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...") + + try: + out = subprocess.run( + ["java", "--version"], check=True, capture_output=True + ) # Java 17 is required for imx500-converter + if "openjdk 17" not in str(out.stdout): + raise FileNotFoundError + except FileNotFoundError: + c = ["apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"] + if is_sudo_available(): + c.insert(0, "sudo") + subprocess.run(c, check=True) + + def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): + for batch in dataloader: + img = batch["img"] + img = img / 255.0 + yield [img] + + tpc = mct.get_target_platform_capabilities( + fw_name="pytorch", target_platform_name="imx500", target_platform_version="v1" + ) + + config = mct.core.CoreConfig( + mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10), + quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True), + ) + + resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76) + + quant_model = ( + mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization + model=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + gptq_config=mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False), + core_config=config, + target_platform_capabilities=tpc, + )[0] + if gptq + else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization + in_module=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + core_config=config, + target_platform_capabilities=tpc, + )[0] + ) + + class NMSWrapper(torch.nn.Module): + def __init__( + self, + model: torch.nn.Module, + score_threshold: float = 0.001, + iou_threshold: float = 0.7, + max_detections: int = 300, + ): + """ + Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. + + Args: + model (nn.Module): Model instance. + score_threshold (float): Score threshold for non-maximum suppression. + iou_threshold (float): Intersection over union threshold for non-maximum suppression. + max_detections (float): The number of detections to return. + """ + super().__init__() + self.model = model + self.score_threshold = score_threshold + self.iou_threshold = iou_threshold + self.max_detections = max_detections + + def forward(self, images): + # model inference + outputs = self.model(images) + + boxes = outputs[0] + scores = outputs[1] + nms = multiclass_nms( + boxes=boxes, + scores=scores, + score_threshold=self.score_threshold, + iou_threshold=self.iou_threshold, + max_detections=self.max_detections, + ) + return nms + + quant_model = NMSWrapper( + model=quant_model, + score_threshold=self.args.conf or 0.001, + iou_threshold=self.args.iou, + max_detections=self.args.max_det, + ).to(self.device) + + f = Path(str(self.file).replace(self.file.suffix, "_imx_model")) + f.mkdir(exist_ok=True) + onnx_model = f / Path(str(self.file.name).replace(self.file.suffix, "_imx.onnx")) # js dir + mct.exporter.pytorch_export_model( + model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen + ) + + model_onnx = onnx.load(onnx_model) # load onnx model + for k, v in self.metadata.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + + onnx.save(model_onnx, onnx_model) + + subprocess.run( + ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"], + check=True, + ) + + # Needed for imx models. + with open(f / "labels.txt", "w", encoding="utf-8") as file: + file.writelines([f"{name}\n" for _, name in self.model.names.items()]) + + return f, None + + def _add_tflite_metadata(self, file): + """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" + import flatbuffers + + try: + # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845 + from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa + from tensorflow_lite_support.metadata.python import metadata # noqa + except ImportError: # ARM64 systems may not have the 'tensorflow_lite_support' package available + from tflite_support import metadata # noqa + from tflite_support import metadata_schema_py_generated as schema # noqa + + # Create model info + model_meta = schema.ModelMetadataT() + model_meta.name = self.metadata["description"] + model_meta.version = self.metadata["version"] + model_meta.author = self.metadata["author"] + model_meta.license = self.metadata["license"] + + # Label file + tmp_file = Path(file).parent / "temp_meta.txt" + with open(tmp_file, "w", encoding="utf-8") as f: + f.write(str(self.metadata)) + + label_file = schema.AssociatedFileT() + label_file.name = tmp_file.name + label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS + + # Create input info + input_meta = schema.TensorMetadataT() + input_meta.name = "image" + input_meta.description = "Input image to be detected." + input_meta.content = schema.ContentT() + input_meta.content.contentProperties = schema.ImagePropertiesT() + input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB + input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties + + # Create output info + output1 = schema.TensorMetadataT() + output1.name = "output" + output1.description = "Coordinates of detected objects, class labels, and confidence score" + output1.associatedFiles = [label_file] + if self.model.task == "segment": + output2 = schema.TensorMetadataT() + output2.name = "output" + output2.description = "Mask protos" + output2.associatedFiles = [label_file] + + # Create subgraph info + subgraph = schema.SubGraphMetadataT() + subgraph.inputTensorMetadata = [input_meta] + subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1] + model_meta.subgraphMetadata = [subgraph] + + b = flatbuffers.Builder(0) + b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + populator = metadata.MetadataPopulator.with_model_file(str(file)) + populator.load_metadata_buffer(metadata_buf) + populator.load_associated_files([str(tmp_file)]) + populator.populate() + tmp_file.unlink() + + def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")): + """YOLO CoreML pipeline.""" + import coremltools as ct # noqa + + LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...") + _, _, h, w = list(self.im.shape) # BCHW + + # Output shapes + spec = model.get_spec() + out0, out1 = iter(spec.description.output) + if MACOS: + from PIL import Image + + img = Image.new("RGB", (w, h)) # w=192, h=320 + out = model.predict({"image": img}) + out0_shape = out[out0.name].shape # (3780, 80) + out1_shape = out[out1.name].shape # (3780, 4) + else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y + out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80) + out1_shape = self.output_shape[2], 4 # (3780, 4) + + # Checks + names = self.metadata["names"] + nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height + _, nc = out0_shape # number of anchors, number of classes + assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check + + # Define output shapes (missing) + out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80) + out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4) + + # Model from spec + model = ct.models.MLModel(spec, weights_dir=weights_dir) + + # 3. Create NMS protobuf + nms_spec = ct.proto.Model_pb2.Model() + nms_spec.specificationVersion = 5 + for i in range(2): + decoder_output = model._spec.description.output[i].SerializeToString() + nms_spec.description.input.add() + nms_spec.description.input[i].ParseFromString(decoder_output) + nms_spec.description.output.add() + nms_spec.description.output[i].ParseFromString(decoder_output) + + nms_spec.description.output[0].name = "confidence" + nms_spec.description.output[1].name = "coordinates" + + output_sizes = [nc, 4] + for i in range(2): + ma_type = nms_spec.description.output[i].type.multiArrayType + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[0].lowerBound = 0 + ma_type.shapeRange.sizeRanges[0].upperBound = -1 + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i] + ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i] + del ma_type.shape[:] + + nms = nms_spec.nonMaximumSuppression + nms.confidenceInputFeatureName = out0.name # 1x507x80 + nms.coordinatesInputFeatureName = out1.name # 1x507x4 + nms.confidenceOutputFeatureName = "confidence" + nms.coordinatesOutputFeatureName = "coordinates" + nms.iouThresholdInputFeatureName = "iouThreshold" + nms.confidenceThresholdInputFeatureName = "confidenceThreshold" + nms.iouThreshold = self.args.iou + nms.confidenceThreshold = self.args.conf + nms.pickTop.perClass = True + nms.stringClassLabels.vector.extend(names.values()) + nms_model = ct.models.MLModel(nms_spec) + + # 4. Pipeline models together + pipeline = ct.models.pipeline.Pipeline( + input_features=[ + ("image", ct.models.datatypes.Array(3, ny, nx)), + ("iouThreshold", ct.models.datatypes.Double()), + ("confidenceThreshold", ct.models.datatypes.Double()), + ], + output_features=["confidence", "coordinates"], + ) + pipeline.add_model(model) + pipeline.add_model(nms_model) + + # Correct datatypes + pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString()) + pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString()) + pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString()) + + # Update metadata + pipeline.spec.specificationVersion = 5 + pipeline.spec.description.metadata.userDefined.update( + {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)} + ) + + # Save the model + model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) + model.input_description["image"] = "Input image" + model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})" + model.input_description["confidenceThreshold"] = ( + f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" + ) + model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")' + model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)" + LOGGER.info(f"{prefix} pipeline success") + return model + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def run_callbacks(self, event: str): + """Execute all callbacks for a given event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + +class IOSDetectModel(torch.nn.Module): + """Wrap an Ultralytics YOLO model for Apple iOS CoreML export.""" + + def __init__(self, model, im): + """Initialize the IOSDetectModel class with a YOLO model and example image.""" + super().__init__() + _, _, h, w = im.shape # batch, channel, height, width + self.model = model + self.nc = len(model.names) # number of classes + if w == h: + self.normalize = 1.0 / w # scalar + else: + self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller) + + def forward(self, x): + """Normalize predictions of object detection model with input size-dependent factors.""" + xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1) + return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4) + + +class NMSModel(torch.nn.Module): + """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB.""" + + def __init__(self, model, args): + """ + Initialize the NMSModel. + + Args: + model (torch.nn.module): The model to wrap with NMS postprocessing. + args (Namespace): The export arguments. + """ + super().__init__() + self.model = model + self.args = args + self.obb = model.task == "obb" + self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"}) + + def forward(self, x): + """ + Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose. + + Args: + x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W). + + Returns: + (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number of detections after NMS. + """ + from functools import partial + + from torchvision.ops import nms + + preds = self.model(x) + pred = preds[0] if isinstance(preds, tuple) else preds + kwargs = dict(device=pred.device, dtype=pred.dtype) + bs = pred.shape[0] + pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + extra_shape = pred.shape[-1] - (4 + len(self.model.names)) # extras from Segment, OBB, Pose + if self.args.dynamic and self.args.batch > 1: # batch size needs to always be same due to loop unroll + pad = torch.zeros(torch.max(torch.tensor(self.args.batch - bs), torch.tensor(0)), *pred.shape[1:], **kwargs) + pred = torch.cat((pred, pad)) + boxes, scores, extras = pred.split([4, len(self.model.names), extra_shape], dim=2) + scores, classes = scores.max(dim=-1) + self.args.max_det = min(pred.shape[1], self.args.max_det) # in case num_anchors < max_det + # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape). + out = torch.zeros(bs, self.args.max_det, boxes.shape[-1] + 2 + extra_shape, **kwargs) + for i in range(bs): + box, cls, score, extra = boxes[i], classes[i], scores[i], extras[i] + mask = score > self.args.conf + if self.is_tf: + # TFLite GatherND error if mask is empty + score *= mask + # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5` + mask = score.topk(min(self.args.max_det * 5, score.shape[0])).indices + box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask] + if not self.obb: + box = xywh2xyxy(box) + if self.is_tf: + # TFlite bug returns less boxes + box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0])) + nmsbox = box.clone() + # `8` is the minimum value experimented to get correct NMS results for obb + multiplier = 8 if self.obb else 1 + # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization + if self.args.format == "tflite": # TFLite is already normalized + nmsbox *= multiplier + else: + nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max() + if not self.args.agnostic_nms: # class-specific NMS + end = 2 if self.obb else 4 + # fully explicit expansion otherwise reshape error + # large max_wh causes issues when quantizing + cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end) + offbox = nmsbox[:, :end] + cls_offset * multiplier + nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1) + nms_fn = ( + partial( + nms_rotated, + use_triu=not ( + self.is_tf + or (self.args.opset or 14) < 14 + or (self.args.format == "openvino" and self.args.int8) # OpenVINO int8 error with triu + ), + ) + if self.obb + else nms + ) + keep = nms_fn( + torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox, + score, + self.args.iou, + )[: self.args.max_det] + dets = torch.cat( + [box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1 + ) + # Zero-pad to max_det size to avoid reshape error + pad = (0, 0, 0, self.args.max_det - dets.shape[0]) + out[i] = torch.nn.functional.pad(dets, pad) + return (out[:bs], preds[1]) if self.model.task == "segment" else out[:bs] diff --git a/tracking/ultralytics/engine/model.py b/tracking/ultralytics/engine/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff1c039aefa27f578fc0635d49bf32ec06c87b2 --- /dev/null +++ b/tracking/ultralytics/engine/model.py @@ -0,0 +1,1156 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import inspect +from pathlib import Path +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from PIL import Image + +from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.engine.results import Results +from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load +from ultralytics.utils import ( + ARGV, + ASSETS, + DEFAULT_CFG_DICT, + LOGGER, + RANK, + SETTINGS, + callbacks, + checks, + emojis, + yaml_load, +) + + +class Model(torch.nn.Module): + """ + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. + + Attributes: + callbacks (dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (torch.nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (dict): A dictionary of overrides for model configuration. + metrics (dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. + + Methods: + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict("image.jpg") + >>> model.train(data="coco8.yaml", epochs=3) + >>> metrics = model.val() + >>> model.export(format="onnx") + """ + + def __init__( + self, + model: Union[str, Path] = "yolo11n.pt", + task: str = None, + verbose: bool = False, + ) -> None: + """ + Initialize a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of + model sources, including local files, Ultralytics HUB models, and Triton Server models. The method + initializes several important attributes of the model and prepares it for operations like training, + prediction, or export. + + Args: + model (str | Path): Path or name of the model to load or create. Can be a local file path, a + model name from Ultralytics HUB, or a Triton Server model. + task (str | None): The task type associated with the YOLO model, specifying its application domain. + verbose (bool): If True, enables verbose output during the model's initialization and subsequent + operations. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = Model("path/to/model.yaml", task="detect") + >>> model = Model("hub_model", verbose=True) + """ + super().__init__() + self.callbacks = callbacks.get_default_callbacks() + self.predictor = None # reuse predictor + self.model = None # model object + self.trainer = None # trainer object + self.ckpt = {} # if loaded from *.pt + self.cfg = None # if loaded from *.yaml + self.ckpt_path = None + self.overrides = {} # overrides for trainer object + self.metrics = None # validation/training metrics + self.session = None # HUB session + self.task = task # task type + self.model_name = None # model name + model = str(model).strip() + + # Check if Ultralytics HUB model from https://hub.ultralytics.com + if self.is_hub_model(model): + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.12") + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session + + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set + return + + # Load or create new YOLO model + __import__("os").environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # to avoid deterministic warnings + if Path(model).suffix in {".yaml", ".yml"}: + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + # Delete super().training for accessing self.model.training + del self.training + + def __call__( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs: Any, + ) -> list: + """ + Alias for the predict method, enabling the model instance to be callable for predictions. + + This method simplifies the process of making predictions by allowing the model instance to be called + directly with the required arguments. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of + the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch + tensor, or a list/tuple of these. + stream (bool): If True, treat the input source as a continuous stream for predictions. + **kwargs (Any): Additional keyword arguments to configure the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model("https://ultralytics.com/images/bus.jpg") + >>> for r in results: + ... print(f"Detected {len(r)} objects in image") + """ + return self.predict(source, stream, **kwargs) + + @staticmethod + def is_triton_model(model: str) -> bool: + """ + Check if the given model string is a Triton Server URL. + + This static method determines whether the provided model string represents a valid Triton Server URL by + parsing its components using urllib.parse.urlsplit(). + + Args: + model (str): The model string to be checked. + + Returns: + (bool): True if the model string is a valid Triton Server URL, False otherwise. + + Examples: + >>> Model.is_triton_model("http://localhost:8000/v2/models/yolo11n") + True + >>> Model.is_triton_model("yolo11n.pt") + False + """ + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """ + Check if the provided model is an Ultralytics HUB model. + + This static method determines whether the given model string represents a valid Ultralytics HUB model + identifier. + + Args: + model (str): The model string to check. + + Returns: + (bool): True if the model is a valid Ultralytics HUB model, False otherwise. + + Examples: + >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL") + True + >>> Model.is_hub_model("yolo11n.pt") + False + """ + return model.startswith(f"{HUB_WEB_ROOT}/models/") + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: + """ + Initialize a new model and infer the task type from model definitions. + + Creates a new model instance based on the provided configuration file. Loads the model configuration, infers + the task type if not specified, and initializes the model using the appropriate class from the task map. + + Args: + cfg (str): Path to the model configuration file in YAML format. + task (str | None): The specific task for the model. If None, it will be inferred from the config. + model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating + a new one. + verbose (bool): If True, displays model information during loading. + + Raises: + ValueError: If the configuration file is invalid or the task cannot be inferred. + ImportError: If the required dependencies for the specified task are not installed. + + Examples: + >>> model = Model() + >>> model._new("yolo11n.yaml", task="detect", verbose=True) + """ + cfg_dict = yaml_model_load(cfg) + self.cfg = cfg + self.task = task or guess_model_task(cfg_dict) + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task + + # Below added to allow export from YAMLs + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) + self.model.task = self.task + self.model_name = cfg + + def _load(self, weights: str, task=None) -> None: + """ + Load a model from a checkpoint file or initialize it from a weights file. + + This method handles loading models from either .pt checkpoint files or other weight file formats. It sets + up the model, task, and related attributes based on the loaded weights. + + Args: + weights (str): Path to the model weights file to be loaded. + task (str | None): The task associated with the model. If None, it will be inferred from the model. + + Raises: + FileNotFoundError: If the specified weights file does not exist or is inaccessible. + ValueError: If the weights file format is unsupported or invalid. + + Examples: + >>> model = Model() + >>> model._load("yolo11n.pt") + >>> model._load("path/to/weights.pth", task="detect") + """ + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt + + if Path(weights).suffix == ".pt": + self.model, self.ckpt = attempt_load_one_weight(weights) + self.task = self.model.args["task"] + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) + self.ckpt_path = self.model.pt_path + else: + weights = checks.check_file(weights) # runs in all cases, not redundant with above call + self.model, self.ckpt = weights, None + self.task = task or guess_model_task(weights) + self.ckpt_path = weights + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights + + def _check_is_pytorch_model(self) -> None: + """ + Check if the model is a PyTorch model and raise TypeError if it's not. + + This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that + certain operations that require a PyTorch model are only performed on compatible model types. + + Raises: + TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed + information about supported model formats and operations. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model._check_is_pytorch_model() # No error raised + >>> model = Model("yolo11n.onnx") + >>> model._check_is_pytorch_model() # Raises TypeError + """ + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" + pt_module = isinstance(self.model, torch.nn.Module) + if not (pt_module or pt_str): + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolo11n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) + + def reset_weights(self) -> "Model": + """ + Reset the model's weights to their initial state. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, + enabling them to be updated during training. + + Returns: + (Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.reset_weights() + """ + self._check_is_pytorch_model() + for m in self.model.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + for p in self.model.parameters(): + p.requires_grad = True + return self + + def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model": + """ + Load parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (Union[str, Path]): Path to the weights file or a weights object. + + Returns: + (Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model() + >>> model.load("yolo11n.pt") + >>> model.load(Path("path/to/weights.pt")) + """ + self._check_is_pytorch_model() + if isinstance(weights, (str, Path)): + self.overrides["pretrained"] = weights # remember the weights for DDP training + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + """ + Save the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as + the date, Ultralytics version, license information, and a link to the documentation. + + Args: + filename (str | Path): The name of the file to save the model to. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.save("my_model.pt") + """ + self._check_is_pytorch_model() + from copy import deepcopy + from datetime import datetime + + from ultralytics import __version__ + + updates = { + "model": deepcopy(self.model).half() if isinstance(self.model, torch.nn.Module) else self.model, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Display model information. + + This method provides an overview or detailed information about the model, depending on the arguments + passed. It can control the verbosity of the output and return the information as a list. + + Args: + detailed (bool): If True, shows detailed information about the model layers and parameters. + verbose (bool): If True, prints the information. If False, returns the information as a list. + + Returns: + (List[str]): A list of strings containing various types of information about the model, including + model summary, layer details, and parameter counts. Empty if verbose is True. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.info() # Prints model summary + >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list + """ + self._check_is_pytorch_model() + return self.model.info(detailed=detailed, verbose=verbose) + + def fuse(self) -> None: + """ + Fuse Conv2d and BatchNorm2d layers in the model for optimized inference. + + This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers + into a single layer. This fusion can significantly improve inference speed by reducing the number of + operations and memory accesses required during forward passes. + + The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and + bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that + performs both convolution and normalization in one step. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.fuse() + >>> # Model is now fused and ready for optimized inference + """ + self._check_is_pytorch_model() + self.model.fuse() + + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs: Any, + ) -> list: + """ + Generate image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image + source. It allows customization of the embedding process through various keyword arguments. + + Args: + source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for + generating embeddings. Can be a file path, URL, PIL image, numpy array, etc. + stream (bool): If True, predictions are streamed. + **kwargs (Any): Additional keyword arguments for configuring the embedding process. + + Returns: + (List[torch.Tensor]): A list containing the image embeddings. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> image = "https://ultralytics.com/images/bus.jpg" + >>> embeddings = model.embed(image) + >>> print(embeddings[0].shape) + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs: Any, + ) -> List[Results]: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source + of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL + images, numpy arrays, and torch tensors. + stream (bool): If True, treats the input source as a continuous stream for predictions. + predictor (BasePredictor | None): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. + **kwargs (Any): Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict(source="path/to/image.jpg", conf=0.25) + >>> for r in results: + ... print(r.boxes.data) # print detection bounding boxes + + Notes: + - If 'source' is not provided, it defaults to the ASSETS constant with a warning. + - The method sets up a new predictor if not already present and updates its arguments with each call. + - For SAM-type models, 'prompts' can be passed as a keyword argument. + """ + if source is None: + source = ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + + is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any( + x in ARGV for x in ("predict", "track", "mode=predict", "mode=track") + ) + + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models + + if not self.predictor: + self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=is_cli) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, args) + if "project" in args or "name" in args: + self.predictor.save_dir = get_save_dir(self.predictor.args) + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models + self.predictor.set_prompts(prompts) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs: Any, + ) -> List[Results]: + """ + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It handles + various input sources such as file paths or video streams, and supports customization through keyword arguments. + The method registers trackers if not already present and can persist them between calls. + + Args: + source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object + tracking. Can be a file path, URL, or video stream. + stream (bool): If True, treats the input source as a continuous video stream. + persist (bool): If True, persists trackers between different calls to this method. + **kwargs (Any): Additional keyword arguments for configuring the tracking process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.track(source="path/to/video.mp4", show=True) + >>> for r in results: + ... print(r.boxes.id) # print tracking IDs + + Notes: + - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking. + - The tracking mode is explicitly set in the keyword arguments. + - Batch size is set to 1 for tracking in videos. + """ + if not hasattr(self.predictor, "trackers"): + from ultralytics.trackers import register_tracker + + register_tracker(self, persist) + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" + return self.predict(source=source, stream=stream, **kwargs) + + def val( + self, + validator=None, + **kwargs: Any, + ): + """ + Validate the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for customization through various settings. It + supports validation with a custom validator or the default validation approach. The method combines default + configurations, method-specific defaults, and user-provided arguments to configure the validation process. + + Args: + validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for + validating the model. + **kwargs (Any): Arbitrary keyword arguments for customizing the validation process. + + Returns: + (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.val(data="coco8.yaml", imgsz=640) + >>> print(results.box.map) # Print mAP50-95 + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + def benchmark( + self, + **kwargs: Any, + ): + """ + Benchmark the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is + configured using a combination of default configuration values, model-specific arguments, method-specific + defaults, and any additional user-provided keyword arguments. + + Args: + **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include: + - data (str): Path to the dataset for benchmarking. + - imgsz (int | List[int]): Image size for benchmarking. + - half (bool): Whether to use half-precision (FP16) mode. + - int8 (bool): Whether to use int8 precision mode. + - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda'). + - verbose (bool): Whether to print detailed benchmark information. + - format (str): Export format name for specific benchmarking. + + Returns: + (dict): A dictionary containing the results of the benchmarking process, including metrics for + different export formats. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True) + >>> print(results) + """ + self._check_is_pytorch_model() + from ultralytics.utils.benchmarks import benchmark + + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} + return benchmark( + model=self, + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose", False), + format=kwargs.get("format", ""), + ) + + def export( + self, + **kwargs: Any, + ) -> str: + """ + Export the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. + + Args: + **kwargs (Any): Arbitrary keyword arguments to customize the export process. These are combined with + the model's overrides and method defaults. Common arguments include: + format (str): Export format (e.g., 'onnx', 'engine', 'coreml'). + half (bool): Export model in half-precision. + int8 (bool): Export model in int8 precision. + device (str): Device to run the export on. + workspace (int): Maximum memory workspace size for TensorRT engines. + nms (bool): Add Non-Maximum Suppression (NMS) module to model. + simplify (bool): Simplify ONNX model. + + Returns: + (str): The path to the exported model file. + + Raises: + AssertionError: If the model is not a PyTorch model. + ValueError: If an unsupported export format is specified. + RuntimeError: If the export process fails due to errors. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.export(format="onnx", dynamic=True, simplify=True) + 'path/to/exported/model.onnx' + """ + self._check_is_pytorch_model() + from .exporter import Exporter + + custom = { + "imgsz": self.model.args["imgsz"], + "batch": 1, + "data": None, + "device": None, # reset to avoid multi-GPU errors + "verbose": False, + } # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right + return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) + + def train( + self, + trainer=None, + **kwargs: Any, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings. It supports training with a + custom trainer or the default training approach. The method handles scenarios such as resuming training + from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training. + + When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training + arguments and warns if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. + + Args: + trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default. + **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include: + data (str): Path to dataset configuration file. + epochs (int): Number of training epochs. + batch_size (int): Batch size for training. + imgsz (int): Input image size. + device (str): Device to run training on (e.g., 'cuda', 'cpu'). + workers (int): Number of worker threads for data loading. + optimizer (str): Optimizer to use for training. + lr0 (float): Initial learning rate. + patience (int): Epochs to wait for no observable improvement for early stopping of training. + + Returns: + (Dict | None): Training metrics if available and training is successful; otherwise, None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.train(data="coco8.yaml", epochs=3) + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train() + # Update model and cfg after training + if RANK in {-1, 0}: + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, self.ckpt = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def tune( + self, + use_ray=False, + iterations=10, + *args: Any, + **kwargs: Any, + ): + """ + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): Whether to use Ray Tune for hyperparameter tuning. If False, uses internal tuning method. + iterations (int): Number of tuning iterations to perform. + *args (Any): Additional positional arguments to pass to the tuner. + **kwargs (Any): Additional keyword arguments for tuning configuration. These are combined with model + overrides and defaults to configure the tuning process. + + Returns: + (dict): Results of the hyperparameter search, including best parameters and performance metrics. + + Raises: + TypeError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.tune(data="coco8.yaml", iterations=5) + >>> print(results) + + # Use Ray Tune for more advanced hyperparameter search + >>> results = model.tune(use_ray=True, iterations=20, data="coco8.yaml") + """ + self._check_is_pytorch_model() + if use_ray: + from ultralytics.utils.tuner import run_ray_tune + + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) + else: + from .tuner import Tuner + + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) + + def _apply(self, fn) -> "Model": + """ + Apply a function to model tensors that are not parameters or registered buffers. + + This method extends the functionality of the parent class's _apply method by additionally resetting the + predictor and updating the device in the model's overrides. It's typically used for operations like + moving the model to a different device or changing its precision. + + Args: + fn (Callable): A function to be applied to the model's tensors. This is typically a method like + to(), cpu(), cuda(), half(), or float(). + + Returns: + (Model): The model instance with the function applied and updated attributes. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU + """ + self._check_is_pytorch_model() + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + return self + + @property + def names(self) -> Dict[int, str]: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not + initialized, it sets it up before retrieving the names. + + Returns: + (Dict[int, str]): A dictionary of class names associated with the model, where keys are class indices and + values are the corresponding class names. + + Raises: + AttributeError: If the model or predictor does not have a 'names' attribute. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.names) + {0: 'person', 1: 'bicycle', 2: 'car', ...} + """ + from ultralytics.nn.autobackend import check_class_names + + if hasattr(self.model, "names"): + return check_class_names(self.model.names) + if not self.predictor: # export formats will not have predictor defined until predict() is called + self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=False) + return self.predictor.model.names + + @property + def device(self) -> torch.device: + """ + Get the device on which the model's parameters are allocated. + + This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is + applicable only to models that are instances of torch.nn.Module. + + Returns: + (torch.device): The device (CPU/GPU) of the model. + + Raises: + AttributeError: If the model is not a torch.nn.Module instance. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.device) + device(type='cuda', index=0) # if CUDA is available + >>> model = model.to("cpu") + >>> print(model.device) + device(type='cpu') + """ + return next(self.model.parameters()).device if isinstance(self.model, torch.nn.Module) else None + + @property + def transforms(self): + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. The transforms + typically include preprocessing steps like resizing, normalization, and data augmentation + that are applied to input data before it is fed into the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> transforms = model.transforms + >>> if transforms: + ... print(f"Model transforms: {transforms}") + ... else: + ... print("No transforms defined for this model.") + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Add a callback function for a specified event. + + This method allows registering custom callback functions that are triggered on specific events during + model operations such as training or inference. Callbacks provide a way to extend and customize the + behavior of the model at various stages of its lifecycle. + + Args: + event (str): The name of the event to attach the callback to. Must be a valid event name recognized + by the Ultralytics framework. + func (Callable): The callback function to be registered. This function will be called when the + specified event occurs. + + Raises: + ValueError: If the event name is not recognized or is invalid. + + Examples: + >>> def on_train_start(trainer): + ... print("Training is starting!") + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", on_train_start) + >>> model.train(data="coco8.yaml", epochs=1) + """ + self.callbacks[event].append(func) + + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + It resets the callback list for the specified event to an empty list, effectively removing all + registered callbacks for that event. + + Args: + event (str): The name of the event for which to clear the callbacks. This should be a valid event name + recognized by the Ultralytics callback system. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", lambda: print("Training started")) + >>> model.clear_callback("on_train_start") + >>> # All callbacks for 'on_train_start' are now removed + + Notes: + - This method affects both custom callbacks added by the user and default callbacks + provided by the Ultralytics framework. + - After calling this method, no callbacks will be executed for the specified event + until new ones are added. + - Use with caution as it removes all callbacks, including essential ones that might + be required for proper functioning of certain operations. + """ + self.callbacks[event] = [] + + def reset_callbacks(self) -> None: + """ + Reset all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + previously added. It iterates through all default callback events and replaces the current callbacks with the + default ones. + + The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined + functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc. + + This method is useful when you want to revert to the original set of callbacks after making custom + modifications, ensuring consistent behavior across different runs or experiments. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", custom_function) + >>> model.reset_callbacks() + # All callbacks are now reset to their default functions + """ + for event in callbacks.default_callbacks.keys(): + self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """ + Reset specific arguments when loading a PyTorch model checkpoint. + + This method filters the input arguments dictionary to retain only a specific set of keys that are + considered important for model loading. It's used to ensure that only relevant arguments are preserved + when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings. + + Args: + args (dict): A dictionary containing various model arguments and settings. + + Returns: + (dict): A new dictionary containing only the specified include keys from the input arguments. + + Examples: + >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100} + >>> reset_args = Model._reset_ckpt_args(original_args) + >>> print(reset_args) + {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'} + """ + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} + + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): + """ + Intelligently loads the appropriate module based on the model task. + + This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) + based on the current task of the model and the provided key. It uses the task_map dictionary to determine + the appropriate module to load for the specific task. + + Args: + key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'. + + Returns: + (object): The loaded module class corresponding to the specified key and current task. + + Raises: + NotImplementedError: If the specified key is not supported for the current task. + + Examples: + >>> model = Model(task="detect") + >>> predictor_class = model._smart_load("predictor") + >>> trainer_class = model._smart_load("trainer") + """ + try: + return self.task_map[self.task][key] + except Exception as e: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e + + @property + def task_map(self) -> dict: + """ + Provides a mapping from model tasks to corresponding classes for different modes. + + This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) + to a nested dictionary. The nested dictionary contains mappings for different operational modes + (model, trainer, validator, predictor) to their respective class implementations. + + The mapping allows for dynamic loading of appropriate classes based on the model's task and the + desired operational mode. This facilitates a flexible and extensible architecture for handling + various tasks and modes within the Ultralytics framework. + + Returns: + (Dict[str, Dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary + contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class + implementations for that task. + + Examples: + >>> model = Model("yolo11n.pt") + >>> task_map = model.task_map + >>> detect_predictor = task_map["detect"]["predictor"] + >>> segment_trainer = task_map["segment"]["trainer"] + """ + raise NotImplementedError("Please provide task map for your model!") + + def eval(self): + """ + Sets the model to evaluation mode. + + This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization + that behave differently during training and evaluation. In evaluation mode, these layers use running statistics + rather than computing batch statistics, and dropout layers are disabled. + + Returns: + (Model): The model instance with evaluation mode set. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.eval() + >>> # Model is now in evaluation mode for inference + """ + self.model.eval() + return self + + def __getattr__(self, name): + """ + Enable accessing model attributes directly through the Model class. + + This method provides a way to access attributes of the underlying model directly through the Model class + instance. It first checks if the requested attribute is 'model', in which case it returns the model from + the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + (Any): The requested attribute value. + + Raises: + AttributeError: If the requested attribute does not exist in the model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.stride) # Access model.stride attribute + >>> print(model.names) # Access model.names attribute + """ + return self._modules["model"] if name == "model" else getattr(self.model, name) diff --git a/tracking/ultralytics/engine/predictor.py b/tracking/ultralytics/engine/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a095825da6d7d7587f5a651a8637c3ab699995f5 --- /dev/null +++ b/tracking/ultralytics/engine/predictor.py @@ -0,0 +1,495 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ yolo mode=predict model=yolo11n.pt source=0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream + +Usage - formats: + $ yolo mode=predict model=yolo11n.pt # PyTorch + yolo11n.torchscript # TorchScript + yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolo11n_openvino_model # OpenVINO + yolo11n.engine # TensorRT + yolo11n.mlpackage # CoreML (macOS-only) + yolo11n_saved_model # TensorFlow SavedModel + yolo11n.pb # TensorFlow GraphDef + yolo11n.tflite # TensorFlow Lite + yolo11n_edgetpu.tflite # TensorFlow Edge TPU + yolo11n_paddle_model # PaddlePaddle + yolo11n.mnn # MNN + yolo11n_ncnn_model # NCNN + yolo11n_imx_model # Sony IMX + yolo11n_rknn_model # Rockchip RKNN +""" + +import platform +import re +import threading +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data import load_inference_source +from ultralytics.data.augment import LetterBox, classify_transforms +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils.checks import check_imgsz, check_imshow +from ultralytics.utils.files import increment_path +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +STREAM_WARNING = """ +WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory +errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help. + +Example: + results = model(source=..., stream=True) # generator of Results objects + for r in results: + boxes = r.boxes # Boxes object for bbox outputs + masks = r.masks # Masks object for segment masks outputs + probs = r.probs # Class probabilities for classification outputs +""" + + +class BasePredictor: + """ + A base class for creating predictors. + + This class provides the foundation for prediction functionality, handling model setup, inference, + and result processing across various input sources. + + Attributes: + args (SimpleNamespace): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_warmup (bool): Whether the predictor has finished setup. + model (torch.nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output. + plotted_img (numpy.ndarray): Last plotted image. + source_type (SimpleNamespace): Type of input source. + seen (int): Number of images processed. + windows (list): List of window names for visualization. + batch (tuple): Current batch data. + results (list): Current batch results. + transforms (callable): Image transforms for classification. + callbacks (dict): Callback functions for different events. + txt_path (Path): Path to save text results. + _lock (threading.Lock): Lock for thread-safe inference. + + Methods: + preprocess: Prepare input image before inference. + inference: Run inference on a given image. + postprocess: Process raw predictions into structured results. + predict_cli: Run prediction for command line interface. + setup_source: Set up input source and inference mode. + stream_inference: Stream inference on input source. + setup_model: Initialize and configure the model. + write_results: Write inference results to files. + save_predicted_images: Save prediction visualizations. + show: Display results in a window. + run_callbacks: Execute registered callbacks for an event. + add_callback: Register a new callback function. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the BasePredictor class. + + Args: + cfg (str | dict): Path to a configuration file or a configuration dictionary. + overrides (dict | None): Configuration overrides. + _callbacks (dict | None): Dictionary of callback functions. + """ + self.args = get_cfg(cfg, overrides) + self.save_dir = get_save_dir(self.args) + if self.args.conf is None: + self.args.conf = 0.25 # default conf=0.25 + self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) + + # Usable if setup is done + self.model = None + self.data = self.args.data # data_dict + self.imgsz = None + self.device = None + self.dataset = None + self.vid_writer = {} # dict of {save_path: video_writer, ...} + self.plotted_img = None + self.source_type = None + self.seen = 0 + self.windows = [] + self.batch = None + self.results = None + self.transforms = None + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + + def preprocess(self, im): + """ + Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + """ + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + if not_tensor: + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + + def inference(self, im, *args, **kwargs): + """Run inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im): + """ + Pre-transform input image before inference. + + Args: + im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Returns: + (List[np.ndarray]): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, + auto=same_shapes and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)), + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): + """Post-process predictions for an image and return them.""" + return preds + + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): + """ + Perform inference on an image or stream. + + Args: + source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None): + Source for inference. + model (str | Path | torch.nn.Module | None): Model for inference. + stream (bool): Whether to stream the inference results. If True, returns a generator. + *args (Any): Additional arguments for the inference method. + **kwargs (Any): Additional keyword arguments for the inference method. + + Returns: + (List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects. + """ + self.stream = stream + if stream: + return self.stream_inference(source, model, *args, **kwargs) + else: + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one + + def predict_cli(self, source=None, model=None): + """ + Method used for Command Line Interface (CLI) prediction. + + This function is designed to run predictions using the CLI. It sets up the source and model, then processes + the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the + generator without storing results. + + Args: + source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None): + Source for inference. + model (str | Path | torch.nn.Module | None): Model for inference. + + Note: + Do not modify this function or remove the generator. The generator ensures that no outputs are + accumulated in memory, which is critical for preventing memory issues during long-running predictions. + """ + gen = self.stream_inference(source, model) + for _ in gen: # sourcery skip: remove-empty-nested-block, noqa + pass + + def setup_source(self, source): + """ + Set up source and inference mode. + + Args: + source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor): + Source for inference. + """ + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) + self.source_type = self.dataset.source_type + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos + LOGGER.warning(STREAM_WARNING) + self.vid_writer = {} + + @smart_inference_mode() + def stream_inference(self, source=None, model=None, *args, **kwargs): + """ + Stream real-time inference on camera feed and save results to file. + + Args: + source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None): + Source for inference. + model (str | Path | torch.nn.Module | None): Model for inference. + *args (Any): Additional arguments for the inference method. + **kwargs (Any): Additional keyword arguments for the inference method. + + Yields: + (ultralytics.engine.results.Results): Results objects. + """ + if self.args.verbose: + LOGGER.info("") + + # Setup model + if not self.model: + self.setup_model(model) + + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch + + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) + + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue + + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") + + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) + + self.run_callbacks("on_predict_batch_end") + yield from self.results + + # Release assets + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() + + # Print final results + if self.args.verbose and self.seen: + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) + if self.args.save or self.args.save_txt or self.args.save_crop: + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + + def setup_model(self, model, verbose=True): + """ + Initialize YOLO model with given parameters and set it to evaluation mode. + + Args: + model (str | Path | torch.nn.Module | None): Model to load or use. + verbose (bool): Whether to print verbose output. + """ + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half + self.model.eval() + + def write_results(self, i, p, im, s): + """ + Write inference results to a file or directory. + + Args: + i (int): Index of the current image in the batch. + p (Path): Path to the current image. + im (torch.Tensor): Preprocessed image tensor. + s (List[str]): List of result strings. + + Returns: + (str): String with result information. + """ + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match[1]) if match else None # 0 if frame undetermined + + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "{:g}x{:g} ".format(*im.shape[2:]) + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += f"{result.verbose()}{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): + """ + Save video predictions as mp4 or images as jpg at specified path. + + Args: + save_path (str): Path to save the results. + frame (int): Frame number for video mode. + """ + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f"{save_path.split('.', 1)[0]}_frames/" + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support + + def show(self, p=""): + """Display an image in a window.""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond + + def run_callbacks(self, event: str): + """Run all registered callbacks for a specific event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def add_callback(self, event: str, func): + """Add a callback function for a specific event.""" + self.callbacks[event].append(func) diff --git a/tracking/ultralytics/engine/results.py b/tracking/ultralytics/engine/results.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc37e4de678c6b650ff742b05042bd0555df25d --- /dev/null +++ b/tracking/ultralytics/engine/results.py @@ -0,0 +1,1843 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Ultralytics Results, Boxes and Masks classes for handling inference results. + +Usage: See https://docs.ultralytics.com/modes/predict/ +""" + +from copy import deepcopy +from functools import lru_cache +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.data.augment import LetterBox +from ultralytics.utils import LOGGER, SimpleClass, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.plotting import Annotator, colors, save_one_box +from ultralytics.utils.torch_utils import smart_inference_mode + + +class BaseTensor(SimpleClass): + """ + Base tensor class with additional methods for easy manipulation and device handling. + + Attributes: + data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints. + orig_shape (Tuple[int, int]): Original shape of the image, typically in the format (height, width). + + Methods: + cpu: Return a copy of the tensor stored in CPU memory. + numpy: Returns a copy of the tensor as a numpy array. + cuda: Moves the tensor to GPU memory, returning a new instance if necessary. + to: Return a copy of the tensor with the specified device and dtype. + + Examples: + >>> import torch + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> orig_shape = (720, 1280) + >>> base_tensor = BaseTensor(data, orig_shape) + >>> cpu_tensor = base_tensor.cpu() + >>> numpy_array = base_tensor.numpy() + >>> gpu_tensor = base_tensor.cuda() + """ + + def __init__(self, data, orig_shape) -> None: + """ + Initialize BaseTensor with prediction data and the original shape of the image. + + Args: + data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints. + orig_shape (Tuple[int, int]): Original shape of the image in (height, width) format. + + Examples: + >>> import torch + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> orig_shape = (720, 1280) + >>> base_tensor = BaseTensor(data, orig_shape) + """ + assert isinstance(data, (torch.Tensor, np.ndarray)), "data must be torch.Tensor or np.ndarray" + self.data = data + self.orig_shape = orig_shape + + @property + def shape(self): + """ + Returns the shape of the underlying data tensor. + + Returns: + (Tuple[int, ...]): The shape of the data tensor. + + Examples: + >>> data = torch.rand(100, 4) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> print(base_tensor.shape) + (100, 4) + """ + return self.data.shape + + def cpu(self): + """ + Returns a copy of the tensor stored in CPU memory. + + Returns: + (BaseTensor): A new BaseTensor object with the data tensor moved to CPU memory. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]).cuda() + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> cpu_tensor = base_tensor.cpu() + >>> isinstance(cpu_tensor, BaseTensor) + True + >>> cpu_tensor.data.device + device(type='cpu') + """ + return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape) + + def numpy(self): + """ + Returns a copy of the tensor as a numpy array. + + Returns: + (np.ndarray): A numpy array containing the same data as the original tensor. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> orig_shape = (720, 1280) + >>> base_tensor = BaseTensor(data, orig_shape) + >>> numpy_array = base_tensor.numpy() + >>> print(type(numpy_array)) + + """ + return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape) + + def cuda(self): + """ + Moves the tensor to GPU memory. + + Returns: + (BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a + numpy array, otherwise returns self. + + Examples: + >>> import torch + >>> from ultralytics.engine.results import BaseTensor + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> gpu_tensor = base_tensor.cuda() + >>> print(gpu_tensor.data.device) + cuda:0 + """ + return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape) + + def to(self, *args, **kwargs): + """ + Return a copy of the tensor with the specified device and dtype. + + Args: + *args (Any): Variable length argument list to be passed to torch.Tensor.to(). + **kwargs (Any): Arbitrary keyword arguments to be passed to torch.Tensor.to(). + + Returns: + (BaseTensor): A new BaseTensor instance with the data moved to the specified device and/or dtype. + + Examples: + >>> base_tensor = BaseTensor(torch.randn(3, 4), orig_shape=(480, 640)) + >>> cuda_tensor = base_tensor.to("cuda") + >>> float16_tensor = base_tensor.to(dtype=torch.float16) + """ + return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape) + + def __len__(self): # override len(results) + """ + Returns the length of the underlying data tensor. + + Returns: + (int): The number of elements in the first dimension of the data tensor. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> len(base_tensor) + 2 + """ + return len(self.data) + + def __getitem__(self, idx): + """ + Returns a new BaseTensor instance containing the specified indexed elements of the data tensor. + + Args: + idx (int | List[int] | torch.Tensor): Index or indices to select from the data tensor. + + Returns: + (BaseTensor): A new BaseTensor instance containing the indexed data. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> result = base_tensor[0] # Select the first row + >>> print(result.data) + tensor([1, 2, 3]) + """ + return self.__class__(self.data[idx], self.orig_shape) + + +class Results(SimpleClass): + """ + A class for storing and manipulating inference results. + + This class provides methods for accessing, manipulating, and visualizing inference results from various + Ultralytics models, including detection, segmentation, classification, and pose estimation. + + Attributes: + orig_img (numpy.ndarray): The original image as a numpy array. + orig_shape (Tuple[int, int]): Original image shape in (height, width) format. + boxes (Boxes | None): Detected bounding boxes. + masks (Masks | None): Segmentation masks. + probs (Probs | None): Classification probabilities. + keypoints (Keypoints | None): Detected keypoints. + obb (OBB | None): Oriented bounding boxes. + speed (dict): Dictionary containing inference speed information. + names (dict): Dictionary mapping class indices to class names. + path (str): Path to the input image file. + save_dir (str | None): Directory to save results. + + Methods: + update: Updates the Results object with new detection data. + cpu: Returns a copy of the Results object with all tensors moved to CPU memory. + numpy: Converts all tensors in the Results object to numpy arrays. + cuda: Moves all tensors in the Results object to GPU memory. + to: Moves all tensors to the specified device and dtype. + new: Creates a new Results object with the same image, path, names, and speed attributes. + plot: Plots detection results on an input RGB image. + show: Displays the image with annotated inference results. + save: Saves annotated inference results image to file. + verbose: Returns a log string for each task in the results. + save_txt: Saves detection results to a text file. + save_crop: Saves cropped detection images to specified directory. + summary: Converts inference results to a summarized dictionary. + to_df: Converts detection results to a Pandas Dataframe. + to_json: Converts detection results to JSON format. + to_csv: Converts detection results to a CSV format. + to_xml: Converts detection results to XML format. + to_html: Converts detection results to HTML format. + to_sql: Converts detection results to an SQL-compatible format. + + Examples: + >>> results = model("path/to/image.jpg") + >>> result = results[0] # Get the first result + >>> boxes = result.boxes # Get the boxes for the first result + >>> masks = result.masks # Get the masks for the first result + >>> for result in results: + >>> result.plot() # Plot detection results + """ + + def __init__( + self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None, speed=None + ) -> None: + """ + Initialize the Results class for storing and manipulating inference results. + + Args: + orig_img (numpy.ndarray): The original image as a numpy array. + path (str): The path to the image file. + names (dict): A dictionary of class names. + boxes (torch.Tensor | None): A 2D tensor of bounding box coordinates for each detection. + masks (torch.Tensor | None): A 3D tensor of detection masks, where each mask is a binary image. + probs (torch.Tensor | None): A 1D tensor of probabilities of each class for classification task. + keypoints (torch.Tensor | None): A 2D tensor of keypoint coordinates for each detection. + obb (torch.Tensor | None): A 2D tensor of oriented bounding box coordinates for each detection. + speed (Dict | None): A dictionary containing preprocess, inference, and postprocess speeds (ms/image). + + Examples: + >>> results = model("path/to/image.jpg") + >>> result = results[0] # Get the first result + >>> boxes = result.boxes # Get the boxes for the first result + >>> masks = result.masks # Get the masks for the first result + + Notes: + For the default pose model, keypoint indices for human body pose estimation are: + 0: Nose, 1: Left Eye, 2: Right Eye, 3: Left Ear, 4: Right Ear + 5: Left Shoulder, 6: Right Shoulder, 7: Left Elbow, 8: Right Elbow + 9: Left Wrist, 10: Right Wrist, 11: Left Hip, 12: Right Hip + 13: Left Knee, 14: Right Knee, 15: Left Ankle, 16: Right Ankle + """ + self.orig_img = orig_img + self.orig_shape = orig_img.shape[:2] + self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes + self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks + self.probs = Probs(probs) if probs is not None else None + self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None + self.obb = OBB(obb, self.orig_shape) if obb is not None else None + self.speed = speed if speed is not None else {"preprocess": None, "inference": None, "postprocess": None} + self.names = names + self.path = path + self.save_dir = None + self._keys = "boxes", "masks", "probs", "keypoints", "obb" + + def __getitem__(self, idx): + """ + Return a Results object for a specific index of inference results. + + Args: + idx (int | slice): Index or slice to retrieve from the Results object. + + Returns: + (Results): A new Results object containing the specified subset of inference results. + + Examples: + >>> results = model("path/to/image.jpg") # Perform inference + >>> single_result = results[0] # Get the first result + >>> subset_results = results[1:4] # Get a slice of results + """ + return self._apply("__getitem__", idx) + + def __len__(self): + """ + Return the number of detections in the Results object. + + Returns: + (int): The number of detections, determined by the length of the first non-empty attribute + (boxes, masks, probs, keypoints, or obb). + + Examples: + >>> results = Results(orig_img, path, names, boxes=torch.rand(5, 4)) + >>> len(results) + 5 + """ + for k in self._keys: + v = getattr(self, k) + if v is not None: + return len(v) + + def update(self, boxes=None, masks=None, probs=None, obb=None, keypoints=None): + """ + Updates the Results object with new detection data. + + This method allows updating the boxes, masks, probabilities, and oriented bounding boxes (OBB) of the + Results object. It ensures that boxes are clipped to the original image shape. + + Args: + boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and + confidence scores. The format is (x1, y1, x2, y2, conf, class). + masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks. + probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities. + obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates. + keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints. + + Examples: + >>> results = model("image.jpg") + >>> new_boxes = torch.tensor([[100, 100, 200, 200, 0.9, 0]]) + >>> results[0].update(boxes=new_boxes) + """ + if boxes is not None: + self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape) + if masks is not None: + self.masks = Masks(masks, self.orig_shape) + if probs is not None: + self.probs = probs + if obb is not None: + self.obb = OBB(obb, self.orig_shape) + if keypoints is not None: + self.keypoints = Keypoints(keypoints, self.orig_shape) + + def _apply(self, fn, *args, **kwargs): + """ + Applies a function to all non-empty attributes and returns a new Results object with modified attributes. + + This method is internally called by methods like .to(), .cuda(), .cpu(), etc. + + Args: + fn (str): The name of the function to apply. + *args (Any): Variable length argument list to pass to the function. + **kwargs (Any): Arbitrary keyword arguments to pass to the function. + + Returns: + (Results): A new Results object with attributes modified by the applied function. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... result_cuda = result.cuda() + ... result_cpu = result.cpu() + """ + r = self.new() + for k in self._keys: + v = getattr(self, k) + if v is not None: + setattr(r, k, getattr(v, fn)(*args, **kwargs)) + return r + + def cpu(self): + """ + Returns a copy of the Results object with all its tensors moved to CPU memory. + + This method creates a new Results object with all tensor attributes (boxes, masks, probs, keypoints, obb) + transferred to CPU memory. It's useful for moving data from GPU to CPU for further processing or saving. + + Returns: + (Results): A new Results object with all tensor attributes on CPU memory. + + Examples: + >>> results = model("path/to/image.jpg") # Perform inference + >>> cpu_result = results[0].cpu() # Move the first result to CPU + >>> print(cpu_result.boxes.device) # Output: cpu + """ + return self._apply("cpu") + + def numpy(self): + """ + Converts all tensors in the Results object to numpy arrays. + + Returns: + (Results): A new Results object with all tensors converted to numpy arrays. + + Examples: + >>> results = model("path/to/image.jpg") + >>> numpy_result = results[0].numpy() + >>> type(numpy_result.boxes.data) + + + Notes: + This method creates a new Results object, leaving the original unchanged. It's useful for + interoperability with numpy-based libraries or when CPU-based operations are required. + """ + return self._apply("numpy") + + def cuda(self): + """ + Moves all tensors in the Results object to GPU memory. + + Returns: + (Results): A new Results object with all tensors moved to CUDA device. + + Examples: + >>> results = model("path/to/image.jpg") + >>> cuda_results = results[0].cuda() # Move first result to GPU + >>> for result in results: + ... result_cuda = result.cuda() # Move each result to GPU + """ + return self._apply("cuda") + + def to(self, *args, **kwargs): + """ + Moves all tensors in the Results object to the specified device and dtype. + + Args: + *args (Any): Variable length argument list to be passed to torch.Tensor.to(). + **kwargs (Any): Arbitrary keyword arguments to be passed to torch.Tensor.to(). + + Returns: + (Results): A new Results object with all tensors moved to the specified device and dtype. + + Examples: + >>> results = model("path/to/image.jpg") + >>> result_cuda = results[0].to("cuda") # Move first result to GPU + >>> result_cpu = results[0].to("cpu") # Move first result to CPU + >>> result_half = results[0].to(dtype=torch.float16) # Convert first result to half precision + """ + return self._apply("to", *args, **kwargs) + + def new(self): + """ + Creates a new Results object with the same image, path, names, and speed attributes. + + Returns: + (Results): A new Results object with copied attributes from the original instance. + + Examples: + >>> results = model("path/to/image.jpg") + >>> new_result = results[0].new() + """ + return Results(orig_img=self.orig_img, path=self.path, names=self.names, speed=self.speed) + + def plot( + self, + conf=True, + line_width=None, + font_size=None, + font="Arial.ttf", + pil=False, + img=None, + im_gpu=None, + kpt_radius=5, + kpt_line=True, + labels=True, + boxes=True, + masks=True, + probs=True, + show=False, + save=False, + filename=None, + color_mode="class", + txt_color=(255, 255, 255), + ): + """ + Plots detection results on an input RGB image. + + Args: + conf (bool): Whether to plot detection confidence scores. + line_width (float | None): Line width of bounding boxes. If None, scaled to image size. + font_size (float | None): Font size for text. If None, scaled to image size. + font (str): Font to use for text. + pil (bool): Whether to return the image as a PIL Image. + img (np.ndarray | None): Image to plot on. If None, uses original image. + im_gpu (torch.Tensor | None): Normalized image on GPU for faster mask plotting. + kpt_radius (int): Radius of drawn keypoints. + kpt_line (bool): Whether to draw lines connecting keypoints. + labels (bool): Whether to plot labels of bounding boxes. + boxes (bool): Whether to plot bounding boxes. + masks (bool): Whether to plot masks. + probs (bool): Whether to plot classification probabilities. + show (bool): Whether to display the annotated image. + save (bool): Whether to save the annotated image. + filename (str | None): Filename to save image if save is True. + color_mode (bool): Specify the color mode, e.g., 'instance' or 'class'. Default to 'class'. + txt_color (tuple[int, int, int]): Specify the RGB text color for classification task + + Returns: + (np.ndarray): Annotated image as a numpy array. + + Examples: + >>> results = model("image.jpg") + >>> for result in results: + >>> im = result.plot() + >>> im.show() + """ + assert color_mode in {"instance", "class"}, f"Expected color_mode='instance' or 'class', not {color_mode}." + if img is None and isinstance(self.orig_img, torch.Tensor): + img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy() + + names = self.names + is_obb = self.obb is not None + pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes + pred_masks, show_masks = self.masks, masks + pred_probs, show_probs = self.probs, probs + annotator = Annotator( + deepcopy(self.orig_img if img is None else img), + line_width, + font_size, + font, + pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True + example=names, + ) + + # Plot Segment results + if pred_masks and show_masks: + if im_gpu is None: + img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) + im_gpu = ( + torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device) + .permute(2, 0, 1) + .flip(0) + .contiguous() + / 255 + ) + idx = ( + pred_boxes.id + if pred_boxes.id is not None and color_mode == "instance" + else pred_boxes.cls + if pred_boxes and color_mode == "class" + else reversed(range(len(pred_masks))) + ) + annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) + + # Plot Detect results + if pred_boxes is not None and show_boxes: + for i, d in enumerate(reversed(pred_boxes)): + c, d_conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) + name = ("" if id is None else f"id:{id} ") + names[c] + label = (f"{name} {d_conf:.2f}" if conf else name) if labels else None + box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() + annotator.box_label( + box, + label, + color=colors( + c + if color_mode == "class" + else id + if id is not None + else i + if color_mode == "instance" + else None, + True, + ), + rotated=is_obb, + ) + + # Plot Classify results + if pred_probs is not None and show_probs: + text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5) + x = round(self.orig_shape[0] * 0.03) + annotator.text([x, x], text, txt_color=txt_color) + + # Plot Pose results + if self.keypoints is not None: + for i, k in enumerate(reversed(self.keypoints.data)): + annotator.kpts( + k, + self.orig_shape, + radius=kpt_radius, + kpt_line=kpt_line, + kpt_color=colors(i, True) if color_mode == "instance" else None, + ) + + # Show results + if show: + annotator.show(self.path) + + # Save results + if save: + annotator.save(filename) + + return annotator.im if pil else annotator.result() + + def show(self, *args, **kwargs): + """ + Display the image with annotated inference results. + + This method plots the detection results on the original image and displays it. It's a convenient way to + visualize the model's predictions directly. + + Args: + *args (Any): Variable length argument list to be passed to the `plot()` method. + **kwargs (Any): Arbitrary keyword arguments to be passed to the `plot()` method. + + Examples: + >>> results = model("path/to/image.jpg") + >>> results[0].show() # Display the first result + >>> for result in results: + >>> result.show() # Display all results + """ + self.plot(show=True, *args, **kwargs) + + def save(self, filename=None, *args, **kwargs): + """ + Saves annotated inference results image to file. + + This method plots the detection results on the original image and saves the annotated image to a file. It + utilizes the `plot` method to generate the annotated image and then saves it to the specified filename. + + Args: + filename (str | Path | None): The filename to save the annotated image. If None, a default filename + is generated based on the original image path. + *args (Any): Variable length argument list to be passed to the `plot` method. + **kwargs (Any): Arbitrary keyword arguments to be passed to the `plot` method. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> result.save("annotated_image.jpg") + >>> # Or with custom plot arguments + >>> for result in results: + >>> result.save("annotated_image.jpg", conf=False, line_width=2) + """ + if not filename: + filename = f"results_{Path(self.path).name}" + self.plot(save=True, filename=filename, *args, **kwargs) + return filename + + def verbose(self): + """ + Returns a log string for each task in the results, detailing detection and classification outcomes. + + This method generates a human-readable string summarizing the detection and classification results. It includes + the number of detections for each class and the top probabilities for classification tasks. + + Returns: + (str): A formatted string containing a summary of the results. For detection tasks, it includes the + number of detections per class. For classification tasks, it includes the top 5 class probabilities. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> print(result.verbose()) + 2 persons, 1 car, 3 traffic lights, + dog 0.92, cat 0.78, horse 0.64, + + Notes: + - If there are no detections, the method returns "(no detections), " for detection tasks. + - For classification tasks, it returns the top 5 class probabilities and their corresponding class names. + - The returned string is comma-separated and ends with a comma and a space. + """ + log_string = "" + probs = self.probs + if len(self) == 0: + return log_string if probs is not None else f"{log_string}(no detections), " + if probs is not None: + log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, " + if boxes := self.boxes: + for c in boxes.cls.unique(): + n = (boxes.cls == c).sum() # detections per class + log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " + return log_string + + def save_txt(self, txt_file, save_conf=False): + """ + Save detection results to a text file. + + Args: + txt_file (str | Path): Path to the output text file. + save_conf (bool): Whether to include confidence scores in the output. + + Returns: + (str): Path to the saved text file. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> result.save_txt("output.txt") + + Notes: + - The file will contain one line per detection or classification with the following structure: + - For detections: `class confidence x_center y_center width height` + - For classifications: `confidence class_name` + - For masks and keypoints, the specific formats will vary accordingly. + - The function will create the output directory if it does not exist. + - If save_conf is False, the confidence scores will be excluded from the output. + - Existing contents of the file will not be overwritten; new results will be appended. + """ + is_obb = self.obb is not None + boxes = self.obb if is_obb else self.boxes + masks = self.masks + probs = self.probs + kpts = self.keypoints + texts = [] + if probs is not None: + # Classify + [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5] + elif boxes: + # Detect/segment/pose + for j, d in enumerate(boxes): + c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) + line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1))) + if masks: + seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2) + line = (c, *seg) + if kpts is not None: + kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn + line += (*kpt.reshape(-1).tolist(),) + line += (conf,) * save_conf + (() if id is None else (id,)) + texts.append(("%g " * len(line)).rstrip() % line) + + if texts: + Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory + with open(txt_file, "a", encoding="utf-8") as f: + f.writelines(text + "\n" for text in texts) + + def save_crop(self, save_dir, file_name=Path("im.jpg")): + """ + Saves cropped detection images to specified directory. + + This method saves cropped images of detected objects to a specified directory. Each crop is saved in a + subdirectory named after the object's class, with the filename based on the input file_name. + + Args: + save_dir (str | Path): Directory path where cropped images will be saved. + file_name (str | Path): Base filename for the saved cropped images. Default is Path("im.jpg"). + + Notes: + - This method does not support Classify or Oriented Bounding Box (OBB) tasks. + - Crops are saved as 'save_dir/class_name/file_name.jpg'. + - The method will create necessary subdirectories if they don't exist. + - Original image is copied before cropping to avoid modifying the original. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> result.save_crop(save_dir="path/to/crops", file_name="detection") + """ + if self.probs is not None: + LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.") + return + if self.obb is not None: + LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.") + return + for d in self.boxes: + save_one_box( + d.xyxy, + self.orig_img.copy(), + file=Path(save_dir) / self.names[int(d.cls)] / Path(file_name).with_suffix(".jpg"), + BGR=True, + ) + + def summary(self, normalize=False, decimals=5): + """ + Converts inference results to a summarized dictionary with optional normalization for box coordinates. + + This method creates a list of detection dictionaries, each containing information about a single + detection or classification result. For classification tasks, it returns the top class and its + confidence. For detection tasks, it includes class information, bounding box coordinates, and + optionally mask segments and keypoints. + + Args: + normalize (bool): Whether to normalize bounding box coordinates by image dimensions. + decimals (int): Number of decimal places to round the output values to. + + Returns: + (List[Dict]): A list of dictionaries, each containing summarized information for a single + detection or classification result. The structure of each dictionary varies based on the + task type (classification or detection) and available information (boxes, masks, keypoints). + + Examples: + >>> results = model("image.jpg") + >>> for result in results: + >>> summary = result.summary() + >>> print(summary) + """ + # Create list of detection dictionaries + results = [] + if self.probs is not None: + class_id = self.probs.top1 + results.append( + { + "name": self.names[class_id], + "class": class_id, + "confidence": round(self.probs.top1conf.item(), decimals), + } + ) + return results + + is_obb = self.obb is not None + data = self.obb if is_obb else self.boxes + h, w = self.orig_shape if normalize else (1, 1) + for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id + class_id, conf = int(row.cls), round(row.conf.item(), decimals) + box = (row.xyxyxyxy if is_obb else row.xyxy).squeeze().reshape(-1, 2).tolist() + xy = {} + for j, b in enumerate(box): + xy[f"x{j + 1}"] = round(b[0] / w, decimals) + xy[f"y{j + 1}"] = round(b[1] / h, decimals) + result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": xy} + if data.is_track: + result["track_id"] = int(row.id.item()) # track ID + if self.masks: + result["segments"] = { + "x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(), + "y": (self.masks.xy[i][:, 1] / h).round(decimals).tolist(), + } + if self.keypoints is not None: + x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor + result["keypoints"] = { + "x": (x / w).numpy().round(decimals).tolist(), # decimals named argument required + "y": (y / h).numpy().round(decimals).tolist(), + "visible": visible.numpy().round(decimals).tolist(), + } + results.append(result) + + return results + + def to_df(self, normalize=False, decimals=5): + """ + Converts detection results to a Pandas Dataframe. + + This method converts the detection results into Pandas Dataframe format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. + decimals (int): Number of decimal places to round the output values to. + + Returns: + (DataFrame): A Pandas Dataframe containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> df_result = result.to_df() + >>> print(df_result) + """ + import pandas as pd # scope for faster 'import ultralytics' + + return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals)) + + def to_csv(self, normalize=False, decimals=5, *args, **kwargs): + """ + Converts detection results to a CSV format. + + This method serializes the detection results into a CSV format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. + decimals (int): Number of decimal places to round the output values to. + *args (Any): Variable length argument list to be passed to pandas.DataFrame.to_csv(). + **kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_csv(). + + + Returns: + (str): CSV containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> csv_result = result.to_csv() + >>> print(csv_result) + """ + return self.to_df(normalize=normalize, decimals=decimals).to_csv(*args, **kwargs) + + def to_xml(self, normalize=False, decimals=5, *args, **kwargs): + """ + Converts detection results to XML format. + + This method serializes the detection results into an XML format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. + decimals (int): Number of decimal places to round the output values to. + *args (Any): Variable length argument list to be passed to pandas.DataFrame.to_xml(). + **kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_xml(). + + Returns: + (str): An XML string containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> xml_result = result.to_xml() + >>> print(xml_result) + """ + check_requirements("lxml") + df = self.to_df(normalize=normalize, decimals=decimals) + return '\n' if df.empty else df.to_xml(*args, **kwargs) + + def to_html(self, normalize=False, decimals=5, index=False, *args, **kwargs): + """ + Converts detection results to HTML format. + + This method serializes the detection results into an HTML format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. + decimals (int): Number of decimal places to round the output values to. + index (bool): Whether to include the DataFrame index in the HTML output. + *args (Any): Variable length argument list to be passed to pandas.DataFrame.to_html(). + **kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_html(). + + Returns: + (str): An HTML string containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> html_result = result.to_html() + >>> print(html_result) + """ + df = self.to_df(normalize=normalize, decimals=decimals) + return "
" if df.empty else df.to_html(index=index, *args, **kwargs) + + def tojson(self, normalize=False, decimals=5): + """Deprecated version of to_json().""" + LOGGER.warning("WARNING ⚠️ 'result.tojson()' is deprecated, replace with 'result.to_json()'.") + return self.to_json(normalize, decimals) + + def to_json(self, normalize=False, decimals=5): + """ + Converts detection results to JSON format. + + This method serializes the detection results into a JSON-compatible format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. + decimals (int): Number of decimal places to round the output values to. + + Returns: + (str): A JSON string containing the serialized detection results. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> json_result = result.to_json() + >>> print(json_result) + + Notes: + - For classification tasks, the JSON will contain class probabilities instead of bounding boxes. + - For object detection tasks, the JSON will include bounding box coordinates, class names, and + confidence scores. + - If available, segmentation masks and keypoints will also be included in the JSON output. + - The method uses the `summary` method internally to generate the data structure before + converting it to JSON. + """ + import json + + return json.dumps(self.summary(normalize=normalize, decimals=decimals), indent=2) + + def to_sql(self, table_name="results", normalize=False, decimals=5, db_path="results.db"): + """ + Converts detection results to an SQL-compatible format. + + This method serializes the detection results into a format compatible with SQL databases. + It includes information about detected objects such as bounding boxes, class names, confidence scores, + and optionally segmentation masks, keypoints or oriented bounding boxes. + + Args: + table_name (str): Name of the SQL table where the data will be inserted. + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. + decimals (int): Number of decimal places to round the bounding boxes values to. + db_path (str): Path to the SQLite database file. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + >>> result.to_sql() + """ + import json + import sqlite3 + + # Convert results to a list of dictionaries + data = self.summary(normalize=normalize, decimals=decimals) + if len(data) == 0: + LOGGER.warning("⚠️ No results to save to SQL. Results dict is empty") + return + + # Connect to the SQLite database + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Create table if it doesn't exist + columns = ( + "id INTEGER PRIMARY KEY AUTOINCREMENT, class_name TEXT, confidence REAL, box TEXT, masks TEXT, kpts TEXT" + ) + cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} ({columns})") + + # Insert data into the table + for item in data: + cursor.execute( + f"INSERT INTO {table_name} (class_name, confidence, box, masks, kpts) VALUES (?, ?, ?, ?, ?)", + ( + item.get("name"), + item.get("confidence"), + json.dumps(item.get("box", {})), + json.dumps(item.get("segments", {})), + json.dumps(item.get("keypoints", {})), + ), + ) + + # Commit and close the connection + conn.commit() + conn.close() + + LOGGER.info(f"✅ Detection results successfully written to SQL table '{table_name}' in database '{db_path}'.") + + +class Boxes(BaseTensor): + """ + A class for managing and manipulating detection boxes. + + This class provides functionality for handling detection boxes, including their coordinates, confidence scores, + class labels, and optional tracking IDs. It supports various box formats and offers methods for easy manipulation + and conversion between different coordinate systems. + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw tensor containing detection boxes and associated data. + orig_shape (Tuple[int, int]): The original image dimensions (height, width). + is_track (bool): Indicates whether tracking IDs are included in the box data. + xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format. + conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. + cls (torch.Tensor | numpy.ndarray): Class labels for each box. + id (torch.Tensor | None): Tracking IDs for each box (if available). + xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format. + xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes relative to orig_shape. + xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes relative to orig_shape. + + Methods: + cpu(): Returns a copy of the object with all tensors on CPU memory. + numpy(): Returns a copy of the object with all tensors as numpy arrays. + cuda(): Returns a copy of the object with all tensors on GPU memory. + to(*args, **kwargs): Returns a copy of the object with tensors on specified device and dtype. + + Examples: + >>> import torch + >>> boxes_data = torch.tensor([[100, 50, 150, 100, 0.9, 0], [200, 150, 300, 250, 0.8, 1]]) + >>> orig_shape = (480, 640) # height, width + >>> boxes = Boxes(boxes_data, orig_shape) + >>> print(boxes.xyxy) + >>> print(boxes.conf) + >>> print(boxes.cls) + >>> print(boxes.xywhn) + """ + + def __init__(self, boxes, orig_shape) -> None: + """ + Initialize the Boxes class with detection box data and the original image shape. + + This class manages detection boxes, providing easy access and manipulation of box coordinates, + confidence scores, class identifiers, and optional tracking IDs. It supports multiple formats + for box coordinates, including both absolute and normalized forms. + + Args: + boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape + (num_boxes, 6) or (num_boxes, 7). Columns should contain + [x1, y1, x2, y2, confidence, class, (optional) track_id]. + orig_shape (Tuple[int, int]): The original image shape as (height, width). Used for normalization. + + Attributes: + data (torch.Tensor): The raw tensor containing detection boxes and their associated data. + orig_shape (Tuple[int, int]): The original image size, used for normalization. + is_track (bool): Indicates whether tracking IDs are included in the box data. + + Examples: + >>> import torch + >>> boxes = torch.tensor([[100, 50, 150, 100, 0.9, 0]]) + >>> orig_shape = (480, 640) + >>> detection_boxes = Boxes(boxes, orig_shape) + >>> print(detection_boxes.xyxy) + tensor([[100., 50., 150., 100.]]) + """ + if boxes.ndim == 1: + boxes = boxes[None, :] + n = boxes.shape[-1] + assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls + super().__init__(boxes, orig_shape) + self.is_track = n == 7 + self.orig_shape = orig_shape + + @property + def xyxy(self): + """ + Returns bounding boxes in [x1, y1, x2, y2] format. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box + coordinates in [x1, y1, x2, y2] format, where n is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> boxes = results[0].boxes + >>> xyxy = boxes.xyxy + >>> print(xyxy) + """ + return self.data[:, :4] + + @property + def conf(self): + """ + Returns the confidence scores for each detection box. + + Returns: + (torch.Tensor | numpy.ndarray): A 1D tensor or array containing confidence scores for each detection, + with shape (N,) where N is the number of detections. + + Examples: + >>> boxes = Boxes(torch.tensor([[10, 20, 30, 40, 0.9, 0]]), orig_shape=(100, 100)) + >>> conf_scores = boxes.conf + >>> print(conf_scores) + tensor([0.9000]) + """ + return self.data[:, -2] + + @property + def cls(self): + """ + Returns the class ID tensor representing category predictions for each bounding box. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the class IDs for each detection box. + The shape is (N,), where N is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> boxes = results[0].boxes + >>> class_ids = boxes.cls + >>> print(class_ids) # tensor([0., 2., 1.]) + """ + return self.data[:, -1] + + @property + def id(self): + """ + Returns the tracking IDs for each detection box if available. + + Returns: + (torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled, + otherwise None. Shape is (N,) where N is the number of boxes. + + Examples: + >>> results = model.track("path/to/video.mp4") + >>> for result in results: + ... boxes = result.boxes + ... if boxes.is_track: + ... track_ids = boxes.id + ... print(f"Tracking IDs: {track_ids}") + ... else: + ... print("Tracking is not enabled for these boxes.") + + Notes: + - This property is only available when tracking is enabled (i.e., when `is_track` is True). + - The tracking IDs are typically used to associate detections across multiple frames in video analysis. + """ + return self.data[:, -3] if self.is_track else None + + @property + @lru_cache(maxsize=2) # maxsize 1 should suffice + def xywh(self): + """ + Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format. + + Returns: + (torch.Tensor | numpy.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center, y_center are the coordinates of + the center point of the bounding box, width, height are the dimensions of the bounding box and the + shape of the returned tensor is (N, 4), where N is the number of boxes. + + Examples: + >>> boxes = Boxes(torch.tensor([[100, 50, 150, 100], [200, 150, 300, 250]]), orig_shape=(480, 640)) + >>> xywh = boxes.xywh + >>> print(xywh) + tensor([[100.0000, 50.0000, 50.0000, 50.0000], + [200.0000, 150.0000, 100.0000, 100.0000]]) + """ + return ops.xyxy2xywh(self.xyxy) + + @property + @lru_cache(maxsize=2) + def xyxyn(self): + """ + Returns normalized bounding box coordinates relative to the original image size. + + This property calculates and returns the bounding box coordinates in [x1, y1, x2, y2] format, + normalized to the range [0, 1] based on the original image dimensions. + + Returns: + (torch.Tensor | numpy.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is + the number of boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1]. + + Examples: + >>> boxes = Boxes(torch.tensor([[100, 50, 300, 400, 0.9, 0]]), orig_shape=(480, 640)) + >>> normalized = boxes.xyxyn + >>> print(normalized) + tensor([[0.1562, 0.1042, 0.4688, 0.8333]]) + """ + xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy) + xyxy[..., [0, 2]] /= self.orig_shape[1] + xyxy[..., [1, 3]] /= self.orig_shape[0] + return xyxy + + @property + @lru_cache(maxsize=2) + def xywhn(self): + """ + Returns normalized bounding boxes in [x, y, width, height] format. + + This property calculates and returns the normalized bounding box coordinates in the format + [x_center, y_center, width, height], where all values are relative to the original image dimensions. + + Returns: + (torch.Tensor | numpy.ndarray): Normalized bounding boxes with shape (N, 4), where N is the + number of boxes. Each row contains [x_center, y_center, width, height] values normalized + to [0, 1] based on the original image dimensions. + + Examples: + >>> boxes = Boxes(torch.tensor([[100, 50, 150, 100, 0.9, 0]]), orig_shape=(480, 640)) + >>> normalized = boxes.xywhn + >>> print(normalized) + tensor([[0.1953, 0.1562, 0.0781, 0.1042]]) + """ + xywh = ops.xyxy2xywh(self.xyxy) + xywh[..., [0, 2]] /= self.orig_shape[1] + xywh[..., [1, 3]] /= self.orig_shape[0] + return xywh + + +class Masks(BaseTensor): + """ + A class for storing and manipulating detection masks. + + This class extends BaseTensor and provides functionality for handling segmentation masks, + including methods for converting between pixel and normalized coordinates. + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw tensor or array containing mask data. + orig_shape (tuple): Original image shape in (height, width) format. + xy (List[numpy.ndarray]): A list of segments in pixel coordinates. + xyn (List[numpy.ndarray]): A list of normalized segments. + + Methods: + cpu(): Returns a copy of the Masks object with the mask tensor on CPU memory. + numpy(): Returns a copy of the Masks object with the mask tensor as a numpy array. + cuda(): Returns a copy of the Masks object with the mask tensor on GPU memory. + to(*args, **kwargs): Returns a copy of the Masks object with the mask tensor on specified device and dtype. + + Examples: + >>> masks_data = torch.rand(1, 160, 160) + >>> orig_shape = (720, 1280) + >>> masks = Masks(masks_data, orig_shape) + >>> pixel_coords = masks.xy + >>> normalized_coords = masks.xyn + """ + + def __init__(self, masks, orig_shape) -> None: + """ + Initialize the Masks class with detection mask data and the original image shape. + + Args: + masks (torch.Tensor | np.ndarray): Detection masks with shape (num_masks, height, width). + orig_shape (tuple): The original image shape as (height, width). Used for normalization. + + Examples: + >>> import torch + >>> from ultralytics.engine.results import Masks + >>> masks = torch.rand(10, 160, 160) # 10 masks of 160x160 resolution + >>> orig_shape = (720, 1280) # Original image shape + >>> mask_obj = Masks(masks, orig_shape) + """ + if masks.ndim == 2: + masks = masks[None, :] + super().__init__(masks, orig_shape) + + @property + @lru_cache(maxsize=1) + def xyn(self): + """ + Returns normalized xy-coordinates of the segmentation masks. + + This property calculates and caches the normalized xy-coordinates of the segmentation masks. The coordinates + are normalized relative to the original image shape. + + Returns: + (List[numpy.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates + of a single segmentation mask. Each array has shape (N, 2), where N is the number of points in the + mask contour. + + Examples: + >>> results = model("image.jpg") + >>> masks = results[0].masks + >>> normalized_coords = masks.xyn + >>> print(normalized_coords[0]) # Normalized coordinates of the first mask + """ + return [ + ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True) + for x in ops.masks2segments(self.data) + ] + + @property + @lru_cache(maxsize=1) + def xy(self): + """ + Returns the [x, y] pixel coordinates for each segment in the mask tensor. + + This property calculates and returns a list of pixel coordinates for each segmentation mask in the + Masks object. The coordinates are scaled to match the original image dimensions. + + Returns: + (List[numpy.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel + coordinates for a single segmentation mask. Each array has shape (N, 2), where N is the + number of points in the segment. + + Examples: + >>> results = model("image.jpg") + >>> masks = results[0].masks + >>> xy_coords = masks.xy + >>> print(len(xy_coords)) # Number of masks + >>> print(xy_coords[0].shape) # Shape of first mask's coordinates + """ + return [ + ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False) + for x in ops.masks2segments(self.data) + ] + + +class Keypoints(BaseTensor): + """ + A class for storing and manipulating detection keypoints. + + This class encapsulates functionality for handling keypoint data, including coordinate manipulation, + normalization, and confidence values. + + Attributes: + data (torch.Tensor): The raw tensor containing keypoint data. + orig_shape (Tuple[int, int]): The original image dimensions (height, width). + has_visible (bool): Indicates whether visibility information is available for keypoints. + xy (torch.Tensor): Keypoint coordinates in [x, y] format. + xyn (torch.Tensor): Normalized keypoint coordinates in [x, y] format, relative to orig_shape. + conf (torch.Tensor): Confidence values for each keypoint, if available. + + Methods: + cpu(): Returns a copy of the keypoints tensor on CPU memory. + numpy(): Returns a copy of the keypoints tensor as a numpy array. + cuda(): Returns a copy of the keypoints tensor on GPU memory. + to(*args, **kwargs): Returns a copy of the keypoints tensor with specified device and dtype. + + Examples: + >>> import torch + >>> from ultralytics.engine.results import Keypoints + >>> keypoints_data = torch.rand(1, 17, 3) # 1 detection, 17 keypoints, (x, y, conf) + >>> orig_shape = (480, 640) # Original image shape (height, width) + >>> keypoints = Keypoints(keypoints_data, orig_shape) + >>> print(keypoints.xy.shape) # Access xy coordinates + >>> print(keypoints.conf) # Access confidence values + >>> keypoints_cpu = keypoints.cpu() # Move keypoints to CPU + """ + + @smart_inference_mode() # avoid keypoints < conf in-place error + def __init__(self, keypoints, orig_shape) -> None: + """ + Initializes the Keypoints object with detection keypoints and original image dimensions. + + This method processes the input keypoints tensor, handling both 2D and 3D formats. For 3D tensors + (x, y, confidence), it masks out low-confidence keypoints by setting their coordinates to zero. + + Args: + keypoints (torch.Tensor): A tensor containing keypoint data. Shape can be either: + - (num_objects, num_keypoints, 2) for x, y coordinates only + - (num_objects, num_keypoints, 3) for x, y coordinates and confidence scores + orig_shape (Tuple[int, int]): The original image dimensions (height, width). + + Examples: + >>> kpts = torch.rand(1, 17, 3) # 1 object, 17 keypoints (COCO format), x,y,conf + >>> orig_shape = (720, 1280) # Original image height, width + >>> keypoints = Keypoints(kpts, orig_shape) + """ + if keypoints.ndim == 2: + keypoints = keypoints[None, :] + if keypoints.shape[2] == 3: # x, y, conf + mask = keypoints[..., 2] < 0.5 # points with conf < 0.5 (not visible) + keypoints[..., :2][mask] = 0 + super().__init__(keypoints, orig_shape) + self.has_visible = self.data.shape[-1] == 3 + + @property + @lru_cache(maxsize=1) + def xy(self): + """ + Returns x, y coordinates of keypoints. + + Returns: + (torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is + the number of detections and K is the number of keypoints per detection. + + Examples: + >>> results = model("image.jpg") + >>> keypoints = results[0].keypoints + >>> xy = keypoints.xy + >>> print(xy.shape) # (N, K, 2) + >>> print(xy[0]) # x, y coordinates of keypoints for first detection + + Notes: + - The returned coordinates are in pixel units relative to the original image dimensions. + - If keypoints were initialized with confidence values, only keypoints with confidence >= 0.5 are returned. + - This property uses LRU caching to improve performance on repeated access. + """ + return self.data[..., :2] + + @property + @lru_cache(maxsize=1) + def xyn(self): + """ + Returns normalized coordinates (x, y) of keypoints relative to the original image size. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or array of shape (N, K, 2) containing normalized keypoint + coordinates, where N is the number of instances, K is the number of keypoints, and the last + dimension contains [x, y] values in the range [0, 1]. + + Examples: + >>> keypoints = Keypoints(torch.rand(1, 17, 2), orig_shape=(480, 640)) + >>> normalized_kpts = keypoints.xyn + >>> print(normalized_kpts.shape) + torch.Size([1, 17, 2]) + """ + xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy) + xy[..., 0] /= self.orig_shape[1] + xy[..., 1] /= self.orig_shape[0] + return xy + + @property + @lru_cache(maxsize=1) + def conf(self): + """ + Returns confidence values for each keypoint. + + Returns: + (torch.Tensor | None): A tensor containing confidence scores for each keypoint if available, + otherwise None. Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,) + for single detection. + + Examples: + >>> keypoints = Keypoints(torch.rand(1, 17, 3), orig_shape=(640, 640)) # 1 detection, 17 keypoints + >>> conf = keypoints.conf + >>> print(conf.shape) # torch.Size([1, 17]) + """ + return self.data[..., 2] if self.has_visible else None + + +class Probs(BaseTensor): + """ + A class for storing and manipulating classification probabilities. + + This class extends BaseTensor and provides methods for accessing and manipulating + classification probabilities, including top-1 and top-5 predictions. + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw tensor or array containing classification probabilities. + orig_shape (tuple | None): The original image shape as (height, width). Not used in this class. + top1 (int): Index of the class with the highest probability. + top5 (List[int]): Indices of the top 5 classes by probability. + top1conf (torch.Tensor | numpy.ndarray): Confidence score of the top 1 class. + top5conf (torch.Tensor | numpy.ndarray): Confidence scores of the top 5 classes. + + Methods: + cpu(): Returns a copy of the probabilities tensor on CPU memory. + numpy(): Returns a copy of the probabilities tensor as a numpy array. + cuda(): Returns a copy of the probabilities tensor on GPU memory. + to(*args, **kwargs): Returns a copy of the probabilities tensor with specified device and dtype. + + Examples: + >>> probs = torch.tensor([0.1, 0.3, 0.6]) + >>> p = Probs(probs) + >>> print(p.top1) + 2 + >>> print(p.top5) + [2, 1, 0] + >>> print(p.top1conf) + tensor(0.6000) + >>> print(p.top5conf) + tensor([0.6000, 0.3000, 0.1000]) + """ + + def __init__(self, probs, orig_shape=None) -> None: + """ + Initialize the Probs class with classification probabilities. + + This class stores and manages classification probabilities, providing easy access to top predictions and their + confidences. + + Args: + probs (torch.Tensor | np.ndarray): A 1D tensor or array of classification probabilities. + orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept for + consistency with other result classes. + + Attributes: + data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities. + top1 (int): Index of the top 1 class. + top5 (List[int]): Indices of the top 5 classes. + top1conf (torch.Tensor | np.ndarray): Confidence of the top 1 class. + top5conf (torch.Tensor | np.ndarray): Confidences of the top 5 classes. + + Examples: + >>> import torch + >>> probs = torch.tensor([0.1, 0.3, 0.2, 0.4]) + >>> p = Probs(probs) + >>> print(p.top1) + 3 + >>> print(p.top1conf) + tensor(0.4000) + >>> print(p.top5) + [3, 1, 2, 0] + """ + super().__init__(probs, orig_shape) + + @property + @lru_cache(maxsize=1) + def top1(self): + """ + Returns the index of the class with the highest probability. + + Returns: + (int): Index of the class with the highest probability. + + Examples: + >>> probs = Probs(torch.tensor([0.1, 0.3, 0.6])) + >>> probs.top1 + 2 + """ + return int(self.data.argmax()) + + @property + @lru_cache(maxsize=1) + def top5(self): + """ + Returns the indices of the top 5 class probabilities. + + Returns: + (List[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order. + + Examples: + >>> probs = Probs(torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])) + >>> print(probs.top5) + [4, 3, 2, 1, 0] + """ + return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy. + + @property + @lru_cache(maxsize=1) + def top1conf(self): + """ + Returns the confidence score of the highest probability class. + + This property retrieves the confidence score (probability) of the class with the highest predicted probability + from the classification results. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor containing the confidence score of the top 1 class. + + Examples: + >>> results = model("image.jpg") # classify an image + >>> probs = results[0].probs # get classification probabilities + >>> top1_confidence = probs.top1conf # get confidence of top 1 class + >>> print(f"Top 1 class confidence: {top1_confidence.item():.4f}") + """ + return self.data[self.top1] + + @property + @lru_cache(maxsize=1) + def top5conf(self): + """ + Returns confidence scores for the top 5 classification predictions. + + This property retrieves the confidence scores corresponding to the top 5 class probabilities + predicted by the model. It provides a quick way to access the most likely class predictions + along with their associated confidence levels. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or array containing the confidence scores for the + top 5 predicted classes, sorted in descending order of probability. + + Examples: + >>> results = model("image.jpg") + >>> probs = results[0].probs + >>> top5_conf = probs.top5conf + >>> print(top5_conf) # Prints confidence scores for top 5 classes + """ + return self.data[self.top5] + + +class OBB(BaseTensor): + """ + A class for storing and manipulating Oriented Bounding Boxes (OBB). + + This class provides functionality to handle oriented bounding boxes, including conversion between + different formats, normalization, and access to various properties of the boxes. + + Attributes: + data (torch.Tensor): The raw OBB tensor containing box coordinates and associated data. + orig_shape (tuple): Original image size as (height, width). + is_track (bool): Indicates whether tracking IDs are included in the box data. + xywhr (torch.Tensor | numpy.ndarray): Boxes in [x_center, y_center, width, height, rotation] format. + conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. + cls (torch.Tensor | numpy.ndarray): Class labels for each box. + id (torch.Tensor | numpy.ndarray): Tracking IDs for each box, if available. + xyxyxyxy (torch.Tensor | numpy.ndarray): Boxes in 8-point [x1, y1, x2, y2, x3, y3, x4, y4] format. + xyxyxyxyn (torch.Tensor | numpy.ndarray): Normalized 8-point coordinates relative to orig_shape. + xyxy (torch.Tensor | numpy.ndarray): Axis-aligned bounding boxes in [x1, y1, x2, y2] format. + + Methods: + cpu(): Returns a copy of the OBB object with all tensors on CPU memory. + numpy(): Returns a copy of the OBB object with all tensors as numpy arrays. + cuda(): Returns a copy of the OBB object with all tensors on GPU memory. + to(*args, **kwargs): Returns a copy of the OBB object with tensors on specified device and dtype. + + Examples: + >>> boxes = torch.tensor([[100, 50, 150, 100, 30, 0.9, 0]]) # xywhr, conf, cls + >>> obb = OBB(boxes, orig_shape=(480, 640)) + >>> print(obb.xyxyxyxy) + >>> print(obb.conf) + >>> print(obb.cls) + """ + + def __init__(self, boxes, orig_shape) -> None: + """ + Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape. + + This class stores and manipulates Oriented Bounding Boxes (OBB) for object detection tasks. It provides + various properties and methods to access and transform the OBB data. + + Args: + boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes, + with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values. + If present, the third last column contains track IDs, and the fifth column contains rotation. + orig_shape (Tuple[int, int]): Original image size, in the format (height, width). + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw OBB tensor. + orig_shape (Tuple[int, int]): The original image shape. + is_track (bool): Whether the boxes include tracking IDs. + + Raises: + AssertionError: If the number of values per box is not 7 or 8. + + Examples: + >>> import torch + >>> boxes = torch.rand(3, 7) # 3 boxes with 7 values each + >>> orig_shape = (640, 480) + >>> obb = OBB(boxes, orig_shape) + >>> print(obb.xywhr) # Access the boxes in xywhr format + """ + if boxes.ndim == 1: + boxes = boxes[None, :] + n = boxes.shape[-1] + assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls + super().__init__(boxes, orig_shape) + self.is_track = n == 8 + self.orig_shape = orig_shape + + @property + def xywhr(self): + """ + Returns boxes in [x_center, y_center, width, height, rotation] format. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the oriented bounding boxes with format + [x_center, y_center, width, height, rotation]. The shape is (N, 5) where N is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> obb = results[0].obb + >>> xywhr = obb.xywhr + >>> print(xywhr.shape) + torch.Size([3, 5]) + """ + return self.data[:, :5] + + @property + def conf(self): + """ + Returns the confidence scores for Oriented Bounding Boxes (OBBs). + + This property retrieves the confidence values associated with each OBB detection. The confidence score + represents the model's certainty in the detection. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array of shape (N,) containing confidence scores + for N detections, where each score is in the range [0, 1]. + + Examples: + >>> results = model("image.jpg") + >>> obb_result = results[0].obb + >>> confidence_scores = obb_result.conf + >>> print(confidence_scores) + """ + return self.data[:, -2] + + @property + def cls(self): + """ + Returns the class values of the oriented bounding boxes. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the class values for each oriented + bounding box. The shape is (N,), where N is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> result = results[0] + >>> obb = result.obb + >>> class_values = obb.cls + >>> print(class_values) + """ + return self.data[:, -1] + + @property + def id(self): + """ + Returns the tracking IDs of the oriented bounding boxes (if available). + + Returns: + (torch.Tensor | numpy.ndarray | None): A tensor or numpy array containing the tracking IDs for each + oriented bounding box. Returns None if tracking IDs are not available. + + Examples: + >>> results = model("image.jpg", tracker=True) # Run inference with tracking + >>> for result in results: + ... if result.obb is not None: + ... track_ids = result.obb.id + ... if track_ids is not None: + ... print(f"Tracking IDs: {track_ids}") + """ + return self.data[:, -3] if self.is_track else None + + @property + @lru_cache(maxsize=2) + def xyxyxyxy(self): + """ + Converts OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes. + + Returns: + (torch.Tensor | numpy.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is + the number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and + moving clockwise. + + Examples: + >>> obb = OBB(torch.tensor([[100, 100, 50, 30, 0.5, 0.9, 0]]), orig_shape=(640, 640)) + >>> xyxyxyxy = obb.xyxyxyxy + >>> print(xyxyxyxy.shape) + torch.Size([1, 4, 2]) + """ + return ops.xywhr2xyxyxyxy(self.xywhr) + + @property + @lru_cache(maxsize=2) + def xyxyxyxyn(self): + """ + Converts rotated bounding boxes to normalized xyxyxyxy format. + + Returns: + (torch.Tensor | numpy.ndarray): Normalized rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), + where N is the number of boxes. Each box is represented by 4 points (x, y), normalized relative to + the original image dimensions. + + Examples: + >>> obb = OBB(torch.rand(10, 7), orig_shape=(640, 480)) # 10 random OBBs + >>> normalized_boxes = obb.xyxyxyxyn + >>> print(normalized_boxes.shape) + torch.Size([10, 4, 2]) + """ + xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy) + xyxyxyxyn[..., 0] /= self.orig_shape[1] + xyxyxyxyn[..., 1] /= self.orig_shape[0] + return xyxyxyxyn + + @property + @lru_cache(maxsize=2) + def xyxy(self): + """ + Converts oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format. + + This property calculates the minimal enclosing rectangle for each oriented bounding box and returns it in + xyxy format (x1, y1, x2, y2). This is useful for operations that require axis-aligned bounding boxes, such + as IoU calculation with non-rotated boxes. + + Returns: + (torch.Tensor | numpy.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N + is the number of boxes. Each row contains [x1, y1, x2, y2] coordinates. + + Examples: + >>> import torch + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n-obb.pt") + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... obb = result.obb + ... if obb is not None: + ... xyxy_boxes = obb.xyxy + ... print(xyxy_boxes.shape) # (N, 4) + + Notes: + - This method approximates the OBB by its minimal enclosing rectangle. + - The returned format is compatible with standard object detection metrics and visualization tools. + - The property uses caching to improve performance for repeated access. + """ + x = self.xyxyxyxy[..., 0] + y = self.xyxyxyxy[..., 1] + return ( + torch.stack([x.amin(1), y.amin(1), x.amax(1), y.amax(1)], -1) + if isinstance(x, torch.Tensor) + else np.stack([x.min(1), y.min(1), x.max(1), y.max(1)], -1) + ) diff --git a/tracking/ultralytics/engine/trainer.py b/tracking/ultralytics/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..25c551d7d868794ceea4ebfa5136a5aade01706a --- /dev/null +++ b/tracking/ultralytics/engine/trainer.py @@ -0,0 +1,840 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Train a model on a dataset. + +Usage: + $ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16 +""" + +import gc +import math +import os +import subprocess +import time +import warnings +from copy import copy, deepcopy +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import torch +from torch import distributed as dist +from torch import nn, optim + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights +from ultralytics.utils import ( + DEFAULT_CFG, + LOCAL_RANK, + LOGGER, + RANK, + TQDM, + __version__, + callbacks, + clean_url, + colorstr, + emojis, + yaml_save, +) +from ultralytics.utils.autobatch import check_train_batch_size +from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args +from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command +from ultralytics.utils.files import get_latest_run +from ultralytics.utils.torch_utils import ( + TORCH_2_4, + EarlyStopping, + ModelEMA, + autocast, + convert_optimizer_state_dict_to_fp16, + init_seeds, + one_cycle, + select_device, + strip_optimizer, + torch_distributed_zero_first, + unset_deterministic, +) + + +class BaseTrainer: + """ + A base class for creating trainers. + + Attributes: + args (SimpleNamespace): Configuration for the trainer. + validator (BaseValidator): Validator instance. + model (nn.Module): Model instance. + callbacks (defaultdict): Dictionary of callbacks. + save_dir (Path): Directory to save results. + wdir (Path): Directory to save weights. + last (Path): Path to the last checkpoint. + best (Path): Path to the best checkpoint. + save_period (int): Save checkpoint every x epochs (disabled if < 1). + batch_size (int): Batch size for training. + epochs (int): Number of epochs to train for. + start_epoch (int): Starting epoch for training. + device (torch.device): Device to use for training. + amp (bool): Flag to enable AMP (Automatic Mixed Precision). + scaler (amp.GradScaler): Gradient scaler for AMP. + data (str): Path to data. + trainset (torch.utils.data.Dataset): Training dataset. + testset (torch.utils.data.Dataset): Testing dataset. + ema (nn.Module): EMA (Exponential Moving Average) of the model. + resume (bool): Resume training from a checkpoint. + lf (nn.Module): Loss function. + scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. + best_fitness (float): The best fitness value achieved. + fitness (float): Current fitness value. + loss (float): Current loss value. + tloss (float): Total loss value. + loss_names (list): List of loss names. + csv (Path): Path to results CSV file. + metrics (dict): Dictionary of metrics. + plots (dict): Dictionary of plots. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the BaseTrainer class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + _callbacks (list, optional): List of callback functions. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.check_resume(overrides) + self.device = select_device(self.args.device, self.args.batch) + self.validator = None + self.metrics = None + self.plots = {} + init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) + + # Dirs + self.save_dir = get_save_dir(self.args) + self.args.name = self.save_dir.name # update name for loggers + self.wdir = self.save_dir / "weights" # weights dir + if RANK in {-1, 0}: + self.wdir.mkdir(parents=True, exist_ok=True) # make dir + self.args.save_dir = str(self.save_dir) + yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args + self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths + self.save_period = self.args.save_period + + self.batch_size = self.args.batch + self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training + self.start_epoch = 0 + if RANK == -1: + print_args(vars(self.args)) + + # Device + if self.device.type in {"cpu", "mps"}: + self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading + + # Model and Dataset + self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt + with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times + self.trainset, self.testset = self.get_dataset() + self.ema = None + + # Optimization utils init + self.lf = None + self.scheduler = None + + # Epoch level metrics + self.best_fitness = None + self.fitness = None + self.loss = None + self.tloss = None + self.loss_names = ["Loss"] + self.csv = self.save_dir / "results.csv" + self.plot_idx = [0, 1, 2] + + # HUB + self.hub_session = None + + # Callbacks + self.callbacks = _callbacks or callbacks.get_default_callbacks() + if RANK in {-1, 0}: + callbacks.add_integration_callbacks(self) + + def add_callback(self, event: str, callback): + """Append the given callback to the event's callback list.""" + self.callbacks[event].append(callback) + + def set_callback(self, event: str, callback): + """Override the existing callbacks with the given callback for the specified event.""" + self.callbacks[event] = [callback] + + def run_callbacks(self, event: str): + """Run all existing callbacks associated with a particular event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def train(self): + """Allow device='', device=None on Multi-GPU systems to default to device=0.""" + if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' + world_size = len(self.args.device.split(",")) + elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) + world_size = len(self.args.device) + elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps' + world_size = 0 + elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number + world_size = 1 # default to device 0 + else: # i.e. device=None or device='' + world_size = 0 + + # Run subprocess if DDP training, else train normally + if world_size > 1 and "LOCAL_RANK" not in os.environ: + # Argument checks + if self.args.rect: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") + self.args.rect = False + if self.args.batch < 1.0: + LOGGER.warning( + "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting " + "default 'batch=16'" + ) + self.args.batch = 16 + + # Command + cmd, file = generate_ddp_command(world_size, self) + try: + LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}") + subprocess.run(cmd, check=True) + except Exception as e: + raise e + finally: + ddp_cleanup(self, str(file)) + + else: + self._do_train(world_size) + + def _setup_scheduler(self): + """Initialize training learning rate scheduler.""" + if self.args.cos_lr: + self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] + else: + self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + + def _setup_ddp(self, world_size): + """Initialize and set the DistributedDataParallel parameters for training.""" + torch.cuda.set_device(RANK) + self.device = torch.device("cuda", RANK) + # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout + dist.init_process_group( + backend="nccl" if dist.is_nccl_available() else "gloo", + timeout=timedelta(seconds=10800), # 3 hours + rank=RANK, + world_size=world_size, + ) + + def _setup_train(self, world_size): + """Build dataloaders and optimizer on correct rank process.""" + # Model + self.run_callbacks("on_pretrain_routine_start") + ckpt = self.setup_model() + self.model = self.model.to(self.device) + self.set_model_attributes() + + # Freeze layers + freeze_list = ( + self.args.freeze + if isinstance(self.args.freeze, list) + else range(self.args.freeze) + if isinstance(self.args.freeze, int) + else [] + ) + always_freeze_names = [".dfl"] # always freeze these layers + freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names + for k, v in self.model.named_parameters(): + # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) + if any(x in k for x in freeze_layer_names): + LOGGER.info(f"Freezing layer '{k}'") + v.requires_grad = False + elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients + LOGGER.info( + f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " + "See ultralytics.engine.trainer for customization of frozen layers." + ) + v.requires_grad = True + + # Check AMP + self.amp = torch.tensor(self.args.amp).to(self.device) # True or False + if self.amp and RANK in {-1, 0}: # Single-GPU and DDP + callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them + self.amp = torch.tensor(check_amp(self.model), device=self.device) + callbacks.default_callbacks = callbacks_backup # restore callbacks + if RANK > -1 and world_size > 1: # DDP + dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) + self.amp = bool(self.amp) # as boolean + self.scaler = ( + torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp) + ) + if world_size > 1: + self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) + + # Check imgsz + gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride) + self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) + self.stride = gs # for multiscale training + + # Batch size + if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size + self.args.batch = self.batch_size = self.auto_batch() + + # Dataloaders + batch_size = self.batch_size // max(world_size, 1) + self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train") + if RANK in {-1, 0}: + # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. + self.test_loader = self.get_dataloader( + self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" + ) + self.validator = self.get_validator() + metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") + self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) + self.ema = ModelEMA(self.model) + if self.args.plots: + self.plot_training_labels() + + # Optimizer + self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing + weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay + iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs + self.optimizer = self.build_optimizer( + model=self.model, + name=self.args.optimizer, + lr=self.args.lr0, + momentum=self.args.momentum, + decay=weight_decay, + iterations=iterations, + ) + # Scheduler + self._setup_scheduler() + self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False + self.resume_training(ckpt) + self.scheduler.last_epoch = self.start_epoch - 1 # do not move + self.run_callbacks("on_pretrain_routine_end") + + def _do_train(self, world_size=1): + """Train the model with the specified world size.""" + if world_size > 1: + self._setup_ddp(world_size) + self._setup_train(world_size) + + nb = len(self.train_loader) # number of batches + nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations + last_opt_step = -1 + self.epoch_time = None + self.epoch_time_start = time.time() + self.train_time_start = time.time() + self.run_callbacks("on_train_start") + LOGGER.info( + f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" + f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n" + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") + ) + if self.args.close_mosaic: + base_idx = (self.epochs - self.args.close_mosaic) * nb + self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) + epoch = self.start_epoch + self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start + while True: + self.epoch = epoch + self.run_callbacks("on_train_epoch_start") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' + self.scheduler.step() + + self.model.train() + if RANK != -1: + self.train_loader.sampler.set_epoch(epoch) + pbar = enumerate(self.train_loader) + # Update dataloader attributes (optional) + if epoch == (self.epochs - self.args.close_mosaic): + self._close_dataloader_mosaic() + self.train_loader.reset() + + if RANK in {-1, 0}: + LOGGER.info(self.progress_string()) + pbar = TQDM(enumerate(self.train_loader), total=nb) + self.tloss = None + for i, batch in pbar: + self.run_callbacks("on_train_batch_start") + # Warmup + ni = i + nb * epoch + if ni <= nw: + xi = [0, nw] # x interp + self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) + for j, x in enumerate(self.optimizer.param_groups): + # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 + x["lr"] = np.interp( + ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] + ) + if "momentum" in x: + x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) + + # Forward + with autocast(self.amp): + batch = self.preprocess_batch(batch) + self.loss, self.loss_items = self.model(batch) + if RANK != -1: + self.loss *= world_size + self.tloss = ( + (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items + ) + + # Backward + self.scaler.scale(self.loss).backward() + + # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html + if ni - last_opt_step >= self.accumulate: + self.optimizer_step() + last_opt_step = ni + + # Timed stopping + if self.args.time: + self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: # training time exceeded + break + + # Log + if RANK in {-1, 0}: + loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1 + pbar.set_description( + ("%11s" * 2 + "%11.4g" * (2 + loss_length)) + % ( + f"{epoch + 1}/{self.epochs}", + f"{self._get_memory():.3g}G", # (GB) GPU memory util + *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses + batch["cls"].shape[0], # batch size, i.e. 8 + batch["img"].shape[-1], # imgsz, i.e 640 + ) + ) + self.run_callbacks("on_batch_end") + if self.args.plots and ni in self.plot_idx: + self.plot_training_samples(batch, ni) + + self.run_callbacks("on_train_batch_end") + + self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.run_callbacks("on_train_epoch_end") + if RANK in {-1, 0}: + final_epoch = epoch + 1 >= self.epochs + self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) + + # Validation + if self.args.val or final_epoch or self.stopper.possible_stop or self.stop: + self.metrics, self.fitness = self.validate() + self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) + self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch + if self.args.time: + self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600) + + # Save model + if self.args.save or final_epoch: + self.save_model() + self.run_callbacks("on_model_save") + + # Scheduler + t = time.time() + self.epoch_time = t - self.epoch_time_start + self.epoch_time_start = t + if self.args.time: + mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) + self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) + self._setup_scheduler() + self.scheduler.last_epoch = self.epoch # do not move + self.stop |= epoch >= self.epochs # stop if exceeded epochs + self.run_callbacks("on_fit_epoch_end") + if self._get_memory(fraction=True) > 0.9: + self._clear_memory() # clear if memory utilization > 90% + + # Early Stopping + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: + break # must break all DDP ranks + epoch += 1 + + if RANK in {-1, 0}: + # Do final val with best.pt + seconds = time.time() - self.train_time_start + LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.") + self.final_eval() + if self.args.plots: + self.plot_metrics() + self.run_callbacks("on_train_end") + self._clear_memory() + unset_deterministic() + self.run_callbacks("teardown") + + def auto_batch(self, max_num_obj=0): + """Calculate optimal batch size based on model and device memory constraints.""" + return check_train_batch_size( + model=self.model, + imgsz=self.args.imgsz, + amp=self.amp, + batch=self.batch_size, + max_num_obj=max_num_obj, + ) # returns batch size + + def _get_memory(self, fraction=False): + """Get accelerator memory utilization in GB or as a fraction of total memory.""" + memory, total = 0, 0 + if self.device.type == "mps": + memory = torch.mps.driver_allocated_memory() + if fraction: + return __import__("psutil").virtual_memory().percent / 100 + elif self.device.type == "cpu": + pass + else: + memory = torch.cuda.memory_reserved() + if fraction: + total = torch.cuda.get_device_properties(self.device).total_memory + return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30) + + def _clear_memory(self): + """Clear accelerator memory by calling garbage collector and emptying cache.""" + gc.collect() + if self.device.type == "mps": + torch.mps.empty_cache() + elif self.device.type == "cpu": + return + else: + torch.cuda.empty_cache() + + def read_results_csv(self): + """Read results.csv into a dictionary using pandas.""" + import pandas as pd # scope for faster 'import ultralytics' + + return pd.read_csv(self.csv).to_dict(orient="list") + + def save_model(self): + """Save model training checkpoints with additional metadata.""" + import io + + # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls) + buffer = io.BytesIO() + torch.save( + { + "epoch": self.epoch, + "best_fitness": self.best_fitness, + "model": None, # resume and final checkpoints derive from EMA + "ema": deepcopy(self.ema.ema).half(), + "updates": self.ema.updates, + "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())), + "train_args": vars(self.args), # save as dict + "train_metrics": {**self.metrics, **{"fitness": self.fitness}}, + "train_results": self.read_results_csv(), + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + }, + buffer, + ) + serialized_ckpt = buffer.getvalue() # get the serialized content to save + + # Save checkpoints + self.last.write_bytes(serialized_ckpt) # save last.pt + if self.best_fitness == self.fitness: + self.best.write_bytes(serialized_ckpt) # save best.pt + if (self.save_period > 0) and (self.epoch % self.save_period == 0): + (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' + # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1): + # (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint + + def get_dataset(self): + """ + Get train and validation datasets from data dictionary. + + Returns: + (tuple): A tuple containing the training and validation/test datasets. + """ + try: + if self.args.task == "classify": + data = check_cls_dataset(self.args.data) + elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in { + "detect", + "segment", + "pose", + "obb", + }: + data = check_det_dataset(self.args.data) + if "yaml_file" in data: + self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage + except Exception as e: + raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e + self.data = data + if self.args.single_cls: + LOGGER.info("Overriding class names with single class.") + self.data["names"] = {0: "item"} + self.data["nc"] = 1 + return data["train"], data.get("val") or data.get("test") + + def setup_model(self): + """ + Load, create, or download model for any task. + + Returns: + (dict): Optional checkpoint to resume training from. + """ + if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed + return + + cfg, weights = self.model, None + ckpt = None + if str(self.model).endswith(".pt"): + weights, ckpt = attempt_load_one_weight(self.model) + cfg = weights.yaml + elif isinstance(self.args.pretrained, (str, Path)): + weights, _ = attempt_load_one_weight(self.args.pretrained) + self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) + return ckpt + + def optimizer_step(self): + """Perform a single step of the training optimizer with gradient clipping and EMA update.""" + self.scaler.unscale_(self.optimizer) # unscale gradients + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + if self.ema: + self.ema.update(self.model) + + def preprocess_batch(self, batch): + """Allows custom preprocessing model inputs and ground truths depending on task type.""" + return batch + + def validate(self): + """ + Run validation on test set using self.validator. + + Returns: + (tuple): A tuple containing metrics dictionary and fitness score. + """ + metrics = self.validator(self) + fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found + if not self.best_fitness or self.best_fitness < fitness: + self.best_fitness = fitness + return metrics, fitness + + def get_model(self, cfg=None, weights=None, verbose=True): + """Get model and raise NotImplementedError for loading cfg files.""" + raise NotImplementedError("This task trainer doesn't support loading cfg files") + + def get_validator(self): + """Returns a NotImplementedError when the get_validator function is called.""" + raise NotImplementedError("get_validator function not implemented in trainer") + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Returns dataloader derived from torch.data.Dataloader.""" + raise NotImplementedError("get_dataloader function not implemented in trainer") + + def build_dataset(self, img_path, mode="train", batch=None): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in trainer") + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Note: + This is not needed for classification but necessary for segmentation & detection + """ + return {"loss": loss_items} if loss_items is not None else ["loss"] + + def set_model_attributes(self): + """Set or update model parameters before training.""" + self.model.names = self.data["names"] + + def build_targets(self, preds, targets): + """Builds target tensors for training YOLO model.""" + pass + + def progress_string(self): + """Returns a string describing training progress.""" + return "" + + # TODO: may need to put these following functions into callback + def plot_training_samples(self, batch, ni): + """Plots training samples during YOLO training.""" + pass + + def plot_training_labels(self): + """Plots training labels for YOLO model.""" + pass + + def save_metrics(self, metrics): + """Save training metrics to a CSV file.""" + keys, vals = list(metrics.keys()), list(metrics.values()) + n = len(metrics) + 2 # number of cols + s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header + t = time.time() - self.train_time_start + with open(self.csv, "a", encoding="utf-8") as f: + f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n") + + def plot_metrics(self): + """Plot and display metrics visually.""" + pass + + def on_plot(self, name, data=None): + """Registers plots (e.g. to be consumed in callbacks).""" + path = Path(name) + self.plots[path] = {"data": data, "timestamp": time.time()} + + def final_eval(self): + """Perform final evaluation and validation for object detection YOLO model.""" + ckpt = {} + for f in self.last, self.best: + if f.exists(): + if f is self.last: + ckpt = strip_optimizer(f) + elif f is self.best: + k = "train_results" # update best.pt train_metrics from last.pt + strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None) + LOGGER.info(f"\nValidating {f}...") + self.validator.args.plots = self.args.plots + self.metrics = self.validator(model=f) + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") + + def check_resume(self, overrides): + """Check if resume checkpoint exists and update arguments accordingly.""" + resume = self.args.resume + if resume: + try: + exists = isinstance(resume, (str, Path)) and Path(resume).exists() + last = Path(check_file(resume) if exists else get_latest_run()) + + # Check that resume data YAML exists, otherwise strip to force re-download of dataset + ckpt_args = attempt_load_weights(last).args + if not Path(ckpt_args["data"]).exists(): + ckpt_args["data"] = self.args.data + + resume = True + self.args = get_cfg(ckpt_args) + self.args.model = self.args.resume = str(last) # reinstate model + for k in ( + "imgsz", + "batch", + "device", + "close_mosaic", + ): # allow arg updates to reduce memory or update device on resume + if k in overrides: + setattr(self.args, k, overrides[k]) + + except Exception as e: + raise FileNotFoundError( + "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " + "i.e. 'yolo train resume model=path/to/last.pt'" + ) from e + self.resume = resume + + def resume_training(self, ckpt): + """Resume YOLO training from given epoch and best fitness.""" + if ckpt is None or not self.resume: + return + best_fitness = 0.0 + start_epoch = ckpt.get("epoch", -1) + 1 + if ckpt.get("optimizer", None) is not None: + self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer + best_fitness = ckpt["best_fitness"] + if self.ema and ckpt.get("ema"): + self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA + self.ema.updates = ckpt["updates"] + assert start_epoch > 0, ( + f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" + f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" + ) + LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") + if self.epochs < start_epoch: + LOGGER.info( + f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." + ) + self.epochs += ckpt["epoch"] # finetune additional epochs + self.best_fitness = best_fitness + self.start_epoch = start_epoch + if start_epoch > (self.epochs - self.args.close_mosaic): + self._close_dataloader_mosaic() + + def _close_dataloader_mosaic(self): + """Update dataloaders to stop using mosaic augmentation.""" + if hasattr(self.train_loader.dataset, "mosaic"): + self.train_loader.dataset.mosaic = False + if hasattr(self.train_loader.dataset, "close_mosaic"): + LOGGER.info("Closing dataloader mosaic") + self.train_loader.dataset.close_mosaic(hyp=copy(self.args)) + + def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): + """ + Construct an optimizer for the given model. + + Args: + model (torch.nn.Module): The model for which to build an optimizer. + name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected + based on the number of iterations. Default: 'auto'. + lr (float, optional): The learning rate for the optimizer. Default: 0.001. + momentum (float, optional): The momentum factor for the optimizer. Default: 0.9. + decay (float, optional): The weight decay for the optimizer. Default: 1e-5. + iterations (float, optional): The number of iterations, which determines the optimizer if + name is 'auto'. Default: 1e5. + + Returns: + (torch.optim.Optimizer): The constructed optimizer. + """ + g = [], [], [] # optimizer parameter groups + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + LOGGER.info( + f"{colorstr('optimizer:')} 'optimizer=auto' found, " + f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = self.data.get("nc", 10) # number of classes + lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places + name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) + self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam + + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) + g[2].append(param) + elif isinstance(module, bn): # weight (no decay) + g[1].append(param) + else: # weight (with decay) + g[0].append(param) + + optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"} + name = {x.lower(): x for x in optimizers}.get(name.lower()) + if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}: + optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) + elif name == "RMSProp": + optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) + elif name == "SGD": + optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) + else: + raise NotImplementedError( + f"Optimizer '{name}' not found in list of available optimizers {optimizers}. " + "Request support for addition optimizers at https://github.com/ultralytics/ultralytics." + ) + + optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay + optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights) + LOGGER.info( + f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " + f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)" + ) + return optimizer diff --git a/tracking/ultralytics/engine/tuner.py b/tracking/ultralytics/engine/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..d9615a12a185541cc8c37cc120c89effa46a9a29 --- /dev/null +++ b/tracking/ultralytics/engine/tuner.py @@ -0,0 +1,236 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance +segmentation, image classification, pose estimation, and multi-object tracking. + +Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters +that yield the best model performance. This is particularly crucial in deep learning models like YOLO, +where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency. + +Examples: + Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False) +""" + +import random +import shutil +import subprocess +import time + +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save +from ultralytics.utils.plotting import plot_tune_results + + +class Tuner: + """ + A class for hyperparameter tuning of YOLO models. + + The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the + search space and retraining the model to evaluate their performance. + + Attributes: + space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. + tune_dir (Path): Directory where evolution logs and results will be saved. + tune_csv (Path): Path to the CSV file where evolution logs are saved. + args (dict): Configuration arguments for the tuning process. + callbacks (list): Callback functions to be executed during tuning. + prefix (str): Prefix string for logging messages. + + Methods: + _mutate: Mutates the given hyperparameters within the specified bounds. + __call__: Executes the hyperparameter evolution across multiple iterations. + + Examples: + Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> model.tune( + ... data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False + ... ) + + Tune with custom search space. + >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary + """ + + def __init__(self, args=DEFAULT_CFG, _callbacks=None): + """ + Initialize the Tuner with configurations. + + Args: + args (dict): Configuration for hyperparameter evolution. + _callbacks (list, optional): Callback functions to be executed during tuning. + """ + self.space = args.pop("space", None) or { # key: (min, max, gain(optional)) + # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), + "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) + "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 + "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": (0.0, 0.95), # warmup initial momentum + "box": (1.0, 20.0), # box loss gain + "cls": (0.2, 4.0), # cls loss gain (scale with pixels) + "dfl": (0.4, 6.0), # dfl loss gain + "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": (0.0, 45.0), # image rotation (+/- deg) + "translate": (0.0, 0.9), # image translation (+/- fraction) + "scale": (0.0, 0.95), # image scale (+/- gain) + "shear": (0.0, 10.0), # image shear (+/- deg) + "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": (0.0, 1.0), # image flip up-down (probability) + "fliplr": (0.0, 1.0), # image flip left-right (probability) + "bgr": (0.0, 1.0), # image channel bgr (probability) + "mosaic": (0.0, 1.0), # image mixup (probability) + "mixup": (0.0, 1.0), # image mixup (probability) + "copy_paste": (0.0, 1.0), # segment copy-paste (probability) + } + self.args = get_cfg(overrides=args) + self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune") + self.args.name = None # reset to not affect training directory + self.tune_csv = self.tune_dir / "tune_results.csv" + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.prefix = colorstr("Tuner: ") + callbacks.add_integration_callbacks(self) + LOGGER.info( + f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" + f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning" + ) + + def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2): + """ + Mutate hyperparameters based on bounds and scaling factors specified in `self.space`. + + Args: + parent (str): Parent selection method: 'single' or 'weighted'. + n (int): Number of parents to consider. + mutation (float): Probability of a parameter mutation in any given iteration. + sigma (float): Standard deviation for Gaussian random number generator. + + Returns: + (dict): A dictionary containing mutated hyperparameters. + """ + if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate + # Select parent(s) + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) + fitness = x[:, 0] # first column + n = min(n, len(x)) # number of previous results to consider + x = x[np.argsort(-fitness)][:n] # top n mutations + w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0) + if parent == "single" or len(x) == 1: + # x = x[random.randint(0, n - 1)] # random selection + x = x[random.choices(range(n), weights=w)[0]] # weighted selection + elif parent == "weighted": + x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination + + # Mutate + r = np.random # method + r.seed(int(time.time())) + g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1 + ng = len(self.space) + v = np.ones(ng) + while all(v == 1): # mutate until a change occurs (prevent duplicates) + v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0) + hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())} + else: + hyp = {k: getattr(self.args, k) for k in self.space.keys()} + + # Constrain to limits + for k, v in self.space.items(): + hyp[k] = max(hyp[k], v[0]) # lower limit + hyp[k] = min(hyp[k], v[1]) # upper limit + hyp[k] = round(hyp[k], 5) # significant digits + + return hyp + + def __call__(self, model=None, iterations=10, cleanup=True): + """ + Execute the hyperparameter evolution process when the Tuner instance is called. + + This method iterates through the number of iterations, performing the following steps in each iteration: + + 1. Load the existing hyperparameters or initialize new ones. + 2. Mutate the hyperparameters using the `mutate` method. + 3. Train a YOLO model with the mutated hyperparameters. + 4. Log the fitness score and mutated hyperparameters to a CSV file. + + Args: + model (Model): A pre-initialized YOLO model to be used for training. + iterations (int): The number of generations to run the evolution for. + cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning. + + Note: + The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores. + Ensure this path is set correctly in the Tuner instance. + """ + t0 = time.time() + best_save_dir, best_metrics = None, None + (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True) + for i in range(iterations): + # Mutate hyperparameters + mutated_hyp = self._mutate() + LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}") + + metrics = {} + train_args = {**vars(self.args), **mutated_hyp} + save_dir = get_save_dir(get_cfg(train_args)) + weights_dir = save_dir / "weights" + try: + # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) + launch = [__import__("sys").executable, "-m", "ultralytics.cfg.__init__"] # workaround yolo not found + cmd = [*launch, "train", *(f"{k}={v}" for k, v in train_args.items())] + return_code = subprocess.run(cmd, check=True).returncode + ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt") + metrics = torch.load(ckpt_file)["train_metrics"] + assert return_code == 0, "training failed" + + except Exception as e: + LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}") + + # Save results and mutated_hyp to CSV + fitness = metrics.get("fitness", 0.0) + log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] + headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n") + with open(self.tune_csv, "a", encoding="utf-8") as f: + f.write(headers + ",".join(map(str, log_row)) + "\n") + + # Get best results + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) + fitness = x[:, 0] # first column + best_idx = fitness.argmax() + best_is_current = best_idx == i + if best_is_current: + best_save_dir = save_dir + best_metrics = {k: round(v, 5) for k, v in metrics.items()} + for ckpt in weights_dir.glob("*.pt"): + shutil.copy2(ckpt, self.tune_dir / "weights") + elif cleanup: + shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space + + # Plot tune results + plot_tune_results(self.tune_csv) + + # Save and print tune results + header = ( + f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n" + f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n" + f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n" + f"{self.prefix}Best fitness metrics are {best_metrics}\n" + f"{self.prefix}Best fitness model is {best_save_dir}\n" + f"{self.prefix}Best fitness hyperparameters are printed below.\n" + ) + LOGGER.info("\n" + header) + data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} + yaml_save( + self.tune_dir / "best_hyperparameters.yaml", + data=data, + header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n", + ) + yaml_print(self.tune_dir / "best_hyperparameters.yaml") diff --git a/tracking/ultralytics/engine/validator.py b/tracking/ultralytics/engine/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..d266dd795fb94180862f2555638cf38fd654ef20 --- /dev/null +++ b/tracking/ultralytics/engine/validator.py @@ -0,0 +1,377 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Check a model's accuracy on a test or val split of a dataset. + +Usage: + $ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640 + +Usage - formats: + $ yolo mode=val model=yolo11n.pt # PyTorch + yolo11n.torchscript # TorchScript + yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolo11n_openvino_model # OpenVINO + yolo11n.engine # TensorRT + yolo11n.mlpackage # CoreML (macOS-only) + yolo11n_saved_model # TensorFlow SavedModel + yolo11n.pb # TensorFlow GraphDef + yolo11n.tflite # TensorFlow Lite + yolo11n_edgetpu.tflite # TensorFlow Edge TPU + yolo11n_paddle_model # PaddlePaddle + yolo11n.mnn # MNN + yolo11n_ncnn_model # NCNN + yolo11n_imx_model # Sony IMX + yolo11n_rknn_model # Rockchip RKNN +""" + +import json +import time +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis +from ultralytics.utils.checks import check_imgsz +from ultralytics.utils.ops import Profile +from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode + + +class BaseValidator: + """ + A base class for creating validators. + + This class provides the foundation for validation processes, including model evaluation, metric computation, and + result visualization. + + Attributes: + args (SimpleNamespace): Configuration for the validator. + dataloader (DataLoader): Dataloader to use for validation. + pbar (tqdm): Progress bar to update during validation. + model (nn.Module): Model to validate. + data (dict): Data dictionary containing dataset information. + device (torch.device): Device to use for validation. + batch_i (int): Current batch index. + training (bool): Whether the model is in training mode. + names (dict): Class names mapping. + seen (int): Number of images seen so far during validation. + stats (dict): Statistics collected during validation. + confusion_matrix: Confusion matrix for classification evaluation. + nc (int): Number of classes. + iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05. + jdict (list): List to store JSON validation results. + speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective + batch processing times in milliseconds. + save_dir (Path): Directory to save results. + plots (dict): Dictionary to store plots for visualization. + callbacks (dict): Dictionary to store various callback functions. + + Methods: + __call__: Execute validation process, running inference on dataloader and computing performance metrics. + match_predictions: Match predictions to ground truth objects using IoU. + add_callback: Append the given callback to the specified event. + run_callbacks: Run all callbacks associated with a specified event. + get_dataloader: Get data loader from dataset path and batch size. + build_dataset: Build dataset from image path. + preprocess: Preprocess an input batch. + postprocess: Postprocess the predictions. + init_metrics: Initialize performance metrics for the YOLO model. + update_metrics: Update metrics based on predictions and batch. + finalize_metrics: Finalize and return all metrics. + get_stats: Return statistics about the model's performance. + check_stats: Check statistics. + print_results: Print the results of the model's predictions. + get_desc: Get description of the YOLO model. + on_plot: Register plots (e.g. to be consumed in callbacks). + plot_val_samples: Plot validation samples during training. + plot_predictions: Plot YOLO model predictions on batch images. + pred_to_json: Convert predictions to JSON format. + eval_json: Evaluate and return JSON format of prediction statistics. + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initialize a BaseValidator instance. + + Args: + dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm, optional): Progress bar for displaying progress. + args (SimpleNamespace, optional): Configuration for the validator. + _callbacks (dict, optional): Dictionary to store various callback functions. + """ + self.args = get_cfg(overrides=args) + self.dataloader = dataloader + self.pbar = pbar + self.stride = None + self.data = None + self.device = None + self.batch_i = None + self.training = True + self.names = None + self.seen = None + self.stats = None + self.confusion_matrix = None + self.nc = None + self.iouv = None + self.jdict = None + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + + self.save_dir = save_dir or get_save_dir(self.args) + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + if self.args.conf is None: + self.args.conf = 0.001 # default conf=0.001 + self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) + + self.plots = {} + self.callbacks = _callbacks or callbacks.get_default_callbacks() + + @smart_inference_mode() + def __call__(self, trainer=None, model=None): + """ + Execute validation process, running inference on dataloader and computing performance metrics. + + Args: + trainer (object, optional): Trainer object that contains the model to validate. + model (nn.Module, optional): Model to validate if not using a trainer. + + Returns: + stats (dict): Dictionary containing validation statistics. + """ + self.training = trainer is not None + augment = self.args.augment and (not self.training) + if self.training: + self.device = trainer.device + self.data = trainer.data + # Force FP16 val during training + self.args.half = self.device.type != "cpu" and trainer.amp + model = trainer.ema.ema or trainer.model + model = model.half() if self.args.half else model.float() + # self.model = model + self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) + self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1) + model.eval() + else: + if str(self.args.model).endswith(".yaml") and model is None: + LOGGER.warning("WARNING ⚠️ validating an untrained model YAML will result in 0 mAP.") + callbacks.add_integration_callbacks(self) + model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + ) + # self.model = model + self.device = model.device # update device + self.args.half = model.fp16 # update half + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_imgsz(self.args.imgsz, stride=stride) + if engine: + self.args.batch = model.batch_size + elif not pt and not jit: + self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1 + LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})") + + if str(self.args.data).split(".")[-1] in {"yaml", "yml"}: + self.data = check_det_dataset(self.args.data) + elif self.args.task == "classify": + self.data = check_cls_dataset(self.args.data, split=self.args.split) + else: + raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) + + if self.device.type in {"cpu", "mps"}: + self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading + if not pt: + self.args.rect = False + self.stride = model.stride # used in get_dataloader() for padding + self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch) + + model.eval() + model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup + + self.run_callbacks("on_val_start") + dt = ( + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + ) + bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) + self.init_metrics(de_parallel(model)) + self.jdict = [] # empty before each val + for batch_i, batch in enumerate(bar): + self.run_callbacks("on_val_batch_start") + self.batch_i = batch_i + # Preprocess + with dt[0]: + batch = self.preprocess(batch) + + # Inference + with dt[1]: + preds = model(batch["img"], augment=augment) + + # Loss + with dt[2]: + if self.training: + self.loss += model.loss(batch, preds)[1] + + # Postprocess + with dt[3]: + preds = self.postprocess(preds) + + self.update_metrics(preds, batch) + if self.args.plots and batch_i < 3: + self.plot_val_samples(batch, batch_i) + self.plot_predictions(batch, preds, batch_i) + + self.run_callbacks("on_val_batch_end") + stats = self.get_stats() + self.check_stats(stats) + self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) + self.finalize_metrics() + self.print_results() + self.run_callbacks("on_val_end") + if self.training: + model.float() + results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} + return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats + else: + LOGGER.info( + "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format( + *tuple(self.speed.values()) + ) + ) + if self.args.save_json and self.jdict: + with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + if self.args.plots or self.args.save_json: + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") + return stats + + def match_predictions( + self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False + ) -> torch.Tensor: + """ + Match predictions to ground truth objects using IoU. + + Args: + pred_classes (torch.Tensor): Predicted class indices of shape (N,). + true_classes (torch.Tensor): Target class indices of shape (M,). + iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth. + use_scipy (bool): Whether to use scipy for matching (more precise). + + Returns: + (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds. + """ + # Dx10 matrix, where D - detections, 10 - IoU thresholds + correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) + # LxD matrix where L - labels (rows), D - detections (columns) + correct_class = true_classes[:, None] == pred_classes + iou = iou * correct_class # zero out the wrong classes + iou = iou.cpu().numpy() + for i, threshold in enumerate(self.iouv.cpu().tolist()): + if use_scipy: + # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 + import scipy # scope import to avoid importing for all commands + + cost_matrix = iou * (iou >= threshold) + if cost_matrix.any(): + labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix) + valid = cost_matrix[labels_idx, detections_idx] > 0 + if valid.any(): + correct[detections_idx[valid], i] = True + else: + matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match + matches = np.array(matches).T + if matches.shape[0]: + if matches.shape[0] > 1: + matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) + + def add_callback(self, event: str, callback): + """Append the given callback to the specified event.""" + self.callbacks[event].append(callback) + + def run_callbacks(self, event: str): + """Run all callbacks associated with a specified event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def get_dataloader(self, dataset_path, batch_size): + """Get data loader from dataset path and batch size.""" + raise NotImplementedError("get_dataloader function not implemented for this validator") + + def build_dataset(self, img_path): + """Build dataset from image path.""" + raise NotImplementedError("build_dataset function not implemented in validator") + + def preprocess(self, batch): + """Preprocess an input batch.""" + return batch + + def postprocess(self, preds): + """Postprocess the predictions.""" + return preds + + def init_metrics(self, model): + """Initialize performance metrics for the YOLO model.""" + pass + + def update_metrics(self, preds, batch): + """Update metrics based on predictions and batch.""" + pass + + def finalize_metrics(self, *args, **kwargs): + """Finalize and return all metrics.""" + pass + + def get_stats(self): + """Return statistics about the model's performance.""" + return {} + + def check_stats(self, stats): + """Check statistics.""" + pass + + def print_results(self): + """Print the results of the model's predictions.""" + pass + + def get_desc(self): + """Get description of the YOLO model.""" + pass + + @property + def metric_keys(self): + """Return the metric keys used in YOLO training/validation.""" + return [] + + def on_plot(self, name, data=None): + """Register plots (e.g. to be consumed in callbacks).""" + self.plots[Path(name)] = {"data": data, "timestamp": time.time()} + + # TODO: may need to put these following functions into callback + def plot_val_samples(self, batch, ni): + """Plot validation samples during training.""" + pass + + def plot_predictions(self, batch, preds, ni): + """Plot YOLO model predictions on batch images.""" + pass + + def pred_to_json(self, preds, batch): + """Convert predictions to JSON format.""" + pass + + def eval_json(self, stats): + """Evaluate and return JSON format of prediction statistics.""" + pass diff --git a/tracking/ultralytics/hub/__init__.py b/tracking/ultralytics/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..caaaff1de75feb4539ee8ff4eac5699facf1131a --- /dev/null +++ b/tracking/ultralytics/hub/__init__.py @@ -0,0 +1,141 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import requests + +from ultralytics.data.utils import HUBDatasetStats +from ultralytics.hub.auth import Auth +from ultralytics.hub.session import HUBTrainingSession +from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events +from ultralytics.utils import LOGGER, SETTINGS, checks + +__all__ = ( + "PREFIX", + "HUB_WEB_ROOT", + "HUBTrainingSession", + "login", + "logout", + "reset_model", + "export_fmts_hub", + "export_model", + "get_export", + "check_dataset", + "events", +) + + +def login(api_key: str = None, save: bool = True) -> bool: + """ + Log in to the Ultralytics HUB API using the provided API key. + + The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY + environment variable if successfully authenticated. + + Args: + api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS + or HUB_API_KEY environment variable. + save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful. + + Returns: + (bool): True if authentication is successful, False otherwise. + """ + checks.check_requirements("hub-sdk>=0.0.12") + from hub_sdk import HUBClient + + api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL + saved_key = SETTINGS.get("api_key") + active_key = api_key or saved_key + credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials + + client = HUBClient(credentials) # initialize HUBClient + + if client.authenticated: + # Successfully authenticated with HUB + + if save and client.api_key != saved_key: + SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key + + # Set message based on whether key was provided or retrieved from settings + log_message = ( + "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅" + ) + LOGGER.info(f"{PREFIX}{log_message}") + + return True + else: + # Failed to authenticate with HUB + LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo login API_KEY'") + return False + + +def logout(): + """ + Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo login'. + + Examples: + >>> from ultralytics import hub + >>> hub.logout() + """ + SETTINGS["api_key"] = "" + LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.") + + +def reset_model(model_id: str = ""): + """Reset a trained model to an untrained state.""" + r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key}) + if r.status_code == 200: + LOGGER.info(f"{PREFIX}Model reset successfully") + return + LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}") + + +def export_fmts_hub(): + """Returns a list of HUB-supported export formats.""" + from ultralytics.engine.exporter import export_formats + + return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"] + + +def export_model(model_id: str = "", format: str = "torchscript"): + """Export a model to the specified format.""" + assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" + r = requests.post( + f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key} + ) + assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" + LOGGER.info(f"{PREFIX}{format} export started ✅") + + +def get_export(model_id: str = "", format: str = "torchscript"): + """Get an exported model dictionary with download URL.""" + assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" + r = requests.post( + f"{HUB_API_ROOT}/get-export", + json={"apiKey": Auth().api_key, "modelId": model_id, "format": format}, + headers={"x-api-key": Auth().api_key}, + ) + assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" + return r.json() + + +def check_dataset(path: str, task: str) -> None: + """ + Check HUB dataset Zip file for errors before upload. + + Args: + path (str): Path to data.zip (with data.yaml inside data.zip). + task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'. + + Examples: + >>> from ultralytics.hub import check_dataset + >>> check_dataset("path/to/coco8.zip", task="detect") # detect dataset + >>> check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset + >>> check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset + >>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset + >>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset + + Note: + Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets + i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip. + """ + HUBDatasetStats(path=path, task=task).get_json() + LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.") diff --git a/tracking/ultralytics/hub/auth.py b/tracking/ultralytics/hub/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac17ba6e54c583d3b95b82b2ef33cf84141258b --- /dev/null +++ b/tracking/ultralytics/hub/auth.py @@ -0,0 +1,137 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import requests + +from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials +from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis + +API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys" + + +class Auth: + """ + Manages authentication processes including API key handling, cookie-based authentication, and header generation. + + The class supports different methods of authentication: + 1. Directly using an API key. + 2. Authenticating using browser cookies (specifically in Google Colab). + 3. Prompting the user to enter an API key. + + Attributes: + id_token (str | bool): Token used for identity verification, initialized as False. + api_key (str | bool): API key for authentication, initialized as False. + model_key (bool): Placeholder for model key, initialized as False. + """ + + id_token = api_key = model_key = False + + def __init__(self, api_key: str = "", verbose: bool = False): + """ + Initialize Auth class and authenticate user. + + Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful + authentication. + + Args: + api_key (str): API key or combined key_id format. + verbose (bool): Enable verbose logging. + """ + # Split the input API key in case it contains a combined key_model and keep only the API key part + api_key = api_key.split("_")[0] + + # Set API key attribute as value passed or SETTINGS API key if none passed + self.api_key = api_key or SETTINGS.get("api_key", "") + + # If an API key is provided + if self.api_key: + # If the provided API key matches the API key in the SETTINGS + if self.api_key == SETTINGS.get("api_key"): + # Log that the user is already logged in + if verbose: + LOGGER.info(f"{PREFIX}Authenticated ✅") + return + else: + # Attempt to authenticate with the provided API key + success = self.authenticate() + # If the API key is not provided and the environment is a Google Colab notebook + elif IS_COLAB: + # Attempt to authenticate using browser cookies + success = self.auth_with_cookies() + else: + # Request an API key + success = self.request_api_key() + + # Update SETTINGS with the new API key after successful authentication + if success: + SETTINGS.update({"api_key": self.api_key}) + # Log that the new login was successful + if verbose: + LOGGER.info(f"{PREFIX}New authentication successful ✅") + elif verbose: + LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'") + + def request_api_key(self, max_attempts: int = 3) -> bool: + """Prompt the user to input their API key.""" + import getpass + + for attempts in range(max_attempts): + LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}") + input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ") + self.api_key = input_key.split("_")[0] # remove model id if present + if self.authenticate(): + return True + raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) + + def authenticate(self) -> bool: + """ + Attempt to authenticate with the server using either id_token or API key. + + Returns: + (bool): True if authentication is successful, False otherwise. + """ + try: + if header := self.get_auth_header(): + r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) + if not r.json().get("success", False): + raise ConnectionError("Unable to authenticate.") + return True + raise ConnectionError("User has not authenticated locally.") + except ConnectionError: + self.id_token = self.api_key = False # reset invalid + LOGGER.warning(f"{PREFIX}Invalid API key ⚠️") + return False + + def auth_with_cookies(self) -> bool: + """ + Attempt to fetch authentication via cookies and set id_token. + + User must be logged in to HUB and running in a supported browser. + + Returns: + (bool): True if authentication is successful, False otherwise. + """ + if not IS_COLAB: + return False # Currently only works with Colab + try: + authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") + if authn.get("success", False): + self.id_token = authn.get("data", {}).get("idToken", None) + self.authenticate() + return True + raise ConnectionError("Unable to fetch browser authentication details.") + except ConnectionError: + self.id_token = False # reset invalid + return False + + def get_auth_header(self): + """ + Get the authentication header for making API requests. + + Returns: + (dict | None): The authentication header if id_token or API key is set, None otherwise. + """ + if self.id_token: + return {"authorization": f"Bearer {self.id_token}"} + elif self.api_key: + return {"x-api-key": self.api_key} + # else returns None diff --git a/tracking/ultralytics/hub/google/__init__.py b/tracking/ultralytics/hub/google/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0acd2dd26da7f5ff583a4ec694a2b0d4030ae9c8 --- /dev/null +++ b/tracking/ultralytics/hub/google/__init__.py @@ -0,0 +1,159 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import concurrent.futures +import statistics +import time +from typing import List, Optional, Tuple + +import requests + + +class GCPRegions: + """ + A class for managing and analyzing Google Cloud Platform (GCP) regions. + + This class provides functionality to initialize, categorize, and analyze GCP regions based on their + geographical location, tier classification, and network latency. + + Attributes: + regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country. + + Methods: + tier1: Returns a list of tier 1 GCP regions. + tier2: Returns a list of tier 2 GCP regions. + lowest_latency: Determines the GCP region(s) with the lowest network latency. + + Examples: + >>> from ultralytics.hub.google import GCPRegions + >>> regions = GCPRegions() + >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3) + >>> print(f"Lowest latency region: {lowest_latency_region[0][0]}") + """ + + def __init__(self): + """Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details.""" + self.regions = { + "asia-east1": (1, "Taiwan", "China"), + "asia-east2": (2, "Hong Kong", "China"), + "asia-northeast1": (1, "Tokyo", "Japan"), + "asia-northeast2": (1, "Osaka", "Japan"), + "asia-northeast3": (2, "Seoul", "South Korea"), + "asia-south1": (2, "Mumbai", "India"), + "asia-south2": (2, "Delhi", "India"), + "asia-southeast1": (2, "Jurong West", "Singapore"), + "asia-southeast2": (2, "Jakarta", "Indonesia"), + "australia-southeast1": (2, "Sydney", "Australia"), + "australia-southeast2": (2, "Melbourne", "Australia"), + "europe-central2": (2, "Warsaw", "Poland"), + "europe-north1": (1, "Hamina", "Finland"), + "europe-southwest1": (1, "Madrid", "Spain"), + "europe-west1": (1, "St. Ghislain", "Belgium"), + "europe-west10": (2, "Berlin", "Germany"), + "europe-west12": (2, "Turin", "Italy"), + "europe-west2": (2, "London", "United Kingdom"), + "europe-west3": (2, "Frankfurt", "Germany"), + "europe-west4": (1, "Eemshaven", "Netherlands"), + "europe-west6": (2, "Zurich", "Switzerland"), + "europe-west8": (1, "Milan", "Italy"), + "europe-west9": (1, "Paris", "France"), + "me-central1": (2, "Doha", "Qatar"), + "me-west1": (1, "Tel Aviv", "Israel"), + "northamerica-northeast1": (2, "Montreal", "Canada"), + "northamerica-northeast2": (2, "Toronto", "Canada"), + "southamerica-east1": (2, "São Paulo", "Brazil"), + "southamerica-west1": (2, "Santiago", "Chile"), + "us-central1": (1, "Iowa", "United States"), + "us-east1": (1, "South Carolina", "United States"), + "us-east4": (1, "Northern Virginia", "United States"), + "us-east5": (1, "Columbus", "United States"), + "us-south1": (1, "Dallas", "United States"), + "us-west1": (1, "Oregon", "United States"), + "us-west2": (2, "Los Angeles", "United States"), + "us-west3": (2, "Salt Lake City", "United States"), + "us-west4": (2, "Las Vegas", "United States"), + } + + def tier1(self) -> List[str]: + """Returns a list of GCP regions classified as tier 1 based on predefined criteria.""" + return [region for region, info in self.regions.items() if info[0] == 1] + + def tier2(self) -> List[str]: + """Returns a list of GCP regions classified as tier 2 based on predefined criteria.""" + return [region for region, info in self.regions.items() if info[0] == 2] + + @staticmethod + def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]: + """Pings a specified GCP region and returns latency statistics: mean, min, max, and standard deviation.""" + url = f"https://{region}-docker.pkg.dev" + latencies = [] + for _ in range(attempts): + try: + start_time = time.time() + _ = requests.head(url, timeout=5) + latency = (time.time() - start_time) * 1000 # convert latency to milliseconds + if latency != float("inf"): + latencies.append(latency) + except requests.RequestException: + pass + if not latencies: + return region, float("inf"), float("inf"), float("inf"), float("inf") + + std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0 + return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies) + + def lowest_latency( + self, + top: int = 1, + verbose: bool = False, + tier: Optional[int] = None, + attempts: int = 1, + ) -> List[Tuple[str, float, float, float, float]]: + """ + Determines the GCP regions with the lowest latency based on ping tests. + + Args: + top (int): Number of top regions to return. + verbose (bool): If True, prints detailed latency information for all tested regions. + tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested. + attempts (int): Number of ping attempts per region. + + Returns: + (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and + latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency). + + Examples: + >>> regions = GCPRegions() + >>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2) + >>> print(results[0][0]) # Print the name of the lowest latency region + """ + if verbose: + print(f"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...") + + regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys()) + with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test)) + + sorted_results = sorted(results, key=lambda x: x[1]) + + if verbose: + print(f"{'Region':<25} {'Location':<35} {'Tier':<5} Latency (ms)") + for region, mean, std, min_, max_ in sorted_results: + tier, city, country = self.regions[region] + location = f"{city}, {country}" + if mean == float("inf"): + print(f"{region:<25} {location:<35} {tier:<5} Timeout") + else: + print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})") + print(f"\nLowest latency region{'s' if top > 1 else ''}:") + for region, mean, std, min_, max_ in sorted_results[:top]: + tier, city, country = self.regions[region] + location = f"{city}, {country}" + print(f"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))") + + return sorted_results[:top] + + +# Usage example +if __name__ == "__main__": + regions = GCPRegions() + top_3_latency_tier1 = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=3) diff --git a/tracking/ultralytics/hub/session.py b/tracking/ultralytics/hub/session.py new file mode 100644 index 0000000000000000000000000000000000000000..340c5f3eeee7705c37f8dd146b747b189e00895e --- /dev/null +++ b/tracking/ultralytics/hub/session.py @@ -0,0 +1,445 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import shutil +import threading +import time +from http import HTTPStatus +from pathlib import Path +from urllib.parse import parse_qs, urlparse + +import requests + +from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM +from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis +from ultralytics.utils.errors import HUBModelError + +AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local" + + +class HUBTrainingSession: + """ + HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. + + This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including + model creation, metrics tracking, and checkpoint uploading. + + Attributes: + model_id (str): Identifier for the YOLO model being trained. + model_url (str): URL for the model in Ultralytics HUB. + rate_limits (dict): Rate limits for different API calls (in seconds). + timers (dict): Timers for rate limiting. + metrics_queue (dict): Queue for the model's metrics. + metrics_upload_failed_queue (dict): Queue for metrics that failed to upload. + model (dict): Model data fetched from Ultralytics HUB. + model_file (str): Path to the model file. + train_args (dict): Arguments for training the model. + client (HUBClient): Client for interacting with Ultralytics HUB. + filename (str): Filename of the model. + + Examples: + >>> session = HUBTrainingSession("https://hub.ultralytics.com/models/example-model") + >>> session.upload_metrics() + """ + + def __init__(self, identifier): + """ + Initialize the HUBTrainingSession with the provided model identifier. + + Args: + identifier (str): Model identifier used to initialize the HUB training session. + It can be a URL string or a model key with specific format. + + Raises: + ValueError: If the provided model identifier is invalid. + ConnectionError: If connecting with global API key is not supported. + ModuleNotFoundError: If hub-sdk package is not installed. + """ + from hub_sdk import HUBClient + + self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds) + self.metrics_queue = {} # holds metrics for each epoch until upload + self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed + self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py + self.model = None + self.model_url = None + self.model_file = None + self.train_args = None + + # Parse input + api_key, model_id, self.filename = self._parse_identifier(identifier) + + # Get credentials + active_key = api_key or SETTINGS.get("api_key") + credentials = {"api_key": active_key} if active_key else None # set credentials + + # Initialize client + self.client = HUBClient(credentials) + + # Load models + try: + if model_id: + self.load_model(model_id) # load existing model + else: + self.model = self.client.model() # load empty model + except Exception: + if identifier.startswith(f"{HUB_WEB_ROOT}/models/") and not self.client.authenticated: + LOGGER.warning( + f"{PREFIX}WARNING ⚠️ Please log in using 'yolo login API_KEY'. " + "You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys." + ) + + @classmethod + def create_session(cls, identifier, args=None): + """ + Create an authenticated HUBTrainingSession or return None. + + Args: + identifier (str): Model identifier used to initialize the HUB training session. + args (dict, optional): Arguments for creating a new model if identifier is not a HUB model URL. + + Returns: + (HUBTrainingSession | None): An authenticated session or None if creation fails. + """ + try: + session = cls(identifier) + if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL + session.create_model(args) + assert session.model.id, "HUB model not loaded correctly" + return session + # PermissionError and ModuleNotFoundError indicate hub-sdk not installed + except (PermissionError, ModuleNotFoundError, AssertionError): + return None + + def load_model(self, model_id): + """ + Load an existing model from Ultralytics HUB using the provided model identifier. + + Args: + model_id (str): The identifier of the model to load. + + Raises: + ValueError: If the specified HUB model does not exist. + """ + self.model = self.client.model(model_id) + if not self.model.data: # then model does not exist + raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling + + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" + if self.model.is_trained(): + LOGGER.info(f"Loading trained HUB model {self.model_url} 🚀") + url = self.model.get_weights_url("best") # download URL with auth + self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id) + return + + # Set training args and start heartbeats for HUB to monitor agent + self._set_train_args() + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") + + def create_model(self, model_args): + """ + Initialize a HUB training session with the specified model arguments. + + Args: + model_args (dict): Arguments for creating the model, including batch size, epochs, image size, etc. + + Returns: + (None): If the model could not be created. + """ + payload = { + "config": { + "batchSize": model_args.get("batch", -1), + "epochs": model_args.get("epochs", 300), + "imageSize": model_args.get("imgsz", 640), + "patience": model_args.get("patience", 100), + "device": str(model_args.get("device", "")), # convert None to string + "cache": str(model_args.get("cache", "ram")), # convert True, False, None to string + }, + "dataset": {"name": model_args.get("data")}, + "lineage": { + "architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")}, + "parent": {}, + }, + "meta": {"name": self.filename}, + } + + if self.filename.endswith(".pt"): + payload["lineage"]["parent"]["name"] = self.filename + + self.model.create_model(payload) + + # Model could not be created + # TODO: improve error handling + if not self.model.id: + return None + + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" + + # Start heartbeats for HUB to monitor agent + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") + + @staticmethod + def _parse_identifier(identifier): + """ + Parse the given identifier to determine the type and extract relevant components. + + The method supports different identifier formats: + - A HUB model URL https://hub.ultralytics.com/models/MODEL + - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY + - A local filename that ends with '.pt' or '.yaml' + + Args: + identifier (str): The identifier string to be parsed. + + Returns: + (tuple): A tuple containing the API key, model ID, and filename as applicable. + + Raises: + HUBModelError: If the identifier format is not recognized. + """ + api_key, model_id, filename = None, None, None + if Path(identifier).suffix in {".pt", ".yaml"}: + filename = identifier + elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"): + parsed_url = urlparse(identifier) + model_id = Path(parsed_url.path).stem # handle possible final backslash robustly + query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]} + api_key = query_params.get("api_key", [None])[0] + else: + raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID") + return api_key, model_id, filename + + def _set_train_args(self): + """ + Initialize training arguments and create a model entry on the Ultralytics HUB. + + This method sets up training arguments based on the model's state and updates them with any additional + arguments provided. It handles different states of the model, such as whether it's resumable, pretrained, + or requires specific file setup. + + Raises: + ValueError: If the model is already trained, if required dataset information is missing, or if there are + issues with the provided training arguments. + """ + if self.model.is_resumable(): + # Model has saved weights + self.train_args = {"data": self.model.get_dataset_url(), "resume": True} + self.model_file = self.model.get_weights_url("last") + else: + # Model has no saved weights + self.train_args = self.model.data.get("train_args") # new response + + # Set the model file as either a *.pt or *.yaml file + self.model_file = ( + self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() + ) + + if "data" not in self.train_args: + # RF bug - datasets are sometimes not exported + raise ValueError("Dataset may still be processing. Please wait a minute and try again.") + + self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u + self.model_id = self.model.id + + def request_queue( + self, + request_func, + retry=3, + timeout=30, + thread=True, + verbose=True, + progress_total=None, + stream_response=None, + *args, + **kwargs, + ): + """ + Attempt to execute `request_func` with retries, timeout handling, optional threading, and progress tracking. + + Args: + request_func (callable): The function to execute. + retry (int): Number of retry attempts. + timeout (int): Maximum time to wait for the request to complete. + thread (bool): Whether to run the request in a separate thread. + verbose (bool): Whether to log detailed messages. + progress_total (int, optional): Total size for progress tracking. + stream_response (bool, optional): Whether to stream the response. + *args (Any): Additional positional arguments for request_func. + **kwargs (Any): Additional keyword arguments for request_func. + + Returns: + (requests.Response | None): The response object if thread=False, otherwise None. + """ + + def retry_request(): + """Attempt to call `request_func` with retries, timeout, and optional threading.""" + t0 = time.time() # Record the start time for the timeout + response = None + for i in range(retry + 1): + if (time.time() - t0) > timeout: + LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}") + break # Timeout reached, exit loop + + response = request_func(*args, **kwargs) + if response is None: + LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") + time.sleep(2**i) # Exponential backoff before retrying + continue # Skip further processing and retry + + if progress_total: + self._show_upload_progress(progress_total, response) + elif stream_response: + self._iterate_content(response) + + if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: + # if request related to metrics upload + if kwargs.get("metrics"): + self.metrics_upload_failed_queue = {} + return response # Success, no need to retry + + if i == 0: + # Initial attempt, check status code and provide messages + message = self._get_failure_message(response, retry, timeout) + + if verbose: + LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})") + + if not self._should_retry(response.status_code): + LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}") + break # Not an error that should be retried, exit loop + + time.sleep(2**i) # Exponential backoff for retries + + # if request related to metrics upload and exceed retries + if response is None and kwargs.get("metrics"): + self.metrics_upload_failed_queue.update(kwargs.get("metrics")) + + return response + + if thread: + # Start a new thread to run the retry_request function + threading.Thread(target=retry_request, daemon=True).start() + else: + # If running in the main thread, call retry_request directly + return retry_request() + + @staticmethod + def _should_retry(status_code): + """ + Determine if a request should be retried based on the HTTP status code. + + Args: + status_code (int): The HTTP status code from the response. + + Returns: + (bool): True if the request should be retried, False otherwise. + """ + retry_codes = { + HTTPStatus.REQUEST_TIMEOUT, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.GATEWAY_TIMEOUT, + } + return status_code in retry_codes + + def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): + """ + Generate a retry message based on the response status code. + + Args: + response (requests.Response): The HTTP response object. + retry (int): The number of retry attempts allowed. + timeout (int): The maximum timeout duration. + + Returns: + (str): The retry message. + """ + if self._should_retry(response.status_code): + return f"Retrying {retry}x for {timeout}s." if retry else "" + elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit + headers = response.headers + return ( + f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " + f"Please retry after {headers['Retry-After']}s." + ) + else: + try: + return response.json().get("message", "No JSON message.") + except AttributeError: + return "Unable to read JSON." + + def upload_metrics(self): + """Upload model metrics to Ultralytics HUB.""" + return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True) + + def upload_model( + self, + epoch: int, + weights: str, + is_best: bool = False, + map: float = 0.0, + final: bool = False, + ) -> None: + """ + Upload a model checkpoint to Ultralytics HUB. + + Args: + epoch (int): The current training epoch. + weights (str): Path to the model weights file. + is_best (bool): Indicates if the current model is the best one so far. + map (float): Mean average precision of the model. + final (bool): Indicates if the model is the final model after training. + """ + weights = Path(weights) + if not weights.is_file(): + last = weights.with_name(f"last{weights.suffix}") + if final and last.is_file(): + LOGGER.warning( + f"{PREFIX} WARNING ⚠️ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. " + "This often happens when resuming training in transient environments like Google Colab. " + "For more reliable training, consider using Ultralytics HUB Cloud. " + "Learn more at https://docs.ultralytics.com/hub/cloud-training." + ) + shutil.copy(last, weights) # copy last.pt to best.pt + else: + LOGGER.warning(f"{PREFIX} WARNING ⚠️ Model upload issue. Missing model {weights}.") + return + + self.request_queue( + self.model.upload_model, + epoch=epoch, + weights=str(weights), + is_best=is_best, + map=map, + final=final, + retry=10, + timeout=3600, + thread=not final, + progress_total=weights.stat().st_size if final else None, # only show progress if final + stream_response=True, + ) + + @staticmethod + def _show_upload_progress(content_length: int, response: requests.Response) -> None: + """ + Display a progress bar to track the upload progress of a file download. + + Args: + content_length (int): The total size of the content to be downloaded in bytes. + response (requests.Response): The response object from the file download request. + """ + with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar: + for data in response.iter_content(chunk_size=1024): + pbar.update(len(data)) + + @staticmethod + def _iterate_content(response: requests.Response) -> None: + """ + Process the streamed HTTP response data. + + Args: + response (requests.Response): The response object from the file download request. + """ + for _ in response.iter_content(chunk_size=1024): + pass # Do nothing with data chunks diff --git a/tracking/ultralytics/hub/utils.py b/tracking/ultralytics/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c71f77257d22a3fa71344f1922165bac051c6f3 --- /dev/null +++ b/tracking/ultralytics/hub/utils.py @@ -0,0 +1,248 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +import platform +import random +import threading +import time +from pathlib import Path + +import requests + +from ultralytics.utils import ( + ARGV, + ENVIRONMENT, + IS_COLAB, + IS_GIT_DIR, + IS_PIP_PACKAGE, + LOGGER, + ONLINE, + RANK, + SETTINGS, + TESTS_RUNNING, + TQDM, + TryExcept, + __version__, + colorstr, + get_git_origin_url, +) +from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES + +HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com") +HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com") + +PREFIX = colorstr("Ultralytics HUB: ") +HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance." + + +def request_with_credentials(url: str) -> any: + """ + Make an AJAX request with cookies attached in a Google Colab environment. + + Args: + url (str): The URL to make the request to. + + Returns: + (Any): The response data from the AJAX request. + + Raises: + OSError: If the function is not run in a Google Colab environment. + """ + if not IS_COLAB: + raise OSError("request_with_credentials() must run in a Colab environment") + from google.colab import output # noqa + from IPython import display # noqa + + display.display( + display.Javascript( + f""" + window._hub_tmp = new Promise((resolve, reject) => {{ + const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000) + fetch("{url}", {{ + method: 'POST', + credentials: 'include' + }}) + .then((response) => resolve(response.json())) + .then((json) => {{ + clearTimeout(timeout); + }}).catch((err) => {{ + clearTimeout(timeout); + reject(err); + }}); + }}); + """ + ) + ) + return output.eval_js("_hub_tmp") + + +def requests_with_progress(method, url, **kwargs): + """ + Make an HTTP request using the specified method and URL, with an optional progress bar. + + Args: + method (str): The HTTP method to use (e.g. 'GET', 'POST'). + url (str): The URL to send the request to. + **kwargs (Any): Additional keyword arguments to pass to the underlying `requests.request` function. + + Returns: + (requests.Response): The response object from the HTTP request. + + Notes: + - If 'progress' is set to True, the progress bar will display the download progress for responses with a known + content length. + - If 'progress' is a number then progress bar will display assuming content length = progress. + """ + progress = kwargs.pop("progress", False) + if not progress: + return requests.request(method, url, **kwargs) + response = requests.request(method, url, stream=True, **kwargs) + total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size + try: + pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024) + for data in response.iter_content(chunk_size=1024): + pbar.update(len(data)) + pbar.close() + except requests.exceptions.ChunkedEncodingError: # avoid 'Connection broken: IncompleteRead' warnings + response.close() + return response + + +def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs): + """ + Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout. + + Args: + method (str): The HTTP method to use for the request. Choices are 'post' and 'get'. + url (str): The URL to make the request to. + retry (int, optional): Number of retries to attempt before giving up. + timeout (int, optional): Timeout in seconds after which the function will give up retrying. + thread (bool, optional): Whether to execute the request in a separate daemon thread. + code (int, optional): An identifier for the request, used for logging purposes. + verbose (bool, optional): A flag to determine whether to print out to console or not. + progress (bool, optional): Whether to show a progress bar during the request. + **kwargs (Any): Keyword arguments to be passed to the requests function specified in method. + + Returns: + (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None. + """ + retry_codes = (408, 500) # retry only these codes + + @TryExcept(verbose=verbose) + def func(func_method, func_url, **func_kwargs): + """Make HTTP requests with retries and timeouts, with optional progress tracking.""" + r = None # response + t0 = time.time() # initial time for timer + for i in range(retry + 1): + if (time.time() - t0) > timeout: + break + r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files) + if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful" + break + try: + m = r.json().get("message", "No JSON message.") + except AttributeError: + m = "Unable to read JSON." + if i == 0: + if r.status_code in retry_codes: + m += f" Retrying {retry}x for {timeout}s." if retry else "" + elif r.status_code == 429: # rate limit + h = r.headers # response headers + m = ( + f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " + f"Please retry after {h['Retry-After']}s." + ) + if verbose: + LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})") + if r.status_code not in retry_codes: + return r + time.sleep(2**i) # exponential standoff + return r + + args = method, url + kwargs["progress"] = progress + if thread: + threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() + else: + return func(*args, **kwargs) + + +class Events: + """ + A class for collecting anonymous event analytics. + + Event analytics are enabled when sync=True in settings and disabled when sync=False. Run 'yolo settings' to see and + update settings. + + Attributes: + url (str): The URL to send anonymous events. + rate_limit (float): The rate limit in seconds for sending events. + metadata (dict): A dictionary containing metadata about the environment. + enabled (bool): A flag to enable or disable Events based on certain conditions. + """ + + url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw" + + def __init__(self): + """Initialize the Events object with default values for events, rate_limit, and metadata.""" + self.events = [] # events list + self.rate_limit = 30.0 # rate limit (seconds) + self.t = 0.0 # rate limit timer (seconds) + self.metadata = { + "cli": Path(ARGV[0]).name == "yolo", + "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10 + "version": __version__, + "env": ENVIRONMENT, + "session_id": round(random.random() * 1e15), + "engagement_time_msec": 1000, + } + self.enabled = ( + SETTINGS["sync"] + and RANK in {-1, 0} + and not TESTS_RUNNING + and ONLINE + and (IS_PIP_PACKAGE or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") + ) + + def __call__(self, cfg): + """ + Attempt to add a new event to the events list and send events if the rate limit is reached. + + Args: + cfg (IterableSimpleNamespace): The configuration object containing mode and task information. + """ + if not self.enabled: + # Events disabled, do nothing + return + + # Attempt to add to events + if len(self.events) < 25: # Events list limited to 25 events (drop any events past this) + params = { + **self.metadata, + "task": cfg.task, + "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom", + } + if cfg.mode == "export": + params["format"] = cfg.format + self.events.append({"name": cfg.mode, "params": params}) + + # Check rate limit + t = time.time() + if (t - self.t) < self.rate_limit: + # Time is under rate limiter, wait to send + return + + # Time is over rate limiter, send now + data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list + + # POST equivalent to requests.post(self.url, json=data) + smart_request("post", self.url, json=data, retry=0, verbose=False) + + # Reset events and rate limit timer + self.events = [] + self.t = t + + +# Run below code on hub/utils init ------------------------------------------------------------------------------------- +events = Events() diff --git a/tracking/ultralytics/models/__init__.py b/tracking/ultralytics/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ead1e9230417102878174b3a011608f3c3d450db --- /dev/null +++ b/tracking/ultralytics/models/__init__.py @@ -0,0 +1,9 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .fastsam import FastSAM +from .nas import NAS +from .rtdetr import RTDETR +from .sam import SAM +from .yolo import YOLO, YOLOWorld + +__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import diff --git a/tracking/ultralytics/models/fastsam/__init__.py b/tracking/ultralytics/models/fastsam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c224ac8f9e8ef5f78b50558e4bb159674b1ba42 --- /dev/null +++ b/tracking/ultralytics/models/fastsam/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import FastSAM +from .predict import FastSAMPredictor +from .val import FastSAMValidator + +__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator" diff --git a/tracking/ultralytics/models/fastsam/model.py b/tracking/ultralytics/models/fastsam/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a2735d0aa8861d96e6e2238fc1ff2104674d8a55 --- /dev/null +++ b/tracking/ultralytics/models/fastsam/model.py @@ -0,0 +1,61 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics.engine.model import Model + +from .predict import FastSAMPredictor +from .val import FastSAMValidator + + +class FastSAM(Model): + """ + FastSAM model interface for segment anything tasks. + + This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything Model) + implementation, allowing for efficient and accurate image segmentation. + + Attributes: + model (str): Path to the pre-trained FastSAM model file. + task (str): The task type, set to "segment" for FastSAM models. + + Examples: + >>> from ultralytics import FastSAM + >>> model = FastSAM("last.pt") + >>> results = model.predict("ultralytics/assets/bus.jpg") + """ + + def __init__(self, model="FastSAM-x.pt"): + """Initialize the FastSAM model with the specified pre-trained weights.""" + if str(model) == "FastSAM.pt": + model = "FastSAM-x.pt" + assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models." + super().__init__(model=model, task="segment") + + def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs): + """ + Perform segmentation prediction on image or video source. + + Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these + prompts and passes them to the parent class predict method. + + Args: + source (str | PIL.Image | numpy.ndarray): Input source for prediction, can be a file path, URL, PIL image, + or numpy array. + stream (bool): Whether to enable real-time streaming mode for video inputs. + bboxes (list): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2], ...]. + points (list): Point coordinates for prompted segmentation in format [[x, y], ...]. + labels (list): Class labels for prompted segmentation. + texts (list): Text prompts for segmentation guidance. + **kwargs (Any): Additional keyword arguments passed to the predictor. + + Returns: + (list): List of Results objects containing the prediction results. + """ + prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts) + return super().predict(source, stream, prompts=prompts, **kwargs) + + @property + def task_map(self): + """Returns a dictionary mapping segment task to corresponding predictor and validator classes.""" + return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}} diff --git a/tracking/ultralytics/models/fastsam/predict.py b/tracking/ultralytics/models/fastsam/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..b16ac52f9c57768c81d37e732efc02ebe85363b3 --- /dev/null +++ b/tracking/ultralytics/models/fastsam/predict.py @@ -0,0 +1,170 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +from PIL import Image + +from ultralytics.models.yolo.segment import SegmentationPredictor +from ultralytics.utils import DEFAULT_CFG, checks +from ultralytics.utils.metrics import box_iou +from ultralytics.utils.ops import scale_masks + +from .utils import adjust_bboxes_to_image_border + + +class FastSAMPredictor(SegmentationPredictor): + """ + FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks. + + This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It + adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for + single-class segmentation. + + Attributes: + prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts). + device (torch.device): Device on which model and tensors are processed. + clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand. + clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand. + + Methods: + postprocess: Applies box postprocessing for FastSAM predictions. + prompt: Performs image segmentation inference based on various prompt types. + _clip_inference: Performs CLIP inference to calculate similarity between images and text prompts. + set_prompts: Sets prompts to be used during inference. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize the FastSAMPredictor with configuration and callbacks.""" + super().__init__(cfg, overrides, _callbacks) + self.prompts = {} + + def postprocess(self, preds, img, orig_imgs): + """ + Apply postprocessing to FastSAM predictions and handle prompts. + + Args: + preds (List[torch.Tensor]): Raw predictions from the model. + img (torch.Tensor): Input image tensor that was fed to the model. + orig_imgs (List[numpy.ndarray]): Original images before preprocessing. + + Returns: + (List[Results]): Processed results with prompts applied. + """ + bboxes = self.prompts.pop("bboxes", None) + points = self.prompts.pop("points", None) + labels = self.prompts.pop("labels", None) + texts = self.prompts.pop("texts", None) + results = super().postprocess(preds, img, orig_imgs) + for result in results: + full_box = torch.tensor( + [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32 + ) + boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape) + idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten() + if idx.numel() != 0: + result.boxes.xyxy[idx] = full_box + + return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts) + + def prompt(self, results, bboxes=None, points=None, labels=None, texts=None): + """ + Perform image segmentation inference based on cues like bounding boxes, points, and text prompts. + + Args: + results (Results | List[Results]): Original inference results from FastSAM models without any prompts. + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + texts (str | List[str], optional): Textual prompts, a list containing string objects. + + Returns: + (List[Results]): Output results filtered and determined by the provided prompts. + """ + if bboxes is None and points is None and texts is None: + return results + prompt_results = [] + if not isinstance(results, list): + results = [results] + for result in results: + if len(result) == 0: + prompt_results.append(result) + continue + masks = result.masks.data + if masks.shape[1:] != result.orig_shape: + masks = scale_masks(masks[None], result.orig_shape)[0] + # bboxes prompt + idx = torch.zeros(len(result), dtype=torch.bool, device=self.device) + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes]) + full_mask_areas = torch.sum(masks, dim=(1, 2)) + + union = bbox_areas[:, None] + full_mask_areas - mask_areas + idx[torch.argmax(mask_areas / union, dim=1)] = True + if points is not None: + points = torch.as_tensor(points, dtype=torch.int32, device=self.device) + points = points[None] if points.ndim == 1 else points + if labels is None: + labels = torch.ones(points.shape[0]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert len(labels) == len(points), ( + f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}" + ) + point_idx = ( + torch.ones(len(result), dtype=torch.bool, device=self.device) + if labels.sum() == 0 # all negative points + else torch.zeros(len(result), dtype=torch.bool, device=self.device) + ) + for point, label in zip(points, labels): + point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label) + idx |= point_idx + if texts is not None: + if isinstance(texts, str): + texts = [texts] + crop_ims, filter_idx = [], [] + for i, b in enumerate(result.boxes.xyxy.tolist()): + x1, y1, x2, y2 = (int(x) for x in b) + if masks[i].sum() <= 100: + filter_idx.append(i) + continue + crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1])) + similarity = self._clip_inference(crop_ims, texts) + text_idx = torch.argmax(similarity, dim=-1) # (M, ) + if len(filter_idx): + text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0) + idx[text_idx] = True + + prompt_results.append(result[idx]) + + return prompt_results + + def _clip_inference(self, images, texts): + """ + Perform CLIP inference to calculate similarity between images and text prompts. + + Args: + images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order. + texts (List[str]): List of prompt texts, each should be a string object. + + Returns: + (torch.Tensor): Similarity matrix between given images and texts with shape (M, N). + """ + try: + import clip + except ImportError: + checks.check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")): + self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device) + images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images]) + tokenized_text = clip.tokenize(texts).to(self.device) + image_features = self.clip_model.encode_image(images) + text_features = self.clip_model.encode_text(tokenized_text) + image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512) + text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512) + return (image_features * text_features[:, None]).sum(-1) # (M, N) + + def set_prompts(self, prompts): + """Set prompts to be used during inference.""" + self.prompts = prompts diff --git a/tracking/ultralytics/models/fastsam/utils.py b/tracking/ultralytics/models/fastsam/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51299a2af5167615f59f58678fefdfae92bf4807 --- /dev/null +++ b/tracking/ultralytics/models/fastsam/utils.py @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + + +def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): + """ + Adjust bounding boxes to stick to image border if they are within a certain threshold. + + Args: + boxes (torch.Tensor): Bounding boxes with shape (n, 4) in xyxy format. + image_shape (Tuple[int, int]): Image dimensions as (height, width). + threshold (int): Pixel threshold for considering a box close to the border. + + Returns: + boxes (torch.Tensor): Adjusted bounding boxes with shape (n, 4). + """ + # Image dimensions + h, w = image_shape + + # Adjust boxes that are close to image borders + boxes[boxes[:, 0] < threshold, 0] = 0 # x1 + boxes[boxes[:, 1] < threshold, 1] = 0 # y1 + boxes[boxes[:, 2] > w - threshold, 2] = w # x2 + boxes[boxes[:, 3] > h - threshold, 3] = h # y2 + return boxes diff --git a/tracking/ultralytics/models/fastsam/val.py b/tracking/ultralytics/models/fastsam/val.py new file mode 100644 index 0000000000000000000000000000000000000000..9823a3562a8ac76311245f4bf43a93c65f2bc84a --- /dev/null +++ b/tracking/ultralytics/models/fastsam/val.py @@ -0,0 +1,40 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo.segment import SegmentationValidator +from ultralytics.utils.metrics import SegmentMetrics + + +class FastSAMValidator(SegmentationValidator): + """ + Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework. + + Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class + sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled + to avoid errors during validation. + + Attributes: + dataloader (torch.utils.data.DataLoader): The data loader object used for validation. + save_dir (Path): The directory where validation results will be saved. + pbar (tqdm.tqdm): A progress bar object for displaying validation progress. + args (SimpleNamespace): Additional arguments for customization of the validation process. + _callbacks (list): List of callback functions to be invoked during validation. + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + args (SimpleNamespace): Configuration for the validator. + _callbacks (list): List of callback functions to be invoked during validation. + + Notes: + Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors. + """ + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.args.task = "segment" + self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors + self.metrics = SegmentMetrics(save_dir=self.save_dir) diff --git a/tracking/ultralytics/models/nas/__init__.py b/tracking/ultralytics/models/nas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c36c0a42f0331f925a7086d82613b7e4f729b7bb --- /dev/null +++ b/tracking/ultralytics/models/nas/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import NAS +from .predict import NASPredictor +from .val import NASValidator + +__all__ = "NASPredictor", "NASValidator", "NAS" diff --git a/tracking/ultralytics/models/nas/model.py b/tracking/ultralytics/models/nas/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0d03b62bc191cb30ef5d574cdf4d2c9e93d73842 --- /dev/null +++ b/tracking/ultralytics/models/nas/model.py @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +YOLO-NAS model interface. + +Examples: + >>> from ultralytics import NAS + >>> model = NAS("yolo_nas_s") + >>> results = model.predict("ultralytics/assets/bus.jpg") +""" + +from pathlib import Path + +import torch + +from ultralytics.engine.model import Model +from ultralytics.utils import DEFAULT_CFG_DICT +from ultralytics.utils.downloads import attempt_download_asset +from ultralytics.utils.torch_utils import model_info + +from .predict import NASPredictor +from .val import NASValidator + + +class NAS(Model): + """ + YOLO NAS model for object detection. + + This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. + It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models. + + Attributes: + model (torch.nn.Module): The loaded YOLO-NAS model. + task (str): The task type for the model, defaults to 'detect'. + predictor (NASPredictor): The predictor instance for making predictions. + validator (NASValidator): The validator instance for model validation. + + Examples: + >>> from ultralytics import NAS + >>> model = NAS("yolo_nas_s") + >>> results = model.predict("ultralytics/assets/bus.jpg") + + Notes: + YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. + """ + + def __init__(self, model: str = "yolo_nas_s.pt") -> None: + """Initialize the NAS model with the provided or default model.""" + assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." + super().__init__(model, task="detect") + + def _load(self, weights: str, task=None) -> None: + """ + Load an existing NAS model weights or create a new NAS model with pretrained weights. + + Args: + weights (str): Path to the model weights file or model name. + task (str, optional): Task type for the model. + """ + import super_gradients + + suffix = Path(weights).suffix + if suffix == ".pt": + self.model = torch.load(attempt_download_asset(weights)) + elif suffix == "": + self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") + + # Override the forward method to ignore additional arguments + def new_forward(x, *args, **kwargs): + """Ignore additional __call__ arguments.""" + return self.model._original_forward(x) + + self.model._original_forward = self.model.forward + self.model.forward = new_forward + + # Standardize model + self.model.fuse = lambda verbose=True: self.model + self.model.stride = torch.tensor([32]) + self.model.names = dict(enumerate(self.model._class_names)) + self.model.is_fused = lambda: False # for info() + self.model.yaml = {} # for info() + self.model.pt_path = weights # for export() + self.model.task = "detect" # for export() + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export() + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Log model information. + + Args: + detailed (bool): Show detailed information about model. + verbose (bool): Controls verbosity. + + Returns: + (dict): Model information dictionary. + """ + return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) + + @property + def task_map(self): + """Return a dictionary mapping tasks to respective predictor and validator classes.""" + return {"detect": {"predictor": NASPredictor, "validator": NASValidator}} diff --git a/tracking/ultralytics/models/nas/predict.py b/tracking/ultralytics/models/nas/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0b44716161b873993675202fa3dc24506991d5 --- /dev/null +++ b/tracking/ultralytics/models/nas/predict.py @@ -0,0 +1,60 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import ops + + +class NASPredictor(BasePredictor): + """ + Ultralytics YOLO NAS Predictor for object detection. + + This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the + raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and + scaling the bounding boxes to fit the original image dimensions. + + Attributes: + args (Namespace): Namespace containing various configurations for post-processing including confidence threshold, + IoU threshold, agnostic NMS flag, maximum detections, and class filtering options. + model (torch.nn.Module): The YOLO NAS model used for inference. + batch (list): Batch of inputs for processing. + + Examples: + >>> from ultralytics import NAS + >>> model = NAS("yolo_nas_s") + >>> predictor = model.predictor + + Assume that raw_preds, img, orig_imgs are available + >>> results = predictor.postprocess(raw_preds, img, orig_imgs) + + Notes: + Typically, this class is not instantiated directly. It is used internally within the `NAS` class. + """ + + def postprocess(self, preds_in, img, orig_imgs): + """Postprocess predictions and returns a list of Results objects.""" + # Convert boxes from xyxy to xywh format and concatenate with class scores + boxes = ops.xyxy2xywh(preds_in[0][0]) + preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) + + # Apply non-maximum suppression to filter overlapping detections + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): + # Scale bounding boxes to match original image dimensions + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results diff --git a/tracking/ultralytics/models/nas/val.py b/tracking/ultralytics/models/nas/val.py new file mode 100644 index 0000000000000000000000000000000000000000..b45064b9079644540a35afea4103afcfe4932952 --- /dev/null +++ b/tracking/ultralytics/models/nas/val.py @@ -0,0 +1,42 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import ops + +__all__ = ["NASValidator"] + + +class NASValidator(DetectionValidator): + """ + Ultralytics YOLO NAS Validator for object detection. + + Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions + generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes, + ultimately producing the final detections. + + Attributes: + args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU + thresholds. + lb (torch.Tensor): Optional tensor for multilabel NMS. + + Examples: + >>> from ultralytics import NAS + >>> model = NAS("yolo_nas_s") + >>> validator = model.validator + Assumes that raw_preds are available + >>> final_preds = validator.postprocess(raw_preds) + + Notes: + This class is generally not instantiated directly but is used internally within the `NAS` class. + """ + + def postprocess(self, preds_in): + """Apply Non-maximum suppression to prediction outputs.""" + boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh + preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute + return super().postprocess( + preds, + max_time_img=0.5, + ) diff --git a/tracking/ultralytics/models/rtdetr/__init__.py b/tracking/ultralytics/models/rtdetr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d038d652cfcfb7e8ddf0424981b558dbbeb270 --- /dev/null +++ b/tracking/ultralytics/models/rtdetr/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import RTDETR +from .predict import RTDETRPredictor +from .val import RTDETRValidator + +__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR" diff --git a/tracking/ultralytics/models/rtdetr/model.py b/tracking/ultralytics/models/rtdetr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..54a7eb6ea26948fa103525be7fa8bdd2b3a0e888 --- /dev/null +++ b/tracking/ultralytics/models/rtdetr/model.py @@ -0,0 +1,63 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. + +RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. +It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy. + +References: + https://arxiv.org/pdf/2304.08069.pdf +""" + +from ultralytics.engine.model import Model +from ultralytics.nn.tasks import RTDETRDetectionModel + +from .predict import RTDETRPredictor +from .train import RTDETRTrainer +from .val import RTDETRValidator + + +class RTDETR(Model): + """ + Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector. + + This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query + selection, and adaptable inference speed. + + Attributes: + model (str): Path to the pre-trained model. + + Examples: + >>> from ultralytics import RTDETR + >>> model = RTDETR("rtdetr-l.pt") + >>> results = model("image.jpg") + """ + + def __init__(self, model: str = "rtdetr-l.pt") -> None: + """ + Initialize the RT-DETR model with the given pre-trained model file. + + Args: + model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats. + + Raises: + NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. + """ + super().__init__(model=model, task="detect") + + @property + def task_map(self) -> dict: + """ + Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes. + + Returns: + (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. + """ + return { + "detect": { + "predictor": RTDETRPredictor, + "validator": RTDETRValidator, + "trainer": RTDETRTrainer, + "model": RTDETRDetectionModel, + } + } diff --git a/tracking/ultralytics/models/rtdetr/predict.py b/tracking/ultralytics/models/rtdetr/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..c6947b09920e047e72bf94342793ea1d09cee2c1 --- /dev/null +++ b/tracking/ultralytics/models/rtdetr/predict.py @@ -0,0 +1,84 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import ops + + +class RTDETRPredictor(BasePredictor): + """ + RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions. + + This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. + It supports key features like efficient hybrid encoding and IoU-aware query selection. + + Attributes: + imgsz (int): Image size for inference (must be square and scale-filled). + args (dict): Argument overrides for the predictor. + model (torch.nn.Module): The loaded RT-DETR model. + batch (list): Current batch of processed inputs. + + Examples: + >>> from ultralytics.utils import ASSETS + >>> from ultralytics.models.rtdetr import RTDETRPredictor + >>> args = dict(model="rtdetr-l.pt", source=ASSETS) + >>> predictor = RTDETRPredictor(overrides=args) + >>> predictor.predict_cli() + """ + + def postprocess(self, preds, img, orig_imgs): + """ + Postprocess the raw predictions from the model to generate bounding boxes and confidence scores. + + The method filters detections based on confidence and class if specified in `self.args`. It converts + model predictions to Results objects containing properly scaled bounding boxes. + + Args: + preds (List | Tuple): List of [predictions, extra] from the model, where predictions contain + bounding boxes and scores. + img (torch.Tensor): Processed input images with shape (N, 3, H, W). + orig_imgs (List | torch.Tensor): Original, unprocessed images. + + Returns: + (List[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, + and class labels. + """ + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + + nd = preds[0].shape[-1] + bboxes, scores = preds[0].split((4, nd - 4), dim=-1) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4) + bbox = ops.xywh2xyxy(bbox) + max_score, cls = score.max(-1, keepdim=True) # (300, 1) + idx = max_score.squeeze(-1) > self.args.conf # (300, ) + if self.args.classes is not None: + idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx + pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter + oh, ow = orig_img.shape[:2] + pred[..., [0, 2]] *= ow # scale x coordinates to original width + pred[..., [1, 3]] *= oh # scale y coordinates to original height + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results + + def pre_transform(self, im): + """ + Pre-transforms the input images before feeding them into the model for inference. The input images are + letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scale_filled. + + Args: + im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list. + + Returns: + (list): List of pre-transformed images ready for model inference. + """ + letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True) + return [letterbox(image=x) for x in im] diff --git a/tracking/ultralytics/models/rtdetr/train.py b/tracking/ultralytics/models/rtdetr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf557c6b16c43f4d5b6db407c854d837bf56fe4 --- /dev/null +++ b/tracking/ultralytics/models/rtdetr/train.py @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +import torch + +from ultralytics.models.yolo.detect import DetectionTrainer +from ultralytics.nn.tasks import RTDETRDetectionModel +from ultralytics.utils import RANK, colorstr + +from .val import RTDETRDataset, RTDETRValidator + + +class RTDETRTrainer(DetectionTrainer): + """ + Trainer class for the RT-DETR model developed by Baidu for real-time object detection. + + This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR. + The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference + speed. + + Attributes: + loss_names (Tuple[str]): Names of the loss components used for training. + data (dict): Dataset configuration containing class count and other parameters. + args (dict): Training arguments and hyperparameters. + save_dir (Path): Directory to save training results. + test_loader (DataLoader): DataLoader for validation/testing data. + + Notes: + - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument. + - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching. + + Examples: + >>> from ultralytics.models.rtdetr.train import RTDETRTrainer + >>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3) + >>> trainer = RTDETRTrainer(overrides=args) + >>> trainer.train() + """ + + def get_model(self, cfg=None, weights=None, verbose=True): + """ + Initialize and return an RT-DETR model for object detection tasks. + + Args: + cfg (dict, optional): Model configuration. + weights (str, optional): Path to pre-trained model weights. + verbose (bool): Verbose logging if True. + + Returns: + (RTDETRDetectionModel): Initialized model. + """ + model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build and return an RT-DETR dataset for training or validation. + + Args: + img_path (str): Path to the folder containing images. + mode (str): Dataset mode, either 'train' or 'val'. + batch (int, optional): Batch size for rectangle training. + + Returns: + (RTDETRDataset): Dataset object for the specific mode. + """ + return RTDETRDataset( + img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=mode == "train", + hyp=self.args, + rect=False, + cache=self.args.cache or None, + single_cls=self.args.single_cls or False, + prefix=colorstr(f"{mode}: "), + classes=self.args.classes, + data=self.data, + fraction=self.args.fraction if mode == "train" else 1.0, + ) + + def get_validator(self): + """Returns a DetectionValidator suitable for RT-DETR model validation.""" + self.loss_names = "giou_loss", "cls_loss", "l1_loss" + return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + + def preprocess_batch(self, batch): + """ + Preprocess a batch of images by scaling and converting to float format. + + Args: + batch (dict): Dictionary containing a batch of images, bboxes, and labels. + + Returns: + (dict): Preprocessed batch with ground truth bounding boxes and classes separated by batch index. + """ + batch = super().preprocess_batch(batch) + bs = len(batch["img"]) + batch_idx = batch["batch_idx"] + gt_bbox, gt_class = [], [] + for i in range(bs): + gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device)) + gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) + return batch diff --git a/tracking/ultralytics/models/rtdetr/val.py b/tracking/ultralytics/models/rtdetr/val.py new file mode 100644 index 0000000000000000000000000000000000000000..85f9bd76758cbc7547e881e7673951936fcc8ae8 --- /dev/null +++ b/tracking/ultralytics/models/rtdetr/val.py @@ -0,0 +1,168 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.data import YOLODataset +from ultralytics.data.augment import Compose, Format, v8_transforms +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import colorstr, ops + +__all__ = ("RTDETRValidator",) # tuple or list + + +class RTDETRDataset(YOLODataset): + """ + Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class. + + This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for + real-time detection and tracking tasks. + """ + + def __init__(self, *args, data=None, **kwargs): + """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.""" + super().__init__(*args, data=data, **kwargs) + + def load_image(self, i, rect_mode=False): + """Loads 1 image from dataset index 'i', returns (im, resized hw).""" + return super().load_image(i=i, rect_mode=rect_mode) + + def build_transforms(self, hyp=None): + """ + Build transformation pipeline for the dataset. + + Args: + hyp (dict, optional): Hyperparameters for transformations. + + Returns: + (Compose): Composition of transformation functions. + """ + if self.augment: + hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 + hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 + transforms = v8_transforms(self, self.imgsz, hyp, stretch=True) + else: + # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)]) + transforms = Compose([]) + transforms.append( + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + ) + ) + return transforms + + +class RTDETRValidator(DetectionValidator): + """ + RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for + the RT-DETR (Real-Time DETR) object detection model. + + The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for + post-processing, and updates evaluation metrics accordingly. + + Examples: + >>> from ultralytics.models.rtdetr import RTDETRValidator + >>> args = dict(model="rtdetr-l.pt", data="coco8.yaml") + >>> validator = RTDETRValidator(args=args) + >>> validator() + + Note: + For further details on the attributes and methods, refer to the parent DetectionValidator class. + """ + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build an RTDETR Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. + + Returns: + (RTDETRDataset): Dataset configured for RT-DETR validation. + """ + return RTDETRDataset( + img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=False, # no augmentation + hyp=self.args, + rect=False, # no rect + cache=self.args.cache or None, + prefix=colorstr(f"{mode}: "), + data=self.data, + ) + + def postprocess(self, preds): + """ + Apply Non-maximum suppression to prediction outputs. + + Args: + preds (List | Tuple | torch.Tensor): Raw predictions from the model. + + Returns: + (List[torch.Tensor]): List of processed predictions for each image in batch. + """ + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + + bs, _, nd = preds[0].shape + bboxes, scores = preds[0].split((4, nd - 4), dim=-1) + bboxes *= self.args.imgsz + outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs + for i, bbox in enumerate(bboxes): # (300, 4) + bbox = ops.xywh2xyxy(bbox) + score, cls = scores[i].max(-1) # (300, ) + # Do not need threshold for evaluation as only got 300 boxes here + # idx = score > self.args.conf + pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter + # Sort by confidence to correctly get internal metrics + pred = pred[score.argsort(descending=True)] + outputs[i] = pred # [idx] + + return outputs + + def _prepare_batch(self, si, batch): + """ + Prepares a batch for validation by applying necessary transformations. + + Args: + si (int): Batch index. + batch (dict): Batch data containing images and annotations. + + Returns: + (dict): Prepared batch with transformed annotations. + """ + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) # target boxes + bbox[..., [0, 2]] *= ori_shape[1] # native-space pred + bbox[..., [1, 3]] *= ori_shape[0] # native-space pred + return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad} + + def _prepare_pred(self, pred, pbatch): + """ + Prepares predictions by scaling bounding boxes to original image dimensions. + + Args: + pred (torch.Tensor): Raw predictions. + pbatch (dict): Prepared batch information. + + Returns: + (torch.Tensor): Predictions scaled to original image dimensions. + """ + predn = pred.clone() + predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred + predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred + return predn.float() diff --git a/tracking/ultralytics/models/sam/__init__.py b/tracking/ultralytics/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95ed6016c710d2437e1a315a073e411f54279e80 --- /dev/null +++ b/tracking/ultralytics/models/sam/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import SAM +from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor + +__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list of exportable items diff --git a/tracking/ultralytics/models/sam/amg.py b/tracking/ultralytics/models/sam/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c577c0bd298ec0032bd2fa77034d63a83591e3 --- /dev/null +++ b/tracking/ultralytics/models/sam/amg.py @@ -0,0 +1,239 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +from itertools import product +from typing import Any, Generator, List, Tuple + +import numpy as np +import torch + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + """Yields batches of data from input arguments with specified batch size for efficient processing.""" + assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. + + The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at + high and low values. + + Args: + masks (torch.Tensor): Batch of predicted mask logits. + mask_threshold (float): Threshold value for creating binary masks. + threshold_offset (float): Offset applied to the threshold for creating high and low binary masks. + + Returns: + (torch.Tensor): Stability scores for each mask in the batch. + + Notes: + - One mask is always contained inside the other. + - Memory is saved by preventing unnecessary cast to torch.int64. + + Examples: + >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks + >>> mask_threshold = 0.5 + >>> threshold_offset = 0.1 + >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset) + """ + intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + return np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + + +def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: + """Generates point grids for multiple crop layers with varying scales and densities.""" + return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)] + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions. + + Args: + im_size (Tuple[int, ...]): Height and width of the input image. + n_layers (int): Number of layers to generate crop boxes for. + overlap_ratio (float): Ratio of overlap between adjacent crop boxes. + + Returns: + (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format. + (List[int]): List of layer indices corresponding to each crop box. + + Examples: + >>> im_size = (800, 1200) # Height, width + >>> n_layers = 3 + >>> overlap_ratio = 0.25 + >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio) + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + """Calculates the length of each crop given the original length, number of crops, and overlap.""" + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + """Uncrop bounding boxes by adding the crop box offset to their coordinates.""" + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + """Uncrop points by adding the crop box offset to their coordinates.""" + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor: + """Uncrop masks by padding them to the original image size, handling coordinate transformations.""" + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions or holes in a mask based on area threshold and mode. + + Args: + mask (np.ndarray): Binary mask to process. + area_thresh (float): Area threshold below which regions will be removed. + mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions. + + Returns: + (np.ndarray): Processed binary mask with small regions removed. + (bool): Whether any regions were modified. + + Examples: + >>> mask = np.zeros((100, 100), dtype=np.bool_) + >>> mask[40:60, 40:60] = True # Create a square + >>> mask[45:55, 45:55] = False # Create a hole + >>> processed_mask, modified = remove_small_regions(mask, 50, "holes") + """ + import cv2 # type: ignore + + assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid" + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if not small_regions: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + # If every region is below threshold, keep largest + fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates bounding boxes in XYXY format around binary masks. + + Args: + masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W). + + Returns: + (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4). + + Notes: + - Handles empty masks by returning zero boxes. + - Preserves input tensor dimensions in the output. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0) + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0] diff --git a/tracking/ultralytics/models/sam/build.py b/tracking/ultralytics/models/sam/build.py new file mode 100644 index 0000000000000000000000000000000000000000..47c9d5a345ba4d9f74c54b1f2427874d82739fa4 --- /dev/null +++ b/tracking/ultralytics/models/sam/build.py @@ -0,0 +1,358 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch + +from ultralytics.utils.downloads import attempt_download_asset + +from .modules.decoders import MaskDecoder +from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder +from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer +from .modules.sam import SAM2Model, SAMModel +from .modules.tiny_encoder import TinyViT +from .modules.transformer import TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters.""" + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +def build_sam_vit_l(checkpoint=None): + """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters.""" + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint.""" + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_mobile_sam(checkpoint=None): + """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.""" + return _build_sam( + encoder_embed_dim=[64, 128, 160, 320], + encoder_depth=[2, 2, 6, 2], + encoder_num_heads=[2, 4, 5, 10], + encoder_global_attn_indexes=None, + mobile_sam=True, + checkpoint=checkpoint, + ) + + +def build_sam2_t(checkpoint=None): + """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 7, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[5, 7, 9], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_s(checkpoint=None): + """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 11, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[7, 10, 13], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_b(checkpoint=None): + """Builds and returns a SAM2 base-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=112, + encoder_stages=[2, 3, 16, 3], + encoder_num_heads=2, + encoder_global_att_blocks=[12, 16, 20], + encoder_window_spec=[8, 4, 14, 7], + encoder_window_spatial_size=[14, 14], + encoder_backbone_channel_list=[896, 448, 224, 112], + checkpoint=checkpoint, + ) + + +def build_sam2_l(checkpoint=None): + """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=144, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[23, 33, 43], + encoder_window_spec=[8, 4, 16, 8], + encoder_backbone_channel_list=[1152, 576, 288, 144], + checkpoint=checkpoint, + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + mobile_sam=False, +): + """ + Builds a Segment Anything Model (SAM) with specified encoder parameters. + + Args: + encoder_embed_dim (int | List[int]): Embedding dimension for the encoder. + encoder_depth (int | List[int]): Depth of the encoder. + encoder_num_heads (int | List[int]): Number of attention heads in the encoder. + encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder. + checkpoint (str | None): Path to the model checkpoint file. + mobile_sam (bool): Whether to build a Mobile-SAM model. + + Returns: + (SAMModel): A Segment Anything Model instance with the specified architecture. + + Examples: + >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11]) + >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True) + """ + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + image_encoder = ( + TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=encoder_embed_dim, + depths=encoder_depth, + num_heads=encoder_num_heads, + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ) + if mobile_sam + else ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + ) + sam = SAMModel( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + if checkpoint is not None: + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + sam.eval() + return sam + + +def _build_sam2( + encoder_embed_dim=1280, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[7, 15, 23, 31], + encoder_backbone_channel_list=[1152, 576, 288, 144], + encoder_window_spatial_size=[7, 7], + encoder_window_spec=[8, 4, 16, 8], + checkpoint=None, +): + """ + Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters. + + Args: + encoder_embed_dim (int): Embedding dimension for the encoder. + encoder_stages (List[int]): Number of blocks in each stage of the encoder. + encoder_num_heads (int): Number of attention heads in the encoder. + encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder. + encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone. + encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings. + encoder_window_spec (List[int]): Window specifications for each stage of the encoder. + checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights. + + Returns: + (SAM2Model): A configured and initialized SAM2 model. + + Examples: + >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2]) + >>> sam2_model.eval() + """ + image_encoder = ImageEncoder( + trunk=Hiera( + embed_dim=encoder_embed_dim, + num_heads=encoder_num_heads, + stages=encoder_stages, + global_att_blocks=encoder_global_att_blocks, + window_pos_embed_bkg_spatial_size=encoder_window_spatial_size, + window_spec=encoder_window_spec, + ), + neck=FpnNeck( + d_model=256, + backbone_channel_list=encoder_backbone_channel_list, + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + scalp=1, + ) + memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) + memory_encoder = MemoryEncoder(out_dim=64) + + is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint + sam2 = SAM2Model( + image_encoder=image_encoder, + memory_attention=memory_attention, + memory_encoder=memory_encoder, + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + use_high_res_features_in_sam=True, + multimask_output_in_sam=True, + iou_prediction_use_sigmoid=True, + use_obj_ptrs_in_encoder=True, + add_tpos_enc_to_obj_ptrs=True, + only_obj_ptrs_in_the_past_for_eval=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + multimask_output_for_tracking=True, + use_multimask_token_for_obj_ptr=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_mlp_for_obj_ptr_proj=True, + compile_image_encoder=False, + no_obj_embed_spatial=is_sam2_1, + proj_tpos_enc_in_obj_ptrs=is_sam2_1, + use_signed_tpos_enc_to_obj_ptrs=is_sam2_1, + sam_mask_decoder_extra_args=dict( + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + ), + ) + + if checkpoint is not None: + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch.load(f)["model"] + sam2.load_state_dict(state_dict) + sam2.eval() + return sam2 + + +sam_model_map = { + "sam_h.pt": build_sam_vit_h, + "sam_l.pt": build_sam_vit_l, + "sam_b.pt": build_sam_vit_b, + "mobile_sam.pt": build_mobile_sam, + "sam2_t.pt": build_sam2_t, + "sam2_s.pt": build_sam2_s, + "sam2_b.pt": build_sam2_b, + "sam2_l.pt": build_sam2_l, + "sam2.1_t.pt": build_sam2_t, + "sam2.1_s.pt": build_sam2_s, + "sam2.1_b.pt": build_sam2_b, + "sam2.1_l.pt": build_sam2_l, +} + + +def build_sam(ckpt="sam_b.pt"): + """ + Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint. + + Args: + ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model. + + Returns: + (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance. + + Raises: + FileNotFoundError: If the provided checkpoint is not a supported SAM model. + + Examples: + >>> sam_model = build_sam("sam_b.pt") + >>> sam_model = build_sam("path/to/custom_checkpoint.pt") + + Notes: + Supported pre-defined models include: + - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt' + - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt' + """ + model_builder = None + ckpt = str(ckpt) # to allow Path ckpt types + for k in sam_model_map.keys(): + if ckpt.endswith(k): + model_builder = sam_model_map.get(k) + + if not model_builder: + raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") + + return model_builder(ckpt) diff --git a/tracking/ultralytics/models/sam/model.py b/tracking/ultralytics/models/sam/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a31a92af53bc4faba5dfb084898f99f53f92a81a --- /dev/null +++ b/tracking/ultralytics/models/sam/model.py @@ -0,0 +1,169 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +SAM model interface. + +This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image +segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis, +and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new +image distributions and tasks without prior knowledge. + +Key Features: + - Promptable segmentation + - Real-time performance + - Zero-shot transfer capabilities + - Trained on SA-1B dataset +""" + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.utils.torch_utils import model_info + +from .build import build_sam +from .predict import Predictor, SAM2Predictor + + +class SAM(Model): + """ + SAM (Segment Anything Model) interface class for real-time image segmentation tasks. + + This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for + promptable segmentation with versatility in image analysis. It supports various prompts such as bounding + boxes, points, or labels, and features zero-shot performance capabilities. + + Attributes: + model (torch.nn.Module): The loaded SAM model. + is_sam2 (bool): Indicates whether the model is SAM2 variant. + task (str): The task type, set to "segment" for SAM models. + + Methods: + predict: Performs segmentation prediction on the given image or video source. + info: Logs information about the SAM model. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> results = sam.predict("image.jpg", points=[[500, 375]]) + >>> for r in results: + >>> print(f"Detected {len(r.masks)} masks") + """ + + def __init__(self, model="sam_b.pt") -> None: + """ + Initialize the SAM (Segment Anything Model) instance. + + Args: + model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension. + + Raises: + NotImplementedError: If the model file extension is not .pt or .pth. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> print(sam.is_sam2) + """ + if model and Path(model).suffix not in {".pt", ".pth"}: + raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") + self.is_sam2 = "sam2" in Path(model).stem + super().__init__(model=model, task="segment") + + def _load(self, weights: str, task=None): + """ + Load the specified weights into the SAM model. + + Args: + weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters. + task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> sam._load("path/to/custom_weights.pt") + """ + self.model = build_sam(weights) + + def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): + """ + Perform segmentation prediction on the given image or video source. + + Args: + source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or + a numpy.ndarray object. + stream (bool): If True, enables real-time streaming. + bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation. + points (List[List[float]] | None): List of points for prompted segmentation. + labels (List[int] | None): List of labels for prompted segmentation. + **kwargs (Any): Additional keyword arguments for prediction. + + Returns: + (list): The model predictions. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> results = sam.predict("image.jpg", points=[[500, 375]]) + >>> for r in results: + ... print(f"Detected {len(r.masks)} masks") + """ + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) + kwargs = {**overrides, **kwargs} + prompts = dict(bboxes=bboxes, points=points, labels=labels) + return super().predict(source, stream, prompts=prompts, **kwargs) + + def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): + """ + Perform segmentation prediction on the given image or video source. + + This method is an alias for the 'predict' method, providing a convenient way to call the SAM model + for segmentation tasks. + + Args: + source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image + object, or a numpy.ndarray object. + stream (bool): If True, enables real-time streaming. + bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation. + points (List[List[float]] | None): List of points for prompted segmentation. + labels (List[int] | None): List of labels for prompted segmentation. + **kwargs (Any): Additional keyword arguments to be passed to the predict method. + + Returns: + (list): The model predictions, typically containing segmentation masks and other relevant information. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> results = sam("image.jpg", points=[[500, 375]]) + >>> print(f"Detected {len(results[0].masks)} masks") + """ + return self.predict(source, stream, bboxes, points, labels, **kwargs) + + def info(self, detailed=False, verbose=True): + """ + Log information about the SAM model. + + Args: + detailed (bool): If True, displays detailed information about the model layers and operations. + verbose (bool): If True, prints the information to the console. + + Returns: + (tuple): A tuple containing the model's information (string representations of the model). + + Examples: + >>> sam = SAM("sam_b.pt") + >>> info = sam.info() + >>> print(info[0]) # Print summary information + """ + return model_info(self.model, detailed=detailed, verbose=verbose) + + @property + def task_map(self): + """ + Provide a mapping from the 'segment' task to its corresponding 'Predictor'. + + Returns: + (Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding Predictor + class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> task_map = sam.task_map + >>> print(task_map) + {'segment': {'predictor': }} + """ + return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}} diff --git a/tracking/ultralytics/models/sam/modules/__init__.py b/tracking/ultralytics/models/sam/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/tracking/ultralytics/models/sam/modules/blocks.py b/tracking/ultralytics/models/sam/modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..7dab134cb8267ea6205f3c07f17be3c560a559cd --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/blocks.py @@ -0,0 +1,1129 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy +import math +from functools import partial +from typing import Any, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock + +from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer +from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition + + +class DropPath(nn.Module): + """ + Implements stochastic depth regularization for neural networks during training. + + Attributes: + drop_prob (float): Probability of dropping a path during training. + scale_by_keep (bool): Whether to scale the output by the keep probability. + + Methods: + forward: Applies stochastic depth to input tensor during training, with optional scaling. + + Examples: + >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True) + >>> x = torch.randn(32, 64, 224, 224) + >>> output = drop_path(x) + """ + + def __init__(self, drop_prob=0.0, scale_by_keep=True): + """Initialize DropPath module for stochastic depth regularization during training.""" + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Applies stochastic depth to input tensor during training, with optional scaling.""" + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class MaskDownSampler(nn.Module): + """ + A mask downsampling and embedding module for efficient processing of input masks. + + This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks + while expanding their channel dimensions using convolutional layers, layer normalization, and activation + functions. + + Attributes: + encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and + activation functions for downsampling and embedding masks. + + Methods: + forward: Downsamples and encodes input mask to embed_dim channels. + + Examples: + >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16) + >>> input_mask = torch.randn(1, 1, 256, 256) + >>> output = mask_downsampler(input_mask) + >>> print(output.shape) + torch.Size([1, 256, 16, 16]) + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + """Initializes a mask downsampler module for progressive downsampling and channel expansion.""" + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" + return self.encoder(x) + + +class CXBlock(nn.Module): + """ + ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering improved performance and + flexibility in feature extraction. + + Attributes: + dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer. + norm (LayerNorm2d): Layer normalization applied to channels. + pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. + act (nn.GELU): GELU activation function. + pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. + gamma (nn.Parameter | None): Learnable scale parameter for layer scaling. + drop_path (nn.Module): DropPath layer for stochastic depth regularization. + + Methods: + forward: Processes the input tensor through the ConvNeXt block. + + Examples: + >>> import torch + >>> x = torch.randn(1, 64, 56, 56) + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + """ + Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering improved performance and + flexibility in feature extraction. + + Args: + dim (int): Number of input channels. + kernel_size (int): Size of the convolutional kernel. + padding (int): Padding size for the convolution. + drop_path (float): Stochastic depth rate. + layer_scale_init_value (float): Initial value for Layer Scale. + use_dwconv (bool): Whether to use depthwise convolution. + + Examples: + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 32, 32]) + """ + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection.""" + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + """ + A module for fusing features through multiple layers of a neural network. + + This class applies a series of identical layers to an input tensor, optionally projecting the input first. + + Attributes: + proj (nn.Module): An optional input projection layer. Identity if no projection is needed. + layers (nn.ModuleList): A list of identical layers to be applied sequentially. + + Methods: + forward: Applies the fuser to an input tensor. + + Examples: + >>> layer = CXBlock(dim=256) + >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True) + >>> x = torch.randn(1, 256, 32, 32) + >>> output = fuser(x) + >>> print(output.shape) + torch.Size([1, 256, 32, 32]) + """ + + def __init__(self, layer, num_layers, dim=None, input_projection=False): + """ + Initializes the Fuser module for feature fusion through multiple layers. + + This module creates a sequence of identical layers and optionally applies an input projection. + + Args: + layer (nn.Module): The layer to be replicated in the fuser. + num_layers (int): The number of times to replicate the layer. + dim (int | None): The dimension for input projection, if used. + input_projection (bool): Whether to use input projection. + + Examples: + >>> layer = nn.Linear(64, 64) + >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True) + >>> input_tensor = torch.randn(1, 64) + >>> output = fuser(input_tensor) + """ + super().__init__() + self.proj = nn.Identity() + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + """Applies a series of layers to the input tensor, optionally projecting it first.""" + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock): + """ + A two-way attention block for performing self-attention and cross-attention in both directions. + + This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on + sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and + cross-attention from dense to sparse inputs. + + Attributes: + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after the first attention block. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization after the second attention block. + mlp (MLP): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after the MLP block. + norm4 (nn.LayerNorm): Layer normalization after the third attention block. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer. + + Methods: + forward: Processes input through the attention blocks and MLP. + + Examples: + >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8) + >>> sparse_input = torch.randn(1, 100, 256) + >>> dense_input = torch.randn(1, 256, 16, 16) + >>> sparse_output, dense_output = block(sparse_input, dense_input) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions. + + This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse + inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention + from dense to sparse inputs. + + Args: + embedding_dim (int): The channel dimension of the embeddings. + num_heads (int): The number of heads in the attention layers. + mlp_dim (int): The hidden dimension of the MLP block. + activation (Type[nn.Module]): The activation function of the MLP block. + attention_downsample_rate (int): The downsample rate for attention computations. + skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + + Examples: + >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> sparse_inputs = torch.randn(1, 100, 256) + >>> dense_inputs = torch.randn(1, 256, 32, 32) + >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs) + """ + super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe) + self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation) + + +class SAM2TwoWayTransformer(TwoWayTransformer): + """ + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an + input image using queries with supplied positional embeddings. It is particularly useful for tasks like + object detection, image segmentation, and point cloud processing. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes input image embeddings and query embeddings through the transformer. + + Examples: + >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 64, 64) + >>> query_embedding = torch.randn(1, 100, 256) + >>> output = transformer(image_embedding, query_embedding) + >>> print(output[0].shape, output[1].shape) + torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + Initializes a SAM2TwoWayTransformer instance. + + This transformer decoder attends to an input image using queries with supplied positional embeddings. + It is designed for tasks like object detection, image segmentation, and point cloud processing. + + Args: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for the input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Channel dimension internal to the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention computations. + + Examples: + >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> transformer + SAM2TwoWayTransformer( + (layers): ModuleList( + (0-4): 5 x SAM2TwoWayAttentionBlock(...) + ) + (final_attn_token_to_image): Attention(...) + (norm_final_attn): LayerNorm(...) + ) + """ + super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate) + self.layers = nn.ModuleList() + for i in range(depth): + self.layers.append( + SAM2TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + +class RoPEAttention(Attention): + """ + Implements rotary position encoding for attention mechanisms in transformer architectures. + + This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance + the positional awareness of the attention mechanism. + + Attributes: + compute_cis (Callable): Function to compute axial complex numbers for rotary encoding. + freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding. + rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories. + + Methods: + forward: Applies rotary position encoding and computes attention between query, key, and value tensors. + + Examples: + >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32)) + >>> q = torch.randn(1, 1024, 256) + >>> k = torch.randn(1, 1024, 256) + >>> v = torch.randn(1, 1024, 256) + >>> output = rope_attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 1024, 256]) + """ + + def __init__( + self, + *args, + rope_theta=10000.0, + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness.""" + super().__init__(*args, **kwargs) + + self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories + + def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor: + """Applies rotary position encoding and computes attention between query, key, and value tensors.""" + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations.""" + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + """ + Implements multiscale self-attention with optional query pooling for efficient feature extraction. + + This class provides a flexible implementation of multiscale attention, allowing for optional + downsampling of query features through pooling. It's designed to enhance the model's ability to + capture multiscale information in visual tasks. + + Attributes: + dim (int): Input dimension of the feature map. + dim_out (int): Output dimension of the attention module. + num_heads (int): Number of attention heads. + scale (float): Scaling factor for dot-product attention. + q_pool (nn.Module | None): Optional pooling module for query features. + qkv (nn.Linear): Linear projection for query, key, and value. + proj (nn.Linear): Output projection. + + Methods: + forward: Applies multiscale attention to the input tensor. + + Examples: + >>> import torch + >>> from torch import nn + >>> x = torch.randn(1, 64, 64, 256) + >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8) + >>> output = msa(x) + >>> print(output.shape) + torch.Size([1, 64, 64, 256]) + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + """Initializes multiscale attention with optional query pooling for efficient feature extraction.""" + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multiscale attention with optional query pooling to extract multiscale features.""" + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + """ + A multiscale attention block with window partitioning and query pooling for efficient vision transformers. + + This class implements a multiscale attention mechanism with optional window partitioning and downsampling, + designed for use in vision transformer architectures. + + Attributes: + dim (int): Input dimension of the block. + dim_out (int): Output dimension of the block. + norm1 (nn.Module): First normalization layer. + window_size (int): Size of the window for partitioning. + pool (nn.Module | None): Pooling layer for query downsampling. + q_stride (Tuple[int, int] | None): Stride for query pooling. + attn (MultiScaleAttention): Multi-scale attention module. + drop_path (nn.Module): Drop path layer for regularization. + norm2 (nn.Module): Second normalization layer. + mlp (MLP): Multi-layer perceptron module. + proj (nn.Linear | None): Projection layer for dimension mismatch. + + Methods: + forward: Processes input tensor through the multiscale block. + + Examples: + >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 28, 28, 512]) + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + """Initializes a multiscale attention block with window partitioning and optional query pooling.""" + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + act=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through multiscale attention and MLP, with optional windowing and downsampling.""" + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PositionEmbeddingSine(nn.Module): + """ + A module for generating sinusoidal positional embeddings for 2D inputs like images. + + This class implements sinusoidal position encoding for 2D spatial positions, which can be used in + transformer-based models for computer vision tasks. + + Attributes: + num_pos_feats (int): Number of positional features (half of the embedding dimension). + temperature (int): Temperature parameter for the sinusoidal functions. + normalize (bool): Whether to normalize the positional embeddings. + scale (float): Scaling factor for the embeddings when normalize is True. + cache (dict): Cache for storing precomputed embeddings. + + Methods: + _encode_xy: Encodes 2D positions using sine and cosine functions. + encode_boxes: Encodes box coordinates and dimensions into positional embeddings. + encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings. + forward: Generates sinusoidal position embeddings for 2D inputs. + + Examples: + >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128) + >>> x = torch.randn(1, 3, 224, 224) + >>> embeddings = pos_emb(x) + >>> print(embeddings.shape) + torch.Size([1, 256, 224, 224]) + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + """Initializes sinusoidal position embeddings for 2D image inputs.""" + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and not normalize: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + """Encodes 2D positions using sine/cosine functions for transformer positional embeddings.""" + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + """Encodes box coordinates and dimensions into positional embeddings for detection.""" + pos_x, pos_y = self._encode_xy(x, y) + return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + """Encodes 2D points with sinusoidal embeddings and appends labels.""" + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + + @torch.no_grad() + def forward(self, x: torch.Tensor): + """Generates sinusoidal position embeddings for 2D inputs like images.""" + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + + This class generates positional embeddings for input coordinates using random spatial frequencies. It is + particularly useful for transformer-based models that require position information. + + Attributes: + positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding. + + Methods: + _pe_encoding: Positionally encodes points that are normalized to [0,1]. + forward: Generates positional encoding for a grid of the specified size. + forward_with_coords: Positionally encodes points that are not normalized to [0,1]. + + Examples: + >>> pe = PositionEmbeddingRandom(num_pos_feats=64) + >>> size = (32, 32) + >>> encoding = pe(size) + >>> print(encoding.shape) + torch.Size([128, 32, 32]) + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + """Initializes random spatial frequency position embedding for transformers.""" + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) + + # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Encodes normalized [0,1] coordinates using random spatial frequencies.""" + # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # Outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generates positional encoding for a grid using random spatial frequencies.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size.""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class Block(nn.Module): + """ + Transformer block with support for window attention and residual propagation. + + This class implements a transformer block that can use either global or windowed self-attention, + followed by a feed-forward network. It supports relative positional embeddings and is designed + for use in vision transformer architectures. + + Attributes: + norm1 (nn.Module): First normalization layer. + attn (REAttention): Self-attention layer with optional relative positional encoding. + norm2 (nn.Module): Second normalization layer. + mlp (MLPBlock): Multi-layer perceptron block. + window_size (int): Size of attention window. If 0, global attention is used. + + Methods: + forward: Processes input through the transformer block. + + Examples: + >>> import torch + >>> block = Block(dim=256, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 56, 56, 256]) + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initializes a transformer block with optional window attention and relative positional embeddings. + + This constructor sets up a transformer block that can use either global or windowed self-attention, + followed by a feed-forward network. It supports relative positional embeddings and is designed + for use in vision transformer architectures. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in the self-attention layer. + mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension. + qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. + norm_layer (Type[nn.Module]): Type of normalization layer to use. + act_layer (Type[nn.Module]): Type of activation function to use in the MLP block. + use_rel_pos (bool): If True, uses relative positional embeddings in attention. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. + window_size (int): Size of attention window. If 0, uses global attention. + input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size. + + Examples: + >>> block = Block(dim=256, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 56, 56, 256]) + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = REAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through transformer block with optional windowed self-attention and residual connection.""" + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + return x + self.mlp(self.norm2(x)) + + +class REAttention(nn.Module): + """ + Rotary Embedding Attention module for efficient self-attention in transformer architectures. + + This class implements a multi-head attention mechanism with rotary positional embeddings, designed + for use in vision transformer models. It supports optional query pooling and window partitioning + for efficient processing of large inputs. + + Attributes: + compute_cis (Callable): Function to compute axial complex numbers for rotary encoding. + freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding. + rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories. + q_proj (nn.Linear): Linear projection for query. + k_proj (nn.Linear): Linear projection for key. + v_proj (nn.Linear): Linear projection for value. + out_proj (nn.Linear): Output projection. + num_heads (int): Number of attention heads. + internal_dim (int): Internal dimension for attention computation. + + Methods: + forward: Applies rotary position encoding and computes attention between query, key, and value tensors. + + Examples: + >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32)) + >>> q = torch.randn(1, 1024, 256) + >>> k = torch.randn(1, 1024, 256) + >>> v = torch.randn(1, 1024, 256) + >>> output = rope_attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 1024, 256]) + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initializes a Relative Position Attention module for transformer-based architectures. + + This module implements multi-head attention with optional relative positional encodings, designed + specifically for vision tasks in transformer models. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default is 8. + qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True. + use_rel_pos (bool): If True, uses relative positional encodings. Default is False. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True. + input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size. + Required if use_rel_pos is True. Default is None. + + Examples: + >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32)) + >>> x = torch.randn(1, 32, 32, 256) + >>> output = attention(x) + >>> print(output.shape) + torch.Size([1, 32, 32, 256]) + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, "Input size must be provided if using relative positional encoding." + # Initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multi-head attention with optional relative positional encoding to input tensor.""" + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + return self.proj(x) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding module for vision transformer architectures. + + This module converts an input image into a sequence of patch embeddings using a convolutional layer. + It is commonly used as the first layer in vision transformer architectures to transform image data + into a suitable format for subsequent transformer blocks. + + Attributes: + proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings. + + Methods: + forward: Applies patch embedding to the input tensor. + + Examples: + >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + torch.Size([1, 768, 14, 14]) + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Initializes the PatchEmbed module for converting image patches to embeddings. + + This module is typically used as the first layer in vision transformer architectures to transform + image data into a suitable format for subsequent transformer blocks. + + Args: + kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction. + stride (Tuple[int, int]): Stride of the convolutional operation. + padding (Tuple[int, int]): Padding applied to the input before convolution. + in_chans (int): Number of input image channels. + embed_dim (int): Dimensionality of the output patch embeddings. + + Examples: + >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + torch.Size([1, 768, 14, 14]) + """ + super().__init__() + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes patch embedding by applying convolution and transposing resulting tensor.""" + return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C diff --git a/tracking/ultralytics/models/sam/modules/decoders.py b/tracking/ultralytics/models/sam/modules/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..7dca2b5a0a6d40e3b20162a2f245a77ff736c3ea --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/decoders.py @@ -0,0 +1,515 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from ultralytics.nn.modules import MLP, LayerNorm2d + + +class MaskDecoder(nn.Module): + """ + Decoder module for generating masks and their associated quality scores using a transformer architecture. + + This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and + generate mask predictions along with their quality scores. + + Attributes: + transformer_dim (int): Channel dimension for the transformer module. + transformer (nn.Module): Transformer module used for mask prediction. + num_multimask_outputs (int): Number of masks to predict for disambiguating masks. + iou_token (nn.Embedding): Embedding for the IoU token. + num_mask_tokens (int): Number of mask tokens. + mask_tokens (nn.Embedding): Embedding for the mask tokens. + output_upscaling (nn.Sequential): Neural network sequence for upscaling the output. + output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks. + iou_prediction_head (nn.Module): MLP for predicting mask quality. + + Methods: + forward: Predicts masks given image and prompt embeddings. + predict_masks: Internal method for mask prediction. + + Examples: + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module) + >>> masks, iou_pred = decoder( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True + ... ) + >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}") + """ + + def __init__( + self, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Initialize the MaskDecoder module for generating masks and their associated quality scores. + + Args: + transformer_dim (int): Channel dimension for the transformer module. + transformer (nn.Module): Transformer module used for mask prediction. + num_multimask_outputs (int): Number of masks to predict for disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + + Examples: + >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6) + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer) + >>> print(decoder) + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) + + self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (torch.Tensor): Embeddings from the image encoder. + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes. + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs. + multimask_output (bool): Whether to return multiple masks or a single mask. + + Returns: + masks (torch.Tensor): Batched predicted masks. + iou_pred (torch.Tensor): Batched predictions of mask quality. + + Examples: + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module) + >>> image_emb = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_emb = torch.rand(1, 2, 256) + >>> dense_emb = torch.rand(1, 256, 64, 64) + >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True) + >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}") + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + mask_slice = slice(1, None) if multimask_output else slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predict masks and quality scores using image and prompt embeddings via transformer architecture.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [ + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) + ] + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +class SAM2MaskDecoder(nn.Module): + """ + Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings. + + This class extends the functionality of the MaskDecoder, incorporating additional features such as + high-resolution feature processing, dynamic multimask output, and object score prediction. + + Attributes: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + iou_token (nn.Embedding): Embedding for IOU token. + num_mask_tokens (int): Total number of mask tokens. + mask_tokens (nn.Embedding): Embedding for mask tokens. + pred_obj_scores (bool): Whether to predict object scores. + obj_score_token (nn.Embedding): Embedding for object score token. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + output_upscaling (nn.Sequential): Upscaling layers for output. + use_high_res_features (bool): Whether to use high-resolution features. + conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0). + conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1). + output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks. + iou_prediction_head (MLP): MLP for IOU prediction. + pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + + Methods: + forward: Predicts masks given image and prompt embeddings. + predict_masks: Predicts instance segmentation masks from image and prompt embeddings. + _get_stability_scores: Computes mask stability scores based on IoU between thresholds. + _dynamic_multimask_via_stability: Dynamically selects the most stable mask output. + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = SAM2MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False + ... ) + """ + + def __init__( + self, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Initialize the SAM2MaskDecoder module for predicting instance segmentation masks. + + This decoder extends the functionality of MaskDecoder, incorporating additional features such as + high-resolution feature processing, dynamic multimask output, and object score prediction. + + Args: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + use_high_res_features (bool): Whether to use high-resolution features. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + pred_obj_scores (bool): Whether to predict object scores. + pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + + Examples: + >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6) + >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer) + >>> print(decoder) + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1) + + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W). + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W). + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C). + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W). + multimask_output (bool): Whether to return multiple masks or a single mask. + repeat_image (bool): Flag to repeat the image embeddings. + high_res_features (List[torch.Tensor] | None): Optional high-resolution features. + + Returns: + masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W). + iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N). + sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C). + object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1). + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = SAM2MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False + ... ) + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Predict instance segmentation masks from image and prompt embeddings using a transformer.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [ + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) + ] + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """Compute mask stability scores based on IoU between upper and lower thresholds.""" + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + return torch.where(area_u > 0, area_i / area_u, 1.0) + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + Dynamically select the most stable mask output based on stability scores and IoU predictions. + + This method is used when outputting a single mask. If the stability score from the current single-mask + output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs + (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask + for both clicking and tracking scenarios. + + Args: + all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is + batch size, N is number of masks (typically 4), and H, W are mask dimensions. + all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N). + + Returns: + mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W). + iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1). + + Examples: + >>> decoder = SAM2MaskDecoder(...) + >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each + >>> all_iou_scores = torch.rand(2, 4) + >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores) + >>> print(mask_logits.shape, iou_scores.shape) + torch.Size([2, 1, 256, 256]) torch.Size([2, 1]) + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/tracking/ultralytics/models/sam/modules/encoders.py b/tracking/ultralytics/models/sam/modules/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e2b687cdc7100db0c0fefde9875283e575d985 --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/encoders.py @@ -0,0 +1,773 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.nn.modules import LayerNorm2d + +from .blocks import ( + Block, + CXBlock, + Fuser, + MaskDownSampler, + MultiScaleBlock, + PatchEmbed, + PositionEmbeddingRandom, + PositionEmbeddingSine, +) + + +class ImageEncoderViT(nn.Module): + """ + An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space. + + This class processes images by splitting them into patches, applying transformer blocks, and generating a final + encoded representation through a neck module. + + Attributes: + img_size (int): Dimension of input images, assumed to be square. + patch_embed (PatchEmbed): Module for patch embedding. + pos_embed (nn.Parameter | None): Absolute positional embedding for patches. + blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings. + neck (nn.Sequential): Neck module to further process the output. + + Methods: + forward: Processes input through patch embedding, positional embedding, blocks, and neck. + + Examples: + >>> import torch + >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12) + >>> input_image = torch.randn(1, 3, 224, 224) + >>> output = encoder(input_image) + >>> print(output.shape) + """ + + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture. + + Args: + img_size (int): Input image size, assumed to be square. + patch_size (int): Size of image patches. + in_chans (int): Number of input image channels. + embed_dim (int): Dimension of patch embeddings. + depth (int): Number of transformer blocks. + num_heads (int): Number of attention heads in each block. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + out_chans (int): Number of output channels from the neck module. + qkv_bias (bool): If True, adds learnable bias to query, key, value projections. + norm_layer (Type[nn.Module]): Type of normalization layer to use. + act_layer (Type[nn.Module]): Type of activation layer to use. + use_abs_pos (bool): If True, uses absolute positional embeddings. + use_rel_pos (bool): If True, adds relative positional embeddings to attention maps. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. + window_size (int): Size of attention window for windowed attention blocks. + global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention. + + Examples: + >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12) + >>> input_image = torch.randn(1, 3, 224, 224) + >>> output = encoder(input_image) + >>> print(output.shape) + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Process input through patch embedding, positional embedding, transformer blocks, and neck module.""" + x = self.patch_embed(x) + if self.pos_embed is not None: + pos_embed = ( + F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1) + if self.img_size != 1024 + else self.pos_embed + ) + x = x + pos_embed + for blk in self.blocks: + x = blk(x) + return self.neck(x.permute(0, 3, 1, 2)) + + +class PromptEncoder(nn.Module): + """ + Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings. + + Attributes: + embed_dim (int): Dimension of the embeddings. + input_image_size (Tuple[int, int]): Size of the input image as (H, W). + image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W). + pe_layer (PositionEmbeddingRandom): Module for random position embedding. + num_point_embeddings (int): Number of point embeddings for different types of points. + point_embeddings (nn.ModuleList): List of point embeddings. + not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label. + mask_input_size (Tuple[int, int]): Size of the input mask. + mask_downscaling (nn.Sequential): Neural network for downscaling the mask. + no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided. + + Methods: + get_dense_pe: Returns the positional encoding used to encode point prompts. + forward: Embeds different types of prompts, returning both sparse and dense embeddings. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks) + >>> print(sparse_embeddings.shape, dense_embeddings.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Initialize the PromptEncoder module for encoding various types of prompts. + + Args: + embed_dim (int): The dimension of the embeddings. + image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W). + input_image_size (Tuple[int, int]): The padded size of the input image as (H, W). + mask_in_chans (int): The number of hidden channels used for encoding input masks. + activation (Type[nn.Module]): The activation function to use when encoding input masks. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks) + >>> print(sparse_embeddings.shape, dense_embeddings.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Return the dense positional encoding used for encoding point prompts. + + Generate a positional encoding for a dense set of points matching the shape of the image + encoding. The encoding is used to provide spatial information to the model when processing point prompts. + + Returns: + (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the + height and width of the image embedding size, respectively. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> dense_pe = prompt_encoder.get_dense_pe() + >>> print(dense_pe.shape) + torch.Size([1, 256, 64, 64]) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embed point prompts by applying positional encoding and label-specific embeddings.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embed box prompts by applying positional encoding and adding corner embeddings.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embed mask inputs by downscaling and processing through convolutional layers.""" + return self.mask_downscaling(masks) + + @staticmethod + def _get_batch_size( + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """Get the batch size of the output given the batch size of the input prompts.""" + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + """Return the device of the first point embedding's weight tensor.""" + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embed different types of prompts, returning both sparse and dense embeddings. + + Args: + points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first + tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with + shape (B, N). + boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes. + masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W). + + Returns: + (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim). + - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W). + + Examples: + >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_emb, dense_emb = encoder(points, boxes, masks) + >>> print(sparse_emb.shape, dense_emb.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class MemoryEncoder(nn.Module): + """ + Encode pixel features and masks into a memory representation for efficient image segmentation. + + This class processes pixel-level features and masks, fusing them to generate encoded memory representations + suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model). + + Attributes: + mask_downsampler (MaskDownSampler): Module for downsampling input masks. + pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features. + fuser (Fuser): Module for fusing pixel features and masks. + position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features. + out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d. + + Methods: + forward: Process input pixel features and masks to generate encoded memory representations. + + Examples: + >>> import torch + >>> encoder = MemoryEncoder(out_dim=256, in_dim=256) + >>> pix_feat = torch.randn(1, 256, 64, 64) + >>> masks = torch.randn(1, 1, 64, 64) + >>> encoded_feat, pos = encoder(pix_feat, masks) + >>> print(encoded_feat.shape, pos.shape) + torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64]) + """ + + def __init__( + self, + out_dim, + in_dim=256, # in_dim of pix_feats + ): + """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.""" + super().__init__() + + self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = Fuser(CXBlock(dim=256), num_layers=2) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=64) + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Process pixel features and masks to generate encoded memory representations for segmentation.""" + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} + + +class ImageEncoder(nn.Module): + """ + Encode images using a trunk-neck architecture, producing multiscale features and positional encodings. + + This class combines a trunk network for feature extraction with a neck network for feature refinement + and positional encoding generation. It can optionally discard the lowest resolution features. + + Attributes: + trunk (nn.Module): The trunk network for initial feature extraction. + neck (nn.Module): The neck network for feature refinement and positional encoding generation. + scalp (int): Number of lowest resolution feature levels to discard. + + Methods: + forward: Process the input image through the trunk and neck networks. + + Examples: + >>> trunk = SomeTrunkNetwork() + >>> neck = SomeNeckNetwork() + >>> encoder = ImageEncoder(trunk, neck, scalp=1) + >>> image = torch.randn(1, 3, 224, 224) + >>> output = encoder(image) + >>> print(output.keys()) + dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn']) + """ + + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + """Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.""" + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert self.trunk.channel_list == self.neck.backbone_channel_list, ( + f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match." + ) + + def forward(self, sample: torch.Tensor): + """Encode input through patch embedding, positional embedding, transformer blocks, and neck module.""" + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + return { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + + +class FpnNeck(nn.Module): + """ + A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models. + + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. + + Attributes: + position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module. + convs (nn.ModuleList): List of convolutional layers for each backbone level. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (List[int]): Levels to have top-down features in outputs. + + Methods: + forward: Perform forward pass through the FPN neck. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels] + >>> outputs, positions = fpn_neck(inputs) + >>> print(len(outputs), len(positions)) + 4 4 + """ + + def __init__( + self, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """ + Initializes a modified Feature Pyramid Network (FPN) neck. + + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. + + Args: + d_model (int): Dimension of the model. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + kernel_size (int): Kernel size for the convolutional layers. + stride (int): Stride for the convolutional layers. + padding (int): Padding for the convolutional layers. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> print(fpn_neck) + """ + super().__init__() + self.position_encoding = PositionEmbeddingSine(num_pos_feats=256) + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in {"sum", "avg"} + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + """ + Performs forward pass through the Feature Pyramid Network (FPN) neck. + + This method processes a list of input tensors from the backbone through the FPN, applying lateral connections + and top-down feature fusion. It generates output feature maps and corresponding positional encodings. + + Args: + xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W). + + Returns: + (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing: + - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape + (B, d_model, H, W). + - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map. + + Examples: + >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512]) + >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]] + >>> outputs, positions = fpn_neck(inputs) + >>> print(len(outputs), len(positions)) + 4 4 + """ + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=(None if self.fpn_interp_model == "nearest" else False), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos + + +class Hiera(nn.Module): + """ + Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks. + + This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for + efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages, + with optional pooling and global attention mechanisms. + + Attributes: + window_spec (Tuple[int, ...]): Window sizes for each stage. + q_stride (Tuple[int, int]): Downsampling stride between stages. + stage_ends (List[int]): Indices of the last block in each stage. + q_pool_blocks (List[int]): Indices of blocks where pooling is applied. + return_interm_layers (bool): Whether to return intermediate layer outputs. + patch_embed (PatchEmbed): Module for patch embedding. + global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention. + window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background. + pos_embed (nn.Parameter): Positional embedding for the background. + pos_embed_window (nn.Parameter): Positional embedding for the window. + blocks (nn.ModuleList): List of MultiScaleBlock modules. + channel_list (List[int]): List of output channel dimensions for each stage. + + Methods: + _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings. + forward: Perform the forward pass through the Hiera model. + + Examples: + >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3)) + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> output_features = model(input_tensor) + >>> for feat in output_features: + ... print(feat.shape) + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + """Initialize the Hiera model, configuring its hierarchical vision transformer architecture.""" + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + kernel_size=(7, 7), + stride=(4, 4), + padding=(3, 3), + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)) + self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + """Generate positional embeddings by interpolating and combining window and background embeddings.""" + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Perform forward pass through Hiera model, extracting multiscale features from input images.""" + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/tracking/ultralytics/models/sam/modules/memory_attention.py b/tracking/ultralytics/models/sam/modules/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3759cdef54715a16aefc6cc4018fb7279e39949d --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/memory_attention.py @@ -0,0 +1,241 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy +from typing import Optional + +import torch +from torch import Tensor, nn + +from .blocks import RoPEAttention + + +class MemoryAttentionLayer(nn.Module): + """ + Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks. + + This class combines self-attention, cross-attention, and feedforward components to process input tensors and + generate memory-based attention outputs. + + Attributes: + d_model (int): Dimensionality of the model. + dim_feedforward (int): Dimensionality of the feedforward network. + dropout_value (float): Dropout rate for regularization. + self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding). + cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing. + linear1 (nn.Linear): First linear layer of the feedforward network. + linear2 (nn.Linear): Second linear layer of the feedforward network. + norm1 (nn.LayerNorm): Layer normalization for self-attention output. + norm2 (nn.LayerNorm): Layer normalization for cross-attention output. + norm3 (nn.LayerNorm): Layer normalization for feedforward network output. + dropout1 (nn.Dropout): Dropout layer after self-attention. + dropout2 (nn.Dropout): Dropout layer after cross-attention. + dropout3 (nn.Dropout): Dropout layer after feedforward network. + activation (nn.ReLU): Activation function for the feedforward network. + pos_enc_at_attn (bool): Flag to add positional encoding at attention. + pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries. + pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys. + + Methods: + forward: Performs the full memory attention operation on input tensors. + _forward_sa: Performs self-attention on input tensor. + _forward_ca: Performs cross-attention between target and memory tensors. + + Examples: + >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1) + >>> tgt = torch.randn(1, 100, 256) + >>> memory = torch.randn(1, 100, 64) + >>> pos = torch.randn(1, 100, 256) + >>> query_pos = torch.randn(1, 100, 256) + >>> output = layer(tgt, memory, pos, query_pos) + >>> print(output.shape) + torch.Size([1, 100, 256]) + """ + + def __init__( + self, + d_model: int = 256, + dim_feedforward: int = 2048, + dropout: float = 0.1, + pos_enc_at_attn: bool = False, + pos_enc_at_cross_attn_keys: bool = True, + pos_enc_at_cross_attn_queries: bool = False, + ): + """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.""" + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1) + self.cross_attn_image = RoPEAttention( + rope_k_repeat=True, + embedding_dim=256, + num_heads=1, + downsample_rate=1, + kv_in_dim=64, + ) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = nn.ReLU() + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt: Tensor, query_pos: Optional[Tensor]) -> Tensor: + """Perform self-attention on input tensor using positional encoding and RoPE attention mechanism.""" + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca( + self, + tgt: Tensor, + memory: Tensor, + query_pos: Optional[Tensor], + pos: Optional[Tensor], + num_k_exclude_rope: int = 0, + ) -> Tensor: + """Perform cross-attention between target and memory tensors using RoPEAttention mechanism.""" + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt: Tensor, + memory: Tensor, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + """Process input tensors through self-attention, cross-attention, and feedforward network layers.""" + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + """ + Memory attention module for processing sequential data with self and cross-attention mechanisms. + + This class implements a multi-layer attention mechanism that combines self-attention and cross-attention + for processing sequential data, particularly useful in transformer-like architectures. + + Attributes: + d_model (int): The dimension of the model's hidden state. + layers (nn.ModuleList): A list of MemoryAttentionLayer modules. + num_layers (int): The number of attention layers. + norm (nn.LayerNorm): Layer normalization applied to the output. + pos_enc_at_input (bool): Whether to apply positional encoding at the input. + batch_first (bool): Whether the input tensors are in batch-first format. + + Methods: + forward: Processes input tensors through the attention layers. + + Examples: + >>> d_model = 256 + >>> layer = MemoryAttentionLayer(d_model) + >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3) + >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model) + >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model) + >>> curr_pos = torch.randn(10, 32, d_model) + >>> memory_pos = torch.randn(20, 32, d_model) + >>> output = attention(curr, memory, curr_pos, memory_pos) + >>> print(output.shape) + torch.Size([10, 32, 256]) + """ + + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + """Initialize MemoryAttention with specified layers and normalization for sequential data processing.""" + super().__init__() + self.d_model = d_model + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ) -> torch.Tensor: + """Process inputs through attention layers, applying self and cross-attention with positional encoding.""" + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = curr[0], curr_pos[0] + + assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/tracking/ultralytics/models/sam/modules/sam.py b/tracking/ultralytics/models/sam/modules/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..96ef3e2046e3fe818232538f93df1b83d8805dc1 --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/sam.py @@ -0,0 +1,1005 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import trunc_normal_ + +from ultralytics.nn.modules import MLP + +from .blocks import SAM2TwoWayTransformer +from .decoders import MaskDecoder, SAM2MaskDecoder +from .encoders import ImageEncoderViT, PromptEncoder +from .utils import get_1d_sine_pe, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAMModel(nn.Module): + """ + Segment Anything Model (SAM) for object segmentation tasks. + + This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images + and input prompts. + + Attributes: + mask_threshold (float): Threshold value for mask prediction. + image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings. + prompt_encoder (PromptEncoder): Encoder for various types of input prompts. + mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings. + + Methods: + __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> prompt_encoder = PromptEncoder(...) + >>> mask_decoder = MaskDecoder(...) + >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) + >>> # Further usage depends on SAMPredictor class + + Notes: + All forward() operations are implemented in the SAMPredictor class. + """ + + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = (123.675, 116.28, 103.53), + pixel_std: List[float] = (58.395, 57.12, 57.375), + ) -> None: + """ + Initialize the SAMModel class to predict object masks from an image and input prompts. + + Args: + image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. + pixel_mean (List[float]): Mean values for normalizing pixels in the input image. + pixel_std (List[float]): Std values for normalizing pixels in the input image. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> prompt_encoder = PromptEncoder(...) + >>> mask_decoder = MaskDecoder(...) + >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) + >>> # Further usage depends on SAMPredictor class + + Notes: + All forward() operations moved to SAMPredictor. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + def set_imgsz(self, imgsz): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + if hasattr(self.image_encoder, "set_imgsz"): + self.image_encoder.set_imgsz(imgsz) + self.prompt_encoder.input_image_size = imgsz + self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model + self.image_encoder.img_size = imgsz[0] + + +class SAM2Model(torch.nn.Module): + """ + SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities. + + This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms + for temporal consistency and efficient tracking of objects across frames. + + Attributes: + mask_threshold (float): Threshold value for mask prediction. + image_encoder (ImageEncoderViT): Visual encoder for extracting image features. + memory_attention (nn.Module): Module for attending to memory features. + memory_encoder (nn.Module): Encoder for generating memory representations. + num_maskmem (int): Number of accessible memory frames. + image_size (int): Size of input images. + backbone_stride (int): Stride of the backbone network output. + sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings. + sam_image_embedding_size (int): Size of SAM image embeddings. + sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts. + sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks. + obj_ptr_proj (nn.Module): Projection layer for object pointers. + obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers. + + Methods: + forward_image: Processes image batch through encoder to extract multi-level features. + track_step: Performs a single tracking step, updating object masks and memory features. + + Examples: + >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) + >>> image_batch = torch.rand(1, 3, 512, 512) + >>> features = model.forward_image(image_batch) + >>> track_results = model.track_step(0, True, features, None, None, None, {}) + """ + + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, + image_size=512, + backbone_stride=16, + sigmoid_scale_for_mem_enc=1.0, + sigmoid_bias_for_mem_enc=0.0, + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, + max_cond_frames_in_attn=-1, + directly_add_no_mem_embed=False, + use_high_res_features_in_sam=False, + multimask_output_in_sam=False, + multimask_min_pt_num=1, + multimask_max_pt_num=1, + multimask_output_for_tracking=False, + use_multimask_token_for_obj_ptr: bool = False, + iou_prediction_use_sigmoid=False, + memory_temporal_stride_for_eval=1, + non_overlap_masks_for_mem_enc=False, + use_obj_ptrs_in_encoder=False, + max_obj_ptrs_in_encoder=16, + add_tpos_enc_to_obj_ptrs=True, + proj_tpos_enc_in_obj_ptrs=False, + use_signed_tpos_enc_to_obj_ptrs=False, + only_obj_ptrs_in_the_past_for_eval=False, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + fixed_no_obj_ptr: bool = False, + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + no_obj_embed_spatial: bool = False, + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + """ + Initialize the SAM2Model for video object segmentation with memory-based tracking. + + Args: + image_encoder (nn.Module): Visual encoder for extracting image features. + memory_attention (nn.Module): Module for attending to memory features. + memory_encoder (nn.Module): Encoder for generating memory representations. + num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames). + image_size (int): Size of input images. + backbone_stride (int): Stride of the image backbone output. + sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability. + sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability. + binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames + with clicks during evaluation. + use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM + prompt encoder and mask decoder on frames with mask input. + max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention. + -1 means no limit. + directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the + first frame. + use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder. + multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial + conditioning frames. + multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM. + multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM. + multimask_output_for_tracking (bool): Whether to use multimask output for tracking. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. + memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. + non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in + memory encoder during evaluation. + use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. + max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder + cross-attention. + add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in + the encoder. + proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional + encoding in object pointers. + use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance (instead of unsigned absolute distance) + in the temporal positional encoding in the object pointers, only relevant when both + `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`. + only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past + during evaluation. + pred_obj_scores (bool): Whether to predict if there is an object in the frame. + pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores. + fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. + soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. + use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. + no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames. + sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. + compile_image_encoder (bool): Whether to compile the image encoder for faster inference. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> memory_attention = SAM2TwoWayTransformer(...) + >>> memory_encoder = nn.Sequential(...) + >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) + >>> image_batch = torch.rand(1, 3, 512, 512) + >>> features = model.forward_image(image_batch) + >>> track_results = model.track_step(0, True, features, None, None, None, {}) + """ + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print("Image encoder compilation is enabled. First forward pass will be slow.") + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + """Return the device on which the model's parameters are stored.""" + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + """Process image and prompt inputs to generate object masks and scores in video sequences.""" + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder for image segmentation tasks.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = SAM2MaskDecoder( + num_multimask_outputs=3, + transformer=SAM2TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward pass through SAM prompt encoders and mask heads. + + This method processes image features and optional point/mask inputs to generate object masks and scores. + + Args: + backbone_features (torch.Tensor): Image features with shape (B, C, H, W). + point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. + 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute + pixel-unit coordinates in (x, y) format for P input points. + 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, + 0 means negative clicks, and -1 means padding. + mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the + same spatial size as the image. + high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes + (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps + for SAM decoder. + multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, + output only 1 mask and its IoU estimate. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): + low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. + high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. + ious: Tensor of shape (B, M) with estimated IoU for each output mask. + low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask. + high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask. + obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. + object_score_logits: Tensor of shape (B) with object score logits. + Where M is 3 if multimask_output=True, and 1 if multimask_output=False. + + Examples: + >>> backbone_features = torch.rand(1, 256, 32, 32) + >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} + >>> mask_inputs = torch.rand(1, 1, 512, 512) + >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) + >>> ( + ... low_res_multimasks, + ... high_res_multimasks, + ... ious, + ... low_res_masks, + ... high_res_masks, + ... obj_ptr, + ... object_score_logits, + ... ) = results + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """Process mask inputs directly as output, bypassing SAM encoder/decoder.""" + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Process image batch through encoder to extract multi-level features for SAM model.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features from the image backbone output for further processing.""" + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Prepare memory-conditioned features by fusing current frame's visual features with previous memories.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frame's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = 1 if self.training else self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel + elif not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to inference device (it's a no-op if it's already on device). + feats = prev["maskmem_features"].to(device=device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if pos_and_ptrs: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encode frame features and masks into a new memory representation for video segmentation.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[ + ..., None, None + ].expand(*maskmem_features.shape) + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + """Perform a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + """Run memory encoder on predicted mask to encode it into a new memory feature for future frames.""" + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + """Perform a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + return ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + + @staticmethod + def _apply_non_overlapping_constraints(pred_masks): + """Apply non-overlapping constraints to masks, keeping the highest scoring object per location.""" + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks + + def set_binarize(self, binarize=False): + """Set binarize for VideoPredictor.""" + self.binarize_mask_from_pts_for_mem_enc = binarize + + def set_imgsz(self, imgsz): + """Set image size to make model compatible with different image sizes.""" + self.image_size = imgsz[0] + self.sam_prompt_encoder.input_image_size = imgsz + self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16 diff --git a/tracking/ultralytics/models/sam/modules/tiny_encoder.py b/tracking/ultralytics/models/sam/modules/tiny_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a3a63d455bb60773d57e8e49d910ec43a0f2ca --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/tiny_encoder.py @@ -0,0 +1,1003 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from ultralytics.nn.modules import LayerNorm2d +from ultralytics.utils.instance import to_2tuple + + +class Conv2d_BN(torch.nn.Sequential): + """ + A sequential container that performs 2D convolution followed by batch normalization. + + Attributes: + c (torch.nn.Conv2d): 2D convolution layer. + bn (torch.nn.BatchNorm2d): Batch normalization layer. + + Methods: + __init__: Initializes the Conv2d_BN with specified parameters. + + Args: + a (int): Number of input channels. + b (int): Number of output channels. + ks (int): Kernel size for the convolution. Defaults to 1. + stride (int): Stride for the convolution. Defaults to 1. + pad (int): Padding for the convolution. Defaults to 0. + dilation (int): Dilation factor for the convolution. Defaults to 1. + groups (int): Number of groups for the convolution. Defaults to 1. + bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1. + + Examples: + >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1) + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> output = conv_bn(input_tensor) + >>> print(output.shape) + """ + + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + """Initializes a sequential container with 2D convolution followed by batch normalization.""" + super().__init__() + self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module("bn", bn) + + +class PatchEmbed(nn.Module): + """ + Embeds images into patches and projects them into a specified embedding dimension. + + Attributes: + patches_resolution (Tuple[int, int]): Resolution of the patches after embedding. + num_patches (int): Total number of patches. + in_chans (int): Number of input channels. + embed_dim (int): Dimension of the embedding. + seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding. + + Methods: + forward: Processes the input tensor through the patch embedding sequence. + + Examples: + >>> import torch + >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + """ + + def __init__(self, in_chans, embed_dim, resolution, activation): + """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection.""" + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + """Processes input tensor through patch embedding sequence, converting images to patch embeddings.""" + return self.seq(x) + + +class MBConv(nn.Module): + """ + Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture. + + Attributes: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + out_chans (int): Number of output channels. + conv1 (Conv2d_BN): First convolutional layer. + act1 (nn.Module): First activation function. + conv2 (Conv2d_BN): Depthwise convolutional layer. + act2 (nn.Module): Second activation function. + conv3 (Conv2d_BN): Final convolutional layer. + act3 (nn.Module): Third activation function. + drop_path (nn.Module): Drop path layer (Identity for inference). + + Methods: + forward: Performs the forward pass through the MBConv layer. + + Examples: + >>> in_chans, out_chans = 32, 64 + >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1) + >>> x = torch.randn(1, in_chans, 56, 56) + >>> output = mbconv(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ + + def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation.""" + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + # NOTE: `DropPath` is needed only for training. + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + + def forward(self, x): + """Implements the forward pass of MBConv, applying convolutions and skip connection.""" + shortcut = x + x = self.conv1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop_path(x) + x += shortcut + return self.act3(x) + + +class PatchMerging(nn.Module): + """ + Merges neighboring patches in the feature map and projects to a new dimension. + + This class implements a patch merging operation that combines spatial information and adjusts the feature + dimension. It uses a series of convolutional layers with batch normalization to achieve this. + + Attributes: + input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map. + dim (int): The input dimension of the feature map. + out_dim (int): The output dimension after merging and projection. + act (nn.Module): The activation function used between convolutions. + conv1 (Conv2d_BN): The first convolutional layer for dimension projection. + conv2 (Conv2d_BN): The second convolutional layer for spatial merging. + conv3 (Conv2d_BN): The third convolutional layer for final projection. + + Methods: + forward: Applies the patch merging operation to the input tensor. + + Examples: + >>> input_resolution = (56, 56) + >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU) + >>> x = torch.randn(4, 64, 56, 56) + >>> output = patch_merging(x) + >>> print(output.shape) + """ + + def __init__(self, input_resolution, dim, out_dim, activation): + """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps.""" + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c = 1 if out_dim in {320, 448, 576} else 2 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + """Applies patch merging and dimension projection to the input feature map.""" + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + return x.flatten(2).transpose(1, 2) + + +class ConvLayer(nn.Module): + """ + Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv). + + This layer optionally applies downsample operations to the output and supports gradient checkpointing. + + Attributes: + dim (int): Dimensionality of the input and output. + input_resolution (Tuple[int, int]): Resolution of the input image. + depth (int): Number of MBConv layers in the block. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + blocks (nn.ModuleList): List of MBConv layers. + downsample (Optional[Callable]): Function for downsampling the output. + + Methods: + forward: Processes the input through the convolutional layers. + + Examples: + >>> input_tensor = torch.randn(1, 64, 56, 56) + >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU) + >>> output = conv_layer(input_tensor) + >>> print(output.shape) + """ + + def __init__( + self, + dim, + input_resolution, + depth, + activation, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4.0, + ): + """ + Initializes the ConvLayer with the given dimensions and settings. + + This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and + optionally applies downsampling to the output. + + Args: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): The resolution of the input image. + depth (int): The number of MBConv layers in the block. + activation (nn.Module): Activation function applied after each convolution. + drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv. + downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`. + conv_expand_ratio (float): Expansion ratio for the MBConv layers. + + Examples: + >>> input_tensor = torch.randn(1, 64, 56, 56) + >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU) + >>> output = conv_layer(input_tensor) + >>> print(output.shape) + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # Build blocks + self.blocks = nn.ModuleList( + [ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ] + ) + + # Patch merging layer + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) + + def forward(self, x): + """Processes input through convolutional layers, applying MBConv blocks and optional downsampling.""" + for blk in self.blocks: + x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) + return x if self.downsample is None else self.downsample(x) + + +class Mlp(nn.Module): + """ + Multi-layer Perceptron (MLP) module for transformer architectures. + + This module applies layer normalization, two fully-connected layers with an activation function in between, + and dropout. It is commonly used in transformer-based architectures. + + Attributes: + norm (nn.LayerNorm): Layer normalization applied to the input. + fc1 (nn.Linear): First fully-connected layer. + fc2 (nn.Linear): Second fully-connected layer. + act (nn.Module): Activation function applied after the first fully-connected layer. + drop (nn.Dropout): Dropout layer applied after the activation function. + + Methods: + forward: Applies the MLP operations on the input tensor. + + Examples: + >>> import torch + >>> from torch import nn + >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1) + >>> x = torch.randn(32, 100, 256) + >>> output = mlp(x) + >>> print(output.shape) + torch.Size([32, 100, 256]) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions.""" + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor.""" + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + +class Attention(torch.nn.Module): + """ + Multi-head attention module with spatial awareness and trainable attention biases. + + This module implements a multi-head attention mechanism with support for spatial awareness, applying + attention biases based on spatial resolution. It includes trainable attention biases for each unique + offset between spatial positions in the resolution grid. + + Attributes: + num_heads (int): Number of attention heads. + scale (float): Scaling factor for attention scores. + key_dim (int): Dimensionality of the keys and queries. + nh_kd (int): Product of num_heads and key_dim. + d (int): Dimensionality of the value vectors. + dh (int): Product of d and num_heads. + attn_ratio (float): Attention ratio affecting the dimensions of the value vectors. + norm (nn.LayerNorm): Layer normalization applied to input. + qkv (nn.Linear): Linear layer for computing query, key, and value projections. + proj (nn.Linear): Linear layer for final projection. + attention_biases (nn.Parameter): Learnable attention biases. + attention_bias_idxs (Tensor): Indices for attention biases. + ab (Tensor): Cached attention biases for inference, deleted during training. + + Methods: + train: Sets the module in training mode and handles the 'ab' attribute. + forward: Performs the forward pass of the attention mechanism. + + Examples: + >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14)) + >>> x = torch.randn(1, 196, 256) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 196, 256]) + """ + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + """ + Initializes the Attention module for multi-head attention with spatial awareness. + + This module implements a multi-head attention mechanism with support for spatial awareness, applying + attention biases based on spatial resolution. It includes trainable attention biases for each unique + offset between spatial positions in the resolution grid. + + Args: + dim (int): The dimensionality of the input and output. + key_dim (int): The dimensionality of the keys and queries. + num_heads (int): Number of attention heads. + attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. + resolution (Tuple[int, int]): Spatial resolution of the input feature map. + + Examples: + >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14)) + >>> x = torch.randn(1, 196, 256) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 196, 256]) + """ + super().__init__() + + assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2" + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False) + + @torch.no_grad() + def train(self, mode=True): + """Performs multi-head attention with spatial awareness and trainable attention biases.""" + super().train(mode) + if mode and hasattr(self, "ab"): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x + """Applies multi-head attention with spatial awareness and trainable attention biases.""" + B, N, _ = x.shape # B, N, C + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + self.ab = self.ab.to(self.attention_biases.device) + + attn = (q @ k.transpose(-2, -1)) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + return self.proj(x) + + +class TinyViTBlock(nn.Module): + """ + TinyViT Block that applies self-attention and a local convolution to the input. + + This block is a key component of the TinyViT architecture, combining self-attention mechanisms with + local convolutions to process input features efficiently. + + Attributes: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + num_heads (int): Number of attention heads. + window_size (int): Size of the attention window. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop_path (nn.Module): Stochastic depth layer, identity function during inference. + attn (Attention): Self-attention module. + mlp (Mlp): Multi-layer perceptron module. + local_conv (Conv2d_BN): Depth-wise local convolution layer. + + Methods: + forward: Processes the input through the TinyViT block. + extra_repr: Returns a string with extra information about the block's parameters. + + Examples: + >>> input_tensor = torch.randn(1, 196, 192) + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3) + >>> output = block(input_tensor) + >>> print(output.shape) + torch.Size([1, 196, 192]) + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + local_conv_size=3, + activation=nn.GELU, + ): + """ + Initializes a TinyViT block with self-attention and local convolution. + + This block is a key component of the TinyViT architecture, combining self-attention mechanisms with + local convolutions to process input features efficiently. + + Args: + dim (int): Dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width). + num_heads (int): Number of attention heads. + window_size (int): Size of the attention window. Must be greater than 0. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop (float): Dropout rate. + drop_path (float): Stochastic depth rate. + local_conv_size (int): Kernel size of the local convolution. + activation (torch.nn.Module): Activation function for MLP. + + Raises: + AssertionError: If window_size is not greater than 0. + AssertionError: If dim is not divisible by num_heads. + + Examples: + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3) + >>> input_tensor = torch.randn(1, 196, 192) + >>> output = block(input_tensor) + >>> print(output.shape) + torch.Size([1, 196, 192]) + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, "window_size must be greater than 0" + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + # NOTE: `DropPath` is needed only for training. + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + + assert dim % num_heads == 0, "dim must be divisible by num_heads" + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + """Applies self-attention, local convolution, and MLP operations to the input tensor.""" + h, w = self.input_resolution + b, hw, c = x.shape # batch, height*width, channels + assert hw == h * w, "input feature has wrong size" + res_x = x + if h == self.window_size and w == self.window_size: + x = self.attn(x) + else: + x = x.view(b, h, w, c) + pad_b = (self.window_size - h % self.window_size) % self.window_size + pad_r = (self.window_size - w % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = h + pad_b, w + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + + # Window partition + x = ( + x.view(b, nH, self.window_size, nW, self.window_size, c) + .transpose(2, 3) + .reshape(b * nH * nW, self.window_size * self.window_size, c) + ) + x = self.attn(x) + + # Window reverse + x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c) + if padding: + x = x[:, :h, :w].contiguous() + + x = x.view(b, hw, c) + + x = res_x + self.drop_path(x) + x = x.transpose(1, 2).reshape(b, c, h, w) + x = self.local_conv(x) + x = x.view(b, c, hw).transpose(1, 2) + + return x + self.drop_path(self.mlp(x)) + + def extra_repr(self) -> str: + """ + Returns a string representation of the TinyViTBlock's parameters. + + This method provides a formatted string containing key information about the TinyViTBlock, including its + dimension, input resolution, number of attention heads, window size, and MLP ratio. + + Returns: + (str): A formatted string containing the block's parameters. + + Examples: + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0) + >>> print(block.extra_repr()) + dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0 + """ + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class BasicLayer(nn.Module): + """ + A basic TinyViT layer for one stage in a TinyViT architecture. + + This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks + and an optional downsampling operation. + + Attributes: + dim (int): The dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + depth (int): Number of TinyViT blocks in this layer. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + blocks (nn.ModuleList): List of TinyViT blocks that make up this layer. + downsample (nn.Module | None): Downsample layer at the end of the layer, if specified. + + Methods: + forward: Processes the input through the layer's blocks and optional downsampling. + extra_repr: Returns a string with the layer's parameters for printing. + + Examples: + >>> input_tensor = torch.randn(1, 3136, 192) + >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7) + >>> output = layer(input_tensor) + >>> print(output.shape) + torch.Size([1, 784, 384]) + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + """ + Initializes a BasicLayer in the TinyViT architecture. + + This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to + process feature maps at a specific resolution and dimensionality within the TinyViT model. + + Args: + dim (int): Dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width). + depth (int): Number of TinyViT blocks in this layer. + num_heads (int): Number of attention heads in each TinyViT block. + window_size (int): Size of the local window for attention computation. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop (float): Dropout rate. + drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block. + downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + local_conv_size (int): Kernel size for the local convolution in each TinyViT block. + activation (nn.Module): Activation function used in the MLP. + out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`. + + Raises: + ValueError: If `drop_path` is a list and its length doesn't match `depth`. + + Examples: + >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7) + >>> x = torch.randn(1, 56 * 56, 96) + >>> output = layer(x) + >>> print(output.shape) + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # Build blocks + self.blocks = nn.ModuleList( + [ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth) + ] + ) + + # Patch merging layer + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) + + def forward(self, x): + """Processes input through TinyViT blocks and optional downsampling.""" + for blk in self.blocks: + x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) + return x if self.downsample is None else self.downsample(x) + + def extra_repr(self) -> str: + """Returns a string with the layer's parameters for printing.""" + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class TinyViT(nn.Module): + """ + TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction. + + This class implements the TinyViT model, which combines elements of vision transformers and convolutional + neural networks for improved efficiency and performance on vision tasks. + + Attributes: + img_size (int): Input image size. + num_classes (int): Number of classification classes. + depths (List[int]): Number of blocks in each stage. + num_layers (int): Total number of layers in the network. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + patch_embed (PatchEmbed): Module for patch embedding. + patches_resolution (Tuple[int, int]): Resolution of embedded patches. + layers (nn.ModuleList): List of network layers. + norm_head (nn.LayerNorm): Layer normalization for the classifier head. + head (nn.Linear): Linear layer for final classification. + neck (nn.Sequential): Neck module for feature refinement. + + Methods: + set_layer_lr_decay: Sets layer-wise learning rate decay. + _init_weights: Initializes weights for linear and normalization layers. + no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay. + forward_features: Processes input through the feature extraction layers. + forward: Performs a forward pass through the entire network. + + Examples: + >>> model = TinyViT(img_size=224, num_classes=1000) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = model.forward_features(x) + >>> print(features.shape) + torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dims=(96, 192, 384, 768), + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_sizes=(7, 7, 14, 7), + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + """ + Initializes the TinyViT model. + + This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of + attention and convolution blocks, and a classification head. + + Args: + img_size (int): Size of the input image. + in_chans (int): Number of input channels. + num_classes (int): Number of classes for classification. + embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage. + depths (Tuple[int, int, int, int]): Number of blocks in each stage. + num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage. + window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + drop_rate (float): Dropout rate. + drop_path_rate (float): Stochastic depth rate. + use_checkpoint (bool): Whether to use checkpointing to save memory. + mbconv_expand_ratio (float): Expansion ratio for MBConv layer. + local_conv_size (int): Kernel size for local convolutions. + layer_lr_decay (float): Layer-wise learning rate decay factor. + + Examples: + >>> model = TinyViT(img_size=224, num_classes=1000) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = model(x) + >>> print(output.shape) + torch.Size([1, 1000]) + """ + super().__init__() + self.img_size = img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed( + in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation + ) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # Stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict( + dim=embed_dims[i_layer], + input_resolution=( + patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + ), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs, + ) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # Init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + + def set_layer_lr_decay(self, layer_lr_decay): + """Sets layer-wise learning rate decay for the TinyViT model based on depth.""" + decay_rate = layer_lr_decay + + # Layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + + def _set_lr_scale(m, scale): + """Sets the learning rate scale for each layer in the model based on the layer's depth.""" + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + """Checks if the learning rate scale attribute is present in module's parameters.""" + for p in m.parameters(): + assert hasattr(p, "lr_scale"), p.param_name + + self.apply(_check_lr_scale) + + @staticmethod + def _init_weights(m): + """Initializes weights for linear and normalization layers in the TinyViT model.""" + if isinstance(m, nn.Linear): + # NOTE: This initialization is needed only for training. + # trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + """Returns a set of keywords for parameters that should not use weight decay.""" + return {"attention_biases"} + + def forward_features(self, x): + """Processes input through feature extraction layers, returning spatial features.""" + x = self.patch_embed(x) # x input is (N, C, H, W) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + batch, _, channel = x.shape + x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel) + x = x.permute(0, 3, 1, 2) + return self.neck(x) + + def forward(self, x): + """Performs the forward pass through the TinyViT model, extracting features from the input image.""" + return self.forward_features(x) + + def set_imgsz(self, imgsz=[1024, 1024]): + """Set image size to make model compatible with different image sizes.""" + imgsz = [s // 4 for s in imgsz] + self.patches_resolution = imgsz + for i, layer in enumerate(self.layers): + input_resolution = ( + imgsz[0] // (2 ** (i - 1 if i == 3 else i)), + imgsz[1] // (2 ** (i - 1 if i == 3 else i)), + ) + layer.input_resolution = input_resolution + if layer.downsample is not None: + layer.downsample.input_resolution = input_resolution + if isinstance(layer, BasicLayer): + for b in layer.blocks: + b.input_resolution = input_resolution diff --git a/tracking/ultralytics/models/sam/modules/transformer.py b/tracking/ultralytics/models/sam/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8dac461803d6b6eb1d8e586ab57fc412d7abeb08 --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/transformer.py @@ -0,0 +1,351 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +from typing import Tuple, Type + +import torch +from torch import Tensor, nn + +from ultralytics.nn.modules import MLPBlock + + +class TwoWayTransformer(nn.Module): + """ + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class implements a specialized transformer decoder that attends to an input image using queries with + supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point + cloud processing. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes image and point embeddings through the transformer. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) + """ + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + Initialize a Two-Way Transformer for simultaneous attention to image and query points. + + Args: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Internal channel dimension for the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention mechanism. + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Process image and point embeddings through the Two-Way Transformer. + + Args: + image_embedding (Tensor): Image to attend to, with shape (B, embedding_dim, H, W). + image_pe (Tensor): Positional encoding to add to the image, with same shape as image_embedding. + point_embedding (Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim). + + Returns: + queries (Tensor): Processed point embeddings with shape (B, N_points, embedding_dim). + keys (Tensor): Processed image embeddings with shape (B, H*W, embedding_dim). + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + """ + A two-way attention block for simultaneous attention to image and query points. + + This class implements a specialized transformer block with four main layers: self-attention on sparse inputs, + cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense + inputs to sparse inputs. + + Attributes: + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after self-attention. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization after token-to-image attention. + mlp (MLPBlock): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after MLP block. + norm4 (nn.LayerNorm): Layer normalization after image-to-token attention. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer. + + Methods: + forward: Applies self-attention and cross-attention to queries and keys. + + Examples: + >>> embedding_dim, num_heads = 256, 8 + >>> block = TwoWayAttentionBlock(embedding_dim, num_heads) + >>> queries = torch.randn(1, 100, embedding_dim) + >>> keys = torch.randn(1, 1000, embedding_dim) + >>> query_pe = torch.randn(1, 100, embedding_dim) + >>> key_pe = torch.randn(1, 1000, embedding_dim) + >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points. + + This block implements a specialized transformer layer with four main components: self-attention on sparse + inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention + of dense inputs to sparse inputs. + + Args: + embedding_dim (int): Channel dimension of the embeddings. + num_heads (int): Number of attention heads in the attention layers. + mlp_dim (int): Hidden dimension of the MLP block. + activation (Type[nn.Module]): Activation function for the MLP block. + attention_downsample_rate (int): Downsampling rate for the attention mechanism. + skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer. + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: + """ + Apply two-way attention to process query and key embeddings in a transformer block. + + Args: + queries (Tensor): Query embeddings with shape (B, N_queries, embedding_dim). + keys (Tensor): Key embeddings with shape (B, N_keys, embedding_dim). + query_pe (Tensor): Positional encodings for queries with same shape as queries. + key_pe (Tensor): Positional encodings for keys with same shape as keys. + + Returns: + queries (Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim). + keys (Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim). + """ + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer with downscaling capability for embedding size after projection. + + This class implements a multi-head attention mechanism with the option to downsample the internal + dimension of queries, keys, and values. + + Attributes: + embedding_dim (int): Dimensionality of input embeddings. + kv_in_dim (int): Dimensionality of key and value inputs. + internal_dim (int): Internal dimension after downsampling. + num_heads (int): Number of attention heads. + q_proj (nn.Linear): Linear projection for queries. + k_proj (nn.Linear): Linear projection for keys. + v_proj (nn.Linear): Linear projection for values. + out_proj (nn.Linear): Linear projection for output. + + Methods: + _separate_heads: Separates input tensor into attention heads. + _recombine_heads: Recombines separated attention heads. + forward: Computes attention output for given query, key, and value tensors. + + Examples: + >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2) + >>> q = torch.randn(1, 100, 256) + >>> k = v = torch.randn(1, 50, 256) + >>> output = attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 100, 256]) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + kv_in_dim: int = None, + ) -> None: + """ + Initialize the Attention module with specified dimensions and settings. + + Args: + embedding_dim (int): Dimensionality of input embeddings. + num_heads (int): Number of attention heads. + downsample_rate (int): Factor by which internal dimensions are downsampled. + kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim. + + Raises: + AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate). + """ + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + @staticmethod + def _separate_heads(x: Tensor, num_heads: int) -> Tensor: + """Separate the input tensor into the specified number of attention heads.""" + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + @staticmethod + def _recombine_heads(x: Tensor) -> Tensor: + """Recombine separated attention heads into a single tensor.""" + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """ + Apply multi-head attention to query, key, and value tensors with optional downsampling. + + Args: + q (Tensor): Query tensor with shape (B, N_q, embedding_dim). + k (Tensor): Key tensor with shape (B, N_k, embedding_dim). + v (Tensor): Value tensor with shape (B, N_k, embedding_dim). + + Returns: + (Tensor): Output tensor after attention with shape (B, N_q, embedding_dim). + """ + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + return self.out_proj(out) diff --git a/tracking/ultralytics/models/sam/modules/utils.py b/tracking/ultralytics/models/sam/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..346bd68a364f5ecc1ae47bfdba686d9be2e8a12f --- /dev/null +++ b/tracking/ultralytics/models/sam/modules/utils.py @@ -0,0 +1,293 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import Tuple + +import torch +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select the closest conditioning frames to a given frame index. + + Args: + frame_idx (int): Current frame index. + cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices. + max_cond_frame_num (int): Maximum number of conditioning frames to select. + + Returns: + (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries: + - selected_outputs: Selected items from cond_frame_outputs. + - unselected_outputs: Items not selected from cond_frame_outputs. + + Examples: + >>> frame_idx = 5 + >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"} + >>> max_cond_frame_num = 2 + >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num) + >>> print(selected) + {3: 'b', 7: 'c'} + >>> print(unselected) + {1: 'a', 9: 'd'} + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # The closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # The closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # Add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs} + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """Generate 1D sinusoidal positional embeddings for given positions and dimensions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def init_t_xy(end_x: int, end_y: int): + """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + """Compute axial complex exponential positional encodings for 2D spatial positions in a grid.""" + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """Reshape frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility.""" + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + """Apply rotary positional encoding to query and key tensors using complex-valued frequency components.""" + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # No keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # Repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +def window_partition(x, window_size): + """ + Partition input tensor into non-overlapping windows with padding if needed. + + Args: + x (torch.Tensor): Input tensor with shape (B, H, W, C). + window_size (int): Size of each window. + + Returns: + (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing: + - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C). + - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition. + + Examples: + >>> x = torch.randn(1, 16, 16, 3) + >>> windows, (Hp, Wp) = window_partition(x, window_size=4) + >>> print(windows.shape, Hp, Wp) + torch.Size([16, 4, 4, 3]) 16 16 + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Unpartition windowed sequences into original sequences and remove padding. + + This function reverses the windowing process, reconstructing the original input from windowed segments + and removing any padding that was added during the windowing process. + + Args: + windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size, + window_size, C), where B is the batch size, num_windows is the number of windows, window_size is + the size of each window, and C is the number of channels. + window_size (int): Size of each window. + pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing. + hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing. + + Returns: + (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W + are the original height and width, and C is the number of channels. + + Examples: + >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels + >>> pad_hw = (16, 16) # Padded height and width + >>> hw = (15, 14) # Original height and width + >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw) + >>> print(x.shape) + torch.Size([1, 15, 14, 64]) + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Extract relative positional embeddings based on query and key sizes. + + Args: + q_size (int): Size of the query. + k_size (int): Size of the key. + rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative + distance and C is the embedding dimension. + + Returns: + (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, + k_size, C). + + Examples: + >>> q_size, k_size = 8, 16 + >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1 + >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos) + >>> print(extracted_pos.shape) + torch.Size([8, 16, 64]) + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Add decomposed Relative Positional Embeddings to the attention map. + + This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2 + paper. It enhances the attention mechanism by incorporating spatial relationships between query and key + positions. + + Args: + attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w). + q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C). + rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C). + q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w). + k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w). + + Returns: + (torch.Tensor): Updated attention map with added relative positional embeddings, shape + (B, q_h * q_w, k_h * k_w). + + Examples: + >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8 + >>> attn = torch.rand(B, q_h * q_w, k_h * k_w) + >>> q = torch.rand(B, q_h * q_w, C) + >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C) + >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C) + >>> q_size, k_size = (q_h, q_w), (k_h, k_w) + >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size) + >>> print(updated_attn.shape) + torch.Size([1, 64, 64]) + + References: + https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + B, q_h * q_w, k_h * k_w + ) + + return attn diff --git a/tracking/ultralytics/models/sam/predict.py b/tracking/ultralytics/models/sam/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..b33bc5904faa22581a2f6844dc8cc24e83dc62ed --- /dev/null +++ b/tracking/ultralytics/models/sam/predict.py @@ -0,0 +1,1602 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) +from .build import build_sam + + +class Predictor(BasePredictor): + """ + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. + + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. + + Attributes: + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> results = predictor(bboxes=bboxes) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. + + Args: + cfg (dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor_example = Predictor(cfg=DEFAULT_CFG) + >>> predictor_example_with_imgsz = Predictor(overrides={"imgsz": 640}) + >>> predictor_example_with_callback = Predictor(_callbacks={"on_predict_start": custom_callback}) + """ + if overrides is None: + overrides = {} + overrides.update(dict(task="segment", mode="predict", batch=1)) + super().__init__(cfg, overrides, _callbacks) + self.args.retina_masks = True + self.im = None + self.features = None + self.prompts = {} + self.segment_all = False + + def preprocess(self, im): + """ + Preprocess the input image for model inference. + + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. + + Args: + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. + + Returns: + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() + if not_tensor: + im = (im - self.mean) / self.std + return im + + def pre_transform(self, im): + """ + Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. + + Args: + im (List[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + return [letterbox(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> results = predictor(bboxes=[[0, 0, 100, 100]]) + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + labels = self.prompts.pop("labels", labels) + + if all(i is None for i in [bboxes, points, masks]): + return self.generate(im, *args, **kwargs) + + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert points.shape[-2] == labels.shape[-1], ( + f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}." + ) + points *= r + if points.ndim == 2: + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + return bboxes, points, labels, masks + + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. + + Args: + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. + + Returns: + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) + """ + import torchvision # scope for faster 'import ultralytics' + + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points,) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes + + def setup_model(self, model=None, verbose=True): + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. + verbose (bool): If True, prints selected device information. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) + """ + device = select_device(self.args.device, verbose=verbose) + if model is None: + model = self.get_model() + model.eval() + self.model = model.to(device) + self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) + + # Ultralytics compatibility settings + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def get_model(self): + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam(self.args.model) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. + + Args: + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. + + Returns: + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) + """ + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): + if len(masks) == 0: + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False + return results + + def setup_source(self, source): + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. + + Args: + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source("path/to/images") + >>> predictor.setup_source("video.mp4") + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. + """ + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(cv2.imread("path/to/image.jpg")) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], ( + f"SAM models only support square image size, but got {self.imgsz}." + ) + self.model.set_imgsz(self.imgsz) + return self.model.image_encoder(im) + + def set_prompts(self, prompts): + """Sets prompts for subsequent inference operations.""" + self.prompts = prompts + + def reset_image(self): + """Resets the current image and its features, clearing them for subsequent inference.""" + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. + + Args: + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. + + Returns: + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") + """ + import torchvision # scope for faster 'import ultralytics' + + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy().astype(np.uint8) + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) + + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> result = predictor(bboxes=bboxes)[0] + >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> result = predictor(image, bboxes=bboxes)[0] + >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + """ + features = self.get_im_features(im) if self.features is None else self.features + + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return points, labels, masks + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], ( + f"SAM 2 models only support square image size, but got {self.imgsz}." + ) + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (list, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = { + "num_frames": predictor.dataset.frames, + "point_inputs_per_obj": {}, # inputs points on each frame + "mask_inputs_per_obj": {}, # inputs mask on each frame + "constants": {}, # values that don't change across frames (so we only need to hold one copy of them) + # mapping between client-side object id and model-side object index + "obj_id_to_idx": OrderedDict(), + "obj_idx_to_id": OrderedDict(), + "obj_ids": [], + # A storage to hold the model's tracking results and states on each frame + "output_dict": { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + }, + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + "output_dict_per_obj": {}, + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + "temp_output_dict_per_obj": {}, + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + "consolidated_frame_inds": { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + }, + # metadata for each tracking frame (e.g. which direction it's tracked) + "tracking_has_started": False, + "frames_already_tracked": [], + } + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/tracking/ultralytics/models/utils/__init__.py b/tracking/ultralytics/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/tracking/ultralytics/models/utils/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/tracking/ultralytics/models/utils/loss.py b/tracking/ultralytics/models/utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0902053f30a5ba7bedc3733a174025fcfc19d9f2 --- /dev/null +++ b/tracking/ultralytics/models/utils/loss.py @@ -0,0 +1,408 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils.loss import FocalLoss, VarifocalLoss +from ultralytics.utils.metrics import bbox_iou + +from .ops import HungarianMatcher + + +class DETRLoss(nn.Module): + """ + DETR (DEtection TRansformer) Loss class for calculating various loss components. + + This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the + DETR object detection model. + + Attributes: + nc (int): Number of classes. + loss_gain (dict): Coefficients for different loss components. + aux_loss (bool): Whether to compute auxiliary losses. + use_fl (bool): Whether to use FocalLoss. + use_vfl (bool): Whether to use VarifocalLoss. + use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment. + uni_match_ind (int): Index of fixed layer to use if use_uni_match is True. + matcher (HungarianMatcher): Object to compute matching cost and indices. + fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None. + vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None. + device (torch.device): Device on which tensors are stored. + """ + + def __init__( + self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0 + ): + """ + Initialize DETR loss function with customizable components and gains. + + Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary + losses and various loss types. + + Args: + nc (int): Number of classes. + loss_gain (dict): Coefficients for different loss components. + aux_loss (bool): Whether to use auxiliary losses from each decoder layer. + use_fl (bool): Whether to use FocalLoss. + use_vfl (bool): Whether to use VarifocalLoss. + use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment. + uni_match_ind (int): Index of fixed layer for uni_match. + """ + super().__init__() + + if loss_gain is None: + loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1} + self.nc = nc + self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2}) + self.loss_gain = loss_gain + self.aux_loss = aux_loss + self.fl = FocalLoss() if use_fl else None + self.vfl = VarifocalLoss() if use_vfl else None + + self.use_uni_match = use_uni_match + self.uni_match_ind = uni_match_ind + self.device = None + + def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""): + """Compute classification loss based on predictions, target values, and ground truth scores.""" + # Logits: [b, query, num_classes], gt_class: list[[n, 1]] + name_class = f"loss_class{postfix}" + bs, nq = pred_scores.shape[:2] + # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes) + one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) + one_hot.scatter_(2, targets.unsqueeze(-1), 1) + one_hot = one_hot[..., :-1] + gt_scores = gt_scores.view(bs, nq, 1) * one_hot + + if self.fl: + if num_gts and self.vfl: + loss_cls = self.vfl(pred_scores, gt_scores, one_hot) + else: + loss_cls = self.fl(pred_scores, one_hot.float()) + loss_cls /= max(num_gts, 1) / nq + else: + loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss + + return {name_class: loss_cls.squeeze() * self.loss_gain["class"]} + + def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""): + """Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.""" + # Boxes: [b, query, 4], gt_bbox: list[[n, 4]] + name_bbox = f"loss_bbox{postfix}" + name_giou = f"loss_giou{postfix}" + + loss = {} + if len(gt_bboxes) == 0: + loss[name_bbox] = torch.tensor(0.0, device=self.device) + loss[name_giou] = torch.tensor(0.0, device=self.device) + return loss + + loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes) + loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) + loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) + loss[name_giou] = self.loss_gain["giou"] * loss[name_giou] + return {k: v.squeeze() for k, v in loss.items()} + + # This function is for future RT-DETR Segment models + # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''): + # # masks: [b, query, h, w], gt_mask: list[[n, H, W]] + # name_mask = f'loss_mask{postfix}' + # name_dice = f'loss_dice{postfix}' + # + # loss = {} + # if sum(len(a) for a in gt_mask) == 0: + # loss[name_mask] = torch.tensor(0., device=self.device) + # loss[name_dice] = torch.tensor(0., device=self.device) + # return loss + # + # num_gts = len(gt_mask) + # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices) + # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0] + # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now. + # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks, + # torch.tensor([num_gts], dtype=torch.float32)) + # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts) + # return loss + + # This function is for future RT-DETR Segment models + # @staticmethod + # def _dice_loss(inputs, targets, num_gts): + # inputs = F.sigmoid(inputs).flatten(1) + # targets = targets.flatten(1) + # numerator = 2 * (inputs * targets).sum(1) + # denominator = inputs.sum(-1) + targets.sum(-1) + # loss = 1 - (numerator + 1) / (denominator + 1) + # return loss.sum() / num_gts + + def _get_loss_aux( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + match_indices=None, + postfix="", + masks=None, + gt_mask=None, + ): + """ + Get auxiliary losses for intermediate decoder layers. + + Args: + pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers. + pred_scores (torch.Tensor): Predicted scores from auxiliary layers. + gt_bboxes (torch.Tensor): Ground truth bounding boxes. + gt_cls (torch.Tensor): Ground truth classes. + gt_groups (List[int]): Number of ground truths per image. + match_indices (List[tuple], optional): Pre-computed matching indices. + postfix (str): String to append to loss names. + masks (torch.Tensor, optional): Predicted masks if using segmentation. + gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation. + + Returns: + (dict): Dictionary of auxiliary losses. + """ + # NOTE: loss class, bbox, giou, mask, dice + loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) + if match_indices is None and self.use_uni_match: + match_indices = self.matcher( + pred_bboxes[self.uni_match_ind], + pred_scores[self.uni_match_ind], + gt_bboxes, + gt_cls, + gt_groups, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask, + ) + for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): + aux_masks = masks[i] if masks is not None else None + loss_ = self._get_loss( + aux_bboxes, + aux_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=aux_masks, + gt_mask=gt_mask, + postfix=postfix, + match_indices=match_indices, + ) + loss[0] += loss_[f"loss_class{postfix}"] + loss[1] += loss_[f"loss_bbox{postfix}"] + loss[2] += loss_[f"loss_giou{postfix}"] + # if masks is not None and gt_mask is not None: + # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) + # loss[3] += loss_[f'loss_mask{postfix}'] + # loss[4] += loss_[f'loss_dice{postfix}'] + + loss = { + f"loss_class_aux{postfix}": loss[0], + f"loss_bbox_aux{postfix}": loss[1], + f"loss_giou_aux{postfix}": loss[2], + } + # if masks is not None and gt_mask is not None: + # loss[f'loss_mask_aux{postfix}'] = loss[3] + # loss[f'loss_dice_aux{postfix}'] = loss[4] + return loss + + @staticmethod + def _get_index(match_indices): + """ + Extract batch indices, source indices, and destination indices from match indices. + + Args: + match_indices (List[tuple]): List of tuples containing matched indices. + + Returns: + (tuple): Tuple containing (batch_idx, src_idx) and dst_idx. + """ + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) + src_idx = torch.cat([src for (src, _) in match_indices]) + dst_idx = torch.cat([dst for (_, dst) in match_indices]) + return (batch_idx, src_idx), dst_idx + + def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): + """ + Assign predicted bounding boxes to ground truth bounding boxes based on match indices. + + Args: + pred_bboxes (torch.Tensor): Predicted bounding boxes. + gt_bboxes (torch.Tensor): Ground truth bounding boxes. + match_indices (List[tuple]): List of tuples containing matched indices. + + Returns: + (tuple): Tuple containing assigned predictions and ground truths. + """ + pred_assigned = torch.cat( + [ + t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (i, _) in zip(pred_bboxes, match_indices) + ] + ) + gt_assigned = torch.cat( + [ + t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (_, j) in zip(gt_bboxes, match_indices) + ] + ) + return pred_assigned, gt_assigned + + def _get_loss( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=None, + gt_mask=None, + postfix="", + match_indices=None, + ): + """ + Calculate losses for a single prediction layer. + + Args: + pred_bboxes (torch.Tensor): Predicted bounding boxes. + pred_scores (torch.Tensor): Predicted class scores. + gt_bboxes (torch.Tensor): Ground truth bounding boxes. + gt_cls (torch.Tensor): Ground truth classes. + gt_groups (List[int]): Number of ground truths per image. + masks (torch.Tensor, optional): Predicted masks if using segmentation. + gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation. + postfix (str): String to append to loss names. + match_indices (List[tuple], optional): Pre-computed matching indices. + + Returns: + (dict): Dictionary of losses. + """ + if match_indices is None: + match_indices = self.matcher( + pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask + ) + + idx, gt_idx = self._get_index(match_indices) + pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] + + bs, nq = pred_scores.shape[:2] + targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype) + targets[idx] = gt_cls[gt_idx] + + gt_scores = torch.zeros([bs, nq], device=pred_scores.device) + if len(gt_bboxes): + gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1) + + return { + **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix), + **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix), + # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {}) + } + + def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs): + """ + Calculate loss for predicted bounding boxes and scores. + + Args: + pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4]. + pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes]. + batch (dict): Batch information containing: + cls (torch.Tensor): Ground truth classes, shape [num_gts]. + bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4]. + gt_groups (List[int]): Number of ground truths for each image in the batch. + postfix (str): Postfix for loss names. + **kwargs (Any): Additional arguments, may include 'match_indices'. + + Returns: + (dict): Computed losses, including main and auxiliary (if enabled). + + Notes: + Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if + self.aux_loss is True. + """ + self.device = pred_bboxes.device + match_indices = kwargs.get("match_indices", None) + gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"] + + total_loss = self._get_loss( + pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices + ) + + if self.aux_loss: + total_loss.update( + self._get_loss_aux( + pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix + ) + ) + + return total_loss + + +class RTDETRDetectionLoss(DETRLoss): + """ + Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss. + + This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as + an additional denoising training loss when provided with denoising metadata. + """ + + def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None): + """ + Forward pass to compute detection loss with optional denoising loss. + + Args: + preds (tuple): Tuple containing predicted bounding boxes and scores. + batch (dict): Batch data containing ground truth information. + dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. + dn_scores (torch.Tensor, optional): Denoising scores. + dn_meta (dict, optional): Metadata for denoising. + + Returns: + (dict): Dictionary containing total loss and denoising loss if applicable. + """ + pred_bboxes, pred_scores = preds + total_loss = super().forward(pred_bboxes, pred_scores, batch) + + # Check for denoising metadata to compute denoising training loss + if dn_meta is not None: + dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"] + assert len(batch["gt_groups"]) == len(dn_pos_idx) + + # Get the match indices for denoising + match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"]) + + # Compute the denoising training loss + dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices) + total_loss.update(dn_loss) + else: + # If no denoising metadata is provided, set denoising loss to zero + total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()}) + + return total_loss + + @staticmethod + def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups): + """ + Get match indices for denoising. + + Args: + dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising. + dn_num_group (int): Number of denoising groups. + gt_groups (List[int]): List of integers representing number of ground truths per image. + + Returns: + (List[tuple]): List of tuples containing matched indices for denoising. + """ + dn_match_indices = [] + idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) + for i, num_gt in enumerate(gt_groups): + if num_gt > 0: + gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] + gt_idx = gt_idx.repeat(dn_num_group) + assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, " + f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." + dn_match_indices.append((dn_pos_idx[i], gt_idx)) + else: + dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) + return dn_match_indices diff --git a/tracking/ultralytics/models/utils/ops.py b/tracking/ultralytics/models/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..472c0c36d1794646f45f6f567abeac2e1c7d6050 --- /dev/null +++ b/tracking/ultralytics/models/utils/ops.py @@ -0,0 +1,254 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from ultralytics.utils.metrics import bbox_iou +from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh + + +class HungarianMatcher(nn.Module): + """ + A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an + end-to-end fashion. + + HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost + function that considers classification scores, bounding box coordinates, and optionally, mask predictions. + + Attributes: + cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'. + use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation. + with_mask (bool): Indicates whether the model makes mask predictions. + num_sample_points (int): The number of sample points used in mask cost calculation. + alpha (float): The alpha factor in Focal Loss calculation. + gamma (float): The gamma factor in Focal Loss calculation. + + Methods: + forward: Computes the assignment between predictions and ground truths for a batch. + _cost_mask: Computes the mask cost and dice cost if masks are predicted. + """ + + def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0): + """Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.""" + super().__init__() + if cost_gain is None: + cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1} + self.cost_gain = cost_gain + self.use_fl = use_fl + self.with_mask = with_mask + self.num_sample_points = num_sample_points + self.alpha = alpha + self.gamma = gamma + + def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): + """ + Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal + matching between predictions and ground truth based on these costs. + + Args: + pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4). + pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes). + gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ). + gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4). + gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for + each image. + masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width). + gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width). + + Returns: + (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where: + - index_i is the tensor of indices of the selected predictions (in order) + - index_j is the tensor of indices of the corresponding selected ground truth targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, nq, nc = pred_scores.shape + + if sum(gt_groups) == 0: + return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)] + + # We flatten to compute the cost matrices in a batch + # (batch_size * num_queries, num_classes) + pred_scores = pred_scores.detach().view(-1, nc) + pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1) + # (batch_size * num_queries, 4) + pred_bboxes = pred_bboxes.detach().view(-1, 4) + + # Compute the classification cost + pred_scores = pred_scores[:, gt_cls] + if self.use_fl: + neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log()) + pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) + cost_class = pos_cost_class - neg_cost_class + else: + cost_class = -pred_scores + + # Compute the L1 cost between boxes + cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt) + + # Compute the GIoU cost between boxes, (bs*num_queries, num_gt) + cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) + + # Final cost matrix + C = ( + self.cost_gain["class"] * cost_class + + self.cost_gain["bbox"] * cost_bbox + + self.cost_gain["giou"] * cost_giou + ) + # Compute the mask cost and dice cost + if self.with_mask: + C += self._cost_mask(bs, gt_groups, masks, gt_mask) + + # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries) + C[C.isnan() | C.isinf()] = 0.0 + + C = C.view(bs, nq, -1).cpu() + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] + gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt) + return [ + (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) + for k, (i, j) in enumerate(indices) + ] + + # This function is for future RT-DETR Segment models + # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): + # assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`' + # # all masks share the same set of points for efficient matching + # sample_points = torch.rand([bs, 1, self.num_sample_points, 2]) + # sample_points = 2.0 * sample_points - 1.0 + # + # out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2) + # out_mask = out_mask.flatten(0, 1) + # + # tgt_mask = torch.cat(gt_mask).unsqueeze(1) + # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) + # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) + # + # with torch.amp.autocast("cuda", enabled=False): + # # binary cross entropy cost + # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') + # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') + # cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T) + # cost_mask /= self.num_sample_points + # + # # dice cost + # out_mask = F.sigmoid(out_mask) + # numerator = 2 * torch.matmul(out_mask, tgt_mask.T) + # denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0) + # cost_dice = 1 - (numerator + 1) / (denominator + 1) + # + # C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice + # return C + + +def get_cdn_group( + batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False +): + """ + Get contrastive denoising training group with positive and negative samples from ground truths. + + Args: + batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes' + (torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length + indicating the number of gts of each image. + num_classes (int): Number of classes. + num_queries (int): Number of queries. + class_embed (torch.Tensor): Embedding weights to map class labels to embedding space. + num_dn (int, optional): Number of denoising queries. + cls_noise_ratio (float, optional): Noise ratio for class labels. + box_noise_scale (float, optional): Noise scale for bounding box coordinates. + training (bool, optional): If it's in training mode. + + Returns: + padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising. + padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising. + attn_mask (Optional[torch.Tensor]): The attention mask for denoising. + dn_meta (Optional[Dict]): Meta information for denoising. + """ + if (not training) or num_dn <= 0 or batch is None: + return None, None, None, None + gt_groups = batch["gt_groups"] + total_num = sum(gt_groups) + max_nums = max(gt_groups) + if max_nums == 0: + return None, None, None, None + + num_group = num_dn // max_nums + num_group = 1 if num_group == 0 else num_group + # Pad gt to max_num of a batch + bs = len(gt_groups) + gt_cls = batch["cls"] # (bs*num, ) + gt_bbox = batch["bboxes"] # bs*num, 4 + b_idx = batch["batch_idx"] + + # Each group has positive and negative queries. + dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, ) + dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4 + dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, ) + + # Positive and negative mask + # (bs*num*num_group, ), the second total_num*num_group part as negative samples + neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num + + if cls_noise_ratio > 0: + # Half of bbox prob + mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5) + idx = torch.nonzero(mask).squeeze(-1) + # Randomly put a new one here + new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device) + dn_cls[idx] = new_label + + if box_noise_scale > 0: + known_bbox = xywh2xyxy(dn_bbox) + + diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4 + + rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(dn_bbox) + rand_part[neg_idx] += 1.0 + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + dn_bbox = xyxy2xywh(known_bbox) + dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid + + num_dn = int(max_nums * 2 * num_group) # total denoising queries + # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)]) + dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256 + padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device) + padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device) + + map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups]) + pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0) + + map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)]) + padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed + padding_bbox[(dn_b_idx, map_indices)] = dn_bbox + + tgt_size = num_dn + num_queries + attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool) + # Match query cannot see the reconstruct + attn_mask[num_dn:, :num_dn] = True + # Reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True + if i == num_group - 1: + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True + else: + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True + dn_meta = { + "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], + "dn_num_group": num_group, + "dn_num_split": [num_dn, num_queries], + } + + return ( + padding_cls.to(class_embed.device), + padding_bbox.to(class_embed.device), + attn_mask.to(class_embed.device), + dn_meta, + ) diff --git a/tracking/ultralytics/models/yolo/__init__.py b/tracking/ultralytics/models/yolo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95006d437ccc1f3d6bcb3822b94a08834498b828 --- /dev/null +++ b/tracking/ultralytics/models/yolo/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo import classify, detect, obb, pose, segment, world + +from .model import YOLO, YOLOWorld + +__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld" diff --git a/tracking/ultralytics/models/yolo/classify/__init__.py b/tracking/ultralytics/models/yolo/classify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a10629229f8a8e5769480d004bbdbe42a633e79 --- /dev/null +++ b/tracking/ultralytics/models/yolo/classify/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo.classify.predict import ClassificationPredictor +from ultralytics.models.yolo.classify.train import ClassificationTrainer +from ultralytics.models.yolo.classify.val import ClassificationValidator + +__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator" diff --git a/tracking/ultralytics/models/yolo/classify/predict.py b/tracking/ultralytics/models/yolo/classify/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..af16bb71a903fc4efb5b49d58fd38ef280d5d466 --- /dev/null +++ b/tracking/ultralytics/models/yolo/classify/predict.py @@ -0,0 +1,78 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import torch +from PIL import Image + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops + + +class ClassificationPredictor(BasePredictor): + """ + A class extending the BasePredictor class for prediction based on a classification model. + + This predictor handles the specific requirements of classification models, including preprocessing images + and postprocessing predictions to generate classification results. + + Attributes: + args (dict): Configuration arguments for the predictor. + _legacy_transform_name (str): Name of the legacy transform class for backward compatibility. + + Methods: + preprocess: Convert input images to model-compatible format. + postprocess: Process model predictions into Results objects. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Examples: + >>> from ultralytics.utils import ASSETS + >>> from ultralytics.models.yolo.classify import ClassificationPredictor + >>> args = dict(model="yolo11n-cls.pt", source=ASSETS) + >>> predictor = ClassificationPredictor(overrides=args) + >>> predictor.predict_cli() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "classify" + self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor" + + def preprocess(self, img): + """Convert input images to model-compatible tensor format with appropriate normalization.""" + if not isinstance(img, torch.Tensor): + is_legacy_transform = any( + self._legacy_transform_name in str(transform) for transform in self.transforms.transforms + ) + if is_legacy_transform: # to handle legacy transforms + img = torch.stack([self.transforms(im) for im in img], dim=0) + else: + img = torch.stack( + [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0 + ) + img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) + return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 + + def postprocess(self, preds, img, orig_imgs): + """ + Process predictions to return Results objects with classification probabilities. + + Args: + preds (torch.Tensor): Raw predictions from the model. + img (torch.Tensor): Input images after preprocessing. + orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing. + + Returns: + (List[Results]): List of Results objects containing classification results for each image. + """ + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + preds = preds[0] if isinstance(preds, (list, tuple)) else preds + return [ + Results(orig_img, path=img_path, names=self.model.names, probs=pred) + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]) + ] diff --git a/tracking/ultralytics/models/yolo/classify/train.py b/tracking/ultralytics/models/yolo/classify/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c025d1ccc56d62326f8096903f0b6cb13c557ab7 --- /dev/null +++ b/tracking/ultralytics/models/yolo/classify/train.py @@ -0,0 +1,217 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +import torch + +from ultralytics.data import ClassificationDataset, build_dataloader +from ultralytics.engine.trainer import BaseTrainer +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel +from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK +from ultralytics.utils.plotting import plot_images, plot_results +from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first + + +class ClassificationTrainer(BaseTrainer): + """ + A class extending the BaseTrainer class for training based on a classification model. + + This trainer handles the training process for image classification tasks, supporting both YOLO classification models + and torchvision models. + + Attributes: + model (ClassificationModel): The classification model to be trained. + data (dict): Dictionary containing dataset information including class names and number of classes. + loss_names (List[str]): Names of the loss functions used during training. + validator (ClassificationValidator): Validator instance for model evaluation. + + Methods: + set_model_attributes: Set the model's class names from the loaded dataset. + get_model: Return a modified PyTorch model configured for training. + setup_model: Load, create or download model for classification. + build_dataset: Create a ClassificationDataset instance. + get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing. + preprocess_batch: Preprocess a batch of images and classes. + progress_string: Return a formatted string showing training progress. + get_validator: Return an instance of ClassificationValidator. + label_loss_items: Return a loss dict with labelled training loss items. + plot_metrics: Plot metrics from a CSV file. + final_eval: Evaluate trained model and save validation results. + plot_training_samples: Plot training samples with their annotations. + + Examples: + >>> from ultralytics.models.yolo.classify import ClassificationTrainer + >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3) + >>> trainer = ClassificationTrainer(overrides=args) + >>> trainer.train() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.""" + if overrides is None: + overrides = {} + overrides["task"] = "classify" + if overrides.get("imgsz") is None: + overrides["imgsz"] = 224 + super().__init__(cfg, overrides, _callbacks) + + def set_model_attributes(self): + """Set the YOLO model's class names from the loaded dataset.""" + self.model.names = self.data["names"] + + def get_model(self, cfg=None, weights=None, verbose=True): + """ + Return a modified PyTorch model configured for training YOLO. + + Args: + cfg (Any): Model configuration. + weights (Any): Pre-trained model weights. + verbose (bool): Whether to display model information. + + Returns: + (ClassificationModel): Configured PyTorch model for classification. + """ + model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + for m in model.modules(): + if not self.args.pretrained and hasattr(m, "reset_parameters"): + m.reset_parameters() + if isinstance(m, torch.nn.Dropout) and self.args.dropout: + m.p = self.args.dropout # set dropout + for p in model.parameters(): + p.requires_grad = True # for training + return model + + def setup_model(self): + """ + Load, create or download model for classification tasks. + + Returns: + (Any): Model checkpoint if applicable, otherwise None. + """ + import torchvision # scope for faster 'import ultralytics' + + if str(self.model) in torchvision.models.__dict__: + self.model = torchvision.models.__dict__[self.model]( + weights="IMAGENET1K_V1" if self.args.pretrained else None + ) + ckpt = None + else: + ckpt = super().setup_model() + ClassificationModel.reshape_outputs(self.model, self.data["nc"]) + return ckpt + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Create a ClassificationDataset instance given an image path and mode. + + Args: + img_path (str): Path to the dataset images. + mode (str): Dataset mode ('train', 'val', or 'test'). + batch (Any): Batch information (unused in this implementation). + + Returns: + (ClassificationDataset): Dataset for the specified mode. + """ + return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode) + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """ + Return PyTorch DataLoader with transforms to preprocess images. + + Args: + dataset_path (str): Path to the dataset. + batch_size (int): Number of images per batch. + rank (int): Process rank for distributed training. + mode (str): 'train', 'val', or 'test' mode. + + Returns: + (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode. + """ + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode) + + loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) + # Attach inference transforms + if mode != "train": + if is_parallel(self.model): + self.model.module.transforms = loader.dataset.torch_transforms + else: + self.model.transforms = loader.dataset.torch_transforms + return loader + + def preprocess_batch(self, batch): + """Preprocesses a batch of images and classes.""" + batch["img"] = batch["img"].to(self.device) + batch["cls"] = batch["cls"].to(self.device) + return batch + + def progress_string(self): + """Returns a formatted string showing training progress.""" + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) + + def get_validator(self): + """Returns an instance of ClassificationValidator for validation.""" + self.loss_names = ["loss"] + return yolo.classify.ClassificationValidator( + self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Return a loss dict with labelled training loss items tensor. + + Args: + loss_items (torch.Tensor, optional): Loss tensor items. + prefix (str): Prefix to prepend to loss names. + + Returns: + (Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None. + """ + keys = [f"{prefix}/{x}" for x in self.loss_names] + if loss_items is None: + return keys + loss_items = [round(float(loss_items), 5)] + return dict(zip(keys, loss_items)) + + def plot_metrics(self): + """Plot metrics from a CSV file.""" + plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png + + def final_eval(self): + """Evaluate trained model and save validation results.""" + for f in self.last, self.best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if f is self.best: + LOGGER.info(f"\nValidating {f}...") + self.validator.args.data = self.args.data + self.validator.args.plots = self.args.plots + self.metrics = self.validator(model=f) + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") + + def plot_training_samples(self, batch, ni): + """ + Plot training samples with their annotations. + + Args: + batch (Dict[str, torch.Tensor]): Batch containing images and class labels. + ni (int): Number of iterations. + """ + plot_images( + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) diff --git a/tracking/ultralytics/models/yolo/classify/val.py b/tracking/ultralytics/models/yolo/classify/val.py new file mode 100644 index 0000000000000000000000000000000000000000..8be7b4df77fef6ff377df76a3995bfad28e4dcd3 --- /dev/null +++ b/tracking/ultralytics/models/yolo/classify/val.py @@ -0,0 +1,139 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.data import ClassificationDataset, build_dataloader +from ultralytics.engine.validator import BaseValidator +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix +from ultralytics.utils.plotting import plot_images + + +class ClassificationValidator(BaseValidator): + """ + A class extending the BaseValidator class for validation based on a classification model. + + This validator handles the validation process for classification models, including metrics calculation, + confusion matrix generation, and visualization of results. + + Attributes: + targets (List[torch.Tensor]): Ground truth class labels. + pred (List[torch.Tensor]): Model predictions. + metrics (ClassifyMetrics): Object to calculate and store classification metrics. + names (dict): Mapping of class indices to class names. + nc (int): Number of classes. + confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes. + + Methods: + get_desc: Return a formatted string summarizing classification metrics. + init_metrics: Initialize confusion matrix, class names, and tracking containers. + preprocess: Preprocess input batch by moving data to device. + update_metrics: Update running metrics with model predictions and batch targets. + finalize_metrics: Finalize metrics including confusion matrix and processing speed. + postprocess: Extract the primary prediction from model output. + get_stats: Calculate and return a dictionary of metrics. + build_dataset: Create a ClassificationDataset instance for validation. + get_dataloader: Build and return a data loader for classification validation. + print_results: Print evaluation metrics for the classification model. + plot_val_samples: Plot validation image samples with their ground truth labels. + plot_predictions: Plot images with their predicted class labels. + + Examples: + >>> from ultralytics.models.yolo.classify import ClassificationValidator + >>> args = dict(model="yolo11n-cls.pt", data="imagenet10") + >>> validator = ClassificationValidator(args=args) + >>> validator() + + Notes: + Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize ClassificationValidator with dataloader, save directory, and other parameters.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.targets = None + self.pred = None + self.args.task = "classify" + self.metrics = ClassifyMetrics() + + def get_desc(self): + """Return a formatted string summarizing classification metrics.""" + return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc") + + def init_metrics(self, model): + """Initialize confusion matrix, class names, and tracking containers for predictions and targets.""" + self.names = model.names + self.nc = len(model.names) + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify") + self.pred = [] + self.targets = [] + + def preprocess(self, batch): + """Preprocess input batch by moving data to device and converting to appropriate dtype.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() + batch["cls"] = batch["cls"].to(self.device) + return batch + + def update_metrics(self, preds, batch): + """Update running metrics with model predictions and batch targets.""" + n5 = min(len(self.names), 5) + self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu()) + self.targets.append(batch["cls"].type(torch.int32).cpu()) + + def finalize_metrics(self, *args, **kwargs): + """Finalize metrics including confusion matrix and processing speed.""" + self.confusion_matrix.process_cls_preds(self.pred, self.targets) + if self.args.plots: + for normalize in True, False: + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + self.metrics.save_dir = self.save_dir + + def postprocess(self, preds): + """Extract the primary prediction from model output if it's in a list or tuple format.""" + return preds[0] if isinstance(preds, (list, tuple)) else preds + + def get_stats(self): + """Calculate and return a dictionary of metrics by processing targets and predictions.""" + self.metrics.process(self.targets, self.pred) + return self.metrics.results_dict + + def build_dataset(self, img_path): + """Create a ClassificationDataset instance for validation.""" + return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split) + + def get_dataloader(self, dataset_path, batch_size): + """Build and return a data loader for classification validation.""" + dataset = self.build_dataset(dataset_path) + return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) + + def print_results(self): + """Print evaluation metrics for the classification model.""" + pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5)) + + def plot_val_samples(self, batch, ni): + """Plot validation image samples with their ground truth labels.""" + plot_images( + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plot images with their predicted class labels and save the visualization.""" + plot_images( + batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=torch.argmax(preds, dim=1), + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred diff --git a/tracking/ultralytics/models/yolo/detect/__init__.py b/tracking/ultralytics/models/yolo/detect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..caece94ae0c06d51e8a4a4c5a083d74a6c731c90 --- /dev/null +++ b/tracking/ultralytics/models/yolo/detect/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import DetectionPredictor +from .train import DetectionTrainer +from .val import DetectionValidator + +__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator" diff --git a/tracking/ultralytics/models/yolo/detect/predict.py b/tracking/ultralytics/models/yolo/detect/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8b579e51de0eb8fd92eeb453851fd0093e0761 --- /dev/null +++ b/tracking/ultralytics/models/yolo/detect/predict.py @@ -0,0 +1,83 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import ops + + +class DetectionPredictor(BasePredictor): + """ + A class extending the BasePredictor class for prediction based on a detection model. + + This predictor specializes in object detection tasks, processing model outputs into meaningful detection results + with bounding boxes and class predictions. + + Attributes: + args (namespace): Configuration arguments for the predictor. + model (nn.Module): The detection model used for inference. + batch (list): Batch of images and metadata for processing. + + Methods: + postprocess: Process raw model predictions into detection results. + construct_results: Build Results objects from processed predictions. + construct_result: Create a single Result object from a prediction. + + Examples: + >>> from ultralytics.utils import ASSETS + >>> from ultralytics.models.yolo.detect import DetectionPredictor + >>> args = dict(model="yolo11n.pt", source=ASSETS) + >>> predictor = DetectionPredictor(overrides=args) + >>> predictor.predict_cli() + """ + + def postprocess(self, preds, img, orig_imgs, **kwargs): + """Post-processes predictions and returns a list of Results objects.""" + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + self.args.classes, + self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + end2end=getattr(self.model, "end2end", False), + rotated=self.args.task == "obb", + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + return self.construct_results(preds, img, orig_imgs, **kwargs) + + def construct_results(self, preds, img, orig_imgs): + """ + Construct a list of Results objects from model predictions. + + Args: + preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image. + img (torch.Tensor): Batch of preprocessed images used for inference. + orig_imgs (List[np.ndarray]): List of original images before preprocessing. + + Returns: + (List[Results]): List of Results objects containing detection information for each image. + """ + return [ + self.construct_result(pred, img, orig_img, img_path) + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]) + ] + + def construct_result(self, pred, img, orig_img, img_path): + """ + Construct a single Results object from one image prediction. + + Args: + pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections. + img (torch.Tensor): Preprocessed image tensor used for inference. + orig_img (np.ndarray): Original image before preprocessing. + img_path (str): Path to the original image file. + + Returns: + (Results): Results object containing the original image, image path, class names, and scaled bounding boxes. + """ + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]) diff --git a/tracking/ultralytics/models/yolo/detect/train.py b/tracking/ultralytics/models/yolo/detect/train.py new file mode 100644 index 0000000000000000000000000000000000000000..312b5303764d473f5fbc8717e917bc02025b6bb6 --- /dev/null +++ b/tracking/ultralytics/models/yolo/detect/train.py @@ -0,0 +1,217 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +import random +from copy import copy + +import numpy as np +import torch.nn as nn + +from ultralytics.data import build_dataloader, build_yolo_dataset +from ultralytics.engine.trainer import BaseTrainer +from ultralytics.models import yolo +from ultralytics.nn.tasks import DetectionModel +from ultralytics.utils import LOGGER, RANK +from ultralytics.utils.plotting import plot_images, plot_labels, plot_results +from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first + + +class DetectionTrainer(BaseTrainer): + """ + A class extending the BaseTrainer class for training based on a detection model. + + This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models + for object detection. + + Attributes: + model (DetectionModel): The YOLO detection model being trained. + data (dict): Dictionary containing dataset information including class names and number of classes. + loss_names (Tuple[str]): Names of the loss components used in training (box_loss, cls_loss, dfl_loss). + + Methods: + build_dataset: Build YOLO dataset for training or validation. + get_dataloader: Construct and return dataloader for the specified mode. + preprocess_batch: Preprocess a batch of images by scaling and converting to float. + set_model_attributes: Set model attributes based on dataset information. + get_model: Return a YOLO detection model. + get_validator: Return a validator for model evaluation. + label_loss_items: Return a loss dictionary with labeled training loss items. + progress_string: Return a formatted string of training progress. + plot_training_samples: Plot training samples with their annotations. + plot_metrics: Plot metrics from a CSV file. + plot_training_labels: Create a labeled training plot of the YOLO model. + auto_batch: Calculate optimal batch size based on model memory requirements. + + Examples: + >>> from ultralytics.models.yolo.detect import DetectionTrainer + >>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3) + >>> trainer = DetectionTrainer(overrides=args) + >>> trainer.train() + """ + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset for training or validation. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. + + Returns: + (Dataset): YOLO dataset object configured for the specified mode. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """ + Construct and return dataloader for the specified mode. + + Args: + dataset_path (str): Path to the dataset. + batch_size (int): Number of images per batch. + rank (int): Process rank for distributed training. + mode (str): 'train' for training dataloader, 'val' for validation dataloader. + + Returns: + (DataLoader): PyTorch dataloader object. + """ + assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}." + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode, batch_size) + shuffle = mode == "train" + if getattr(dataset, "rect", False) and shuffle: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") + shuffle = False + workers = self.args.workers if mode == "train" else self.args.workers * 2 + return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader + + def preprocess_batch(self, batch): + """ + Preprocess a batch of images by scaling and converting to float. + + Args: + batch (dict): Dictionary containing batch data with 'img' tensor. + + Returns: + (dict): Preprocessed batch with normalized images. + """ + batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 + if self.args.multi_scale: + imgs = batch["img"] + sz = ( + random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride)) + // self.stride + * self.stride + ) # size + sf = sz / max(imgs.shape[2:]) # scale factor + if sf != 1: + ns = [ + math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] + ] # new shape (stretched to gs-multiple) + imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + batch["img"] = imgs + return batch + + def set_model_attributes(self): + """Set model attributes based on dataset information.""" + # Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps) + # self.args.box *= 3 / nl # scale to layers + # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers + # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers + self.model.nc = self.data["nc"] # attach number of classes to model + self.model.names = self.data["names"] # attach class names to model + self.model.args = self.args # attach hyperparameters to model + # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc + + def get_model(self, cfg=None, weights=None, verbose=True): + """ + Return a YOLO detection model. + + Args: + cfg (str, optional): Path to model configuration file. + weights (str, optional): Path to model weights. + verbose (bool): Whether to display model information. + + Returns: + (DetectionModel): YOLO detection model. + """ + model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model + + def get_validator(self): + """Return a DetectionValidator for YOLO model validation.""" + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.detect.DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Return a loss dict with labeled training loss items tensor. + + Args: + loss_items (List[float], optional): List of loss values. + prefix (str): Prefix for keys in the returned dictionary. + + Returns: + (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys. + """ + keys = [f"{prefix}/{x}" for x in self.loss_names] + if loss_items is not None: + loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats + return dict(zip(keys, loss_items)) + else: + return keys + + def progress_string(self): + """Return a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) + + def plot_training_samples(self, batch, ni): + """ + Plot training samples with their annotations. + + Args: + batch (dict): Dictionary containing batch data. + ni (int): Number of iterations. + """ + plot_images( + images=batch["img"], + batch_idx=batch["batch_idx"], + cls=batch["cls"].squeeze(-1), + bboxes=batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plot metrics from a CSV file.""" + plot_results(file=self.csv, on_plot=self.on_plot) # save results.png + + def plot_training_labels(self): + """Create a labeled training plot of the YOLO model.""" + boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) + cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) + plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) + + def auto_batch(self): + """ + Get optimal batch size by calculating memory occupation of model. + + Returns: + (int): Optimal batch size. + """ + train_dataset = self.build_dataset(self.trainset, mode="train", batch=16) + max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation + return super().auto_batch(max_num_obj) diff --git a/tracking/ultralytics/models/yolo/detect/val.py b/tracking/ultralytics/models/yolo/detect/val.py new file mode 100644 index 0000000000000000000000000000000000000000..61705e3ac706a5d7986b183c72efdc1fec40b4c6 --- /dev/null +++ b/tracking/ultralytics/models/yolo/detect/val.py @@ -0,0 +1,462 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.data import build_dataloader, build_yolo_dataset, converter +from ultralytics.engine.validator import BaseValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class DetectionValidator(BaseValidator): + """ + A class extending the BaseValidator class for validation based on a detection model. + + This class implements validation functionality specific to object detection tasks, including metrics calculation, + prediction processing, and visualization of results. + + Attributes: + nt_per_class (np.ndarray): Number of targets per class. + nt_per_image (np.ndarray): Number of targets per image. + is_coco (bool): Whether the dataset is COCO. + is_lvis (bool): Whether the dataset is LVIS. + class_map (list): Mapping from model class indices to dataset class indices. + metrics (DetMetrics): Object detection metrics calculator. + iouv (torch.Tensor): IoU thresholds for mAP calculation. + niou (int): Number of IoU thresholds. + lb (list): List for storing ground truth labels for hybrid saving. + jdict (list): List for storing JSON detection results. + stats (dict): Dictionary for storing statistics during validation. + + Examples: + >>> from ultralytics.models.yolo.detect import DetectionValidator + >>> args = dict(model="yolo11n.pt", data="coco8.yaml") + >>> validator = DetectionValidator(args=args) + >>> validator() + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initialize detection validator with necessary variables and settings. + + Args: + dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation. + save_dir (Path, optional): Directory to save results. + pbar (Any, optional): Progress bar for displaying progress. + args (dict, optional): Arguments for the validator. + _callbacks (list, optional): List of callback functions. + """ + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.nt_per_class = None + self.nt_per_image = None + self.is_coco = False + self.is_lvis = False + self.class_map = None + self.args.task = "detect" + self.metrics = DetMetrics(save_dir=self.save_dir) + self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95 + self.niou = self.iouv.numel() + self.lb = [] # for autolabelling + if self.args.save_hybrid and self.args.task == "detect": + LOGGER.warning( + "WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n" + "WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n" + ) + + def preprocess(self, batch): + """ + Preprocess batch of images for YOLO validation. + + Args: + batch (dict): Batch containing images and annotations. + + Returns: + (dict): Preprocessed batch. + """ + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 + for k in ["batch_idx", "cls", "bboxes"]: + batch[k] = batch[k].to(self.device) + + if self.args.save_hybrid and self.args.task == "detect": + height, width = batch["img"].shape[2:] + nb = len(batch["img"]) + bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) + self.lb = [ + torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) + for i in range(nb) + ] + + return batch + + def init_metrics(self, model): + """ + Initialize evaluation metrics for YOLO detection validation. + + Args: + model (torch.nn.Module): Model to validate. + """ + val = self.data.get(self.args.split, "") # validation path + self.is_coco = ( + isinstance(val, str) + and "coco" in val + and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt")) + ) # is COCO + self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS + self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1)) + self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val + self.names = model.names + self.nc = len(model.names) + self.end2end = getattr(model, "end2end", False) + self.metrics.names = self.names + self.metrics.plot = self.args.plots + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) + self.seen = 0 + self.jdict = [] + self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + + def get_desc(self): + """Return a formatted string summarizing class metrics of YOLO model.""" + return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") + + def postprocess(self, preds): + """ + Apply Non-maximum suppression to prediction outputs. + + Args: + preds (torch.Tensor): Raw predictions from the model. + + Returns: + (List[torch.Tensor]): Processed predictions after NMS. + """ + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + nc=self.nc, + multi_label=True, + agnostic=self.args.single_cls or self.args.agnostic_nms, + max_det=self.args.max_det, + end2end=self.end2end, + rotated=self.args.task == "obb", + ) + + def _prepare_batch(self, si, batch): + """ + Prepare a batch of images and annotations for validation. + + Args: + si (int): Batch index. + batch (dict): Batch data containing images and annotations. + + Returns: + (dict): Prepared batch with processed annotations. + """ + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels + return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad} + + def _prepare_pred(self, pred, pbatch): + """ + Prepare predictions for evaluation against ground truth. + + Args: + pred (torch.Tensor): Model predictions. + pbatch (dict): Prepared batch information. + + Returns: + (torch.Tensor): Prepared predictions in native space. + """ + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] + ) # native-space pred + return predn + + def update_metrics(self, preds, batch): + """ + Update metrics with new predictions and ground truth. + + Args: + preds (List[torch.Tensor]): List of predictions from the model. + batch (dict): Batch data containing ground truth. + """ + for si, pred in enumerate(preds): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + stat["target_img"] = cls.unique() + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + if self.args.save_txt: + self.save_one_txt( + predn, + self.args.save_conf, + pbatch["ori_shape"], + self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt", + ) + + def finalize_metrics(self, *args, **kwargs): + """ + Set final values for metrics speed and confusion matrix. + + Args: + *args (Any): Variable length argument list. + **kwargs (Any): Arbitrary keyword arguments. + """ + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def get_stats(self): + """ + Calculate and return metrics statistics. + + Returns: + (dict): Dictionary containing metrics results. + """ + stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy + self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc) + self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc) + stats.pop("target_img", None) + if len(stats): + self.metrics.process(**stats, on_plot=self.on_plot) + return self.metrics.results_dict + + def print_results(self): + """Print training/validation set metrics per class.""" + pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + if self.nt_per_class.sum() == 0: + LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") + + # Print results per class + if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): + for i, c in enumerate(self.metrics.ap_class_index): + LOGGER.info( + pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i)) + ) + + if self.args.plots: + for normalize in True, False: + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Return correct prediction matrix. + + Args: + detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is + (x1, y1, x2, y2, conf, class). + gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each + bounding box is of the format: (x1, y1, x2, y2). + gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices. + + Returns: + (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels. + """ + iou = box_iou(gt_bboxes, detections[:, :4]) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. + + Returns: + (Dataset): YOLO dataset. + """ + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride) + + def get_dataloader(self, dataset_path, batch_size): + """ + Construct and return dataloader. + + Args: + dataset_path (str): Path to the dataset. + batch_size (int): Size of each batch. + + Returns: + (torch.utils.data.DataLoader): Dataloader for validation. + """ + dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") + return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader + + def plot_val_samples(self, batch, ni): + """ + Plot validation image samples. + + Args: + batch (dict): Batch containing images and annotations. + ni (int): Batch index. + """ + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """ + Plot predicted bounding boxes on input images and save the result. + + Args: + batch (dict): Batch containing images and annotations. + preds (List[torch.Tensor]): List of predictions from the model. + ni (int): Batch index. + """ + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def save_one_txt(self, predn, save_conf, shape, file): + """ + Save YOLO detections to a txt file in normalized coordinates in a specific format. + + Args: + predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class). + save_conf (bool): Whether to save confidence scores. + shape (tuple): Shape of the original image. + file (Path): File path to save the detections. + """ + from ultralytics.engine.results import Results + + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + boxes=predn[:, :6], + ).save_txt(file, save_conf=save_conf) + + def pred_to_json(self, predn, filename): + """ + Serialize YOLO predictions to COCO json format. + + Args: + predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class). + filename (str): Image filename. + """ + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + } + ) + + def eval_json(self, stats): + """ + Evaluate YOLO output in JSON format and return performance statistics. + + Args: + stats (dict): Current statistics dictionary. + + Returns: + (dict): Updated statistics dictionary with COCO/LVIS evaluation results. + """ + if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict): + pred_json = self.save_dir / "predictions.json" # predictions + anno_json = ( + self.data["path"] + / "annotations" + / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json") + ) # annotations + pkg = "pycocotools" if self.is_coco else "lvis" + LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + for x in pred_json, anno_json: + assert x.is_file(), f"{x} file not found" + check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3") + if self.is_coco: + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + val = COCOeval(anno, pred, "bbox") + else: + from lvis import LVIS, LVISEval + + anno = LVIS(str(anno_json)) # init annotations api + pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path) + val = LVISEval(anno, pred, "bbox") + val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval + val.evaluate() + val.accumulate() + val.summarize() + if self.is_lvis: + val.print_results() # explicitly call print_results + # update mAP50-95 and mAP50 + stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = ( + val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]] + ) + except Exception as e: + LOGGER.warning(f"{pkg} unable to run: {e}") + return stats diff --git a/tracking/ultralytics/models/yolo/model.py b/tracking/ultralytics/models/yolo/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf9aaa1e8f253c81894d654d983d7ca0a353a33 --- /dev/null +++ b/tracking/ultralytics/models/yolo/model.py @@ -0,0 +1,110 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import ROOT, yaml_load + + +class YOLO(Model): + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolo11n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path, verbose=verbose) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt", verbose=False) -> None: + """ + Initialize YOLOv8-World model with a pre-trained model file. + + Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default + COCO class names. + + Args: + model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats. + verbose (bool): If True, prints additional information during initialization. + """ + super().__init__(model=model, task="detect", verbose=verbose) + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, + } + } + + def set_classes(self, classes): + """ + Set the model's class names for detection. + + Args: + classes (List(str)): A list of categories i.e. ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + if self.predictor: + self.predictor.model.names = classes diff --git a/tracking/ultralytics/models/yolo/obb/__init__.py b/tracking/ultralytics/models/yolo/obb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61e3e3c6a82b9addf0206bfe0bab63fa34c26108 --- /dev/null +++ b/tracking/ultralytics/models/yolo/obb/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import OBBPredictor +from .train import OBBTrainer +from .val import OBBValidator + +__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator" diff --git a/tracking/ultralytics/models/yolo/obb/predict.py b/tracking/ultralytics/models/yolo/obb/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b21564d5da130a4b3d6382a0c4bdbe7474ddbb --- /dev/null +++ b/tracking/ultralytics/models/yolo/obb/predict.py @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, ops + + +class OBBPredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model. + + This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated + bounding boxes. + + Attributes: + args (namespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded YOLO OBB model. + + Examples: + >>> from ultralytics.utils import ASSETS + >>> from ultralytics.models.yolo.obb import OBBPredictor + >>> args = dict(model="yolo11n-obb.pt", source=ASSETS) + >>> predictor = OBBPredictor(overrides=args) + >>> predictor.predict_cli() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize OBBPredictor with optional model and data configuration overrides.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "obb" + + def construct_result(self, pred, img, orig_img, img_path): + """ + Construct the result object from the prediction. + + Args: + pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where + the last dimension contains [x, y, w, h, confidence, class_id, angle]. + img (torch.Tensor): The image after preprocessing with shape (B, C, H, W). + orig_img (np.ndarray): The original image before preprocessing. + img_path (str): The path to the original image. + + Returns: + (Results): The result object containing the original image, image path, class names, and oriented bounding boxes. + """ + rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1)) + rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True) + obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1) + return Results(orig_img, path=img_path, names=self.model.names, obb=obb) diff --git a/tracking/ultralytics/models/yolo/obb/train.py b/tracking/ultralytics/models/yolo/obb/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0d18cbd073810bf474ea6de23f29dfd7f8c0be31 --- /dev/null +++ b/tracking/ultralytics/models/yolo/obb/train.py @@ -0,0 +1,48 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import OBBModel +from ultralytics.utils import DEFAULT_CFG, RANK + + +class OBBTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model. + + Attributes: + loss_names (Tuple[str]): Names of the loss components used during training. + + Methods: + get_model: Return OBBModel initialized with specified config and weights. + get_validator: Return an instance of OBBValidator for validation of YOLO model. + + Examples: + >>> from ultralytics.models.yolo.obb import OBBTrainer + >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3) + >>> trainer = OBBTrainer(overrides=args) + >>> trainer.train() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a OBBTrainer object with given arguments.""" + if overrides is None: + overrides = {} + overrides["task"] = "obb" + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return OBBModel initialized with specified config and weights.""" + model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + return model + + def get_validator(self): + """Return an instance of OBBValidator for validation of YOLO model.""" + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.obb.OBBValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) diff --git a/tracking/ultralytics/models/yolo/obb/val.py b/tracking/ultralytics/models/yolo/obb/val.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ade36128d269585060b441a80fb1a20d816970 --- /dev/null +++ b/tracking/ultralytics/models/yolo/obb/val.py @@ -0,0 +1,202 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.metrics import OBBMetrics, batch_probiou +from ultralytics.utils.plotting import output_to_rotated_target, plot_images + + +class OBBValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model. + + This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and + satellite imagery where objects can appear at various orientations. + + Attributes: + args (dict): Configuration arguments for the validator. + metrics (OBBMetrics): Metrics object for evaluating OBB model performance. + is_dota (bool): Flag indicating whether the validation dataset is in DOTA format. + + Methods: + init_metrics: Initialize evaluation metrics for YOLO. + _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix. + _prepare_batch: Prepare batch data for OBB validation. + _prepare_pred: Prepare predictions with scaled and padded bounding boxes. + plot_predictions: Plot predicted bounding boxes on input images. + pred_to_json: Serialize YOLO predictions to COCO json format. + save_one_txt: Save YOLO detections to a txt file in normalized coordinates. + eval_json: Evaluate YOLO output in JSON format and return performance statistics. + + Examples: + >>> from ultralytics.models.yolo.obb import OBBValidator + >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml") + >>> validator = OBBValidator(args=args) + >>> validator(model=args["model"]) + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.args.task = "obb" + self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True) + + def init_metrics(self, model): + """Initialize evaluation metrics for YOLO.""" + super().init_metrics(model) + val = self.data.get(self.args.split, "") # validation path + self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes. + + Args: + detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated + data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle). + gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is + represented as (x1, y1, x2, y2, angle). + gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes. + + Returns: + (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over + Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth. + + Examples: + >>> detections = torch.rand(100, 7) # 100 sample detections + >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes + >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels + >>> correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls) + + Note: + This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes. + """ + iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def _prepare_batch(self, si, batch): + """Prepare batch data for OBB validation with proper scaling and formatting.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels + return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad} + + def _prepare_pred(self, pred, pbatch): + """Prepare predictions by scaling bounding boxes to original image dimensions.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True + ) # native-space pred + return predn + + def plot_predictions(self, batch, preds, ni): + """Plot predicted bounding boxes on input images and save the result.""" + plot_images( + batch["img"], + *output_to_rotated_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def pred_to_json(self, predn, filename): + """Convert YOLO predictions to COCO JSON format with rotated bounding box information.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) + poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8) + for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(predn[i, 5].item())], + "score": round(predn[i, 4].item(), 5), + "rbox": [round(x, 3) for x in r], + "poly": [round(x, 3) for x in b], + } + ) + + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates using the Results class.""" + import numpy as np + + from ultralytics.engine.results import Results + + rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) + # xywh, r, conf, cls + obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1) + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + obb=obb, + ).save_txt(file, save_conf=save_conf) + + def eval_json(self, stats): + """Evaluate YOLO output in JSON format and save predictions in DOTA format.""" + if self.args.save_json and self.is_dota and len(self.jdict): + import json + import re + from collections import defaultdict + + pred_json = self.save_dir / "predictions.json" # predictions + pred_txt = self.save_dir / "predictions_txt" # predictions + pred_txt.mkdir(parents=True, exist_ok=True) + data = json.load(open(pred_json)) + # Save split results + LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...") + for d in data: + image_id = d["image_id"] + score = d["score"] + classname = self.names[d["category_id"] - 1].replace(" ", "-") + p = d["poly"] + + with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f: + f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n") + # Save merged results, this could result slightly lower map than using official merging script, + # because of the probiou calculation. + pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions + pred_merged_txt.mkdir(parents=True, exist_ok=True) + merged_results = defaultdict(list) + LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...") + for d in data: + image_id = d["image_id"].split("__")[0] + pattern = re.compile(r"\d+___\d+") + x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___")) + bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1 + bbox[0] += x + bbox[1] += y + bbox.extend([score, cls]) + merged_results[image_id].append(bbox) + for image_id, bbox in merged_results.items(): + bbox = torch.tensor(bbox) + max_wh = torch.max(bbox[:, :2]).item() * 2 + c = bbox[:, 6:7] * max_wh # classes + scores = bbox[:, 5] # scores + b = bbox[:, :5].clone() + b[:, :2] += c + # 0.3 could get results close to the ones from official merging script, even slightly better. + i = ops.nms_rotated(b, scores, 0.3) + bbox = bbox[i] + + b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8) + for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist(): + classname = self.names[int(x[-1])].replace(" ", "-") + p = [round(i, 3) for i in x[:-2]] # poly + score = round(x[-2], 3) + + with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f: + f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n") + + return stats diff --git a/tracking/ultralytics/models/yolo/pose/__init__.py b/tracking/ultralytics/models/yolo/pose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..396167b08f88632230296306abdec8eb508f8b78 --- /dev/null +++ b/tracking/ultralytics/models/yolo/pose/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import PosePredictor +from .train import PoseTrainer +from .val import PoseValidator + +__all__ = "PoseTrainer", "PoseValidator", "PosePredictor" diff --git a/tracking/ultralytics/models/yolo/pose/predict.py b/tracking/ultralytics/models/yolo/pose/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..fbfb2d833de643b633e92cd663dc51a8caaf51cc --- /dev/null +++ b/tracking/ultralytics/models/yolo/pose/predict.py @@ -0,0 +1,62 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, LOGGER, ops + + +class PosePredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on a pose model. + + This class specializes in pose estimation, handling keypoints detection alongside standard object detection + capabilities inherited from DetectionPredictor. + + Attributes: + args (namespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities. + + Methods: + construct_result: Constructs the result object from the prediction, including keypoints. + + Examples: + >>> from ultralytics.utils import ASSETS + >>> from ultralytics.models.yolo.pose import PosePredictor + >>> args = dict(model="yolo11n-pose.pt", source=ASSETS) + >>> predictor = PosePredictor(overrides=args) + >>> predictor.predict_cli() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize PosePredictor, set task to 'pose' and log a warning for using 'mps' as device.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "pose" + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) + + def construct_result(self, pred, img, orig_img, img_path): + """ + Construct the result object from the prediction, including keypoints. + + This method extends the parent class implementation by extracting keypoint data from predictions + and adding them to the result object. + + Args: + pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is + the number of detections, K is the number of keypoints, and D is the keypoint dimension. + img (torch.Tensor): The processed input image tensor with shape (B, C, H, W). + orig_img (np.ndarray): The original unprocessed image as a numpy array. + img_path (str): The path to the original image file. + + Returns: + (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints. + """ + result = super().construct_result(pred, img, orig_img, img_path) + # Extract keypoints from prediction and reshape according to model's keypoint shape + pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:] + # Scale keypoints coordinates to match the original image dimensions + pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape) + result.update(keypoints=pred_kpts) + return result diff --git a/tracking/ultralytics/models/yolo/pose/train.py b/tracking/ultralytics/models/yolo/pose/train.py new file mode 100644 index 0000000000000000000000000000000000000000..bda1729b4cbeefa7f6a196c429fba510580faeb8 --- /dev/null +++ b/tracking/ultralytics/models/yolo/pose/train.py @@ -0,0 +1,92 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import PoseModel +from ultralytics.utils import DEFAULT_CFG, LOGGER +from ultralytics.utils.plotting import plot_images, plot_results + + +class PoseTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training YOLO pose estimation models. + + This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization + of pose keypoints alongside bounding boxes. + + Attributes: + args (dict): Configuration arguments for training. + model (PoseModel): The pose estimation model being trained. + data (dict): Dataset configuration including keypoint shape information. + loss_names (Tuple[str]): Names of the loss components used in training. + + Methods: + get_model: Retrieves a pose estimation model with specified configuration. + set_model_attributes: Sets keypoints shape attribute on the model. + get_validator: Creates a validator instance for model evaluation. + plot_training_samples: Visualizes training samples with keypoints. + plot_metrics: Generates and saves training/validation metric plots. + + Examples: + >>> from ultralytics.models.yolo.pose import PoseTrainer + >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3) + >>> trainer = PoseTrainer(overrides=args) + >>> trainer.train() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a PoseTrainer object with specified configurations and overrides.""" + if overrides is None: + overrides = {} + overrides["task"] = "pose" + super().__init__(cfg, overrides, _callbacks) + + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Get pose estimation model with specified configuration and weights.""" + model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose) + if weights: + model.load(weights) + + return model + + def set_model_attributes(self): + """Sets keypoints shape attribute of PoseModel.""" + super().set_model_attributes() + self.model.kpt_shape = self.data["kpt_shape"] + + def get_validator(self): + """Returns an instance of the PoseValidator class for validation.""" + self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss" + return yolo.pose.PoseValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def plot_training_samples(self, batch, ni): + """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" + images = batch["img"] + kpts = batch["keypoints"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images( + images, + batch_idx, + cls, + bboxes, + kpts=kpts, + paths=paths, + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots training/val metrics.""" + plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png diff --git a/tracking/ultralytics/models/yolo/pose/val.py b/tracking/ultralytics/models/yolo/pose/val.py new file mode 100644 index 0000000000000000000000000000000000000000..19817f4b5bdc655275c754fc8c41a2d57bcabec2 --- /dev/null +++ b/tracking/ultralytics/models/yolo/pose/val.py @@ -0,0 +1,289 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class PoseValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on a pose model. + + This validator is specifically designed for pose estimation tasks, handling keypoints and implementing + specialized metrics for pose evaluation. + + Attributes: + sigma (np.ndarray): Sigma values for OKS calculation, either from OKS_SIGMA or ones divided by number of keypoints. + kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format. + args (dict): Arguments for the validator including task set to "pose". + metrics (PoseMetrics): Metrics object for pose evaluation. + + Methods: + preprocess: Preprocesses batch data for pose validation. + get_desc: Returns description of evaluation metrics. + init_metrics: Initializes pose metrics for the model. + _prepare_batch: Prepares a batch for processing. + _prepare_pred: Prepares and scales predictions for evaluation. + update_metrics: Updates metrics with new predictions. + _process_batch: Processes batch to compute IoU between detections and ground truth. + plot_val_samples: Plots validation samples with ground truth annotations. + plot_predictions: Plots model predictions. + save_one_txt: Saves detections to a text file. + pred_to_json: Converts predictions to COCO JSON format. + eval_json: Evaluates model using COCO JSON format. + + Examples: + >>> from ultralytics.models.yolo.pose import PoseValidator + >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml") + >>> validator = PoseValidator(args=args) + >>> validator() + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize a PoseValidator object with custom parameters and assigned attributes.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.sigma = None + self.kpt_shape = None + self.args.task = "pose" + self.metrics = PoseMetrics(save_dir=self.save_dir) + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) + + def preprocess(self, batch): + """Preprocess batch by converting keypoints data to float and moving it to the device.""" + batch = super().preprocess(batch) + batch["keypoints"] = batch["keypoints"].to(self.device).float() + return batch + + def get_desc(self): + """Return description of evaluation metrics in string format.""" + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Pose(P", + "R", + "mAP50", + "mAP50-95)", + ) + + def init_metrics(self, model): + """Initialize pose estimation metrics for YOLO model.""" + super().init_metrics(model) + self.kpt_shape = self.data["kpt_shape"] + is_pose = self.kpt_shape == [17, 3] + nkpt = self.kpt_shape[0] + self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt + self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + + def _prepare_batch(self, si, batch): + """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.""" + pbatch = super()._prepare_batch(si, batch) + kpts = batch["keypoints"][batch["batch_idx"] == si] + h, w = pbatch["imgsz"] + kpts = kpts.clone() + kpts[..., 0] *= w + kpts[..., 1] *= h + kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + pbatch["kpts"] = kpts + return pbatch + + def _prepare_pred(self, pred, pbatch): + """Prepare and scale keypoints in predictions for pose processing.""" + predn = super()._prepare_pred(pred, pbatch) + nk = pbatch["kpts"].shape[1] + pred_kpts = predn[:, 6:].view(len(predn), nk, -1) + ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + return predn, pred_kpts + + def update_metrics(self, preds, batch): + """ + Update metrics with new predictions and ground truth data. + + This method processes each prediction, compares it with ground truth, and updates various statistics + for performance evaluation. + + Args: + preds (List[torch.Tensor]): List of prediction tensors from the model. + batch (dict): Batch data containing images and ground truth annotations. + """ + for si, pred in enumerate(preds): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + stat["target_img"] = cls.unique() + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn, pred_kpts = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"]) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + if self.args.save_txt: + self.save_one_txt( + predn, + pred_kpts, + self.args.save_conf, + pbatch["ori_shape"], + self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt", + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None): + """ + Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth. + + Args: + detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each + detection is of the format (x1, y1, x2, y2, conf, class). + gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each + box is of the format (x1, y1, x2, y2). + gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices. + pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where + 51 corresponds to 17 keypoints each having 3 values. + gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints. + + Returns: + (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels, + where N is the number of detections. + + Notes: + `0.53` scale factor used in area computation is referenced from + https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384. + """ + if pred_kpts is not None and gt_kpts is not None: + # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384 + area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53 + iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area) + else: # boxes + iou = box_iou(gt_bboxes, detections[:, :4]) + + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def plot_val_samples(self, batch, ni): + """Plot and save validation set samples with ground truth bounding boxes and keypoints.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + kpts=batch["keypoints"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plot and save model predictions with bounding boxes and keypoints.""" + pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0) + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + kpts=pred_kpts, + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def save_one_txt(self, predn, pred_kpts, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + from ultralytics.engine.results import Results + + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + boxes=predn[:, :6], + keypoints=pred_kpts, + ).save_txt(file, save_conf=save_conf) + + def pred_to_json(self, predn, filename): + """Convert YOLO predictions to COCO JSON format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "keypoints": p[6:], + "score": round(p[4], 5), + } + ) + + def eval_json(self, stats): + """Evaluate object detection model using COCO JSON format.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements("pycocotools>=2.0.6") + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]): + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + idx = i * 4 + 2 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f"pycocotools unable to run: {e}") + return stats diff --git a/tracking/ultralytics/models/yolo/segment/__init__.py b/tracking/ultralytics/models/yolo/segment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36a921a9a36a0f0a0bf1bf03be9014c6886f6c6e --- /dev/null +++ b/tracking/ultralytics/models/yolo/segment/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import SegmentationPredictor +from .train import SegmentationTrainer +from .val import SegmentationValidator + +__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" diff --git a/tracking/ultralytics/models/yolo/segment/predict.py b/tracking/ultralytics/models/yolo/segment/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..11a752567fb761b19da881e043e7cad11cd35361 --- /dev/null +++ b/tracking/ultralytics/models/yolo/segment/predict.py @@ -0,0 +1,88 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, ops + + +class SegmentationPredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on a segmentation model. + + This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the + prediction results. + + Attributes: + args (dict): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded YOLO segmentation model. + batch (list): Current batch of images being processed. + + Methods: + postprocess: Applies non-max suppression and processes detections. + construct_results: Constructs a list of result objects from predictions. + construct_result: Constructs a single result object from a prediction. + + Examples: + >>> from ultralytics.utils import ASSETS + >>> from ultralytics.models.yolo.segment import SegmentationPredictor + >>> args = dict(model="yolo11n-seg.pt", source=ASSETS) + >>> predictor = SegmentationPredictor(overrides=args) + >>> predictor.predict_cli() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize the SegmentationPredictor with configuration, overrides, and callbacks.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "segment" + + def postprocess(self, preds, img, orig_imgs): + """Apply non-max suppression and process detections for each image in the input batch.""" + # Extract protos - tuple if PyTorch model or array if exported + protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] + return super().postprocess(preds[0], img, orig_imgs, protos=protos) + + def construct_results(self, preds, img, orig_imgs, protos): + """ + Construct a list of result objects from the predictions. + + Args: + preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks. + img (torch.Tensor): The image after preprocessing. + orig_imgs (List[np.ndarray]): List of original images before preprocessing. + protos (List[torch.Tensor]): List of prototype masks. + + Returns: + (List[Results]): List of result objects containing the original images, image paths, class names, + bounding boxes, and masks. + """ + return [ + self.construct_result(pred, img, orig_img, img_path, proto) + for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos) + ] + + def construct_result(self, pred, img, orig_img, img_path, proto): + """ + Construct a single result object from the prediction. + + Args: + pred (np.ndarray): The predicted bounding boxes, scores, and masks. + img (torch.Tensor): The image after preprocessing. + orig_img (np.ndarray): The original image before preprocessing. + img_path (str): The path to the original image. + proto (torch.Tensor): The prototype masks. + + Returns: + (Results): Result object containing the original image, image path, class names, bounding boxes, and masks. + """ + if not len(pred): # save empty boxes + masks = None + elif self.args.retina_masks: + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC + else: + masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + if masks is not None: + keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks + pred, masks = pred[keep], masks[keep] + return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks) diff --git a/tracking/ultralytics/models/yolo/segment/train.py b/tracking/ultralytics/models/yolo/segment/train.py new file mode 100644 index 0000000000000000000000000000000000000000..01a7a2e0919eb6c7f5c619affad1f704b5c4c0bc --- /dev/null +++ b/tracking/ultralytics/models/yolo/segment/train.py @@ -0,0 +1,65 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import SegmentationModel +from ultralytics.utils import DEFAULT_CFG, RANK +from ultralytics.utils.plotting import plot_images, plot_results + + +class SegmentationTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on a segmentation model. + + This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific + functionality including model initialization, validation, and visualization. + + Attributes: + loss_names (Tuple[str]): Names of the loss components used during training. + + Examples: + >>> from ultralytics.models.yolo.segment import SegmentationTrainer + >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3) + >>> trainer = SegmentationTrainer(overrides=args) + >>> trainer.train() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a SegmentationTrainer object with given arguments.""" + if overrides is None: + overrides = {} + overrides["task"] = "segment" + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return SegmentationModel initialized with specified config and weights.""" + model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + return model + + def get_validator(self): + """Return an instance of SegmentationValidator for validation of YOLO model.""" + self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" + return yolo.segment.SegmentationValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def plot_training_samples(self, batch, ni): + """Creates a plot of training sample images with labels and box coordinates.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots training/val metrics.""" + plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png diff --git a/tracking/ultralytics/models/yolo/segment/val.py b/tracking/ultralytics/models/yolo/segment/val.py new file mode 100644 index 0000000000000000000000000000000000000000..4d078f36de61ca2c6e4b9a47ea7ca81fbc010a9c --- /dev/null +++ b/tracking/ultralytics/models/yolo/segment/val.py @@ -0,0 +1,392 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, NUM_THREADS, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class SegmentationValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on a segmentation model. + + This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions + to compute metrics such as mAP for both detection and segmentation tasks. + + Attributes: + plot_masks (list): List to store masks for plotting. + process (callable): Function to process masks based on save_json and save_txt flags. + args (namespace): Arguments for the validator. + metrics (SegmentMetrics): Metrics calculator for segmentation tasks. + stats (dict): Dictionary to store statistics during validation. + + Examples: + >>> from ultralytics.models.yolo.segment import SegmentationValidator + >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml") + >>> validator = SegmentationValidator(args=args) + >>> validator() + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics. + + Args: + dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation. + save_dir (Path, optional): Directory to save results. + pbar (Any, optional): Progress bar for displaying progress. + args (namespace, optional): Arguments for the validator. + _callbacks (list, optional): List of callback functions. + """ + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.plot_masks = None + self.process = None + self.args.task = "segment" + self.metrics = SegmentMetrics(save_dir=self.save_dir) + + def preprocess(self, batch): + """Preprocess batch by converting masks to float and sending to device.""" + batch = super().preprocess(batch) + batch["masks"] = batch["masks"].to(self.device).float() + return batch + + def init_metrics(self, model): + """ + Initialize metrics and select mask processing function based on save_json flag. + + Args: + model (torch.nn.Module): Model to validate. + """ + super().init_metrics(model) + self.plot_masks = [] + if self.args.save_json: + check_requirements("pycocotools>=2.0.6") + # more accurate vs faster + self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask + self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + + def get_desc(self): + """Return a formatted description of evaluation metrics.""" + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Mask(P", + "R", + "mAP50", + "mAP50-95)", + ) + + def postprocess(self, preds): + """ + Post-process YOLO predictions and return output detections with proto. + + Args: + preds (list): Raw predictions from the model. + + Returns: + p (torch.Tensor): Processed detection predictions. + proto (torch.Tensor): Prototype masks for segmentation. + """ + p = super().postprocess(preds[0]) + proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported + return p, proto + + def _prepare_batch(self, si, batch): + """ + Prepare a batch for training or inference by processing images and targets. + + Args: + si (int): Batch index. + batch (dict): Batch data containing images and targets. + + Returns: + (dict): Prepared batch with processed images and targets. + """ + prepared_batch = super()._prepare_batch(si, batch) + midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si + prepared_batch["masks"] = batch["masks"][midx] + return prepared_batch + + def _prepare_pred(self, pred, pbatch, proto): + """ + Prepare predictions for evaluation by processing bounding boxes and masks. + + Args: + pred (torch.Tensor): Raw predictions from the model. + pbatch (dict): Prepared batch data. + proto (torch.Tensor): Prototype masks for segmentation. + + Returns: + predn (torch.Tensor): Processed bounding box predictions. + pred_masks (torch.Tensor): Processed mask predictions. + """ + predn = super()._prepare_pred(pred, pbatch) + pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"]) + return predn, pred_masks + + def update_metrics(self, preds, batch): + """ + Update metrics with the current batch predictions and targets. + + Args: + preds (list): Predictions from the model. + batch (dict): Batch data containing images and targets. + """ + for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + stat["target_img"] = cls.unique() + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Masks + gt_masks = pbatch.pop("masks") + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn, pred_masks = self._prepare_pred(pred, pbatch, proto) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_m"] = self._process_batch( + predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True + ) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) + if self.args.plots and self.batch_i < 3: + self.plot_masks.append(pred_masks[:50].cpu()) # Limit plotted items for speed + if pred_masks.shape[0] > 50: + LOGGER.warning("WARNING ⚠️ Limiting validation plots to first 50 items per image for speed...") + + # Save + if self.args.save_json: + self.pred_to_json( + predn, + batch["im_file"][si], + ops.scale_image( + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), + pbatch["ori_shape"], + ratio_pad=batch["ratio_pad"][si], + ), + ) + if self.args.save_txt: + self.save_one_txt( + predn, + pred_masks, + self.args.save_conf, + pbatch["ori_shape"], + self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt", + ) + + def finalize_metrics(self, *args, **kwargs): + """Set speed and confusion matrix for evaluation metrics.""" + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False): + """ + Compute correct prediction matrix for a batch based on bounding boxes and optional masks. + + Args: + detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and + associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class]. + gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates. + Each row is of the format [x1, y1, x2, y2]. + gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices. + pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should + match the ground truth masks. + gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available. + overlap (bool): Flag indicating if overlapping masks should be considered. + masks (bool): Flag indicating if the batch contains mask data. + + Returns: + (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels. + + Note: + - If `masks` is True, the function computes IoU between predicted and ground truth masks. + - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU. + + Examples: + >>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]]) + >>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]]) + >>> gt_cls = torch.tensor([1, 0]) + >>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls) + """ + if masks: + if overlap: + nl = len(gt_cls) + index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 + gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) + gt_masks = torch.where(gt_masks == index, 1.0, 0.0) + if gt_masks.shape[1:] != pred_masks.shape[1:]: + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] + gt_masks = gt_masks.gt_(0.5) + iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) + else: # boxes + iou = box_iou(gt_bboxes, detections[:, :4]) + + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def plot_val_samples(self, batch, ni): + """ + Plot validation samples with bounding box labels and masks. + + Args: + batch (dict): Batch data containing images and targets. + ni (int): Batch index. + """ + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """ + Plot batch predictions with masks and bounding boxes. + + Args: + batch (dict): Batch data containing images. + preds (list): Predictions from the model. + ni (int): Batch index. + """ + plot_images( + batch["img"], + *output_to_target(preds[0], max_det=50), # not set to self.args.max_det due to slow plotting speed + torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + self.plot_masks.clear() + + def save_one_txt(self, predn, pred_masks, save_conf, shape, file): + """ + Save YOLO detections to a txt file in normalized coordinates in a specific format. + + Args: + predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls]. + pred_masks (torch.Tensor): Predicted masks. + save_conf (bool): Whether to save confidence scores. + shape (tuple): Original image shape. + file (Path): File path to save the detections. + """ + from ultralytics.engine.results import Results + + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + boxes=predn[:, :6], + masks=pred_masks, + ).save_txt(file, save_conf=save_conf) + + def pred_to_json(self, predn, filename, pred_masks): + """ + Save one JSON result for COCO evaluation. + + Args: + predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls]. + filename (str): Image filename. + pred_masks (numpy.ndarray): Predicted masks. + + Examples: + >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ + from pycocotools.mask import encode # noqa + + def single_encode(x): + """Encode predicted masks as RLE and append results to jdict.""" + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") + return rle + + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + pred_masks = np.transpose(pred_masks, (2, 0, 1)) + with ThreadPool(NUM_THREADS) as pool: + rles = pool.map(single_encode, pred_masks) + for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + "segmentation": rles[i], + } + ) + + def eval_json(self, stats): + """Return COCO-style object detection evaluation metrics.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements("pycocotools>=2.0.6") + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]): + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + idx = i * 4 + 2 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f"pycocotools unable to run: {e}") + return stats diff --git a/tracking/ultralytics/models/yolo/world/__init__.py b/tracking/ultralytics/models/yolo/world/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4380d244602c1db758195f87f7fb2c6aa8141536 --- /dev/null +++ b/tracking/ultralytics/models/yolo/world/__init__.py @@ -0,0 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .train import WorldTrainer + +__all__ = ["WorldTrainer"] diff --git a/tracking/ultralytics/models/yolo/world/train.py b/tracking/ultralytics/models/yolo/world/train.py new file mode 100644 index 0000000000000000000000000000000000000000..207a451d1027bd0434de35fd3e34bb86c29c2e19 --- /dev/null +++ b/tracking/ultralytics/models/yolo/world/train.py @@ -0,0 +1,119 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import itertools + +from ultralytics.data import build_yolo_dataset +from ultralytics.models import yolo +from ultralytics.nn.tasks import WorldModel +from ultralytics.utils import DEFAULT_CFG, RANK, checks +from ultralytics.utils.torch_utils import de_parallel + + +def on_pretrain_routine_end(trainer): + """Callback to set up model classes and text encoder at the end of the pretrain routine.""" + if RANK in {-1, 0}: + # Set class names for evaluation + names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())] + de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False) + device = next(trainer.model.parameters()).device + trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device) + for p in trainer.text_model.parameters(): + p.requires_grad_(False) + + +class WorldTrainer(yolo.detect.DetectionTrainer): + """ + A class to fine-tune a world model on a close-set dataset. + + This trainer extends the DetectionTrainer to support training YOLO World models, which combine + visual and textual features for improved object detection and understanding. + + Attributes: + clip (module): The CLIP module for text-image understanding. + text_model (module): The text encoder model from CLIP. + model (WorldModel): The YOLO World model being trained. + data (dict): Dataset configuration containing class information. + args (dict): Training arguments and configuration. + + Examples: + >>> from ultralytics.models.yolo.world import WorldModel + >>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3) + >>> trainer = WorldTrainer(overrides=args) + >>> trainer.train() + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize a WorldTrainer object with given arguments. + + Args: + cfg (dict): Configuration for the trainer. + overrides (dict, optional): Configuration overrides. + _callbacks (list, optional): List of callback functions. + """ + if overrides is None: + overrides = {} + super().__init__(cfg, overrides, _callbacks) + + # Import and assign clip + try: + import clip + except ImportError: + checks.check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + self.clip = clip + + def get_model(self, cfg=None, weights=None, verbose=True): + """ + Return WorldModel initialized with specified config and weights. + + Args: + cfg (Dict | str, optional): Model configuration. + weights (str, optional): Path to pretrained weights. + verbose (bool): Whether to display model info. + + Returns: + (WorldModel): Initialized WorldModel. + """ + # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`. + # NOTE: Following the official config, nc hard-coded to 80 for now. + model = WorldModel( + cfg["yaml_file"] if isinstance(cfg, dict) else cfg, + ch=3, + nc=min(self.data["nc"], 80), + verbose=verbose and RANK == -1, + ) + if weights: + model.load(weights) + self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end) + + return model + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset for training or validation. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. + + Returns: + (Dataset): YOLO dataset configured for training or validation. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset( + self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train" + ) + + def preprocess_batch(self, batch): + """Preprocess a batch of images and text for YOLOWorld training.""" + batch = super().preprocess_batch(batch) + + # Add text features + texts = list(itertools.chain(*batch["texts"])) + text_token = self.clip.tokenize(texts).to(batch["img"].device) + txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32 + txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) + batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1]) + return batch diff --git a/tracking/ultralytics/models/yolo/world/train_world.py b/tracking/ultralytics/models/yolo/world/train_world.py new file mode 100644 index 0000000000000000000000000000000000000000..f78f1e93f02e70526912bc08faf377877eb5005d --- /dev/null +++ b/tracking/ultralytics/models/yolo/world/train_world.py @@ -0,0 +1,134 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset +from ultralytics.data.utils import check_det_dataset +from ultralytics.models.yolo.world import WorldTrainer +from ultralytics.utils import DEFAULT_CFG +from ultralytics.utils.torch_utils import de_parallel + + +class WorldTrainerFromScratch(WorldTrainer): + """ + A class extending the WorldTrainer for training a world model from scratch on open-set datasets. + + This trainer specializes in handling mixed datasets including both object detection and grounding datasets, + supporting training YOLO-World models with combined vision-language capabilities. + + Attributes: + cfg (dict): Configuration dictionary with default parameters for model training. + overrides (dict): Dictionary of parameter overrides to customize the configuration. + _callbacks (list): List of callback functions to be executed during different stages of training. + + Examples: + >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + >>> from ultralytics import YOLOWorld + >>> data = dict( + ... train=dict( + ... yolo_data=["Objects365.yaml"], + ... grounding_data=[ + ... dict( + ... img_path="../datasets/flickr30k/images", + ... json_file="../datasets/flickr30k/final_flickr_separateGT_train.json", + ... ), + ... dict( + ... img_path="../datasets/GQA/images", + ... json_file="../datasets/GQA/final_mixed_train_no_coco.json", + ... ), + ... ], + ... ), + ... val=dict(yolo_data=["lvis.yaml"]), + ... ) + >>> model = YOLOWorld("yolov8s-worldv2.yaml") + >>> model.train(data=data, trainer=WorldTrainerFromScratch) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a WorldTrainerFromScratch object with given configuration and callbacks.""" + if overrides is None: + overrides = {} + super().__init__(cfg, overrides, _callbacks) + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset for training or validation. + + This method constructs appropriate datasets based on the mode and input paths, handling both + standard YOLO datasets and grounding datasets with different formats. + + Args: + img_path (List[str] | str): Path to the folder containing images or list of paths. + mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode. + batch (int, optional): Size of batches, used for rectangular training/validation. + + Returns: + (YOLOConcatDataset | Dataset): The constructed dataset for training or validation. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + if mode != "train": + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) + dataset = [ + build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) + if isinstance(im_path, str) + else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs) + for im_path in img_path + ] + return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] + + def get_dataset(self): + """ + Get train and validation paths from data dictionary. + + Processes the data configuration to extract paths for training and validation datasets, + handling both YOLO detection datasets and grounding datasets. + + Returns: + (str): Train dataset path. + (str): Validation dataset path. + + Raises: + AssertionError: If train or validation datasets are not found, or if validation has multiple datasets. + """ + final_data = {} + data_yaml = self.args.data + assert data_yaml.get("train", False), "train dataset not found" # object365.yaml + assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml + data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()} + assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}." + val_split = "minival" if "lvis" in data["val"][0]["val"] else "val" + for d in data["val"]: + if d.get("minival") is None: # for lvis dataset + continue + d["minival"] = str(d["path"] / d["minival"]) + for s in ["train", "val"]: + final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]] + # save grounding data if there's one + grounding_data = data_yaml[s].get("grounding_data") + if grounding_data is None: + continue + grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data] + for g in grounding_data: + assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}" + final_data[s] += grounding_data + # NOTE: to make training work properly, set `nc` and `names` + final_data["nc"] = data["val"][0]["nc"] + final_data["names"] = data["val"][0]["names"] + self.data = final_data + return final_data["train"], final_data["val"][0] + + def plot_training_labels(self): + """Do not plot labels for YOLO-World training.""" + pass + + def final_eval(self): + """ + Perform final evaluation and validation for the YOLO-World model. + + Configures the validator with appropriate dataset and split information before running evaluation. + + Returns: + (dict): Dictionary containing evaluation metrics and results. + """ + val = self.args.data["val"]["yolo_data"][0] + self.validator.args.data = val + self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val" + return super().final_eval() diff --git a/tracking/ultralytics/nn/__init__.py b/tracking/ultralytics/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a54d51cb76e4b6356ef22ec574b6031b4640ea --- /dev/null +++ b/tracking/ultralytics/nn/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .tasks import ( + BaseModel, + ClassificationModel, + DetectionModel, + SegmentationModel, + attempt_load_one_weight, + attempt_load_weights, + guess_model_scale, + guess_model_task, + parse_model, + torch_safe_load, + yaml_model_load, +) + +__all__ = ( + "attempt_load_one_weight", + "attempt_load_weights", + "parse_model", + "yaml_model_load", + "guess_model_task", + "guess_model_scale", + "torch_safe_load", + "DetectionModel", + "SegmentationModel", + "ClassificationModel", + "BaseModel", +) diff --git a/tracking/ultralytics/nn/autobackend.py b/tracking/ultralytics/nn/autobackend.py new file mode 100644 index 0000000000000000000000000000000000000000..398d378192f5e4b05dfc070233998ae6d21bdfd8 --- /dev/null +++ b/tracking/ultralytics/nn/autobackend.py @@ -0,0 +1,821 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import ast +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, PYTHON_VERSION, ROOT, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """Check class names and convert to dict format if needed.""" + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + try: + return yaml_load(check_yaml(data))["names"] + except Exception: + pass + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + | --------------------- | ----------------- | + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | + | IMX | *_imx_model/ | + | RKNN | *_rknn_model/ | + + Attributes: + model (torch.nn.Module): The loaded YOLO model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + task (str): The type of task the model performs (detect, segment, classify, pose). + names (dict): A dictionary of class names that the model can detect. + stride (int): The model stride, typically 32 for YOLO models. + fp16 (bool): Whether the model uses half-precision (FP16) inference. + + Methods: + forward: Run inference on an input image. + from_numpy: Convert numpy array to tensor. + warmup: Warm up the model with a dummy input. + _model_type: Determine the model type from file path. + + Examples: + >>> model = AutoBackend(weights="yolov8n.pt", device="cuda") + >>> results = model(img) + """ + + @torch.no_grad() + def __init__( + self, + weights="yolo11n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + rknn, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH) + stride = 32 # default stride + end2end, dynamic = False, False + model, metadata, task = None, None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx, paddle]): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # In-memory PyTorch model + if nn_module: + model = weights.to(device) + if fuse: + model = model.fuse(verbose=verbose) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + + # PyTorch + elif pt: + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + + # TorchScript + elif jit: + import torchvision # noqa - https://github.com/ultralytics/ultralytics/pull/19747 + + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + + # ONNX Runtime and IMX + elif onnx or imx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + if IS_RASPBERRYPI or IS_JETSON: + # Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson + check_requirements("numpy==1.23.5") + import onnxruntime + + providers = ["CPUExecutionProvider"] + if cuda: + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + providers.insert(0, "CUDAExecutionProvider") + else: # Only log warning if CUDA was requested but unavailable + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Using ONNX Runtime {providers[0]}") + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + fp16 = "float16" in session.get_inputs()[0].type + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + out_fp16 = "float16" in output.type + y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if out_fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0,!=2025.0.0") + import openvino as ov + + core = ov.Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") + + if IS_JETSON and check_version(PYTHON_VERSION, "<=3.8.0"): + # fix error: `np.bool` was a deprecated alias for the builtin `bool` for JetPack 4 with Python <= 3.8.0 + check_requirements("numpy==1.23.5") + + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + try: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + except UnicodeDecodeError: + f.seek(0) # engine file may lack embedded Ultralytics metadata + dla = metadata.get("dla", None) + if dla is not None: + runtime.DLA_core = int(dla) + model = runtime.deserialize_cuda_engine(f.read()) # read engine + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # find metadata in SavedModel alongside GraphDef + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass + + # TFLite or TFLite Edge TPU + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + device = device[3:] if str(device).startswith("tpu") else ":0" + LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] + interpreter = Interpreter( + model_path=w, + experimental_delegates=[load_delegate(delegate, options={"device": device})], + ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + try: + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + except zipfile.BadZipFile: + pass + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {"precision": "low", "backend": "CPU", "numThread": (os.cpu_count() + 1) // 2} + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + metadata = model.metadata + + # RKNN + elif rknn: + if not is_rockchip(): + raise OSError("RKNN inference is only supported on Rockchip devices.") + LOGGER.info(f"Loading {w} for RKNN inference...") + check_requirements("rknn-toolkit-lite2") + from rknnlite.api import RKNNLite + + w = Path(w) + if not w.is_file(): # if not *.rknn + w = next(w.rglob("*.rknn")) # get *.rknn file from *_rknn_model dir + rknn_model = RKNNLite() + rknn_model.load_rknn(w) + rknn_model.init_runtime() + metadata = Path(w).parent / "metadata.yaml" + + # Any other format (unsupported) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n" + f"See https://docs.ultralytics.com/modes/predict for help." + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata and isinstance(metadata, dict): + for k, v in metadata.items(): + if k in {"stride", "batch"}: + metadata[k] = int(v) + elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + end2end = metadata.get("args", {}).get("nms", False) + dynamic = metadata.get("args", {}).get("dynamic", dynamic) + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): Whether to perform data augmentation during inference. Defaults to False. + visualize (bool): Whether to visualize the output predictions. Defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model. + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: + y = self.model(im) + + # ONNX OpenCV DNN + elif self.dnn: + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + + # ONNX Runtime + elif self.onnx or self.imx: + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) + + # OpenVINO + elif self.xml: + im = im.cpu().numpy() # FP32 + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + + s = self.bindings["images"].shape + assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + + # CoreML + elif self.coreml: + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) + # TODO: CoreML NMS inference handling + # from ultralytics.utils.ops import xywh2xyxy + # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32) + # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) + y = list(y.values()) + if len(y) == 2 and len(y[1].shape) != 4: # segmentation model + y = list(reversed(y)) # reversed for segmentation models (pred, proto) + + # PaddlePaddle + elif self.paddle: + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + + # NCNN + elif self.ncnn: + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130 + y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + + # RKNN + elif self.rknn: + im = (im.cpu().numpy() * 255).astype("uint8") + im = im if isinstance(im, (list, tuple)) else [im] + y = self.rknn_model.inference(inputs=im) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + else: # Lite or Edge TPU + details = self.input_details[0] + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if is_int: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well + # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + if x.shape[-1] == 6 or self.end2end: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + if self.task == "pose": + x[:, :, 6::3] *= w + x[:, :, 7::3] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + # for x in y: + # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes + if isinstance(y, (list, tuple)): + if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined + nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) + """ + import torchvision # noqa (import here so torchvision import time not recorded in postprocess time) + + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml, + saved_model, pb, tflite, edgetpu, tfjs, ncnn, mnn, imx or paddle. + + Args: + p (str): Path to the model file. Defaults to path/to/model.pt + + Returns: + (List[bool]): List of booleans indicating the model type. + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = export_formats()["Suffix"] # export suffixes + if not is_url(p) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/tracking/ultralytics/nn/modules/__init__.py b/tracking/ultralytics/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c45e31529957936621f1322b9c7a95856c3923 --- /dev/null +++ b/tracking/ultralytics/nn/modules/__init__.py @@ -0,0 +1,165 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Ultralytics modules. + +This module provides access to various neural network components used in Ultralytics models, including convolution blocks, +attention mechanisms, transformer components, and detection/segmentation heads. + +Examples: + Visualize a module with Netron. + >>> from ultralytics.nn.modules import * + >>> import torch + >>> import os + >>> x = torch.ones(1, 128, 40, 40) + >>> m = Conv(128, 128) + >>> f = f"{m._get_name()}.onnx" + >>> torch.onnx.export(m, x, f) + >>> os.system(f"onnxslim {f} {f} && open {f}") # pip install onnxslim +""" + +from .block import ( + C1, + C2, + C2PSA, + C3, + C3TR, + CIB, + DFL, + ELAN1, + PSA, + SPP, + SPPELAN, + SPPF, + A2C2f, + AConv, + ADown, + Attention, + BNContrastiveHead, + Bottleneck, + BottleneckCSP, + C2f, + C2fAttn, + C2fCIB, + C2fPSA, + C3Ghost, + C3k2, + C3x, + CBFuse, + CBLinear, + ContrastiveHead, + GhostBottleneck, + HGBlock, + HGStem, + ImagePoolingAttn, + Proto, + RepC3, + RepNCSPELAN4, + RepVGGDW, + ResNetLayer, + SCDown, + TorchVision, +) +from .conv import ( + CBAM, + ChannelAttention, + Concat, + Conv, + Conv2, + ConvTranspose, + DWConv, + DWConvTranspose2d, + Focus, + GhostConv, + Index, + LightConv, + RepConv, + SpatialAttention, +) +from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect +from .transformer import ( + AIFI, + MLP, + DeformableTransformerDecoder, + DeformableTransformerDecoderLayer, + LayerNorm2d, + MLPBlock, + MSDeformAttn, + TransformerBlock, + TransformerEncoderLayer, + TransformerLayer, +) + +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "RepConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C3k2", + "SCDown", + "C2fPSA", + "C2PSA", + "C2fAttn", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "Detect", + "Segment", + "Pose", + "Classify", + "TransformerEncoderLayer", + "RepC3", + "RTDETRDecoder", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", + "ResNetLayer", + "OBB", + "WorldDetect", + "v10Detect", + "ImagePoolingAttn", + "ContrastiveHead", + "BNContrastiveHead", + "RepNCSPELAN4", + "ADown", + "SPPELAN", + "CBFuse", + "CBLinear", + "AConv", + "ELAN1", + "RepVGGDW", + "CIB", + "C2fCIB", + "Attention", + "PSA", + "TorchVision", + "Index", + "A2C2f", +) diff --git a/tracking/ultralytics/nn/modules/activation.py b/tracking/ultralytics/nn/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef9308a0178573c7b167538454bf5d0cf69a792 --- /dev/null +++ b/tracking/ultralytics/nn/modules/activation.py @@ -0,0 +1,30 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Activation modules.""" + +import torch +import torch.nn as nn + + +class AGLU(nn.Module): + """ + Unified activation function module from https://github.com/kostas1515/AGLU. + + This class implements a parameterized activation function with learnable parameters lambda and kappa. + + Attributes: + act (nn.Softplus): Softplus activation function with negative beta. + lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution. + kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution. + """ + + def __init__(self, device=None, dtype=None) -> None: + """Initialize the Unified activation function with learnable parameters.""" + super().__init__() + self.act = nn.Softplus(beta=-1.0) + self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter + self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute the forward pass of the Unified activation function.""" + lam = torch.clamp(self.lambd, min=0.0001) # Clamp lambda to avoid division by zero + return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam))) diff --git a/tracking/ultralytics/nn/modules/block.py b/tracking/ultralytics/nn/modules/block.py new file mode 100644 index 0000000000000000000000000000000000000000..de88a7e19a7830f227ad9bc4d7978011ca1d7acb --- /dev/null +++ b/tracking/ultralytics/nn/modules/block.py @@ -0,0 +1,1861 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Block modules.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils.torch_utils import fuse_conv_and_bn + +from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad +from .transformer import TransformerBlock + +__all__ = ( + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C2fAttn", + "ImagePoolingAttn", + "ContrastiveHead", + "BNContrastiveHead", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "RepC3", + "ResNetLayer", + "RepNCSPELAN4", + "ELAN1", + "ADown", + "AConv", + "SPPELAN", + "CBFuse", + "CBLinear", + "C3k2", + "C2fPSA", + "C2PSA", + "RepVGGDW", + "CIB", + "C2fCIB", + "Attention", + "PSA", + "SCDown", + "TorchVision", +) + + +class DFL(nn.Module): + """ + Integral module of Distribution Focal Loss (DFL). + + Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 + """ + + def __init__(self, c1=16): + """Initialize a convolutional layer with a given number of input channels.""" + super().__init__() + self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) + x = torch.arange(c1, dtype=torch.float) + self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) + self.c1 = c1 + + def forward(self, x): + """Apply the DFL module to input tensor and return transformed output.""" + b, _, a = x.shape # batch, channels, anchors + return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) + # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) + + +class Proto(nn.Module): + """YOLOv8 mask Proto module for segmentation models.""" + + def __init__(self, c1, c_=256, c2=32): + """ + Initialize the YOLOv8 mask Proto module with specified number of protos and masks. + + Args: + c1 (int): Input channels. + c_ (int): Intermediate channels. + c2 (int): Output channels (number of protos). + """ + super().__init__() + self.cv1 = Conv(c1, c_, k=3) + self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest') + self.cv2 = Conv(c_, c_, k=3) + self.cv3 = Conv(c_, c2) + + def forward(self, x): + """Perform a forward pass through layers using an upsampled input image.""" + return self.cv3(self.cv2(self.upsample(self.cv1(x)))) + + +class HGStem(nn.Module): + """ + StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py + """ + + def __init__(self, c1, cm, c2): + """ + Initialize the StemBlock of PPHGNetV2. + + Args: + c1 (int): Input channels. + cm (int): Middle channels. + c2 (int): Output channels. + """ + super().__init__() + self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU()) + self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU()) + self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU()) + self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU()) + self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU()) + self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True) + + def forward(self, x): + """Forward pass of a PPHGNetV2 backbone layer.""" + x = self.stem1(x) + x = F.pad(x, [0, 1, 0, 1]) + x2 = self.stem2a(x) + x2 = F.pad(x2, [0, 1, 0, 1]) + x2 = self.stem2b(x2) + x1 = self.pool(x) + x = torch.cat([x1, x2], dim=1) + x = self.stem3(x) + x = self.stem4(x) + return x + + +class HGBlock(nn.Module): + """ + HG_Block of PPHGNetV2 with 2 convolutions and LightConv. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py + """ + + def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()): + """ + Initialize HGBlock with specified parameters. + + Args: + c1 (int): Input channels. + cm (int): Middle channels. + c2 (int): Output channels. + k (int): Kernel size. + n (int): Number of LightConv or Conv blocks. + lightconv (bool): Whether to use LightConv. + shortcut (bool): Whether to use shortcut connection. + act (nn.Module): Activation function. + """ + super().__init__() + block = LightConv if lightconv else Conv + self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n)) + self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv + self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv + self.add = shortcut and c1 == c2 + + def forward(self, x): + """Forward pass of a PPHGNetV2 backbone layer.""" + y = [x] + y.extend(m(y[-1]) for m in self.m) + y = self.ec(self.sc(torch.cat(y, 1))) + return y + x if self.add else y + + +class SPP(nn.Module): + """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.""" + + def __init__(self, c1, c2, k=(5, 9, 13)): + """ + Initialize the SPP layer with input/output channels and pooling kernel sizes. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + k (Tuple[int, int, int]): Kernel sizes for max pooling. + """ + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + """Forward pass of the SPP layer, performing spatial pyramid pooling.""" + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class SPPF(nn.Module): + """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.""" + + def __init__(self, c1, c2, k=5): + """ + Initialize the SPPF layer with given input/output channels and kernel size. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + k (int): Kernel size. + + Notes: + This module is equivalent to SPP(k=(5, 9, 13)). + """ + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + """Apply sequential pooling operations to input and return concatenated feature maps.""" + y = [self.cv1(x)] + y.extend(self.m(y[-1]) for _ in range(3)) + return self.cv2(torch.cat(y, 1)) + + +class C1(nn.Module): + """CSP Bottleneck with 1 convolution.""" + + def __init__(self, c1, c2, n=1): + """ + Initialize the CSP Bottleneck with 1 convolution. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of convolutions. + """ + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) + + def forward(self, x): + """Apply convolution and residual connection to input tensor.""" + y = self.cv1(x) + return self.m(y) + y + + +class C2(nn.Module): + """CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize a CSP Bottleneck with 2 convolutions. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2) + # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention() + self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))) + + def forward(self, x): + """Forward pass through the CSP bottleneck with 2 convolutions.""" + a, b = self.cv1(x).chunk(2, 1) + return self.cv2(torch.cat((self.m(a), b), 1)) + + +class C2f(nn.Module): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): + """ + Initialize a CSP bottleneck with 2 convolutions. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + + def forward(self, x): + """Forward pass through C2f layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x): + """Forward pass using split() instead of chunk().""" + y = self.cv1(x).split((self.c, self.c), 1) + y = [y[0], y[1]] + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + + +class C3(nn.Module): + """CSP Bottleneck with 3 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize the CSP Bottleneck with 3 convolutions. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) + + def forward(self, x): + """Forward pass through the CSP bottleneck with 3 convolutions.""" + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) + + +class C3x(C3): + """C3 module with cross-convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize C3 module with cross-convolutions. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__(c1, c2, n, shortcut, g, e) + self.c_ = int(c2 * e) + self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n))) + + +class RepC3(nn.Module): + """Rep C3.""" + + def __init__(self, c1, c2, n=3, e=1.0): + """ + Initialize CSP Bottleneck with a single convolution. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of RepConv blocks. + e (float): Expansion ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)]) + self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity() + + def forward(self, x): + """Forward pass of RepC3 module.""" + return self.cv3(self.m(self.cv1(x)) + self.cv2(x)) + + +class C3TR(C3): + """C3 module with TransformerBlock().""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize C3 module with TransformerBlock. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Transformer blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = TransformerBlock(c_, c_, 4, n) + + +class C3Ghost(C3): + """C3 module with GhostBottleneck().""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize C3 module with GhostBottleneck. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Ghost bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) + + +class GhostBottleneck(nn.Module): + """Ghost Bottleneck https://github.com/huawei-noah/ghostnet.""" + + def __init__(self, c1, c2, k=3, s=1): + """ + Initialize Ghost Bottleneck module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + k (int): Kernel size. + s (int): Stride. + """ + super().__init__() + c_ = c2 // 2 + self.conv = nn.Sequential( + GhostConv(c1, c_, 1, 1), # pw + DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw + GhostConv(c_, c2, 1, 1, act=False), # pw-linear + ) + self.shortcut = ( + nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + ) + + def forward(self, x): + """Apply skip connection and concatenation to input tensor.""" + return self.conv(x) + self.shortcut(x) + + +class Bottleneck(nn.Module): + """Standard bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """ + Initialize a standard bottleneck module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + shortcut (bool): Whether to use shortcut connection. + g (int): Groups for convolutions. + k (Tuple[int, int]): Kernel sizes for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, k[0], 1) + self.cv2 = Conv(c_, c2, k[1], 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + """Apply bottleneck with optional shortcut connection.""" + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize CSP Bottleneck. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.SiLU() + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + """Apply CSP bottleneck with 3 convolutions.""" + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) + + +class ResNetBlock(nn.Module): + """ResNet block with standard convolution layers.""" + + def __init__(self, c1, c2, s=1, e=4): + """ + Initialize ResNet block. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + s (int): Stride. + e (int): Expansion ratio. + """ + super().__init__() + c3 = e * c2 + self.cv1 = Conv(c1, c2, k=1, s=1, act=True) + self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True) + self.cv3 = Conv(c2, c3, k=1, act=False) + self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity() + + def forward(self, x): + """Forward pass through the ResNet block.""" + return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x)) + + +class ResNetLayer(nn.Module): + """ResNet layer with multiple ResNet blocks.""" + + def __init__(self, c1, c2, s=1, is_first=False, n=1, e=4): + """ + Initialize ResNet layer. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + s (int): Stride. + is_first (bool): Whether this is the first layer. + n (int): Number of ResNet blocks. + e (int): Expansion ratio. + """ + super().__init__() + self.is_first = is_first + + if self.is_first: + self.layer = nn.Sequential( + Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ) + else: + blocks = [ResNetBlock(c1, c2, s, e=e)] + blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)]) + self.layer = nn.Sequential(*blocks) + + def forward(self, x): + """Forward pass through the ResNet layer.""" + return self.layer(x) + + +class MaxSigmoidAttnBlock(nn.Module): + """Max Sigmoid attention block.""" + + def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False): + """ + Initialize MaxSigmoidAttnBlock. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + nh (int): Number of heads. + ec (int): Embedding channels. + gc (int): Guide channels. + scale (bool): Whether to use learnable scale parameter. + """ + super().__init__() + self.nh = nh + self.hc = c2 // nh + self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None + self.gl = nn.Linear(gc, ec) + self.bias = nn.Parameter(torch.zeros(nh)) + self.proj_conv = Conv(c1, c2, k=3, s=1, act=False) + self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0 + + def forward(self, x, guide): + """ + Forward pass of MaxSigmoidAttnBlock. + + Args: + x (torch.Tensor): Input tensor. + guide (torch.Tensor): Guide tensor. + + Returns: + (torch.Tensor): Output tensor after attention. + """ + bs, _, h, w = x.shape + + guide = self.gl(guide) + guide = guide.view(bs, -1, self.nh, self.hc) + embed = self.ec(x) if self.ec is not None else x + embed = embed.view(bs, self.nh, self.hc, h, w) + + aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide) + aw = aw.max(dim=-1)[0] + aw = aw / (self.hc**0.5) + aw = aw + self.bias[None, :, None, None] + aw = aw.sigmoid() * self.scale + + x = self.proj_conv(x) + x = x.view(bs, self.nh, -1, h, w) + x = x * aw.unsqueeze(2) + return x.view(bs, -1, h, w) + + +class C2fAttn(nn.Module): + """C2f module with an additional attn module.""" + + def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5): + """ + Initialize C2f module with attention mechanism. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + ec (int): Embedding channels for attention. + nh (int): Number of heads for attention. + gc (int): Guide channels for attention. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh) + + def forward(self, x, guide): + """ + Forward pass through C2f layer with attention. + + Args: + x (torch.Tensor): Input tensor. + guide (torch.Tensor): Guide tensor for attention. + + Returns: + (torch.Tensor): Output tensor after processing. + """ + y = list(self.cv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x, guide): + """ + Forward pass using split() instead of chunk(). + + Args: + x (torch.Tensor): Input tensor. + guide (torch.Tensor): Guide tensor for attention. + + Returns: + (torch.Tensor): Output tensor after processing. + """ + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + + +class ImagePoolingAttn(nn.Module): + """ImagePoolingAttn: Enhance the text embeddings with image-aware information.""" + + def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False): + """ + Initialize ImagePoolingAttn module. + + Args: + ec (int): Embedding channels. + ch (tuple): Channel dimensions for feature maps. + ct (int): Channel dimension for text embeddings. + nh (int): Number of attention heads. + k (int): Kernel size for pooling. + scale (bool): Whether to use learnable scale parameter. + """ + super().__init__() + + nf = len(ch) + self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec)) + self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec)) + self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec)) + self.proj = nn.Linear(ec, ct) + self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0 + self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch]) + self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)]) + self.ec = ec + self.nh = nh + self.nf = nf + self.hc = ec // nh + self.k = k + + def forward(self, x, text): + """ + Forward pass of ImagePoolingAttn. + + Args: + x (List[torch.Tensor]): List of input feature maps. + text (torch.Tensor): Text embeddings. + + Returns: + (torch.Tensor): Enhanced text embeddings. + """ + bs = x[0].shape[0] + assert len(x) == self.nf + num_patches = self.k**2 + x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)] + x = torch.cat(x, dim=-1).transpose(1, 2) + q = self.query(text) + k = self.key(x) + v = self.value(x) + + # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1) + q = q.reshape(bs, -1, self.nh, self.hc) + k = k.reshape(bs, -1, self.nh, self.hc) + v = v.reshape(bs, -1, self.nh, self.hc) + + aw = torch.einsum("bnmc,bkmc->bmnk", q, k) + aw = aw / (self.hc**0.5) + aw = F.softmax(aw, dim=-1) + + x = torch.einsum("bmnk,bkmc->bnmc", aw, v) + x = self.proj(x.reshape(bs, -1, self.ec)) + return x * self.scale + text + + +class ContrastiveHead(nn.Module): + """Implements contrastive learning head for region-text similarity in vision-language models.""" + + def __init__(self): + """Initialize ContrastiveHead with region-text similarity parameters.""" + super().__init__() + # NOTE: use -10.0 to keep the init cls loss consistency with other losses + self.bias = nn.Parameter(torch.tensor([-10.0])) + self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log()) + + def forward(self, x, w): + """ + Forward function of contrastive learning. + + Args: + x (torch.Tensor): Image features. + w (torch.Tensor): Text features. + + Returns: + (torch.Tensor): Similarity scores. + """ + x = F.normalize(x, dim=1, p=2) + w = F.normalize(w, dim=-1, p=2) + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + + +class BNContrastiveHead(nn.Module): + """ + Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization. + + Args: + embed_dims (int): Embed dimensions of text and image features. + """ + + def __init__(self, embed_dims: int): + """ + Initialize BNContrastiveHead. + + Args: + embed_dims (int): Embedding dimensions for features. + """ + super().__init__() + self.norm = nn.BatchNorm2d(embed_dims) + # NOTE: use -10.0 to keep the init cls loss consistency with other losses + self.bias = nn.Parameter(torch.tensor([-10.0])) + # use -1.0 is more stable + self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) + + def forward(self, x, w): + """ + Forward function of contrastive learning with batch normalization. + + Args: + x (torch.Tensor): Image features. + w (torch.Tensor): Text features. + + Returns: + (torch.Tensor): Similarity scores. + """ + x = self.norm(x) + w = F.normalize(w, dim=-1, p=2) + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + + +class RepBottleneck(Bottleneck): + """Rep bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """ + Initialize RepBottleneck. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + shortcut (bool): Whether to use shortcut connection. + g (int): Groups for convolutions. + k (Tuple[int, int]): Kernel sizes for convolutions. + e (float): Expansion ratio. + """ + super().__init__(c1, c2, shortcut, g, k, e) + c_ = int(c2 * e) # hidden channels + self.cv1 = RepConv(c1, c_, k[0], 1) + + +class RepCSP(C3): + """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """ + Initialize RepCSP layer. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of RepBottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + +class RepNCSPELAN4(nn.Module): + """CSP-ELAN.""" + + def __init__(self, c1, c2, c3, c4, n=1): + """ + Initialize CSP-ELAN layer. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + c3 (int): Intermediate channels. + c4 (int): Intermediate channels for RepCSP. + n (int): Number of RepCSP blocks. + """ + super().__init__() + self.c = c3 // 2 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1)) + self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1)) + self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) + + def forward(self, x): + """Forward pass through RepNCSPELAN4 layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + def forward_split(self, x): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class ELAN1(RepNCSPELAN4): + """ELAN1 module with 4 convolutions.""" + + def __init__(self, c1, c2, c3, c4): + """ + Initialize ELAN1 layer. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + c3 (int): Intermediate channels. + c4 (int): Intermediate channels for convolutions. + """ + super().__init__(c1, c2, c3, c4) + self.c = c3 // 2 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = Conv(c3 // 2, c4, 3, 1) + self.cv3 = Conv(c4, c4, 3, 1) + self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) + + +class AConv(nn.Module): + """AConv.""" + + def __init__(self, c1, c2): + """ + Initialize AConv module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + """ + super().__init__() + self.cv1 = Conv(c1, c2, 3, 2, 1) + + def forward(self, x): + """Forward pass through AConv layer.""" + x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True) + return self.cv1(x) + + +class ADown(nn.Module): + """ADown.""" + + def __init__(self, c1, c2): + """ + Initialize ADown module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + """ + super().__init__() + self.c = c2 // 2 + self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1) + self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0) + + def forward(self, x): + """Forward pass through ADown layer.""" + x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True) + x1, x2 = x.chunk(2, 1) + x1 = self.cv1(x1) + x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1) + x2 = self.cv2(x2) + return torch.cat((x1, x2), 1) + + +class SPPELAN(nn.Module): + """SPP-ELAN.""" + + def __init__(self, c1, c2, c3, k=5): + """ + Initialize SPP-ELAN block. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + c3 (int): Intermediate channels. + k (int): Kernel size for max pooling. + """ + super().__init__() + self.c = c3 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv5 = Conv(4 * c3, c2, 1, 1) + + def forward(self, x): + """Forward pass through SPPELAN layer.""" + y = [self.cv1(x)] + y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4]) + return self.cv5(torch.cat(y, 1)) + + +class CBLinear(nn.Module): + """CBLinear.""" + + def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): + """ + Initialize CBLinear module. + + Args: + c1 (int): Input channels. + c2s (List[int]): List of output channel sizes. + k (int): Kernel size. + s (int): Stride. + p (int | None): Padding. + g (int): Groups. + """ + super().__init__() + self.c2s = c2s + self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True) + + def forward(self, x): + """Forward pass through CBLinear layer.""" + return self.conv(x).split(self.c2s, dim=1) + + +class CBFuse(nn.Module): + """CBFuse.""" + + def __init__(self, idx): + """ + Initialize CBFuse module. + + Args: + idx (List[int]): Indices for feature selection. + """ + super().__init__() + self.idx = idx + + def forward(self, xs): + """ + Forward pass through CBFuse layer. + + Args: + xs (List[torch.Tensor]): List of input tensors. + + Returns: + (torch.Tensor): Fused output tensor. + """ + target_size = xs[-1].shape[2:] + res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] + return torch.sum(torch.stack(res + xs[-1:]), dim=0) + + +class C3f(nn.Module): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): + """ + Initialize CSP bottleneck layer with two convolutions. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + + def forward(self, x): + """Forward pass through C3f layer.""" + y = [self.cv2(x), self.cv1(x)] + y.extend(m(y[-1]) for m in self.m) + return self.cv3(torch.cat(y, 1)) + + +class C3k2(C2f): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True): + """ + Initialize C3k2 module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of blocks. + c3k (bool): Whether to use C3k blocks. + e (float): Expansion ratio. + g (int): Groups for convolutions. + shortcut (bool): Whether to use shortcut connections. + """ + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList( + C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n) + ) + + +class C3k(C3): + """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3): + """ + Initialize C3k module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of Bottleneck blocks. + shortcut (bool): Whether to use shortcut connections. + g (int): Groups for convolutions. + e (float): Expansion ratio. + k (int): Kernel size. + """ + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) + + +class RepVGGDW(torch.nn.Module): + """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture.""" + + def __init__(self, ed) -> None: + """ + Initialize RepVGGDW module. + + Args: + ed (int): Input and output channels. + """ + super().__init__() + self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False) + self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False) + self.dim = ed + self.act = nn.SiLU() + + def forward(self, x): + """ + Perform a forward pass of the RepVGGDW block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after applying the depth wise separable convolution. + """ + return self.act(self.conv(x) + self.conv1(x)) + + def forward_fuse(self, x): + """ + Perform a forward pass of the RepVGGDW block without fusing the convolutions. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after applying the depth wise separable convolution. + """ + return self.act(self.conv(x)) + + @torch.no_grad() + def fuse(self): + """ + Fuse the convolutional layers in the RepVGGDW block. + + This method fuses the convolutional layers and updates the weights and biases accordingly. + """ + conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn) + conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn) + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2]) + + final_conv_w = conv_w + conv1_w + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + + self.conv = conv + del self.conv1 + + +class CIB(nn.Module): + """ + Conditional Identity Block (CIB) module. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True. + e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5. + lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False. + """ + + def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False): + """ + Initialize the CIB module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + shortcut (bool): Whether to use shortcut connection. + e (float): Expansion ratio. + lk (bool): Whether to use RepVGGDW. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = nn.Sequential( + Conv(c1, c1, 3, g=c1), + Conv(c1, 2 * c_, 1), + RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_), + Conv(2 * c_, c2, 1), + Conv(c2, c2, 3, g=c2), + ) + + self.add = shortcut and c1 == c2 + + def forward(self, x): + """ + Forward pass of the CIB module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return x + self.cv1(x) if self.add else self.cv1(x) + + +class C2fCIB(C2f): + """ + C2fCIB class represents a convolutional block with C2f and CIB modules. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + n (int, optional): Number of CIB modules to stack. Defaults to 1. + shortcut (bool, optional): Whether to use shortcut connection. Defaults to False. + lk (bool, optional): Whether to use local key connection. Defaults to False. + g (int, optional): Number of groups for grouped convolution. Defaults to 1. + e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5. + """ + + def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5): + """ + Initialize C2fCIB module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of CIB modules. + shortcut (bool): Whether to use shortcut connection. + lk (bool): Whether to use local key connection. + g (int): Groups for convolutions. + e (float): Expansion ratio. + """ + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n)) + + +class Attention(nn.Module): + """ + Attention module that performs self-attention on the input tensor. + + Args: + dim (int): The input tensor dimension. + num_heads (int): The number of attention heads. + attn_ratio (float): The ratio of the attention key dimension to the head dimension. + + Attributes: + num_heads (int): The number of attention heads. + head_dim (int): The dimension of each attention head. + key_dim (int): The dimension of the attention key. + scale (float): The scaling factor for the attention scores. + qkv (Conv): Convolutional layer for computing the query, key, and value. + proj (Conv): Convolutional layer for projecting the attended values. + pe (Conv): Convolutional layer for positional encoding. + """ + + def __init__(self, dim, num_heads=8, attn_ratio=0.5): + """ + Initialize multi-head attention module. + + Args: + dim (int): Input dimension. + num_heads (int): Number of attention heads. + attn_ratio (float): Attention ratio for key dimension. + """ + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.key_dim = int(self.head_dim * attn_ratio) + self.scale = self.key_dim**-0.5 + nh_kd = self.key_dim * num_heads + h = dim + nh_kd * 2 + self.qkv = Conv(dim, h, 1, act=False) + self.proj = Conv(dim, dim, 1, act=False) + self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) + + def forward(self, x): + """ + Forward pass of the Attention module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + (torch.Tensor): The output tensor after self-attention. + """ + B, C, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split( + [self.key_dim, self.key_dim, self.head_dim], dim=2 + ) + + attn = (q.transpose(-2, -1) @ k) * self.scale + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) + x = self.proj(x) + return x + + +class PSABlock(nn.Module): + """ + PSABlock class implementing a Position-Sensitive Attention block for neural networks. + + This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers + with optional shortcut connections. + + Attributes: + attn (Attention): Multi-head attention module. + ffn (nn.Sequential): Feed-forward neural network module. + add (bool): Flag indicating whether to add shortcut connections. + + Methods: + forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers. + + Examples: + Create a PSABlock and perform a forward pass + >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True) + >>> input_tensor = torch.randn(1, 128, 32, 32) + >>> output_tensor = psablock(input_tensor) + """ + + def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None: + """ + Initialize the PSABlock. + + Args: + c (int): Input and output channels. + attn_ratio (float): Attention ratio for key dimension. + num_heads (int): Number of attention heads. + shortcut (bool): Whether to use shortcut connections. + """ + super().__init__() + + self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads) + self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False)) + self.add = shortcut + + def forward(self, x): + """ + Execute a forward pass through PSABlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after attention and feed-forward processing. + """ + x = x + self.attn(x) if self.add else self.attn(x) + x = x + self.ffn(x) if self.add else self.ffn(x) + return x + + +class PSA(nn.Module): + """ + PSA class for implementing Position-Sensitive Attention in neural networks. + + This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to + input tensors, enhancing feature extraction and processing capabilities. + + Attributes: + c (int): Number of hidden channels after applying the initial convolution. + cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c. + cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c. + attn (Attention): Attention module for position-sensitive attention. + ffn (nn.Sequential): Feed-forward network for further processing. + + Methods: + forward: Applies position-sensitive attention and feed-forward network to the input tensor. + + Examples: + Create a PSA module and apply it to an input tensor + >>> psa = PSA(c1=128, c2=128, e=0.5) + >>> input_tensor = torch.randn(1, 128, 64, 64) + >>> output_tensor = psa.forward(input_tensor) + """ + + def __init__(self, c1, c2, e=0.5): + """ + Initialize PSA module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + e (float): Expansion ratio. + """ + super().__init__() + assert c1 == c2 + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64) + self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False)) + + def forward(self, x): + """ + Execute forward pass in PSA module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after attention and feed-forward processing. + """ + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = b + self.attn(b) + b = b + self.ffn(b) + return self.cv2(torch.cat((a, b), 1)) + + +class C2PSA(nn.Module): + """ + C2PSA module with attention mechanism for enhanced feature extraction and processing. + + This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing + capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations. + + Attributes: + c (int): Number of hidden channels. + cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c. + cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c. + m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations. + + Methods: + forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations. + + Notes: + This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules. + + Examples: + >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5) + >>> input_tensor = torch.randn(1, 256, 64, 64) + >>> output_tensor = c2psa(input_tensor) + """ + + def __init__(self, c1, c2, n=1, e=0.5): + """ + Initialize C2PSA module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of PSABlock modules. + e (float): Expansion ratio. + """ + super().__init__() + assert c1 == c2 + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))) + + def forward(self, x): + """ + Process the input tensor through a series of PSA blocks. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after processing. + """ + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = self.m(b) + return self.cv2(torch.cat((a, b), 1)) + + +class C2fPSA(C2f): + """ + C2fPSA module with enhanced feature extraction using PSA blocks. + + This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction. + + Attributes: + c (int): Number of hidden channels. + cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c. + cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c. + m (nn.ModuleList): List of PSA blocks for feature extraction. + + Methods: + forward: Performs a forward pass through the C2fPSA module. + forward_split: Performs a forward pass using split() instead of chunk(). + + Examples: + >>> import torch + >>> from ultralytics.models.common import C2fPSA + >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5) + >>> x = torch.randn(1, 64, 128, 128) + >>> output = model(x) + >>> print(output.shape) + """ + + def __init__(self, c1, c2, n=1, e=0.5): + """ + Initialize C2fPSA module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + n (int): Number of PSABlock modules. + e (float): Expansion ratio. + """ + assert c1 == c2 + super().__init__(c1, c2, n=n, e=e) + self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)) + + +class SCDown(nn.Module): + """ + SCDown module for downsampling with separable convolutions. + + This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in + efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information. + + Attributes: + cv1 (Conv): Pointwise convolution layer that reduces the number of channels. + cv2 (Conv): Depthwise convolution layer that performs spatial downsampling. + + Methods: + forward: Applies the SCDown module to the input tensor. + + Examples: + >>> import torch + >>> from ultralytics import SCDown + >>> model = SCDown(c1=64, c2=128, k=3, s=2) + >>> x = torch.randn(1, 64, 128, 128) + >>> y = model(x) + >>> print(y.shape) + torch.Size([1, 128, 64, 64]) + """ + + def __init__(self, c1, c2, k, s): + """ + Initialize SCDown module. + + Args: + c1 (int): Input channels. + c2 (int): Output channels. + k (int): Kernel size. + s (int): Stride. + """ + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False) + + def forward(self, x): + """ + Apply convolution and downsampling to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Downsampled output tensor. + """ + return self.cv2(self.cv1(x)) + + +class TorchVision(nn.Module): + """ + TorchVision module to allow loading any torchvision model. + + This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers. + + Attributes: + m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped. + + Args: + model (str): Name of the torchvision model to load. + weights (str, optional): Pre-trained weights to load. Default is "DEFAULT". + unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True. + truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2. + split (bool, optional): Returns output from intermediate child modules as list. Default is False. + """ + + def __init__(self, model, weights="DEFAULT", unwrap=True, truncate=2, split=False): + """ + Load the model and weights from torchvision. + + Args: + model (str): Name of the torchvision model to load. + weights (str): Pre-trained weights to load. + unwrap (bool): Whether to unwrap the model. + truncate (int): Number of layers to truncate. + split (bool): Whether to split the output. + """ + import torchvision # scope for faster 'import ultralytics' + + super().__init__() + if hasattr(torchvision.models, "get_model"): + self.m = torchvision.models.get_model(model, weights=weights) + else: + self.m = torchvision.models.__dict__[model](pretrained=bool(weights)) + if unwrap: + layers = list(self.m.children()) + if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin + layers = [*list(layers[0].children()), *layers[1:]] + self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers)) + self.split = split + else: + self.split = False + self.m.head = self.m.heads = nn.Identity() + + def forward(self, x): + """ + Forward pass through the model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor | List[torch.Tensor]): Output tensor or list of tensors. + """ + if self.split: + y = [x] + y.extend(m(y[-1]) for m in self.m) + else: + y = self.m(x) + return y + + +class AAttn(nn.Module): + """ + Area-attention module for YOLO models, providing efficient attention mechanisms. + + This module implements an area-based attention mechanism that processes input features in a spatially-aware manner, + making it particularly effective for object detection tasks. + + Attributes: + area (int): Number of areas the feature map is divided. + num_heads (int): Number of heads into which the attention mechanism is divided. + head_dim (int): Dimension of each attention head. + qkv (Conv): Convolution layer for computing query, key and value tensors. + proj (Conv): Projection convolution layer. + pe (Conv): Position encoding convolution layer. + + Methods: + forward: Applies area-attention to input tensor. + + Examples: + >>> attn = AAttn(dim=256, num_heads=8, area=4) + >>> x = torch.randn(1, 256, 32, 32) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 256, 32, 32]) + """ + + def __init__(self, dim, num_heads, area=1): + """ + Initialize an Area-attention module for YOLO models. + + Args: + dim (int): Number of hidden channels. + num_heads (int): Number of heads into which the attention mechanism is divided. + area (int): Number of areas the feature map is divided, default is 1. + """ + super().__init__() + self.area = area + + self.num_heads = num_heads + self.head_dim = head_dim = dim // num_heads + all_head_dim = head_dim * self.num_heads + + self.qkv = Conv(dim, all_head_dim * 3, 1, act=False) + self.proj = Conv(all_head_dim, dim, 1, act=False) + self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False) + + def forward(self, x): + """ + Process the input tensor through the area-attention. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after area-attention. + """ + B, C, H, W = x.shape + N = H * W + + qkv = self.qkv(x).flatten(2).transpose(1, 2) + if self.area > 1: + qkv = qkv.reshape(B * self.area, N // self.area, C * 3) + B, N, _ = qkv.shape + q, k, v = ( + qkv.view(B, N, self.num_heads, self.head_dim * 3) + .permute(0, 2, 3, 1) + .split([self.head_dim, self.head_dim, self.head_dim], dim=2) + ) + attn = (q.transpose(-2, -1) @ k) * (self.head_dim**-0.5) + attn = attn.softmax(dim=-1) + x = v @ attn.transpose(-2, -1) + x = x.permute(0, 3, 1, 2) + v = v.permute(0, 3, 1, 2) + + if self.area > 1: + x = x.reshape(B // self.area, N * self.area, C) + v = v.reshape(B // self.area, N * self.area, C) + B, N, _ = x.shape + + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + v = v.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + + x = x + self.pe(v) + return self.proj(x) + + +class ABlock(nn.Module): + """ + Area-attention block module for efficient feature extraction in YOLO models. + + This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps. + It uses a novel area-based attention approach that is more efficient than traditional self-attention while + maintaining effectiveness. + + Attributes: + attn (AAttn): Area-attention module for processing spatial features. + mlp (nn.Sequential): Multi-layer perceptron for feature transformation. + + Methods: + _init_weights: Initializes module weights using truncated normal distribution. + forward: Applies area-attention and feed-forward processing to input tensor. + + Examples: + >>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1) + >>> x = torch.randn(1, 256, 32, 32) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 256, 32, 32]) + """ + + def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1): + """ + Initialize an Area-attention block module. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of heads into which the attention mechanism is divided. + mlp_ratio (float): Expansion ratio for MLP hidden dimension. + area (int): Number of areas the feature map is divided. + """ + super().__init__() + + self.attn = AAttn(dim, num_heads=num_heads, area=area) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + """ + Initialize weights using a truncated normal distribution. + + Args: + m (nn.Module): Module to initialize. + """ + if isinstance(m, nn.Conv2d): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + """ + Forward pass through ABlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after area-attention and feed-forward processing. + """ + x = x + self.attn(x) + return x + self.mlp(x) + + +class A2C2f(nn.Module): + """ + Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms. + + This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature + processing. It supports both area-attention and standard convolution modes. + + Attributes: + cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels. + cv2 (Conv): Final 1x1 convolution layer that processes concatenated features. + gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention. + m (nn.ModuleList): List of either ABlock or C3k modules for feature processing. + + Methods: + forward: Processes input through area-attention or standard convolution pathway. + + Examples: + >>> m = A2C2f(512, 512, n=1, a2=True, area=1) + >>> x = torch.randn(1, 512, 32, 32) + >>> output = m(x) + >>> print(output.shape) + torch.Size([1, 512, 32, 32]) + """ + + def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True): + """ + Initialize Area-Attention C2f module. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + n (int): Number of ABlock or C3k modules to stack. + a2 (bool): Whether to use area attention blocks. If False, uses C3k blocks instead. + area (int): Number of areas the feature map is divided. + residual (bool): Whether to use residual connections with learnable gamma parameter. + mlp_ratio (float): Expansion ratio for MLP hidden dimension. + e (float): Channel expansion ratio for hidden channels. + g (int): Number of groups for grouped convolutions. + shortcut (bool): Whether to use shortcut connections in C3k blocks. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32." + + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv((1 + n) * c_, c2, 1) + + self.gamma = nn.Parameter(0.01 * torch.ones(c2), requires_grad=True) if a2 and residual else None + self.m = nn.ModuleList( + nn.Sequential(*(ABlock(c_, c_ // 32, mlp_ratio, area) for _ in range(2))) + if a2 + else C3k(c_, c_, 2, shortcut, g) + for _ in range(n) + ) + + def forward(self, x): + """ + Forward pass through A2C2f layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after processing. + """ + y = [self.cv1(x)] + y.extend(m(y[-1]) for m in self.m) + y = self.cv2(torch.cat(y, 1)) + if self.gamma is not None: + return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y + return y diff --git a/tracking/ultralytics/nn/modules/conv.py b/tracking/ultralytics/nn/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9c331368b1551b24084cf83642d1c6253802c976 --- /dev/null +++ b/tracking/ultralytics/nn/modules/conv.py @@ -0,0 +1,714 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Convolution modules.""" + +import math + +import numpy as np +import torch +import torch.nn as nn + +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "RepConv", + "Index", +) + + +def autopad(k, p=None, d=1): # kernel, padding, dilation + """Pad to 'same' shape outputs.""" + if d > 1: + k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +class Conv(nn.Module): + """ + Standard convolution module with batch normalization and activation. + + Attributes: + conv (nn.Conv2d): Convolutional layer. + bn (nn.BatchNorm2d): Batch normalization layer. + act (nn.Module): Activation function layer. + default_act (nn.Module): Default activation function (SiLU). + """ + + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): + """ + Initialize Conv layer with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + p (int, optional): Padding. + g (int): Groups. + d (int): Dilation. + act (bool | nn.Module): Activation function. + """ + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + def forward(self, x): + """ + Apply convolution, batch normalization and activation to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.bn(self.conv(x))) + + def forward_fuse(self, x): + """ + Apply convolution and activation without batch normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.conv(x)) + + +class Conv2(Conv): + """ + Simplified RepConv module with Conv fusing. + + Attributes: + conv (nn.Conv2d): Main 3x3 convolutional layer. + cv2 (nn.Conv2d): Additional 1x1 convolutional layer. + bn (nn.BatchNorm2d): Batch normalization layer. + act (nn.Module): Activation function layer. + """ + + def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True): + """ + Initialize Conv2 layer with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + p (int, optional): Padding. + g (int): Groups. + d (int): Dilation. + act (bool | nn.Module): Activation function. + """ + super().__init__(c1, c2, k, s, p, g=g, d=d, act=act) + self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv + + def forward(self, x): + """ + Apply convolution, batch normalization and activation to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.bn(self.conv(x) + self.cv2(x))) + + def forward_fuse(self, x): + """ + Apply fused convolution, batch normalization and activation to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.bn(self.conv(x))) + + def fuse_convs(self): + """Fuse parallel convolutions.""" + w = torch.zeros_like(self.conv.weight.data) + i = [x // 2 for x in w.shape[2:]] + w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone() + self.conv.weight.data += w + self.__delattr__("cv2") + self.forward = self.forward_fuse + + +class LightConv(nn.Module): + """ + Light convolution module with 1x1 and depthwise convolutions. + + This implementation is based on the PaddleDetection HGNetV2 backbone. + + Attributes: + conv1 (Conv): 1x1 convolution layer. + conv2 (DWConv): Depthwise convolution layer. + """ + + def __init__(self, c1, c2, k=1, act=nn.ReLU()): + """ + Initialize LightConv layer with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size for depthwise convolution. + act (nn.Module): Activation function. + """ + super().__init__() + self.conv1 = Conv(c1, c2, 1, act=False) + self.conv2 = DWConv(c2, c2, k, act=act) + + def forward(self, x): + """ + Apply 2 convolutions to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.conv2(self.conv1(x)) + + +class DWConv(Conv): + """Depth-wise convolution module.""" + + def __init__(self, c1, c2, k=1, s=1, d=1, act=True): + """ + Initialize depth-wise convolution with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + d (int): Dilation. + act (bool | nn.Module): Activation function. + """ + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) + + +class DWConvTranspose2d(nn.ConvTranspose2d): + """Depth-wise transpose convolution module.""" + + def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): + """ + Initialize depth-wise transpose convolution with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + p1 (int): Padding. + p2 (int): Output padding. + """ + super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) + + +class ConvTranspose(nn.Module): + """ + Convolution transpose module with optional batch normalization and activation. + + Attributes: + conv_transpose (nn.ConvTranspose2d): Transposed convolution layer. + bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer. + act (nn.Module): Activation function layer. + default_act (nn.Module): Default activation function (SiLU). + """ + + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): + """ + Initialize ConvTranspose layer with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + p (int): Padding. + bn (bool): Use batch normalization. + act (bool | nn.Module): Activation function. + """ + super().__init__() + self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn) + self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity() + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + def forward(self, x): + """ + Apply transposed convolution, batch normalization and activation to input. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.bn(self.conv_transpose(x))) + + def forward_fuse(self, x): + """ + Apply activation and convolution transpose operation to input. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.conv_transpose(x)) + + +class Focus(nn.Module): + """ + Focus module for concentrating feature information. + + Slices input tensor into 4 parts and concatenates them in the channel dimension. + + Attributes: + conv (Conv): Convolution layer. + """ + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): + """ + Initialize Focus module with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + p (int, optional): Padding. + g (int): Groups. + act (bool | nn.Module): Activation function. + """ + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act) + # self.contract = Contract(gain=2) + + def forward(self, x): + """ + Apply Focus operation and convolution to input tensor. + + Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2). + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)) + # return self.conv(self.contract(x)) + + +class GhostConv(nn.Module): + """ + Ghost Convolution module. + + Generates more features with fewer parameters by using cheap operations. + + Attributes: + cv1 (Conv): Primary convolution. + cv2 (Conv): Cheap operation convolution. + + References: + https://github.com/huawei-noah/ghostnet + """ + + def __init__(self, c1, c2, k=1, s=1, g=1, act=True): + """ + Initialize Ghost Convolution module with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + g (int): Groups. + act (bool | nn.Module): Activation function. + """ + super().__init__() + c_ = c2 // 2 # hidden channels + self.cv1 = Conv(c1, c_, k, s, None, g, act=act) + self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act) + + def forward(self, x): + """ + Apply Ghost Convolution to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor with concatenated features. + """ + y = self.cv1(x) + return torch.cat((y, self.cv2(y)), 1) + + +class RepConv(nn.Module): + """ + RepConv module with training and deploy modes. + + This module is used in RT-DETR and can fuse convolutions during inference for efficiency. + + Attributes: + conv1 (Conv): 3x3 convolution. + conv2 (Conv): 1x1 convolution. + bn (nn.BatchNorm2d, optional): Batch normalization for identity branch. + act (nn.Module): Activation function. + default_act (nn.Module): Default activation function (SiLU). + + References: + https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + """ + + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False): + """ + Initialize RepConv module with given parameters. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + k (int): Kernel size. + s (int): Stride. + p (int): Padding. + g (int): Groups. + d (int): Dilation. + act (bool | nn.Module): Activation function. + bn (bool): Use batch normalization for identity branch. + deploy (bool): Deploy mode for inference. + """ + super().__init__() + assert k == 3 and p == 1 + self.g = g + self.c1 = c1 + self.c2 = c2 + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None + self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False) + self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False) + + def forward_fuse(self, x): + """ + Forward pass for deploy mode. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return self.act(self.conv(x)) + + def forward(self, x): + """ + Forward pass for training mode. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + id_out = 0 if self.bn is None else self.bn(x) + return self.act(self.conv1(x) + self.conv2(x) + id_out) + + def get_equivalent_kernel_bias(self): + """ + Calculate equivalent kernel and bias by fusing convolutions. + + Returns: + (tuple): Tuple containing: + - Equivalent kernel (torch.Tensor) + - Equivalent bias (torch.Tensor) + """ + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + kernelid, biasid = self._fuse_bn_tensor(self.bn) + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + @staticmethod + def _pad_1x1_to_3x3_tensor(kernel1x1): + """ + Pad a 1x1 kernel to 3x3 size. + + Args: + kernel1x1 (torch.Tensor): 1x1 convolution kernel. + + Returns: + (torch.Tensor): Padded 3x3 kernel. + """ + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + """ + Fuse batch normalization with convolution weights. + + Args: + branch (Conv | nn.BatchNorm2d | None): Branch to fuse. + + Returns: + (tuple): Tuple containing: + - Fused kernel (torch.Tensor) + - Fused bias (torch.Tensor) + """ + if branch is None: + return 0, 0 + if isinstance(branch, Conv): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + elif isinstance(branch, nn.BatchNorm2d): + if not hasattr(self, "id_tensor"): + input_dim = self.c1 // self.g + kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32) + for i in range(self.c1): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def fuse_convs(self): + """Fuse convolutions for inference by creating a single equivalent convolution.""" + if hasattr(self, "conv"): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.conv = nn.Conv2d( + in_channels=self.conv1.conv.in_channels, + out_channels=self.conv1.conv.out_channels, + kernel_size=self.conv1.conv.kernel_size, + stride=self.conv1.conv.stride, + padding=self.conv1.conv.padding, + dilation=self.conv1.conv.dilation, + groups=self.conv1.conv.groups, + bias=True, + ).requires_grad_(False) + self.conv.weight.data = kernel + self.conv.bias.data = bias + for para in self.parameters(): + para.detach_() + self.__delattr__("conv1") + self.__delattr__("conv2") + if hasattr(self, "nm"): + self.__delattr__("nm") + if hasattr(self, "bn"): + self.__delattr__("bn") + if hasattr(self, "id_tensor"): + self.__delattr__("id_tensor") + + +class ChannelAttention(nn.Module): + """ + Channel-attention module for feature recalibration. + + Applies attention weights to channels based on global average pooling. + + Attributes: + pool (nn.AdaptiveAvgPool2d): Global average pooling. + fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution. + act (nn.Sigmoid): Sigmoid activation for attention weights. + + References: + https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet + """ + + def __init__(self, channels: int) -> None: + """ + Initialize Channel-attention module. + + Args: + channels (int): Number of input channels. + """ + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) + self.act = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply channel attention to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Channel-attended output tensor. + """ + return x * self.act(self.fc(self.pool(x))) + + +class SpatialAttention(nn.Module): + """ + Spatial-attention module for feature recalibration. + + Applies attention weights to spatial dimensions based on channel statistics. + + Attributes: + cv1 (nn.Conv2d): Convolution layer for spatial attention. + act (nn.Sigmoid): Sigmoid activation for attention weights. + """ + + def __init__(self, kernel_size=7): + """ + Initialize Spatial-attention module. + + Args: + kernel_size (int): Size of the convolutional kernel (3 or 7). + """ + super().__init__() + assert kernel_size in {3, 7}, "kernel size must be 3 or 7" + padding = 3 if kernel_size == 7 else 1 + self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.act = nn.Sigmoid() + + def forward(self, x): + """ + Apply spatial attention to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Spatial-attended output tensor. + """ + return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1))) + + +class CBAM(nn.Module): + """ + Convolutional Block Attention Module. + + Combines channel and spatial attention mechanisms for comprehensive feature refinement. + + Attributes: + channel_attention (ChannelAttention): Channel attention module. + spatial_attention (SpatialAttention): Spatial attention module. + """ + + def __init__(self, c1, kernel_size=7): + """ + Initialize CBAM with given parameters. + + Args: + c1 (int): Number of input channels. + kernel_size (int): Size of the convolutional kernel for spatial attention. + """ + super().__init__() + self.channel_attention = ChannelAttention(c1) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x): + """ + Apply channel and spatial attention sequentially to input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Attended output tensor. + """ + return self.spatial_attention(self.channel_attention(x)) + + +class Concat(nn.Module): + """ + Concatenate a list of tensors along specified dimension. + + Attributes: + d (int): Dimension along which to concatenate tensors. + """ + + def __init__(self, dimension=1): + """ + Initialize Concat module. + + Args: + dimension (int): Dimension along which to concatenate tensors. + """ + super().__init__() + self.d = dimension + + def forward(self, x): + """ + Concatenate input tensors along specified dimension. + + Args: + x (List[torch.Tensor]): List of input tensors. + + Returns: + (torch.Tensor): Concatenated tensor. + """ + return torch.cat(x, self.d) + + +class Index(nn.Module): + """ + Returns a particular index of the input. + + Attributes: + index (int): Index to select from input. + """ + + def __init__(self, index=0): + """ + Initialize Index module. + + Args: + index (int): Index to select from input. + """ + super().__init__() + self.index = index + + def forward(self, x): + """ + Select and return a particular index from input. + + Args: + x (List[torch.Tensor]): List of input tensors. + + Returns: + (torch.Tensor): Selected tensor. + """ + return x[self.index] diff --git a/tracking/ultralytics/nn/modules/head.py b/tracking/ultralytics/nn/modules/head.py new file mode 100644 index 0000000000000000000000000000000000000000..9a341f4f75ea87238def39c08c2b8890ecddf419 --- /dev/null +++ b/tracking/ultralytics/nn/modules/head.py @@ -0,0 +1,624 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Model head modules.""" + +import copy +import math + +import torch +import torch.nn as nn +from torch.nn.init import constant_, xavier_uniform_ + +from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors + +from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto +from .conv import Conv, DWConv +from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer +from .utils import bias_init_with_prob, linear_init + +__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect" + + +class Detect(nn.Module): + """YOLO Detect head for detection models.""" + + dynamic = False # force grid reconstruction + export = False # export mode + format = None # export format + end2end = False # end2end + max_det = 300 # max_det + shape = None + anchors = torch.empty(0) # init + strides = torch.empty(0) # init + legacy = False # backward compatibility for v3/v5/v8/v9 models + + def __init__(self, nc=80, ch=()): + """Initialize the YOLO detection layer with specified number of classes and channels.""" + super().__init__() + self.nc = nc # number of classes + self.nl = len(ch) # number of detection layers + self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = torch.zeros(self.nl) # strides computed during build + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch + ) + self.cv3 = ( + nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + if self.legacy + else nn.ModuleList( + nn.Sequential( + nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), + nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch + ) + ) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() + + if self.end2end: + self.one2one_cv2 = copy.deepcopy(self.cv2) + self.one2one_cv3 = copy.deepcopy(self.cv3) + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + if self.end2end: + return self.forward_end2end(x) + + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + if self.training: # Training path + return x + y = self._inference(x) + return y if self.export else (y, x) + + def forward_end2end(self, x): + """ + Performs forward pass of the v10Detect module. + + Args: + x (tensor): Input tensor. + + Returns: + (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections. + If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately. + """ + x_detach = [xi.detach() for xi in x] + one2one = [ + torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl) + ] + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + if self.training: # Training path + return {"one2many": x, "one2one": one2one} + + y = self._inference(one2one) + y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc) + return y if self.export else (y, {"one2many": x, "one2one": one2one}) + + def _inference(self, x): + """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.""" + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + if self.format != "imx" and (self.dynamic or self.shape != shape): + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] + else: + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + + if self.export and self.format in {"tflite", "edgetpu"}: + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + elif self.export and self.format == "imx": + dbox = self.decode_bboxes( + self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False + ) + return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides + + return torch.cat((dbox, cls.sigmoid()), 1) + + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + if self.end2end: + for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + + def decode_bboxes(self, bboxes, anchors, xywh=True): + """Decode bounding boxes.""" + return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1) + + @staticmethod + def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): + """ + Post-processes YOLO model predictions. + + Args: + preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension + format [x, y, w, h, class_probs]. + max_det (int): Maximum detections per image. + nc (int, optional): Number of classes. Default: 80. + + Returns: + (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last + dimension format [x, y, w, h, max_class_prob, class_index]. + """ + batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84) + boxes, scores = preds.split([4, nc], dim=-1) + index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1) + boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4)) + scores = scores.gather(dim=1, index=index.repeat(1, 1, nc)) + scores, index = scores.flatten(1).topk(min(max_det, anchors)) + i = torch.arange(batch_size)[..., None] # batch indices + return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1) + + +class Segment(Detect): + """YOLO Segment head for segmentation models.""" + + def __init__(self, nc=80, nm=32, npr=256, ch=()): + """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" + super().__init__(nc, ch) + self.nm = nm # number of masks + self.npr = npr # number of protos + self.proto = Proto(ch[0], self.npr, self.nm) # protos + + c4 = max(ch[0] // 4, self.nm) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch) + + def forward(self, x): + """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.""" + p = self.proto(x[0]) # mask protos + bs = p.shape[0] # batch size + + mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients + x = Detect.forward(self, x) + if self.training: + return x, mc, p + return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p)) + + +class OBB(Detect): + """YOLO OBB detection head for detection with rotation models.""" + + def __init__(self, nc=80, ne=1, ch=()): + """Initialize OBB with number of classes `nc` and layer channels `ch`.""" + super().__init__(nc, ch) + self.ne = ne # number of extra parameters + + c4 = max(ch[0] // 4, self.ne) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch) + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + bs = x[0].shape[0] # batch size + angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits + # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it. + angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4] + # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2] + if not self.training: + self.angle = angle + x = Detect.forward(self, x) + if self.training: + return x, angle + return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle)) + + def decode_bboxes(self, bboxes, anchors): + """Decode rotated bounding boxes.""" + return dist2rbox(bboxes, self.angle, anchors, dim=1) + + +class Pose(Detect): + """YOLO Pose head for keypoints models.""" + + def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): + """Initialize YOLO network with default parameters and Convolutional Layers.""" + super().__init__(nc, ch) + self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) + self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total + + c4 = max(ch[0] // 4, self.nk) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch) + + def forward(self, x): + """Perform forward pass through YOLO model and return predictions.""" + bs = x[0].shape[0] # batch size + kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) + x = Detect.forward(self, x) + if self.training: + return x, kpt + pred_kpt = self.kpts_decode(bs, kpt) + return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt)) + + def kpts_decode(self, bs, kpts): + """Decodes keypoints.""" + ndim = self.kpt_shape[1] + if self.export: + if self.format in { + "tflite", + "edgetpu", + }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug + # Precompute normalization factor to increase numerical stability + y = kpts.view(bs, *self.kpt_shape, -1) + grid_h, grid_w = self.shape[2], self.shape[3] + grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1) + norm = self.strides / (self.stride[0] * grid_size) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm + else: + # NCNN fix + y = kpts.view(bs, *self.kpt_shape, -1) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides + if ndim == 3: + a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) + return a.view(bs, self.nk, -1) + else: + y = kpts.clone() + if ndim == 3: + y[:, 2::ndim] = y[:, 2::ndim].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug) + y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides + y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides + return y + + +class Classify(nn.Module): + """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).""" + + export = False # export mode + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): + """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.""" + super().__init__() + c_ = 1280 # efficientnet_b0 size + self.conv = Conv(c1, c_, k, s, p, g) + self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1) + self.drop = nn.Dropout(p=0.0, inplace=True) + self.linear = nn.Linear(c_, c2) # to x(b,c2) + + def forward(self, x): + """Performs a forward pass of the YOLO model on input image data.""" + if isinstance(x, list): + x = torch.cat(x, 1) + x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + if self.training: + return x + y = x.softmax(1) # get final output + return y if self.export else (y, x) + + +class WorldDetect(Detect): + """Head for integrating YOLO detection models with semantic understanding from text embeddings.""" + + def __init__(self, nc=80, embed=512, with_bn=False, ch=()): + """Initialize YOLO detection layer with nc classes and layer channels ch.""" + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) + self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch) + self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch) + + def forward(self, x, text): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1) + if self.training: + return x + + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2) + if self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] + else: + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + + if self.export and self.format in {"tflite", "edgetpu"}: + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides + + y = torch.cat((dbox, cls.sigmoid()), 1) + return y if self.export else (y, x) + + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + + +class RTDETRDecoder(nn.Module): + """ + Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection. + + This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes + and class labels for objects in an image. It integrates features from multiple layers and runs through a series of + Transformer decoder layers to output the final predictions. + """ + + export = False # export mode + + def __init__( + self, + nc=80, + ch=(512, 1024, 2048), + hd=256, # hidden dim + nq=300, # num queries + ndp=4, # num decoder points + nh=8, # num head + ndl=6, # num decoder layers + d_ffn=1024, # dim of feedforward + dropout=0.0, + act=nn.ReLU(), + eval_idx=-1, + # Training args + nd=100, # num denoising + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=False, + ): + """ + Initializes the RTDETRDecoder module with the given parameters. + + Args: + nc (int): Number of classes. Default is 80. + ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048). + hd (int): Dimension of hidden layers. Default is 256. + nq (int): Number of query points. Default is 300. + ndp (int): Number of decoder points. Default is 4. + nh (int): Number of heads in multi-head attention. Default is 8. + ndl (int): Number of decoder layers. Default is 6. + d_ffn (int): Dimension of the feed-forward networks. Default is 1024. + dropout (float): Dropout rate. Default is 0.0. + act (nn.Module): Activation function. Default is nn.ReLU. + eval_idx (int): Evaluation index. Default is -1. + nd (int): Number of denoising. Default is 100. + label_noise_ratio (float): Label noise ratio. Default is 0.5. + box_noise_scale (float): Box noise scale. Default is 1.0. + learnt_init_query (bool): Whether to learn initial query embeddings. Default is False. + """ + super().__init__() + self.hidden_dim = hd + self.nhead = nh + self.nl = len(ch) # num level + self.nc = nc + self.num_queries = nq + self.num_decoder_layers = ndl + + # Backbone feature projection + self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch) + # NOTE: simplified version but it's not consistent with .pt weights. + # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch) + + # Transformer module + decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp) + self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx) + + # Denoising part + self.denoising_class_embed = nn.Embedding(nc, hd) + self.num_denoising = nd + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # Decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(nq, hd) + self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2) + + # Encoder head + self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd)) + self.enc_score_head = nn.Linear(hd, nc) + self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3) + + # Decoder head + self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)]) + self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)]) + + self._reset_parameters() + + def forward(self, x, batch=None): + """Runs the forward pass of the module, returning bounding box and classification scores for the input.""" + from ultralytics.models.utils.ops import get_cdn_group + + # Input projection and embedding + feats, shapes = self._get_encoder_input(x) + + # Prepare denoising training + dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group( + batch, + self.nc, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + self.training, + ) + + embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) + + # Decoder + dec_bboxes, dec_scores = self.decoder( + embed, + refer_bbox, + feats, + shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask, + ) + x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta + if self.training: + return x + # (bs, 300, 4+nc) + y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) + return y if self.export else (y, x) + + def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2): + """Generates anchor bounding boxes for given shapes with specific grid size and validates them.""" + anchors = [] + for i, (h, w) in enumerate(shapes): + sy = torch.arange(end=h, dtype=dtype, device=device) + sx = torch.arange(end=w, dtype=dtype, device=device) + grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2) + + valid_WH = torch.tensor([w, h], dtype=dtype, device=device) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2) + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) + anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4) + + anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) + valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 + anchors = torch.log(anchors / (1 - anchors)) + anchors = anchors.masked_fill(~valid_mask, float("inf")) + return anchors, valid_mask + + def _get_encoder_input(self, x): + """Processes and returns encoder inputs by getting projection features from input and concatenating them.""" + # Get projection features + x = [self.input_proj[i](feat) for i, feat in enumerate(x)] + # Get encoder inputs + feats = [] + shapes = [] + for feat in x: + h, w = feat.shape[2:] + # [b, c, h, w] -> [b, h*w, c] + feats.append(feat.flatten(2).permute(0, 2, 1)) + # [nl, 2] + shapes.append([h, w]) + + # [b, h*w, c] + feats = torch.cat(feats, 1) + return feats, shapes + + def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None): + """Generates and prepares the input required for the decoder from the provided features and shapes.""" + bs = feats.shape[0] + # Prepare input for decoder + anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) + features = self.enc_output(valid_mask * feats) # bs, h*w, 256 + + enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) + + # Query selection + # (bs, num_queries) + topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1) + # (bs, num_queries) + batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1) + + # (bs, num_queries, 256) + top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1) + # (bs, num_queries, 4) + top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1) + + # Dynamic anchors + static content + refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors + + enc_bboxes = refer_bbox.sigmoid() + if dn_bbox is not None: + refer_bbox = torch.cat([dn_bbox, refer_bbox], 1) + enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1) + + embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features + if self.training: + refer_bbox = refer_bbox.detach() + if not self.learnt_init_query: + embeddings = embeddings.detach() + if dn_embed is not None: + embeddings = torch.cat([dn_embed, embeddings], 1) + + return embeddings, refer_bbox, enc_bboxes, enc_scores + + def _reset_parameters(self): + """Initializes or resets the parameters of the model's various components with predefined weights and biases.""" + # Class and bbox head init + bias_cls = bias_init_with_prob(0.01) / 80 * self.nc + # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets. + # linear_init(self.enc_score_head) + constant_(self.enc_score_head.bias, bias_cls) + constant_(self.enc_bbox_head.layers[-1].weight, 0.0) + constant_(self.enc_bbox_head.layers[-1].bias, 0.0) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + # linear_init(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight, 0.0) + constant_(reg_.layers[-1].bias, 0.0) + + linear_init(self.enc_output[0]) + xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for layer in self.input_proj: + xavier_uniform_(layer[0].weight) + + +class v10Detect(Detect): + """ + v10 Detection head from https://arxiv.org/pdf/2405.14458. + + Args: + nc (int): Number of classes. + ch (tuple): Tuple of channel sizes. + + Attributes: + max_det (int): Maximum number of detections. + + Methods: + __init__(self, nc=80, ch=()): Initializes the v10Detect object. + forward(self, x): Performs forward pass of the v10Detect module. + bias_init(self): Initializes biases of the Detect module. + + """ + + end2end = True + + def __init__(self, nc=80, ch=()): + """Initializes the v10Detect object with the specified number of classes and input channels.""" + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) # channels + # Light cls head + self.cv3 = nn.ModuleList( + nn.Sequential( + nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)), + nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch + ) + self.one2one_cv3 = copy.deepcopy(self.cv3) diff --git a/tracking/ultralytics/nn/modules/transformer.py b/tracking/ultralytics/nn/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6d534413d1a121974e753774fd9e5ba148d4d0b9 --- /dev/null +++ b/tracking/ultralytics/nn/modules/transformer.py @@ -0,0 +1,713 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Transformer modules.""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import constant_, xavier_uniform_ + +from .conv import Conv +from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch + +__all__ = ( + "TransformerEncoderLayer", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", +) + + +class TransformerEncoderLayer(nn.Module): + """ + Defines a single layer of the transformer encoder. + + Attributes: + ma (nn.MultiheadAttention): Multi-head attention module. + fc1 (nn.Linear): First linear layer in the feedforward network. + fc2 (nn.Linear): Second linear layer in the feedforward network. + norm1 (nn.LayerNorm): Layer normalization after attention. + norm2 (nn.LayerNorm): Layer normalization after feedforward network. + dropout (nn.Dropout): Dropout layer for the feedforward network. + dropout1 (nn.Dropout): Dropout layer after attention. + dropout2 (nn.Dropout): Dropout layer after feedforward network. + act (nn.Module): Activation function. + normalize_before (bool): Whether to apply normalization before attention and feedforward. + """ + + def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False): + """ + Initialize the TransformerEncoderLayer with specified parameters. + + Args: + c1 (int): Input dimension. + cm (int): Hidden dimension in the feedforward network. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + act (nn.Module): Activation function. + normalize_before (bool): Whether to apply normalization before attention and feedforward. + """ + super().__init__() + from ...utils.torch_utils import TORCH_1_9 + + if not TORCH_1_9: + raise ModuleNotFoundError( + "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)." + ) + self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.fc1 = nn.Linear(c1, cm) + self.fc2 = nn.Linear(cm, c1) + + self.norm1 = nn.LayerNorm(c1) + self.norm2 = nn.LayerNorm(c1) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.act = act + self.normalize_before = normalize_before + + @staticmethod + def with_pos_embed(tensor, pos=None): + """Add position embeddings to the tensor if provided.""" + return tensor if pos is None else tensor + pos + + def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """ + Perform forward pass with post-normalization. + + Args: + src (torch.Tensor): Input tensor. + src_mask (torch.Tensor, optional): Mask for the src sequence. + src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch. + pos (torch.Tensor, optional): Positional encoding. + + Returns: + (torch.Tensor): Output tensor after attention and feedforward. + """ + q = k = self.with_pos_embed(src, pos) + src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.fc2(self.dropout(self.act(self.fc1(src)))) + src = src + self.dropout2(src2) + return self.norm2(src) + + def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """ + Perform forward pass with pre-normalization. + + Args: + src (torch.Tensor): Input tensor. + src_mask (torch.Tensor, optional): Mask for the src sequence. + src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch. + pos (torch.Tensor, optional): Positional encoding. + + Returns: + (torch.Tensor): Output tensor after attention and feedforward. + """ + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.fc2(self.dropout(self.act(self.fc1(src2)))) + return src + self.dropout2(src2) + + def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """ + Forward propagates the input through the encoder module. + + Args: + src (torch.Tensor): Input tensor. + src_mask (torch.Tensor, optional): Mask for the src sequence. + src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch. + pos (torch.Tensor, optional): Positional encoding. + + Returns: + (torch.Tensor): Output tensor after transformer encoder layer. + """ + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class AIFI(TransformerEncoderLayer): + """ + Defines the AIFI transformer layer. + + This class extends TransformerEncoderLayer to work with 2D data by adding positional embeddings. + """ + + def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False): + """ + Initialize the AIFI instance with specified parameters. + + Args: + c1 (int): Input dimension. + cm (int): Hidden dimension in the feedforward network. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + act (nn.Module): Activation function. + normalize_before (bool): Whether to apply normalization before attention and feedforward. + """ + super().__init__(c1, cm, num_heads, dropout, act, normalize_before) + + def forward(self, x): + """ + Forward pass for the AIFI transformer layer. + + Args: + x (torch.Tensor): Input tensor with shape [B, C, H, W]. + + Returns: + (torch.Tensor): Output tensor with shape [B, C, H, W]. + """ + c, h, w = x.shape[1:] + pos_embed = self.build_2d_sincos_position_embedding(w, h, c) + # Flatten [B, C, H, W] to [B, HxW, C] + x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype)) + return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous() + + @staticmethod + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): + """ + Build 2D sine-cosine position embedding. + + Args: + w (int): Width of the feature map. + h (int): Height of the feature map. + embed_dim (int): Embedding dimension. + temperature (float): Temperature for the sine/cosine functions. + + Returns: + (torch.Tensor): Position embedding with shape [1, embed_dim, h*w]. + """ + assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None] + + +class TransformerLayer(nn.Module): + """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).""" + + def __init__(self, c, num_heads): + """ + Initialize a self-attention mechanism using linear transformations and multi-head attention. + + Args: + c (int): Input and output channel dimension. + num_heads (int): Number of attention heads. + """ + super().__init__() + self.q = nn.Linear(c, c, bias=False) + self.k = nn.Linear(c, c, bias=False) + self.v = nn.Linear(c, c, bias=False) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) + self.fc1 = nn.Linear(c, c, bias=False) + self.fc2 = nn.Linear(c, c, bias=False) + + def forward(self, x): + """ + Apply a transformer block to the input x and return the output. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after transformer layer. + """ + x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x + return self.fc2(self.fc1(x)) + x + + +class TransformerBlock(nn.Module): + """ + Vision Transformer https://arxiv.org/abs/2010.11929. + + Attributes: + conv (Conv, optional): Convolution layer if input and output channels differ. + linear (nn.Linear): Learnable position embedding. + tr (nn.Sequential): Sequential container of transformer layers. + c2 (int): Output channel dimension. + """ + + def __init__(self, c1, c2, num_heads, num_layers): + """ + Initialize a Transformer module with position embedding and specified number of heads and layers. + + Args: + c1 (int): Input channel dimension. + c2 (int): Output channel dimension. + num_heads (int): Number of attention heads. + num_layers (int): Number of transformer layers. + """ + super().__init__() + self.conv = None + if c1 != c2: + self.conv = Conv(c1, c2) + self.linear = nn.Linear(c2, c2) # learnable position embedding + self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers))) + self.c2 = c2 + + def forward(self, x): + """ + Forward propagates the input through the bottleneck module. + + Args: + x (torch.Tensor): Input tensor with shape [b, c1, w, h]. + + Returns: + (torch.Tensor): Output tensor with shape [b, c2, w, h]. + """ + if self.conv is not None: + x = self.conv(x) + b, _, w, h = x.shape + p = x.flatten(2).permute(2, 0, 1) + return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h) + + +class MLPBlock(nn.Module): + """Implements a single block of a multi-layer perceptron.""" + + def __init__(self, embedding_dim, mlp_dim, act=nn.GELU): + """ + Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function. + + Args: + embedding_dim (int): Input and output dimension. + mlp_dim (int): Hidden dimension. + act (nn.Module): Activation function. + """ + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MLPBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after MLP block. + """ + return self.lin2(self.act(self.lin1(x))) + + +class MLP(nn.Module): + """ + Implements a simple multi-layer perceptron (also called FFN). + + Attributes: + num_layers (int): Number of layers in the MLP. + layers (nn.ModuleList): List of linear layers. + sigmoid (bool): Whether to apply sigmoid to the output. + act (nn.Module): Activation function. + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=nn.ReLU, sigmoid=False): + """ + Initialize the MLP with specified input, hidden, output dimensions and number of layers. + + Args: + input_dim (int): Input dimension. + hidden_dim (int): Hidden dimension. + output_dim (int): Output dimension. + num_layers (int): Number of layers. + act (nn.Module): Activation function. + sigmoid (bool): Whether to apply sigmoid to the output. + """ + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid = sigmoid + self.act = act() + + def forward(self, x): + """ + Forward pass for the entire MLP. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after MLP. + """ + for i, layer in enumerate(self.layers): + x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.sigmoid() if getattr(self, "sigmoid", False) else x + + +class LayerNorm2d(nn.Module): + """ + 2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations. + + Original implementations in + https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py + and + https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py. + + Attributes: + weight (nn.Parameter): Learnable scale parameter. + bias (nn.Parameter): Learnable bias parameter. + eps (float): Small constant for numerical stability. + """ + + def __init__(self, num_channels, eps=1e-6): + """ + Initialize LayerNorm2d with the given parameters. + + Args: + num_channels (int): Number of channels in the input. + eps (float): Small constant for numerical stability. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x): + """ + Perform forward pass for 2D layer normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Normalized output tensor. + """ + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight[:, None, None] * x + self.bias[:, None, None] + + +class MSDeformAttn(nn.Module): + """ + Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations. + + https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py + + Attributes: + im2col_step (int): Step size for im2col operations. + d_model (int): Model dimension. + n_levels (int): Number of feature levels. + n_heads (int): Number of attention heads. + n_points (int): Number of sampling points per attention head per feature level. + sampling_offsets (nn.Linear): Linear layer for generating sampling offsets. + attention_weights (nn.Linear): Linear layer for generating attention weights. + value_proj (nn.Linear): Linear layer for projecting values. + output_proj (nn.Linear): Linear layer for projecting output. + """ + + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Initialize MSDeformAttn with the given parameters. + + Args: + d_model (int): Model dimension. + n_levels (int): Number of feature levels. + n_heads (int): Number of attention heads. + n_points (int): Number of sampling points per attention head per feature level. + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}") + _d_per_head = d_model // n_heads + # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation + assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`" + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + """Reset module parameters.""" + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward(self, query, refer_bbox, value, value_shapes, value_mask=None): + """ + Perform forward pass for multiscale deformable attention. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + + Args: + query (torch.Tensor): Tensor with shape [bs, query_length, C]. + refer_bbox (torch.Tensor): Tensor with shape [bs, query_length, n_levels, 2], range in [0, 1], + top-left (0,0), bottom-right (1, 1), including padding area. + value (torch.Tensor): Tensor with shape [bs, value_length, C]. + value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]. + value_mask (torch.Tensor, optional): Tensor with shape [bs, value_length], True for non-padding elements, + False for padding elements. + + Returns: + (torch.Tensor): Output tensor with shape [bs, Length_{query}, C]. + """ + bs, len_q = query.shape[:2] + len_v = value.shape[1] + assert sum(s[0] * s[1] for s in value_shapes) == len_v + + value = self.value_proj(value) + if value_mask is not None: + value = value.masked_fill(value_mask[..., None], float(0)) + value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + num_points = refer_bbox.shape[-1] + if num_points == 2: + offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1) + add = sampling_offsets / offset_normalizer[None, None, None, :, None, :] + sampling_locations = refer_bbox[:, :, None, :, None, :] + add + elif num_points == 4: + add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5 + sampling_locations = refer_bbox[:, :, None, :, None, :2] + add + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.") + output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) + return self.output_proj(output) + + +class DeformableTransformerDecoderLayer(nn.Module): + """ + Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py + + Attributes: + self_attn (nn.MultiheadAttention): Self-attention module. + dropout1 (nn.Dropout): Dropout after self-attention. + norm1 (nn.LayerNorm): Layer normalization after self-attention. + cross_attn (MSDeformAttn): Cross-attention module. + dropout2 (nn.Dropout): Dropout after cross-attention. + norm2 (nn.LayerNorm): Layer normalization after cross-attention. + linear1 (nn.Linear): First linear layer in the feedforward network. + act (nn.Module): Activation function. + dropout3 (nn.Dropout): Dropout in the feedforward network. + linear2 (nn.Linear): Second linear layer in the feedforward network. + dropout4 (nn.Dropout): Dropout after the feedforward network. + norm3 (nn.LayerNorm): Layer normalization after the feedforward network. + """ + + def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4): + """ + Initialize the DeformableTransformerDecoderLayer with the given parameters. + + Args: + d_model (int): Model dimension. + n_heads (int): Number of attention heads. + d_ffn (int): Dimension of the feedforward network. + dropout (float): Dropout probability. + act (nn.Module): Activation function. + n_levels (int): Number of feature levels. + n_points (int): Number of sampling points. + """ + super().__init__() + + # Self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # Cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # FFN + self.linear1 = nn.Linear(d_model, d_ffn) + self.act = act + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + """Add positional embeddings to the input tensor, if provided.""" + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + """ + Perform forward pass through the Feed-Forward Network part of the layer. + + Args: + tgt (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after FFN. + """ + tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + return self.norm3(tgt) + + def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None): + """ + Perform the forward pass through the entire decoder layer. + + Args: + embed (torch.Tensor): Input embeddings. + refer_bbox (torch.Tensor): Reference bounding boxes. + feats (torch.Tensor): Feature maps. + shapes (list): Feature shapes. + padding_mask (torch.Tensor, optional): Padding mask. + attn_mask (torch.Tensor, optional): Attention mask. + query_pos (torch.Tensor, optional): Query position embeddings. + + Returns: + (torch.Tensor): Output tensor after decoder layer. + """ + # Self attention + q = k = self.with_pos_embed(embed, query_pos) + tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[ + 0 + ].transpose(0, 1) + embed = embed + self.dropout1(tgt) + embed = self.norm1(embed) + + # Cross attention + tgt = self.cross_attn( + self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask + ) + embed = embed + self.dropout2(tgt) + embed = self.norm2(embed) + + # FFN + return self.forward_ffn(embed) + + +class DeformableTransformerDecoder(nn.Module): + """ + Implementation of Deformable Transformer Decoder based on PaddleDetection. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + + Attributes: + layers (nn.ModuleList): List of decoder layers. + num_layers (int): Number of decoder layers. + hidden_dim (int): Hidden dimension. + eval_idx (int): Index of the layer to use during evaluation. + """ + + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + """ + Initialize the DeformableTransformerDecoder with the given parameters. + + Args: + hidden_dim (int): Hidden dimension. + decoder_layer (nn.Module): Decoder layer module. + num_layers (int): Number of decoder layers. + eval_idx (int): Index of the layer to use during evaluation. + """ + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward( + self, + embed, # decoder embeddings + refer_bbox, # anchor + feats, # image features + shapes, # feature shapes + bbox_head, + score_head, + pos_mlp, + attn_mask=None, + padding_mask=None, + ): + """ + Perform the forward pass through the entire decoder. + + Args: + embed (torch.Tensor): Decoder embeddings. + refer_bbox (torch.Tensor): Reference bounding boxes. + feats (torch.Tensor): Image features. + shapes (list): Feature shapes. + bbox_head (nn.Module): Bounding box prediction head. + score_head (nn.Module): Score prediction head. + pos_mlp (nn.Module): Position MLP. + attn_mask (torch.Tensor, optional): Attention mask. + padding_mask (torch.Tensor, optional): Padding mask. + + Returns: + dec_bboxes (torch.Tensor): Decoded bounding boxes. + dec_cls (torch.Tensor): Decoded classification scores. + """ + output = embed + dec_bboxes = [] + dec_cls = [] + last_refined_bbox = None + refer_bbox = refer_bbox.sigmoid() + for i, layer in enumerate(self.layers): + output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox)) + + bbox = bbox_head[i](output) + refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox)) + + if self.training: + dec_cls.append(score_head[i](output)) + if i == 0: + dec_bboxes.append(refined_bbox) + else: + dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox))) + elif i == self.eval_idx: + dec_cls.append(score_head[i](output)) + dec_bboxes.append(refined_bbox) + break + + last_refined_bbox = refined_bbox + refer_bbox = refined_bbox.detach() if self.training else refined_bbox + + return torch.stack(dec_bboxes), torch.stack(dec_cls) diff --git a/tracking/ultralytics/nn/modules/utils.py b/tracking/ultralytics/nn/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..944b2affaa6434e05ba66ea64593e175b0cd288a --- /dev/null +++ b/tracking/ultralytics/nn/modules/utils.py @@ -0,0 +1,98 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import uniform_ + +__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid" + + +def _get_clones(module, n): + """Create a list of cloned modules from the given module.""" + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def bias_init_with_prob(prior_prob=0.01): + """Initialize conv/fc bias value according to a given probability value.""" + return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init + + +def linear_init(module): + """Initialize the weights and biases of a linear module.""" + bound = 1 / math.sqrt(module.weight.shape[0]) + uniform_(module.weight, -bound, bound) + if hasattr(module, "bias") and module.bias is not None: + uniform_(module.bias, -bound, bound) + + +def inverse_sigmoid(x, eps=1e-5): + """Calculate the inverse sigmoid function for a tensor.""" + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: + """ + Implement multi-scale deformable attention in PyTorch. + + This function performs deformable attention across multiple feature map scales, allowing the model to attend to + different spatial locations with learned offsets. + + Args: + value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims). + value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2). + sampling_locations (torch.Tensor): The sampling locations with shape + (bs, num_queries, num_heads, num_levels, num_points, 2). + attention_weights (torch.Tensor): The attention weights with shape + (bs, num_queries, num_heads, num_levels, num_points). + + Returns: + (torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims). + + References: + https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py + """ + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) + return output.transpose(1, 2).contiguous() diff --git a/tracking/ultralytics/nn/tasks.py b/tracking/ultralytics/nn/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..f26aca596edd8350380da96ef0bce18a4e6b1fcd --- /dev/null +++ b/tracking/ultralytics/nn/tasks.py @@ -0,0 +1,1345 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import pickle +import re +import types +from copy import deepcopy +from pathlib import Path + +import torch + +from ultralytics.nn.modules import ( + AIFI, + C1, + C2, + C2PSA, + C3, + C3TR, + ELAN1, + OBB, + PSA, + SPP, + SPPELAN, + SPPF, + A2C2f, + AConv, + ADown, + Bottleneck, + BottleneckCSP, + C2f, + C2fAttn, + C2fCIB, + C2fPSA, + C3Ghost, + C3k2, + C3x, + CBFuse, + CBLinear, + Classify, + Concat, + Conv, + Conv2, + ConvTranspose, + Detect, + DWConv, + DWConvTranspose2d, + Focus, + GhostBottleneck, + GhostConv, + HGBlock, + HGStem, + ImagePoolingAttn, + Index, + Pose, + RepC3, + RepConv, + RepNCSPELAN4, + RepVGGDW, + ResNetLayer, + RTDETRDecoder, + SCDown, + Segment, + TorchVision, + WorldDetect, + v10Detect, +) +from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml +from ultralytics.utils.loss import ( + E2EDetectLoss, + v8ClassificationLoss, + v8DetectionLoss, + v8OBBLoss, + v8PoseLoss, + v8SegmentationLoss, +) +from ultralytics.utils.ops import make_divisible +from ultralytics.utils.plotting import feature_visualization +from ultralytics.utils.torch_utils import ( + fuse_conv_and_bn, + fuse_deconv_and_bn, + initialize_weights, + intersect_dicts, + model_info, + scale_img, + time_sync, +) + +try: + import thop +except ImportError: + thop = None # conda support without 'ultralytics-thop' installed + + +class BaseModel(torch.nn.Module): + """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" + + def forward(self, x, *args, **kwargs): + """ + Perform forward pass of the model for either training or inference. + + If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference. + + Args: + x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training. + *args (Any): Variable length argument list. + **kwargs (Any): Arbitrary keyword arguments. + + Returns: + (torch.Tensor): Loss if x is a dict (training), or network predictions (inference). + """ + if isinstance(x, dict): # for cases of training and validating while training. + return self.loss(x, *args, **kwargs) + return self.predict(x, *args, **kwargs) + + def predict(self, x, profile=False, visualize=False, augment=False, embed=None): + """ + Perform a forward pass through the network. + + Args: + x (torch.Tensor): The input tensor to the model. + profile (bool): Print the computation time of each layer if True. + visualize (bool): Save the feature maps of the model if True. + augment (bool): Augment image during prediction. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): The last output of the model. + """ + if augment: + return self._predict_augment(x) + return self._predict_once(x, profile, visualize, embed) + + def _predict_once(self, x, profile=False, visualize=False, embed=None): + """ + Perform a forward pass through the network. + + Args: + x (torch.Tensor): The input tensor to the model. + profile (bool): Print the computation time of each layer if True. + visualize (bool): Save the feature maps of the model if True. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): The last output of the model. + """ + y, dt, embeddings = [], [], [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + return x + + def _predict_augment(self, x): + """Perform augmentations on input image x and return augmented inference.""" + LOGGER.warning( + f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. " + f"Reverting to single-scale prediction." + ) + return self._predict_once(x) + + def _profile_one_layer(self, m, x, dt): + """ + Profile the computation time and FLOPs of a single layer of the model on a given input. + + Args: + m (torch.nn.Module): The layer to be profiled. + x (torch.Tensor): The input data to the layer. + dt (list): A list to store the computation time of the layer. + """ + c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix + flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs + t = time_sync() + for _ in range(10): + m(x.copy() if c else x) + dt.append((time_sync() - t) * 100) + if m == self.model[0]: + LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") + LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}") + if c: + LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") + + def fuse(self, verbose=True): + """ + Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation + efficiency. + + Returns: + (torch.nn.Module): The fused model is returned. + """ + if not self.is_fused(): + for m in self.model.modules(): + if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): + if isinstance(m, Conv2): + m.fuse_convs() + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.forward_fuse # update forward + if isinstance(m, ConvTranspose) and hasattr(m, "bn"): + m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) + delattr(m, "bn") # remove batchnorm + m.forward = m.forward_fuse # update forward + if isinstance(m, RepConv): + m.fuse_convs() + m.forward = m.forward_fuse # update forward + if isinstance(m, RepVGGDW): + m.fuse() + m.forward = m.forward_fuse + self.info(verbose=verbose) + + return self + + def is_fused(self, thresh=10): + """ + Check if the model has less than a certain threshold of BatchNorm layers. + + Args: + thresh (int, optional): The threshold number of BatchNorm layers. + + Returns: + (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. + """ + bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model + + def info(self, detailed=False, verbose=True, imgsz=640): + """ + Print model information. + + Args: + detailed (bool): If True, prints out detailed information about the model. + verbose (bool): If True, prints out the model information. + imgsz (int): The size of the image that the model will be trained on. + """ + return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz) + + def _apply(self, fn): + """ + Apply a function to all tensors in the model that are not parameters or registered buffers. + + Args: + fn (function): The function to apply to the model. + + Returns: + (BaseModel): An updated BaseModel object. + """ + self = super()._apply(fn) + m = self.model[-1] # Detect() + if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + m.stride = fn(m.stride) + m.anchors = fn(m.anchors) + m.strides = fn(m.strides) + return self + + def load(self, weights, verbose=True): + """ + Load weights into the model. + + Args: + weights (dict | torch.nn.Module): The pre-trained weights to be loaded. + verbose (bool, optional): Whether to log the transfer progress. + """ + model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts + csd = model.float().state_dict() # checkpoint state_dict as FP32 + csd = intersect_dicts(csd, self.state_dict()) # intersect + self.load_state_dict(csd, strict=False) # load + if verbose: + LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") + + def loss(self, batch, preds=None): + """ + Compute loss. + + Args: + batch (dict): Batch to compute loss on. + preds (torch.Tensor | List[torch.Tensor], optional): Predictions. + """ + if getattr(self, "criterion", None) is None: + self.criterion = self.init_criterion() + + preds = self.forward(batch["img"]) if preds is None else preds + return self.criterion(preds, batch) + + def init_criterion(self): + """Initialize the loss criterion for the BaseModel.""" + raise NotImplementedError("compute_loss() needs to be implemented by task heads") + + +class DetectionModel(BaseModel): + """YOLO detection model.""" + + def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes + """ + Initialize the YOLO detection model with the given config and parameters. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Whether to display model information. + """ + super().__init__() + self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict + if self.yaml["backbone"][0][2] == "Silence": + LOGGER.warning( + "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. " + "Please delete local *.pt file and re-download the latest model checkpoint." + ) + self.yaml["backbone"][0][2] = "nn.Identity" + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml["nc"] = nc # override YAML value + self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.inplace = self.yaml.get("inplace", True) + self.end2end = getattr(self.model[-1], "end2end", False) + + # Build strides + m = self.model[-1] # Detect() + if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + s = 256 # 2x min stride + m.inplace = self.inplace + + def _forward(x): + """Perform a forward pass through the model, handling different Detect subclass types accordingly.""" + if self.end2end: + return self.forward(x)["one2many"] + return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) + + m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward + self.stride = m.stride + m.bias_init() # only run once + else: + self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR + + # Init weights, biases + initialize_weights(self) + if verbose: + self.info() + LOGGER.info("") + + def _predict_augment(self, x): + """ + Perform augmentations on input image x and return augmented inference and train outputs. + + Args: + x (torch.Tensor): Input image tensor. + + Returns: + (torch.Tensor): Augmented inference output. + """ + if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel": + LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.") + return self._predict_once(x) + img_size = x.shape[-2:] # height, width + s = [1, 0.83, 0.67] # scales + f = [None, 3, None] # flips (2-ud, 3-lr) + y = [] # outputs + for si, fi in zip(s, f): + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) + yi = super().predict(xi)[0] # forward + yi = self._descale_pred(yi, fi, si, img_size) + y.append(yi) + y = self._clip_augmented(y) # clip augmented tails + return torch.cat(y, -1), None # augmented inference, train + + @staticmethod + def _descale_pred(p, flips, scale, img_size, dim=1): + """ + De-scale predictions following augmented inference (inverse operation). + + Args: + p (torch.Tensor): Predictions tensor. + flips (int): Flip type (0=none, 2=ud, 3=lr). + scale (float): Scale factor. + img_size (tuple): Original image size (height, width). + dim (int): Dimension to split at. + + Returns: + (torch.Tensor): De-scaled predictions. + """ + p[:, :4] /= scale # de-scale + x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim) + if flips == 2: + y = img_size[0] - y # de-flip ud + elif flips == 3: + x = img_size[1] - x # de-flip lr + return torch.cat((x, y, wh, cls), dim) + + def _clip_augmented(self, y): + """ + Clip YOLO augmented inference tails. + + Args: + y (List[torch.Tensor]): List of detection tensors. + + Returns: + (List[torch.Tensor]): Clipped detection tensors. + """ + nl = self.model[-1].nl # number of detection layers (P3-P5) + g = sum(4**x for x in range(nl)) # grid points + e = 1 # exclude layer count + i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices + y[0] = y[0][..., :-i] # large + i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices + y[-1] = y[-1][..., i:] # small + return y + + def init_criterion(self): + """Initialize the loss criterion for the DetectionModel.""" + return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self) + + +class OBBModel(DetectionModel): + """YOLO Oriented Bounding Box (OBB) model.""" + + def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True): + """ + Initialize YOLO OBB model with given config and parameters. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Whether to display model information. + """ + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the model.""" + return v8OBBLoss(self) + + +class SegmentationModel(DetectionModel): + """YOLO segmentation model.""" + + def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True): + """ + Initialize YOLOv8 segmentation model with given config and parameters. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Whether to display model information. + """ + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the SegmentationModel.""" + return v8SegmentationLoss(self) + + +class PoseModel(DetectionModel): + """YOLO pose model.""" + + def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): + """ + Initialize YOLOv8 Pose model. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + data_kpt_shape (tuple): Shape of keypoints data. + verbose (bool): Whether to display model information. + """ + if not isinstance(cfg, dict): + cfg = yaml_model_load(cfg) # load model YAML + if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): + LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") + cfg["kpt_shape"] = data_kpt_shape + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the PoseModel.""" + return v8PoseLoss(self) + + +class ClassificationModel(BaseModel): + """YOLO classification model.""" + + def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True): + """ + Initialize ClassificationModel with YAML, channels, number of classes, verbose flag. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Whether to display model information. + """ + super().__init__() + self._from_yaml(cfg, ch, nc, verbose) + + def _from_yaml(self, cfg, ch, nc, verbose): + """ + Set YOLOv8 model configurations and define the model architecture. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Whether to display model information. + """ + self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml["nc"] = nc # override YAML value + elif not nc and not self.yaml.get("nc", None): + raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") + self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist + self.stride = torch.Tensor([1]) # no stride constraints + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.info() + + @staticmethod + def reshape_outputs(model, nc): + """ + Update a TorchVision classification model to class count 'n' if required. + + Args: + model (torch.nn.Module): Model to update. + nc (int): New number of classes. + """ + name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module + if isinstance(m, Classify): # YOLO Classify() head + if m.linear.out_features != nc: + m.linear = torch.nn.Linear(m.linear.in_features, nc) + elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet + if m.out_features != nc: + setattr(model, name, torch.nn.Linear(m.in_features, nc)) + elif isinstance(m, torch.nn.Sequential): + types = [type(x) for x in m] + if torch.nn.Linear in types: + i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index + if m[i].out_features != nc: + m[i] = torch.nn.Linear(m[i].in_features, nc) + elif torch.nn.Conv2d in types: + i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index + if m[i].out_channels != nc: + m[i] = torch.nn.Conv2d( + m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None + ) + + def init_criterion(self): + """Initialize the loss criterion for the ClassificationModel.""" + return v8ClassificationLoss() + + +class RTDETRDetectionModel(DetectionModel): + """ + RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. + + This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both + the training and inference processes. RTDETR is an object detection and tracking model that extends from the + DetectionModel base class. + + Methods: + init_criterion: Initializes the criterion used for loss calculation. + loss: Computes and returns the loss during training. + predict: Performs a forward pass through the network and returns the output. + """ + + def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): + """ + Initialize the RTDETRDetectionModel. + + Args: + cfg (str | dict): Configuration file name or path. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Print additional information during initialization. + """ + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the RTDETRDetectionModel.""" + from ultralytics.models.utils.loss import RTDETRDetectionLoss + + return RTDETRDetectionLoss(nc=self.nc, use_vfl=True) + + def loss(self, batch, preds=None): + """ + Compute the loss for the given batch of data. + + Args: + batch (dict): Dictionary containing image and label data. + preds (torch.Tensor, optional): Precomputed model predictions. + + Returns: + (tuple): A tuple containing the total loss and main three losses in a tensor. + """ + if not hasattr(self, "criterion"): + self.criterion = self.init_criterion() + + img = batch["img"] + # NOTE: preprocess gt_bbox and gt_labels to list. + bs = len(img) + batch_idx = batch["batch_idx"] + gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] + targets = { + "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), + "bboxes": batch["bboxes"].to(device=img.device), + "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), + "gt_groups": gt_groups, + } + + preds = self.predict(img, batch=targets) if preds is None else preds + dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] + if dn_meta is None: + dn_bboxes, dn_scores = None, None + else: + dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) + dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) + + dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4) + dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) + + loss = self.criterion( + (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta + ) + # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. + return sum(loss.values()), torch.as_tensor( + [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device + ) + + def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): + """ + Perform a forward pass through the model. + + Args: + x (torch.Tensor): The input tensor. + profile (bool): If True, profile the computation time for each layer. + visualize (bool): If True, save feature maps for visualization. + batch (dict, optional): Ground truth data for evaluation. + augment (bool): If True, perform data augmentation during inference. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): Model's output tensor. + """ + y, dt, embeddings = [], [], [] # outputs + for m in self.model[:-1]: # except the head part + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + head = self.model[-1] + x = head([y[j] for j in head.f], batch) # head inference + return x + + +class WorldModel(DetectionModel): + """YOLOv8 World Model.""" + + def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): + """ + Initialize YOLOv8 world model with given config and parameters. + + Args: + cfg (str | dict): Model configuration file path or dictionary. + ch (int): Number of input channels. + nc (int, optional): Number of classes. + verbose (bool): Whether to display model information. + """ + self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder + self.clip_model = None # CLIP model placeholder + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def set_classes(self, text, batch=80, cache_clip_model=True): + """ + Set classes in advance so that model could do offline-inference without clip model. + + Args: + text (List[str]): List of class names. + batch (int): Batch size for processing text tokens. + cache_clip_model (bool): Whether to cache the CLIP model. + """ + try: + import clip + except ImportError: + check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + + if ( + not getattr(self, "clip_model", None) and cache_clip_model + ): # for backwards compatibility of models lacking clip_model attribute + self.clip_model = clip.load("ViT-B/32")[0] + model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0] + device = next(model.parameters()).device + text_token = clip.tokenize(text).to(device) + txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)] + txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0) + txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) + self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]) + self.model[-1].nc = len(text) + + def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None): + """ + Perform a forward pass through the model. + + Args: + x (torch.Tensor): The input tensor. + profile (bool): If True, profile the computation time for each layer. + visualize (bool): If True, save feature maps for visualization. + txt_feats (torch.Tensor, optional): The text features, use it if it's given. + augment (bool): If True, perform data augmentation during inference. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): Model's output tensor. + """ + txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype) + if len(txt_feats) != len(x) or self.model[-1].export: + txt_feats = txt_feats.expand(x.shape[0], -1, -1) + ori_txt_feats = txt_feats.clone() + y, dt, embeddings = [], [], [] # outputs + for m in self.model: # except the head part + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + if isinstance(m, C2fAttn): + x = m(x, txt_feats) + elif isinstance(m, WorldDetect): + x = m(x, ori_txt_feats) + elif isinstance(m, ImagePoolingAttn): + txt_feats = m(x, txt_feats) + else: + x = m(x) # run + + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + return x + + def loss(self, batch, preds=None): + """ + Compute loss. + + Args: + batch (dict): Batch to compute loss on. + preds (torch.Tensor | List[torch.Tensor], optional): Predictions. + """ + if not hasattr(self, "criterion"): + self.criterion = self.init_criterion() + + if preds is None: + preds = self.forward(batch["img"], txt_feats=batch["txt_feats"]) + return self.criterion(preds, batch) + + +class Ensemble(torch.nn.ModuleList): + """Ensemble of models.""" + + def __init__(self): + """Initialize an ensemble of models.""" + super().__init__() + + def forward(self, x, augment=False, profile=False, visualize=False): + """ + Generate the YOLO network's final layer. + + Args: + x (torch.Tensor): Input tensor. + augment (bool): Whether to augment the input. + profile (bool): Whether to profile the model. + visualize (bool): Whether to visualize the features. + + Returns: + (tuple): Tuple containing the concatenated predictions and None. + """ + y = [module(x, augment, profile, visualize)[0] for module in self] + # y = torch.stack(y).max(0)[0] # max ensemble + # y = torch.stack(y).mean(0) # mean ensemble + y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C) + return y, None # inference, train output + + +# Functions ------------------------------------------------------------------------------------------------------------ + + +@contextlib.contextmanager +def temporary_modules(modules=None, attributes=None): + """ + Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`). + + This function can be used to change the module paths during runtime. It's useful when refactoring code, + where you've moved a module from one location to another, but you still want to support the old import + paths for backwards compatibility. + + Args: + modules (dict, optional): A dictionary mapping old module paths to new module paths. + attributes (dict, optional): A dictionary mapping old module attributes to new module attributes. + + Examples: + >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}): + >>> import old.module # this will now import new.module + >>> from old.module import attribute # this will now import new.module.attribute + + Note: + The changes are only in effect inside the context manager and are undone once the context manager exits. + Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger + applications or libraries. Use this function with caution. + """ + if modules is None: + modules = {} + if attributes is None: + attributes = {} + import sys + from importlib import import_module + + try: + # Set attributes in sys.modules under their old name + for old, new in attributes.items(): + old_module, old_attr = old.rsplit(".", 1) + new_module, new_attr = new.rsplit(".", 1) + setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr)) + + # Set modules in sys.modules under their old name + for old, new in modules.items(): + sys.modules[old] = import_module(new) + + yield + finally: + # Remove the temporary module paths + for old in modules: + if old in sys.modules: + del sys.modules[old] + + +class SafeClass: + """A placeholder class to replace unknown classes during unpickling.""" + + def __init__(self, *args, **kwargs): + """Initialize SafeClass instance, ignoring all arguments.""" + pass + + def __call__(self, *args, **kwargs): + """Run SafeClass instance, ignoring all arguments.""" + pass + + +class SafeUnpickler(pickle.Unpickler): + """Custom Unpickler that replaces unknown classes with SafeClass.""" + + def find_class(self, module, name): + """ + Attempt to find a class, returning SafeClass if not among safe modules. + + Args: + module (str): Module name. + name (str): Class name. + + Returns: + (type): Found class or SafeClass. + """ + safe_modules = ( + "torch", + "collections", + "collections.abc", + "builtins", + "math", + "numpy", + # Add other modules considered safe + ) + if module in safe_modules: + return super().find_class(module, name) + else: + return SafeClass + + +def torch_safe_load(weight, safe_only=False): + """ + Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the + error, logs a warning message, and attempts to install the missing module via the check_requirements() function. + After installation, the function again attempts to load the model using torch.load(). + + Args: + weight (str): The file path of the PyTorch model. + safe_only (bool): If True, replace unknown classes with SafeClass during loading. + + Returns: + ckpt (dict): The loaded model checkpoint. + file (str): The loaded filename. + + Examples: + >>> from ultralytics.nn.tasks import torch_safe_load + >>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True) + """ + from ultralytics.utils.downloads import attempt_download_asset + + check_suffix(file=weight, suffix=".pt") + file = attempt_download_asset(weight) # search online if missing locally + try: + with temporary_modules( + modules={ + "ultralytics.yolo.utils": "ultralytics.utils", + "ultralytics.yolo.v8": "ultralytics.models.yolo", + "ultralytics.yolo.data": "ultralytics.data", + }, + attributes={ + "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e + "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10 + "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10 + }, + ): + if safe_only: + # Load via custom pickle module + safe_pickle = types.ModuleType("safe_pickle") + safe_pickle.Unpickler = SafeUnpickler + safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load() + with open(file, "rb") as f: + ckpt = torch.load(f, pickle_module=safe_pickle) + else: + ckpt = torch.load(file, map_location="cpu") + + except ModuleNotFoundError as e: # e.name is missing module name + if e.name == "models": + raise TypeError( + emojis( + f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained " + f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " + f"YOLOv8 at https://github.com/ultralytics/ultralytics." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'" + ) + ) from e + LOGGER.warning( + f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements." + f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolo11n.pt'" + ) + check_requirements(e.name) # install missing module + ckpt = torch.load(file, map_location="cpu") + + if not isinstance(ckpt, dict): + # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt") + LOGGER.warning( + f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. " + f"For optimal results, use model.save('filename.pt') to correctly save YOLO models." + ) + ckpt = {"model": ckpt.model} + + return ckpt, file + + +def attempt_load_weights(weights, device=None, inplace=True, fuse=False): + """ + Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a. + + Args: + weights (str | List[str]): Model weights path(s). + device (torch.device, optional): Device to load model to. + inplace (bool): Whether to do inplace operations. + fuse (bool): Whether to fuse model. + + Returns: + (torch.nn.Module): Loaded model. + """ + ensemble = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + ckpt, w = torch_safe_load(w) # load ckpt + args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model + + # Model compatibility updates + model.args = args # attach args to model + model.pt_path = w # attach *.pt file path to model + model.task = guess_model_task(model) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) + + # Append + ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode + + # Module updates + for m in ensemble.modules(): + if hasattr(m, "inplace"): + m.inplace = inplace + elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility + + # Return model + if len(ensemble) == 1: + return ensemble[-1] + + # Return ensemble + LOGGER.info(f"Ensemble created with {weights}\n") + for k in "names", "nc", "yaml": + setattr(ensemble, k, getattr(ensemble[0], k)) + ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride + assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" + return ensemble + + +def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): + """ + Load a single model weights. + + Args: + weight (str): Model weight path. + device (torch.device, optional): Device to load model to. + inplace (bool): Whether to do inplace operations. + fuse (bool): Whether to fuse model. + + Returns: + (tuple): Tuple containing the model and checkpoint. + """ + ckpt, weight = torch_safe_load(weight) # load ckpt + args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model + + # Model compatibility updates + model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model + model.pt_path = weight # attach *.pt file path to model + model.task = guess_model_task(model) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) + + model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode + + # Module updates + for m in model.modules(): + if hasattr(m, "inplace"): + m.inplace = inplace + elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility + + # Return model and ckpt + return model, ckpt + + +def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) + """ + Parse a YOLO model.yaml dictionary into a PyTorch model. + + Args: + d (dict): Model dictionary. + ch (int): Input channels. + verbose (bool): Whether to print model details. + + Returns: + (tuple): Tuple containing the PyTorch model and sorted list of output layers. + """ + import ast + + # Args + legacy = True # backward compatibility for v3/v5/v8/v9 models + max_channels = float("inf") + nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) + depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) + if scales: + scale = d.get("scale") + if not scale: + scale = tuple(scales.keys())[0] + LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.") + depth, width, max_channels = scales[scale] + + if act: + Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU() + if verbose: + LOGGER.info(f"{colorstr('activation:')} {act}") # print + + if verbose: + LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") + ch = [ch] + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + base_modules = frozenset( + { + Classify, + Conv, + ConvTranspose, + GhostConv, + Bottleneck, + GhostBottleneck, + SPP, + SPPF, + C2fPSA, + C2PSA, + DWConv, + Focus, + BottleneckCSP, + C1, + C2, + C2f, + C3k2, + RepNCSPELAN4, + ELAN1, + ADown, + AConv, + SPPELAN, + C2fAttn, + C3, + C3TR, + C3Ghost, + torch.nn.ConvTranspose2d, + DWConvTranspose2d, + C3x, + RepC3, + PSA, + SCDown, + C2fCIB, + A2C2f, + } + ) + repeat_modules = frozenset( # modules with 'repeat' arguments + { + BottleneckCSP, + C1, + C2, + C2f, + C3k2, + C2fAttn, + C3, + C3TR, + C3Ghost, + C3x, + RepC3, + C2fPSA, + C2fCIB, + C2PSA, + A2C2f, + } + ) + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = ( + getattr(torch.nn, m[3:]) + if "nn." in m + else getattr(__import__("torchvision").ops, m[16:]) + if "torchvision.ops." in m + else globals()[m] + ) # get module + for j, a in enumerate(args): + if isinstance(a, str): + with contextlib.suppress(ValueError): + args[j] = locals()[a] if a in locals() else ast.literal_eval(a) + n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain + if m in base_modules: + c1, c2 = ch[f], args[0] + if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) + c2 = make_divisible(min(c2, max_channels) * width, 8) + if m is C2fAttn: # set 1) embed channels and 2) num heads + args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) + args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]) + + args = [c1, c2, *args[1:]] + if m in repeat_modules: + args.insert(2, n) # number of repeats + n = 1 + if m is C3k2: # for M/L/X sizes + legacy = False + if scale in "mlx": + args[3] = True + if m is A2C2f: + legacy = False + if scale in "lx": # for L/X sizes + args.extend((True, 1.2)) + elif m is AIFI: + args = [ch[f], *args] + elif m in frozenset({HGStem, HGBlock}): + c1, cm, c2 = ch[f], args[0], args[1] + args = [c1, cm, c2, *args[2:]] + if m is HGBlock: + args.insert(4, n) # number of repeats + n = 1 + elif m is ResNetLayer: + c2 = args[1] if args[3] else args[1] * 4 + elif m is torch.nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[x] for x in f) + elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}): + args.append([ch[x] for x in f]) + if m is Segment: + args[2] = make_divisible(min(args[2], max_channels) * width, 8) + if m in {Detect, Segment, Pose, OBB}: + m.legacy = legacy + elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 + args.insert(1, [ch[x] for x in f]) + elif m is CBLinear: + c2 = args[0] + c1 = ch[f] + args = [c1, c2, *args[1:]] + elif m is CBFuse: + c2 = ch[f[-1]] + elif m in frozenset({TorchVision, Index}): + c2 = args[0] + c1 = ch[f] + args = [*args[1:]] + else: + c2 = ch[f] + + m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + m_.np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type + if verbose: + LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + if i == 0: + ch = [] + ch.append(c2) + return torch.nn.Sequential(*layers), sorted(save) + + +def yaml_model_load(path): + """ + Load a YOLOv8 model from a YAML file. + + Args: + path (str | Path): Path to the YAML file. + + Returns: + (dict): Model dictionary. + """ + path = Path(path) + if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): + new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) + LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") + path = path.with_name(new_stem + path.suffix) + + unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) + d = yaml_load(yaml_file) # model dict + d["scale"] = guess_model_scale(path) + d["yaml_file"] = str(path) + return d + + +def guess_model_scale(model_path): + """ + Extract the size character n, s, m, l, or x of the model's scale from the model path. + + Args: + model_path (str | Path): The path to the YOLO model's YAML file. + + Returns: + (str): The size character of the model's scale (n, s, m, l, or x). + """ + try: + return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # returns n, s, m, l, or x + except AttributeError: + return "" + + +def guess_model_task(model): + """ + Guess the task of a PyTorch model from its architecture or configuration. + + Args: + model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format. + + Returns: + (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb'). + """ + + def cfg2task(cfg): + """Guess from YAML dictionary.""" + m = cfg["head"][-1][-2].lower() # output module name + if m in {"classify", "classifier", "cls", "fc"}: + return "classify" + if "detect" in m: + return "detect" + if m == "segment": + return "segment" + if m == "pose": + return "pose" + if m == "obb": + return "obb" + + # Guess from model cfg + if isinstance(model, dict): + with contextlib.suppress(Exception): + return cfg2task(model) + # Guess from PyTorch model + if isinstance(model, torch.nn.Module): # PyTorch model + for x in "model.args", "model.model.args", "model.model.model.args": + with contextlib.suppress(Exception): + return eval(x)["task"] + for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": + with contextlib.suppress(Exception): + return cfg2task(eval(x)) + for m in model.modules(): + if isinstance(m, Segment): + return "segment" + elif isinstance(m, Classify): + return "classify" + elif isinstance(m, Pose): + return "pose" + elif isinstance(m, OBB): + return "obb" + elif isinstance(m, (Detect, WorldDetect, v10Detect)): + return "detect" + + # Guess from model filename + if isinstance(model, (str, Path)): + model = Path(model) + if "-seg" in model.stem or "segment" in model.parts: + return "segment" + elif "-cls" in model.stem or "classify" in model.parts: + return "classify" + elif "-pose" in model.stem or "pose" in model.parts: + return "pose" + elif "-obb" in model.stem or "obb" in model.parts: + return "obb" + elif "detect" in model.parts: + return "detect" + + # Unable to determine task from model + LOGGER.warning( + "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " + "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." + ) + return "detect" # assume detect diff --git a/tracking/ultralytics/solutions/__init__.py b/tracking/ultralytics/solutions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96f4b677ae5808931a791a0089d2c652b6cdec34 --- /dev/null +++ b/tracking/ultralytics/solutions/__init__.py @@ -0,0 +1,38 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .ai_gym import AIGym +from .analytics import Analytics +from .distance_calculation import DistanceCalculation +from .heatmap import Heatmap +from .instance_segmentation import InstanceSegmentation +from .object_blurrer import ObjectBlurrer +from .object_counter import ObjectCounter +from .object_cropper import ObjectCropper +from .parking_management import ParkingManagement, ParkingPtsSelection +from .queue_management import QueueManager +from .region_counter import RegionCounter +from .security_alarm import SecurityAlarm +from .speed_estimation import SpeedEstimator +from .streamlit_inference import Inference +from .trackzone import TrackZone +from .vision_eye import VisionEye + +__all__ = ( + "ObjectCounter", + "ObjectCropper", + "ObjectBlurrer", + "AIGym", + "RegionCounter", + "SecurityAlarm", + "Heatmap", + "InstanceSegmentation", + "VisionEye", + "SpeedEstimator", + "DistanceCalculation", + "QueueManager", + "ParkingManagement", + "ParkingPtsSelection", + "Analytics", + "Inference", + "TrackZone", +) diff --git a/tracking/ultralytics/solutions/ai_gym.py b/tracking/ultralytics/solutions/ai_gym.py new file mode 100644 index 0000000000000000000000000000000000000000..981895df4d1c9ad28013967b9a53cd63a3520324 --- /dev/null +++ b/tracking/ultralytics/solutions/ai_gym.py @@ -0,0 +1,122 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults + + +class AIGym(BaseSolution): + """ + A class to manage gym steps of people in a real-time video stream based on their poses. + + This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts + repetitions of exercises based on predefined angle thresholds for up and down positions. + + Attributes: + count (List[int]): Repetition counts for each detected person. + angle (List[float]): Current angle of the tracked body part for each person. + stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person. + initial_stage (str | None): Initial stage of the exercise. + up_angle (float): Angle threshold for considering the 'up' position of an exercise. + down_angle (float): Angle threshold for considering the 'down' position of an exercise. + kpts (List[int]): Indices of keypoints used for angle calculation. + + Methods: + process: Processes a frame to detect poses, calculate angles, and count repetitions. + + Examples: + >>> gym = AIGym(model="yolo11n-pose.pt") + >>> image = cv2.imread("gym_scene.jpg") + >>> results = gym.process(image) + >>> processed_image = results.plot_im + >>> cv2.imshow("Processed Image", processed_image) + >>> cv2.waitKey(0) + """ + + def __init__(self, **kwargs): + """ + Initialize AIGym for workout monitoring using pose estimation and predefined angles. + + Args: + **kwargs (Any): Keyword arguments passed to the parent class constructor. + model (str): Model name or path, defaults to "yolo11n-pose.pt". + """ + kwargs["model"] = kwargs.get("model", "yolo11n-pose.pt") + super().__init__(**kwargs) + self.count = [] # List for counts, necessary where there are multiple objects in frame + self.angle = [] # List for angle, necessary where there are multiple objects in frame + self.stage = [] # List for stage, necessary where there are multiple objects in frame + + # Extract details from CFG single time for usage later + self.initial_stage = None + self.up_angle = float(self.CFG["up_angle"]) # Pose up predefined angle to consider up pose + self.down_angle = float(self.CFG["down_angle"]) # Pose down predefined angle to consider down pose + self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage + + def process(self, im0): + """ + Monitor workouts using Ultralytics YOLO Pose Model. + + This function processes an input image to track and analyze human poses for workout monitoring. It uses + the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined + angle thresholds. + + Args: + im0 (np.ndarray): Input image for processing. + + Returns: + (SolutionResults): Contains processed image `plot_im`, + 'workout_count' (list of completed reps), + 'workout_stage' (list of current stages), + 'workout_angle' (list of angles), and + 'total_tracks' (total number of tracked individuals). + + Examples: + >>> gym = AIGym() + >>> image = cv2.imread("workout.jpg") + >>> results = gym.process(image) + >>> processed_image = results.plot_im + """ + annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks) + tracks = self.tracks[0] + + if tracks.boxes.id is not None: + if len(tracks) > len(self.count): # Add new entries for newly detected people + new_human = len(tracks) - len(self.count) + self.angle += [0] * new_human + self.count += [0] * new_human + self.stage += ["-"] * new_human + + # Enumerate over keypoints + for ind, k in enumerate(reversed(tracks.keypoints.data)): + # Get keypoints and estimate the angle + kpts = [k[int(self.kpts[i])].cpu() for i in range(3)] + self.angle[ind] = annotator.estimate_pose_angle(*kpts) + annotator.draw_specific_kpts(k, self.kpts, radius=self.line_width * 3) + + # Determine stage and count logic based on angle thresholds + if self.angle[ind] < self.down_angle: + if self.stage[ind] == "up": + self.count[ind] += 1 + self.stage[ind] = "down" + elif self.angle[ind] > self.up_angle: + self.stage[ind] = "up" + + # Display angle, count, and stage text + annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], # angle text for display + count_text=self.count[ind], # count text for workouts + stage_text=self.stage[ind], # stage position text + center_kpt=k[int(self.kpts[1])], # center keypoint for display + ) + plot_im = annotator.result() + self.display_output(plot_im) # Display output image, if environment support display + + # Return SolutionResults + return SolutionResults( + plot_im=plot_im, + workout_count=self.count, + workout_stage=self.stage, + workout_angle=self.angle, + total_tracks=len(self.track_ids), + ) diff --git a/tracking/ultralytics/solutions/analytics.py b/tracking/ultralytics/solutions/analytics.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0b4af7e442f58689be8d918e9c3ac3c03e94e1 --- /dev/null +++ b/tracking/ultralytics/solutions/analytics.py @@ -0,0 +1,251 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from itertools import cycle + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure + +from ultralytics.solutions.solutions import BaseSolution, SolutionResults # Import a parent class + + +class Analytics(BaseSolution): + """ + A class for creating and updating various types of charts for visual analytics. + + This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts + based on object detection and tracking data. + + Attributes: + type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area'). + x_label (str): Label for the x-axis. + y_label (str): Label for the y-axis. + bg_color (str): Background color of the chart frame. + fg_color (str): Foreground color of the chart frame. + title (str): Title of the chart window. + max_points (int): Maximum number of data points to display on the chart. + fontsize (int): Font size for text display. + color_cycle (cycle): Cyclic iterator for chart colors. + total_counts (int): Total count of detected objects (used for line charts). + clswise_count (Dict[str, int]): Dictionary for class-wise object counts. + fig (Figure): Matplotlib figure object for the chart. + ax (Axes): Matplotlib axes object for the chart. + canvas (FigureCanvas): Canvas for rendering the chart. + lines (dict): Dictionary to store line objects for area charts. + color_mapping (Dict[str, str]): Dictionary mapping class labels to colors for consistent visualization. + + Methods: + process: Process image data and update the chart. + update_graph: Update the chart with new data points. + + Examples: + >>> analytics = Analytics(analytics_type="line") + >>> frame = cv2.imread("image.jpg") + >>> results = analytics.process(frame, frame_number=1) + >>> cv2.imshow("Analytics", results.plot_im) + """ + + def __init__(self, **kwargs): + """Initialize Analytics class with various chart types for visual data representation.""" + super().__init__(**kwargs) + + self.type = self.CFG["analytics_type"] # extract type of analytics + self.x_label = "Classes" if self.type in {"bar", "pie"} else "Frame#" + self.y_label = "Total Counts" + + # Predefined data + self.bg_color = "#F3F3F3" # background color of frame + self.fg_color = "#111E68" # foreground color of frame + self.title = "Ultralytics Solutions" # window name + self.max_points = 45 # maximum points to be drawn on window + self.fontsize = 25 # text font size for display + figsize = (12.8, 7.2) # Set output image size 1280 * 720 + self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) + + self.total_counts = 0 # count variable for storing total counts i.e. for line + self.clswise_count = {} # dictionary for class-wise counts + + # Ensure line and area chart + if self.type in {"line", "area"}: + self.lines = {} + self.fig = Figure(facecolor=self.bg_color, figsize=figsize) + self.canvas = FigureCanvas(self.fig) # Set common axis properties + self.ax = self.fig.add_subplot(111, facecolor=self.bg_color) + if self.type == "line": + (self.line,) = self.ax.plot([], [], color="cyan", linewidth=self.line_width) + elif self.type in {"bar", "pie"}: + # Initialize bar or pie plot + self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color) + self.canvas = FigureCanvas(self.fig) # Set common axis properties + self.ax.set_facecolor(self.bg_color) + self.color_mapping = {} + + if self.type == "pie": # Ensure pie chart is circular + self.ax.axis("equal") + + def process(self, im0, frame_number): + """ + Process image data and run object tracking to update analytics charts. + + Args: + im0 (np.ndarray): Input image for processing. + frame_number (int): Video frame number for plotting the data. + + Returns: + (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects) + and 'classwise_count' (dict, per-class object count). + + Raises: + ModuleNotFoundError: If an unsupported chart type is specified. + + Examples: + >>> analytics = Analytics(analytics_type="line") + >>> frame = np.zeros((480, 640, 3), dtype=np.uint8) + >>> results = analytics.process(frame, frame_number=1) + """ + self.extract_tracks(im0) # Extract tracks + if self.type == "line": + for _ in self.boxes: + self.total_counts += 1 + plot_im = self.update_graph(frame_number=frame_number) + self.total_counts = 0 + elif self.type in {"pie", "bar", "area"}: + self.clswise_count = {} + for cls in self.clss: + if self.names[int(cls)] in self.clswise_count: + self.clswise_count[self.names[int(cls)]] += 1 + else: + self.clswise_count[self.names[int(cls)]] = 1 + plot_im = self.update_graph(frame_number=frame_number, count_dict=self.clswise_count, plot=self.type) + else: + raise ModuleNotFoundError(f"{self.type} chart is not supported ❌") + + # return output dictionary with summary for more usage + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), classwise_count=self.clswise_count) + + def update_graph(self, frame_number, count_dict=None, plot="line"): + """ + Update the graph with new data for single or multiple classes. + + Args: + frame_number (int): The current frame number. + count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple + classes. If None, updates a single line graph. + plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'. + + Returns: + (np.ndarray): Updated image containing the graph. + + Examples: + >>> analytics = Analytics(analytics_type="bar") + >>> frame_num = 10 + >>> results_dict = {"person": 5, "car": 3} + >>> updated_image = analytics.update_graph(frame_num, results_dict, plot="bar") + """ + if count_dict is None: + # Single line update + x_data = np.append(self.line.get_xdata(), float(frame_number)) + y_data = np.append(self.line.get_ydata(), float(self.total_counts)) + + if len(x_data) > self.max_points: + x_data, y_data = x_data[-self.max_points :], y_data[-self.max_points :] + + self.line.set_data(x_data, y_data) + self.line.set_label("Counts") + self.line.set_color("#7b0068") # Pink color + self.line.set_marker("*") + self.line.set_markersize(self.line_width * 5) + else: + labels = list(count_dict.keys()) + counts = list(count_dict.values()) + if plot == "area": + color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) + # Multiple lines or area update + x_data = self.ax.lines[0].get_xdata() if self.ax.lines else np.array([]) + y_data_dict = {key: np.array([]) for key in count_dict.keys()} + if self.ax.lines: + for line, key in zip(self.ax.lines, count_dict.keys()): + y_data_dict[key] = line.get_ydata() + + x_data = np.append(x_data, float(frame_number)) + max_length = len(x_data) + for key in count_dict.keys(): + y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key])) + if len(y_data_dict[key]) < max_length: + y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key]))) + if len(x_data) > self.max_points: + x_data = x_data[1:] + for key in count_dict.keys(): + y_data_dict[key] = y_data_dict[key][1:] + + self.ax.clear() + for key, y_data in y_data_dict.items(): + color = next(color_cycle) + self.ax.fill_between(x_data, y_data, color=color, alpha=0.7) + self.ax.plot( + x_data, + y_data, + color=color, + linewidth=self.line_width, + marker="o", + markersize=self.line_width * 5, + label=f"{key} Data Points", + ) + if plot == "bar": + self.ax.clear() # clear bar data + for label in labels: # Map labels to colors + if label not in self.color_mapping: + self.color_mapping[label] = next(self.color_cycle) + colors = [self.color_mapping[label] for label in labels] + bars = self.ax.bar(labels, counts, color=colors) + for bar, count in zip(bars, counts): + self.ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height(), + str(count), + ha="center", + va="bottom", + color=self.fg_color, + ) + # Create the legend using labels from the bars + for bar, label in zip(bars, labels): + bar.set_label(label) # Assign label to each bar + self.ax.legend(loc="upper left", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color) + if plot == "pie": + total = sum(counts) + percentages = [size / total * 100 for size in counts] + start_angle = 90 + self.ax.clear() + + # Create pie chart and create legend labels with percentages + wedges, _ = self.ax.pie( + counts, labels=labels, startangle=start_angle, textprops={"color": self.fg_color}, autopct=None + ) + legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)] + + # Assign the legend using the wedges and manually created labels + self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1)) + self.fig.subplots_adjust(left=0.1, right=0.75) # Adjust layout to fit the legend + + # Common plot settings + self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like + self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize) + self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3) + self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3) + + # Add and format legend + legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.bg_color) + for text in legend.get_texts(): + text.set_color(self.fg_color) + + # Redraw graph, update view, capture, and display the updated plot + self.ax.relim() + self.ax.autoscale_view() + self.canvas.draw() + im0 = np.array(self.canvas.renderer.buffer_rgba()) + im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR) + self.display_output(im0) + + return im0 # Return the image diff --git a/tracking/ultralytics/solutions/distance_calculation.py b/tracking/ultralytics/solutions/distance_calculation.py new file mode 100644 index 0000000000000000000000000000000000000000..eb34d19087b39e3fc34cd59b98ebca2764c3a85b --- /dev/null +++ b/tracking/ultralytics/solutions/distance_calculation.py @@ -0,0 +1,124 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math + +import cv2 + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class DistanceCalculation(BaseSolution): + """ + A class to calculate distance between two objects in a real-time video stream based on their tracks. + + This class extends BaseSolution to provide functionality for selecting objects and calculating the distance + between them in a video stream using YOLO object detection and tracking. + + Attributes: + left_mouse_count (int): Counter for left mouse button clicks. + selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs. + centroids (List[List[int]]): List to store centroids of selected bounding boxes. + + Methods: + mouse_event_for_distance: Handles mouse events for selecting objects in the video stream. + process: Processes video frames and calculates the distance between selected objects. + + Examples: + >>> distance_calc = DistanceCalculation() + >>> frame = cv2.imread("frame.jpg") + >>> results = distance_calc.process(frame) + >>> cv2.imshow("Distance Calculation", results.plot_im) + >>> cv2.waitKey(0) + """ + + def __init__(self, **kwargs): + """Initializes the DistanceCalculation class for measuring object distances in video streams.""" + super().__init__(**kwargs) + + # Mouse event information + self.left_mouse_count = 0 + self.selected_boxes = {} + self.centroids = [] # Store centroids of selected objects + + def mouse_event_for_distance(self, event, x, y, flags, param): + """ + Handles mouse events to select regions in a real-time video stream for distance calculation. + + Args: + event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN). + x (int): X-coordinate of the mouse pointer. + y (int): Y-coordinate of the mouse pointer. + flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY). + param (Any): Additional parameters passed to the function. + + Examples: + >>> # Assuming 'dc' is an instance of DistanceCalculation + >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance) + """ + if event == cv2.EVENT_LBUTTONDOWN: + self.left_mouse_count += 1 + if self.left_mouse_count <= 2: + for box, track_id in zip(self.boxes, self.track_ids): + if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes: + self.selected_boxes[track_id] = box + + elif event == cv2.EVENT_RBUTTONDOWN: + self.selected_boxes = {} + self.left_mouse_count = 0 + + def process(self, im0): + """ + Processes a video frame and calculates the distance between two selected bounding boxes. + + This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance + between two user-selected objects if they have been chosen. + + Args: + im0 (numpy.ndarray): The input image frame to process. + + Returns: + (SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number + of tracked objects, and `pixels_distance` (float) representing the distance between selected objects + in pixels. + + Examples: + >>> import numpy as np + >>> from ultralytics.solutions import DistanceCalculation + >>> dc = DistanceCalculation() + >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> results = dc.process(frame) + >>> print(f"Distance: {results.pixels_distance:.2f} pixels") + """ + self.extract_tracks(im0) # Extract tracks + annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + pixels_distance = 0 + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)]) + + # Update selected boxes if they're being tracked + if len(self.selected_boxes) == 2: + for trk_id in self.selected_boxes.keys(): + if trk_id == track_id: + self.selected_boxes[track_id] = box + + if len(self.selected_boxes) == 2: + # Calculate centroids of selected boxes + self.centroids.extend( + [[int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)] for box in self.selected_boxes.values()] + ) + # Calculate Euclidean distance between centroids + pixels_distance = math.sqrt( + (self.centroids[0][0] - self.centroids[1][0]) ** 2 + (self.centroids[0][1] - self.centroids[1][1]) ** 2 + ) + annotator.plot_distance_and_line(pixels_distance, self.centroids) + + self.centroids = [] # Reset centroids for next frame + plot_im = annotator.result() + self.display_output(plot_im) # Display output with base class function + cv2.setMouseCallback("Ultralytics Solutions", self.mouse_event_for_distance) + + # Return SolutionResults with processed image and calculated metrics + return SolutionResults(plot_im=plot_im, pixels_distance=pixels_distance, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/solutions/heatmap.py b/tracking/ultralytics/solutions/heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..f005c2da2072f44f98f3980b2968be3d2fa654ea --- /dev/null +++ b/tracking/ultralytics/solutions/heatmap.py @@ -0,0 +1,129 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import numpy as np + +from ultralytics.solutions.object_counter import ObjectCounter +from ultralytics.solutions.solutions import SolutionAnnotator, SolutionResults + + +class Heatmap(ObjectCounter): + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (SolutionAnnotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculate and update the heatmap effect for a given bounding box. + process: Generate and apply the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolo11n.pt", colormap=cv2.COLORMAP_JET) + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = heatmap.process(frame) + """ + + def __init__(self, **kwargs): + """ + Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks. + + Args: + **kwargs (Any): Keyword arguments passed to the parent ObjectCounter class. + """ + super().__init__(**kwargs) + + self.initialized = False # Flag for heatmap initialization + if self.region is not None: # Check if user provided the region coordinates + self.initialize_region() + + # Store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] + self.heatmap = None + + def heatmap_effect(self, box): + """ + Efficiently calculate heatmap area and effect location for applying colormap. + + Args: + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + """ + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 + + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) + + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 + + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared + + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 + + def process(self, im0): + """ + Generate heatmap for each frame using Ultralytics. + + Args: + im0 (np.ndarray): Input image array for processing. + + Returns: + (SolutionResults): Contains processed image `plot_im`, + 'in_count' (int, count of objects entering the region), + 'out_count' (int, count of objects exiting the region), + 'classwise_count' (dict, per-class object count), and + 'total_tracks' (int, total number of tracked objects). + """ + if not self.initialized: + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 + self.initialized = True # Initialize heatmap only once + + self.extract_tracks(im0) # Extract tracks + self.annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Apply heatmap effect for the bounding box + self.heatmap_effect(box) + + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # Store classwise counts in dict + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Get previous position if available + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + plot_im = self.annotator.result() + if self.region is not None: + self.display_counts(plot_im) # Display the counts on the frame + + # Normalize, apply colormap to heatmap and combine with original image + if self.track_data.id is not None: + normalized_heatmap = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + colored_heatmap = cv2.applyColorMap(normalized_heatmap, self.colormap) + plot_im = cv2.addWeighted(plot_im, 0.5, colored_heatmap, 0.5, 0) + + self.display_output(plot_im) # Display output with base class function + + # Return SolutionResults + return SolutionResults( + plot_im=plot_im, + in_count=self.in_count, + out_count=self.out_count, + classwise_count=self.classwise_counts, + total_tracks=len(self.track_ids), + ) diff --git a/tracking/ultralytics/solutions/instance_segmentation.py b/tracking/ultralytics/solutions/instance_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..bd991ff6196b243b0840cf8d4f26123c4dd03467 --- /dev/null +++ b/tracking/ultralytics/solutions/instance_segmentation.py @@ -0,0 +1,75 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class InstanceSegmentation(BaseSolution): + """ + A class to manage instance segmentation in images or video streams. + + This class extends the BaseSolution class and provides functionality for performing instance segmentation, including + drawing segmented masks with bounding boxes and labels. + + Attributes: + model (str): The segmentation model to use for inference. + line_width (int): Width of the bounding box and text lines. + names (Dict[int, str]): Dictionary mapping class indices to class names. + clss (List[int]): List of detected class indices. + track_ids (List[int]): List of track IDs for detected instances. + masks (List[numpy.ndarray]): List of segmentation masks for detected instances. + + Methods: + process: Process the input image to perform instance segmentation and annotate results. + extract_tracks: Extract tracks including bounding boxes, classes, and masks from model predictions. + + Examples: + >>> segmenter = InstanceSegmentation() + >>> frame = cv2.imread("frame.jpg") + >>> results = segmenter.segment(frame) + >>> print(f"Total segmented instances: {results['total_tracks']}") + """ + + def __init__(self, **kwargs): + """ + Initialize the InstanceSegmentation class for detecting and annotating segmented instances. + + Args: + **kwargs (Any): Keyword arguments passed to the BaseSolution parent class. + model (str): Model name or path, defaults to "yolo11n-seg.pt". + """ + kwargs["model"] = kwargs.get("model", "yolo11n-seg.pt") + super().__init__(**kwargs) + + def process(self, im0): + """ + Perform instance segmentation on the input image and annotate the results. + + Args: + im0 (numpy.ndarray): The input image for segmentation. + + Returns: + (SolutionResults): Object containing the annotated image and total number of tracked instances. + + Examples: + >>> segmenter = InstanceSegmentation() + >>> frame = cv2.imread("image.jpg") + >>> summary = segmenter.segment(frame) + >>> print(summary) + """ + self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks) + annotator = SolutionAnnotator(im0, self.line_width) + + # Iterate over detected classes, track IDs, and segmentation masks + if self.masks is None: + self.LOGGER.warning("⚠️ No masks detected! Ensure you're using a supported Ultralytics segmentation model.") + else: + for cls, t_id, mask in zip(self.clss, self.track_ids, self.masks): + # Annotate the image with segmentation mask, mask color, and label + annotator.segmentation_mask(mask=mask, mask_color=colors(t_id, True), label=self.names[cls]) + + plot_im = annotator.result() + self.display_output(plot_im) # Display the annotated output using the base class function + + # Return SolutionResults + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/solutions/object_blurrer.py b/tracking/ultralytics/solutions/object_blurrer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a73d5cd737632bc7c417c7727cd475ef4636fd --- /dev/null +++ b/tracking/ultralytics/solutions/object_blurrer.py @@ -0,0 +1,89 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + + +import cv2 + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils import LOGGER +from ultralytics.utils.plotting import colors + + +class ObjectBlurrer(BaseSolution): + """ + A class to manage the blurring of detected objects in a real-time video stream. + + This class extends the BaseSolution class and provides functionality for blurring objects based on detected bounding + boxes. The blurred areas are updated directly in the input image, allowing for privacy preservation or other effects. + + Attributes: + blur_ratio (int): The intensity of the blur effect applied to detected objects (higher values create more blur). + iou (float): Intersection over Union threshold for object detection. + conf (float): Confidence threshold for object detection. + + Methods: + process: Applies a blurring effect to detected objects in the input image. + extract_tracks: Extracts tracking information from detected objects. + display_output: Displays the processed output image. + + Examples: + >>> blurrer = ObjectBlurrer() + >>> frame = cv2.imread("frame.jpg") + >>> processed_results = blurrer.process(frame) + >>> print(f"Total blurred objects: {processed_results.total_tracks}") + """ + + def __init__(self, **kwargs): + """ + Initialize the ObjectBlurrer class for applying a blur effect to objects detected in video streams or images. + + Args: + **kwargs (Any): Keyword arguments passed to the parent class and for configuration. + blur_ratio (float): Intensity of the blur effect (0.1-1.0, default=0.5). + """ + super().__init__(**kwargs) + blur_ratio = kwargs.get("blur_ratio", 0.5) + if blur_ratio < 0.1: + LOGGER.warning("⚠️ blur ratio cannot be less than 0.1, updating it to default value 0.5") + blur_ratio = 0.5 + self.blur_ratio = int(blur_ratio * 100) + + def process(self, im0): + """ + Apply a blurring effect to detected objects in the input image. + + This method extracts tracking information, applies blur to regions corresponding to detected objects, + and annotates the image with bounding boxes. + + Args: + im0 (numpy.ndarray): The input image containing detected objects. + + Returns: + (SolutionResults): Object containing the processed image and number of tracked objects. + - plot_im (numpy.ndarray): The annotated output image with blurred objects. + - total_tracks (int): The total number of tracked objects in the frame. + + Examples: + >>> blurrer = ObjectBlurrer() + >>> frame = cv2.imread("image.jpg") + >>> results = blurrer.process(frame) + >>> print(f"Blurred {results.total_tracks} objects") + """ + self.extract_tracks(im0) # Extract tracks + annotator = SolutionAnnotator(im0, self.line_width) + + # Iterate over bounding boxes and classes + for box, cls in zip(self.boxes, self.clss): + # Crop and blur the detected object + blur_obj = cv2.blur( + im0[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])], + (self.blur_ratio, self.blur_ratio), + ) + # Update the blurred area in the original image + im0[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] = blur_obj + annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) # Annotate bounding box + + plot_im = annotator.result() + self.display_output(plot_im) # Display the output using the base class function + + # Return a SolutionResults + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/solutions/object_counter.py b/tracking/ultralytics/solutions/object_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..b0aa2a55c36dc125f0882a89b8e1997db9b352fe --- /dev/null +++ b/tracking/ultralytics/solutions/object_counter.py @@ -0,0 +1,205 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class ObjectCounter(BaseSolution): + """ + A class to manage the counting of objects in a real-time video stream based on their tracks. + + This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a + specified region in a video stream. It supports both polygonal and linear regions for counting. + + Attributes: + in_count (int): Counter for objects moving inward. + out_count (int): Counter for objects moving outward. + counted_ids (List[int]): List of IDs of objects that have been counted. + classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class. + region_initialized (bool): Flag indicating whether the counting region has been initialized. + show_in (bool): Flag to control display of inward count. + show_out (bool): Flag to control display of outward count. + + Methods: + count_objects: Counts objects within a polygonal or linear region. + store_classwise_counts: Initializes class-wise counts if not already present. + display_counts: Displays object counts on the frame. + process: Processes input data (frames or object tracks) and updates counts. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("frame.jpg") + >>> results = counter.process(frame) + >>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}") + """ + + def __init__(self, **kwargs): + """Initializes the ObjectCounter class for real-time object counting in video streams.""" + super().__init__(**kwargs) + + self.in_count = 0 # Counter for objects moving inward + self.out_count = 0 # Counter for objects moving outward + self.counted_ids = [] # List of IDs of objects that have been counted + self.classwise_counts = {} # Dictionary for counts, categorized by object class + self.region_initialized = False # Flag indicating whether the region has been initialized + + self.show_in = self.CFG["show_in"] + self.show_out = self.CFG["show_out"] + + def count_objects(self, current_centroid, track_id, prev_position, cls): + """ + Counts objects within a polygonal or linear region based on their tracks. + + Args: + current_centroid (Tuple[float, float]): Current centroid coordinates (x, y) in the current frame. + track_id (int): Unique identifier for the tracked object. + prev_position (Tuple[float, float]): Last frame position coordinates (x, y) of the track. + cls (int): Class index for classwise count updates. + + Examples: + >>> counter = ObjectCounter() + >>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]} + >>> box = [130, 230, 150, 250] + >>> track_id_num = 1 + >>> previous_position = (120, 220) + >>> class_to_count = 0 # In COCO model, class 0 = person + >>> counter.count_objects((140, 240), track_id_num, previous_position, class_to_count) + """ + if prev_position is None or track_id in self.counted_ids: + return + + if len(self.region) == 2: # Linear region (defined as a line segment) + line = self.LineString(self.region) # Check if the line intersects the trajectory of the object + if line.intersects(self.LineString([prev_position, current_centroid])): + # Determine orientation of the region (vertical or horizontal) + if abs(self.region[0][0] - self.region[1][0]) < abs(self.region[0][1] - self.region[1][1]): + # Vertical region: Compare x-coordinates to determine direction + if current_centroid[0] > prev_position[0]: # Moving right + self.in_count += 1 + self.classwise_counts[self.names[cls]]["IN"] += 1 + else: # Moving left + self.out_count += 1 + self.classwise_counts[self.names[cls]]["OUT"] += 1 + # Horizontal region: Compare y-coordinates to determine direction + elif current_centroid[1] > prev_position[1]: # Moving downward + self.in_count += 1 + self.classwise_counts[self.names[cls]]["IN"] += 1 + else: # Moving upward + self.out_count += 1 + self.classwise_counts[self.names[cls]]["OUT"] += 1 + self.counted_ids.append(track_id) + + elif len(self.region) > 2: # Polygonal region + polygon = self.Polygon(self.region) + if polygon.contains(self.Point(current_centroid)): + # Determine motion direction for vertical or horizontal polygons + region_width = max(p[0] for p in self.region) - min(p[0] for p in self.region) + region_height = max(p[1] for p in self.region) - min(p[1] for p in self.region) + + if ( + region_width < region_height + and current_centroid[0] > prev_position[0] + or region_width >= region_height + and current_centroid[1] > prev_position[1] + ): # Moving right or downward + self.in_count += 1 + self.classwise_counts[self.names[cls]]["IN"] += 1 + else: # Moving left or upward + self.out_count += 1 + self.classwise_counts[self.names[cls]]["OUT"] += 1 + self.counted_ids.append(track_id) + + def store_classwise_counts(self, cls): + """ + Initialize class-wise counts for a specific object class if not already present. + + Args: + cls (int): Class index for classwise count updates. + + Examples: + >>> counter = ObjectCounter() + >>> counter.store_classwise_counts(0) # Initialize counts for class index 0 + >>> print(counter.classwise_counts) + {'person': {'IN': 0, 'OUT': 0}} + """ + if self.names[cls] not in self.classwise_counts: + self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0} + + def display_counts(self, plot_im): + """ + Display object counts on the input image or frame. + + Args: + plot_im (numpy.ndarray): The image or frame to display counts on. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("image.jpg") + >>> counter.display_counts(frame) + """ + labels_dict = { + str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} " + f"{'OUT ' + str(value['OUT']) if self.show_out else ''}".strip() + for key, value in self.classwise_counts.items() + if value["IN"] != 0 or value["OUT"] != 0 + } + if labels_dict: + self.annotator.display_analytics(plot_im, labels_dict, (104, 31, 17), (255, 255, 255), 10) + + def process(self, im0): + """ + Process input data (frames or object tracks) and update object counts. + + This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates + object counts, and displays the results on the input image. + + Args: + im0 (numpy.ndarray): The input image or frame to be processed. + + Returns: + (SolutionResults): Contains processed image `im0`, 'in_count' (int, count of objects entering the region), + 'out_count' (int, count of objects exiting the region), 'classwise_count' (dict, per-class object count), + and 'total_tracks' (int, total number of tracked objects). + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("path/to/image.jpg") + >>> results = counter.process(frame) + """ + if not self.region_initialized: + self.initialize_region() + self.region_initialized = True + + self.extract_tracks(im0) # Extract tracks + self.annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + self.annotator.draw_region( + reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2 + ) # Draw region + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # Store classwise counts in dict + + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Store previous position of track for object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + plot_im = self.annotator.result() + self.display_counts(plot_im) # Display the counts on the frame + self.display_output(plot_im) # Display output with base class function + + # Return SolutionResults + return SolutionResults( + plot_im=plot_im, + in_count=self.in_count, + out_count=self.out_count, + classwise_count=self.classwise_counts, + total_tracks=len(self.track_ids), + ) diff --git a/tracking/ultralytics/solutions/object_cropper.py b/tracking/ultralytics/solutions/object_cropper.py new file mode 100644 index 0000000000000000000000000000000000000000..bc80f91e56495363cd569d894f31ca41281e988e --- /dev/null +++ b/tracking/ultralytics/solutions/object_cropper.py @@ -0,0 +1,84 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +from pathlib import Path + +from ultralytics.solutions.solutions import BaseSolution, SolutionResults +from ultralytics.utils.plotting import save_one_box + + +class ObjectCropper(BaseSolution): + """ + A class to manage the cropping of detected objects in a real-time video stream or images. + + This class extends the BaseSolution class and provides functionality for cropping objects based on detected bounding + boxes. The cropped images are saved to a specified directory for further analysis or usage. + + Attributes: + crop_dir (str): Directory where cropped object images are stored. + crop_idx (int): Counter for the total number of cropped objects. + iou (float): IoU (Intersection over Union) threshold for non-maximum suppression. + conf (float): Confidence threshold for filtering detections. + + Methods: + process: Crops detected objects from the input image and saves them to the output directory. + + Examples: + >>> cropper = ObjectCropper() + >>> frame = cv2.imread("frame.jpg") + >>> processed_results = cropper.process(frame) + >>> print(f"Total cropped objects: {cropper.crop_idx}") + """ + + def __init__(self, **kwargs): + """ + Initialize the ObjectCropper class for cropping objects from detected bounding boxes. + + Args: + **kwargs (Any): Keyword arguments passed to the parent class and used for configuration. + crop_dir (str): Path to the directory for saving cropped object images. + """ + super().__init__(**kwargs) + + self.crop_dir = kwargs.get("crop_dir", "cropped-detections") # Directory for storing cropped detections + if not os.path.exists(self.crop_dir): + os.mkdir(self.crop_dir) # Create directory if it does not exist + if self.CFG["show"]: + self.LOGGER.info( + f"⚠️ show=True disabled for crop solution, results will be saved in the directory named: {self.crop_dir}" + ) + self.crop_idx = 0 # Initialize counter for total cropped objects + self.iou = self.CFG["iou"] + self.conf = self.CFG["conf"] if self.CFG["conf"] is not None else 0.25 + + def process(self, im0): + """ + Crop detected objects from the input image and save them as separate images. + + Args: + im0 (numpy.ndarray): The input image containing detected objects. + + Returns: + (SolutionResults): A SolutionResults object containing the total number of cropped objects and processed image. + + Examples: + >>> cropper = ObjectCropper() + >>> frame = cv2.imread("image.jpg") + >>> results = cropper.process(frame) + >>> print(f"Total cropped objects: {results.total_crop_objects}") + """ + results = self.model.predict( + im0, classes=self.classes, conf=self.conf, iou=self.iou, device=self.CFG["device"] + )[0] + + for box in results.boxes: + self.crop_idx += 1 + save_one_box( + box.xyxy, + im0, + file=Path(self.crop_dir) / f"crop_{self.crop_idx}.jpg", + BGR=True, + ) + + # Return SolutionResults + return SolutionResults(plot_im=im0, total_crop_objects=self.crop_idx) diff --git a/tracking/ultralytics/solutions/parking_management.py b/tracking/ultralytics/solutions/parking_management.py new file mode 100644 index 0000000000000000000000000000000000000000..70bcfd236b6f2f1f2b85bc34679bfe28871fed0c --- /dev/null +++ b/tracking/ultralytics/solutions/parking_management.py @@ -0,0 +1,273 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json + +import cv2 +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils import LOGGER +from ultralytics.utils.checks import check_imshow + + +class ParkingPtsSelection: + """ + A class for selecting and managing parking zone points on images using a Tkinter-based UI. + + This class provides functionality to upload an image, select points to define parking zones, and save the + selected points to a JSON file. It uses Tkinter for the graphical user interface. + + Attributes: + tk (module): The Tkinter module for GUI operations. + filedialog (module): Tkinter's filedialog module for file selection operations. + messagebox (module): Tkinter's messagebox module for displaying message boxes. + master (tk.Tk): The main Tkinter window. + canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes. + image (PIL.Image.Image): The uploaded image. + canvas_image (ImageTk.PhotoImage): The image displayed on the canvas. + rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points. + current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box. + imgw (int): Original width of the uploaded image. + imgh (int): Original height of the uploaded image. + canvas_max_width (int): Maximum width of the canvas. + canvas_max_height (int): Maximum height of the canvas. + + Methods: + initialize_properties: Initializes the necessary properties. + upload_image: Uploads an image, resizes it to fit the canvas, and displays it. + on_canvas_click: Handles mouse clicks to add points for bounding boxes. + draw_box: Draws a bounding box on the canvas. + remove_last_bounding_box: Removes the last bounding box and redraws the canvas. + redraw_canvas: Redraws the canvas with the image and all bounding boxes. + save_to_json: Saves the bounding boxes to a JSON file. + + Examples: + >>> parking_selector = ParkingPtsSelection() + >>> # Use the GUI to upload an image, select parking zones, and save the data + """ + + def __init__(self): + """Initialize the ParkingPtsSelection class, setting up UI and properties for parking zone point selection.""" + try: # check if tkinter installed + import tkinter as tk + from tkinter import filedialog, messagebox + except ImportError: # Display error with recommendations + import platform + + install_cmd = { + "Linux": "sudo apt install python3-tk (Debian/Ubuntu) | sudo dnf install python3-tkinter (Fedora) | " + "sudo pacman -S tk (Arch)", + "Windows": "reinstall Python and enable the checkbox `tcl/tk and IDLE` on **Optional Features** during installation", + "Darwin": "reinstall Python from https://www.python.org/downloads/mac-osx/ or `brew install python-tk`", + }.get(platform.system(), "Unknown OS. Check your Python installation.") + + LOGGER.warning(f"WARNING ⚠️ Tkinter is not configured or supported. Potential fix: {install_cmd}") + return + + if not check_imshow(warn=True): + return + + self.tk, self.filedialog, self.messagebox = tk, filedialog, messagebox + self.master = self.tk.Tk() # Reference to the main application window or parent widget + self.master.title("Ultralytics Parking Zones Points Selector") + self.master.resizable(False, False) + + self.canvas = self.tk.Canvas(self.master, bg="white") # Canvas widget for displaying images or graphics + self.canvas.pack(side=self.tk.BOTTOM) + + self.image = None # Variable to store the loaded image + self.canvas_image = None # Reference to the image displayed on the canvas + self.canvas_max_width = None # Maximum allowed width for the canvas + self.canvas_max_height = None # Maximum allowed height for the canvas + self.rg_data = None # Data related to region or annotation management + self.current_box = None # Stores the currently selected or active bounding box + self.imgh = None # Height of the current image + self.imgw = None # Width of the current image + + # Button frame with buttons + button_frame = self.tk.Frame(self.master) + button_frame.pack(side=self.tk.TOP) + + for text, cmd in [ + ("Upload Image", self.upload_image), + ("Remove Last BBox", self.remove_last_bounding_box), + ("Save", self.save_to_json), + ]: + self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT) + + self.initialize_properties() + self.master.mainloop() + + def initialize_properties(self): + """Initialize properties for image, canvas, bounding boxes, and dimensions.""" + self.image = self.canvas_image = None + self.rg_data, self.current_box = [], [] + self.imgw = self.imgh = 0 + self.canvas_max_width, self.canvas_max_height = 1280, 720 + + def upload_image(self): + """Upload and display an image on the canvas, resizing it to fit within specified dimensions.""" + from PIL import Image, ImageTk # scope because ImageTk requires tkinter package + + self.image = Image.open(self.filedialog.askopenfilename(filetypes=[("Image Files", "*.png *.jpg *.jpeg")])) + if not self.image: + return + + self.imgw, self.imgh = self.image.size + aspect_ratio = self.imgw / self.imgh + canvas_width = ( + min(self.canvas_max_width, self.imgw) if aspect_ratio > 1 else int(self.canvas_max_height * aspect_ratio) + ) + canvas_height = ( + min(self.canvas_max_height, self.imgh) if aspect_ratio <= 1 else int(canvas_width / aspect_ratio) + ) + + self.canvas.config(width=canvas_width, height=canvas_height) + self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height))) + self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image) + self.canvas.bind("", self.on_canvas_click) + + self.rg_data.clear(), self.current_box.clear() + + def on_canvas_click(self, event): + """Handle mouse clicks to add points for bounding boxes on the canvas.""" + self.current_box.append((event.x, event.y)) + self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red") + if len(self.current_box) == 4: + self.rg_data.append(self.current_box.copy()) + self.draw_box(self.current_box) + self.current_box.clear() + + def draw_box(self, box): + """Draw a bounding box on the canvas using the provided coordinates.""" + for i in range(4): + self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2) + + def remove_last_bounding_box(self): + """Remove the last bounding box from the list and redraw the canvas.""" + if not self.rg_data: + self.messagebox.showwarning("Warning", "No bounding boxes to remove.") + return + self.rg_data.pop() + self.redraw_canvas() + + def redraw_canvas(self): + """Redraw the canvas with the image and all bounding boxes.""" + self.canvas.delete("all") + self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image) + for box in self.rg_data: + self.draw_box(box) + + def save_to_json(self): + """Save the selected parking zone points to a JSON file with scaled coordinates.""" + scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height() + data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data] + + from io import StringIO # Function level import, as it's only required to store coordinates, not every frame + + write_buffer = StringIO() + json.dump(data, write_buffer, indent=4) + with open("bounding_boxes.json", "w", encoding="utf-8") as f: + f.write(write_buffer.getvalue()) + self.messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json") + + +class ParkingManagement(BaseSolution): + """ + Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization. + + This class extends BaseSolution to provide functionality for parking lot management, including detection of + occupied spaces, visualization of parking regions, and display of occupancy statistics. + + Attributes: + json_file (str): Path to the JSON file containing parking region details. + json (List[Dict]): Loaded JSON data containing parking region information. + pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces). + arc (Tuple[int, int, int]): RGB color tuple for available region visualization. + occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization. + dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects. + + Methods: + process: Processes the input image for parking lot management and visualization. + + Examples: + >>> from ultralytics.solutions import ParkingManagement + >>> parking_manager = ParkingManagement(model="yolo11n.pt", json_file="parking_regions.json") + >>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}") + >>> print(f"Available spaces: {parking_manager.pr_info['Available']}") + """ + + def __init__(self, **kwargs): + """Initialize the parking management system with a YOLO model and visualization settings.""" + super().__init__(**kwargs) + + self.json_file = self.CFG["json_file"] # Load JSON data + if self.json_file is None: + LOGGER.warning("❌ json_file argument missing. Parking region details required.") + raise ValueError("❌ Json file path can not be empty") + + with open(self.json_file) as f: + self.json = json.load(f) + + self.pr_info = {"Occupancy": 0, "Available": 0} # dictionary for parking information + + self.arc = (0, 0, 255) # available region color + self.occ = (0, 255, 0) # occupied region color + self.dc = (255, 0, 189) # centroid color for each box + + def process(self, im0): + """ + Process the input image for parking lot management and visualization. + + This function analyzes the input image, extracts tracks, and determines the occupancy status of parking + regions defined in the JSON file. It annotates the image with occupied and available parking spots, + and updates the parking information. + + Args: + im0 (np.ndarray): The input inference image. + + Returns: + (SolutionResults): Contains processed image `plot_im`, 'filled_slots' (number of occupied parking slots), + 'available_slots' (number of available parking slots), and 'total_tracks' (total number of tracked objects). + + Examples: + >>> parking_manager = ParkingManagement(json_file="parking_regions.json") + >>> image = cv2.imread("parking_lot.jpg") + >>> results = parking_manager.process(image) + """ + self.extract_tracks(im0) # extract tracks from im0 + es, fs = len(self.json), 0 # empty slots, filled slots + annotator = SolutionAnnotator(im0, self.line_width) # init annotator + + for region in self.json: + # Convert points to a NumPy array with the correct dtype and reshape properly + pts_array = np.array(region["points"], dtype=np.int32).reshape((-1, 1, 2)) + rg_occupied = False # occupied region initialization + for box, cls in zip(self.boxes, self.clss): + xc, yc = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + dist = cv2.pointPolygonTest(pts_array, (xc, yc), False) + if dist >= 0: + # cv2.circle(im0, (xc, yc), radius=self.line_width * 4, color=self.dc, thickness=-1) + annotator.display_objects_labels( + im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10 + ) + rg_occupied = True + break + fs, es = (fs + 1, es - 1) if rg_occupied else (fs, es) + # Plotting regions + cv2.polylines(im0, [pts_array], isClosed=True, color=self.occ if rg_occupied else self.arc, thickness=2) + + self.pr_info["Occupancy"], self.pr_info["Available"] = fs, es + + annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10) + + plot_im = annotator.result() + self.display_output(plot_im) # display output with base class function + + # Return SolutionResults + return SolutionResults( + plot_im=plot_im, + filled_slots=self.pr_info["Occupancy"], + available_slots=self.pr_info["Available"], + total_tracks=len(self.track_ids), + ) diff --git a/tracking/ultralytics/solutions/queue_management.py b/tracking/ultralytics/solutions/queue_management.py new file mode 100644 index 0000000000000000000000000000000000000000..c0234aa9655865db6303eb143a71ccda6a77c789 --- /dev/null +++ b/tracking/ultralytics/solutions/queue_management.py @@ -0,0 +1,93 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class QueueManager(BaseSolution): + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> cap = cv2.VideoCapture("path/to/video.mp4") + >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300]) + >>> while cap.isOpened(): + >>> success, im0 = cap.read() + >>> if not success: + >>> break + >>> results = queue_manager.process(im0) + """ + + def __init__(self, **kwargs): + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" + super().__init__(**kwargs) + self.initialize_region() + self.counts = 0 # Queue counts information + self.rect_color = (255, 255, 255) # Rectangle color for visualization + self.region_length = len(self.region) # Store region length for further usage + + def process(self, im0): + """ + Process queue management for a single frame of video. + + Args: + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (SolutionResults): Contains processed image `im0`, 'queue_count' (int, number of objects in the queue) and + 'total_tracks' (int, total number of tracked objects). + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> results = queue_manager.process(frame) + """ + self.counts = 0 # Reset counts every frame + self.extract_tracks(im0) # Extract tracks from the current frame + annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + annotator.draw_region(reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + + # Cache frequently accessed attributes + track_history = self.track_history.get(track_id, []) + + # Store previous position of track and check if the object is inside the counting region + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): + self.counts += 1 + + # Display queue counts + annotator.queue_counts_display( + f"Queue Counts : {str(self.counts)}", + points=self.region, + region_color=self.rect_color, + txt_color=(104, 31, 17), + ) + plot_im = annotator.result() + self.display_output(plot_im) # Display output with base class function + + # Return a SolutionResults object with processed data + return SolutionResults(plot_im=plot_im, queue_count=self.counts, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/solutions/region_counter.py b/tracking/ultralytics/solutions/region_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..a2973b6316629d3815488a3b4b753777d73ae316 --- /dev/null +++ b/tracking/ultralytics/solutions/region_counter.py @@ -0,0 +1,119 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class RegionCounter(BaseSolution): + """ + A class for real-time counting of objects within user-defined regions in a video stream. + + This class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame, + track objects, and count those objects that pass through each defined region. Useful for applications requiring + counting in specified areas, such as monitoring zones or segmented sections. + + Attributes: + region_template (dict): Template for creating new counting regions with default attributes including name, + polygon coordinates, and display colors. + counting_regions (list): List storing all defined regions, where each entry is based on `region_template` + and includes specific region settings like name, coordinates, and color. + region_counts (dict): Dictionary storing the count of objects for each named region. + + Methods: + add_region: Adds a new counting region with specified attributes. + process: Processes video frames to count objects in each region. + """ + + def __init__(self, **kwargs): + """Initializes the RegionCounter class for real-time counting in different regions of video streams.""" + super().__init__(**kwargs) + self.region_template = { + "name": "Default Region", + "polygon": None, + "counts": 0, + "dragging": False, + "region_color": (255, 255, 255), + "text_color": (0, 0, 0), + } + self.region_counts = {} + self.counting_regions = [] + + def add_region(self, name, polygon_points, region_color, text_color): + """ + Add a new region to the counting list based on the provided template with specific attributes. + + Args: + name (str): Name assigned to the new region. + polygon_points (List[Tuple]): List of (x, y) coordinates defining the region's polygon. + region_color (tuple): BGR color for region visualization. + text_color (tuple): BGR color for the text within the region. + """ + region = self.region_template.copy() + region.update( + { + "name": name, + "polygon": self.Polygon(polygon_points), + "region_color": region_color, + "text_color": text_color, + } + ) + self.counting_regions.append(region) + + def process(self, im0): + """ + Process the input frame to detect and count objects within each defined region. + + Args: + im0 (np.ndarray): Input image frame where objects and regions are annotated. + + Returns: + (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects), + and 'region_counts' (dict, counts of objects per region). + """ + self.extract_tracks(im0) + annotator = SolutionAnnotator(im0, line_width=self.line_width) + + # Ensure self.region is initialized and structured as a dictionary + if not isinstance(self.region, dict): + self.region = {"Region#01": self.region or self.initialize_region()} + + # Draw only valid regions + for idx, (region_name, reg_pts) in enumerate(self.region.items(), start=1): + color = colors(idx, True) + annotator.draw_region(reg_pts, color, self.line_width * 2) + self.add_region(region_name, reg_pts, color, annotator.get_txt_color()) + + # Prepare regions for containment check (only process valid ones) + for region in self.counting_regions: + if "prepared_polygon" not in region: + region["prepared_polygon"] = self.prep(region["polygon"]) + + # Convert bounding boxes to NumPy array for center points + boxes_np = np.array([((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) for box in self.boxes], dtype=np.float32) + points = [self.Point(pt) for pt in boxes_np] # Convert centers to Point objects + + # Process bounding boxes & check containment + if points: + for (point, cls), box in zip(zip(points, self.clss), self.boxes): + annotator.box_label(box, label=self.names[cls], color=colors(cls)) + + for region in self.counting_regions: + if region["prepared_polygon"].contains(point): + region["counts"] += 1 + self.region_counts[region["name"]] = region["counts"] + + # Display region counts + for region in self.counting_regions: + annotator.text_label( + region["polygon"].bounds, + label=str(region["counts"]), + color=region["region_color"], + txt_color=region["text_color"], + ) + region["counts"] = 0 # Reset for next frame + plot_im = annotator.result() + self.display_output(plot_im) + + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), region_counts=self.region_counts) diff --git a/tracking/ultralytics/solutions/security_alarm.py b/tracking/ultralytics/solutions/security_alarm.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2c60f35aa32fe881884c481d41f379990451c5 --- /dev/null +++ b/tracking/ultralytics/solutions/security_alarm.py @@ -0,0 +1,154 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils import LOGGER +from ultralytics.utils.plotting import colors + + +class SecurityAlarm(BaseSolution): + """ + A class to manage security alarm functionalities for real-time monitoring. + + This class extends the BaseSolution class and provides features to monitor objects in a frame, send email + notifications when specific thresholds are exceeded for total detections, and annotate the output frame for + visualization. + + Attributes: + email_sent (bool): Flag to track if an email has already been sent for the current event. + records (int): Threshold for the number of detected objects to trigger an alert. + server (smtplib.SMTP): SMTP server connection for sending email alerts. + to_email (str): Recipient's email address for alerts. + from_email (str): Sender's email address for alerts. + + Methods: + authenticate: Set up email server authentication for sending alerts. + send_email: Send an email notification with details and an image attachment. + process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed. + + Examples: + >>> security = SecurityAlarm() + >>> security.authenticate("abc@gmail.com", "1111222233334444", "xyz@gmail.com") + >>> frame = cv2.imread("frame.jpg") + >>> results = security.process(frame) + """ + + def __init__(self, **kwargs): + """ + Initialize the SecurityAlarm class with parameters for real-time object monitoring. + + Args: + **kwargs (Any): Additional keyword arguments passed to the parent class. + """ + super().__init__(**kwargs) + self.email_sent = False + self.records = self.CFG["records"] + self.server = None + self.to_email = "" + self.from_email = "" + + def authenticate(self, from_email, password, to_email): + """ + Authenticate the email server for sending alert notifications. + + Args: + from_email (str): Sender's email address. + password (str): Password for the sender's email account. + to_email (str): Recipient's email address. + + This method initializes a secure connection with the SMTP server and logs in using the provided credentials. + + Examples: + >>> alarm = SecurityAlarm() + >>> alarm.authenticate("sender@example.com", "password123", "recipient@example.com") + """ + import smtplib + + self.server = smtplib.SMTP("smtp.gmail.com: 587") + self.server.starttls() + self.server.login(from_email, password) + self.to_email = to_email + self.from_email = from_email + + def send_email(self, im0, records=5): + """ + Send an email notification with an image attachment indicating the number of objects detected. + + Args: + im0 (numpy.ndarray): The input image or frame to be attached to the email. + records (int): The number of detected objects to be included in the email message. + + This method encodes the input image, composes the email message with details about the detection, and sends it + to the specified recipient. + + Examples: + >>> alarm = SecurityAlarm() + >>> frame = cv2.imread("path/to/image.jpg") + >>> alarm.send_email(frame, records=10) + """ + from email.mime.image import MIMEImage + from email.mime.multipart import MIMEMultipart + from email.mime.text import MIMEText + + import cv2 + + img_bytes = cv2.imencode(".jpg", im0)[1].tobytes() # Encode the image as JPEG + + # Create the email + message = MIMEMultipart() + message["From"] = self.from_email + message["To"] = self.to_email + message["Subject"] = "Security Alert" + + # Add the text message body + message_body = f"Ultralytics ALERT!!! {records} objects have been detected!!" + message.attach(MIMEText(message_body)) + + # Attach the image + image_attachment = MIMEImage(img_bytes, name="ultralytics.jpg") + message.attach(image_attachment) + + # Send the email + try: + self.server.send_message(message) + LOGGER.info("✅ Email sent successfully!") + except Exception as e: + LOGGER.error(f"❌ Failed to send email: {e}") + + def process(self, im0): + """ + Monitor the frame, process object detections, and trigger alerts if thresholds are exceeded. + + Args: + im0 (numpy.ndarray): The input image or frame to be processed and annotated. + + Returns: + (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (total number of tracked objects) and + 'email_sent' (whether an email alert was triggered). + + This method processes the input frame, extracts detections, annotates the frame with bounding boxes, and sends + an email notification if the number of detected objects surpasses the specified threshold and an alert has not + already been sent. + + Examples: + >>> alarm = SecurityAlarm() + >>> frame = cv2.imread("path/to/image.jpg") + >>> results = alarm.process(frame) + """ + self.extract_tracks(im0) # Extract tracks + annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + # Iterate over bounding boxes and classes index + for box, cls in zip(self.boxes, self.clss): + # Draw bounding box + annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) + + total_det = len(self.clss) + if total_det > self.records and not self.email_sent: # Only send email if not sent before + self.send_email(im0, total_det) + self.email_sent = True + + plot_im = annotator.result() + self.display_output(plot_im) # Display output with base class function + + # Return a SolutionResults + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), email_sent=self.email_sent) diff --git a/tracking/ultralytics/solutions/solutions.py b/tracking/ultralytics/solutions/solutions.py new file mode 100644 index 0000000000000000000000000000000000000000..cf624eacc2f5383da06db48b3b8e73ae7df16845 --- /dev/null +++ b/tracking/ultralytics/solutions/solutions.py @@ -0,0 +1,759 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from collections import defaultdict + +import cv2 +import numpy as np + +from ultralytics import YOLO +from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER +from ultralytics.utils.checks import check_imshow, check_requirements +from ultralytics.utils.plotting import Annotator + + +class BaseSolution: + """ + A base class for managing Ultralytics Solutions. + + This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking, + and region initialization. + + Attributes: + LineString (shapely.geometry.LineString): Class for creating line string geometries. + Polygon (shapely.geometry.Polygon): Class for creating polygon geometries. + Point (shapely.geometry.Point): Class for creating point geometries. + CFG (dict): Configuration dictionary loaded from a YAML file and updated with kwargs. + region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest. + line_width (int): Width of lines used in visualizations. + model (ultralytics.YOLO): Loaded YOLO model instance. + names (Dict[int, str]): Dictionary mapping class indices to class names. + env_check (bool): Flag indicating whether the environment supports image display. + track_history (collections.defaultdict): Dictionary to store tracking history for each object. + + Methods: + extract_tracks: Apply object tracking and extract tracks from an input image. + store_tracking_history: Store object tracking history for a given track ID and bounding box. + initialize_region: Initialize the counting region and line segment based on configuration. + display_output: Display the results of processing, including showing frames or saving results. + + Examples: + >>> solution = BaseSolution(model="yolo11n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)]) + >>> solution.initialize_region() + >>> image = cv2.imread("image.jpg") + >>> solution.extract_tracks(image) + >>> solution.display_output(image) + """ + + def __init__(self, is_cli=False, **kwargs): + """ + Initializes the BaseSolution class with configuration settings and the YOLO model. + + Args: + is_cli (bool): Enables CLI mode if set to True. + **kwargs (Any): Additional configuration parameters that override defaults. + """ + check_requirements("shapely>=2.0.0") + from shapely.geometry import LineString, Point, Polygon + from shapely.prepared import prep + + self.LineString = LineString + self.Polygon = Polygon + self.Point = Point + self.prep = prep + self.annotator = None # Initialize annotator + self.tracks = None + self.track_data = None + self.boxes = [] + self.clss = [] + self.track_ids = [] + self.track_line = None + self.masks = None + self.r_s = None + + self.LOGGER = LOGGER # Store logger object to be used in multiple solution classes + + # Load config and update with args + DEFAULT_SOL_DICT.update(kwargs) + DEFAULT_CFG_DICT.update(kwargs) + self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} + self.LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}") + + self.region = self.CFG["region"] # Store region data for other classes usage + self.line_width = ( + self.CFG["line_width"] if self.CFG["line_width"] is not None else 2 + ) # Store line_width for usage + + # Load Model and store classes names + if self.CFG["model"] is None: + self.CFG["model"] = "yolo11n.pt" + self.model = YOLO(self.CFG["model"]) + self.names = self.model.names + self.classes = self.CFG["classes"] + + self.track_add_args = { # Tracker additional arguments for advance configuration + k: self.CFG[k] for k in ["iou", "conf", "device", "max_det", "half", "tracker", "device", "verbose"] + } # verbose must be passed to track method; setting it False in YOLO still logs the track information. + + if is_cli and self.CFG["source"] is None: + d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" + self.LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") + from ultralytics.utils.downloads import safe_download + + safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets + self.CFG["source"] = d_s # set default source + + # Initialize environment and region setup + self.env_check = check_imshow(warn=True) + self.track_history = defaultdict(list) + + def extract_tracks(self, im0): + """ + Applies object tracking and extracts tracks from an input image or frame. + + Args: + im0 (np.ndarray): The input image or frame. + + Examples: + >>> solution = BaseSolution() + >>> frame = cv2.imread("path/to/image.jpg") + >>> solution.extract_tracks(frame) + """ + self.tracks = self.model.track(source=im0, persist=True, classes=self.classes, **self.track_add_args) + self.track_data = self.tracks[0].obb or self.tracks[0].boxes # Extract tracks for OBB or object detection + + self.masks = ( + self.tracks[0].masks.xy if hasattr(self.tracks[0], "masks") and self.tracks[0].masks is not None else None + ) + + if self.track_data and self.track_data.id is not None: + self.boxes = self.track_data.xyxy.cpu() + self.clss = self.track_data.cls.cpu().tolist() + self.track_ids = self.track_data.id.int().cpu().tolist() + else: + self.LOGGER.warning("WARNING ⚠️ no tracks found!") + self.boxes, self.clss, self.track_ids = [], [], [] + + def store_tracking_history(self, track_id, box): + """ + Stores the tracking history of an object. + + This method updates the tracking history for a given object by appending the center point of its + bounding box to the track line. It maintains a maximum of 30 points in the tracking history. + + Args: + track_id (int): The unique identifier for the tracked object. + box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2]. + + Examples: + >>> solution = BaseSolution() + >>> solution.store_tracking_history(1, [100, 200, 300, 400]) + """ + # Store tracking history + self.track_line = self.track_history[track_id] + self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)) + if len(self.track_line) > 30: + self.track_line.pop(0) + + def initialize_region(self): + """Initialize the counting region and line segment based on configuration settings.""" + if self.region is None: + self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)] + self.r_s = ( + self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region) + ) # region or line + + def display_output(self, plot_im): + """ + Display the results of the processing, which could involve showing frames, printing counts, or saving results. + + This method is responsible for visualizing the output of the object detection and tracking process. It displays + the processed frame with annotations, and allows for user interaction to close the display. + + Args: + plot_im (numpy.ndarray): The image or frame that has been processed and annotated. + + Examples: + >>> solution = BaseSolution() + >>> frame = cv2.imread("path/to/image.jpg") + >>> solution.display_output(frame) + + Notes: + - This method will only display output if the 'show' configuration is set to True and the environment + supports image display. + - The display can be closed by pressing the 'q' key. + """ + if self.CFG.get("show") and self.env_check: + cv2.imshow("Ultralytics Solutions", plot_im) + if cv2.waitKey(1) & 0xFF == ord("q"): + cv2.destroyAllWindows() # Closes current frame window + return + + def process(self, *args, **kwargs): + """Process method should be implemented by each Solution subclass.""" + + def __call__(self, *args, **kwargs): + """Allow instances to be called like a function with flexible arguments.""" + result = self.process(*args, **kwargs) # Call the subclass-specific process method + if self.CFG["verbose"]: # extract verbose value to display the output logs if True + LOGGER.info(f"🚀 Results: {result}") + return result + + +class SolutionAnnotator(Annotator): + """ + A specialized annotator class for visualizing and analyzing computer vision tasks. + + This class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking + trails, and visual annotations for Ultralytics Solutions: https://docs.ultralytics.com/solutions/. + and parking management. + + Attributes: + im (np.ndarray): The image being annotated. + line_width (int): Thickness of lines used in annotations. + font_size (int): Size of the font used for text annotations. + font (str): Path to the font file used for text rendering. + pil (bool): Whether to use PIL for text rendering. + example (str): An example attribute for demonstration purposes. + + Methods: + draw_region: Draws a region using specified points, colors, and thickness. + queue_counts_display: Displays queue counts in the specified region. + display_analytics: Displays overall statistics for parking lot management. + estimate_pose_angle: Calculates the angle between three points in an object pose. + draw_specific_points: Draws specific keypoints on the image. + plot_workout_information: Draws a labeled text box on the image. + plot_angle_and_count_and_stage: Visualizes angle, step count, and stage for workout monitoring. + plot_distance_and_line: Displays the distance between centroids and connects them with a line. + display_objects_labels: Annotates bounding boxes with object class labels. + segmentation_mask: Draws mask for segmented objects and optionally labels them. + sweep_annotator: Visualizes a vertical sweep line and optional label. + visioneye: Maps and connects object centroids to a visual "eye" point. + circle_label: Draws a circular label within a bounding box. + text_label: Draws a rectangular label within a bounding box. + + Examples: + >>> annotator = SolutionAnnotator(image) + >>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5) + >>> annotator.display_analytics( + ... image, text={"Available Spots": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10 + ... ) + """ + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + """ + Initializes the SolutionAnnotator class with an image for annotation. + + Args: + im (np.ndarray): The image to be annotated. + line_width (int, optional): Line thickness for drawing on the image. + font_size (int, optional): Font size for text annotations. + font (str, optional): Path to the font file. + pil (bool, optional): Indicates whether to use PIL for rendering text. + example (str, optional): An example parameter for demonstration purposes. + """ + super().__init__(im, line_width, font_size, font, pil, example) + + def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5): + """ + Draw a region or line on the image. + + Args: + reg_pts (List[Tuple[int, int]]): Region points (for line 2 points, for region 4+ points). + color (Tuple[int, int, int]): RGB color value for the region. + thickness (int): Line thickness for drawing the region. + """ + cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness) + + # Draw small circles at the corner points + for point in reg_pts: + cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle + + def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)): + """ + Displays queue counts on an image centered at the points with customizable font size and colors. + + Args: + label (str): Queue counts label. + points (List[Tuple[int, int]]): Region points for center point calculation to display text. + region_color (Tuple[int, int, int]): RGB queue region color. + txt_color (Tuple[int, int, int]): RGB text display color. + """ + x_values = [point[0] for point in points] + y_values = [point[1] for point in points] + center_x = sum(x_values) // len(points) + center_y = sum(y_values) // len(points) + + text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] + text_width = text_size[0] + text_height = text_size[1] + + rect_width = text_width + 20 + rect_height = text_height + 20 + rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2) + rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2) + cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1) + + text_x = center_x - text_width // 2 + text_y = center_y + text_height // 2 + + # Draw text + cv2.putText( + self.im, + label, + (text_x, text_y), + 0, + fontScale=self.sf, + color=txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def display_analytics(self, im0, text, txt_color, bg_color, margin): + """ + Display the overall statistics for parking lots, object counter etc. + + Args: + im0 (np.ndarray): Inference image. + text (Dict[str, Any]): Labels dictionary. + txt_color (Tuple[int, int, int]): Display color for text foreground. + bg_color (Tuple[int, int, int]): Display color for text background. + margin (int): Gap between text and rectangle for better display. + """ + horizontal_gap = int(im0.shape[1] * 0.02) + vertical_gap = int(im0.shape[0] * 0.01) + text_y_offset = 0 + for label, value in text.items(): + txt = f"{label}: {value}" + text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0] + if text_size[0] < 5 or text_size[1] < 5: + text_size = (5, 5) + text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap + text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap + rect_x1 = text_x - margin * 2 + rect_y1 = text_y - text_size[1] - margin * 2 + rect_x2 = text_x + text_size[0] + margin * 2 + rect_y2 = text_y + margin * 2 + cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1) + cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA) + text_y_offset = rect_y2 + + @staticmethod + def estimate_pose_angle(a, b, c): + """ + Calculate the angle between three points for workout monitoring. + + Args: + a (List[float]): The coordinates of the first point. + b (List[float]): The coordinates of the second point (vertex). + c (List[float]): The coordinates of the third point. + + Returns: + (float): The angle in degrees between the three points. + """ + a, b, c = np.array(a), np.array(b), np.array(c) + radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0]) + angle = np.abs(radians * 180.0 / np.pi) + if angle > 180.0: + angle = 360 - angle + return angle + + def draw_specific_kpts(self, keypoints, indices=None, radius=2, conf_thresh=0.25): + """ + Draw specific keypoints for gym steps counting. + + Args: + keypoints (List[List[float]]): Keypoints data to be plotted, each in format [x, y, confidence]. + indices (List[int], optional): Keypoint indices to be plotted. + radius (int, optional): Keypoint radius. + conf_thresh (float, optional): Confidence threshold for keypoints. + + Returns: + (np.ndarray): Image with drawn keypoints. + + Note: + Keypoint format: [x, y] or [x, y, confidence]. + Modifies self.im in-place. + """ + indices = indices or [2, 5, 7] + points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thresh] + + # Draw lines between consecutive points + for start, end in zip(points[:-1], points[1:]): + cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA) + + # Draw circles for keypoints + for pt in points: + cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA) + + return self.im + + def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)): + """ + Draw workout text with a background on the image. + + Args: + display_text (str): The text to be displayed. + position (Tuple[int, int]): Coordinates (x, y) on the image where the text will be placed. + color (Tuple[int, int, int], optional): Text background color. + txt_color (Tuple[int, int, int], optional): Text foreground color. + + Returns: + (int): The height of the text. + """ + (text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf) + + # Draw background rectangle + cv2.rectangle( + self.im, + (position[0], position[1] - text_height - 5), + (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf), + color, + -1, + ) + # Draw text + cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf) + + return text_height + + def plot_angle_and_count_and_stage( + self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255) + ): + """ + Plot the pose angle, count value, and step stage for workout monitoring. + + Args: + angle_text (str): Angle value for workout monitoring. + count_text (str): Counts value for workout monitoring. + stage_text (str): Stage decision for workout monitoring. + center_kpt (List[int]): Centroid pose index for workout monitoring. + color (Tuple[int, int, int], optional): Text background color. + txt_color (Tuple[int, int, int], optional): Text foreground color. + """ + # Format text + angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}" + + # Draw angle, count and stage text + angle_height = self.plot_workout_information( + angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color + ) + count_height = self.plot_workout_information( + count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color + ) + self.plot_workout_information( + stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color + ) + + def plot_distance_and_line( + self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255) + ): + """ + Plot the distance and line between two centroids on the frame. + + Args: + pixels_distance (float): Pixels distance between two bbox centroids. + centroids (List[Tuple[int, int]]): Bounding box centroids data. + line_color (Tuple[int, int, int], optional): Distance line color. + centroid_color (Tuple[int, int, int], optional): Bounding box centroid color. + """ + # Get the text size + text = f"Pixels Distance: {pixels_distance:.2f}" + (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf) + + # Define corners with 10-pixel margin and draw rectangle + cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1) + + # Calculate the position for the text with a 10-pixel margin and draw text + text_position = (25, 25 + text_height_m + 10) + cv2.putText( + self.im, + text, + text_position, + 0, + self.sf, + (255, 255, 255), + self.tf, + cv2.LINE_AA, + ) + + cv2.line(self.im, centroids[0], centroids[1], line_color, 3) + cv2.circle(self.im, centroids[0], 6, centroid_color, -1) + cv2.circle(self.im, centroids[1], 6, centroid_color, -1) + + def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin): + """ + Display the bounding boxes labels in parking management app. + + Args: + im0 (np.ndarray): Inference image. + text (str): Object/class name. + txt_color (Tuple[int, int, int]): Display color for text foreground. + bg_color (Tuple[int, int, int]): Display color for text background. + x_center (float): The x position center point for bounding box. + y_center (float): The y position center point for bounding box. + margin (int): The gap between text and rectangle for better display. + """ + text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 + + rect_x1 = text_x - margin + rect_y1 = text_y - text_size[1] - margin + rect_x2 = text_x + text_size[0] + margin + rect_y2 = text_y + margin + cv2.rectangle( + im0, + (int(rect_x1), int(rect_y1)), + (int(rect_x2), int(rect_y2)), + tuple(map(int, bg_color)), # Ensure color values are int + -1, + ) + + cv2.putText( + im0, + text, + (int(text_x), int(text_y)), + 0, + self.sf, + tuple(map(int, txt_color)), # Ensure color values are int + self.tf, + lineType=cv2.LINE_AA, + ) + + def segmentation_mask(self, mask, mask_color=(255, 0, 255), label=None, alpha=0.5): + """ + Draw an optimized segmentation mask with smooth corners, highlighted edge, and dynamic text box size. + + Args: + mask (np.ndarray): A 2D array of shape (N, 2) containing the object mask. + mask_color (Tuple[int, int, int]): RGB color for the mask. + label (str, optional): Text label for the object. + alpha (float): Transparency level (0 = fully transparent, 1 = fully opaque). + """ + if mask.size == 0: + return + + overlay = self.im.copy() + mask = np.int32([mask]) + + # Approximate polygon for smooth corners with epsilon + refined_mask = cv2.approxPolyDP(mask, 0.002 * cv2.arcLength(mask, True), True) + + # Apply a highlighter effect by drawing a thick outer shadow + cv2.polylines(overlay, [refined_mask], isClosed=True, color=mask_color, thickness=self.lw * 3) + cv2.fillPoly(overlay, [refined_mask], mask_color) # draw mask with primary color + + # Apply an inner glow effect for extra clarity + cv2.polylines(overlay, [refined_mask], isClosed=True, color=mask_color, thickness=self.lw) + + self.im = cv2.addWeighted(overlay, alpha, self.im, 1 - alpha, 0) # blend overlay with the original image + + # Draw label if provided + if label: + text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf) + text_x, text_y = refined_mask[0][0][0], refined_mask[0][0][1] + rect_start, rect_end = (text_x - 5, text_y - text_size[1] - 5), (text_x + text_size[0] + 5, text_y + 5) + cv2.rectangle(self.im, rect_start, rect_end, mask_color, -1) + cv2.putText( + self.im, + label, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf, + self.get_txt_color(mask_color), + self.tf, + ) + + def sweep_annotator(self, line_x=0, line_y=0, label=None, color=(221, 0, 186), txt_color=(255, 255, 255)): + """ + Draw a sweep annotation line and an optional label. + + Args: + line_x (int): The x-coordinate of the sweep line. + line_y (int): The y-coordinate limit of the sweep line. + label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn. + color (Tuple[int, int, int]): RGB color for the line and label background. + txt_color (Tuple[int, int, int]): RGB color for the label text. + """ + # Draw the sweep line + cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2) + + # Draw label, if provided + if label: + (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf) + cv2.rectangle( + self.im, + (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10), + (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10), + color, + -1, + ) + cv2.putText( + self.im, + label, + (line_x - text_width // 2, line_y // 2 + text_height // 2), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf, + txt_color, + self.tf, + ) + + def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)): + """ + Perform pinpoint human-vision eye mapping and plotting. + + Args: + box (List[float]): Bounding box coordinates in format [x1, y1, x2, y2]. + center_point (Tuple[int, int]): Center point for vision eye view. + color (Tuple[int, int, int]): Object centroid and line color. + pin_color (Tuple[int, int, int]): Visioneye point color. + """ + center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1) + cv2.circle(self.im, center_bbox, self.tf * 2, color, -1) + cv2.line(self.im, center_point, center_bbox, color, self.tf) + + def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2): + """ + Draw a label with a background circle centered within a given bounding box. + + Args: + box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2). + label (str): The text label to be displayed. + color (Tuple[int, int, int]): The background color of the circle (B, G, R). + txt_color (Tuple[int, int, int]): The color of the text (R, G, B). + margin (int): The margin between the text and the circle border. + """ + # If label have more than 3 characters, skip other characters, due to circle size + if len(label) > 3: + print( + f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!" + ) + label = label[:3] + + # Calculate the center of the box + x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + # Get the text size + text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0] + # Calculate the required radius to fit the text with the margin + required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin + # Draw the circle with the required radius + cv2.circle(self.im, (x_center, y_center), required_radius, color, -1) + # Calculate the position for the text + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 + # Draw the text + cv2.putText( + self.im, + str(label), + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf - 0.15, + self.get_txt_color(color, txt_color), + self.tf, + lineType=cv2.LINE_AA, + ) + + def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5): + """ + Draw a label with a background rectangle centered within a given bounding box. + + Args: + box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2). + label (str): The text label to be displayed. + color (Tuple[int, int, int]): The background color of the rectangle (B, G, R). + txt_color (Tuple[int, int, int]): The color of the text (R, G, B). + margin (int): The margin between the text and the rectangle border. + """ + # Calculate the center of the bounding box + x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + # Get the size of the text + text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0] + # Calculate the top-left corner of the text (to center it) + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 + # Calculate the coordinates of the background rectangle + rect_x1 = text_x - margin + rect_y1 = text_y - text_size[1] - margin + rect_x2 = text_x + text_size[0] + margin + rect_y2 = text_y + margin + # Draw the background rectangle + cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1) + # Draw the text on top of the rectangle + cv2.putText( + self.im, + label, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf - 0.1, + self.get_txt_color(color, txt_color), + self.tf, + lineType=cv2.LINE_AA, + ) + + +class SolutionResults: + """ + A class to encapsulate the results of Ultralytics Solutions. + + This class is designed to store and manage various outputs generated by the solution pipeline, including counts, + angles, and workout stages. + + Attributes: + plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions. + in_count (int): The total number of "in" counts in a video stream. + out_count (int): The total number of "out" counts in a video stream. + classwise_count (Dict[str, int]): A dictionary containing counts of objects categorized by class. + queue_count (int): The count of objects in a queue or waiting area. + workout_count (int): The count of workout repetitions. + workout_angle (float): The angle calculated during a workout exercise. + workout_stage (str): The current stage of the workout. + pixels_distance (float): The calculated distance in pixels between two points or objects. + available_slots (int): The number of available slots in a monitored area. + filled_slots (int): The number of filled slots in a monitored area. + email_sent (bool): A flag indicating whether an email notification was sent. + total_tracks (int): The total number of tracked objects. + region_counts (dict): The count of objects within a specific region. + speed_dict (Dict[str, float]): A dictionary containing speed information for tracked objects. + total_crop_objects (int): Total number of cropped objects using ObjectCropper class. + """ + + def __init__(self, **kwargs): + """ + Initialize a SolutionResults object with default or user-specified values. + + Args: + **kwargs (Any): Optional arguments to override default attribute values. + """ + self.plot_im = None + self.in_count = 0 + self.out_count = 0 + self.classwise_count = {} + self.queue_count = 0 + self.workout_count = 0 + self.workout_angle = 0.0 + self.workout_stage = None + self.pixels_distance = 0.0 + self.available_slots = 0 + self.filled_slots = 0 + self.email_sent = False + self.total_tracks = 0 + self.region_counts = {} + self.speed_dict = {} + self.total_crop_objects = 0 + + # Override with user-defined values + self.__dict__.update(kwargs) + + def __str__(self): + """ + Return a formatted string representation of the SolutionResults object. + + Returns: + (str): A string representation listing non-null attributes. + """ + attrs = { + k: v + for k, v in self.__dict__.items() + if k != "plot_im" and v not in [None, {}, 0, 0.0, False] # Exclude `plot_im` explicitly + } + return f"SolutionResults({', '.join(f'{k}={v}' for k, v in attrs.items())})" diff --git a/tracking/ultralytics/solutions/speed_estimation.py b/tracking/ultralytics/solutions/speed_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..f477de19215e61dda80463247135ed19c7cd6cbf --- /dev/null +++ b/tracking/ultralytics/solutions/speed_estimation.py @@ -0,0 +1,113 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from time import time + +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class SpeedEstimator(BaseSolution): + """ + A class to estimate the speed of objects in a real-time video stream based on their tracks. + + This class extends the BaseSolution class and provides functionality for estimating object speeds using + tracking data in video streams. + + Attributes: + spd (Dict[int, float]): Dictionary storing speed data for tracked objects. + trkd_ids (List[int]): List of tracked object IDs that have already been speed-estimated. + trk_pt (Dict[int, float]): Dictionary storing previous timestamps for tracked objects. + trk_pp (Dict[int, Tuple[float, float]]): Dictionary storing previous positions for tracked objects. + region (List[Tuple[int, int]]): List of points defining the speed estimation region. + track_line (List[Tuple[float, float]]): List of points representing the object's track. + r_s (LineString): LineString object representing the speed estimation region. + + Methods: + initialize_region: Initializes the speed estimation region. + process: Processes input frames to estimate object speeds. + store_tracking_history: Stores the tracking history for an object. + extract_tracks: Extracts tracks from the current frame. + display_output: Displays the output with annotations. + + Examples: + >>> estimator = SpeedEstimator() + >>> frame = cv2.imread("frame.jpg") + >>> results = estimator.process(frame) + >>> cv2.imshow("Speed Estimation", results.plot_im) + """ + + def __init__(self, **kwargs): + """ + Initialize the SpeedEstimator object with speed estimation parameters and data structures. + + Args: + **kwargs (Any): Additional keyword arguments passed to the parent class. + """ + super().__init__(**kwargs) + + self.initialize_region() # Initialize speed region + + self.spd = {} # Dictionary for speed data + self.trkd_ids = [] # List for already speed-estimated and tracked IDs + self.trk_pt = {} # Dictionary for tracks' previous timestamps + self.trk_pp = {} # Dictionary for tracks' previous positions + + def process(self, im0): + """ + Process an input frame to estimate object speeds based on tracking data. + + Args: + im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images. + + Returns: + (SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects). + + Examples: + >>> estimator = SpeedEstimator() + >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> results = estimator.process(image) + """ + self.extract_tracks(im0) # Extract tracks + annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + # Draw speed estimation region + annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + self.store_tracking_history(track_id, box) # Store track history + + # Initialize tracking data for new objects + if track_id not in self.trk_pt: + self.trk_pt[track_id] = 0 + if track_id not in self.trk_pp: + self.trk_pp[track_id] = self.track_line[-1] + + # Prepare label with speed if available, otherwise use class name + speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)] + annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box + + # Determine if object is crossing the speed estimation region + if self.LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.r_s): + direction = "known" + else: + direction = "unknown" + + # Calculate speed for objects crossing the region for the first time + if direction == "known" and track_id not in self.trkd_ids: + self.trkd_ids.append(track_id) + time_difference = time() - self.trk_pt[track_id] + if time_difference > 0: + # Calculate speed based on vertical displacement and time + self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference + + # Update tracking data for next frame + self.trk_pt[track_id] = time() + self.trk_pp[track_id] = self.track_line[-1] + + plot_im = annotator.result() + self.display_output(plot_im) # Display output with base class function + + # Return results with processed image and tracking summary + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/solutions/streamlit_inference.py b/tracking/ultralytics/solutions/streamlit_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..1fbbca663e3b1bff3204bb33b5b663404b260559 --- /dev/null +++ b/tracking/ultralytics/solutions/streamlit_inference.py @@ -0,0 +1,196 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import io +from typing import Any + +import cv2 + +from ultralytics import YOLO +from ultralytics.utils import LOGGER +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS + + +class Inference: + """ + A class to perform object detection, image classification, image segmentation and pose estimation inference. + + This class provides functionalities for loading models, configuring settings, uploading video files, and performing + real-time inference using Streamlit and Ultralytics YOLO models. + + Attributes: + st (module): Streamlit module for UI creation. + temp_dict (dict): Temporary dictionary to store the model path and other configuration. + model_path (str): Path to the loaded model. + model (YOLO): The YOLO model instance. + source (str): Selected video source (webcam or video file). + enable_trk (str): Enable tracking option ("Yes" or "No"). + conf (float): Confidence threshold for detection. + iou (float): IoU threshold for non-maximum suppression. + org_frame (Any): Container for the original frame to be displayed. + ann_frame (Any): Container for the annotated frame to be displayed. + vid_file_name (str | int): Name of the uploaded video file or webcam index. + selected_ind (List[int]): List of selected class indices for detection. + + Methods: + web_ui: Sets up the Streamlit web interface with custom HTML elements. + sidebar: Configures the Streamlit sidebar for model and inference settings. + source_upload: Handles video file uploads through the Streamlit interface. + configure: Configures the model and loads selected classes for inference. + inference: Performs real-time object detection inference. + + Examples: + >>> inf = Inference(model="path/to/model.pt") # Model is an optional argument + >>> inf.inference() + """ + + def __init__(self, **kwargs: Any): + """ + Initialize the Inference class, checking Streamlit requirements and setting up the model path. + + Args: + **kwargs (Any): Additional keyword arguments for model configuration. + """ + check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds + import streamlit as st + + self.st = st # Reference to the Streamlit module + self.source = None # Video source selection (webcam or video file) + self.enable_trk = False # Flag to toggle object tracking + self.conf = 0.25 # Confidence threshold for detection + self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression + self.org_frame = None # Container for the original frame display + self.ann_frame = None # Container for the annotated frame display + self.vid_file_name = None # Video file name or webcam index + self.selected_ind = [] # List of selected class indices for detection + self.model = None # YOLO model instance + + self.temp_dict = {"model": None, **kwargs} + self.model_path = None # Model file path + if self.temp_dict["model"] is not None: + self.model_path = self.temp_dict["model"] + + LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}") + + def web_ui(self): + """Sets up the Streamlit web interface with custom HTML elements.""" + menu_style_cfg = """""" # Hide main menu style + + # Main title of streamlit application + main_title_cfg = """

Ultralytics YOLO Streamlit Application

""" + + # Subtitle of streamlit application + sub_title_cfg = """

Experience real-time object detection on your webcam with the power + of Ultralytics YOLO! 🚀

""" + + # Set html page configuration and append custom HTML + self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide") + self.st.markdown(menu_style_cfg, unsafe_allow_html=True) + self.st.markdown(main_title_cfg, unsafe_allow_html=True) + self.st.markdown(sub_title_cfg, unsafe_allow_html=True) + + def sidebar(self): + """Configure the Streamlit sidebar for model and inference settings.""" + with self.st.sidebar: # Add Ultralytics LOGO + logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg" + self.st.image(logo, width=250) + + self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu + self.source = self.st.sidebar.selectbox( + "Video", + ("webcam", "video"), + ) # Add source selection dropdown + self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking + self.conf = float( + self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01) + ) # Slider for confidence + self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold + + col1, col2 = self.st.columns(2) # Create two columns for displaying frames + self.org_frame = col1.empty() # Container for original frame + self.ann_frame = col2.empty() # Container for annotated frame + + def source_upload(self): + """Handle video file uploads through the Streamlit interface.""" + self.vid_file_name = "" + if self.source == "video": + vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"]) + if vid_file is not None: + g = io.BytesIO(vid_file.read()) # BytesIO Object + with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes + out.write(g.read()) # Read bytes into file + self.vid_file_name = "ultralytics.mp4" + elif self.source == "webcam": + self.vid_file_name = 0 # Use webcam index 0 + + def configure(self): + """Configure the model and load selected classes for inference.""" + # Add dropdown menu for model selection + available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")] + if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later + available_models.insert(0, self.model_path.split(".pt")[0]) + selected_model = self.st.sidebar.selectbox("Model", available_models) + + with self.st.spinner("Model is downloading..."): + self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model + class_names = list(self.model.names.values()) # Convert dictionary to list of class names + self.st.success("Model loaded successfully!") + + # Multiselect box with class names and get indices of selected classes + selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) + self.selected_ind = [class_names.index(option) for option in selected_classes] + + if not isinstance(self.selected_ind, list): # Ensure selected_options is a list + self.selected_ind = list(self.selected_ind) + + def inference(self): + """Perform real-time object detection inference on video or webcam feed.""" + self.web_ui() # Initialize the web interface + self.sidebar() # Create the sidebar + self.source_upload() # Upload the video source + self.configure() # Configure the app + + if self.st.sidebar.button("Start"): + stop_button = self.st.button("Stop") # Button to stop the inference + cap = cv2.VideoCapture(self.vid_file_name) # Capture the video + if not cap.isOpened(): + self.st.error("Could not open webcam or video source.") + return + + while cap.isOpened(): + success, frame = cap.read() + if not success: + self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.") + break + + # Process frame with model + if self.enable_trk == "Yes": + results = self.model.track( + frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True + ) + else: + results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind) + + annotated_frame = results[0].plot() # Add annotations on frame + + if stop_button: + cap.release() # Release the capture + self.st.stop() # Stop streamlit app + + self.org_frame.image(frame, channels="BGR") # Display original frame + self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame + + cap.release() # Release the capture + cv2.destroyAllWindows() # Destroy all OpenCV windows + + +if __name__ == "__main__": + import sys # Import the sys module for accessing command-line arguments + + # Check if a model name is provided as a command-line argument + args = len(sys.argv) + model = sys.argv[1] if args > 1 else None # Assign first argument as the model name if provided + # Create an instance of the Inference class and run inference + Inference(model=model).inference() diff --git a/tracking/ultralytics/solutions/trackzone.py b/tracking/ultralytics/solutions/trackzone.py new file mode 100644 index 0000000000000000000000000000000000000000..74bb7fdb8e13220f39029b74e40c81a7ea73a818 --- /dev/null +++ b/tracking/ultralytics/solutions/trackzone.py @@ -0,0 +1,86 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class TrackZone(BaseSolution): + """ + A class to manage region-based object tracking in a video stream. + + This class extends the BaseSolution class and provides functionality for tracking objects within a specific region + defined by a polygonal area. Objects outside the region are excluded from tracking. + + Attributes: + region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points. + line_width (int): Width of the lines used for drawing bounding boxes and region boundaries. + names (List[str]): List of class names that the model can detect. + boxes (List[np.ndarray]): Bounding boxes of tracked objects. + track_ids (List[int]): Unique identifiers for each tracked object. + clss (List[int]): Class indices of tracked objects. + + Methods: + process: Processes each frame of the video, applying region-based tracking. + extract_tracks: Extracts tracking information from the input frame. + display_output: Displays the processed output. + + Examples: + >>> tracker = TrackZone() + >>> frame = cv2.imread("frame.jpg") + >>> results = tracker.process(frame) + >>> cv2.imshow("Tracked Frame", results.plot_im) + """ + + def __init__(self, **kwargs): + """ + Initialize the TrackZone class for tracking objects within a defined region in video streams. + + Args: + **kwargs (Any): Additional keyword arguments passed to the parent class. + """ + super().__init__(**kwargs) + default_region = [(150, 150), (1130, 150), (1130, 570), (150, 570)] + self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32)) + + def process(self, im0): + """ + Process the input frame to track objects within a defined region. + + This method initializes the annotator, creates a mask for the specified region, extracts tracks + only from the masked area, and updates tracking information. Objects outside the region are ignored. + + Args: + im0 (np.ndarray): The input image or frame to be processed. + + Returns: + (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the + total number of tracked objects within the defined region. + + Examples: + >>> tracker = TrackZone() + >>> frame = cv2.imread("path/to/image.jpg") + >>> results = tracker.process(frame) + """ + annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator + + # Create a mask for the region and extract tracks from the masked image + mask = np.zeros_like(im0[:, :, 0]) + mask = cv2.fillPoly(mask, [self.region], 255) + masked_frame = cv2.bitwise_and(im0, im0, mask=mask) + self.extract_tracks(masked_frame) + + # Draw the region boundary + cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2) + + # Iterate over boxes, track ids, classes indexes list and draw bounding boxes + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + annotator.box_label(box, label=f"{self.names[cls]}:{track_id}", color=colors(track_id, True)) + + plot_im = annotator.result() + self.display_output(plot_im) # display output with base class function + + # Return a SolutionResults + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/solutions/vision_eye.py b/tracking/ultralytics/solutions/vision_eye.py new file mode 100644 index 0000000000000000000000000000000000000000..282dccb379ac24e0ff049a1159631c1c9e2cd111 --- /dev/null +++ b/tracking/ultralytics/solutions/vision_eye.py @@ -0,0 +1,69 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + + +from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults +from ultralytics.utils.plotting import colors + + +class VisionEye(BaseSolution): + """ + A class to manage object detection and vision mapping in images or video streams. + + This class extends the BaseSolution class and provides functionality for detecting objects, + mapping vision points, and annotating results with bounding boxes and labels. + + Attributes: + vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks. + + Methods: + process: Process the input image to detect objects, annotate them, and apply vision mapping. + + Examples: + >>> vision_eye = VisionEye() + >>> frame = cv2.imread("frame.jpg") + >>> results = vision_eye.process(frame) + >>> print(f"Total detected instances: {results.total_tracks}") + """ + + def __init__(self, **kwargs): + """ + Initialize the VisionEye class for detecting objects and applying vision mapping. + + Args: + **kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point. + """ + super().__init__(**kwargs) + # Set the vision point where the system will view objects and draw tracks + self.vision_point = kwargs.get("vision_point", (30, 30)) + + def process(self, im0): + """ + Perform object detection, vision mapping, and annotation on the input image. + + Args: + im0 (numpy.ndarray): The input image for detection and annotation. + + Returns: + (SolutionResults): Object containing the annotated image and tracking statistics. + - plot_im: Annotated output image with bounding boxes and vision mapping + - total_tracks: Number of tracked objects in the frame + + Examples: + >>> vision_eye = VisionEye() + >>> frame = cv2.imread("image.jpg") + >>> results = vision_eye.process(frame) + >>> print(f"Detected {results.total_tracks} objects") + """ + self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks) + annotator = SolutionAnnotator(im0, self.line_width) + + for cls, t_id, box in zip(self.clss, self.track_ids, self.boxes): + # Annotate the image with bounding boxes, labels, and vision mapping + annotator.box_label(box, label=self.names[cls], color=colors(int(t_id), True)) + annotator.visioneye(box, self.vision_point) + + plot_im = annotator.result() + self.display_output(plot_im) # Display the annotated output using the base class function + + # Return a SolutionResults object with the annotated image and tracking statistics + return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/tracking/ultralytics/trackers/README.md b/tracking/ultralytics/trackers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..18d45e8c81258263d418469a6f4788f8cf5d0b73 --- /dev/null +++ b/tracking/ultralytics/trackers/README.md @@ -0,0 +1,314 @@ +# Multi-Object Tracking with Ultralytics YOLO + +YOLOv8 trackers visualization + +Object tracking in the realm of video analytics is a critical task that not only identifies the location and class of objects within the frame but also maintains a unique ID for each detected object as the video progresses. The applications are limitless—ranging from surveillance and security to real-time sports analytics. + +## Why Choose Ultralytics YOLO for Object Tracking? + +The output from Ultralytics trackers is consistent with standard object detection but has the added value of object IDs. This makes it easy to track objects in video streams and perform subsequent analytics. Here's why you should consider using Ultralytics YOLO for your object tracking needs: + +- **Efficiency:** Process video streams in real-time without compromising accuracy. +- **Flexibility:** Supports multiple tracking algorithms and configurations. +- **Ease of Use:** Simple Python API and CLI options for quick integration and deployment. +- **Customizability:** Easy to use with custom trained YOLO models, allowing integration into domain-specific applications. + +**Video Tutorial:** [Object Detection and Tracking with Ultralytics YOLO](https://www.youtube.com/embed/hHyHmOtmEgs?si=VNZtXmm45Nb9s-N-). + +## Features at a Glance + +Ultralytics YOLO extends its object detection features to provide robust and versatile object tracking: + +- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos. +- **Multiple Tracker Support:** Choose from a variety of established tracking algorithms. +- **Customizable Tracker Configurations:** Tailor the tracking algorithm to meet specific requirements by adjusting various parameters. + +## Available Trackers + +Ultralytics YOLO supports the following tracking algorithms. They can be enabled by passing the relevant YAML configuration file such as `tracker=tracker_type.yaml`: + +- [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - Use `botsort.yaml` to enable this tracker. +- [ByteTrack](https://github.com/ifzhang/ByteTrack) - Use `bytetrack.yaml` to enable this tracker. + +The default tracker is BoT-SORT. + +## Tracking + +To run the tracker on video streams, use a trained Detect, Segment or Pose model such as YOLO11n, YOLO11n-seg and YOLO11n-pose. + +#### Python + +```python +from ultralytics import YOLO + +# Load an official or custom model +model = YOLO("yolo11n.pt") # Load an official Detect model +model = YOLO("yolo11n-seg.pt") # Load an official Segment model +model = YOLO("yolo11n-pose.pt") # Load an official Pose model +model = YOLO("path/to/best.pt") # Load a custom trained model + +# Perform tracking with the model +results = model.track(source="https://youtu.be/LNwODJXcvt4", show=True) # Tracking with default tracker +results = model.track( + source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml" +) # Tracking with ByteTrack tracker +``` + +#### CLI + +```bash +# Perform tracking with various models using the command line interface +yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model +yolo track model=yolo11n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model +yolo track model=yolo11n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model +yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model + +# Track using ByteTrack tracker +yolo track model=path/to/best.pt tracker="bytetrack.yaml" +``` + +As can be seen in the above usage, tracking is available for all Detect, Segment and Pose models run on videos or streaming sources. + +## Configuration + +### Tracking Arguments + +Tracking configuration shares properties with Predict mode, such as `conf`, `iou`, and `show`. For further configurations, refer to the [Predict](https://docs.ultralytics.com/modes/predict/) model page. + +#### Python + +```python +from ultralytics import YOLO + +# Configure the tracking parameters and run the tracker +model = YOLO("yolo11n.pt") +results = model.track(source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True) +``` + +#### CLI + +```bash +# Configure tracking parameters and run the tracker using the command line interface +yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3, iou=0.5 show +``` + +### Tracker Selection + +Ultralytics also allows you to use a modified tracker configuration file. To do this, simply make a copy of a tracker config file (for example, `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and modify any configurations (except the `tracker_type`) as per your needs. + +#### Python + +```python +from ultralytics import YOLO + +# Load the model and run the tracker with a custom configuration file +model = YOLO("yolo11n.pt") +results = model.track(source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml") +``` + +#### CLI + +```bash +# Load the model and run the tracker with a custom configuration file using the command line interface +yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml' +``` + +For a comprehensive list of tracking arguments, refer to the [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) page. + +## Python Examples + +### Persisting Tracks Loop + +Here is a Python script using OpenCV (`cv2`) and YOLO11 to run object tracking on video frames. This script still assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`). The `persist=True` argument tells the tracker than the current image or frame is the next in a sequence and to expect tracks from the previous image in the current image. + +#### Python + +```python +import cv2 + +from ultralytics import YOLO + +# Load the YOLO11 model +model = YOLO("yolo11n.pt") + +# Open the video file +video_path = "path/to/video.mp4" +cap = cv2.VideoCapture(video_path) + +# Loop through the video frames +while cap.isOpened(): + # Read a frame from the video + success, frame = cap.read() + + if success: + # Run YOLO11 tracking on the frame, persisting tracks between frames + results = model.track(frame, persist=True) + + # Visualize the results on the frame + annotated_frame = results[0].plot() + + # Display the annotated frame + cv2.imshow("YOLO11 Tracking", annotated_frame) + + # Break the loop if 'q' is pressed + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + # Break the loop if the end of the video is reached + break + +# Release the video capture object and close the display window +cap.release() +cv2.destroyAllWindows() +``` + +Please note the change from `model(frame)` to `model.track(frame)`, which enables object tracking instead of simple detection. This modified script will run the tracker on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'. + +### Plotting Tracks Over Time + +Visualizing object tracks over consecutive frames can provide valuable insights into the movement patterns and behavior of detected objects within a video. With Ultralytics YOLO11, plotting these tracks is a seamless and efficient process. + +In the following example, we demonstrate how to utilize YOLO11's tracking capabilities to plot the movement of detected objects across multiple video frames. This script involves opening a video file, reading it frame by frame, and utilizing the YOLO model to identify and track various objects. By retaining the center points of the detected bounding boxes and connecting them, we can draw lines that represent the paths followed by the tracked objects. + +#### Python + +```python +from collections import defaultdict + +import cv2 +import numpy as np + +from ultralytics import YOLO + +# Load the YOLO11 model +model = YOLO("yolo11n.pt") + +# Open the video file +video_path = "path/to/video.mp4" +cap = cv2.VideoCapture(video_path) + +# Store the track history +track_history = defaultdict(lambda: []) + +# Loop through the video frames +while cap.isOpened(): + # Read a frame from the video + success, frame = cap.read() + + if success: + # Run YOLO11 tracking on the frame, persisting tracks between frames + results = model.track(frame, persist=True) + + # Get the boxes and track IDs + boxes = results[0].boxes.xywh.cpu() + track_ids = results[0].boxes.id.int().cpu().tolist() + + # Visualize the results on the frame + annotated_frame = results[0].plot() + + # Plot the tracks + for box, track_id in zip(boxes, track_ids): + x, y, w, h = box + track = track_history[track_id] + track.append((float(x), float(y))) # x, y center point + if len(track) > 30: # retain 90 tracks for 90 frames + track.pop(0) + + # Draw the tracking lines + points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines( + annotated_frame, + [points], + isClosed=False, + color=(230, 230, 230), + thickness=10, + ) + + # Display the annotated frame + cv2.imshow("YOLO11 Tracking", annotated_frame) + + # Break the loop if 'q' is pressed + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + # Break the loop if the end of the video is reached + break + +# Release the video capture object and close the display window +cap.release() +cv2.destroyAllWindows() +``` + +### Multithreaded Tracking + +Multithreaded tracking provides the capability to run object tracking on multiple video streams simultaneously. This is particularly useful when handling multiple video inputs, such as from multiple surveillance cameras, where concurrent processing can greatly enhance efficiency and performance. + +In the provided Python script, we make use of Python's `threading` module to run multiple instances of the tracker concurrently. Each thread is responsible for running the tracker on one video file, and all the threads run simultaneously in the background. + +To ensure that each thread receives the correct parameters (the video file and the model to use), we define a function `run_tracker_in_thread` that accepts these parameters and contains the main tracking loop. This function reads the video frame by frame, runs the tracker, and displays the results. + +Two different models are used in this example: `yolo11n.pt` and `yolo11n-seg.pt`, each tracking objects in a different video file. The video files are specified in `video_file1` and `video_file2`. + +The `daemon=True` parameter in `threading.Thread` means that these threads will be closed as soon as the main program finishes. We then start the threads with `start()` and use `join()` to make the main thread wait until both tracker threads have finished. + +Finally, after all threads have completed their task, the windows displaying the results are closed using `cv2.destroyAllWindows()`. + +#### Python + +```python +import threading + +import cv2 + +from ultralytics import YOLO + + +def run_tracker_in_thread(filename, model): + """Starts multi-thread tracking on video from `filename` using `model` and displays results frame by frame.""" + video = cv2.VideoCapture(filename) + frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + for _ in range(frames): + ret, frame = video.read() + if ret: + results = model.track(source=frame, persist=True) + res_plotted = results[0].plot() + cv2.imshow("p", res_plotted) + if cv2.waitKey(1) == ord("q"): + break + + +# Load the models +model1 = YOLO("yolo11n.pt") +model2 = YOLO("yolo11n-seg.pt") + +# Define the video files for the trackers +video_file1 = "path/to/video1.mp4" +video_file2 = "path/to/video2.mp4" + +# Create the tracker threads +tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1), daemon=True) +tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2), daemon=True) + +# Start the tracker threads +tracker_thread1.start() +tracker_thread2.start() + +# Wait for the tracker threads to finish +tracker_thread1.join() +tracker_thread2.join() + +# Clean up and close windows +cv2.destroyAllWindows() +``` + +This example can easily be extended to handle more video files and models by creating more threads and applying the same methodology. + +## Contribute New Trackers + +Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your real-world applications and solutions could be invaluable for users working on tracking tasks. + +By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community. + +To initiate your contribution, please refer to our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) for comprehensive instructions on submitting a Pull Request (PR) 🛠️. We are excited to see what you bring to the table! + +Together, let's enhance the tracking capabilities of the Ultralytics YOLO ecosystem 🙏! diff --git a/tracking/ultralytics/trackers/__init__.py b/tracking/ultralytics/trackers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2919511ba506cf9887d4fcd1014f4a57263f36ba --- /dev/null +++ b/tracking/ultralytics/trackers/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .bot_sort import BOTSORT +from .byte_tracker import BYTETracker +from .track import register_tracker + +__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import diff --git a/tracking/ultralytics/trackers/basetrack.py b/tracking/ultralytics/trackers/basetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..156483feb76816bad73275e2b12df35e1fb945af --- /dev/null +++ b/tracking/ultralytics/trackers/basetrack.py @@ -0,0 +1,124 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Module defines the base classes and structures for object tracking in YOLO.""" + +from collections import OrderedDict + +import numpy as np + + +class TrackState: + """ + Enumeration class representing the possible states of an object being tracked. + + Attributes: + New (int): State when the object is newly detected. + Tracked (int): State when the object is successfully tracked in subsequent frames. + Lost (int): State when the object is no longer tracked. + Removed (int): State when the object is removed from tracking. + + Examples: + >>> state = TrackState.New + >>> if state == TrackState.New: + >>> print("Object is newly detected.") + """ + + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack: + """ + Base class for object tracking, providing foundational attributes and methods. + + Attributes: + _count (int): Class-level counter for unique track IDs. + track_id (int): Unique identifier for the track. + is_activated (bool): Flag indicating whether the track is currently active. + state (TrackState): Current state of the track. + history (OrderedDict): Ordered history of the track's states. + features (list): List of features extracted from the object for tracking. + curr_feature (Any): The current feature of the object being tracked. + score (float): The confidence score of the tracking. + start_frame (int): The frame number where tracking started. + frame_id (int): The most recent frame ID processed by the track. + time_since_update (int): Frames passed since the last update. + location (tuple): The location of the object in the context of multi-camera tracking. + + Methods: + end_frame: Returns the ID of the last frame where the object was tracked. + next_id: Increments and returns the next global track ID. + activate: Abstract method to activate the track. + predict: Abstract method to predict the next state of the track. + update: Abstract method to update the track with new data. + mark_lost: Marks the track as lost. + mark_removed: Marks the track as removed. + reset_id: Resets the global track ID counter. + + Examples: + Initialize a new track and mark it as lost: + >>> track = BaseTrack() + >>> track.mark_lost() + >>> print(track.state) # Output: 2 (TrackState.Lost) + """ + + _count = 0 + + def __init__(self): + """ + Initialize a new track with a unique ID and foundational tracking attributes. + + Examples: + Initialize a new track + >>> track = BaseTrack() + >>> print(track.track_id) + 0 + """ + self.track_id = 0 + self.is_activated = False + self.state = TrackState.New + self.history = OrderedDict() + self.features = [] + self.curr_feature = None + self.score = 0 + self.start_frame = 0 + self.frame_id = 0 + self.time_since_update = 0 + self.location = (np.inf, np.inf) + + @property + def end_frame(self): + """Returns the ID of the most recent frame where the object was tracked.""" + return self.frame_id + + @staticmethod + def next_id(): + """Increment and return the next unique global track ID for object tracking.""" + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + """Activates the track with provided arguments, initializing necessary attributes for tracking.""" + raise NotImplementedError + + def predict(self): + """Predicts the next state of the track based on the current state and tracking model.""" + raise NotImplementedError + + def update(self, *args, **kwargs): + """Updates the track with new observations and data, modifying its state and attributes accordingly.""" + raise NotImplementedError + + def mark_lost(self): + """Marks the track as lost by updating its state to TrackState.Lost.""" + self.state = TrackState.Lost + + def mark_removed(self): + """Marks the track as removed by setting its state to TrackState.Removed.""" + self.state = TrackState.Removed + + @staticmethod + def reset_id(): + """Reset the global track ID counter to its initial value.""" + BaseTrack._count = 0 diff --git a/tracking/ultralytics/trackers/bot_sort.py b/tracking/ultralytics/trackers/bot_sort.py new file mode 100644 index 0000000000000000000000000000000000000000..6be2b0bfd3873b8b2f6a5107a4e3ef8de6746c50 --- /dev/null +++ b/tracking/ultralytics/trackers/bot_sort.py @@ -0,0 +1,234 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from collections import deque + +import numpy as np + +from .basetrack import TrackState +from .byte_tracker import BYTETracker, STrack +from .utils import matching +from .utils.gmc import GMC +from .utils.kalman_filter import KalmanFilterXYWH + + +class BOTrack(STrack): + """ + An extended version of the STrack class for YOLOv8, adding object tracking features. + + This class extends the STrack class to include additional functionalities for object tracking, such as feature + smoothing, Kalman filter prediction, and reactivation of tracks. + + Attributes: + shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack. + smooth_feat (np.ndarray): Smoothed feature vector. + curr_feat (np.ndarray): Current feature vector. + features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`. + alpha (float): Smoothing factor for the exponential moving average of features. + mean (np.ndarray): The mean state of the Kalman filter. + covariance (np.ndarray): The covariance matrix of the Kalman filter. + + Methods: + update_features: Update features vector and smooth it using exponential moving average. + predict: Predict the mean and covariance using Kalman filter. + re_activate: Reactivate a track with updated features and optionally new ID. + update: Update the track with new detection and frame ID. + tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`. + multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter. + convert_coords: Convert tlwh bounding box coordinates to xywh format. + tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`. + + Examples: + Create a BOTrack instance and update its features + >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128)) + >>> bo_track.predict() + >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128)) + >>> bo_track.update(new_track, frame_id=2) + """ + + shared_kalman = KalmanFilterXYWH() + + def __init__(self, tlwh, score, cls, feat=None, feat_history=50): + """ + Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features. + + Args: + tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height). + score (float): Confidence score of the detection. + cls (int): Class ID of the detected object. + feat (np.ndarray | None): Feature vector associated with the detection. + feat_history (int): Maximum length of the feature history deque. + + Examples: + Initialize a BOTrack object with bounding box, score, class ID, and feature vector + >>> tlwh = np.array([100, 50, 80, 120]) + >>> score = 0.9 + >>> cls = 1 + >>> feat = np.random.rand(128) + >>> bo_track = BOTrack(tlwh, score, cls, feat) + """ + super().__init__(tlwh, score, cls) + + self.smooth_feat = None + self.curr_feat = None + if feat is not None: + self.update_features(feat) + self.features = deque([], maxlen=feat_history) + self.alpha = 0.9 + + def update_features(self, feat): + """Update the feature vector and apply exponential moving average smoothing.""" + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def predict(self): + """Predict the object's future state using the Kalman filter to update its mean and covariance.""" + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[6] = 0 + mean_state[7] = 0 + + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + def re_activate(self, new_track, frame_id, new_id=False): + """Reactivate a track with updated features and optionally assign a new ID.""" + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + super().re_activate(new_track, frame_id, new_id) + + def update(self, new_track, frame_id): + """Update the track with new detection information and the current frame ID.""" + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + super().update(new_track, frame_id) + + @property + def tlwh(self): + """Return the current bounding box position in `(top left x, top left y, width, height)` format.""" + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[:2] -= ret[2:] / 2 + return ret + + @staticmethod + def multi_predict(stracks): + """Predict the mean and covariance for multiple object tracks using a shared Kalman filter.""" + if len(stracks) <= 0: + return + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][6] = 0 + multi_mean[i][7] = 0 + multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def convert_coords(self, tlwh): + """Convert tlwh bounding box coordinates to xywh format.""" + return self.tlwh_to_xywh(tlwh) + + @staticmethod + def tlwh_to_xywh(tlwh): + """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format.""" + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + return ret + + +class BOTSORT(BYTETracker): + """ + An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm. + + Attributes: + proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections. + appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections. + encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled. + gmc (GMC): An instance of the GMC algorithm for data association. + args (Any): Parsed command-line arguments containing tracking parameters. + + Methods: + get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking. + init_track: Initialize track with detections, scores, and classes. + get_dists: Get distances between tracks and detections using IoU and (optionally) ReID. + multi_predict: Predict and track multiple objects with YOLOv8 model. + reset: Reset the BOTSORT tracker to its initial state. + + Examples: + Initialize BOTSORT and process detections + >>> bot_sort = BOTSORT(args, frame_rate=30) + >>> bot_sort.init_track(dets, scores, cls, img) + >>> bot_sort.multi_predict(tracks) + + Note: + The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args. + """ + + def __init__(self, args, frame_rate=30): + """ + Initialize BOTSORT object with ReID module and GMC algorithm. + + Args: + args (object): Parsed command-line arguments containing tracking parameters. + frame_rate (int): Frame rate of the video being processed. + + Examples: + Initialize BOTSORT with command-line arguments and a specified frame rate: + >>> args = parse_args() + >>> bot_sort = BOTSORT(args, frame_rate=30) + """ + super().__init__(args, frame_rate) + # ReID module + self.proximity_thresh = args.proximity_thresh + self.appearance_thresh = args.appearance_thresh + + if args.with_reid: + # Haven't supported BoT-SORT(reid) yet + self.encoder = None + self.gmc = GMC(method=args.gmc_method) + + def get_kalmanfilter(self): + """Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process.""" + return KalmanFilterXYWH() + + def init_track(self, dets, scores, cls, img=None): + """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features.""" + if len(dets) == 0: + return [] + if self.args.with_reid and self.encoder is not None: + features_keep = self.encoder.inference(img, dets) + return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections + else: + return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections + + def get_dists(self, tracks, detections): + """Calculate distances between tracks and detections using IoU and optionally ReID embeddings.""" + dists = matching.iou_distance(tracks, detections) + dists_mask = dists > self.proximity_thresh + + if self.args.fuse_score: + dists = matching.fuse_score(dists, detections) + + if self.args.with_reid and self.encoder is not None: + emb_dists = matching.embedding_distance(tracks, detections) / 2.0 + emb_dists[emb_dists > self.appearance_thresh] = 1.0 + emb_dists[dists_mask] = 1.0 + dists = np.minimum(dists, emb_dists) + return dists + + def multi_predict(self, tracks): + """Predict the mean and covariance of multiple object tracks using a shared Kalman filter.""" + BOTrack.multi_predict(tracks) + + def reset(self): + """Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states.""" + super().reset() + self.gmc.reset_params() diff --git a/tracking/ultralytics/trackers/byte_tracker.py b/tracking/ultralytics/trackers/byte_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..c639008febcffa5e45ebe0bb8027d6ef127b8db0 --- /dev/null +++ b/tracking/ultralytics/trackers/byte_tracker.py @@ -0,0 +1,476 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np + +from ..utils import LOGGER +from ..utils.ops import xywh2ltwh +from .basetrack import BaseTrack, TrackState +from .utils import matching +from .utils.kalman_filter import KalmanFilterXYAH + + +class STrack(BaseTrack): + """ + Single object tracking representation that uses Kalman filtering for state estimation. + + This class is responsible for storing all the information regarding individual tracklets and performs state updates + and predictions based on Kalman filter. + + Attributes: + shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction. + _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box. + kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track. + mean (np.ndarray): Mean state estimate vector. + covariance (np.ndarray): Covariance of state estimate. + is_activated (bool): Boolean flag indicating if the track has been activated. + score (float): Confidence score of the track. + tracklet_len (int): Length of the tracklet. + cls (Any): Class label for the object. + idx (int): Index or identifier for the object. + frame_id (int): Current frame ID. + start_frame (int): Frame where the object was first detected. + + Methods: + predict(): Predict the next state of the object using Kalman filter. + multi_predict(stracks): Predict the next states for multiple tracks. + multi_gmc(stracks, H): Update multiple track states using a homography matrix. + activate(kalman_filter, frame_id): Activate a new tracklet. + re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet. + update(new_track, frame_id): Update the state of a matched track. + convert_coords(tlwh): Convert bounding box to x-y-aspect-height format. + tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. + + Examples: + Initialize and activate a new track + >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person") + >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1) + """ + + shared_kalman = KalmanFilterXYAH() + + def __init__(self, xywh, score, cls): + """ + Initialize a new STrack instance. + + Args: + xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where + (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id. + score (float): Confidence score of the detection. + cls (Any): Class label for the detected object. + + Examples: + >>> xywh = [100.0, 150.0, 50.0, 75.0, 1] + >>> score = 0.9 + >>> cls = "person" + >>> track = STrack(xywh, score, cls) + """ + super().__init__() + # xywh+idx or xywha+idx + assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}" + self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + self.cls = cls + self.idx = xywh[-1] + self.angle = xywh[4] if len(xywh) == 6 else None + + def predict(self): + """Predicts the next state (mean and covariance) of the object using the Kalman filter.""" + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks): + """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.""" + if len(stracks) <= 0: + return + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + @staticmethod + def multi_gmc(stracks, H=np.eye(2, 3)): + """Update state tracks positions and covariances using a homography matrix for multiple tracks.""" + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + + R = H[:2, :2] + R8x8 = np.kron(np.eye(4, dtype=float), R) + t = H[:2, 2] + + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + mean = R8x8.dot(mean) + mean[:2] += t + cov = R8x8.dot(cov).dot(R8x8.transpose()) + + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + """Reactivates a previously lost track using new detection data and updates its state and attributes.""" + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_track.tlwh) + ) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + self.score = new_track.score + self.cls = new_track.cls + self.angle = new_track.angle + self.idx = new_track.idx + + def update(self, new_track, frame_id): + """ + Update the state of a matched track. + + Args: + new_track (STrack): The new track containing updated information. + frame_id (int): The ID of the current frame. + + Examples: + Update the state of a track with new detection information + >>> track = STrack([100, 200, 50, 80, 0.9, 1]) + >>> new_track = STrack([105, 205, 55, 85, 0.95, 1]) + >>> track.update(new_track, 2) + """ + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_tlwh) + ) + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + self.cls = new_track.cls + self.angle = new_track.angle + self.idx = new_track.idx + + def convert_coords(self, tlwh): + """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.""" + return self.tlwh_to_xyah(tlwh) + + @property + def tlwh(self): + """Returns the bounding box in top-left-width-height format from the current state estimate.""" + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + @property + def xyxy(self): + """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.""" + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @staticmethod + def tlwh_to_xyah(tlwh): + """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.""" + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + @property + def xywh(self): + """Returns the current position of the bounding box in (center x, center y, width, height) format.""" + ret = np.asarray(self.tlwh).copy() + ret[:2] += ret[2:] / 2 + return ret + + @property + def xywha(self): + """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing.""" + if self.angle is None: + LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.") + return self.xywh + return np.concatenate([self.xywh, self.angle[None]]) + + @property + def result(self): + """Returns the current tracking results in the appropriate bounding box format.""" + coords = self.xyxy if self.angle is None else self.xywha + return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] + + def __repr__(self): + """Returns a string representation of the STrack object including start frame, end frame, and track ID.""" + return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" + + +class BYTETracker: + """ + BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking. + + This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a + video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for + predicting the new object locations, and performs data association. + + Attributes: + tracked_stracks (List[STrack]): List of successfully activated tracks. + lost_stracks (List[STrack]): List of lost tracks. + removed_stracks (List[STrack]): List of removed tracks. + frame_id (int): The current frame ID. + args (Namespace): Command-line arguments. + max_time_lost (int): The maximum frames for a track to be considered as 'lost'. + kalman_filter (KalmanFilterXYAH): Kalman Filter object. + + Methods: + update(results, img=None): Updates object tracker with new detections. + get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes. + init_track(dets, scores, cls, img=None): Initialize object tracking with detections. + get_dists(tracks, detections): Calculates the distance between tracks and detections. + multi_predict(tracks): Predicts the location of tracks. + reset_id(): Resets the ID counter of STrack. + joint_stracks(tlista, tlistb): Combines two lists of stracks. + sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list. + remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU. + + Examples: + Initialize BYTETracker and update with detection results + >>> tracker = BYTETracker(args, frame_rate=30) + >>> results = yolo_model.detect(image) + >>> tracked_objects = tracker.update(results) + """ + + def __init__(self, args, frame_rate=30): + """ + Initialize a BYTETracker instance for object tracking. + + Args: + args (Namespace): Command-line arguments containing tracking parameters. + frame_rate (int): Frame rate of the video sequence. + + Examples: + Initialize BYTETracker with command-line arguments and a frame rate of 30 + >>> args = Namespace(track_buffer=30) + >>> tracker = BYTETracker(args, frame_rate=30) + """ + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + + self.frame_id = 0 + self.args = args + self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer) + self.kalman_filter = self.get_kalmanfilter() + self.reset_id() + + def update(self, results, img=None): + """Updates the tracker with new detections and returns the current list of tracked objects.""" + self.frame_id += 1 + activated_stracks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + + scores = results.conf + bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh + # Add index + bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) + cls = results.cls + + remain_inds = scores >= self.args.track_high_thresh + inds_low = scores > self.args.track_low_thresh + inds_high = scores < self.args.track_high_thresh + + inds_second = inds_low & inds_high + dets_second = bboxes[inds_second] + dets = bboxes[remain_inds] + scores_keep = scores[remain_inds] + scores_second = scores[inds_second] + cls_keep = cls[remain_inds] + cls_second = cls[inds_second] + + detections = self.init_track(dets, scores_keep, cls_keep, img) + # Add newly detected tracklets to tracked_stracks + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + unconfirmed.append(track) + else: + tracked_stracks.append(track) + # Step 2: First association, with high score detection boxes + strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + self.multi_predict(strack_pool) + if hasattr(self, "gmc") and img is not None: + warp = self.gmc.apply(img, dets) + STrack.multi_gmc(strack_pool, warp) + STrack.multi_gmc(unconfirmed, warp) + + dists = self.get_dists(strack_pool, detections) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + # Step 3: Second association, with low score detection boxes association the untrack to the low score detections + detections_second = self.init_track(dets_second, scores_second, cls_second, img) + r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] + # TODO + dists = matching.iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if track.state != TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + # Deal with unconfirmed tracks, usually tracks with only one beginning frame + detections = [detections[i] for i in u_detection] + dists = self.get_dists(unconfirmed, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_stracks.append(unconfirmed[itracked]) + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + # Step 4: Init new stracks + for inew in u_detection: + track = detections[inew] + if track.score < self.args.new_track_thresh: + continue + track.activate(self.kalman_filter, self.frame_id) + activated_stracks.append(track) + # Step 5: Update state + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks) + self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks) + self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks) + self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + self.removed_stracks.extend(removed_stracks) + if len(self.removed_stracks) > 1000: + self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum + + return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) + + def get_kalmanfilter(self): + """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.""" + return KalmanFilterXYAH() + + def init_track(self, dets, scores, cls, img=None): + """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm.""" + return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections + + def get_dists(self, tracks, detections): + """Calculates the distance between tracks and detections using IoU and optionally fuses scores.""" + dists = matching.iou_distance(tracks, detections) + if self.args.fuse_score: + dists = matching.fuse_score(dists, detections) + return dists + + def multi_predict(self, tracks): + """Predict the next states for multiple tracks using Kalman filter.""" + STrack.multi_predict(tracks) + + @staticmethod + def reset_id(): + """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions.""" + STrack.reset_id() + + def reset(self): + """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.""" + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + self.frame_id = 0 + self.kalman_filter = self.get_kalmanfilter() + self.reset_id() + + @staticmethod + def joint_stracks(tlista, tlistb): + """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.""" + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + @staticmethod + def sub_stracks(tlista, tlistb): + """Filters out the stracks present in the second list from the first list.""" + track_ids_b = {t.track_id for t in tlistb} + return [t for t in tlista if t.track_id not in track_ids_b] + + @staticmethod + def remove_duplicate_stracks(stracksa, stracksb): + """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance.""" + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = [], [] + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if i not in dupa] + resb = [t for i, t in enumerate(stracksb) if i not in dupb] + return resa, resb diff --git a/tracking/ultralytics/trackers/track.py b/tracking/ultralytics/trackers/track.py new file mode 100644 index 0000000000000000000000000000000000000000..67d0633f418f58b0978cd177a39b928e0b667667 --- /dev/null +++ b/tracking/ultralytics/trackers/track.py @@ -0,0 +1,105 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from functools import partial +from pathlib import Path + +import torch + +from ultralytics.utils import IterableSimpleNamespace, yaml_load +from ultralytics.utils.checks import check_yaml + +from .bot_sort import BOTSORT +from .byte_tracker import BYTETracker + +# A mapping of tracker types to corresponding tracker classes +TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} + + +def on_predict_start(predictor: object, persist: bool = False) -> None: + """ + Initialize trackers for object tracking during prediction. + + Args: + predictor (object): The predictor object to initialize trackers for. + persist (bool): Whether to persist the trackers if they already exist. + + Raises: + AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. + ValueError: If the task is 'classify' as classification doesn't support tracking. + + Examples: + Initialize trackers for a predictor object: + >>> predictor = SomePredictorClass() + >>> on_predict_start(predictor, persist=True) + """ + if predictor.args.task == "classify": + raise ValueError("❌ Classification doesn't support 'mode=track'") + + if hasattr(predictor, "trackers") and persist: + return + + tracker = check_yaml(predictor.args.tracker) + cfg = IterableSimpleNamespace(**yaml_load(tracker)) + + if cfg.tracker_type not in {"bytetrack", "botsort"}: + raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") + + trackers = [] + for _ in range(predictor.dataset.bs): + tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) + trackers.append(tracker) + if predictor.dataset.mode != "stream": # only need one tracker for other modes + break + predictor.trackers = trackers + predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video + + +def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: + """ + Postprocess detected boxes and update with object tracking. + + Args: + predictor (object): The predictor object containing the predictions. + persist (bool): Whether to persist the trackers if they already exist. + + Examples: + Postprocess predictions and update with tracking + >>> predictor = YourPredictorClass() + >>> on_predict_postprocess_end(predictor, persist=True) + """ + is_obb = predictor.args.task == "obb" + is_stream = predictor.dataset.mode == "stream" + for i, result in enumerate(predictor.results): + tracker = predictor.trackers[i if is_stream else 0] + vid_path = predictor.save_dir / Path(result.path).name + if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: + predictor.vid_path[i if is_stream else 0] = vid_path + + det = (result.obb.data if is_obb else result.boxes.data).cpu().numpy() + if len(det) == 0: + continue + tracks = tracker.update(det, result.orig_img) + if len(tracks) == 0: + continue + idx = tracks[:, -1].astype(int) + predictor.results[i] = result[idx] + + update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])} + predictor.results[i].update(**update_args) + + +def register_tracker(model: object, persist: bool) -> None: + """ + Register tracking callbacks to the model for object tracking during prediction. + + Args: + model (object): The model object to register tracking callbacks for. + persist (bool): Whether to persist the trackers if they already exist. + + Examples: + Register tracking callbacks to a YOLO model + >>> model = YOLOModel() + >>> register_tracker(model, persist=True) + """ + model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) + model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) diff --git a/tracking/ultralytics/trackers/utils/__init__.py b/tracking/ultralytics/trackers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/tracking/ultralytics/trackers/utils/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/tracking/ultralytics/trackers/utils/gmc.py b/tracking/ultralytics/trackers/utils/gmc.py new file mode 100644 index 0000000000000000000000000000000000000000..0d64f233c906a69de6e1f1149ff7bcdf9ba4c515 --- /dev/null +++ b/tracking/ultralytics/trackers/utils/gmc.py @@ -0,0 +1,376 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy + +import cv2 +import numpy as np + +from ultralytics.utils import LOGGER + + +class GMC: + """ + Generalized Motion Compensation (GMC) class for tracking and object detection in video frames. + + This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB, + SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency. + + Attributes: + method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. + downscale (int): Factor by which to downscale the frames for processing. + prevFrame (np.ndarray): Previous frame for tracking. + prevKeyPoints (list): Keypoints from the previous frame. + prevDescriptors (np.ndarray): Descriptors from the previous frame. + initializedFirstFrame (bool): Flag indicating if the first frame has been processed. + + Methods: + apply: Apply the chosen method to a raw frame and optionally use provided detections. + apply_ecc: Apply the ECC algorithm to a raw frame. + apply_features: Apply feature-based methods like ORB or SIFT to a raw frame. + apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame. + reset_params: Reset the internal parameters of the GMC object. + + Examples: + Create a GMC object and apply it to a frame + >>> gmc = GMC(method="sparseOptFlow", downscale=2) + >>> frame = np.array([[1, 2, 3], [4, 5, 6]]) + >>> processed_frame = gmc.apply(frame) + >>> print(processed_frame) + array([[1, 2, 3], + [4, 5, 6]]) + """ + + def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: + """ + Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor. + + Args: + method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. + downscale (int): Downscale factor for processing frames. + + Examples: + Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2 + >>> gmc = GMC(method="sparseOptFlow", downscale=2) + """ + super().__init__() + + self.method = method + self.downscale = max(1, downscale) + + if self.method == "orb": + self.detector = cv2.FastFeatureDetector_create(20) + self.extractor = cv2.ORB_create() + self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING) + + elif self.method == "sift": + self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) + self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) + self.matcher = cv2.BFMatcher(cv2.NORM_L2) + + elif self.method == "ecc": + number_of_iterations = 5000 + termination_eps = 1e-6 + self.warp_mode = cv2.MOTION_EUCLIDEAN + self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) + + elif self.method == "sparseOptFlow": + self.feature_params = dict( + maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04 + ) + + elif self.method in {"none", "None", None}: + self.method = None + else: + raise ValueError(f"Error: Unknown GMC method: {method}") + + self.prevFrame = None + self.prevKeyPoints = None + self.prevDescriptors = None + self.initializedFirstFrame = False + + def apply(self, raw_frame: np.ndarray, detections: list = None) -> np.ndarray: + """ + Apply object detection on a raw frame using the specified method. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + detections (List | None): List of detections to be used in the processing. + + Returns: + (np.ndarray): Transformation matrix with shape (2, 3). + + Examples: + >>> gmc = GMC(method="sparseOptFlow") + >>> raw_frame = np.random.rand(480, 640, 3) + >>> transformation_matrix = gmc.apply(raw_frame) + >>> print(transformation_matrix.shape) + (2, 3) + """ + if self.method in {"orb", "sift"}: + return self.apply_features(raw_frame, detections) + elif self.method == "ecc": + return self.apply_ecc(raw_frame) + elif self.method == "sparseOptFlow": + return self.apply_sparseoptflow(raw_frame) + else: + return np.eye(2, 3) + + def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray: + """ + Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + + Returns: + (np.ndarray): Transformation matrix with shape (2, 3). + + Examples: + >>> gmc = GMC(method="ecc") + >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) + >>> print(processed_frame) + [[1. 0. 0.] + [0. 1. 0.]] + """ + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3, dtype=np.float32) + + # Downscale image + if self.downscale > 1.0: + frame = cv2.GaussianBlur(frame, (3, 3), 1.5) + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + + # Initialization done + self.initializedFirstFrame = True + + return H + + # Run the ECC algorithm. The results are stored in warp_matrix. + # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria) + try: + (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) + except Exception as e: + LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}") + + return H + + def apply_features(self, raw_frame: np.ndarray, detections: list = None) -> np.ndarray: + """ + Apply feature-based methods like ORB or SIFT to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + detections (List | None): List of detections to be used in the processing. + + Returns: + (np.ndarray): Transformation matrix with shape (2, 3). + + Examples: + >>> gmc = GMC(method="orb") + >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> transformation_matrix = gmc.apply_features(raw_frame) + >>> print(transformation_matrix.shape) + (2, 3) + """ + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3) + + # Downscale image + if self.downscale > 1.0: + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + width = width // self.downscale + height = height // self.downscale + + # Find the keypoints + mask = np.zeros_like(frame) + mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255 + if detections is not None: + for det in detections: + tlbr = (det[:4] / self.downscale).astype(np.int_) + mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0 + + keypoints = self.detector.detect(frame, mask) + + # Compute the descriptors + keypoints, descriptors = self.extractor.compute(frame, keypoints) + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + # Initialization done + self.initializedFirstFrame = True + + return H + + # Match descriptors + knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2) + + # Filter matches based on smallest spatial distance + matches = [] + spatialDistances = [] + + maxSpatialDistance = 0.25 * np.array([width, height]) + + # Handle empty matches case + if len(knnMatches) == 0: + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + return H + + for m, n in knnMatches: + if m.distance < 0.9 * n.distance: + prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt + currKeyPointLocation = keypoints[m.trainIdx].pt + + spatialDistance = ( + prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1], + ) + + if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and ( + np.abs(spatialDistance[1]) < maxSpatialDistance[1] + ): + spatialDistances.append(spatialDistance) + matches.append(m) + + meanSpatialDistances = np.mean(spatialDistances, 0) + stdSpatialDistances = np.std(spatialDistances, 0) + + inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances + + goodMatches = [] + prevPoints = [] + currPoints = [] + for i in range(len(matches)): + if inliers[i, 0] and inliers[i, 1]: + goodMatches.append(matches[i]) + prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt) + currPoints.append(keypoints[matches[i].trainIdx].pt) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # Draw the keypoint matches on the output image + # if False: + # import matplotlib.pyplot as plt + # matches_img = np.hstack((self.prevFrame, frame)) + # matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR) + # W = self.prevFrame.shape[1] + # for m in goodMatches: + # prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_) + # curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_) + # curr_pt[0] += W + # color = np.random.randint(0, 255, 3) + # color = (int(color[0]), int(color[1]), int(color[2])) + # + # matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA) + # matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1) + # matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1) + # + # plt.figure() + # plt.imshow(matches_img) + # plt.show() + + # Find rigid matrix + if prevPoints.shape[0] > 4: + H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + # Handle downscale + if self.downscale > 1.0: + H[0, 2] *= self.downscale + H[1, 2] *= self.downscale + else: + LOGGER.warning("WARNING: not enough matching points") + + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + return H + + def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray: + """ + Apply Sparse Optical Flow method to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + + Returns: + (np.ndarray): Transformation matrix with shape (2, 3). + + Examples: + >>> gmc = GMC() + >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) + >>> print(result) + [[1. 0. 0.] + [0. 1. 0.]] + """ + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3) + + # Downscale image + if self.downscale > 1.0: + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + + # Find the keypoints + keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params) + + # Handle first frame + if not self.initializedFirstFrame or self.prevKeyPoints is None: + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.initializedFirstFrame = True + return H + + # Find correspondences + matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None) + + # Leave good correspondences only + prevPoints = [] + currPoints = [] + + for i in range(len(status)): + if status[i]: + prevPoints.append(self.prevKeyPoints[i]) + currPoints.append(matchedKeypoints[i]) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # Find rigid matrix + if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]): + H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + if self.downscale > 1.0: + H[0, 2] *= self.downscale + H[1, 2] *= self.downscale + else: + LOGGER.warning("WARNING: not enough matching points") + + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + + return H + + def reset_params(self) -> None: + """Reset the internal parameters including previous frame, keypoints, and descriptors.""" + self.prevFrame = None + self.prevKeyPoints = None + self.prevDescriptors = None + self.initializedFirstFrame = False diff --git a/tracking/ultralytics/trackers/utils/kalman_filter.py b/tracking/ultralytics/trackers/utils/kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..75d6ac2cec1e246ba0aa74c05f74a99b274373cf --- /dev/null +++ b/tracking/ultralytics/trackers/utils/kalman_filter.py @@ -0,0 +1,493 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np +import scipy.linalg + + +class KalmanFilterXYAH: + """ + A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter. + + Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space + (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their + respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is + taken as a direct observation of the state space (linear observation model). + + Attributes: + _motion_mat (np.ndarray): The motion matrix for the Kalman filter. + _update_mat (np.ndarray): The update matrix for the Kalman filter. + _std_weight_position (float): Standard deviation weight for position. + _std_weight_velocity (float): Standard deviation weight for velocity. + + Methods: + initiate: Creates a track from an unassociated measurement. + predict: Runs the Kalman filter prediction step. + project: Projects the state distribution to measurement space. + multi_predict: Runs the Kalman filter prediction step (vectorized version). + update: Runs the Kalman filter correction step. + gating_distance: Computes the gating distance between state distribution and measurements. + + Examples: + Initialize the Kalman filter and create a track from a measurement + >>> kf = KalmanFilterXYAH() + >>> measurement = np.array([100, 200, 1.5, 50]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + >>> print(covariance) + """ + + def __init__(self): + """ + Initialize Kalman filter model matrices with motion and observation uncertainty weights. + + The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y) + represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective + velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear + observation model for bounding box location. + + Examples: + Initialize a Kalman filter for tracking: + >>> kf = KalmanFilterXYAH() + """ + ndim, dt = 4, 1.0 + + # Create Kalman filter model matrices + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current state estimate + self._std_weight_position = 1.0 / 20 + self._std_weight_velocity = 1.0 / 160 + + def initiate(self, measurement: np.ndarray): + """ + Create a track from an unassociated measurement. + + Args: + measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, + and height h. + + Returns: + (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean. + (np.ndarray): Covariance matrix (8x8 dimensional) of the new track. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> measurement = np.array([100, 50, 1.5, 200]) + >>> mean, covariance = kf.initiate(measurement) + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3], + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean: np.ndarray, covariance: np.ndarray): + """ + Run Kalman filter prediction step. + + Args: + mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step. + covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. + + Returns: + (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean. + (np.ndarray): Covariance matrix of the predicted state. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3], + ] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3], + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean: np.ndarray, covariance: np.ndarray): + """ + Project state distribution to measurement space. + + Args: + mean (np.ndarray): The state's mean vector (8 dimensional array). + covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). + + Returns: + (np.ndarray): Projected mean of the given state estimate. + (np.ndarray): Projected covariance matrix of the given state estimate. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> projected_mean, projected_covariance = kf.project(mean, covariance) + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3], + ] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean: np.ndarray, covariance: np.ndarray): + """ + Run Kalman filter prediction step for multiple object states (Vectorized version). + + Args: + mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. + covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. + + Returns: + (np.ndarray): Mean matrix of the predicted states with shape (N, 8). + (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8). + + Examples: + >>> mean = np.random.rand(10, 8) # 10 object states + >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states + >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3], + ] + std_vel = [ + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3], + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray): + """ + Run Kalman filter correction step. + + Args: + mean (np.ndarray): The predicted state's mean vector (8 dimensional). + covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). + measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center + position, a the aspect ratio, and h the height of the bounding box. + + Returns: + (np.ndarray): Measurement-corrected state mean. + (np.ndarray): Measurement-corrected state covariance. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurement = np.array([1, 1, 1, 1]) + >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False + ).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance( + self, + mean: np.ndarray, + covariance: np.ndarray, + measurements: np.ndarray, + only_position: bool = False, + metric: str = "maha", + ) -> np.ndarray: + """ + Compute gating distance between state distribution and measurements. + + A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square + distribution has 4 degrees of freedom, otherwise 2. + + Args: + mean (np.ndarray): Mean vector over the state distribution (8 dimensional). + covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional). + measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the + bounding box center position, a the aspect ratio, and h the height. + only_position (bool): If True, distance computation is done with respect to box center position only. + metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared + Euclidean distance and 'maha' for the squared Mahalanobis distance. + + Returns: + (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between + (mean, covariance) and `measurements[i]`. + + Examples: + Compute gating distance using Mahalanobis metric: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]]) + >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha") + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == "gaussian": + return np.sum(d * d, axis=1) + elif metric == "maha": + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) + return np.sum(z * z, axis=0) # square maha + else: + raise ValueError("Invalid distance metric") + + +class KalmanFilterXYWH(KalmanFilterXYAH): + """ + A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter. + + Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where + (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. + The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct + observation of the state space (linear observation model). + + Attributes: + _motion_mat (np.ndarray): The motion matrix for the Kalman filter. + _update_mat (np.ndarray): The update matrix for the Kalman filter. + _std_weight_position (float): Standard deviation weight for position. + _std_weight_velocity (float): Standard deviation weight for velocity. + + Methods: + initiate: Creates a track from an unassociated measurement. + predict: Runs the Kalman filter prediction step. + project: Projects the state distribution to measurement space. + multi_predict: Runs the Kalman filter prediction step in a vectorized manner. + update: Runs the Kalman filter correction step. + + Examples: + Create a Kalman filter and initialize a track + >>> kf = KalmanFilterXYWH() + >>> measurement = np.array([100, 50, 20, 40]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + >>> print(covariance) + """ + + def initiate(self, measurement: np.ndarray): + """ + Create track from unassociated measurement. + + Args: + measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height. + + Returns: + (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean. + (np.ndarray): Covariance matrix (8x8 dimensional) of the new track. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> measurement = np.array([100, 50, 20, 40]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + [100. 50. 20. 40. 0. 0. 0. 0.] + >>> print(covariance) + [[ 4. 0. 0. 0. 0. 0. 0. 0.] + [ 0. 4. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 4. 0. 0. 0. 0. 0.] + [ 0. 0. 0. 4. 0. 0. 0. 0.] + [ 0. 0. 0. 0. 0.25 0. 0. 0.] + [ 0. 0. 0. 0. 0. 0.25 0. 0.] + [ 0. 0. 0. 0. 0. 0. 0.25 0.] + [ 0. 0. 0. 0. 0. 0. 0. 0.25]] + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """ + Run Kalman filter prediction step. + + Args: + mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step. + covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. + + Returns: + (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean. + (np.ndarray): Covariance matrix of the predicted state. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] + std_vel = [ + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """ + Project state distribution to measurement space. + + Args: + mean (np.ndarray): The state's mean vector (8 dimensional array). + covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). + + Returns: + (np.ndarray): Projected mean of the given state estimate. + (np.ndarray): Projected covariance matrix of the given state estimate. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> projected_mean, projected_cov = kf.project(mean, covariance) + """ + std = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance): + """ + Run Kalman filter prediction step (Vectorized version). + + Args: + mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. + covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. + + Returns: + (np.ndarray): Mean matrix of the predicted states with shape (N, 8). + (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8). + + Examples: + >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors + >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices + >>> kf = KalmanFilterXYWH() + >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + ] + std_vel = [ + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement): + """ + Run Kalman filter correction step. + + Args: + mean (np.ndarray): The predicted state's mean vector (8 dimensional). + covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). + measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center + position, w the width, and h the height of the bounding box. + + Returns: + (np.ndarray): Measurement-corrected state mean. + (np.ndarray): Measurement-corrected state covariance. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurement = np.array([0.5, 0.5, 1.2, 1.2]) + >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) + """ + return super().update(mean, covariance, measurement) diff --git a/tracking/ultralytics/trackers/utils/matching.py b/tracking/ultralytics/trackers/utils/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..8577a97e62915c8c7bb9a08a0cd11a0a3e55aa7e --- /dev/null +++ b/tracking/ultralytics/trackers/utils/matching.py @@ -0,0 +1,157 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np +import scipy +from scipy.spatial.distance import cdist + +from ultralytics.utils.metrics import batch_probiou, bbox_ioa + +try: + import lap # for linear_assignment + + assert lap.__version__ # verify package is not directory +except (ImportError, AssertionError, AttributeError): + from ultralytics.utils.checks import check_requirements + + check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap + import lap + + +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: + """ + Perform linear assignment using either the scipy or lap.lapjv method. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + thresh (float): Threshold for considering an assignment valid. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. + + Returns: + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) + """ + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + + if use_lap: + # Use lap.lapjv + # https://github.com/gatagat/lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + else: + # Use scipy.optimize.linear_sum_assignment + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html + x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y + matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) + if len(matches) == 0: + unmatched_a = list(np.arange(cost_matrix.shape[0])) + unmatched_b = list(np.arange(cost_matrix.shape[1])) + else: + unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0])) + unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def iou_distance(atracks: list, btracks: list) -> np.ndarray: + """ + Compute cost based on Intersection over Union (IoU) between tracks. + + Args: + atracks (List[STrack] | List[np.ndarray]): List of tracks 'a' or bounding boxes. + btracks (List[STrack] | List[np.ndarray]): List of tracks 'b' or bounding boxes. + + Returns: + (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)). + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) + """ + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if len(atlbrs) and len(btlbrs): + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) + return 1 - ious # cost matrix + + +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: + """ + Compute distance between tracks and detections based on embeddings. + + Args: + tracks (List[STrack]): List of tracks, where each track contains embedding features. + detections (List[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. + + Returns: + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features + return cost_matrix + + +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: + """ + Fuse cost matrix with detection scores to produce a single similarity matrix. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (List[BaseTrack]): List of detections, each containing a score attribute. + + Returns: + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) + """ + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + return 1 - fuse_sim # fuse_cost diff --git a/tracking/ultralytics/utils/__init__.py b/tracking/ultralytics/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4931afe2557dffea74fbdfbeaa70c3d2f17a88 --- /dev/null +++ b/tracking/ultralytics/utils/__init__.py @@ -0,0 +1,1366 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import importlib.metadata +import inspect +import json +import logging.config +import os +import platform +import re +import subprocess +import sys +import threading +import time +import uuid +import warnings +from pathlib import Path +from threading import Lock +from types import SimpleNamespace +from typing import Union +from urllib.parse import unquote + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +import tqdm +import yaml + +from ultralytics import __version__ +from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save # for patches + +# PyTorch Multi-GPU DDP Constants +RANK = int(os.getenv("RANK", -1)) +LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html + +# Other Constants +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLO +ASSETS = ROOT / "assets" # default images +ASSETS_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0" # assets GitHub URL +DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml" +DEFAULT_SOL_CFG_PATH = ROOT / "cfg/solutions/default.yaml" # Ultralytics solutions yaml path +NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads +AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode +VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode +TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format +LOGGING_NAME = "ultralytics" +MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans +ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans +PYTHON_VERSION = platform.python_version() +TORCH_VERSION = torch.__version__ +TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision +IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode" +RKNN_CHIPS = frozenset( + { + "rk3588", + "rk3576", + "rk3566", + "rk3568", + "rk3562", + "rv1103", + "rv1106", + "rv1103b", + "rv1106b", + "rk2118", + } +) # Rockchip processors available for export +HELP_MSG = """ + Examples for running Ultralytics: + + 1. Install the ultralytics package: + + pip install ultralytics + + 2. Use the Python SDK: + + from ultralytics import YOLO + + # Load a model + model = YOLO("yolo11n.yaml") # build a new model from scratch + model = YOLO("yolo11n.pt") # load a pretrained model (recommended for training) + + # Use the model + results = model.train(data="coco8.yaml", epochs=3) # train the model + results = model.val() # evaluate model performance on the validation set + results = model("https://ultralytics.com/images/bus.jpg") # predict on an image + success = model.export(format="onnx") # export the model to ONNX format + + 3. Use the command line interface (CLI): + + Ultralytics 'yolo' CLI commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of [detect, segment, classify, pose, obb] + MODE (required) is one of [train, val, predict, export, track, benchmark] + ARGS (optional) are any number of custom "arg=value" pairs like "imgsz=320" that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with "yolo cfg" + + - Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + - Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + - Val a pretrained detection model at batch-size 1 and image size 640: + yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + - Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + + Docs: https://docs.ultralytics.com + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Settings and Environment Variables +torch.set_printoptions(linewidth=320, precision=4, profile="default") +np.set_printoptions(linewidth=320, formatter=dict(float_kind="{:11.5g}".format)) # format short g, %precision=5 +cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) +os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warnings in Colab +os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings +os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs + +if TQDM_RICH := str(os.getenv("YOLO_TQDM_RICH", False)).lower() == "true": + from tqdm import rich + + +class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm): + """ + A custom TQDM progress bar class that extends the original tqdm functionality. + + This class modifies the behavior of the original tqdm progress bar based on global settings and provides + additional customization options. + + Attributes: + disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and + any passed 'disable' argument. + bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not + explicitly set. + + Methods: + __init__: Initializes the TQDM object with custom settings. + + Examples: + >>> from ultralytics.utils import TQDM + >>> for i in TQDM(range(100)): + ... # Your processing code here + ... pass + """ + + def __init__(self, *args, **kwargs): + """ + Initializes a custom TQDM progress bar. + + This class extends the original tqdm class to provide customized behavior for Ultralytics projects. + + Args: + *args (Any): Variable length argument list to be passed to the original tqdm constructor. + **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor. + + Notes: + - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs. + - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs. + + Examples: + >>> from ultralytics.utils import TQDM + >>> for i in TQDM(range(100)): + ... # Your code here + ... pass + """ + warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning) # suppress tqdm.rich warning + kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) + kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed + super().__init__(*args, **kwargs) + + +class SimpleClass: + """ + A simple base class for creating objects with string representations of their attributes. + + This class provides a foundation for creating objects that can be easily printed or represented as strings, + showing all their non-callable attributes. It's useful for debugging and introspection of object states. + + Methods: + __str__: Returns a human-readable string representation of the object. + __repr__: Returns a machine-readable string representation of the object. + __getattr__: Provides a custom attribute access error message with helpful information. + + Examples: + >>> class MyClass(SimpleClass): + ... def __init__(self): + ... self.x = 10 + ... self.y = "hello" + >>> obj = MyClass() + >>> print(obj) + __main__.MyClass object with attributes: + + x: 10 + y: 'hello' + + Notes: + - This class is designed to be subclassed. It provides a convenient way to inspect object attributes. + - The string representation includes the module and class name of the object. + - Callable attributes and attributes starting with an underscore are excluded from the string representation. + """ + + def __str__(self): + """Return a human-readable string representation of the object.""" + attr = [] + for a in dir(self): + v = getattr(self, a) + if not callable(v) and not a.startswith("_"): + if isinstance(v, SimpleClass): + # Display only the module and class name for subclasses + s = f"{a}: {v.__module__}.{v.__class__.__name__} object" + else: + s = f"{a}: {repr(v)}" + attr.append(s) + return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr) + + def __repr__(self): + """Return a machine-readable string representation of the object.""" + return self.__str__() + + def __getattr__(self, attr): + """Custom attribute access error message with helpful information.""" + name = self.__class__.__name__ + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + +class IterableSimpleNamespace(SimpleNamespace): + """ + An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration. + + This class extends the SimpleNamespace class with additional methods for iteration, string representation, + and attribute access. It is designed to be used as a convenient container for storing and accessing + configuration parameters. + + Methods: + __iter__: Returns an iterator of key-value pairs from the namespace's attributes. + __str__: Returns a human-readable string representation of the object. + __getattr__: Provides a custom attribute access error message with helpful information. + get: Retrieves the value of a specified key, or a default value if the key doesn't exist. + + Examples: + >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3) + >>> for k, v in cfg: + ... print(f"{k}: {v}") + a: 1 + b: 2 + c: 3 + >>> print(cfg) + a=1 + b=2 + c=3 + >>> cfg.get("b") + 2 + >>> cfg.get("d", "default") + 'default' + + Notes: + This class is particularly useful for storing configuration parameters in a more accessible + and iterable format compared to a standard dictionary. + """ + + def __iter__(self): + """Return an iterator of key-value pairs from the namespace's attributes.""" + return iter(vars(self).items()) + + def __str__(self): + """Return a human-readable string representation of the object.""" + return "\n".join(f"{k}={v}" for k, v in vars(self).items()) + + def __getattr__(self, attr): + """Custom attribute access error message with helpful information.""" + name = self.__class__.__name__ + raise AttributeError( + f""" + '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics + 'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace + {DEFAULT_CFG_PATH} with the latest version from + https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml + """ + ) + + def get(self, key, default=None): + """Return the value of the specified key if it exists; otherwise, return the default value.""" + return getattr(self, key, default) + + +def plt_settings(rcparams=None, backend="Agg"): + """ + Decorator to temporarily set rc parameters and the backend for a plotting function. + + Args: + rcparams (dict, optional): Dictionary of rc parameters to set. + backend (str, optional): Name of the backend to use. Defaults to 'Agg'. + + Returns: + (Callable): Decorated function with temporarily set rc parameters and backend. + + Examples: + >>> @plt_settings({"font.size": 12}) + >>> def plot_function(): + ... plt.figure() + ... plt.plot([1, 2, 3]) + ... plt.show() + + >>> with plt_settings({"font.size": 12}): + ... plt.figure() + ... plt.plot([1, 2, 3]) + ... plt.show() + """ + if rcparams is None: + rcparams = {"font.size": 11} + + def decorator(func): + """Decorator to apply temporary rc parameters and backend to a function.""" + + def wrapper(*args, **kwargs): + """Sets rc parameters and backend, calls the original function, and restores the settings.""" + original_backend = plt.get_backend() + switch = backend.lower() != original_backend.lower() + if switch: + plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8 + plt.switch_backend(backend) + + # Plot with backend and always revert to original backend + try: + with plt.rc_context(rcparams): + result = func(*args, **kwargs) + finally: + if switch: + plt.close("all") + plt.switch_backend(original_backend) + return result + + return wrapper + + return decorator + + +def set_logging(name="LOGGING_NAME", verbose=True): + """ + Sets up logging with UTF-8 encoding and configurable verbosity. + + This function configures logging for the Ultralytics library, setting the appropriate logging level and + formatter based on the verbosity flag and the current process rank. It handles special cases for Windows + environments where UTF-8 encoding might not be the default. + + Args: + name (str): Name of the logger. Defaults to "LOGGING_NAME". + verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True. + + Returns: + (logging.Logger): Configured logger object. + + Examples: + >>> set_logging(name="ultralytics", verbose=True) + >>> logger = logging.getLogger("ultralytics") + >>> logger.info("This is an info message") + + Notes: + - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible. + - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments. + - The function sets up a StreamHandler with the appropriate formatter and level. + - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers. + """ + level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings + + # Configure the console (stdout) encoding to UTF-8, with checks for compatibility + formatter = logging.Formatter("%(message)s") # Default formatter + if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8": + + class CustomFormatter(logging.Formatter): + def format(self, record): + """Format log records with UTF-8 encoding for Windows compatibility.""" + return emojis(super().format(record)) + + try: + # Attempt to reconfigure stdout to use UTF-8 encoding if possible + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8") + # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper + elif hasattr(sys.stdout, "buffer"): + import io + + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + else: + formatter = CustomFormatter("%(message)s") + except Exception as e: + print(f"Creating custom formatter for non UTF-8 environments due to {e}") + formatter = CustomFormatter("%(message)s") + + # Create and configure the StreamHandler with the appropriate formatter and level + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + stream_handler.setLevel(level) + + # Set up the logger + logger = logging.getLogger(name) + logger.setLevel(level) + logger.addHandler(stream_handler) + logger.propagate = False + return logger + + +# Set logger +LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.) +for logger in "sentry_sdk", "urllib3.connectionpool": + logging.getLogger(logger).setLevel(logging.CRITICAL + 1) + + +def emojis(string=""): + """Return platform-dependent emoji-safe version of string.""" + return string.encode().decode("ascii", "ignore") if WINDOWS else string + + +class ThreadingLocked: + """ + A decorator class for ensuring thread-safe execution of a function or method. + + This class can be used as a decorator to make sure that if the decorated function is called from multiple threads, + only one thread at a time will be able to execute the function. + + Attributes: + lock (threading.Lock): A lock object used to manage access to the decorated function. + + Examples: + >>> from ultralytics.utils import ThreadingLocked + >>> @ThreadingLocked() + >>> def my_function(): + ... # Your code here + """ + + def __init__(self): + """Initialize the decorator class with a threading lock.""" + self.lock = threading.Lock() + + def __call__(self, f): + """Run thread-safe execution of function or method.""" + from functools import wraps + + @wraps(f) + def decorated(*args, **kwargs): + """Applies thread-safety to the decorated function or method.""" + with self.lock: + return f(*args, **kwargs) + + return decorated + + +def yaml_save(file="data.yaml", data=None, header=""): + """ + Save YAML data to a file. + + Args: + file (str, optional): File name. Default is 'data.yaml'. + data (dict): Data to save in YAML format. + header (str, optional): YAML header to add. + + Returns: + (None): Data is saved to the specified file. + """ + if data is None: + data = {} + file = Path(file) + if not file.parent.exists(): + # Create parent directories if they don't exist + file.parent.mkdir(parents=True, exist_ok=True) + + # Convert Path objects to strings + valid_types = int, float, str, bool, list, tuple, dict, type(None) + for k, v in data.items(): + if not isinstance(v, valid_types): + data[k] = str(v) + + # Dump data to file in YAML format + with open(file, "w", errors="ignore", encoding="utf-8") as f: + if header: + f.write(header) + yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True) + + +def yaml_load(file="data.yaml", append_filename=False): + """ + Load YAML data from a file. + + Args: + file (str, optional): File name. Default is 'data.yaml'. + append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False. + + Returns: + (dict): YAML data and file name. + """ + assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()" + with open(file, errors="ignore", encoding="utf-8") as f: + s = f.read() # string + + # Remove special characters + if not s.isprintable(): + s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s) + + # Add YAML filename to dict and return + data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files) + if append_filename: + data["yaml_file"] = str(file) + return data + + +def yaml_print(yaml_file: Union[str, Path, dict]) -> None: + """ + Pretty prints a YAML file or a YAML-formatted dictionary. + + Args: + yaml_file: The file path of the YAML file or a YAML-formatted dictionary. + + Returns: + (None) + """ + yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file + dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=float("inf")) + LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}") + + +# Default configuration +DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH) +DEFAULT_SOL_DICT = yaml_load(DEFAULT_SOL_CFG_PATH) # Ultralytics solutions configuration +for k, v in DEFAULT_CFG_DICT.items(): + if isinstance(v, str) and v.lower() == "none": + DEFAULT_CFG_DICT[k] = None +DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() +DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) + + +def read_device_model() -> str: + """ + Reads the device model information from the system and caches it for quick access. + + Returns: + (str): Kernel release information. + """ + return platform.release().lower() + + +def is_ubuntu() -> bool: + """ + Check if the OS is Ubuntu. + + Returns: + (bool): True if OS is Ubuntu, False otherwise. + """ + try: + with open("/etc/os-release") as f: + return "ID=ubuntu" in f.read() + except FileNotFoundError: + return False + + +def is_colab(): + """ + Check if the current script is running inside a Google Colab notebook. + + Returns: + (bool): True if running inside a Colab notebook, False otherwise. + """ + return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ + + +def is_kaggle(): + """ + Check if the current script is running inside a Kaggle kernel. + + Returns: + (bool): True if running inside a Kaggle kernel, False otherwise. + """ + return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com" + + +def is_jupyter(): + """ + Check if the current script is running inside a Jupyter Notebook. + + Returns: + (bool): True if running inside a Jupyter Notebook, False otherwise. + + Note: + - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable. + - "get_ipython" in globals() method suffers false positives when IPython package installed manually. + """ + return IS_COLAB or IS_KAGGLE + + +def is_runpod(): + """ + Check if the current script is running inside a RunPod container. + + Returns: + (bool): True if running in RunPod, False otherwise. + """ + return "RUNPOD_POD_ID" in os.environ + + +def is_docker() -> bool: + """ + Determine if the script is running inside a Docker container. + + Returns: + (bool): True if the script is running inside a Docker container, False otherwise. + """ + try: + with open("/proc/self/cgroup") as f: + return "docker" in f.read() + except Exception: + return False + + +def is_raspberrypi() -> bool: + """ + Determines if the Python environment is running on a Raspberry Pi. + + Returns: + (bool): True if running on a Raspberry Pi, False otherwise. + """ + return "rpi" in DEVICE_MODEL + + +def is_jetson() -> bool: + """ + Determines if the Python environment is running on an NVIDIA Jetson device. + + Returns: + (bool): True if running on an NVIDIA Jetson device, False otherwise. + """ + return "tegra" in DEVICE_MODEL + + +def is_online() -> bool: + """ + Check internet connectivity by attempting to connect to a known online host. + + Returns: + (bool): True if connection is successful, False otherwise. + """ + try: + assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True" + import socket + + for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS + socket.create_connection(address=(dns, 80), timeout=2.0).close() + return True + except Exception: + return False + + +def is_pip_package(filepath: str = __name__) -> bool: + """ + Determines if the file at the given filepath is part of a pip package. + + Args: + filepath (str): The filepath to check. + + Returns: + (bool): True if the file is part of a pip package, False otherwise. + """ + import importlib.util + + # Get the spec for the module + spec = importlib.util.find_spec(filepath) + + # Return whether the spec is not None and the origin is not None (indicating it is a package) + return spec is not None and spec.origin is not None + + +def is_dir_writeable(dir_path: Union[str, Path]) -> bool: + """ + Check if a directory is writeable. + + Args: + dir_path (str | Path): The path to the directory. + + Returns: + (bool): True if the directory is writeable, False otherwise. + """ + return os.access(str(dir_path), os.W_OK) + + +def is_pytest_running(): + """ + Determines whether pytest is currently running or not. + + Returns: + (bool): True if pytest is running, False otherwise. + """ + return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem) + + +def is_github_action_running() -> bool: + """ + Determine if the current environment is a GitHub Actions runner. + + Returns: + (bool): True if the current environment is a GitHub Actions runner, False otherwise. + """ + return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ + + +def get_git_dir(): + """ + Determines whether the current file is part of a git repository and if so, returns the repository root directory. + + Returns: + (Path | None): Git root directory if found or None if not found. + """ + for d in Path(__file__).parents: + if (d / ".git").is_dir(): + return d + + +def is_git_dir(): + """ + Determines whether the current file is part of a git repository. + + Returns: + (bool): True if current file is part of a git repository. + """ + return GIT_DIR is not None + + +def get_git_origin_url(): + """ + Retrieves the origin URL of a git repository. + + Returns: + (str | None): The origin URL of the git repository or None if not git directory. + """ + if IS_GIT_DIR: + try: + origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"]) + return origin.decode().strip() + except subprocess.CalledProcessError: + return None + + +def get_git_branch(): + """ + Returns the current git branch name. If not in a git repository, returns None. + + Returns: + (str | None): The current git branch name or None if not a git directory. + """ + if IS_GIT_DIR: + try: + origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + return origin.decode().strip() + except subprocess.CalledProcessError: + return None + + +def get_default_args(func): + """ + Returns a dictionary of default arguments for a function. + + Args: + func (callable): The function to inspect. + + Returns: + (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter. + """ + signature = inspect.signature(func) + return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} + + +def get_ubuntu_version(): + """ + Retrieve the Ubuntu version if the OS is Ubuntu. + + Returns: + (str): Ubuntu version or None if not an Ubuntu OS. + """ + if is_ubuntu(): + try: + with open("/etc/os-release") as f: + return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1] + except (FileNotFoundError, AttributeError): + return None + + +def get_user_config_dir(sub_dir="Ultralytics"): + """ + Return the appropriate config directory based on the environment operating system. + + Args: + sub_dir (str): The name of the subdirectory to create. + + Returns: + (Path): The path to the user config directory. + """ + if WINDOWS: + path = Path.home() / "AppData" / "Roaming" / sub_dir + elif MACOS: # macOS + path = Path.home() / "Library" / "Application Support" / sub_dir + elif LINUX: + path = Path.home() / ".config" / sub_dir + else: + raise ValueError(f"Unsupported operating system: {platform.system()}") + + # GCP and AWS lambda fix, only /tmp is writeable + if not is_dir_writeable(path.parent): + LOGGER.warning( + f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." + "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path." + ) + path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir + + # Create the subdirectory if it does not exist + path.mkdir(parents=True, exist_ok=True) + + return path + + +# Define constants (required below) +DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant +ONLINE = is_online() +IS_COLAB = is_colab() +IS_KAGGLE = is_kaggle() +IS_DOCKER = is_docker() +IS_JETSON = is_jetson() +IS_JUPYTER = is_jupyter() +IS_PIP_PACKAGE = is_pip_package() +IS_RASPBERRYPI = is_raspberrypi() +GIT_DIR = get_git_dir() +IS_GIT_DIR = is_git_dir() +USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir +SETTINGS_FILE = USER_CONFIG_DIR / "settings.json" + + +def colorstr(*input): + r""" + Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes. + See https://en.wikipedia.org/wiki/ANSI_escape_code for more details. + + This function can be called in two ways: + - colorstr('color', 'style', 'your string') + - colorstr('your string') + + In the second form, 'blue' and 'bold' will be applied by default. + + Args: + *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments, + and the last string is the one to be colored. + + Supported Colors and Styles: + Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white' + Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow', + 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white' + Misc: 'end', 'bold', 'underline' + + Returns: + (str): The input string wrapped with ANSI escape codes for the specified color and style. + + Examples: + >>> colorstr("blue", "bold", "hello world") + >>> "\033[34m\033[1mhello world\033[0m" + """ + *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string + colors = { + "black": "\033[30m", # basic colors + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + "bright_black": "\033[90m", # bright colors + "bright_red": "\033[91m", + "bright_green": "\033[92m", + "bright_yellow": "\033[93m", + "bright_blue": "\033[94m", + "bright_magenta": "\033[95m", + "bright_cyan": "\033[96m", + "bright_white": "\033[97m", + "end": "\033[0m", # misc + "bold": "\033[1m", + "underline": "\033[4m", + } + return "".join(colors[x] for x in args) + f"{string}" + colors["end"] + + +def remove_colorstr(input_string): + """ + Removes ANSI escape codes from a string, effectively un-coloring it. + + Args: + input_string (str): The string to remove color and style from. + + Returns: + (str): A new string with all ANSI escape codes removed. + + Examples: + >>> remove_colorstr(colorstr("blue", "bold", "hello world")) + >>> "hello world" + """ + ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") + return ansi_escape.sub("", input_string) + + +class TryExcept(contextlib.ContextDecorator): + """ + Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager. + + Examples: + As a decorator: + >>> @TryExcept(msg="Error occurred in func", verbose=True) + >>> def func(): + >>> # Function logic here + >>> pass + + As a context manager: + >>> with TryExcept(msg="Error occurred in block", verbose=True): + >>> # Code block here + >>> pass + """ + + def __init__(self, msg="", verbose=True): + """Initialize TryExcept class with optional message and verbosity settings.""" + self.msg = msg + self.verbose = verbose + + def __enter__(self): + """Executes when entering TryExcept context, initializes instance.""" + pass + + def __exit__(self, exc_type, value, traceback): + """Defines behavior when exiting a 'with' block, prints error message if necessary.""" + if self.verbose and value: + LOGGER.warning(f"{self.msg}{': ' if self.msg else ''}{value}") + return True + + +class Retry(contextlib.ContextDecorator): + """ + Retry class for function execution with exponential backoff. + + Can be used as a decorator to retry a function on exceptions, up to a specified number of times with an + exponentially increasing delay between retries. + + Examples: + Example usage as a decorator: + >>> @Retry(times=3, delay=2) + >>> def test_func(): + >>> # Replace with function logic that may raise exceptions + >>> return True + """ + + def __init__(self, times=3, delay=2): + """Initialize Retry class with specified number of retries and delay.""" + self.times = times + self.delay = delay + self._attempts = 0 + + def __call__(self, func): + """Decorator implementation for Retry with exponential backoff.""" + + def wrapped_func(*args, **kwargs): + """Applies retries to the decorated function or method.""" + self._attempts = 0 + while self._attempts < self.times: + try: + return func(*args, **kwargs) + except Exception as e: + self._attempts += 1 + print(f"Retry {self._attempts}/{self.times} failed: {e}") + if self._attempts >= self.times: + raise e + time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay + + return wrapped_func + + +def threaded(func): + """ + Multi-threads a target function by default and returns the thread or function result. + + This decorator provides flexible execution of the target function, either in a separate thread or synchronously. + By default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument + which is removed from kwargs before calling the function. + + Args: + func (callable): The function to be potentially executed in a separate thread. + + Returns: + (callable): A wrapper function that either returns a daemon thread or the direct function result. + + Example: + >>> @threaded + ... def process_data(data): + ... return data + >>> + >>> thread = process_data(my_data) # Runs in background thread + >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result + """ + + def wrapper(*args, **kwargs): + """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result.""" + if kwargs.pop("threaded", True): # run in thread + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + else: + return func(*args, **kwargs) + + return wrapper + + +def set_sentry(): + """ + Initialize the Sentry SDK for error tracking and reporting. + + Only used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update + settings. + + Conditions required to send errors (ALL conditions must be met or no errors will be reported): + - sentry_sdk package is installed + - sync=True in YOLO settings + - pytest is not running + - running in a pip package installation + - running in a non-git directory + - running with rank -1 or 0 + - online environment + - CLI used to run package (checked with 'yolo' as the name of the main CLI command) + """ + if ( + not SETTINGS["sync"] + or RANK not in {-1, 0} + or Path(ARGV[0]).name != "yolo" + or TESTS_RUNNING + or not ONLINE + or not IS_PIP_PACKAGE + or IS_GIT_DIR + ): + return + # If sentry_sdk package is not installed then return and do not use Sentry + try: + import sentry_sdk # noqa + except ImportError: + return + + def before_send(event, hint): + """ + Modify the event before sending it to Sentry based on specific exception types and messages. + + Args: + event (dict): The event dictionary containing information about the error. + hint (dict): A dictionary containing additional information about the error. + + Returns: + dict: The modified event or None if the event should not be sent to Sentry. + """ + if "exc_info" in hint: + exc_type, exc_value, _ = hint["exc_info"] + if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value): + return None # do not send event + + event["tags"] = { + "sys_argv": ARGV[0], + "sys_argv_name": Path(ARGV[0]).name, + "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "os": ENVIRONMENT, + } + return event + + sentry_sdk.init( + dsn="https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016", + debug=False, + auto_enabling_integrations=False, + traces_sample_rate=1.0, + release=__version__, + environment="runpod" if is_runpod() else "production", + before_send=before_send, + ignore_errors=[KeyboardInterrupt, FileNotFoundError], + ) + sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash + + +class JSONDict(dict): + """ + A dictionary-like class that provides JSON persistence for its contents. + + This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are + modified. It ensures thread-safe operations using a lock. + + Attributes: + file_path (Path): The path to the JSON file used for persistence. + lock (threading.Lock): A lock object to ensure thread-safe operations. + + Methods: + _load: Loads the data from the JSON file into the dictionary. + _save: Saves the current state of the dictionary to the JSON file. + __setitem__: Stores a key-value pair and persists it to disk. + __delitem__: Removes an item and updates the persistent storage. + update: Updates the dictionary and persists changes. + clear: Clears all entries and updates the persistent storage. + + Examples: + >>> json_dict = JSONDict("data.json") + >>> json_dict["key"] = "value" + >>> print(json_dict["key"]) + value + >>> del json_dict["key"] + >>> json_dict.update({"new_key": "new_value"}) + >>> json_dict.clear() + """ + + def __init__(self, file_path: Union[str, Path] = "data.json"): + """Initialize a JSONDict object with a specified file path for JSON persistence.""" + super().__init__() + self.file_path = Path(file_path) + self.lock = Lock() + self._load() + + def _load(self): + """Load the data from the JSON file into the dictionary.""" + try: + if self.file_path.exists(): + with open(self.file_path) as f: + self.update(json.load(f)) + except json.JSONDecodeError: + print(f"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.") + except Exception as e: + print(f"Error reading from {self.file_path}: {e}") + + def _save(self): + """Save the current state of the dictionary to the JSON file.""" + try: + self.file_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.file_path, "w", encoding="utf-8") as f: + json.dump(dict(self), f, indent=2, default=self._json_default) + except Exception as e: + print(f"Error writing to {self.file_path}: {e}") + + @staticmethod + def _json_default(obj): + """Handle JSON serialization of Path objects.""" + if isinstance(obj, Path): + return str(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + def __setitem__(self, key, value): + """Store a key-value pair and persist to disk.""" + with self.lock: + super().__setitem__(key, value) + self._save() + + def __delitem__(self, key): + """Remove an item and update the persistent storage.""" + with self.lock: + super().__delitem__(key) + self._save() + + def __str__(self): + """Return a pretty-printed JSON string representation of the dictionary.""" + contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default) + return f'JSONDict("{self.file_path}"):\n{contents}' + + def update(self, *args, **kwargs): + """Update the dictionary and persist changes.""" + with self.lock: + super().update(*args, **kwargs) + self._save() + + def clear(self): + """Clear all entries and update the persistent storage.""" + with self.lock: + super().clear() + self._save() + + +class SettingsManager(JSONDict): + """ + SettingsManager class for managing and persisting Ultralytics settings. + + This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default + values. It validates settings on initialization and provides methods to update or reset settings. + + Attributes: + file (Path): The path to the JSON file used for persistence. + version (str): The version of the settings schema. + defaults (dict): A dictionary containing default settings. + help_msg (str): A help message for users on how to view and update settings. + + Methods: + _validate_settings: Validates the current settings and resets if necessary. + update: Updates settings, validating keys and types. + reset: Resets the settings to default and saves them. + + Examples: + Initialize and update settings: + >>> settings = SettingsManager() + >>> settings.update(runs_dir="/new/runs/dir") + >>> print(settings["runs_dir"]) + /new/runs/dir + """ + + def __init__(self, file=SETTINGS_FILE, version="0.0.6"): + """Initializes the SettingsManager with default settings and loads user settings.""" + import hashlib + + from ultralytics.utils.torch_utils import torch_distributed_zero_first + + root = GIT_DIR or Path() + datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve() + + self.file = Path(file) + self.version = version + self.defaults = { + "settings_version": version, # Settings schema version + "datasets_dir": str(datasets_root / "datasets"), # Datasets directory + "weights_dir": str(root / "weights"), # Model weights directory + "runs_dir": str(root / "runs"), # Experiment runs directory + "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash + "sync": True, # Enable synchronization + "api_key": "", # Ultralytics API Key + "openai_api_key": "", # OpenAI API Key + "clearml": True, # ClearML integration + "comet": True, # Comet integration + "dvc": True, # DVC integration + "hub": True, # Ultralytics HUB integration + "mlflow": True, # MLflow integration + "neptune": True, # Neptune integration + "raytune": True, # Ray Tune integration + "tensorboard": True, # TensorBoard logging + "wandb": False, # Weights & Biases logging + "vscode_msg": True, # VSCode messaging + } + + self.help_msg = ( + f"\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'" + "\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. " + "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings." + ) + + with torch_distributed_zero_first(RANK): + super().__init__(self.file) + + if not self.file.exists() or not self: # Check if file doesn't exist or is empty + LOGGER.info(f"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}") + self.reset() + + self._validate_settings() + + def _validate_settings(self): + """Validate the current settings and reset if necessary.""" + correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys()) + correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items()) + correct_version = self.get("settings_version", "") == self.version + + if not (correct_keys and correct_types and correct_version): + LOGGER.warning( + "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem " + f"with your settings or a recent ultralytics package update. {self.help_msg}" + ) + self.reset() + + if self.get("datasets_dir") == self.get("runs_dir"): + LOGGER.warning( + f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' " + f"must be different than 'runs_dir: {self.get('runs_dir')}'. " + f"Please change one to avoid possible issues during training. {self.help_msg}" + ) + + def __setitem__(self, key, value): + """Updates one key: value pair.""" + self.update({key: value}) + + def update(self, *args, **kwargs): + """Updates settings, validating keys and types.""" + for arg in args: + if isinstance(arg, dict): + kwargs.update(arg) + for k, v in kwargs.items(): + if k not in self.defaults: + raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}") + t = type(self.defaults[k]) + if not isinstance(v, t): + raise TypeError( + f"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}" + ) + super().update(*args, **kwargs) + + def reset(self): + """Resets the settings to default and saves them.""" + self.clear() + self.update(self.defaults) + + +def deprecation_warn(arg, new_arg=None): + """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.""" + msg = f"WARNING ⚠️ '{arg}' is deprecated and will be removed in in the future." + if new_arg is not None: + msg += f" Use '{new_arg}' instead." + LOGGER.warning(msg) + + +def clean_url(url): + """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.""" + url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows + return unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth + + +def url2file(url): + """Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.""" + return Path(clean_url(url)).name + + +def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str: + """Display a message to install Ultralytics-Snippets for VS Code if not already installed.""" + path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / ".vscode/extensions" + obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains + installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "") + url = "https://docs.ultralytics.com/integrations/vscode" + return "" if installed else f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}" + + +# Run below code on utils init ------------------------------------------------------------------------------------ + +# Check first-install steps +PREFIX = colorstr("Ultralytics: ") +SETTINGS = SettingsManager() # initialize settings +PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / "persistent_cache.json") # initialize persistent cache +DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory +WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory +RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory +ENVIRONMENT = ( + "Colab" + if IS_COLAB + else "Kaggle" + if IS_KAGGLE + else "Jupyter" + if IS_JUPYTER + else "Docker" + if IS_DOCKER + else platform.system() +) +TESTS_RUNNING = is_pytest_running() or is_github_action_running() +set_sentry() + +# Apply monkey patches +torch.load = torch_load +torch.save = torch_save +if WINDOWS: + # Apply cv2 patches for non-ASCII and non-UTF characters in image paths + cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow diff --git a/tracking/ultralytics/utils/autobatch.py b/tracking/ultralytics/utils/autobatch.py new file mode 100644 index 0000000000000000000000000000000000000000..87c7e41c24ec12299ffcec47e446a5e3ed35455c --- /dev/null +++ b/tracking/ultralytics/utils/autobatch.py @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.""" + +import os +from copy import deepcopy + +import numpy as np +import torch + +from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr +from ultralytics.utils.torch_utils import autocast, profile + + +def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1): + """ + Compute optimal YOLO training batch size using the autobatch() function. + + Args: + model (torch.nn.Module): YOLO model to check batch size for. + imgsz (int, optional): Image size used for training. + amp (bool, optional): Use automatic mixed precision if True. + batch (float, optional): Fraction of GPU memory to use. If -1, use default. + max_num_obj (int, optional): The maximum number of objects from dataset. + + Returns: + (int): Optimal batch size computed using the autobatch() function. + + Notes: + If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use. + Otherwise, a default fraction of 0.6 is used. + """ + with autocast(enabled=amp): + return autobatch( + deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj + ) + + +def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1): + """ + Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. + + Args: + model (torch.nn.Module): YOLO model to compute batch size for. + imgsz (int, optional): The image size used as input for the YOLO model. + fraction (float, optional): The fraction of available CUDA memory to use. + batch_size (int, optional): The default batch size to use if an error is detected. + max_num_obj (int, optional): The maximum number of objects from dataset. + + Returns: + (int): The optimal batch size. + """ + # Check device + prefix = colorstr("AutoBatch: ") + LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.") + device = next(model.parameters()).device # get model device + if device.type in {"cpu", "mps"}: + LOGGER.info(f"{prefix} ⚠️ intended for CUDA devices, using default batch-size {batch_size}") + return batch_size + if torch.backends.cudnn.benchmark: + LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}") + return batch_size + + # Inspect CUDA memory + gb = 1 << 30 # bytes to GiB (1024 ** 3) + d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0' + properties = torch.cuda.get_device_properties(device) # device properties + t = properties.total_memory / gb # GiB total + r = torch.cuda.memory_reserved(device) / gb # GiB reserved + a = torch.cuda.memory_allocated(device) / gb # GiB allocated + f = t - (r + a) # GiB free + LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free") + + # Profile batch sizes + batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64] + try: + img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] + results = profile(img, model, n=1, device=device, max_num_obj=max_num_obj) + + # Fit a solution + xy = [ + [x, y[2]] + for i, (x, y) in enumerate(zip(batch_sizes, results)) + if y # valid result + and isinstance(y[2], (int, float)) # is numeric + and 0 < y[2] < t # between 0 and GPU limit + and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory + ] + fit_x, fit_y = zip(*xy) if xy else ([], []) + p = np.polyfit(np.log(fit_x), np.log(fit_y), deg=1) # first-degree polynomial fit in log space + b = int(round(np.exp((np.log(f * fraction) - p[1]) / p[0]))) # y intercept (optimal batch size) + if None in results: # some sizes failed + i = results.index(None) # first fail index + if b >= batch_sizes[i]: # y intercept above failure point + b = batch_sizes[max(i - 1, 0)] # select prior safe point + if b < 1 or b > 1024: # b outside of safe range + LOGGER.info(f"{prefix}WARNING ⚠️ batch={b} outside safe range, using default batch-size {batch_size}.") + b = batch_size + + fraction = (np.exp(np.polyval(p, np.log(b))) + r + a) / t # predicted fraction + LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅") + return b + except Exception as e: + LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.") + return batch_size + finally: + torch.cuda.empty_cache() diff --git a/tracking/ultralytics/utils/benchmarks.py b/tracking/ultralytics/utils/benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7084d2c1aff4572cdaf94428bdd7617c3964d8 --- /dev/null +++ b/tracking/ultralytics/utils/benchmarks.py @@ -0,0 +1,699 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Benchmark a YOLO model formats for speed and accuracy. + +Usage: + from ultralytics.utils.benchmarks import ProfileModels, benchmark + ProfileModels(['yolo11n.yaml', 'yolov8s.yaml']).profile() + benchmark(model='yolo11n.pt', imgsz=160) + +Format | `format=argument` | Model +--- | --- | --- +PyTorch | - | yolo11n.pt +TorchScript | `torchscript` | yolo11n.torchscript +ONNX | `onnx` | yolo11n.onnx +OpenVINO | `openvino` | yolo11n_openvino_model/ +TensorRT | `engine` | yolo11n.engine +CoreML | `coreml` | yolo11n.mlpackage +TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/ +TensorFlow GraphDef | `pb` | yolo11n.pb +TensorFlow Lite | `tflite` | yolo11n.tflite +TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite +TensorFlow.js | `tfjs` | yolo11n_web_model/ +PaddlePaddle | `paddle` | yolo11n_paddle_model/ +MNN | `mnn` | yolo11n.mnn +NCNN | `ncnn` | yolo11n_ncnn_model/ +RKNN | `rknn` | yolo11n_rknn_model/ +""" + +import glob +import os +import platform +import re +import shutil +import time +from pathlib import Path + +import numpy as np +import torch.cuda +import yaml + +from ultralytics import YOLO, YOLOWorld +from ultralytics.cfg import TASK2DATA, TASK2METRIC +from ultralytics.engine.exporter import export_formats +from ultralytics.utils import ARM64, ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR +from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip +from ultralytics.utils.downloads import safe_download +from ultralytics.utils.files import file_size +from ultralytics.utils.torch_utils import get_cpu_info, select_device + + +def benchmark( + model=WEIGHTS_DIR / "yolo11n.pt", + data=None, + imgsz=160, + half=False, + int8=False, + device="cpu", + verbose=False, + eps=1e-3, + format="", +): + """ + Benchmark a YOLO model across different formats for speed and accuracy. + + Args: + model (str | Path): Path to the model file or directory. + data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed. + imgsz (int): Image size for the benchmark. + half (bool): Use half-precision for the model if True. + int8 (bool): Use int8-precision for the model if True. + device (str): Device to run the benchmark on, either 'cpu' or 'cuda'. + verbose (bool | float): If True or a float, assert benchmarks pass with given metric. + eps (float): Epsilon value for divide by zero prevention. + format (str): Export format for benchmarking. If not supplied all formats are benchmarked. + + Returns: + (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric, + and inference time. + + Examples: + Benchmark a YOLO model with default settings: + >>> from ultralytics.utils.benchmarks import benchmark + >>> benchmark(model="yolo11n.pt", imgsz=640) + """ + imgsz = check_imgsz(imgsz) + assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz." + + import pandas as pd # scope for faster 'import ultralytics' + + pd.options.display.max_columns = 10 + pd.options.display.width = 120 + device = select_device(device, verbose=False) + if isinstance(model, (str, Path)): + model = YOLO(model) + is_end2end = getattr(model.model.model[-1], "end2end", False) + data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect + key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect + + y = [] + t0 = time.time() + + format_arg = format.lower() + if format_arg: + formats = frozenset(export_formats()["Argument"]) + assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'." + for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())): + emoji, filename = "❌", None # export defaults + try: + if format_arg and format_arg != format: + continue + + # Checks + if i == 7: # TF GraphDef + assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task" + elif i == 9: # Edge TPU + assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux" + elif i in {5, 10}: # CoreML and TF.js + assert MACOS or (LINUX and not ARM64), ( + "CoreML and TF.js export only supported on macOS and non-aarch64 Linux" + ) + if i in {5}: # CoreML + assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13" + if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + if i in {9, 10}: # TF EdgeTPU and TF.js + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + if i == 11: # Paddle + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" + assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" + assert LINUX or MACOS, "Windows Paddle exports not supported yet" + if i == 12: # MNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet" + if i == 13: # NCNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" + if i == 14: # IMX + assert not is_end2end + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported" + assert model.task == "detect", "IMX only supported for detection task" + assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" + if i == 15: # RKNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet" + assert not is_end2end, "End-to-end models not supported by RKNN yet" + assert LINUX, "RKNN only supported on Linux" + assert not is_rockchip(), "RKNN Inference only supported on Rockchip devices" + if "cpu" in device.type: + assert cpu, "inference not supported on CPU" + if "cuda" in device.type: + assert gpu, "inference not supported on GPU" + + # Export + if format == "-": + filename = model.pt_path or model.ckpt_path or model.model_name + exported_model = model # PyTorch format + else: + filename = model.export( + imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False + ) + exported_model = YOLO(filename, task=model.task) + assert suffix in str(filename), "export failed" + emoji = "❎" # indicates export succeeded + + # Predict + assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" + assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported + assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML + if i in {13}: + assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet" + exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False) + + # Validate + results = exported_model.val( + data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False + ) + metric, speed = results.results_dict[key], results.speed["inference"] + fps = round(1000 / (speed + eps), 2) # frames per second + y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps]) + except Exception as e: + if verbose: + assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}" + LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}") + y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference + + # Print results + check_yolo(device=device) # print system info + df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"]) + + name = model.model_name + dt = time.time() - t0 + legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed" + s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.fillna('-')}\n" + LOGGER.info(s) + with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f: + f.write(s) + + if verbose and isinstance(verbose, float): + metrics = df[key].array # values to compare to floor + floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n + assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}" + + return df + + +class RF100Benchmark: + """ + Benchmark YOLO model performance across various formats for speed and accuracy. + + This class provides functionality to benchmark YOLO models on the RF100 dataset collection. + + Attributes: + ds_names (List[str]): Names of datasets used for benchmarking. + ds_cfg_list (List[Path]): List of paths to dataset configuration files. + rf (Roboflow): Roboflow instance for accessing datasets. + val_metrics (List[str]): Metrics used for validation. + + Methods: + set_key: Set Roboflow API key for accessing datasets. + parse_dataset: Parse dataset links and download datasets. + fix_yaml: Fix train and validation paths in YAML files. + evaluate: Evaluate model performance on validation results. + """ + + def __init__(self): + """Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats.""" + self.ds_names = [] + self.ds_cfg_list = [] + self.rf = None + self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"] + + def set_key(self, api_key): + """ + Set Roboflow API key for processing. + + Args: + api_key (str): The API key. + + Examples: + Set the Roboflow API key for accessing datasets: + >>> benchmark = RF100Benchmark() + >>> benchmark.set_key("your_roboflow_api_key") + """ + check_requirements("roboflow") + from roboflow import Roboflow + + self.rf = Roboflow(api_key=api_key) + + def parse_dataset(self, ds_link_txt="datasets_links.txt"): + """ + Parse dataset links and download datasets. + + Args: + ds_link_txt (str): Path to the file containing dataset links. + + Returns: + ds_names (List[str]): List of dataset names. + ds_cfg_list (List[Path]): List of paths to dataset configuration files. + + Examples: + >>> benchmark = RF100Benchmark() + >>> benchmark.set_key("api_key") + >>> benchmark.parse_dataset("datasets_links.txt") + """ + (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100") + os.chdir("rf-100") + os.mkdir("ultralytics-benchmarks") + safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt") + + with open(ds_link_txt, encoding="utf-8") as file: + for line in file: + try: + _, url, workspace, project, version = re.split("/+", line.strip()) + self.ds_names.append(project) + proj_version = f"{project}-{version}" + if not Path(proj_version).exists(): + self.rf.workspace(workspace).project(project).version(version).download("yolov8") + else: + print("Dataset already downloaded.") + self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml") + except Exception: + continue + + return self.ds_names, self.ds_cfg_list + + @staticmethod + def fix_yaml(path): + """Fix the train and validation paths in a given YAML file.""" + with open(path, encoding="utf-8") as file: + yaml_data = yaml.safe_load(file) + yaml_data["train"] = "train/images" + yaml_data["val"] = "valid/images" + with open(path, "w", encoding="utf-8") as file: + yaml.safe_dump(yaml_data, file) + + def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind): + """ + Evaluate model performance on validation results. + + Args: + yaml_path (str): Path to the YAML configuration file. + val_log_file (str): Path to the validation log file. + eval_log_file (str): Path to the evaluation log file. + list_ind (int): Index of the current dataset in the list. + + Returns: + (float): The mean average precision (mAP) value for the evaluated model. + + Examples: + Evaluate a model on a specific dataset + >>> benchmark = RF100Benchmark() + >>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0) + """ + skip_symbols = ["🚀", "⚠️", "💡", "❌"] + with open(yaml_path, encoding="utf-8") as stream: + class_names = yaml.safe_load(stream)["names"] + with open(val_log_file, encoding="utf-8") as f: + lines = f.readlines() + eval_lines = [] + for line in lines: + if any(symbol in line for symbol in skip_symbols): + continue + entries = line.split(" ") + entries = list(filter(lambda val: val != "", entries)) + entries = [e.strip("\n") for e in entries] + eval_lines.extend( + { + "class": entries[0], + "images": entries[1], + "targets": entries[2], + "precision": entries[3], + "recall": entries[4], + "map50": entries[5], + "map95": entries[6], + } + for e in entries + if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries) + ) + map_val = 0.0 + if len(eval_lines) > 1: + print("There's more dicts") + for lst in eval_lines: + if lst["class"] == "all": + map_val = lst["map50"] + else: + print("There's only one dict res") + map_val = [res["map50"] for res in eval_lines][0] + + with open(eval_log_file, "a", encoding="utf-8") as f: + f.write(f"{self.ds_names[list_ind]}: {map_val}\n") + + +class ProfileModels: + """ + ProfileModels class for profiling different models on ONNX and TensorRT. + + This class profiles the performance of different models, returning results such as model speed and FLOPs. + + Attributes: + paths (List[str]): Paths of the models to profile. + num_timed_runs (int): Number of timed runs for the profiling. + num_warmup_runs (int): Number of warmup runs before profiling. + min_time (float): Minimum number of seconds to profile for. + imgsz (int): Image size used in the models. + half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling. + trt (bool): Flag to indicate whether to profile using TensorRT. + device (torch.device): Device used for profiling. + + Methods: + profile: Profiles the models and prints the result. + get_files: Gets all relevant model files. + get_onnx_model_info: Extracts metadata from an ONNX model. + iterative_sigma_clipping: Applies sigma clipping to remove outliers. + profile_tensorrt_model: Profiles a TensorRT model. + profile_onnx_model: Profiles an ONNX model. + generate_table_row: Generates a table row with model metrics. + generate_results_dict: Generates a dictionary of profiling results. + print_table: Prints a formatted table of results. + + Examples: + Profile models and print results + >>> from ultralytics.utils.benchmarks import ProfileModels + >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640) + >>> profiler.profile() + """ + + def __init__( + self, + paths: list, + num_timed_runs=100, + num_warmup_runs=10, + min_time=60, + imgsz=640, + half=True, + trt=True, + device=None, + ): + """ + Initialize the ProfileModels class for profiling models. + + Args: + paths (List[str]): List of paths of the models to be profiled. + num_timed_runs (int): Number of timed runs for the profiling. + num_warmup_runs (int): Number of warmup runs before the actual profiling starts. + min_time (float): Minimum time in seconds for profiling a model. + imgsz (int): Size of the image used during profiling. + half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling. + trt (bool): Flag to indicate whether to profile using TensorRT. + device (torch.device | None): Device used for profiling. If None, it is determined automatically. + + Notes: + FP16 'half' argument option removed for ONNX as slower on CPU than FP32. + + Examples: + Initialize and profile models + >>> from ultralytics.utils.benchmarks import ProfileModels + >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640) + >>> profiler.profile() + """ + self.paths = paths + self.num_timed_runs = num_timed_runs + self.num_warmup_runs = num_warmup_runs + self.min_time = min_time + self.imgsz = imgsz + self.half = half + self.trt = trt # run TensorRT profiling + self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu") + + def profile(self): + """ + Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT. + + Returns: + (List[Dict]): List of dictionaries containing profiling results for each model. + + Examples: + Profile models and print results + >>> from ultralytics.utils.benchmarks import ProfileModels + >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"]) + >>> results = profiler.profile() + """ + files = self.get_files() + + if not files: + print("No matching *.pt or *.onnx files found.") + return + + table_rows = [] + output = [] + for file in files: + engine_file = file.with_suffix(".engine") + if file.suffix in {".pt", ".yaml", ".yml"}: + model = YOLO(str(file)) + model.fuse() # to report correct params and GFLOPs in model.info() + model_info = model.info() + if self.trt and self.device.type != "cpu" and not engine_file.is_file(): + engine_file = model.export( + format="engine", + half=self.half, + imgsz=self.imgsz, + device=self.device, + verbose=False, + ) + onnx_file = model.export( + format="onnx", + imgsz=self.imgsz, + device=self.device, + verbose=False, + ) + elif file.suffix == ".onnx": + model_info = self.get_onnx_model_info(file) + onnx_file = file + else: + continue + + t_engine = self.profile_tensorrt_model(str(engine_file)) + t_onnx = self.profile_onnx_model(str(onnx_file)) + table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info)) + output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info)) + + self.print_table(table_rows) + return output + + def get_files(self): + """ + Return a list of paths for all relevant model files given by the user. + + Returns: + (List[Path]): List of Path objects for the model files. + """ + files = [] + for path in self.paths: + path = Path(path) + if path.is_dir(): + extensions = ["*.pt", "*.onnx", "*.yaml"] + files.extend([file for ext in extensions for file in glob.glob(str(path / ext))]) + elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing + files.append(str(path)) + else: + files.extend(glob.glob(str(path))) + + print(f"Profiling: {sorted(files)}") + return [Path(file) for file in sorted(files)] + + @staticmethod + def get_onnx_model_info(onnx_file: str): + """Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape.""" + return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops) + + @staticmethod + def iterative_sigma_clipping(data, sigma=2, max_iters=3): + """ + Apply iterative sigma clipping to data to remove outliers. + + Args: + data (numpy.ndarray): Input data array. + sigma (float): Number of standard deviations to use for clipping. + max_iters (int): Maximum number of iterations for the clipping process. + + Returns: + (numpy.ndarray): Clipped data array with outliers removed. + """ + data = np.array(data) + for _ in range(max_iters): + mean, std = np.mean(data), np.std(data) + clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)] + if len(clipped_data) == len(data): + break + data = clipped_data + return data + + def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3): + """ + Profile YOLO model performance with TensorRT, measuring average run time and standard deviation. + + Args: + engine_file (str): Path to the TensorRT engine file. + eps (float): Small epsilon value to prevent division by zero. + + Returns: + mean_time (float): Mean inference time in milliseconds. + std_time (float): Standard deviation of inference time in milliseconds. + """ + if not self.trt or not Path(engine_file).is_file(): + return 0.0, 0.0 + + # Model and input + model = YOLO(engine_file) + input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify + + # Warmup runs + elapsed = 0.0 + for _ in range(3): + start_time = time.time() + for _ in range(self.num_warmup_runs): + model(input_data, imgsz=self.imgsz, verbose=False) + elapsed = time.time() - start_time + + # Compute number of runs as higher of min_time or num_timed_runs + num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50) + + # Timed runs + run_times = [] + for _ in TQDM(range(num_runs), desc=engine_file): + results = model(input_data, imgsz=self.imgsz, verbose=False) + run_times.append(results[0].speed["inference"]) # Convert to milliseconds + + run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping + return np.mean(run_times), np.std(run_times) + + def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3): + """ + Profile an ONNX model, measuring average inference time and standard deviation across multiple runs. + + Args: + onnx_file (str): Path to the ONNX model file. + eps (float): Small epsilon value to prevent division by zero. + + Returns: + mean_time (float): Mean inference time in milliseconds. + std_time (float): Standard deviation of inference time in milliseconds. + """ + check_requirements("onnxruntime") + import onnxruntime as ort + + # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.intra_op_num_threads = 8 # Limit the number of threads + sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) + + input_tensor = sess.get_inputs()[0] + input_type = input_tensor.type + dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape + input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape + + # Mapping ONNX datatype to numpy datatype + if "float16" in input_type: + input_dtype = np.float16 + elif "float" in input_type: + input_dtype = np.float32 + elif "double" in input_type: + input_dtype = np.float64 + elif "int64" in input_type: + input_dtype = np.int64 + elif "int32" in input_type: + input_dtype = np.int32 + else: + raise ValueError(f"Unsupported ONNX datatype {input_type}") + + input_data = np.random.rand(*input_shape).astype(input_dtype) + input_name = input_tensor.name + output_name = sess.get_outputs()[0].name + + # Warmup runs + elapsed = 0.0 + for _ in range(3): + start_time = time.time() + for _ in range(self.num_warmup_runs): + sess.run([output_name], {input_name: input_data}) + elapsed = time.time() - start_time + + # Compute number of runs as higher of min_time or num_timed_runs + num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs) + + # Timed runs + run_times = [] + for _ in TQDM(range(num_runs), desc=onnx_file): + start_time = time.time() + sess.run([output_name], {input_name: input_data}) + run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds + + run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping + return np.mean(run_times), np.std(run_times) + + def generate_table_row(self, model_name, t_onnx, t_engine, model_info): + """ + Generate a table row string with model performance metrics. + + Args: + model_name (str): Name of the model. + t_onnx (tuple): ONNX model inference time statistics (mean, std). + t_engine (tuple): TensorRT engine inference time statistics (mean, std). + model_info (tuple): Model information (layers, params, gradients, flops). + + Returns: + (str): Formatted table row string with model metrics. + """ + layers, params, gradients, flops = model_info + return ( + f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±" + f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |" + ) + + @staticmethod + def generate_results_dict(model_name, t_onnx, t_engine, model_info): + """ + Generate a dictionary of profiling results. + + Args: + model_name (str): Name of the model. + t_onnx (tuple): ONNX model inference time statistics (mean, std). + t_engine (tuple): TensorRT engine inference time statistics (mean, std). + model_info (tuple): Model information (layers, params, gradients, flops). + + Returns: + (dict): Dictionary containing profiling results. + """ + layers, params, gradients, flops = model_info + return { + "model/name": model_name, + "model/parameters": params, + "model/GFLOPs": round(flops, 3), + "model/speed_ONNX(ms)": round(t_onnx[0], 3), + "model/speed_TensorRT(ms)": round(t_engine[0], 3), + } + + @staticmethod + def print_table(table_rows): + """ + Print a formatted table of model profiling results. + + Args: + table_rows (List[str]): List of formatted table row strings. + """ + gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" + headers = [ + "Model", + "size
(pixels)", + "mAPval
50-95", + f"Speed
CPU ({get_cpu_info()}) ONNX
(ms)", + f"Speed
{gpu} TensorRT
(ms)", + "params
(M)", + "FLOPs
(B)", + ] + header = "|" + "|".join(f" {h} " for h in headers) + "|" + separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|" + + print(f"\n\n{header}") + print(separator) + for row in table_rows: + print(row) diff --git a/tracking/ultralytics/utils/callbacks/__init__.py b/tracking/ultralytics/utils/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..920cc4fad9d6fa2b7b99707d2fe33941e4612e5b --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/__init__.py @@ -0,0 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .base import add_integration_callbacks, default_callbacks, get_default_callbacks + +__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks" diff --git a/tracking/ultralytics/utils/callbacks/base.py b/tracking/ultralytics/utils/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8e95c6bca170de94a42b3dbd1bf2ae31e1d97140 --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/base.py @@ -0,0 +1,217 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Base callbacks for Ultralytics training, validation, prediction, and export processes.""" + +from collections import defaultdict +from copy import deepcopy + +# Trainer callbacks ---------------------------------------------------------------------------------------------------- + + +def on_pretrain_routine_start(trainer): + """Called before the pretraining routine starts.""" + pass + + +def on_pretrain_routine_end(trainer): + """Called after the pretraining routine ends.""" + pass + + +def on_train_start(trainer): + """Called when the training starts.""" + pass + + +def on_train_epoch_start(trainer): + """Called at the start of each training epoch.""" + pass + + +def on_train_batch_start(trainer): + """Called at the start of each training batch.""" + pass + + +def optimizer_step(trainer): + """Called when the optimizer takes a step.""" + pass + + +def on_before_zero_grad(trainer): + """Called before the gradients are set to zero.""" + pass + + +def on_train_batch_end(trainer): + """Called at the end of each training batch.""" + pass + + +def on_train_epoch_end(trainer): + """Called at the end of each training epoch.""" + pass + + +def on_fit_epoch_end(trainer): + """Called at the end of each fit epoch (train + val).""" + pass + + +def on_model_save(trainer): + """Called when the model is saved.""" + pass + + +def on_train_end(trainer): + """Called when the training ends.""" + pass + + +def on_params_update(trainer): + """Called when the model parameters are updated.""" + pass + + +def teardown(trainer): + """Called during the teardown of the training process.""" + pass + + +# Validator callbacks -------------------------------------------------------------------------------------------------- + + +def on_val_start(validator): + """Called when the validation starts.""" + pass + + +def on_val_batch_start(validator): + """Called at the start of each validation batch.""" + pass + + +def on_val_batch_end(validator): + """Called at the end of each validation batch.""" + pass + + +def on_val_end(validator): + """Called when the validation ends.""" + pass + + +# Predictor callbacks -------------------------------------------------------------------------------------------------- + + +def on_predict_start(predictor): + """Called when the prediction starts.""" + pass + + +def on_predict_batch_start(predictor): + """Called at the start of each prediction batch.""" + pass + + +def on_predict_batch_end(predictor): + """Called at the end of each prediction batch.""" + pass + + +def on_predict_postprocess_end(predictor): + """Called after the post-processing of the prediction ends.""" + pass + + +def on_predict_end(predictor): + """Called when the prediction ends.""" + pass + + +# Exporter callbacks --------------------------------------------------------------------------------------------------- + + +def on_export_start(exporter): + """Called when the model export starts.""" + pass + + +def on_export_end(exporter): + """Called when the model export ends.""" + pass + + +default_callbacks = { + # Run in trainer + "on_pretrain_routine_start": [on_pretrain_routine_start], + "on_pretrain_routine_end": [on_pretrain_routine_end], + "on_train_start": [on_train_start], + "on_train_epoch_start": [on_train_epoch_start], + "on_train_batch_start": [on_train_batch_start], + "optimizer_step": [optimizer_step], + "on_before_zero_grad": [on_before_zero_grad], + "on_train_batch_end": [on_train_batch_end], + "on_train_epoch_end": [on_train_epoch_end], + "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val + "on_model_save": [on_model_save], + "on_train_end": [on_train_end], + "on_params_update": [on_params_update], + "teardown": [teardown], + # Run in validator + "on_val_start": [on_val_start], + "on_val_batch_start": [on_val_batch_start], + "on_val_batch_end": [on_val_batch_end], + "on_val_end": [on_val_end], + # Run in predictor + "on_predict_start": [on_predict_start], + "on_predict_batch_start": [on_predict_batch_start], + "on_predict_postprocess_end": [on_predict_postprocess_end], + "on_predict_batch_end": [on_predict_batch_end], + "on_predict_end": [on_predict_end], + # Run in exporter + "on_export_start": [on_export_start], + "on_export_end": [on_export_end], +} + + +def get_default_callbacks(): + """ + Return a copy of the default_callbacks dictionary with lists as default values. + + Returns: + (defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values. + """ + return defaultdict(list, deepcopy(default_callbacks)) + + +def add_integration_callbacks(instance): + """ + Add integration callbacks from various sources to the instance's callbacks. + + Args: + instance (Trainer | Predictor | Validator | Exporter): An object with a 'callbacks' attribute that is a + dictionary of callback lists. + """ + # Load HUB callbacks + from .hub import callbacks as hub_cb + + callbacks_list = [hub_cb] + + # Load training callbacks + if "Trainer" in instance.__class__.__name__: + from .clearml import callbacks as clear_cb + from .comet import callbacks as comet_cb + from .dvc import callbacks as dvc_cb + from .mlflow import callbacks as mlflow_cb + from .neptune import callbacks as neptune_cb + from .raytune import callbacks as tune_cb + from .tensorboard import callbacks as tb_cb + from .wb import callbacks as wb_cb + + callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb]) + + # Add the callbacks to the callbacks dictionary + for callbacks in callbacks_list: + for k, v in callbacks.items(): + if v not in instance.callbacks[k]: + instance.callbacks[k].append(v) diff --git a/tracking/ultralytics/utils/callbacks/clearml.py b/tracking/ultralytics/utils/callbacks/clearml.py new file mode 100644 index 0000000000000000000000000000000000000000..a89d6fb93ca05da352ade3ed3420c8897e092c68 --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/clearml.py @@ -0,0 +1,153 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["clearml"] is True # verify integration is enabled + import clearml + from clearml import Task + + assert hasattr(clearml, "__version__") # verify package is not directory + +except (ImportError, AssertionError): + clearml = None + + +def _log_debug_samples(files, title: str = "Debug Samples") -> None: + """ + Log files (images) as debug samples in the ClearML task. + + Args: + files (List[Path]): A list of file paths in PosixPath format. + title (str): A title that groups together images with the same values. + """ + import re + + if task := Task.current_task(): + for f in files: + if f.exists(): + it = re.search(r"_batch(\d+)", f.name) + iteration = int(it.groups()[0]) if it else 0 + task.get_logger().report_image( + title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration + ) + + +def _log_plot(title: str, plot_path: str) -> None: + """ + Log an image as a plot in the plot section of ClearML. + + Args: + title (str): The title of the plot. + plot_path (str): The path to the saved image file. + """ + import matplotlib.image as mpimg + import matplotlib.pyplot as plt + + img = mpimg.imread(plot_path) + fig = plt.figure() + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks + ax.imshow(img) + + Task.current_task().get_logger().report_matplotlib_figure( + title=title, series="", figure=fig, report_interactive=False + ) + + +def on_pretrain_routine_start(trainer) -> None: + """Runs at start of pretraining routine; initializes and connects/logs task to ClearML.""" + try: + if task := Task.current_task(): + # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled! + # We are logging these plots and model files manually in the integration + from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO + from clearml.binding.matplotlib_bind import PatchedMatplotlib + + PatchPyTorchModelIO.update_current_task(None) + PatchedMatplotlib.update_current_task(None) + else: + task = Task.init( + project_name=trainer.args.project or "Ultralytics", + task_name=trainer.args.name, + tags=["Ultralytics"], + output_uri=True, + reuse_last_task_id=False, + auto_connect_frameworks={"pytorch": False, "matplotlib": False}, + ) + LOGGER.warning( + "ClearML Initialized a new task. If you want to run remotely, " + "please add clearml-init and connect your arguments before initializing YOLO." + ) + task.connect(vars(trainer.args), name="General") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}") + + +def on_train_epoch_end(trainer) -> None: + """Logs debug samples for the first epoch of YOLO training and reports current training progress.""" + if task := Task.current_task(): + # Log debug samples + if trainer.epoch == 1: + _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic") + # Report the current training progress + for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items(): + task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch) + for k, v in trainer.lr.items(): + task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch) + + +def on_fit_epoch_end(trainer) -> None: + """Reports model information to logger at the end of an epoch.""" + if task := Task.current_task(): + # Report epoch time and validation metrics + task.get_logger().report_scalar( + title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch + ) + for k, v in trainer.metrics.items(): + task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch) + if trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + for k, v in model_info_for_loggers(trainer).items(): + task.get_logger().report_single_value(k, v) + + +def on_val_end(validator) -> None: + """Logs validation results including labels and predictions.""" + if Task.current_task(): + # Log val_labels and val_pred + _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation") + + +def on_train_end(trainer) -> None: + """Logs final model and its name on training completion.""" + if task := Task.current_task(): + # Log final results, CM matrix + PR plots + files = [ + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] + files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter + for f in files: + _log_plot(title=f.stem, plot_path=f) + # Report final metrics + for k, v in trainer.validator.metrics.results_dict.items(): + task.get_logger().report_single_value(k, v) + # Log the final model + task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if clearml + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/comet.py b/tracking/ultralytics/utils/callbacks/comet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa25b65875a1c23e2035cde39ccd76a88578199a --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/comet.py @@ -0,0 +1,466 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +from collections.abc import Callable +from types import SimpleNamespace +from typing import Any, List, Optional + +import cv2 +import numpy as np + +from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops +from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["comet"] is True # verify integration is enabled + import comet_ml + + assert hasattr(comet_ml, "__version__") # verify package is not directory + + import os + from pathlib import Path + + # Ensures certain logging functions only run for supported tasks + COMET_SUPPORTED_TASKS = ["detect", "segment"] + + # Names of plots created by Ultralytics that are logged to Comet + CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized" + EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve" + LABEL_PLOT_NAMES = "labels", "labels_correlogram" + SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask" + POSE_METRICS_PLOT_PREFIX = "Box", "Pose" + + _comet_image_prediction_count = 0 + +except (ImportError, AssertionError): + comet_ml = None + + +def _get_comet_mode() -> str: + """Returns the mode of comet set in the environment variables, defaults to 'online' if not set.""" + comet_mode = os.getenv("COMET_MODE") + if comet_mode is not None: + LOGGER.warning( + "WARNING ⚠️ The COMET_MODE environment variable is deprecated. " + "Please use COMET_START_ONLINE to set the Comet experiment mode. " + "To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. " + "If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created." + ) + return comet_mode + + return "online" + + +def _get_comet_model_name() -> str: + """Returns the model name for Comet from the environment variable COMET_MODEL_NAME or defaults to 'Ultralytics'.""" + return os.getenv("COMET_MODEL_NAME", "Ultralytics") + + +def _get_eval_batch_logging_interval() -> int: + """Get the evaluation batch logging interval from environment variable or use default value 1.""" + return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1)) + + +def _get_max_image_predictions_to_log() -> int: + """Get the maximum number of image predictions to log from the environment variables.""" + return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100)) + + +def _scale_confidence_score(score: float) -> float: + """Scales the given confidence score by a factor specified in an environment variable.""" + scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0)) + return score * scale + + +def _should_log_confusion_matrix() -> bool: + """Determines if the confusion matrix should be logged based on the environment variable settings.""" + return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true" + + +def _should_log_image_predictions() -> bool: + """Determines whether to log image predictions based on a specified environment variable.""" + return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true" + + +def _resume_or_create_experiment(args: SimpleNamespace) -> None: + """ + Resumes CometML experiment or creates a new experiment based on args. + + Ensures that the experiment object is only created in a single process during distributed training. + """ + if RANK not in {-1, 0}: + return + + # Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood. + # IF COMET_START_ONLINE is set by the user it will override COMET_MODE value. + if os.getenv("COMET_START_ONLINE") is None: + comet_mode = _get_comet_mode() + os.environ["COMET_START_ONLINE"] = "1" if comet_mode != "offline" else "0" + + try: + _project_name = os.getenv("COMET_PROJECT_NAME", args.project) + experiment = comet_ml.start(project_name=_project_name) + experiment.log_parameters(vars(args)) + experiment.log_others( + { + "eval_batch_logging_interval": _get_eval_batch_logging_interval(), + "log_confusion_matrix_on_eval": _should_log_confusion_matrix(), + "log_image_predictions": _should_log_image_predictions(), + "max_image_predictions": _get_max_image_predictions_to_log(), + } + ) + experiment.log_other("Created from", "ultralytics") + + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}") + + +def _fetch_trainer_metadata(trainer) -> dict: + """Returns metadata for YOLO training including epoch and asset saving status.""" + curr_epoch = trainer.epoch + 1 + + train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size + curr_step = curr_epoch * train_num_steps_per_epoch + final_epoch = curr_epoch == trainer.epochs + + save = trainer.args.save + save_period = trainer.args.save_period + save_interval = curr_epoch % save_period == 0 + save_assets = save and save_period > 0 and save_interval and not final_epoch + + return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch) + + +def _scale_bounding_box_to_original_image_shape( + box, resized_image_shape, original_image_shape, ratio_pad +) -> List[float]: + """ + YOLO resizes images during training and the label values are normalized based on this resized shape. + + This function rescales the bounding box labels to the original image shape. + """ + resized_image_height, resized_image_width = resized_image_shape + + # Convert normalized xywh format predictions to xyxy in resized scale format + box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width) + # Scale box predictions from resized image scale back to original image scale + box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad) + # Convert bounding box format from xyxy to xywh for Comet logging + box = ops.xyxy2xywh(box) + # Adjust xy center to correspond top-left corner + box[:2] -= box[2:] / 2 + box = box.tolist() + + return box + + +def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> Optional[dict]: + """Format ground truth annotations for detection.""" + indices = batch["batch_idx"] == img_idx + bboxes = batch["bboxes"][indices] + if len(bboxes) == 0: + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels") + return None + + cls_labels = batch["cls"][indices].squeeze(1).tolist() + if class_name_map: + cls_labels = [str(class_name_map[label]) for label in cls_labels] + + original_image_shape = batch["ori_shape"][img_idx] + resized_image_shape = batch["resized_shape"][img_idx] + ratio_pad = batch["ratio_pad"][img_idx] + + data = [] + for box, label in zip(bboxes, cls_labels): + box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad) + data.append( + { + "boxes": [box], + "label": f"gt_{label}", + "score": _scale_confidence_score(1.0), + } + ) + + return {"name": "ground_truth", "data": data} + + +def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> Optional[dict]: + """Format YOLO predictions for object detection visualization.""" + stem = image_path.stem + image_id = int(stem) if stem.isnumeric() else stem + + predictions = metadata.get(image_id) + if not predictions: + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions") + return None + + label_index_offset = 0 + if class_map is not None: + # offset to align indices of class labels (starting from zero) + # with prediction's category ID indices (can start from one) + label_index_offset = sorted(class_map)[0] + + try: + # import pycotools utilities to decompress annotations for various tasks, e.g. segmentation + from pycocotools.mask import decode # noqa + except ImportError: + decode = None + + data = [] + for prediction in predictions: + boxes = prediction["bbox"] + score = _scale_confidence_score(prediction["score"]) + cls_label = prediction["category_id"] + if class_label_map: + cls_label = str(class_label_map[cls_label - label_index_offset]) + + annotation_data = {"boxes": [boxes], "label": cls_label, "score": score} + + if decode is not None: + # do segmentation processing only if we are able to decode it + segments = prediction.get("segmentation", None) + if segments is not None: + segments = _extract_segmentation_annotation(segments, decode) + if segments is not None: + annotation_data["points"] = segments + + data.append(annotation_data) + + return {"name": "prediction", "data": data} + + +def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> Optional[List[List[Any]]]: + """ + Extracts segmentation annotation from compressed segmentations as list of polygons. + + Args: + segmentation_raw: Raw segmentation data in compressed format. + decode: Function to decode the compressed segmentation data. + + Returns: + (Optional[List[List[Any]]]): List of polygon points or None if extraction fails. + """ + try: + mask = decode(segmentation_raw) + contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3] + return [annotation.ravel().tolist() for annotation in annotations] + except Exception as e: + LOGGER.warning(f"COMET WARNING: Failed to extract segmentation annotation: {e}") + return None + + +def _fetch_annotations( + img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map +) -> Optional[List]: + """Join the ground truth and prediction annotations if they exist.""" + ground_truth_annotations = _format_ground_truth_annotations_for_detection( + img_idx, image_path, batch, class_label_map + ) + prediction_annotations = _format_prediction_annotations( + image_path, prediction_metadata_map, class_label_map, class_map + ) + + annotations = [ + annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None + ] + return [annotations] if annotations else None + + +def _create_prediction_metadata_map(model_predictions) -> dict: + """Create metadata map for model predictions by groupings them based on image ID.""" + pred_metadata_map = {} + for prediction in model_predictions: + pred_metadata_map.setdefault(prediction["image_id"], []) + pred_metadata_map[prediction["image_id"]].append(prediction) + + return pred_metadata_map + + +def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None: + """Log the confusion matrix to Comet experiment.""" + conf_mat = trainer.validator.confusion_matrix.matrix + names = list(trainer.data["names"].values()) + ["background"] + experiment.log_confusion_matrix( + matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step + ) + + +def _log_images(experiment, image_paths, curr_step, annotations=None) -> None: + """Logs images to the experiment with optional annotations.""" + if annotations: + for image_path, annotation in zip(image_paths, annotations): + experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation) + + else: + for image_path in image_paths: + experiment.log_image(image_path, name=image_path.stem, step=curr_step) + + +def _log_image_predictions(experiment, validator, curr_step) -> None: + """Logs predicted boxes for a single image during training.""" + global _comet_image_prediction_count + + task = validator.args.task + if task not in COMET_SUPPORTED_TASKS: + return + + jdict = validator.jdict + if not jdict: + return + + predictions_metadata_map = _create_prediction_metadata_map(jdict) + dataloader = validator.dataloader + class_label_map = validator.names + class_map = getattr(validator, "class_map", None) + + batch_logging_interval = _get_eval_batch_logging_interval() + max_image_predictions = _get_max_image_predictions_to_log() + + for batch_idx, batch in enumerate(dataloader): + if (batch_idx + 1) % batch_logging_interval != 0: + continue + + image_paths = batch["im_file"] + for img_idx, image_path in enumerate(image_paths): + if _comet_image_prediction_count >= max_image_predictions: + return + + image_path = Path(image_path) + annotations = _fetch_annotations( + img_idx, + image_path, + batch, + predictions_metadata_map, + class_label_map, + class_map=class_map, + ) + _log_images( + experiment, + [image_path], + curr_step, + annotations=annotations, + ) + _comet_image_prediction_count += 1 + + +def _log_plots(experiment, trainer) -> None: + """Logs evaluation plots and label plots for the experiment.""" + plot_filenames = None + if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment": + plot_filenames = [ + trainer.save_dir / f"{prefix}{plots}.png" + for plots in EVALUATION_PLOT_NAMES + for prefix in SEGMENT_METRICS_PLOT_PREFIX + ] + elif isinstance(trainer.validator.metrics, PoseMetrics): + plot_filenames = [ + trainer.save_dir / f"{prefix}{plots}.png" + for plots in EVALUATION_PLOT_NAMES + for prefix in POSE_METRICS_PLOT_PREFIX + ] + elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)): + plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES] + + if plot_filenames is not None: + _log_images(experiment, plot_filenames, None) + + confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES] + _log_images(experiment, confusion_matrix_filenames, None) + + if not isinstance(trainer.validator.metrics, ClassifyMetrics): + label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] + _log_images(experiment, label_plot_filenames, None) + + +def _log_model(experiment, trainer) -> None: + """Log the best-trained model to Comet.ml.""" + model_name = _get_comet_model_name() + experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True) + + +def _log_image_batches(experiment, trainer, curr_step: int) -> None: + """Log samples of images batches for train, validation, and test.""" + _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) + _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step) + + +def on_pretrain_routine_start(trainer) -> None: + """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine.""" + _resume_or_create_experiment(trainer.args) + + +def on_train_epoch_end(trainer) -> None: + """Log metrics and save batch images at the end of training epochs.""" + experiment = comet_ml.get_running_experiment() + if not experiment: + return + + metadata = _fetch_trainer_metadata(trainer) + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + + experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch) + + +def on_fit_epoch_end(trainer) -> None: + """Logs model assets at the end of each epoch.""" + experiment = comet_ml.get_running_experiment() + if not experiment: + return + + metadata = _fetch_trainer_metadata(trainer) + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + save_assets = metadata["save_assets"] + + experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) + experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) + if curr_epoch == 1: + from ultralytics.utils.torch_utils import model_info_for_loggers + + experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) + + if not save_assets: + return + + _log_model(experiment, trainer) + if _should_log_confusion_matrix(): + _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) + if _should_log_image_predictions(): + _log_image_predictions(experiment, trainer.validator, curr_step) + + +def on_train_end(trainer) -> None: + """Perform operations at the end of training.""" + experiment = comet_ml.get_running_experiment() + if not experiment: + return + + metadata = _fetch_trainer_metadata(trainer) + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + plots = trainer.args.plots + + _log_model(experiment, trainer) + if plots: + _log_plots(experiment, trainer) + + _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) + _log_image_predictions(experiment, trainer.validator, curr_step) + _log_image_batches(experiment, trainer, curr_step) + experiment.end() + + global _comet_image_prediction_count + _comet_image_prediction_count = 0 + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if comet_ml + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/dvc.py b/tracking/ultralytics/utils/callbacks/dvc.py new file mode 100644 index 0000000000000000000000000000000000000000..72abbfe4d4cc9a536024afe392cc9f935b92a649 --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/dvc.py @@ -0,0 +1,146 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["dvc"] is True # verify integration is enabled + import dvclive + + assert checks.check_version("dvclive", "2.11.0", verbose=True) + + import os + import re + + # DVCLive logger instance + live = None + _processed_plots = {} + + # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we + # distinguish final evaluation of the best model vs last epoch validation + _training_epoch = False + +except (ImportError, AssertionError, TypeError): + dvclive = None + + +def _log_images(path: Path, prefix: str = "") -> None: + """Logs images at specified path with an optional prefix using DVCLive.""" + if live: + name = path.name + + # Group images by batch to enable sliders in UI + if m := re.search(r"_batch(\d+)", name): + ni = m[1] + new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem) + name = (Path(new_stem) / ni).with_suffix(path.suffix) + + live.log_image(os.path.join(prefix, name), path) + + +def _log_plots(plots: dict, prefix: str = "") -> None: + """Logs plot images for training progress if they have not been previously processed.""" + for name, params in plots.items(): + timestamp = params["timestamp"] + if _processed_plots.get(name) != timestamp: + _log_images(name, prefix) + _processed_plots[name] = timestamp + + +def _log_confusion_matrix(validator) -> None: + """Logs the confusion matrix for the given validator using DVCLive.""" + targets = [] + preds = [] + matrix = validator.confusion_matrix.matrix + names = list(validator.names.values()) + if validator.confusion_matrix.task == "detect": + names += ["background"] + + for ti, pred in enumerate(matrix.T.astype(int)): + for pi, num in enumerate(pred): + targets.extend([names[ti]] * num) + preds.extend([names[pi]] * num) + + live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True) + + +def on_pretrain_routine_start(trainer) -> None: + """Initializes DVCLive logger for training metadata during pre-training routine.""" + try: + global live + live = dvclive.Live(save_dvc_exp=True, cache_images=True) + LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}") + + +def on_pretrain_routine_end(trainer) -> None: + """Logs plots related to the training process at the end of the pretraining routine.""" + _log_plots(trainer.plots, "train") + + +def on_train_start(trainer) -> None: + """Logs the training parameters if DVCLive logging is active.""" + if live: + live.log_params(trainer.args) + + +def on_train_epoch_start(trainer) -> None: + """Sets the global variable _training_epoch value to True at the start of training each epoch.""" + global _training_epoch + _training_epoch = True + + +def on_fit_epoch_end(trainer) -> None: + """Logs training metrics and model info, and advances to next step on the end of each fit epoch.""" + global _training_epoch + if live and _training_epoch: + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} + for metric, value in all_metrics.items(): + live.log_metric(metric, value) + + if trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + for metric, value in model_info_for_loggers(trainer).items(): + live.log_metric(metric, value, plot=False) + + _log_plots(trainer.plots, "train") + _log_plots(trainer.validator.plots, "val") + + live.next_step() + _training_epoch = False + + +def on_train_end(trainer) -> None: + """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active.""" + if live: + # At the end log the best metrics. It runs validator on the best model internally. + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} + for metric, value in all_metrics.items(): + live.log_metric(metric, value, plot=False) + + _log_plots(trainer.plots, "val") + _log_plots(trainer.validator.plots, "val") + _log_confusion_matrix(trainer.validator) + + if trainer.best.exists(): + live.log_artifact(trainer.best, copy=True, type="model") + + live.end() + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_start": on_train_start, + "on_train_epoch_start": on_train_epoch_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if dvclive + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/hub.py b/tracking/ultralytics/utils/callbacks/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4a9a43c7ecf93221e1c722b4f305e9fd57e263 --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/hub.py @@ -0,0 +1,108 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json +from time import time + +from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events +from ultralytics.utils import LOGGER, RANK, SETTINGS + + +def on_pretrain_routine_start(trainer): + """Create a remote Ultralytics HUB session to log local model training.""" + if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None: + trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args) + + +def on_pretrain_routine_end(trainer): + """Initialize timers for upload rate limiting before training begins.""" + if session := getattr(trainer, "hub_session", None): + # Start timer for upload rate limit + session.timers = {"metrics": time(), "ckpt": time()} # start timer for session rate limiting + + +def on_fit_epoch_end(trainer): + """Upload training progress metrics to Ultralytics HUB at the end of each epoch.""" + if session := getattr(trainer, "hub_session", None): + # Upload metrics after validation ends + all_plots = { + **trainer.label_loss_items(trainer.tloss, prefix="train"), + **trainer.metrics, + } + if trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + all_plots = {**all_plots, **model_info_for_loggers(trainer)} + + session.metrics_queue[trainer.epoch] = json.dumps(all_plots) + + # If any metrics failed to upload previously, add them to the queue to attempt uploading again + if session.metrics_upload_failed_queue: + session.metrics_queue.update(session.metrics_upload_failed_queue) + + if time() - session.timers["metrics"] > session.rate_limits["metrics"]: + session.upload_metrics() + session.timers["metrics"] = time() # reset timer + session.metrics_queue = {} # reset queue + + +def on_model_save(trainer): + """Upload model checkpoints to Ultralytics HUB with rate limiting.""" + if session := getattr(trainer, "hub_session", None): + # Upload checkpoints with rate limiting + is_best = trainer.best_fitness == trainer.fitness + if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: + LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}") + session.upload_model(trainer.epoch, trainer.last, is_best) + session.timers["ckpt"] = time() # reset timer + + +def on_train_end(trainer): + """Upload final model and metrics to Ultralytics HUB at the end of training.""" + if session := getattr(trainer, "hub_session", None): + # Upload final model and metrics with exponential standoff + LOGGER.info(f"{PREFIX}Syncing final model...") + session.upload_model( + trainer.epoch, + trainer.best, + map=trainer.metrics.get("metrics/mAP50-95(B)", 0), + final=True, + ) + session.alive = False # stop heartbeats + LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀") + + +def on_train_start(trainer): + """Run events on train start.""" + events(trainer.args) + + +def on_val_start(validator): + """Run events on validation start.""" + events(validator.args) + + +def on_predict_start(predictor): + """Run events on predict start.""" + events(predictor.args) + + +def on_export_start(exporter): + """Run events on export start.""" + events(exporter.args) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_model_save": on_model_save, + "on_train_end": on_train_end, + "on_train_start": on_train_start, + "on_val_start": on_val_start, + "on_predict_start": on_predict_start, + "on_export_start": on_export_start, + } + if SETTINGS["hub"] is True + else {} +) # verify hub is enabled before registering callbacks diff --git a/tracking/ultralytics/utils/callbacks/mlflow.py b/tracking/ultralytics/utils/callbacks/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..b3876e04dfc9fbefc495a1dabf829df2dd4bac8f --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/mlflow.py @@ -0,0 +1,137 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +MLflow Logging for Ultralytics YOLO. + +This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts. +For setting up, a tracking URI should be specified. The logging can be customized using environment variables. + +Commands: + 1. To set a project name: + `export MLFLOW_EXPERIMENT_NAME=` or use the project= argument + + 2. To set a run name: + `export MLFLOW_RUN=` or use the name= argument + + 3. To start a local MLflow server: + mlflow server --backend-store-uri runs/mlflow + It will by default start a local server at http://127.0.0.1:5000. + To specify a different URI, set the MLFLOW_TRACKING_URI environment variable. + + 4. To kill all running MLflow server instances: + ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9 +""" + +from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr + +try: + import os + + assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest + assert SETTINGS["mlflow"] is True # verify integration is enabled + import mlflow + + assert hasattr(mlflow, "__version__") # verify package is not directory + from pathlib import Path + + PREFIX = colorstr("MLflow: ") + +except (ImportError, AssertionError): + mlflow = None + + +def sanitize_dict(x: dict) -> dict: + """Sanitize dictionary keys by removing parentheses and converting values to floats.""" + return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()} + + +def on_pretrain_routine_end(trainer): + """ + Log training parameters to MLflow at the end of the pretraining routine. + + This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI, + experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters + from the trainer. + + Args: + trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log. + + Global: + mlflow: The imported mlflow module to use for logging. + + Environment Variables: + MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'. + MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project. + MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name. + MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of training. + """ + global mlflow + + uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow") + LOGGER.debug(f"{PREFIX} tracking uri: {uri}") + mlflow.set_tracking_uri(uri) + + # Set experiment and run names + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics" + run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name + mlflow.set_experiment(experiment_name) + + mlflow.autolog() + try: + active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name) + LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}") + if Path(uri).is_dir(): + LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'") + LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'") + mlflow.log_params(dict(trainer.args)) + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n{PREFIX}WARNING ⚠️ Not tracking this run") + + +def on_train_epoch_end(trainer): + """Log training metrics at the end of each train epoch to MLflow.""" + if mlflow: + mlflow.log_metrics( + metrics={ + **sanitize_dict(trainer.lr), + **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")), + }, + step=trainer.epoch, + ) + + +def on_fit_epoch_end(trainer): + """Log training metrics at the end of each fit epoch to MLflow.""" + if mlflow: + mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch) + + +def on_train_end(trainer): + """Log model artifacts at the end of the training.""" + if not mlflow: + return + mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt + for f in trainer.save_dir.glob("*"): # log all other files in save_dir + if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: + mlflow.log_artifact(str(f)) + keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true" + if keep_run_active: + LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()") + else: + mlflow.end_run() + LOGGER.debug(f"{PREFIX}mlflow run ended") + + LOGGER.info( + f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'" + ) + + +callbacks = ( + { + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if mlflow + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/neptune.py b/tracking/ultralytics/utils/callbacks/neptune.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4ed25726790ed65d6c935c4cac361ade1f4b5c --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/neptune.py @@ -0,0 +1,118 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["neptune"] is True # verify integration is enabled + + import neptune + from neptune.types import File + + assert hasattr(neptune, "__version__") + + run = None # NeptuneAI experiment logger instance + +except (ImportError, AssertionError): + neptune = None + + +def _log_scalars(scalars: dict, step: int = 0) -> None: + """Log scalars to the NeptuneAI experiment logger.""" + if run: + for k, v in scalars.items(): + run[k].append(value=v, step=step) + + +def _log_images(imgs_dict: dict, group: str = "") -> None: + """Log images to the NeptuneAI experiment logger.""" + if run: + for k, v in imgs_dict.items(): + run[f"{group}/{k}"].upload(File(v)) + + +def _log_plot(title: str, plot_path: str) -> None: + """ + Log plots to the NeptuneAI experiment logger. + + Args: + title (str): Title of the plot. + plot_path (str): Path to the saved image file. + """ + import matplotlib.image as mpimg + import matplotlib.pyplot as plt + + img = mpimg.imread(plot_path) + fig = plt.figure() + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks + ax.imshow(img) + run[f"Plots/{title}"].upload(fig) + + +def on_pretrain_routine_start(trainer) -> None: + """Callback function called before the training routine starts.""" + try: + global run + run = neptune.init_run( + project=trainer.args.project or "Ultralytics", + name=trainer.args.name, + tags=["Ultralytics"], + ) + run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()} + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}") + + +def on_train_epoch_end(trainer) -> None: + """Callback function called at end of each training epoch.""" + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) + _log_scalars(trainer.lr, trainer.epoch + 1) + if trainer.epoch == 1: + _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic") + + +def on_fit_epoch_end(trainer) -> None: + """Callback function called at end of each fit (train+val) epoch.""" + if run and trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + run["Configuration/Model"] = model_info_for_loggers(trainer) + _log_scalars(trainer.metrics, trainer.epoch + 1) + + +def on_val_end(validator) -> None: + """Callback function called at end of each validation.""" + if run: + # Log val_labels and val_pred + _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation") + + +def on_train_end(trainer) -> None: + """Callback function called at end of training.""" + if run: + # Log final results, CM matrix + PR plots + files = [ + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] + files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter + for f in files: + _log_plot(title=f.stem, plot_path=f) + # Log the final model + run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best))) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if neptune + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/raytune.py b/tracking/ultralytics/utils/callbacks/raytune.py new file mode 100644 index 0000000000000000000000000000000000000000..5e84135ee18cc310042bb8711ddbb2c4bb539ff2 --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/raytune.py @@ -0,0 +1,40 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import SETTINGS + +try: + assert SETTINGS["raytune"] is True # verify integration is enabled + import ray + from ray import tune + from ray.air import session + +except (ImportError, AssertionError): + tune = None + + +def on_fit_epoch_end(trainer): + """ + Sends training metrics to Ray Tune at end of each epoch. + + This function checks if a Ray Tune session is active and reports the current training metrics along with the + epoch number to Ray Tune's session. + + Args: + trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs. + + Examples: + >>> # Called automatically by the Ultralytics training loop + >>> on_fit_epoch_end(trainer) + """ + if ray.train._internal.session.get_session(): # check if Ray Tune session is active + metrics = trainer.metrics + session.report({**metrics, **{"epoch": trainer.epoch + 1}}) + + +callbacks = ( + { + "on_fit_epoch_end": on_fit_epoch_end, + } + if tune + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/tensorboard.py b/tracking/ultralytics/utils/callbacks/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..39f25346a2aa05f371f3de5820351d54dedfcc7a --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/tensorboard.py @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr + +try: + # WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674 + from torch.utils.tensorboard import SummaryWriter + + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["tensorboard"] is True # verify integration is enabled + WRITER = None # TensorBoard SummaryWriter instance + PREFIX = colorstr("TensorBoard: ") + + # Imports below only required if TensorBoard enabled + import warnings + from copy import deepcopy + + from ultralytics.utils.torch_utils import de_parallel, torch + +except (ImportError, AssertionError, TypeError, AttributeError): + # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows + # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed + SummaryWriter = None + + +def _log_scalars(scalars: dict, step: int = 0) -> None: + """Logs scalar values to TensorBoard.""" + if WRITER: + for k, v in scalars.items(): + WRITER.add_scalar(k, v, step) + + +def _log_tensorboard_graph(trainer) -> None: + """Log model graph to TensorBoard.""" + # Input image + imgsz = trainer.args.imgsz + imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz + p = next(trainer.model.parameters()) # for device, type + im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning + warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning + + # Try simple method first (YOLO) + try: + trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes + WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) + LOGGER.info(f"{PREFIX}model graph visualization added ✅") + return + + except Exception: + # Fallback to TorchScript export steps (RTDETR) + try: + model = deepcopy(de_parallel(trainer.model)) + model.eval() + model = model.fuse(verbose=False) + for m in model.modules(): + if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class) + m.export = True + m.format = "torchscript" + model(im) # dry run + WRITER.add_graph(torch.jit.trace(model, im, strict=False), []) + LOGGER.info(f"{PREFIX}model graph visualization added ✅") + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}") + + +def on_pretrain_routine_start(trainer) -> None: + """Initialize TensorBoard logging with SummaryWriter.""" + if SummaryWriter: + try: + global WRITER + WRITER = SummaryWriter(str(trainer.save_dir)) + LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}") + + +def on_train_start(trainer) -> None: + """Log TensorBoard graph.""" + if WRITER: + _log_tensorboard_graph(trainer) + + +def on_train_epoch_end(trainer) -> None: + """Logs scalar statistics at the end of a training epoch.""" + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) + _log_scalars(trainer.lr, trainer.epoch + 1) + + +def on_fit_epoch_end(trainer) -> None: + """Logs epoch metrics at end of training epoch.""" + _log_scalars(trainer.metrics, trainer.epoch + 1) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_start": on_train_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_epoch_end": on_train_epoch_end, + } + if SummaryWriter + else {} +) diff --git a/tracking/ultralytics/utils/callbacks/wb.py b/tracking/ultralytics/utils/callbacks/wb.py new file mode 100644 index 0000000000000000000000000000000000000000..24e748b9d9f59db4d212bc5331a6f73090a6014e --- /dev/null +++ b/tracking/ultralytics/utils/callbacks/wb.py @@ -0,0 +1,170 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import SETTINGS, TESTS_RUNNING +from ultralytics.utils.torch_utils import model_info_for_loggers + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["wandb"] is True # verify integration is enabled + import wandb as wb + + assert hasattr(wb, "__version__") # verify package is not directory + _processed_plots = {} + +except (ImportError, AssertionError): + wb = None + + +def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"): + """ + Create and log a custom metric visualization to wandb.plot.pr_curve. + + This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall + curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across + different classes. + + Args: + x (list): Values for the x-axis; expected to have length N. + y (list): Corresponding values for the y-axis; also expected to have length N. + classes (list): Labels identifying the class of each point; length N. + title (str): Title for the plot; defaults to 'Precision Recall Curve'. + x_title (str): Label for the x-axis; defaults to 'Recall'. + y_title (str): Label for the y-axis; defaults to 'Precision'. + + Returns: + (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization. + """ + import pandas # scope for faster 'import ultralytics' + + df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3) + fields = {"x": "x", "y": "y", "class": "class"} + string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title} + return wb.plot_table( + "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields + ) + + +def _plot_curve( + x, + y, + names=None, + id="precision-recall", + title="Precision Recall Curve", + x_title="Recall", + y_title="Precision", + num_x=100, + only_mean=False, +): + """ + Log a metric curve visualization. + + This function generates a metric curve based on input data and logs the visualization to wandb. + The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag. + + Args: + x (np.ndarray): Data points for the x-axis with length N. + y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes. + names (list): Names of the classes corresponding to the y-axis data; length C. + id (str): Unique identifier for the logged data in wandb. + title (str): Title for the visualization plot. + x_title (str): Label for the x-axis. + y_title (str): Label for the y-axis. + num_x (int): Number of interpolated data points for visualization. + only_mean (bool): Flag to indicate if only the mean curve should be plotted. + + Notes: + The function leverages the '_custom_table' function to generate the actual visualization. + """ + import numpy as np + + # Create new x + if names is None: + names = [] + x_new = np.linspace(x[0], x[-1], num_x).round(5) + + # Create arrays for logging + x_log = x_new.tolist() + y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist() + + if only_mean: + table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title]) + wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)}) + else: + classes = ["mean"] * len(x_log) + for i, yi in enumerate(y): + x_log.extend(x_new) # add new x + y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x + classes.extend([names[i]] * len(x_new)) # add class names + wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False) + + +def _log_plots(plots, step): + """Logs plots from the input dictionary if they haven't been logged already at the specified step.""" + for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration + timestamp = params["timestamp"] + if _processed_plots.get(name) != timestamp: + wb.run.log({name.stem: wb.Image(str(name))}, step=step) + _processed_plots[name] = timestamp + + +def on_pretrain_routine_start(trainer): + """Initiate and start wandb project if module is present.""" + if not wb.run: + wb.init( + project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics", + name=str(trainer.args.name).replace("/", "-"), + config=vars(trainer.args), + ) + + +def on_fit_epoch_end(trainer): + """Log training metrics and model information at the end of an epoch.""" + wb.run.log(trainer.metrics, step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) + _log_plots(trainer.validator.plots, step=trainer.epoch + 1) + if trainer.epoch == 0: + wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) + + +def on_train_epoch_end(trainer): + """Log metrics and save images at the end of each training epoch.""" + wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) + wb.run.log(trainer.lr, step=trainer.epoch + 1) + if trainer.epoch == 1: + _log_plots(trainer.plots, step=trainer.epoch + 1) + + +def on_train_end(trainer): + """Save the best model as an artifact and log final plots at the end of training.""" + _log_plots(trainer.validator.plots, step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) + art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model") + if trainer.best.exists(): + art.add_file(trainer.best) + wb.run.log_artifact(art, aliases=["best"]) + # Check if we actually have plots to save + if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"): + for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): + x, y, x_title, y_title = curve_values + _plot_curve( + x, + y, + names=list(trainer.validator.metrics.names.values()), + id=f"curves/{curve_name}", + title=curve_name, + x_title=x_title, + y_title=y_title, + ) + wb.run.finish() # required or run continues on dashboard + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if wb + else {} +) diff --git a/tracking/ultralytics/utils/checks.py b/tracking/ultralytics/utils/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..3410449556c17b8846955cb33bf71aecf8a87489 --- /dev/null +++ b/tracking/ultralytics/utils/checks.py @@ -0,0 +1,894 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import glob +import inspect +import math +import os +import platform +import re +import shutil +import subprocess +import time +from importlib import metadata +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import requests +import torch + +from ultralytics.utils import ( + ARM64, + ASSETS, + AUTOINSTALL, + IS_COLAB, + IS_GIT_DIR, + IS_KAGGLE, + IS_PIP_PACKAGE, + LINUX, + LOGGER, + MACOS, + ONLINE, + PYTHON_VERSION, + RKNN_CHIPS, + ROOT, + TORCHVISION_VERSION, + USER_CONFIG_DIR, + WINDOWS, + Retry, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_github_action_running, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file. + + Returns: + (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and `specifier` attributes. + + Examples: + >>> from ultralytics.utils.checks import parse_requirements + >>> parse_requirements(package="ultralytics") + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line): + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + + Args: + s (str): String to be checked. + + Returns: + (bool): True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + + +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int | List[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + max_dim (int): Maximum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + (List[int] | int): Updated image size. + """ + # Convert stride to integer if it is a tensor + stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + elif isinstance(imgsz, (list, tuple)): + imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) + else: + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) + + # Apply max_dim + if len(imgsz) > max_dim: + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated + if sz != imgsz: + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") + + # Add missing dimensions if necessary + sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz + + return sz + + +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str): Name to be used in warning message. + hard (bool): If True, raise an AssertionError if the requirement is not met. + verbose (bool): If True, print warning message if requirement is not met. + msg (str): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Examples: + Check if current version is exactly 22.04 + >>> check_version(current="22.04", required="==22.04") + + Check if current version is greater than or equal to 22.04 + >>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + Check if current version is less than or equal to 22.04 + >>> check_version(current="22.04", required="<=22.04") + + Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + >>> check_version(current="21.10", required=">20.04,<22.04") + """ + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"WARNING ⚠️ {name}{required} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +def check_latest_pypi_version(package_name="ultralytics"): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Args: + package_name (str): The name of the package to find the latest version for. + + Returns: + (str): The latest version of the package. + """ + try: + requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) + if response.status_code == 200: + return response.json()["info"]["version"] + except Exception: + return None + + +def check_pip_update_available(): + """ + Checks if a new version of the ultralytics package is available on PyPI. + + Returns: + (bool): True if an update is available, False otherwise. + """ + if ONLINE and IS_PIP_PACKAGE: + try: + from ultralytics import __version__ + + latest = check_latest_pypi_version() + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) + return True + except Exception: + pass + return False + + +@ThreadingLocked() +def check_font(font="Arial.ttf"): + """ + Find font locally or download to user's configuration directory if it does not already exist. + + Args: + font (str): Path or name of font. + + Returns: + (Path): Resolved font file path. + """ + from matplotlib import font_manager + + # Check USER_CONFIG_DIR + name = Path(font).name + file = USER_CONFIG_DIR / name + if file.exists(): + return file + + # Check system fonts + matches = [s for s in font_manager.findSystemFonts() if font in s] + if any(matches): + return matches[0] + + # Download to USER_CONFIG_DIR if missing + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" + if downloads.is_url(url, check=True): + downloads.safe_download(url=url, file=file) + return file + + +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + hard (bool): If True, raise an AssertionError if the requirement is not met. + verbose (bool): If True, print warning message if requirement is not met. + + Returns: + (bool): Whether the installed Python version meets the minimum constraints. + """ + return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. + + Args: + requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (Tuple[str]): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Examples: + >>> from ultralytics.utils.checks import check_requirements + + Check a requirements.txt file + >>> check_requirements("path/to/requirements.txt") + + Check a single package + >>> check_requirements("ultralytics>=8.0.0") + + Check multiple packages + >>> check_requirements(["numpy", "ultralytics>=8.0.0"]) + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands): + """Attempt pip install command with retries on failure.""" + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds)) + dt = time.time() - t + LOGGER.info( + f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + + +def check_torchvision(): + """ + Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. + + This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according + to the compatibility table based on: https://github.com/pytorch/vision#installation. + """ + compatibility_table = { + "2.6": ["0.21"], + "2.5": ["0.20"], + "2.4": ["0.19"], + "2.3": ["0.18"], + "2.2": ["0.17"], + "2.1": ["0.16"], + "2.0": ["0.15"], + "1.13": ["0.14"], + "1.12": ["0.13"], + } + + # Check major and minor versions + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + if v_torch in compatibility_table: + compatible_versions = compatibility_table[v_torch] + v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2]) + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) + + +def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): + """ + Check file(s) for acceptable suffix. + + Args: + file (str | List[str]): File or list of files to check. + suffix (str | Tuple[str]): Acceptable suffix or tuple of suffixes. + msg (str): Additional message to display in case of error. + """ + if file and suffix: + if isinstance(suffix, str): + suffix = (suffix,) + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower().strip() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" + + +def check_yolov5u_filename(file: str, verbose: bool = True): + """ + Replace legacy YOLOv5 filenames with updated YOLOv5u filenames. + + Args: + file (str): Filename to check and potentially update. + verbose (bool): Whether to print information about the replacement. + + Returns: + (str): Updated filename. + """ + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file and verbose: + LOGGER.info( + f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) + return file + + +def check_model_file_from_stem(model="yolo11n"): + """ + Return a model filename from a valid model stem. + + Args: + model (str): Model stem to check. + + Returns: + (str | Path): Model filename with appropriate suffix. + """ + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt + else: + return model + + +def check_file(file, suffix="", download=True, download_dir=".", hard=True): + """ + Search/download file (if necessary) and return path. + + Args: + file (str): File name or path. + suffix (str): File suffix to check. + download (bool): Whether to download the file if it doesn't exist locally. + download_dir (str): Directory to download the file to. + hard (bool): Whether to raise an error if the file is not found. + + Returns: + (str): Path to the file. + """ + check_suffix(file, suffix) # optional + file = str(file).strip() # convert to string and strip spaces + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images + return file + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download + url = file # warning: Pathlib turns :// -> :/ + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + downloads.safe_download(url=url, file=file, unzip=False) + return str(file) + else: # search + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file + if not files and hard: + raise FileNotFoundError(f"'{file}' does not exist") + elif len(files) > 1 and hard: + raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") + return files[0] if len(files) else [] # return file + + +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): + """ + Search/download YAML file (if necessary) and return path, checking suffix. + + Args: + file (str): File name or path. + suffix (tuple): Acceptable file suffixes. + hard (bool): Whether to raise an error if the file is not found. + + Returns: + (str): Path to the YAML file. + """ + return check_file(file, suffix, hard=hard) + + +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + +def check_imshow(warn=False): + """ + Check if environment supports image displays. + + Args: + warn (bool): Whether to warn if environment doesn't support image displays. + + Returns: + (bool): True if environment supports image displays, False otherwise. + """ + try: + if LINUX: + assert not IS_COLAB and not IS_KAGGLE + assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + if warn: + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") + return False + + +def check_yolo(verbose=True, device=""): + """ + Return a human-readable YOLO software and hardware summary. + + Args: + verbose (bool): Whether to print verbose information. + device (str): Device to use for YOLO. + """ + import psutil + + from ultralytics.utils.torch_utils import select_device + + if IS_COLAB: + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory + + if verbose: + # System info + gib = 1 << 30 # bytes per GiB + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" + try: + from IPython import display + + display.clear_output() # clear display if notebook + except ImportError: + pass + else: + s = "" + + select_device(device=device, newline=False) + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """ + Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA. + + Returns: + (dict): Dictionary containing system information. + """ + import psutil + + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import + from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info + + gib = 1 << 30 # bytes per GiB + cuda = torch and torch.cuda.is_available() + check_yolo() + total, used, free = shutil.disk_usage("/") + + info_dict = { + "OS": platform.platform(), + "Environment": ENVIRONMENT, + "Python": PYTHON_VERSION, + "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "Path": str(ROOT), + "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", + "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", + "CPU": get_cpu_info(), + "CPU count": os.cpu_count(), + "GPU": get_gpu_info(index=0) if cuda else None, + "GPU count": torch.cuda.device_count() if cuda else None, + "CUDA": torch.version.cuda if cuda else None, + } + LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") + + package_info = {} + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + package_info[r.name] = f"{is_met}{current}{r.specifier}" + LOGGER.info(f"{r.name:<20}{package_info[r.name]}") + + info_dict["Package Info"] = package_info + + if is_github_action_running(): + github_info = { + "RUNNER_OS": os.getenv("RUNNER_OS"), + "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), + "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), + "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), + "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), + "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), + } + LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) + info_dict["GitHub Info"] = github_info + + return info_dict + + +def check_amp(model): + """ + Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. + + If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP + results, so AMP will be disabled during training. + + Args: + model (nn.Module): A YOLO11 model instance. + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. + + Examples: + >>> from ultralytics import YOLO + >>> from ultralytics.utils.checks import check_amp + >>> model = YOLO("yolo11n.pt").model.cuda() + >>> check_amp(model) + """ + from ultralytics.utils.torch_utils import autocast + + device = next(model.parameters()).device # get model device + prefix = colorstr("AMP: ") + if device.type in {"cpu", "mps"}: + return False # AMP only used on CUDA devices + else: + # GPUs that have issues with AMP + pattern = re.compile( + r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE + ) + + gpu = torch.cuda.get_device_name(device) + if bool(pattern.search(gpu)): + LOGGER.warning( + f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + batch = [im] * 8 + imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 + a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference + with autocast(enabled=True): + b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + im = ASSETS / "bus.jpg" # image to check + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + + assert amp_allclose(YOLO("yolo11n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") + except ConnectionError: + LOGGER.warning( + f"{prefix}checks skipped ⚠️. Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" + ) + except AssertionError: + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + return True + + +def git_describe(path=ROOT): # path must be a directory + """ + Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe. + + Args: + path (Path): Path to git repository. + + Returns: + (str): Human-readable git description. + """ + try: + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + except Exception: + return "" + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + """ + Print function arguments (optional args dict). + + Args: + args (dict, optional): Arguments to print. + show_file (bool): Whether to show the file name. + show_func (bool): Whether to show the function name. + """ + + def strip_auth(v): + """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v + + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) + + +def cuda_device_count() -> int: + """ + Get the number of NVIDIA GPUs available in the environment. + + Returns: + (int): The number of NVIDIA GPUs available. + """ + try: + # Run the nvidia-smi command and capture its output + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + + # Take the first line and strip any leading/trailing white space + first_line = output.strip().split("\n")[0] + + return int(first_line) + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available + return 0 + + +def cuda_is_available() -> bool: + """ + Check if CUDA is available in the environment. + + Returns: + (bool): True if one or more NVIDIA GPUs are available, False otherwise. + """ + return cuda_device_count() > 0 + + +def is_rockchip(): + """ + Check if the current environment is running on a Rockchip SoC. + + Returns: + (bool): True if running on a Rockchip SoC, False otherwise. + """ + if LINUX and ARM64: + try: + with open("/proc/device-tree/compatible") as f: + dev_str = f.read() + *_, soc = dev_str.split(",") + if soc.replace("\x00", "") in RKNN_CHIPS: + return True + except OSError: + return False + else: + return False + + +def is_sudo_available() -> bool: + """ + Check if the sudo command is available in the environment. + + Returns: + (bool): True if the sudo command is available, False otherwise. + """ + if WINDOWS: + return False + cmd = "sudo --version" + return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0 + + +# Run checks and define constants +check_python("3.8", hard=False, verbose=True) # check python version +check_torchvision() # check torch-torchvision compatibility + +# Define constants +IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") +IS_PYTHON_3_13 = PYTHON_VERSION.startswith("3.13") diff --git a/tracking/ultralytics/utils/dist.py b/tracking/ultralytics/utils/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7715a02f2025a116b66f2758e9cc069d2c9f6c --- /dev/null +++ b/tracking/ultralytics/utils/dist.py @@ -0,0 +1,85 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +import shutil +import socket +import sys +import tempfile + +from . import USER_CONFIG_DIR +from .torch_utils import TORCH_1_9 + + +def find_free_network_port() -> int: + """ + Find a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real main node but have to set the + `MASTER_PORT` environment variable. + + Returns: + (int): The available network port number. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] # port + + +def generate_ddp_file(trainer): + """Generates a DDP file and returns its file name.""" + module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1) + + content = f""" +# Ultralytics Multi-GPU training temp file (should be automatically deleted after use) +overrides = {vars(trainer.args)} + +if __name__ == "__main__": + from {module} import {name} + from ultralytics.utils import DEFAULT_CFG_DICT + + cfg = DEFAULT_CFG_DICT.copy() + cfg.update(save_dir='') # handle the extra key 'save_dir' + trainer = {name}(cfg=cfg, overrides=overrides) + trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}" + results = trainer.train() +""" + (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True) + with tempfile.NamedTemporaryFile( + prefix="_temp_", + suffix=f"{id(trainer)}.py", + mode="w+", + encoding="utf-8", + dir=USER_CONFIG_DIR / "DDP", + delete=False, + ) as file: + file.write(content) + return file.name + + +def generate_ddp_command(world_size, trainer): + """ + Generate command for distributed training. + + Args: + world_size (int): Number of processes to spawn for distributed training. + trainer (object): The trainer object containing configuration for distributed training. + + Returns: + cmd (List[str]): The command to execute for distributed training. + file (str): Path to the temporary file created for DDP training. + """ + import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + + if not trainer.resume: + shutil.rmtree(trainer.save_dir) # remove the save_dir + file = generate_ddp_file(trainer) + dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch" + port = find_free_network_port() + cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file] + return cmd, file + + +def ddp_cleanup(trainer, file): + """Delete temp file if created.""" + if f"{id(trainer)}.py" in file: # if temp_file suffix in file + os.remove(file) diff --git a/tracking/ultralytics/utils/downloads.py b/tracking/ultralytics/utils/downloads.py new file mode 100644 index 0000000000000000000000000000000000000000..8b75721bb2891398f2affbbb77e779f5fb21b3ec --- /dev/null +++ b/tracking/ultralytics/utils/downloads.py @@ -0,0 +1,488 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import re +import shutil +import subprocess +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from urllib import parse, request + +import requests +import torch + +from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file + +# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] + + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "tsmce"] + + [f"yolov10{k}.pt" for k in "nsmblx"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) +GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] + + +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to False. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Examples: + >>> valid = is_url("https://www.example.com") + """ + try: + url = str(url) + result = parse.urlparse(url) + assert all([result.scheme, result.netloc]) # check if is url + if check: + with request.urlopen(url) as response: + return response.getcode() == 200 # check if exists online + return True + except Exception: + return False + + +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): + """ + Delete all ".DS_store" files in a specified directory. + + Args: + path (str, optional): The directory path where the ".DS_store" files should be deleted. + files_to_delete (tuple): The files to be deleted. + + Examples: + >>> from ultralytics.utils.downloads import delete_dsstore + >>> delete_dsstore("path/to/dir") + + Notes: + ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They + are hidden system files and can cause issues when transferring files between different operating systems. + """ + for file in files_to_delete: + matches = list(Path(path).rglob(file)) + LOGGER.info(f"Deleting {file} files: {matches}") + for f in matches: + f.unlink() + + +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): + """ + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. + + Args: + directory (str | Path): The path to the directory to be zipped. + compress (bool): Whether to compress the files while zipping. Default is True. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + (Path): The path to the resulting zip file. + + Examples: + >>> from ultralytics.utils.downloads import zip_directory + >>> file = zip_directory("path/to/dir") + """ + from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile + + delete_dsstore(directory) + directory = Path(directory) + if not directory.is_dir(): + raise FileNotFoundError(f"Directory '{directory}' does not exist.") + + # Unzip with progress bar + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") + compression = ZIP_DEFLATED if compress else ZIP_STORED + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): + f.write(file, file.relative_to(directory)) + + return zip_file # return path to zip file + + +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): + """ + Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. + + If the zipfile does not contain a single top-level directory, the function will create a new + directory with the same name as the zipfile (without the extension) to extract its contents. + If a path is not provided, the function will use the parent directory of the zipfile as the default path. + + Args: + file (str | Path): The path to the zipfile to be extracted. + path (str | Path, optional): The path to extract the zipfile to. Defaults to None. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Raises: + BadZipFile: If the provided file does not exist or is not a valid zipfile. + + Returns: + (Path): The path to the directory where the zipfile was extracted. + + Examples: + >>> from ultralytics.utils.downloads import unzip_file + >>> directory = unzip_file("path/to/file.zip") + """ + from zipfile import BadZipFile, ZipFile, is_zipfile + + if not (Path(file).exists() and is_zipfile(file)): + raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") + if path is None: + path = Path(file).parent # default path + + # Unzip the file contents + with ZipFile(file) as zipObj: + files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] + top_level_dirs = {Path(f).parts[0] for f in files} + + # Decide to unzip directly or unzip into a directory + unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) + if unzip_as_dir: + # Zip has 1 top-level directory + extract_path = path # i.e. ../datasets + path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ + else: + # Zip has multiple files at top level + path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ + + # Check if destination directory already exists and contains files + if path.exists() and any(path.iterdir()) and not exist_ok: + # If it exists and is not empty, return the path without unzipping + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") + return path + + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) + + return path # return unzip dir + + +def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True): + """ + Check if there is sufficient disk space to download and store a file. + + Args: + url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'. + path (str | Path, optional): The path or drive to check the available free space on. + sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5. + hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. + + Returns: + (bool): True if there is sufficient disk space, False otherwise. + """ + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True + + # Check file size + gib = 1 << 30 # bytes per GiB + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + + if data * sf < free: + return True # sufficient space + + # Insufficient space + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) + if hard: + raise MemoryError(text) + LOGGER.warning(text) + return False + + +def get_google_drive_file_info(link): + """ + Retrieves the direct download link and filename for a shareable Google Drive file link. + + Args: + link (str): The shareable link of the Google Drive file. + + Returns: + (str): Direct download URL for the Google Drive file. + (str): Original filename of the Google Drive file. If filename extraction fails, returns None. + + Examples: + >>> from ultralytics.utils.downloads import get_google_drive_file_info + >>> link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" + >>> url, filename = get_google_drive_file_info(link) + """ + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" + filename = None + + # Start session + with requests.Session() as session: + response = session.get(drive_url, stream=True) + if "quota exceeded" in str(response.content.lower()): + raise ConnectionError( + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + if cd := response.headers.get("content-disposition"): + filename = re.findall('filename="(.+)"', cd)[0] + return drive_url, filename + + +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): + """ + Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. + + Args: + url (str): The URL of the file to be downloaded. + file (str, optional): The filename of the downloaded file. + If not provided, the file will be saved with the same name as the URL. + dir (str | Path, optional): The directory to save the downloaded file. + If not provided, the file will be saved in the current working directory. + unzip (bool, optional): Whether to unzip the downloaded file. Default: True. + delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False. + curl (bool, optional): Whether to use curl command line tool for downloading. Default: False. + retry (int, optional): The number of times to retry the download in case of failure. Default: 3. + min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered + a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + progress (bool, optional): Whether to display a progress bar during the download. Default: True. + + Returns: + (Path | str): The path to the downloaded file or extracted directory. + + Examples: + >>> from ultralytics.utils.downloads import safe_download + >>> link = "https://ultralytics.com/assets/bus.jpg" + >>> path = safe_download(link) + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link + if gdrive: + url, file = get_google_drive_file_info(url) + + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(url) # filename + elif not f.is_file(): # URL and file do not exist + uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url + "https://github.com/ultralytics/assets/releases/download/v0.0.0/", + "https://ultralytics.com/assets/", # assets alias + ) + desc = f"Downloading {uri} to '{f}'" + LOGGER.info(f"{desc}...") + f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing + check_disk_space(url, path=f.parent) + for i in range(retry + 1): + try: + if curl or i > 0: # curl download with retry, continue + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" + else: # urllib download + method = "torch" + if method == "torch": + torch.hub.download_url_to_file(url, f, progress=progress) + else: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: + for data in response: + f_opened.write(data) + pbar.update(len(data)) + + if f.exists(): + if f.stat().st_size > min_bytes: + break # success + f.unlink() # remove partial downloads + except Exception as e: + if i == 0 and not is_online(): + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e + elif i >= retry: + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...") + + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: + from zipfile import is_zipfile + + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place + if is_zipfile(f): + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in {".tar", ".gz"}: + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) + if delete: + f.unlink() # remove zip + return unzip_dir + return f + + +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (str): The release tag. + (List[str]): A list of asset names. + + Examples: + >>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") + """ + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" + r = requests.get(url) # github api + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded + r = requests.get(url) # try again + if r.status_code != 200: + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] + data = r.json() + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...] + + +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'. + **kwargs (Any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Examples: + >>> file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") + """ + from ultralytics.utils import SETTINGS # scoped for circular import + + # YOLOv3/5u updates + file = str(file) + file = checks.check_yolov5u_filename(file) + file = Path(file.strip().replace("'", "")) + if file.exists(): + return str(file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) + else: + # URL specified + name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ + file = url2file(name) # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) + + elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) + + else: + tag, assets = get_github_assets(repo, release) + if not assets: + tag, assets = get_github_assets(repo) # latest release + if name in assets: + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) + + return str(file) + + +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | List[str]): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Examples: + >>> download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) + """ + dir = Path(dir) + dir.mkdir(parents=True, exist_ok=True) # make directory + if threads > 1: + with ThreadPool(threads) as pool: + pool.map( + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) + pool.close() + pool.join() + else: + for u in [url] if isinstance(url, (str, Path)) else url: + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/tracking/ultralytics/utils/errors.py b/tracking/ultralytics/utils/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb7aae13f1acc7cfae790ac8aa2246f50485e5e --- /dev/null +++ b/tracking/ultralytics/utils/errors.py @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import emojis + + +class HUBModelError(Exception): + """ + Custom exception class for handling errors related to model fetching in Ultralytics YOLO. + + This exception is raised when a requested model is not found or cannot be retrieved. + The message is also processed to include emojis for better user experience. + + Attributes: + message (str): The error message displayed when the exception is raised. + + Note: + The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package. + """ + + def __init__(self, message="Model not found. Please check model URL and try again."): + """Create an exception for when a model is not found.""" + super().__init__(emojis(message)) diff --git a/tracking/ultralytics/utils/files.py b/tracking/ultralytics/utils/files.py new file mode 100644 index 0000000000000000000000000000000000000000..495bf71d2016ddb43333654b96da7bbb20e26b8a --- /dev/null +++ b/tracking/ultralytics/utils/files.py @@ -0,0 +1,221 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import glob +import os +import shutil +import tempfile +from contextlib import contextmanager +from datetime import datetime +from pathlib import Path + + +class WorkingDirectory(contextlib.ContextDecorator): + """ + A context manager and decorator for temporarily changing the working directory. + + This class allows for the temporary change of the working directory using a context manager or decorator. + It ensures that the original working directory is restored after the context or decorated function completes. + + Attributes: + dir (Path | str): The new directory to switch to. + cwd (Path): The original current working directory before the switch. + + Methods: + __enter__: Changes the current directory to the specified directory. + __exit__: Restores the original working directory on context exit. + + Examples: + Using as a context manager: + >>> with WorkingDirectory('/path/to/new/dir'): + >>> # Perform operations in the new directory + >>> pass + + Using as a decorator: + >>> @WorkingDirectory('/path/to/new/dir') + >>> def some_function(): + >>> # Perform operations in the new directory + >>> pass + """ + + def __init__(self, new_dir): + """Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators.""" + self.dir = new_dir # new dir + self.cwd = Path.cwd().resolve() # current dir + + def __enter__(self): + """Changes the current working directory to the specified directory upon entering the context.""" + os.chdir(self.dir) + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa + """Restores the original working directory when exiting the context.""" + os.chdir(self.cwd) + + +@contextmanager +def spaces_in_path(path): + """ + Context manager to handle paths with spaces in their names. + + If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes + the context code block, then copies the file/directory back to its original location. + + Args: + path (str | Path): The original path that may contain spaces. + + Yields: + (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path. + + Examples: + >>> with spaces_in_path('/path/with spaces') as new_path: + >>> # Your code here + >>> pass + """ + # If path has spaces, replace them with underscores + if " " in str(path): + string = isinstance(path, str) # input type + path = Path(path) + + # Create a temporary directory and construct the new path + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / path.name.replace(" ", "_") + + # Copy file/directory + if path.is_dir(): + # tmp_path.mkdir(parents=True, exist_ok=True) + shutil.copytree(path, tmp_path) + elif path.is_file(): + tmp_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(path, tmp_path) + + try: + # Yield the temporary path + yield str(tmp_path) if string else tmp_path + + finally: + # Copy file/directory back + if tmp_path.is_dir(): + shutil.copytree(tmp_path, path, dirs_exist_ok=True) + elif tmp_path.is_file(): + shutil.copy2(tmp_path, path) # Copy back the file + + else: + # If there are no spaces, just yield the original path + yield path + + +def increment_path(path, exist_ok=False, sep="", mkdir=False): + """ + Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + + If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to + the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the + number will be appended directly to the end of the path. + + Args: + path (str | Path): Path to increment. + exist_ok (bool): If True, the path will not be incremented and returned as-is. + sep (str): Separator to use between the path and the incrementation number. + mkdir (bool): Create a directory if it does not exist. + + Returns: + (Path): Incremented path. + + Examples: + Increment a directory path: + >>> from pathlib import Path + >>> path = Path("runs/exp") + >>> new_path = increment_path(path) + >>> print(new_path) + runs/exp2 + + Increment a file path: + >>> path = Path("runs/exp/results.txt") + >>> new_path = increment_path(path) + >>> print(new_path) + runs/exp/results2.txt + """ + path = Path(path) # os-agnostic + if path.exists() and not exist_ok: + path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") + + # Method 1 + for n in range(2, 9999): + p = f"{path}{sep}{n}{suffix}" # increment path + if not os.path.exists(p): + break + path = Path(p) + + if mkdir: + path.mkdir(parents=True, exist_ok=True) # make directory + + return path + + +def file_age(path=__file__): + """Return days since the last modification of the specified file.""" + dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta + return dt.days # + dt.seconds / 86400 # fractional days + + +def file_date(path=__file__): + """Returns the file modification date in 'YYYY-M-D' format.""" + t = datetime.fromtimestamp(Path(path).stat().st_mtime) + return f"{t.year}-{t.month}-{t.day}" + + +def file_size(path): + """Returns the size of a file or directory in megabytes (MB).""" + if isinstance(path, (str, Path)): + mb = 1 << 20 # bytes to MiB (1024 ** 2) + path = Path(path) + if path.is_file(): + return path.stat().st_size / mb + elif path.is_dir(): + return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb + return 0.0 + + +def get_latest_run(search_dir="."): + """Returns the path to the most recent 'last.pt' file in the specified directory for resuming training.""" + last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) + return max(last_list, key=os.path.getctime) if last_list else "" + + +def update_models(model_names=("yolo11n.pt",), source_dir=Path("."), update_names=False): + """ + Update and re-save specified YOLO models in an 'updated_models' subdirectory. + + Args: + model_names (Tuple[str, ...]): Model filenames to update. + source_dir (Path): Directory containing models and target subdirectory. + update_names (bool): Update model names from a data YAML. + + Examples: + Update specified YOLO models and save them in 'updated_models' subdirectory: + >>> from ultralytics.utils.files import update_models + >>> model_names = ("yolo11n.pt", "yolov8s.pt") + >>> update_models(model_names, source_dir=Path("/models"), update_names=True) + """ + from ultralytics import YOLO + from ultralytics.nn.autobackend import default_class_names + + target_dir = source_dir / "updated_models" + target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists + + for model_name in model_names: + model_path = source_dir / model_name + print(f"Loading model from {model_path}") + + # Load model + model = YOLO(model_path) + model.half() + if update_names: # update model names from a dataset YAML + model.model.names = default_class_names("coco8.yaml") + + # Define new save path + save_path = target_dir / model_name + + # Save model using model.save() + print(f"Re-saving {model_name} model to {save_path}") + model.save(save_path) diff --git a/tracking/ultralytics/utils/instance.py b/tracking/ultralytics/utils/instance.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1007db5c40c40799c906859782b1674bba04d8 --- /dev/null +++ b/tracking/ultralytics/utils/instance.py @@ -0,0 +1,494 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from collections import abc +from itertools import repeat +from numbers import Number +from typing import List + +import numpy as np + +from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh + + +def _ntuple(n): + """From PyTorch internals.""" + + def parse(x): + """Parse input to return n-tuple by repeating singleton values n times.""" + return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) +to_4tuple = _ntuple(4) + +# `xyxy` means left top and right bottom +# `xywh` means center x, center y and width, height(YOLO format) +# `ltwh` means left top and width, height(COCO format) +_formats = ["xyxy", "xywh", "ltwh"] + +__all__ = ("Bboxes", "Instances") # tuple or list + + +class Bboxes: + """ + A class for handling bounding boxes. + + The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'. + Bounding box data should be provided in numpy arrays. + + Attributes: + bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4). + format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh'). + + Note: + This class does not handle normalization or denormalization of bounding boxes. + """ + + def __init__(self, bboxes, format="xyxy") -> None: + """ + Initialize the Bboxes class with bounding box data in a specified format. + + Args: + bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,). + format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'. + """ + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" + bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes + assert bboxes.ndim == 2 + assert bboxes.shape[1] == 4 + self.bboxes = bboxes + self.format = format + # self.normalized = normalized + + def convert(self, format): + """ + Convert bounding box format from one type to another. + + Args: + format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'. + """ + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" + if self.format == format: + return + elif self.format == "xyxy": + func = xyxy2xywh if format == "xywh" else xyxy2ltwh + elif self.format == "xywh": + func = xywh2xyxy if format == "xyxy" else xywh2ltwh + else: + func = ltwh2xyxy if format == "xyxy" else ltwh2xywh + self.bboxes = func(self.bboxes) + self.format = format + + def areas(self): + """Return box areas.""" + return ( + (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy + if self.format == "xyxy" + else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh + ) + + # def denormalize(self, w, h): + # if not self.normalized: + # return + # assert (self.bboxes <= 1.0).all() + # self.bboxes[:, 0::2] *= w + # self.bboxes[:, 1::2] *= h + # self.normalized = False + # + # def normalize(self, w, h): + # if self.normalized: + # return + # assert (self.bboxes > 1.0).any() + # self.bboxes[:, 0::2] /= w + # self.bboxes[:, 1::2] /= h + # self.normalized = True + + def mul(self, scale): + """ + Multiply bounding box coordinates by scale factor(s). + + Args: + scale (int | tuple | list): Scale factor(s) for four coordinates. + If int, the same scale is applied to all coordinates. + """ + if isinstance(scale, Number): + scale = to_4tuple(scale) + assert isinstance(scale, (tuple, list)) + assert len(scale) == 4 + self.bboxes[:, 0] *= scale[0] + self.bboxes[:, 1] *= scale[1] + self.bboxes[:, 2] *= scale[2] + self.bboxes[:, 3] *= scale[3] + + def add(self, offset): + """ + Add offset to bounding box coordinates. + + Args: + offset (int | tuple | list): Offset(s) for four coordinates. + If int, the same offset is applied to all coordinates. + """ + if isinstance(offset, Number): + offset = to_4tuple(offset) + assert isinstance(offset, (tuple, list)) + assert len(offset) == 4 + self.bboxes[:, 0] += offset[0] + self.bboxes[:, 1] += offset[1] + self.bboxes[:, 2] += offset[2] + self.bboxes[:, 3] += offset[3] + + def __len__(self): + """Return the number of boxes.""" + return len(self.bboxes) + + @classmethod + def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes": + """ + Concatenate a list of Bboxes objects into a single Bboxes object. + + Args: + boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate. + axis (int, optional): The axis along which to concatenate the bounding boxes. + + Returns: + (Bboxes): A new Bboxes object containing the concatenated bounding boxes. + + Note: + The input should be a list or tuple of Bboxes objects. + """ + assert isinstance(boxes_list, (list, tuple)) + if not boxes_list: + return cls(np.empty(0)) + assert all(isinstance(box, Bboxes) for box in boxes_list) + + if len(boxes_list) == 1: + return boxes_list[0] + return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) + + def __getitem__(self, index) -> "Bboxes": + """ + Retrieve a specific bounding box or a set of bounding boxes using indexing. + + Args: + index (int | slice | np.ndarray): The index, slice, or boolean array to select + the desired bounding boxes. + + Returns: + (Bboxes): A new Bboxes object containing the selected bounding boxes. + + Raises: + AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix. + + Note: + When using boolean indexing, make sure to provide a boolean array with the same + length as the number of bounding boxes. + """ + if isinstance(index, int): + return Bboxes(self.bboxes[index].reshape(1, -1)) + b = self.bboxes[index] + assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" + return Bboxes(b) + + +class Instances: + """ + Container for bounding boxes, segments, and keypoints of detected objects in an image. + + Attributes: + _bboxes (Bboxes): Internal object for handling bounding box operations. + keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible). + normalized (bool): Flag indicating whether the bounding box coordinates are normalized. + segments (np.ndarray): Segments array with shape (N, M, 2) after resampling. + + Methods: + convert_bbox: Convert bounding box format. + scale: Scale coordinates by given factors. + denormalize: Convert normalized coordinates to absolute coordinates. + normalize: Convert absolute coordinates to normalized coordinates. + add_padding: Add padding to coordinates. + flipud: Flip coordinates vertically. + fliplr: Flip coordinates horizontally. + clip: Clip coordinates to stay within image boundaries. + remove_zero_area_boxes: Remove boxes with zero area. + update: Update instance variables. + concatenate: Concatenate multiple Instances objects. + + Examples: + >>> instances = Instances( + ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]), + ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])], + ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]), + ... ) + """ + + def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: + """ + Initialize the object with bounding boxes, segments, and keypoints. + + Args: + bboxes (np.ndarray): Bounding boxes, shape (N, 4). + segments (List | np.ndarray, optional): Segmentation masks. + keypoints (np.ndarray, optional): Keypoints, shape (N, 17, 3) in format (x, y, visible). + bbox_format (str, optional): Format of bboxes. + normalized (bool, optional): Whether the coordinates are normalized. + """ + self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format) + self.keypoints = keypoints + self.normalized = normalized + self.segments = segments + + def convert_bbox(self, format): + """ + Convert bounding box format. + + Args: + format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'. + """ + self._bboxes.convert(format=format) + + @property + def bbox_areas(self): + """Calculate the area of bounding boxes.""" + return self._bboxes.areas() + + def scale(self, scale_w, scale_h, bbox_only=False): + """ + Scale coordinates by given factors. + + Args: + scale_w (float): Scale factor for width. + scale_h (float): Scale factor for height. + bbox_only (bool, optional): Whether to scale only bounding boxes. + """ + self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h)) + if bbox_only: + return + self.segments[..., 0] *= scale_w + self.segments[..., 1] *= scale_h + if self.keypoints is not None: + self.keypoints[..., 0] *= scale_w + self.keypoints[..., 1] *= scale_h + + def denormalize(self, w, h): + """ + Convert normalized coordinates to absolute coordinates. + + Args: + w (int): Image width. + h (int): Image height. + """ + if not self.normalized: + return + self._bboxes.mul(scale=(w, h, w, h)) + self.segments[..., 0] *= w + self.segments[..., 1] *= h + if self.keypoints is not None: + self.keypoints[..., 0] *= w + self.keypoints[..., 1] *= h + self.normalized = False + + def normalize(self, w, h): + """ + Convert absolute coordinates to normalized coordinates. + + Args: + w (int): Image width. + h (int): Image height. + """ + if self.normalized: + return + self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h)) + self.segments[..., 0] /= w + self.segments[..., 1] /= h + if self.keypoints is not None: + self.keypoints[..., 0] /= w + self.keypoints[..., 1] /= h + self.normalized = True + + def add_padding(self, padw, padh): + """ + Add padding to coordinates. + + Args: + padw (int): Padding width. + padh (int): Padding height. + """ + assert not self.normalized, "you should add padding with absolute coordinates." + self._bboxes.add(offset=(padw, padh, padw, padh)) + self.segments[..., 0] += padw + self.segments[..., 1] += padh + if self.keypoints is not None: + self.keypoints[..., 0] += padw + self.keypoints[..., 1] += padh + + def __getitem__(self, index) -> "Instances": + """ + Retrieve a specific instance or a set of instances using indexing. + + Args: + index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances. + + Returns: + (Instances): A new Instances object containing the selected boxes, segments, and keypoints if present. + + Note: + When using boolean indexing, make sure to provide a boolean array with the same + length as the number of instances. + """ + segments = self.segments[index] if len(self.segments) else self.segments + keypoints = self.keypoints[index] if self.keypoints is not None else None + bboxes = self.bboxes[index] + bbox_format = self._bboxes.format + return Instances( + bboxes=bboxes, + segments=segments, + keypoints=keypoints, + bbox_format=bbox_format, + normalized=self.normalized, + ) + + def flipud(self, h): + """ + Flip coordinates vertically. + + Args: + h (int): Image height. + """ + if self._bboxes.format == "xyxy": + y1 = self.bboxes[:, 1].copy() + y2 = self.bboxes[:, 3].copy() + self.bboxes[:, 1] = h - y2 + self.bboxes[:, 3] = h - y1 + else: + self.bboxes[:, 1] = h - self.bboxes[:, 1] + self.segments[..., 1] = h - self.segments[..., 1] + if self.keypoints is not None: + self.keypoints[..., 1] = h - self.keypoints[..., 1] + + def fliplr(self, w): + """ + Flip coordinates horizontally. + + Args: + w (int): Image width. + """ + if self._bboxes.format == "xyxy": + x1 = self.bboxes[:, 0].copy() + x2 = self.bboxes[:, 2].copy() + self.bboxes[:, 0] = w - x2 + self.bboxes[:, 2] = w - x1 + else: + self.bboxes[:, 0] = w - self.bboxes[:, 0] + self.segments[..., 0] = w - self.segments[..., 0] + if self.keypoints is not None: + self.keypoints[..., 0] = w - self.keypoints[..., 0] + + def clip(self, w, h): + """ + Clip coordinates to stay within image boundaries. + + Args: + w (int): Image width. + h (int): Image height. + """ + ori_format = self._bboxes.format + self.convert_bbox(format="xyxy") + self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) + self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) + if ori_format != "xyxy": + self.convert_bbox(format=ori_format) + self.segments[..., 0] = self.segments[..., 0].clip(0, w) + self.segments[..., 1] = self.segments[..., 1].clip(0, h) + if self.keypoints is not None: + self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w) + self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) + + def remove_zero_area_boxes(self): + """ + Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. + + Returns: + (np.ndarray): Boolean array indicating which boxes were kept. + """ + good = self.bbox_areas > 0 + if not all(good): + self._bboxes = self._bboxes[good] + if len(self.segments): + self.segments = self.segments[good] + if self.keypoints is not None: + self.keypoints = self.keypoints[good] + return good + + def update(self, bboxes, segments=None, keypoints=None): + """ + Update instance variables. + + Args: + bboxes (np.ndarray): New bounding boxes. + segments (np.ndarray, optional): New segments. + keypoints (np.ndarray, optional): New keypoints. + """ + self._bboxes = Bboxes(bboxes, format=self._bboxes.format) + if segments is not None: + self.segments = segments + if keypoints is not None: + self.keypoints = keypoints + + def __len__(self): + """Return the length of the instance list.""" + return len(self.bboxes) + + @classmethod + def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": + """ + Concatenate a list of Instances objects into a single Instances object. + + Args: + instances_list (List[Instances]): A list of Instances objects to concatenate. + axis (int, optional): The axis along which the arrays will be concatenated. + + Returns: + (Instances): A new Instances object containing the concatenated bounding boxes, + segments, and keypoints if present. + + Note: + The `Instances` objects in the list should have the same properties, such as + the format of the bounding boxes, whether keypoints are present, and if the + coordinates are normalized. + """ + assert isinstance(instances_list, (list, tuple)) + if not instances_list: + return cls(np.empty(0)) + assert all(isinstance(instance, Instances) for instance in instances_list) + + if len(instances_list) == 1: + return instances_list[0] + + use_keypoint = instances_list[0].keypoints is not None + bbox_format = instances_list[0]._bboxes.format + normalized = instances_list[0].normalized + + cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) + seg_len = [b.segments.shape[1] for b in instances_list] + if len(frozenset(seg_len)) > 1: # resample segments if there's different length + max_len = max(seg_len) + cat_segments = np.concatenate( + [ + resample_segments(list(b.segments), max_len) + if len(b.segments) + else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments + for b in instances_list + ], + axis=axis, + ) + else: + cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) + cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None + return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized) + + @property + def bboxes(self): + """Return bounding boxes.""" + return self._bboxes.bboxes diff --git a/tracking/ultralytics/utils/loss.py b/tracking/ultralytics/utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3945f0391af1271c095d469ad633e334cc92b54e --- /dev/null +++ b/tracking/ultralytics/utils/loss.py @@ -0,0 +1,738 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils.metrics import OKS_SIGMA +from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh +from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors +from ultralytics.utils.torch_utils import autocast + +from .metrics import bbox_iou, probiou +from .tal import bbox2dist + + +class VarifocalLoss(nn.Module): + """ + Varifocal loss by Zhang et al. + + https://arxiv.org/abs/2008.13367. + """ + + def __init__(self): + """Initialize the VarifocalLoss class.""" + super().__init__() + + @staticmethod + def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): + """Compute varfocal loss between predictions and ground truth.""" + weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label + with autocast(enabled=False): + loss = ( + (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) + .mean(1) + .sum() + ) + return loss + + +class FocalLoss(nn.Module): + """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" + + def __init__(self): + """Initialize FocalLoss class with no parameters.""" + super().__init__() + + @staticmethod + def forward(pred, label, gamma=1.5, alpha=0.25): + """Calculate focal loss with modulating factors for class imbalance.""" + loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none") + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = pred.sigmoid() # prob from logits + p_t = label * pred_prob + (1 - label) * (1 - pred_prob) + modulating_factor = (1.0 - p_t) ** gamma + loss *= modulating_factor + if alpha > 0: + alpha_factor = label * alpha + (1 - label) * (1 - alpha) + loss *= alpha_factor + return loss.mean(1).sum() + + +class DFLoss(nn.Module): + """Criterion class for computing Distribution Focal Loss (DFL).""" + + def __init__(self, reg_max=16) -> None: + """Initialize the DFL module with regularization maximum.""" + super().__init__() + self.reg_max = reg_max + + def __call__(self, pred_dist, target): + """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391.""" + target = target.clamp_(0, self.reg_max - 1 - 0.01) + tl = target.long() # target left + tr = tl + 1 # target right + wl = tr - target # weight left + wr = 1 - wl # weight right + return ( + F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr + ).mean(-1, keepdim=True) + + +class BboxLoss(nn.Module): + """Criterion class for computing training losses for bounding boxes.""" + + def __init__(self, reg_max=16): + """Initialize the BboxLoss module with regularization maximum and DFL settings.""" + super().__init__() + self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + """Compute IoU and DFL losses for bounding boxes.""" + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) + iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + if self.dfl_loss: + target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1) + loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl + + +class RotatedBboxLoss(BboxLoss): + """Criterion class for computing training losses for rotated bounding boxes.""" + + def __init__(self, reg_max): + """Initialize the BboxLoss module with regularization maximum and DFL settings.""" + super().__init__(reg_max) + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + """Compute IoU and DFL losses for rotated bounding boxes.""" + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) + iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + if self.dfl_loss: + target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1) + loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl + + +class KeypointLoss(nn.Module): + """Criterion class for computing keypoint losses.""" + + def __init__(self, sigmas) -> None: + """Initialize the KeypointLoss class with keypoint sigmas.""" + super().__init__() + self.sigmas = sigmas + + def forward(self, pred_kpts, gt_kpts, kpt_mask, area): + """Calculate keypoint loss factor and Euclidean distance loss for keypoints.""" + d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2) + kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9) + # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula + e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval + return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean() + + +class v8DetectionLoss: + """Criterion class for computing training losses for YOLOv8 object detection.""" + + def __init__(self, model, tal_topk=10): # model must be de-paralleled + """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings.""" + device = next(model.parameters()).device # get model device + h = model.args # hyperparameters + + m = model.model[-1] # Detect() module + self.bce = nn.BCEWithLogitsLoss(reduction="none") + self.hyp = h + self.stride = m.stride # model strides + self.nc = m.nc # number of classes + self.no = m.nc + m.reg_max * 4 + self.reg_max = m.reg_max + self.device = device + + self.use_dfl = m.reg_max > 1 + + self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = BboxLoss(m.reg_max).to(device) + self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocess targets by converting to tensor format and scaling coordinates.""" + nl, ne = targets.shape + if nl == 0: + out = torch.zeros(batch_size, 0, ne - 1, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device) + for j in range(batch_size): + matches = i == j + if n := matches.sum(): + out[j, :n] = targets[matches, 1:] + out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) + return out + + def bbox_decode(self, anchor_points, pred_dist): + """Decode predicted object bounding box coordinates from anchor points and distribution.""" + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) + return dist2bbox(pred_dist, anchor_points, xywh=False) + + def __call__(self, preds, batch): + """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats = preds[1] if isinstance(preds, tuple) else preds + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + batch_size = pred_scores.shape[0] + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Targets + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1) + # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2 + + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2, + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.cls # cls gain + loss[2] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + +class v8SegmentationLoss(v8DetectionLoss): + """Criterion class for computing training losses for YOLOv8 segmentation.""" + + def __init__(self, model): # model must be de-paralleled + """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting.""" + super().__init__(model) + self.overlap = model.args.overlap_mask + + def __call__(self, preds, batch): + """Calculate and return the combined loss for detection and segmentation.""" + loss = torch.zeros(4, device=self.device) # box, cls, dfl + feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] + batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # B, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_masks = pred_masks.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Targets + try: + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + except RuntimeError as e: + raise TypeError( + "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" + "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a " + "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help." + ) from e + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + if fg_mask.sum(): + # Bbox loss + loss[0], loss[3] = self.bbox_loss( + pred_distri, + pred_bboxes, + anchor_points, + target_bboxes / stride_tensor, + target_scores, + target_scores_sum, + fg_mask, + ) + # Masks loss + masks = batch["masks"].to(self.device).float() + if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] + + loss[1] = self.calculate_segmentation_loss( + fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap + ) + + # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.box # seg gain + loss[2] *= self.hyp.cls # cls gain + loss[3] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + @staticmethod + def single_mask_loss( + gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor + ) -> torch.Tensor: + """ + Compute the instance segmentation loss for a single image. + + Args: + gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. + pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). + proto (torch.Tensor): Prototype masks of shape (32, H, W). + xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). + area (torch.Tensor): Area of each ground truth bounding box of shape (n,). + + Returns: + (torch.Tensor): The calculated mask loss for a single image. + + Notes: + The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the + predicted masks from the prototype masks and predicted mask coefficients. + """ + pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") + return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum() + + def calculate_segmentation_loss( + self, + fg_mask: torch.Tensor, + masks: torch.Tensor, + target_gt_idx: torch.Tensor, + target_bboxes: torch.Tensor, + batch_idx: torch.Tensor, + proto: torch.Tensor, + pred_masks: torch.Tensor, + imgsz: torch.Tensor, + overlap: bool, + ) -> torch.Tensor: + """ + Calculate the loss for instance segmentation. + + Args: + fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. + masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). + target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). + target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). + batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). + proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). + pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). + imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). + overlap (bool): Whether the masks in `masks` tensor overlap. + + Returns: + (torch.Tensor): The calculated loss for instance segmentation. + + Notes: + The batch loss can be computed for improved speed at higher memory usage. + For example, pred_mask can be computed as follows: + pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """ + _, _, mask_h, mask_w = proto.shape + loss = 0 + + # Normalize to 0-1 + target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]] + + # Areas of target bboxes + marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2) + + # Normalize to mask size + mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device) + + for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)): + fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i + if fg_mask_i.any(): + mask_idx = target_gt_idx_i[fg_mask_i] + if overlap: + gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) + gt_mask = gt_mask.float() + else: + gt_mask = masks[batch_idx.view(-1) == i][mask_idx] + + loss += self.single_mask_loss( + gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i] + ) + + # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + return loss / fg_mask.sum() + + +class v8PoseLoss(v8DetectionLoss): + """Criterion class for computing training losses for YOLOv8 pose estimation.""" + + def __init__(self, model): # model must be de-paralleled + """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions.""" + super().__init__(model) + self.kpt_shape = model.model[-1].kpt_shape + self.bce_pose = nn.BCEWithLogitsLoss() + is_pose = self.kpt_shape == [17, 3] + nkpt = self.kpt_shape[0] # number of keypoints + sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt + self.keypoint_loss = KeypointLoss(sigmas=sigmas) + + def __call__(self, preds, batch): + """Calculate the total loss and detach it for pose estimation.""" + loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility + feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # B, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_kpts = pred_kpts.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Targets + batch_size = pred_scores.shape[0] + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[4] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + keypoints = batch["keypoints"].to(self.device).float().clone() + keypoints[..., 0] *= imgsz[1] + keypoints[..., 1] *= imgsz[0] + + loss[1], loss[2] = self.calculate_keypoints_loss( + fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ) + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.pose # pose gain + loss[2] *= self.hyp.kobj # kobj gain + loss[3] *= self.hyp.cls # cls gain + loss[4] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + @staticmethod + def kpts_decode(anchor_points, pred_kpts): + """Decode predicted keypoints to image coordinates.""" + y = pred_kpts.clone() + y[..., :2] *= 2.0 + y[..., 0] += anchor_points[:, [0]] - 0.5 + y[..., 1] += anchor_points[:, [1]] - 0.5 + return y + + def calculate_keypoints_loss( + self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ): + """ + Calculate the keypoints loss for the model. + + This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is + based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is + a binary classification loss that classifies whether a keypoint is present or not. + + Args: + masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors). + target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors). + keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim). + batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1). + stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1). + target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4). + pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim). + + Returns: + kpts_loss (torch.Tensor): The keypoints loss. + kpts_obj_loss (torch.Tensor): The keypoints object loss. + """ + batch_idx = batch_idx.flatten() + batch_size = len(masks) + + # Find the maximum number of keypoints in a single image + max_kpts = torch.unique(batch_idx, return_counts=True)[1].max() + + # Create a tensor to hold batched keypoints + batched_keypoints = torch.zeros( + (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device + ) + + # TODO: any idea how to vectorize this? + # Fill batched_keypoints with keypoints based on batch_idx + for i in range(batch_size): + keypoints_i = keypoints[batch_idx == i] + batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i + + # Expand dimensions of target_gt_idx to match the shape of batched_keypoints + target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1) + + # Use target_gt_idx_expanded to select keypoints from batched_keypoints + selected_keypoints = batched_keypoints.gather( + 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]) + ) + + # Divide coordinates by stride + selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1) + + kpts_loss = 0 + kpts_obj_loss = 0 + + if masks.any(): + gt_kpt = selected_keypoints[masks] + area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True) + pred_kpt = pred_kpts[masks] + kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True) + kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss + + if pred_kpt.shape[-1] == 3: + kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss + + return kpts_loss, kpts_obj_loss + + +class v8ClassificationLoss: + """Criterion class for computing training losses for classification.""" + + def __call__(self, preds, batch): + """Compute the classification loss between predictions and true labels.""" + preds = preds[1] if isinstance(preds, (list, tuple)) else preds + loss = F.cross_entropy(preds, batch["cls"], reduction="mean") + loss_items = loss.detach() + return loss, loss_items + + +class v8OBBLoss(v8DetectionLoss): + """Calculates losses for object detection, classification, and box distribution in rotated YOLO models.""" + + def __init__(self, model): + """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled.""" + super().__init__(model) + self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device) + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocess targets for oriented bounding box detection.""" + if targets.shape[0] == 0: + out = torch.zeros(batch_size, 0, 6, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), 6, device=self.device) + for j in range(batch_size): + matches = i == j + if n := matches.sum(): + bboxes = targets[matches, 2:] + bboxes[..., :4].mul_(scale_tensor) + out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1) + return out + + def __call__(self, preds, batch): + """Calculate and return the loss for oriented bounding box detection.""" + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats, pred_angle = preds if isinstance(preds[0], list) else preds[1] + batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # b, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_angle = pred_angle.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + try: + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1) + rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item() + targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + except RuntimeError as e: + raise TypeError( + "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" + "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolo11n-obb.pt data=dota8.yaml'.\nVerify your dataset is a " + "correctly formatted 'OBB' dataset using 'data=dota8.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help." + ) from e + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4) + + bboxes_for_assigner = pred_bboxes.clone().detach() + # Only the first four elements need to be scaled + bboxes_for_assigner[..., :4] *= stride_tensor + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + pred_scores.detach().sigmoid(), + bboxes_for_assigner.type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes[..., :4] /= stride_tensor + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + else: + loss[0] += (pred_angle * 0).sum() + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.cls # cls gain + loss[2] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + def bbox_decode(self, anchor_points, pred_dist, pred_angle): + """ + Decode predicted object bounding box coordinates from anchor points and distribution. + + Args: + anchor_points (torch.Tensor): Anchor points, (h*w, 2). + pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4). + pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1). + + Returns: + (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5). + """ + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1) + + +class E2EDetectLoss: + """Criterion class for computing training losses for end-to-end detection.""" + + def __init__(self, model): + """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model.""" + self.one2many = v8DetectionLoss(model, tal_topk=10) + self.one2one = v8DetectionLoss(model, tal_topk=1) + + def __call__(self, preds, batch): + """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" + preds = preds[1] if isinstance(preds, tuple) else preds + one2many = preds["one2many"] + loss_one2many = self.one2many(one2many, batch) + one2one = preds["one2one"] + loss_one2one = self.one2one(one2one, batch) + return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1] diff --git a/tracking/ultralytics/utils/metrics.py b/tracking/ultralytics/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad8b5a9519876eaa2094c7e45d362d06d145578 --- /dev/null +++ b/tracking/ultralytics/utils/metrics.py @@ -0,0 +1,1351 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Model validation metrics.""" + +import math +import warnings +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings + +OKS_SIGMA = ( + np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) + / 10.0 +) + + +def bbox_ioa(box1, box2, iou=False, eps=1e-7): + """ + Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format. + + Args: + box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes. + box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes. + iou (bool): Calculate the standard IoU if True else return inter_area/box2_area. + eps (float, optional): A small value to avoid division by zero. + + Returns: + (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area. + """ + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1.T + b2_x1, b2_y1, b2_x2, b2_y2 = box2.T + + # Intersection area + inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * ( + np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1) + ).clip(0) + + # Box2 area + area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + if iou: + box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + area = area + box1_area[:, None] - inter_area + + # Intersection over box2 area + return inter_area / (area + eps) + + +def box_iou(box1, box2, eps=1e-7): + """ + Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py. + + Args: + box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes. + box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes. + eps (float, optional): A small value to avoid division by zero. + + Returns: + (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2. + """ + # NOTE: Need .float() to get accurate iou values + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2) + inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2) + + # IoU = inter / (area1 + area2 - inter) + return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps) + + +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + """ + Calculate the Intersection over Union (IoU) between bounding boxes. + + This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. + For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). + Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`, + or (x1, y1, x2, y2) if `xywh=False`. + + Args: + box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4. + box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4. + xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in + (x1, y1, x2, y2) format. + GIoU (bool, optional): If True, calculate Generalized IoU. + DIoU (bool, optional): If True, calculate Distance IoU. + CIoU (bool, optional): If True, calculate Complete IoU. + eps (float, optional): A small value to avoid division by zero. + + Returns: + (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags. + """ + # Get the coordinates of bounding boxes + if xywh: # transform from xywh to xyxy + (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) + w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 + b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ + b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ + else: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) + b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # Intersection area + inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * ( + b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) + ).clamp_(0) + + # Union Area + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union + if CIoU or DIoU or GIoU: + cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width + ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared + rho2 = ( + (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2) + ) / 4 # center dist**2 + if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + return iou - rho2 / c2 # DIoU + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf + return iou # IoU + + +def mask_iou(mask1, mask2, eps=1e-7): + """ + Calculate masks IoU. + + Args: + mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the + product of image width and height. + mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the + product of image width and height. + eps (float, optional): A small value to avoid division by zero. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing masks IoU. + """ + intersection = torch.matmul(mask1, mask2.T).clamp_(0) + union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection + return intersection / (union + eps) + + +def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7): + """ + Calculate Object Keypoint Similarity (OKS). + + Args: + kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints. + kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints. + area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth. + sigma (list): A list containing 17 values representing keypoint scales. + eps (float, optional): A small value to avoid division by zero. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities. + """ + d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17) + sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, ) + kpt_mask = kpt1[..., 2] != 0 # (N, 17) + e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval + # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula + return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps) + + +def _get_covariance_matrix(boxes): + """ + Generate covariance matrix from oriented bounding boxes. + + Args: + boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format. + + Returns: + (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes. + """ + # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here. + gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1) + a, b, c = gbbs.split(1, dim=-1) + cos = c.cos() + sin = c.sin() + cos2 = cos.pow(2) + sin2 = sin.pow(2) + return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin + + +def probiou(obb1, obb2, CIoU=False, eps=1e-7): + """ + Calculate probabilistic IoU between oriented bounding boxes. + + Args: + obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr. + obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr. + CIoU (bool, optional): If True, calculate CIoU. + eps (float, optional): Small value to avoid division by zero. + + Returns: + (torch.Tensor): OBB similarities, shape (N,). + + Notes: + - OBB format: [center_x, center_y, width, height, rotation_angle]. + - Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf. + """ + x1, y1 = obb1[..., :2].split(1, dim=-1) + x2, y2 = obb2[..., :2].split(1, dim=-1) + a1, b1, c1 = _get_covariance_matrix(obb1) + a2, b2, c2 = _get_covariance_matrix(obb2) + + t1 = ( + ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) + ) * 0.25 + t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) + / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) + + eps + ).log() * 0.5 + bd = (t1 + t2 + t3).clamp(eps, 100.0) + hd = (1.0 - (-bd).exp() + eps).sqrt() + iou = 1 - hd + if CIoU: # only include the wh aspect ratio part + w1, h1 = obb1[..., 2:4].split(1, dim=-1) + w2, h2 = obb2[..., 2:4].split(1, dim=-1) + v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - v * alpha # CIoU + return iou + + +def batch_probiou(obb1, obb2, eps=1e-7): + """ + Calculate the probabilistic IoU between oriented bounding boxes. + + Args: + obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. + obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. + eps (float, optional): A small value to avoid division by zero. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing obb similarities. + + References: + https://arxiv.org/pdf/2106.06072v1.pdf + """ + obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1 + obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2 + + x1, y1 = obb1[..., :2].split(1, dim=-1) + x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1)) + a1, b1, c1 = _get_covariance_matrix(obb1) + a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2)) + + t1 = ( + ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) + ) * 0.25 + t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) + / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) + + eps + ).log() * 0.5 + bd = (t1 + t2 + t3).clamp(eps, 100.0) + hd = (1.0 - (-bd).exp() + eps).sqrt() + return 1 - hd + + +def smooth_bce(eps=0.1): + """ + Compute smoothed positive and negative Binary Cross-Entropy targets. + + Args: + eps (float, optional): The epsilon value for label smoothing. + + Returns: + (tuple): A tuple containing the positive and negative label smoothing BCE targets. + + References: + https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 + """ + return 1.0 - 0.5 * eps, 0.5 * eps + + +class ConfusionMatrix: + """ + A class for calculating and updating a confusion matrix for object detection and classification tasks. + + Attributes: + task (str): The type of task, either 'detect' or 'classify'. + matrix (np.ndarray): The confusion matrix, with dimensions depending on the task. + nc (int): The number of classes. + conf (float): The confidence threshold for detections. + iou_thres (float): The Intersection over Union threshold. + """ + + def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"): + """ + Initialize a ConfusionMatrix instance. + + Args: + nc (int): Number of classes. + conf (float, optional): Confidence threshold for detections. + iou_thres (float, optional): IoU threshold for matching detections to ground truth. + task (str, optional): Type of task, either 'detect' or 'classify'. + """ + self.task = task + self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc)) + self.nc = nc # number of classes + self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed + self.iou_thres = iou_thres + + def process_cls_preds(self, preds, targets): + """ + Update confusion matrix for classification task. + + Args: + preds (Array[N, min(nc,5)]): Predicted class labels. + targets (Array[N, 1]): Ground truth class labels. + """ + preds, targets = torch.cat(preds)[:, 0], torch.cat(targets) + for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()): + self.matrix[p][t] += 1 + + def process_batch(self, detections, gt_bboxes, gt_cls): + """ + Update confusion matrix for object detection task. + + Args: + detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information. + Each row should contain (x1, y1, x2, y2, conf, class) + or with an additional element `angle` when it's obb. + gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format. + gt_cls (Array[M]): The class labels. + """ + if gt_cls.shape[0] == 0: # Check if labels is empty + if detections is not None: + detections = detections[detections[:, 4] > self.conf] + detection_classes = detections[:, 5].int() + for dc in detection_classes: + self.matrix[dc, self.nc] += 1 # false positives + return + if detections is None: + gt_classes = gt_cls.int() + for gc in gt_classes: + self.matrix[self.nc, gc] += 1 # background FN + return + + detections = detections[detections[:, 4] > self.conf] + gt_classes = gt_cls.int() + detection_classes = detections[:, 5].int() + is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension + iou = ( + batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + if is_obb + else box_iou(gt_bboxes, detections[:, :4]) + ) + + x = torch.where(iou > self.iou_thres) + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + else: + matches = np.zeros((0, 3)) + + n = matches.shape[0] > 0 + m0, m1, _ = matches.transpose().astype(int) + for i, gc in enumerate(gt_classes): + j = m0 == i + if n and sum(j) == 1: + self.matrix[detection_classes[m1[j]], gc] += 1 # correct + else: + self.matrix[self.nc, gc] += 1 # true background + + for i, dc in enumerate(detection_classes): + if not any(m1 == i): + self.matrix[dc, self.nc] += 1 # predicted background + + def matrix(self): + """Return the confusion matrix.""" + return self.matrix + + def tp_fp(self): + """ + Return true positives and false positives. + + Returns: + (tuple): True positives and false positives. + """ + tp = self.matrix.diagonal() # true positives + fp = self.matrix.sum(1) - tp # false positives + # fn = self.matrix.sum(0) - tp # false negatives (missed detections) + return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect + + @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure") + @plt_settings() + def plot(self, normalize=True, save_dir="", names=(), on_plot=None): + """ + Plot the confusion matrix using seaborn and save it to a file. + + Args: + normalize (bool): Whether to normalize the confusion matrix. + save_dir (str): Directory where the plot will be saved. + names (tuple): Names of classes, used as labels on the plot. + on_plot (func): An optional callback to pass plots path and data when they are rendered. + """ + import seaborn # scope for faster 'import ultralytics' + + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns + array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) + + fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) + nc, nn = self.nc, len(names) # number of classes, names + seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size + labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels + ticklabels = (list(names) + ["background"]) if labels else "auto" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered + seaborn.heatmap( + array, + ax=ax, + annot=nc < 30, + annot_kws={"size": 8}, + cmap="Blues", + fmt=".2f" if normalize else ".0f", + square=True, + vmin=0.0, + xticklabels=ticklabels, + yticklabels=ticklabels, + ).set_facecolor((1, 1, 1)) + title = "Confusion Matrix" + " Normalized" * normalize + ax.set_xlabel("True") + ax.set_ylabel("Predicted") + ax.set_title(title) + plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png" + fig.savefig(plot_fname, dpi=250) + plt.close(fig) + if on_plot: + on_plot(plot_fname) + + def print(self): + """Print the confusion matrix to the console.""" + for i in range(self.matrix.shape[0]): + LOGGER.info(" ".join(map(str, self.matrix[i]))) + + +def smooth(y, f=0.05): + """Box filter of fraction f.""" + nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) + p = np.ones(nf // 2) # ones padding + yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded + return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed + + +@plt_settings() +def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None): + """ + Plot precision-recall curve. + + Args: + px (np.ndarray): X values for the PR curve. + py (np.ndarray): Y values for the PR curve. + ap (np.ndarray): Average precision values. + save_dir (Path, optional): Path to save the plot. + names (dict, optional): Dictionary mapping class indices to class names. + on_plot (callable, optional): Function to call after plot is saved. + """ + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + py = np.stack(py, axis=1) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) + + ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5") + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title("Precision-Recall Curve") + fig.savefig(save_dir, dpi=250) + plt.close(fig) + if on_plot: + on_plot(save_dir) + + +@plt_settings() +def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None): + """ + Plot metric-confidence curve. + + Args: + px (np.ndarray): X values for the metric-confidence curve. + py (np.ndarray): Y values for the metric-confidence curve. + save_dir (Path, optional): Path to save the plot. + names (dict, optional): Dictionary mapping class indices to class names. + xlabel (str, optional): X-axis label. + ylabel (str, optional): Y-axis label. + on_plot (callable, optional): Function to call after plot is saved. + """ + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py): + ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) + else: + ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) + + y = smooth(py.mean(0), 0.05) + ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}") + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f"{ylabel}-Confidence Curve") + fig.savefig(save_dir, dpi=250) + plt.close(fig) + if on_plot: + on_plot(save_dir) + + +def compute_ap(recall, precision): + """ + Compute the average precision (AP) given the recall and precision curves. + + Args: + recall (list): The recall curve. + precision (list): The precision curve. + + Returns: + (float): Average precision. + (np.ndarray): Precision envelope curve. + (np.ndarray): Modified recall curve with sentinel values added at the beginning and end. + """ + # Append sentinel values to beginning and end + mrec = np.concatenate(([0.0], recall, [1.0])) + mpre = np.concatenate(([1.0], precision, [0.0])) + + # Compute the precision envelope + mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) + + # Integrate area under curve + method = "interp" # methods: 'continuous', 'interp' + if method == "interp": + x = np.linspace(0, 1, 101) # 101-point interp (COCO) + ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate + else: # 'continuous' + i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve + + return ap, mpre, mrec + + +def ap_per_class( + tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix="" +): + """ + Compute the average precision per class for object detection evaluation. + + Args: + tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False). + conf (np.ndarray): Array of confidence scores of the detections. + pred_cls (np.ndarray): Array of predicted classes of the detections. + target_cls (np.ndarray): Array of true classes of the detections. + plot (bool, optional): Whether to plot PR curves or not. + on_plot (func, optional): A callback to pass plots path and data when they are rendered. + save_dir (Path, optional): Directory to save the PR curves. + names (dict, optional): Dict of class names to plot PR curves. + eps (float, optional): A small value to avoid division by zero. + prefix (str, optional): A prefix string for saving the plot files. + + Returns: + tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class. + fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. + p (np.ndarray): Precision values at threshold given by max F1 metric for each class. + r (np.ndarray): Recall values at threshold given by max F1 metric for each class. + f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. + ap (np.ndarray): Average precision for each class at different IoU thresholds. + unique_classes (np.ndarray): An array of unique classes that have data. + p_curve (np.ndarray): Precision curves for each class. + r_curve (np.ndarray): Recall curves for each class. + f1_curve (np.ndarray): F1-score curves for each class. + x (np.ndarray): X-axis values for the curves. + prec_values (np.ndarray): Precision values at mAP@0.5 for each class. + """ + # Sort by objectness + i = np.argsort(-conf) + tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] + + # Find unique classes + unique_classes, nt = np.unique(target_cls, return_counts=True) + nc = unique_classes.shape[0] # number of classes, number of detections + + # Create Precision-Recall curve and compute AP for each class + x, prec_values = np.linspace(0, 1, 1000), [] + + # Average precision, precision and recall curves + ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + for ci, c in enumerate(unique_classes): + i = pred_cls == c + n_l = nt[ci] # number of labels + n_p = i.sum() # number of predictions + if n_p == 0 or n_l == 0: + continue + + # Accumulate FPs and TPs + fpc = (1 - tp[i]).cumsum(0) + tpc = tp[i].cumsum(0) + + # Recall + recall = tpc / (n_l + eps) # recall curve + r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases + + # Precision + precision = tpc / (tpc + fpc) # precision curve + p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score + + # AP from recall-precision curve + for j in range(tp.shape[1]): + ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) + if j == 0: + prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5 + + prec_values = np.array(prec_values) if prec_values else np.zeros((1, 1000)) # (nc, 1000) + + # Compute F1 (harmonic mean of precision and recall) + f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps) + names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data + names = dict(enumerate(names)) # to dict + if plot: + plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot) + plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot) + plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot) + plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot) + + i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index + p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values + tp = (r * nt).round() # true positives + fp = (tp / (p + eps) - tp).round() # false positives + return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values + + +class Metric(SimpleClass): + """ + Class for computing evaluation metrics for YOLOv8 model. + + Attributes: + p (list): Precision for each class. Shape: (nc,). + r (list): Recall for each class. Shape: (nc,). + f1 (list): F1 score for each class. Shape: (nc,). + all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). + ap_class_index (list): Index of class for each AP score. Shape: (nc,). + nc (int): Number of classes. + + Methods: + ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or []. + ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or []. + mp(): Mean precision of all classes. Returns: Float. + mr(): Mean recall of all classes. Returns: Float. + map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float. + map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float. + map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float. + mean_results(): Mean of results, returns mp, mr, map50, map. + class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i]. + maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,). + fitness(): Model fitness as a weighted combination of metrics. Returns: Float. + update(results): Update metric attributes with new evaluation results. + """ + + def __init__(self) -> None: + """Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model.""" + self.p = [] # (nc, ) + self.r = [] # (nc, ) + self.f1 = [] # (nc, ) + self.all_ap = [] # (nc, 10) + self.ap_class_index = [] # (nc, ) + self.nc = 0 + + @property + def ap50(self): + """ + Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes. + + Returns: + (np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available. + """ + return self.all_ap[:, 0] if len(self.all_ap) else [] + + @property + def ap(self): + """ + Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes. + + Returns: + (np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available. + """ + return self.all_ap.mean(1) if len(self.all_ap) else [] + + @property + def mp(self): + """ + Return the Mean Precision of all classes. + + Returns: + (float): The mean precision of all classes. + """ + return self.p.mean() if len(self.p) else 0.0 + + @property + def mr(self): + """ + Return the Mean Recall of all classes. + + Returns: + (float): The mean recall of all classes. + """ + return self.r.mean() if len(self.r) else 0.0 + + @property + def map50(self): + """ + Return the mean Average Precision (mAP) at an IoU threshold of 0.5. + + Returns: + (float): The mAP at an IoU threshold of 0.5. + """ + return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 + + @property + def map75(self): + """ + Return the mean Average Precision (mAP) at an IoU threshold of 0.75. + + Returns: + (float): The mAP at an IoU threshold of 0.75. + """ + return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 + + @property + def map(self): + """ + Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05. + + Returns: + (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05. + """ + return self.all_ap.mean() if len(self.all_ap) else 0.0 + + def mean_results(self): + """Return mean of results, mp, mr, map50, map.""" + return [self.mp, self.mr, self.map50, self.map] + + def class_result(self, i): + """Return class-aware result, p[i], r[i], ap50[i], ap[i].""" + return self.p[i], self.r[i], self.ap50[i], self.ap[i] + + @property + def maps(self): + """Return mAP of each class.""" + maps = np.zeros(self.nc) + self.map + for i, c in enumerate(self.ap_class_index): + maps[c] = self.ap[i] + return maps + + def fitness(self): + """Return model fitness as a weighted combination of metrics.""" + w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] + return (np.array(self.mean_results()) * w).sum() + + def update(self, results): + """ + Update the evaluation metrics with a new set of results. + + Args: + results (tuple): A tuple containing evaluation metrics: + - p (list): Precision for each class. + - r (list): Recall for each class. + - f1 (list): F1 score for each class. + - all_ap (list): AP scores for all classes and all IoU thresholds. + - ap_class_index (list): Index of class for each AP score. + - p_curve (list): Precision curve for each class. + - r_curve (list): Recall curve for each class. + - f1_curve (list): F1 curve for each class. + - px (list): X values for the curves. + - prec_values (list): Precision values for each class. + """ + ( + self.p, + self.r, + self.f1, + self.all_ap, + self.ap_class_index, + self.p_curve, + self.r_curve, + self.f1_curve, + self.px, + self.prec_values, + ) = results + + @property + def curves(self): + """Return a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Return a list of curves for accessing specific metrics curves.""" + return [ + [self.px, self.prec_values, "Recall", "Precision"], + [self.px, self.f1_curve, "Confidence", "F1"], + [self.px, self.p_curve, "Confidence", "Precision"], + [self.px, self.r_curve, "Confidence", "Recall"], + ] + + +class DetMetrics(SimpleClass): + """ + Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP). + + Attributes: + save_dir (Path): A path to the directory where the output plots will be saved. + plot (bool): A flag that indicates whether to plot precision-recall curves for each class. + names (dict): A dictionary of class names. + box (Metric): An instance of the Metric class for storing detection results. + speed (dict): A dictionary for storing execution times of different parts of the detection process. + task (str): The task type, set to 'detect'. + """ + + def __init__(self, save_dir=Path("."), plot=False, names={}) -> None: + """ + Initialize a DetMetrics instance with a save directory, plot flag, and class names. + + Args: + save_dir (Path, optional): Directory to save plots. + plot (bool, optional): Whether to plot precision-recall curves. + names (dict, optional): Dictionary mapping class indices to names. + """ + self.save_dir = save_dir + self.plot = plot + self.names = names + self.box = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "detect" + + def process(self, tp, conf, pred_cls, target_cls, on_plot=None): + """ + Process predicted results for object detection and update metrics. + + Args: + tp (np.ndarray): True positive array. + conf (np.ndarray): Confidence array. + pred_cls (np.ndarray): Predicted class indices array. + target_cls (np.ndarray): Target class indices array. + on_plot (callable, optional): Function to call after plots are generated. + """ + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=on_plot, + )[2:] + self.box.nc = len(self.names) + self.box.update(results) + + @property + def keys(self): + """Return a list of keys for accessing specific metrics.""" + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] + + def mean_results(self): + """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" + return self.box.mean_results() + + def class_result(self, i): + """Return the result of evaluating the performance of an object detection model on a specific class.""" + return self.box.class_result(i) + + @property + def maps(self): + """Return mean Average Precision (mAP) scores per class.""" + return self.box.maps + + @property + def fitness(self): + """Return the fitness of box object.""" + return self.box.fitness() + + @property + def ap_class_index(self): + """Return the average precision index per class.""" + return self.box.ap_class_index + + @property + def results_dict(self): + """Return dictionary of computed performance metrics and statistics.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Return a list of curves for accessing specific metrics curves.""" + return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"] + + @property + def curves_results(self): + """Return dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + + +class SegmentMetrics(SimpleClass): + """ + Calculates and aggregates detection and segmentation metrics over a given set of classes. + + Attributes: + save_dir (Path): Path to the directory where the output plots should be saved. + plot (bool): Whether to save the detection and segmentation plots. + names (dict): Dictionary of class names. + box (Metric): An instance of the Metric class to calculate box detection metrics. + seg (Metric): An instance of the Metric class to calculate mask segmentation metrics. + speed (dict): Dictionary to store the time taken in different phases of inference. + task (str): The task type, set to 'segment'. + """ + + def __init__(self, save_dir=Path("."), plot=False, names=()) -> None: + """ + Initialize a SegmentMetrics instance with a save directory, plot flag, and class names. + + Args: + save_dir (Path, optional): Directory to save plots. + plot (bool, optional): Whether to plot precision-recall curves. + names (dict, optional): Dictionary mapping class indices to names. + """ + self.save_dir = save_dir + self.plot = plot + self.names = names + self.box = Metric() + self.seg = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "segment" + + def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None): + """ + Process the detection and segmentation metrics over the given set of predictions. + + Args: + tp (np.ndarray): True positive array for boxes. + tp_m (np.ndarray): True positive array for masks. + conf (np.ndarray): Confidence array. + pred_cls (np.ndarray): Predicted class indices array. + target_cls (np.ndarray): Target class indices array. + on_plot (callable, optional): Function to call after plots are generated. + """ + results_mask = ap_per_class( + tp_m, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Mask", + )[2:] + self.seg.nc = len(self.names) + self.seg.update(results_mask) + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] + self.box.nc = len(self.names) + self.box.update(results_box) + + @property + def keys(self): + """Return a list of keys for accessing metrics.""" + return [ + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(M)", + "metrics/recall(M)", + "metrics/mAP50(M)", + "metrics/mAP50-95(M)", + ] + + def mean_results(self): + """Return the mean metrics for bounding box and segmentation results.""" + return self.box.mean_results() + self.seg.mean_results() + + def class_result(self, i): + """Return classification results for a specified class index.""" + return self.box.class_result(i) + self.seg.class_result(i) + + @property + def maps(self): + """Return mAP scores for object detection and semantic segmentation models.""" + return self.box.maps + self.seg.maps + + @property + def fitness(self): + """Return the fitness score for both segmentation and bounding box models.""" + return self.seg.fitness() + self.box.fitness() + + @property + def ap_class_index(self): + """ + Return the class indices. + + Boxes and masks have the same ap_class_index. + """ + return self.box.ap_class_index + + @property + def results_dict(self): + """Return results of object detection model for evaluation.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Return a list of curves for accessing specific metrics curves.""" + return [ + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(M)", + "F1-Confidence(M)", + "Precision-Confidence(M)", + "Recall-Confidence(M)", + ] + + @property + def curves_results(self): + """Return dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + self.seg.curves_results + + +class PoseMetrics(SegmentMetrics): + """ + Calculates and aggregates detection and pose metrics over a given set of classes. + + Attributes: + save_dir (Path): Path to the directory where the output plots should be saved. + plot (bool): Whether to save the detection and pose plots. + names (dict): Dictionary of class names. + box (Metric): An instance of the Metric class to calculate box detection metrics. + pose (Metric): An instance of the Metric class to calculate pose metrics. + speed (dict): Dictionary to store the time taken in different phases of inference. + task (str): The task type, set to 'pose'. + + Methods: + process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions. + mean_results(): Returns the mean of the detection and segmentation metrics over all the classes. + class_result(i): Returns the detection and segmentation metrics of class `i`. + maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95. + fitness: Returns the fitness scores, which are a single weighted combination of metrics. + ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP). + results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. + """ + + def __init__(self, save_dir=Path("."), plot=False, names=()) -> None: + """ + Initialize the PoseMetrics class with directory path, class names, and plotting options. + + Args: + save_dir (Path, optional): Directory to save plots. + plot (bool, optional): Whether to plot precision-recall curves. + names (dict, optional): Dictionary mapping class indices to names. + """ + super().__init__(save_dir, plot, names) + self.save_dir = save_dir + self.plot = plot + self.names = names + self.box = Metric() + self.pose = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "pose" + + def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None): + """ + Process the detection and pose metrics over the given set of predictions. + + Args: + tp (np.ndarray): True positive array for boxes. + tp_p (np.ndarray): True positive array for keypoints. + conf (np.ndarray): Confidence array. + pred_cls (np.ndarray): Predicted class indices array. + target_cls (np.ndarray): Target class indices array. + on_plot (callable, optional): Function to call after plots are generated. + """ + results_pose = ap_per_class( + tp_p, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Pose", + )[2:] + self.pose.nc = len(self.names) + self.pose.update(results_pose) + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] + self.box.nc = len(self.names) + self.box.update(results_box) + + @property + def keys(self): + """Return list of evaluation metric keys.""" + return [ + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(P)", + "metrics/recall(P)", + "metrics/mAP50(P)", + "metrics/mAP50-95(P)", + ] + + def mean_results(self): + """Return the mean results of box and pose.""" + return self.box.mean_results() + self.pose.mean_results() + + def class_result(self, i): + """Return the class-wise detection results for a specific class i.""" + return self.box.class_result(i) + self.pose.class_result(i) + + @property + def maps(self): + """Return the mean average precision (mAP) per class for both box and pose detections.""" + return self.box.maps + self.pose.maps + + @property + def fitness(self): + """Return combined fitness score for pose and box detection.""" + return self.pose.fitness() + self.box.fitness() + + @property + def curves(self): + """Return a list of curves for accessing specific metrics curves.""" + return [ + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(P)", + "F1-Confidence(P)", + "Precision-Confidence(P)", + "Recall-Confidence(P)", + ] + + @property + def curves_results(self): + """Return dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + self.pose.curves_results + + +class ClassifyMetrics(SimpleClass): + """ + Class for computing classification metrics including top-1 and top-5 accuracy. + + Attributes: + top1 (float): The top-1 accuracy. + top5 (float): The top-5 accuracy. + speed (dict): A dictionary containing the time taken for each step in the pipeline. + task (str): The task type, set to 'classify'. + """ + + def __init__(self) -> None: + """Initialize a ClassifyMetrics instance.""" + self.top1 = 0 + self.top5 = 0 + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "classify" + + def process(self, targets, pred): + """ + Process target classes and predicted classes to compute metrics. + + Args: + targets (torch.Tensor): Target classes. + pred (torch.Tensor): Predicted classes. + """ + pred, targets = torch.cat(pred), torch.cat(targets) + correct = (targets[:, None] == pred).float() + acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy + self.top1, self.top5 = acc.mean(0).tolist() + + @property + def fitness(self): + """Return mean of top-1 and top-5 accuracies as fitness score.""" + return (self.top1 + self.top5) / 2 + + @property + def results_dict(self): + """Return a dictionary with model's performance metrics and fitness score.""" + return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness])) + + @property + def keys(self): + """Return a list of keys for the results_dict property.""" + return ["metrics/accuracy_top1", "metrics/accuracy_top5"] + + @property + def curves(self): + """Return a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Return a list of curves for accessing specific metrics curves.""" + return [] + + +class OBBMetrics(SimpleClass): + """ + Metrics for evaluating oriented bounding box (OBB) detection. + + Attributes: + save_dir (Path): Path to the directory where the output plots should be saved. + plot (bool): Whether to save the detection plots. + names (dict): Dictionary of class names. + box (Metric): An instance of the Metric class for storing detection results. + speed (dict): A dictionary for storing execution times of different parts of the detection process. + + References: + https://arxiv.org/pdf/2106.06072.pdf + """ + + def __init__(self, save_dir=Path("."), plot=False, names=()) -> None: + """ + Initialize an OBBMetrics instance with directory, plotting, and class names. + + Args: + save_dir (Path, optional): Directory to save plots. + plot (bool, optional): Whether to plot precision-recall curves. + names (dict, optional): Dictionary mapping class indices to names. + """ + self.save_dir = save_dir + self.plot = plot + self.names = names + self.box = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + + def process(self, tp, conf, pred_cls, target_cls, on_plot=None): + """ + Process predicted results for object detection and update metrics. + + Args: + tp (np.ndarray): True positive array. + conf (np.ndarray): Confidence array. + pred_cls (np.ndarray): Predicted class indices array. + target_cls (np.ndarray): Target class indices array. + on_plot (callable, optional): Function to call after plots are generated. + """ + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=on_plot, + )[2:] + self.box.nc = len(self.names) + self.box.update(results) + + @property + def keys(self): + """Return a list of keys for accessing specific metrics.""" + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] + + def mean_results(self): + """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" + return self.box.mean_results() + + def class_result(self, i): + """Return the result of evaluating the performance of an object detection model on a specific class.""" + return self.box.class_result(i) + + @property + def maps(self): + """Return mean Average Precision (mAP) scores per class.""" + return self.box.maps + + @property + def fitness(self): + """Return the fitness of box object.""" + return self.box.fitness() + + @property + def ap_class_index(self): + """Return the average precision index per class.""" + return self.box.ap_class_index + + @property + def results_dict(self): + """Return dictionary of computed performance metrics and statistics.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Return a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Return a list of curves for accessing specific metrics curves.""" + return [] diff --git a/tracking/ultralytics/utils/ops.py b/tracking/ultralytics/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0c8c07b7ea0b9011e50fc7defb78eaa6753203 --- /dev/null +++ b/tracking/ultralytics/utils/ops.py @@ -0,0 +1,875 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import math +import re +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou + + +class Profile(contextlib.ContextDecorator): + """ + YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. + + Attributes: + t (float): Accumulated time. + device (torch.device): Device used for model inference. + cuda (bool): Whether CUDA is being used. + + Examples: + >>> from ultralytics.utils.ops import Profile + >>> with Profile(device=device) as dt: + ... pass # slow operation here + >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" + """ + + def __init__(self, t=0.0, device: torch.device = None): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. + device (torch.device): Device used for model inference. + """ + self.t = t + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) + + def __enter__(self): + """Start timing.""" + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): # noqa + """Stop timing.""" + self.dt = self.time() - self.start # delta-time + self.t += self.dt # accumulate dt + + def __str__(self): + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" + + def time(self): + """Get current time.""" + if self.cuda: + torch.cuda.synchronize(self.device) + return time.perf_counter() + + +def segment2box(segment, width=640, height=640): + """ + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy). + + Args: + segment (torch.Tensor): The segment label. + width (int): The width of the image. + height (int): The height of the image. + + Returns: + (np.ndarray): The minimum and maximum x and y values of the segment. + """ + x, y = segment.T # segment xy + # any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294 + if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3: + x = x.clip(0, width) + y = y.clip(0, height) + inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) + x = x[inside] + y = y[inside] + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + """ + Rescale bounding boxes from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): The bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2). + img0_shape (tuple): The shape of the target image, in the format of (height, width). + ratio_pad (tuple): A tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + xywh (bool): The box format is xywh or not. + + Returns: + (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2). + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def make_divisible(x, divisor): + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int | torch.Tensor): The divisor. + + Returns: + (int): The nearest number divisible by the divisor. + """ + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + + +def nms_rotated(boxes, scores, threshold=0.45, use_triu=True): + """ + NMS for oriented bounding boxes using probiou and fast-nms. + + Args: + boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr. + scores (torch.Tensor): Confidence scores, shape (N,). + threshold (float): IoU threshold. + use_triu (bool): Whether to use `torch.triu` operator. It'd be useful for disable it + when exporting obb models to some formats that do not support `torch.triu`. + + Returns: + (torch.Tensor): Indices of boxes to keep after NMS. + """ + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes) + if use_triu: + ious = ious.triu_(diagonal=1) + # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition + pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1) + else: + n = boxes.shape[0] + row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n) + col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1) + upper_mask = row_idx < col_idx + ious = ious * upper_mask + # Zeroing these scores ensures the additional indices would not affect the final results + scores[~((ious >= threshold).sum(0) <= 0)] = 0 + # NOTE: return indices with fixed length to avoid TFLite reshape error + pick = torch.topk(scores, scores.shape[0]).indices + return sorted_idx[pick] + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, + end2end=False, +): + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Args: + prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nc (int): The number of classes output by the model. Any indices after this will be considered masks. + max_time_img (float): The maximum time (seconds) for processing one image. + max_nms (int): The maximum number of boxes into torchvision.ops.nms(). + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. + end2end (bool): If the model doesn't require NMS. + + Returns: + (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). + """ + import torchvision # scope for faster 'import ultralytics' + + # Checks + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) + nc = nc or (prediction.shape[1] - 4) # number of classes + nm = prediction.shape[1] - nc - 4 # number of masks + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + time_limit = 2.0 + max_time_img * bs # seconds to quit after + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + + prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]) and not rotated: + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + i = i[:max_det] # limit detections + + # # Experimental + # merge = False # use merge-NMS + # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + # from .metrics import box_iou + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix + # weights = iou * scores[None] # box weights + # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + # redundant = True # require redundant detections + # if redundant: + # i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def clip_boxes(boxes, shape): + """ + Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. + + Args: + boxes (torch.Tensor | numpy.ndarray): The bounding boxes to clip. + shape (tuple): The shape of the image. + + Returns: + (torch.Tensor | numpy.ndarray): The clipped boxes. + """ + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes + + +def clip_coords(coords, shape): + """ + Clip line coordinates to the image boundaries. + + Args: + coords (torch.Tensor | numpy.ndarray): A list of line coordinates. + shape (tuple): A tuple of integers representing the size of the image in the format (height, width). + + Returns: + (torch.Tensor | numpy.ndarray): Clipped coordinates. + """ + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y + else: # np.array (faster grouped) + coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords + + +def scale_image(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size. + + Args: + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned with shape [h, w, num]. + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + # gain = ratio_pad[0][0] + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + xy = x[..., :2] # centers + wh = x[..., 2:] / 2 # half width-height + y[..., :2] = xy - wh # top left xy + y[..., 2:] = xy + wh # bottom right xy + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + """ + Convert normalized bounding box coordinates to pixel coordinates. + + Args: + x (np.ndarray | torch.Tensor): The bounding box coordinates. + w (int): Width of the image. + h (int): Height of the image. + padw (int): Padding width. + padh (int): Padding height. + + Returns: + y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + w (int): The width of the image. + h (int): The height of the image. + clip (bool): If True, the boxes will be clipped to the image boundaries. + eps (float): The minimum value of the box's width and height. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + """ + if clip: + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xywh2ltwh(x): + """ + Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + return y + + +def xyxy2ltwh(x): + """ + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def ltwh2xywh(x): + """ + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. + + Args: + x (torch.Tensor): the input tensor + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x + y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y + return y + + +def xyxyxyxy2xywhr(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + returned in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8). + + Returns: + (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). + """ + is_torch = isinstance(x, torch.Tensor) + points = x.cpu().numpy() if is_torch else x + points = points.reshape(len(x), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (cx, cy), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) + return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + cos, sin, cat, stack = ( + (torch.cos, torch.sin, torch.cat, torch.stack) + if isinstance(x, torch.Tensor) + else (np.cos, np.sin, np.concatenate, np.stack) + ) + + ctr = x[..., :2] + w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = cat(vec1, -1) + vec2 = cat(vec2, -1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return stack([pt1, pt2, pt3, pt4], -2) + + +def ltwh2xyxy(x): + """ + Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input image. + + Returns: + (np.ndarray | torch.Tensor): The xyxy coordinates of the bounding boxes. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] + x[..., 0] # width + y[..., 3] = x[..., 3] + x[..., 1] # height + return y + + +def segments2boxes(segments): + """ + Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). + + Args: + segments (list): List of segments, each segment is a list of points, each point is a list of x, y coordinates. + + Returns: + (np.ndarray): The xywh coordinates of the bounding boxes. + """ + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + """ + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. + + Args: + segments (list): A list of (n,2) arrays, where n is the number of points in the segment. + n (int): Number of points to resample the segment to. + + Returns: + segments (list): The resampled segments. + """ + for i, s in enumerate(segments): + if len(s) == n: + continue + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n) + xp = np.arange(len(s)) + x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy + return segments + + +def crop_mask(masks, boxes): + """ + Crop masks to bounding boxes. + + Args: + masks (torch.Tensor): [n, h, w] tensor of masks. + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form. + + Returns: + (torch.Tensor): Cropped masks. + """ + _, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. + + Returns: + (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW + return masks.gt_(0.0) + + +def process_mask_native(protos, masks_in, bboxes, shape): + """ + Apply masks to bounding boxes using the output of the mask head with native upsampling. + + Args: + protos (torch.Tensor): [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). + + Returns: + (torch.Tensor): The returned masks with dimensions [h, w, n]. + """ + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) + masks = scale_masks(masks[None], shape)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.0) + + +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + (torch.Tensor): Rescaled masks. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): + """ + Rescale segment coordinates (xy) from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the coords are from. + coords (torch.Tensor): The coords to be scaled of shape n,2. + img0_shape (tuple): The shape of the image that the segmentation is being applied to. + ratio_pad (tuple): The ratio of the image size to the padded image size. + normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + coords (torch.Tensor): The scaled coordinates. + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding + coords[..., 0] /= gain + coords[..., 1] /= gain + coords = clip_coords(coords, img0_shape) + if normalize: + coords[..., 0] /= img0_shape[1] # width + coords[..., 1] /= img0_shape[0] # height + return coords + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge if t >= pi/2 while not being symmetrically opposite + swap = t % math.pi >= math.pi / 2 + w_ = torch.where(swap, h, w) + h_ = torch.where(swap, w, h) + t = t % (math.pi / 2) + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="all"): + """ + Convert masks to segments. + + Args: + masks (torch.Tensor): The output of the model, which is a tensor of shape (batch_size, 160, 160). + strategy (str): 'all' or 'largest'. + + Returns: + (list): List of segment masks. + """ + from ultralytics.data.converter import merge_multi_segment + + segments = [] + for x in masks.int().cpu().numpy().astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if c: + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) + elif strategy == "largest": # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + +def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: + """ + Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout. + + Args: + batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32. + + Returns: + (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8. + """ + return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + +def clean_str(s): + """ + Cleans a string by replacing special characters with '_' character. + + Args: + s (str): A string needing special characters replaced. + + Returns: + (str): A string with special characters replaced by an underscore _. + """ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + + +def empty_like(x): + """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.""" + return ( + torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) + ) diff --git a/tracking/ultralytics/utils/patches.py b/tracking/ultralytics/utils/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..11b8d927e981d9fadf7438aaeb89242451aae242 --- /dev/null +++ b/tracking/ultralytics/utils/patches.py @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Monkey patches to update/extend functionality of existing functions.""" + +import time +from pathlib import Path + +import cv2 +import numpy as np +import torch + +# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------ +_imshow = cv2.imshow # copy to avoid recursion errors + + +def imread(filename: str, flags: int = cv2.IMREAD_COLOR): + """ + Read an image from a file. + + Args: + filename (str): Path to the file to read. + flags (int, optional): Flag that can take values of cv2.IMREAD_*. + + Returns: + (np.ndarray): The read image. + """ + return cv2.imdecode(np.fromfile(filename, np.uint8), flags) + + +def imwrite(filename: str, img: np.ndarray, params=None): + """ + Write an image to a file. + + Args: + filename (str): Path to the file to write. + img (np.ndarray): Image to write. + params (List[int], optional): Additional parameters for image encoding. + + Returns: + (bool): True if the file was written, False otherwise. + """ + try: + cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename) + return True + except Exception: + return False + + +def imshow(winname: str, mat: np.ndarray): + """ + Display an image in the specified window. + + Args: + winname (str): Name of the window. + mat (np.ndarray): Image to be shown. + """ + _imshow(winname.encode("unicode_escape").decode(), mat) + + +# PyTorch functions ---------------------------------------------------------------------------------------------------- +_torch_load = torch.load # copy to avoid recursion errors +_torch_save = torch.save + + +def torch_load(*args, **kwargs): + """ + Load a PyTorch model with updated arguments to avoid warnings. + + This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings. + + Args: + *args (Any): Variable length argument list to pass to torch.load. + **kwargs (Any): Arbitrary keyword arguments to pass to torch.load. + + Returns: + (Any): The loaded PyTorch object. + + Note: + For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False' + if the argument is not provided, to avoid deprecation warnings. + """ + from ultralytics.utils.torch_utils import TORCH_1_13 + + if TORCH_1_13 and "weights_only" not in kwargs: + kwargs["weights_only"] = False + + return _torch_load(*args, **kwargs) + + +def torch_save(*args, **kwargs): + """ + Save PyTorch objects with retry mechanism for robustness. + + This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur + due to device flushing delays or antivirus scanning. + + Args: + *args (Any): Positional arguments to pass to torch.save. + **kwargs (Any): Keyword arguments to pass to torch.save. + """ + for i in range(4): # 3 retries + try: + return _torch_save(*args, **kwargs) + except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan + if i == 3: + raise e + time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s diff --git a/tracking/ultralytics/utils/plotting.py b/tracking/ultralytics/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..985441613e955fd67b34f230470e564202c4941e --- /dev/null +++ b/tracking/ultralytics/utils/plotting.py @@ -0,0 +1,1005 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +import warnings +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL import __version__ as pil_version + +from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded +from ultralytics.utils.checks import check_font, check_version, is_ascii +from ultralytics.utils.files import increment_path + + +class Colors: + """ + Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (List[Tuple]): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8. + + Examples: + >>> from ultralytics.utils.plotting import Colors + >>> colors = Colors() + >>> colors(5, True) # ff6fdd or (255, 111, 221) + + ## Ultralytics Color Palette + + | Index | Color | HEX | RGB | + |-------|-------------------------------------------------------------------|-----------|-------------------| + | 0 | | `#042aff` | (4, 42, 255) | + | 1 | | `#0bdbeb` | (11, 219, 235) | + | 2 | | `#f3f3f3` | (243, 243, 243) | + | 3 | | `#00dfb7` | (0, 223, 183) | + | 4 | | `#111f68` | (17, 31, 104) | + | 5 | | `#ff6fdd` | (255, 111, 221) | + | 6 | | `#ff444f` | (255, 68, 79) | + | 7 | | `#cced00` | (204, 237, 0) | + | 8 | | `#00f344` | (0, 243, 68) | + | 9 | | `#bd00ff` | (189, 0, 255) | + | 10 | | `#00b4ff` | (0, 180, 255) | + | 11 | | `#dd00ba` | (221, 0, 186) | + | 12 | | `#00ffff` | (0, 255, 255) | + | 13 | | `#26c000` | (38, 192, 0) | + | 14 | | `#01ffb3` | (1, 255, 179) | + | 15 | | `#7d24ff` | (125, 36, 255) | + | 16 | | `#7b0068` | (123, 0, 104) | + | 17 | | `#ff1b6c` | (255, 27, 108) | + | 18 | | `#fc6d2f` | (252, 109, 47) | + | 19 | | `#a2ff0b` | (162, 255, 11) | + + ## Pose Color Palette + + | Index | Color | HEX | RGB | + |-------|-------------------------------------------------------------------|-----------|-------------------| + | 0 | | `#ff8000` | (255, 128, 0) | + | 1 | | `#ff9933` | (255, 153, 51) | + | 2 | | `#ffb266` | (255, 178, 102) | + | 3 | | `#e6e600` | (230, 230, 0) | + | 4 | | `#ff99ff` | (255, 153, 255) | + | 5 | | `#99ccff` | (153, 204, 255) | + | 6 | | `#ff66ff` | (255, 102, 255) | + | 7 | | `#ff33ff` | (255, 51, 255) | + | 8 | | `#66b2ff` | (102, 178, 255) | + | 9 | | `#3399ff` | (51, 153, 255) | + | 10 | | `#ff9999` | (255, 153, 153) | + | 11 | | `#ff6666` | (255, 102, 102) | + | 12 | | `#ff3333` | (255, 51, 51) | + | 13 | | `#99ff99` | (153, 255, 153) | + | 14 | | `#66ff66` | (102, 255, 102) | + | 15 | | `#33ff33` | (51, 255, 51) | + | 16 | | `#00ff00` | (0, 255, 0) | + | 17 | | `#0000ff` | (0, 0, 255) | + | 18 | | `#ff0000` | (255, 0, 0) | + | 19 | | `#ffffff` | (255, 255, 255) | + + !!! note "Ultralytics Brand Colors" + + For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "042AFF", + "0BDBEB", + "F3F3F3", + "00DFB7", + "111F68", + "FF6FDD", + "FF444F", + "CCED00", + "00F344", + "BD00FF", + "00B4FF", + "DD00BA", + "00FFFF", + "26C000", + "01FFB3", + "7D24FF", + "7B0068", + "FF1B6C", + "FC6D2F", + "A2FF0B", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Convert hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Convert hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() # create instance for 'from utils.plots import colors' + + +class Annotator: + """ + Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations. + + Attributes: + im (Image.Image or np.ndarray): The image to annotate. + pil (bool): Whether to use PIL or cv2 for drawing annotations. + font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations. + lw (float): Line width for drawing. + skeleton (List[List[int]]): Skeleton structure for keypoints. + limb_color (List[int]): Color palette for limbs. + kpt_color (List[int]): Color palette for keypoints. + dark_colors (set): Set of colors considered dark for text contrast. + light_colors (set): Set of colors considered light for text contrast. + + Examples: + >>> from ultralytics.utils.plotting import Annotator + >>> im0 = cv2.imread("test.png") + >>> annotator = Annotator(im0, line_width=10) + """ + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.""" + non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic + input_is_pil = isinstance(im, Image.Image) + self.pil = pil or non_ascii or input_is_pil + self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2) + if self.pil: # use PIL + self.im = im if input_is_pil else Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + try: + font = check_font("Arial.Unicode.ttf" if non_ascii else font) + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + except Exception: + self.font = ImageFont.load_default() + # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string) + if check_version(pil_version, "9.2.0"): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height + else: # use cv2 + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + # Pose + self.skeleton = [ + [16, 14], + [14, 12], + [17, 15], + [15, 13], + [12, 13], + [6, 12], + [7, 13], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [2, 3], + [1, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [5, 7], + ] + + self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] + self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] + self.dark_colors = { + (235, 219, 11), + (243, 243, 243), + (183, 223, 0), + (221, 111, 255), + (0, 237, 204), + (68, 243, 0), + (255, 255, 0), + (179, 255, 1), + (11, 255, 162), + } + self.light_colors = { + (255, 42, 4), + (79, 68, 255), + (255, 0, 189), + (255, 180, 0), + (186, 0, 221), + (0, 192, 38), + (255, 36, 125), + (104, 0, 123), + (108, 27, 255), + (47, 109, 252), + (104, 31, 17), + } + + def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)): + """ + Assign text color based on background color. + + Args: + color (tuple, optional): The background color of the rectangle for text (B, G, R). + txt_color (tuple, optional): The color of the text (R, G, B). + + Returns: + (tuple): Text color for label. + + Examples: + >>> from ultralytics.utils.plotting import Annotator + >>> im0 = cv2.imread("test.png") + >>> annotator = Annotator(im0, line_width=10) + >>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255) + """ + if color in self.dark_colors: + return 104, 31, 17 + elif color in self.light_colors: + return 255, 255, 255 + else: + return txt_color + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """ + Draw a bounding box on an image with a given label. + + Args: + box (tuple): The bounding box coordinates (x1, y1, x2, y2). + label (str, optional): The text label to be displayed. + color (tuple, optional): The background color of the rectangle (B, G, R). + txt_color (tuple, optional): The color of the text (R, G, B). + rotated (bool, optional): Whether the task is oriented bounding box detection. + + Examples: + >>> from ultralytics.utils.plotting import Annotator + >>> im0 = cv2.imread("test.png") + >>> annotator = Annotator(im0, line_width=10) + >>> annotator.box_label(box=[10, 20, 30, 40], label="person") + """ + txt_color = self.get_txt_color(color, txt_color) + if isinstance(box, torch.Tensor): + box = box.tolist() + if self.pil or not is_ascii(label): + if rotated: + p1 = box[0] + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box + if label: + w, h = self.font.getsize(label) # text width, height + outside = p1[1] >= h # label fits outside box + if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image + p1 = self.im.size[0] - w, p1[1] + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, + ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) + else: # cv2 + if rotated: + p1 = [int(b) for b in box[0]] + cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box + else: + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + h += 3 # add pixels to pad text + outside = p1[1] >= h # label fits outside box + if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image + p1 = self.im.shape[1] - w, p1[1] + p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h - 1), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False): + """ + Plot masks on image. + + Args: + masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w] + colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n] + im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1] + alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque. + retina_masks (bool, optional): Whether to use high resolution masks or not. + """ + if self.pil: + # Convert to numpy first + self.im = np.asarray(self.im).copy() + if len(masks) == 0: + self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 + if im_gpu.device != masks.device: + im_gpu = im_gpu.to(masks.device) + colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3) + colors = colors[:, None, None] # shape(n,1,1,3) + masks = masks.unsqueeze(3) # shape(n,h,w,1) + masks_color = masks * (colors * alpha) # shape(n,h,w,3) + + inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1) + mcs = masks_color.max(dim=0).values # shape(n,h,w,3) + + im_gpu = im_gpu.flip(dims=[0]) # flip channel + im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3) + im_gpu = im_gpu * inv_alpha_masks[-1] + mcs + im_mask = im_gpu * 255 + im_mask_np = im_mask.byte().cpu().numpy() + self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape) + if self.pil: + # Convert im back to PIL and update draw + self.fromarray(self.im) + + def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None): + """ + Plot keypoints on the image. + + Args: + kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence). + shape (tuple, optional): Image shape (h, w). + radius (int, optional): Keypoint radius. + kpt_line (bool, optional): Draw lines between keypoints. + conf_thres (float, optional): Confidence threshold. + kpt_color (tuple, optional): Keypoint color (B, G, R). + + Note: + - `kpt_line=True` currently only supports human pose plotting. + - Modifies self.im in-place. + - If self.pil is True, converts image to numpy array and back to PIL. + """ + radius = radius if radius is not None else self.lw + if self.pil: + # Convert to numpy first + self.im = np.asarray(self.im).copy() + nkpt, ndim = kpts.shape + is_pose = nkpt == 17 and ndim in {2, 3} + kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting + for i, k in enumerate(kpts): + color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i)) + x_coord, y_coord = k[0], k[1] + if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: + if len(k) == 3: + conf = k[2] + if conf < conf_thres: + continue + cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA) + + if kpt_line: + ndim = kpts.shape[-1] + for i, sk in enumerate(self.skeleton): + pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1])) + pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1])) + if ndim == 3: + conf1 = kpts[(sk[0] - 1), 2] + conf2 = kpts[(sk[1] - 1), 2] + if conf1 < conf_thres or conf2 < conf_thres: + continue + if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0: + continue + if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0: + continue + cv2.line( + self.im, + pos1, + pos2, + kpt_color or self.limb_color[i].tolist(), + thickness=int(np.ceil(self.lw / 2)), + lineType=cv2.LINE_AA, + ) + if self.pil: + # Convert im back to PIL and update draw + self.fromarray(self.im) + + def rectangle(self, xy, fill=None, outline=None, width=1): + """Add rectangle to image (PIL-only).""" + self.draw.rectangle(xy, fill, outline, width) + + def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): + """ + Add text to an image using PIL or cv2. + + Args: + xy (List[int]): Top-left coordinates for text placement. + text (str): Text to be drawn. + txt_color (tuple, optional): Text color (R, G, B). + anchor (str, optional): Text anchor position ('top' or 'bottom'). + box_style (bool, optional): Whether to draw text with a background box. + """ + if anchor == "bottom": # start y from font bottom + w, h = self.font.getsize(text) # text width, height + xy[1] += 1 - h + if self.pil: + if box_style: + w, h = self.font.getsize(text) + self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color) + # Using `txt_color` for background and draw fg with white color + txt_color = (255, 255, 255) + if "\n" in text: + lines = text.split("\n") + _, h = self.font.getsize(text) + for line in lines: + self.draw.text(xy, line, fill=txt_color, font=self.font) + xy[1] += h + else: + self.draw.text(xy, text, fill=txt_color, font=self.font) + else: + if box_style: + w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + h += 3 # add pixels to pad text + outside = xy[1] >= h # label fits outside box + p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h + cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled + # Using `txt_color` for background and draw fg with white color + txt_color = (255, 255, 255) + cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA) + + def fromarray(self, im): + """Update self.im from a numpy array.""" + self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) + + def show(self, title=None): + """Show the annotated image.""" + im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR + if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments + try: + display(im) # noqa - display() function only available in ipython environments + except ImportError as e: + LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}") + else: + im.show(title=title) + + def save(self, filename="image.jpg"): + """Save the annotated image to 'filename'.""" + cv2.imwrite(filename, np.asarray(self.im)) + + @staticmethod + def get_bbox_dimension(bbox=None): + """ + Calculate the dimensions and area of a bounding box. + + Args: + bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max). + + Returns: + width (float): Width of the bounding box. + height (float): Height of the bounding box. + area (float): Area enclosed by the bounding box. + + Examples: + >>> from ultralytics.utils.plotting import Annotator + >>> im0 = cv2.imread("test.png") + >>> annotator = Annotator(im0, line_width=10) + >>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40]) + """ + x_min, y_min, x_max, y_max = bbox + width = x_max - x_min + height = y_max - y_min + return width, height, width * height + + +@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 +@plt_settings() +def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): + """ + Plot training labels including class histograms and box statistics. + + Args: + boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height]. + cls (np.ndarray): Class indices. + names (dict, optional): Dictionary mapping class indices to class names. + save_dir (Path, optional): Directory to save the plot. + on_plot (Callable, optional): Function to call after plot is saved. + """ + import pandas # scope for faster 'import ultralytics' + import seaborn # scope for faster 'import ultralytics' + + # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings + warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") + warnings.filterwarnings("ignore", category=FutureWarning) + + # Plot dataset labels + LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") + nc = int(cls.max() + 1) # number of classes + boxes = boxes[:1000000] # limit to 1M boxes + x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"]) + + # Seaborn correlogram + seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200) + plt.close() + + # Matplotlib labels + ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() + y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) + for i in range(nc): + y[2].patches[i].set_color([x / 255 for x in colors(i)]) + ax[0].set_ylabel("instances") + if 0 < len(names) < 30: + ax[0].set_xticks(range(len(names))) + ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10) + else: + ax[0].set_xlabel("classes") + seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9) + seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9) + + # Rectangles + boxes[:, 0:2] = 0.5 # center + boxes = ops.xywh2xyxy(boxes) * 1000 + img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255) + for cls, box in zip(cls[:500], boxes[:500]): + ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot + ax[1].imshow(img) + ax[1].axis("off") + + for a in [0, 1, 2, 3]: + for s in ["top", "right", "left", "bottom"]: + ax[a].spines[s].set_visible(False) + + fname = save_dir / "labels.jpg" + plt.savefig(fname, dpi=200) + plt.close() + if on_plot: + on_plot(fname) + + +def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True): + """ + Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop. + + This function takes a bounding box and an image, and then saves a cropped portion of the image according + to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding + adjustments to the bounding box. + + Args: + xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format. + im (np.ndarray): The input image. + file (Path, optional): The path where the cropped image will be saved. + gain (float, optional): A multiplicative factor to increase the size of the bounding box. + pad (int, optional): The number of pixels to add to the width and height of the bounding box. + square (bool, optional): If True, the bounding box will be transformed into a square. + BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. + save (bool, optional): If True, the cropped image will be saved to disk. + + Returns: + (np.ndarray): The cropped image. + + Examples: + >>> from ultralytics.utils.plotting import save_one_box + >>> xyxy = [50, 50, 150, 150] + >>> im = cv2.imread("image.jpg") + >>> cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True) + """ + if not isinstance(xyxy, torch.Tensor): # may be list + xyxy = torch.stack(xyxy) + b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes + if square: + b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square + b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad + xyxy = ops.xywh2xyxy(b).long() + xyxy = ops.clip_boxes(xyxy, im.shape) + crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)] + if save: + file.parent.mkdir(parents=True, exist_ok=True) # make directory + f = str(increment_path(file).with_suffix(".jpg")) + # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue + Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB + return crop + + +@threaded +def plot_images( + images: Union[torch.Tensor, np.ndarray], + batch_idx: Union[torch.Tensor, np.ndarray], + cls: Union[torch.Tensor, np.ndarray], + bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32), + confs: Optional[Union[torch.Tensor, np.ndarray]] = None, + masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8), + kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), + paths: Optional[List[str]] = None, + fname: str = "images.jpg", + names: Optional[Dict[int, str]] = None, + on_plot: Optional[Callable] = None, + max_size: int = 1920, + max_subplots: int = 16, + save: bool = True, + conf_thres: float = 0.25, +) -> Optional[np.ndarray]: + """ + Plot image grid with labels, bounding boxes, masks, and keypoints. + + Args: + images: Batch of images to plot. Shape: (batch_size, channels, height, width). + batch_idx: Batch indices for each detection. Shape: (num_detections,). + cls: Class labels for each detection. Shape: (num_detections,). + bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes. + confs: Confidence scores for each detection. Shape: (num_detections,). + masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width). + kpts: Keypoints for each detection. Shape: (num_detections, 51). + paths: List of file paths for each image in the batch. + fname: Output filename for the plotted image grid. + names: Dictionary mapping class indices to class names. + on_plot: Optional callback function to be called after saving the plot. + max_size: Maximum size of the output image grid. + max_subplots: Maximum number of subplots in the image grid. + save: Whether to save the plotted image grid to a file. + conf_thres: Confidence threshold for displaying detections. + + Returns: + (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise. + + Note: + This function supports both tensor and numpy array inputs. It will automatically + convert tensor inputs to numpy arrays for processing. + """ + if isinstance(images, torch.Tensor): + images = images.cpu().float().numpy() + if isinstance(cls, torch.Tensor): + cls = cls.cpu().numpy() + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + if isinstance(masks, torch.Tensor): + masks = masks.cpu().numpy().astype(int) + if isinstance(kpts, torch.Tensor): + kpts = kpts.cpu().numpy() + if isinstance(batch_idx, torch.Tensor): + batch_idx = batch_idx.cpu().numpy() + + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs**0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) + + # Build Image + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i in range(bs): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) + + # Resize (optional) + scale = max_size / ns / max(h, w) + if scale < 1: + h = math.ceil(scale * h) + w = math.ceil(scale * w) + mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) + + # Annotate + fs = int((h + w) * ns * 0.01) # font size + fs = max(fs, 18) # ensure that the font size is large enough to be easily readable. + annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) + for i in range(bs): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders + if paths: + annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames + if len(cls) > 0: + idx = batch_idx == i + classes = cls[idx].astype("int") + labels = confs is None + + if len(bboxes): + boxes = bboxes[idx] + conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred) + if len(boxes): + if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1 + boxes[..., [0, 2]] *= w # scale to pixels + boxes[..., [1, 3]] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes[..., :4] *= scale + boxes[..., 0] += x + boxes[..., 1] += y + is_obb = boxes.shape[-1] == 5 # xywhr + boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) + for j, box in enumerate(boxes.astype(np.int64).tolist()): + c = classes[j] + color = colors(c) + c = names.get(c, c) if names else c + if labels or conf[j] > conf_thres: + label = f"{c}" if labels else f"{c} {conf[j]:.1f}" + annotator.box_label(box, label, color=color, rotated=is_obb) + + elif len(classes): + for c in classes: + color = colors(c) + c = names.get(c, c) if names else c + annotator.text((x, y), f"{c}", txt_color=color, box_style=True) + + # Plot keypoints + if len(kpts): + kpts_ = kpts[idx].copy() + if len(kpts_): + if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01 + kpts_[..., 0] *= w # scale to pixels + kpts_[..., 1] *= h + elif scale < 1: # absolute coords need scale if image scales + kpts_ *= scale + kpts_[..., 0] += x + kpts_[..., 1] += y + for j in range(len(kpts_)): + if labels or conf[j] > conf_thres: + annotator.kpts(kpts_[j], conf_thres=conf_thres) + + # Plot masks + if len(masks): + if idx.shape[0] == masks.shape[0]: # overlap_masks=False + image_masks = masks[idx] + else: # overlap_masks=True + image_masks = masks[[i]] # (1, 640, 640) + nl = idx.sum() + index = np.arange(nl).reshape((nl, 1, 1)) + 1 + image_masks = np.repeat(image_masks, nl, axis=0) + image_masks = np.where(image_masks == index, 1.0, 0.0) + + im = np.asarray(annotator.im).copy() + for j in range(len(image_masks)): + if labels or conf[j] > conf_thres: + color = colors(classes[j]) + mh, mw = image_masks[j].shape + if mh != h or mw != w: + mask = image_masks[j].astype(np.uint8) + mask = cv2.resize(mask, (w, h)) + mask = mask.astype(bool) + else: + mask = image_masks[j].astype(bool) + try: + im[y : y + h, x : x + w, :][mask] = ( + im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6 + ) + except Exception: + pass + annotator.fromarray(im) + if not save: + return np.asarray(annotator.im) + annotator.im.save(fname) # save + if on_plot: + on_plot(fname) + + +@plt_settings() +def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None): + """ + Plot training results from a results CSV file. The function supports various types of data including segmentation, + pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located. + + Args: + file (str, optional): Path to the CSV file containing the training results. + dir (str, optional): Directory where the CSV file is located if 'file' is not provided. + segment (bool, optional): Flag to indicate if the data is for segmentation. + pose (bool, optional): Flag to indicate if the data is for pose estimation. + classify (bool, optional): Flag to indicate if the data is for classification. + on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument. + + Examples: + >>> from ultralytics.utils.plotting import plot_results + >>> plot_results("path/to/results.csv", segment=True) + """ + import pandas as pd # scope for faster 'import ultralytics' + from scipy.ndimage import gaussian_filter1d + + save_dir = Path(file).parent if file else Path(dir) + if classify: + fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True) + index = [2, 5, 3, 4] + elif segment: + fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) + index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13] + elif pose: + fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True) + index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14] + else: + fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) + index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8] + ax = ax.ravel() + files = list(save_dir.glob("results*.csv")) + assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." + for f in files: + try: + data = pd.read_csv(f) + s = [x.strip() for x in data.columns] + x = data.values[:, 0] + for i, j in enumerate(index): + y = data.values[:, j].astype("float") + # y[y == 0] = np.nan # don't show zero values + ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results + ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line + ax[i].set_title(s[j], fontsize=12) + # if j in {8, 9, 10}: # share train and val loss y axes + # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) + except Exception as e: + LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") + ax[1].legend() + fname = save_dir / "results.png" + fig.savefig(fname, dpi=200) + plt.close() + if on_plot: + on_plot(fname) + + +def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"): + """ + Plot a scatter plot with points colored based on a 2D histogram. + + Args: + v (array-like): Values for the x-axis. + f (array-like): Values for the y-axis. + bins (int, optional): Number of bins for the histogram. + cmap (str, optional): Colormap for the scatter plot. + alpha (float, optional): Alpha for the scatter plot. + edgecolors (str, optional): Edge colors for the scatter plot. + + Examples: + >>> v = np.random.rand(100) + >>> f = np.random.rand(100) + >>> plt_color_scatter(v, f) + """ + # Calculate 2D histogram and corresponding colors + hist, xedges, yedges = np.histogram2d(v, f, bins=bins) + colors = [ + hist[ + min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), + min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1), + ] + for i in range(len(v)) + ] + + # Scatter plot + plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors) + + +def plot_tune_results(csv_file="tune_results.csv"): + """ + Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key + in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots. + + Args: + csv_file (str, optional): Path to the CSV file containing the tuning results. + + Examples: + >>> plot_tune_results("path/to/tune_results.csv") + """ + import pandas as pd # scope for faster 'import ultralytics' + from scipy.ndimage import gaussian_filter1d + + def _save_one_file(file): + """Save one matplotlib plot to 'file'.""" + plt.savefig(file, dpi=200) + plt.close() + LOGGER.info(f"Saved {file}") + + # Scatter plots for each hyperparameter + csv_file = Path(csv_file) + data = pd.read_csv(csv_file) + num_metrics_columns = 1 + keys = [x.strip() for x in data.columns][num_metrics_columns:] + x = data.values + fitness = x[:, 0] # fitness + j = np.argmax(fitness) # max fitness index + n = math.ceil(len(keys) ** 0.5) # columns and rows in plot + plt.figure(figsize=(10, 10), tight_layout=True) + for i, k in enumerate(keys): + v = x[:, i + num_metrics_columns] + mu = v[j] # best single result + plt.subplot(n, n, i + 1) + plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none") + plt.plot(mu, fitness.max(), "k+", markersize=15) + plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters + plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8 + if i % n != 0: + plt.yticks([]) + _save_one_file(csv_file.with_name("tune_scatter_plots.png")) + + # Fitness vs iteration + x = range(1, len(fitness) + 1) + plt.figure(figsize=(10, 6), tight_layout=True) + plt.plot(x, fitness, marker="o", linestyle="none", label="fitness") + plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line + plt.title("Fitness vs Iteration") + plt.xlabel("Iteration") + plt.ylabel("Fitness") + plt.grid(True) + plt.legend() + _save_one_file(csv_file.with_name("tune_fitness.png")) + + +def output_to_target(output, max_det=300): + """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" + targets = [] + for i, o in enumerate(output): + box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1)) + targets = torch.cat(targets, 0).numpy() + return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] + + +def output_to_rotated_target(output, max_det=300): + """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" + targets = [] + for i, o in enumerate(output): + box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, box, angle, conf), 1)) + targets = torch.cat(targets, 0).numpy() + return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] + + +def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")): + """ + Visualize feature maps of a given model module during inference. + + Args: + x (torch.Tensor): Features to be visualized. + module_type (str): Module type. + stage (int): Module stage within the model. + n (int, optional): Maximum number of feature maps to plot. + save_dir (Path, optional): Directory to save results. + """ + for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads + if m in module_type: + return + if isinstance(x, torch.Tensor): + _, channels, height, width = x.shape # batch, channels, height, width + if height > 1 and width > 1: + f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename + + blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels + n = min(n, channels) # number of plots + _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols + ax = ax.ravel() + plt.subplots_adjust(wspace=0.05, hspace=0.05) + for i in range(n): + ax[i].imshow(blocks[i].squeeze()) # cmap='gray' + ax[i].axis("off") + + LOGGER.info(f"Saving {f}... ({n}/{channels})") + plt.savefig(f, dpi=300, bbox_inches="tight") + plt.close() + np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save diff --git a/tracking/ultralytics/utils/tal.py b/tracking/ultralytics/utils/tal.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa4801fced98383d1a617d8470785603703a0d6 --- /dev/null +++ b/tracking/ultralytics/utils/tal.py @@ -0,0 +1,416 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn + +from . import LOGGER +from .checks import check_version +from .metrics import bbox_iou, probiou +from .ops import xywhr2xyxyxyxy + +TORCH_1_10 = check_version(torch.__version__, "1.10.0") + + +class TaskAlignedAssigner(nn.Module): + """ + A task-aligned assigner for object detection. + + This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both + classification and localization information. + + Attributes: + topk (int): The number of top candidates to consider. + num_classes (int): The number of object classes. + bg_idx (int): Background class index. + alpha (float): The alpha parameter for the classification component of the task-aligned metric. + beta (float): The beta parameter for the localization component of the task-aligned metric. + eps (float): A small value to prevent division by zero. + """ + + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): + """Initialize a TaskAlignedAssigner object with customizable hyperparameters.""" + super().__init__() + self.topk = topk + self.num_classes = num_classes + self.bg_idx = num_classes + self.alpha = alpha + self.beta = beta + self.eps = eps + + @torch.no_grad() + def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + """ + Compute the task-aligned assignment. + + Args: + pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). + pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). + anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2). + gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). + gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). + mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1). + + Returns: + target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors). + target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4). + target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes). + fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors). + target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors). + + References: + https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py + """ + self.bs = pd_scores.shape[0] + self.n_max_boxes = gt_bboxes.shape[1] + device = gt_bboxes.device + + if self.n_max_boxes == 0: + return ( + torch.full_like(pd_scores[..., 0], self.bg_idx), + torch.zeros_like(pd_bboxes), + torch.zeros_like(pd_scores), + torch.zeros_like(pd_scores[..., 0]), + torch.zeros_like(pd_scores[..., 0]), + ) + + try: + return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt) + except torch.OutOfMemoryError: + # Move tensors to CPU, compute, then move back to original device + LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU") + cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)] + result = self._forward(*cpu_tensors) + return tuple(t.to(device) for t in result) + + def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + """ + Compute the task-aligned assignment. + + Args: + pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). + pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). + anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2). + gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). + gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). + mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1). + + Returns: + target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors). + target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4). + target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes). + fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors). + target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors). + """ + mask_pos, align_metric, overlaps = self.get_pos_mask( + pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt + ) + + target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) + + # Assigned target + target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) + + # Normalize + align_metric *= mask_pos + pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj + pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj + norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) + target_scores = target_scores * norm_align_metric + + return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx + + def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): + """ + Get positive mask for each ground truth box. + + Args: + pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). + pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). + gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). + gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). + anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2). + mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1). + + Returns: + mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w). + align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w). + overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w). + """ + mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) + # Get anchor_align metric, (b, max_num_obj, h*w) + align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt) + # Get topk_metric mask, (b, max_num_obj, h*w) + mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool()) + # Merge all mask to a final mask, (b, max_num_obj, h*w) + mask_pos = mask_topk * mask_in_gts * mask_gt + + return mask_pos, align_metric, overlaps + + def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt): + """ + Compute alignment metric given predicted and ground truth bounding boxes. + + Args: + pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). + pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). + gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). + gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). + mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w). + + Returns: + align_metric (torch.Tensor): Alignment metric combining classification and localization. + overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes. + """ + na = pd_bboxes.shape[-2] + mask_gt = mask_gt.bool() # b, max_num_obj, h*w + overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) + bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device) + + ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj + ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj + ind[1] = gt_labels.squeeze(-1) # b, max_num_obj + # Get the scores of each grid for each gt cls + bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w + + # (b, max_num_obj, 1, 4), (b, 1, h*w, 4) + pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt] + gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt] + overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes) + + align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) + return align_metric, overlaps + + def iou_calculation(self, gt_bboxes, pd_bboxes): + """ + Calculate IoU for horizontal bounding boxes. + + Args: + gt_bboxes (torch.Tensor): Ground truth boxes. + pd_bboxes (torch.Tensor): Predicted boxes. + + Returns: + (torch.Tensor): IoU values between each pair of boxes. + """ + return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) + + def select_topk_candidates(self, metrics, largest=True, topk_mask=None): + """ + Select the top-k candidates based on the given metrics. + + Args: + metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, + max_num_obj is the maximum number of objects, and h*w represents the + total number of anchor points. + largest (bool): If True, select the largest values; otherwise, select the smallest values. + topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where + topk is the number of top candidates to consider. If not provided, + the top-k values are automatically computed based on the given metrics. + + Returns: + (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates. + """ + # (b, max_num_obj, topk) + topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest) + if topk_mask is None: + topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs) + # (b, max_num_obj, topk) + topk_idxs.masked_fill_(~topk_mask, 0) + + # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) + count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device) + ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) + for k in range(self.topk): + # Expand topk_idxs for each value of k and add 1 at the specified positions + count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) + # Filter invalid bboxes + count_tensor.masked_fill_(count_tensor > 1, 0) + + return count_tensor.to(metrics.dtype) + + def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): + """ + Compute target labels, target bounding boxes, and target scores for the positive anchor points. + + Args: + gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the + batch size and max_num_obj is the maximum number of objects. + gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4). + target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive + anchor points, with shape (b, h*w), where h*w is the total + number of anchor points. + fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive + (foreground) anchor points. + + Returns: + target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points. + target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive + anchor points. + target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive + anchor points. + """ + # Assigned target labels, (b, 1) + batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] + target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w) + target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w) + + # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4) + target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] + + # Assigned target scores + target_labels.clamp_(0) + + # 10x faster than F.one_hot() + target_scores = torch.zeros( + (target_labels.shape[0], target_labels.shape[1], self.num_classes), + dtype=torch.int64, + device=target_labels.device, + ) # (b, h*w, 80) + target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) + + fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) + target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) + + return target_labels, target_bboxes, target_scores + + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): + """ + Select positive anchor centers within ground truth bounding boxes. + + Args: + xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2). + gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4). + eps (float, optional): Small value for numerical stability. Defaults to 1e-9. + + Returns: + (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w). + + Note: + b: batch size, n_boxes: number of ground truth boxes, h: height, w: width. + Bounding box format: [x_min, y_min, x_max, y_max]. + """ + n_anchors = xy_centers.shape[0] + bs, n_boxes, _ = gt_bboxes.shape + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom + bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) + return bbox_deltas.amin(3).gt_(eps) + + @staticmethod + def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): + """ + Select anchor boxes with highest IoU when assigned to multiple ground truths. + + Args: + mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w). + overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w). + n_max_boxes (int): Maximum number of ground truth boxes. + + Returns: + target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w). + fg_mask (torch.Tensor): Foreground mask, shape (b, h*w). + mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w). + """ + # Convert (b, n_max_boxes, h*w) -> (b, h*w) + fg_mask = mask_pos.sum(-2) + if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes + mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w) + max_overlaps_idx = overlaps.argmax(1) # (b, h*w) + + is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) + is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) + + mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w) + fg_mask = mask_pos.sum(-2) + # Find each grid serve which gt(index) + target_gt_idx = mask_pos.argmax(-2) # (b, h*w) + return target_gt_idx, fg_mask, mask_pos + + +class RotatedTaskAlignedAssigner(TaskAlignedAssigner): + """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.""" + + def iou_calculation(self, gt_bboxes, pd_bboxes): + """Calculate IoU for rotated bounding boxes.""" + return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0) + + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes): + """ + Select the positive anchor center in gt for rotated bounding boxes. + + Args: + xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2). + gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5). + + Returns: + (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w). + """ + # (b, n_boxes, 5) --> (b, n_boxes, 4, 2) + corners = xywhr2xyxyxyxy(gt_bboxes) + # (b, n_boxes, 1, 2) + a, b, _, d = corners.split(1, dim=-2) + ab = b - a + ad = d - a + + # (b, n_boxes, h*w, 2) + ap = xy_centers - a + norm_ab = (ab * ab).sum(dim=-1) + norm_ad = (ad * ad).sum(dim=-1) + ap_dot_ab = (ap * ab).sum(dim=-1) + ap_dot_ad = (ap * ad).sum(dim=-1) + return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box + + +def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = [], [] + assert feats is not None + dtype, device = feats[0].dtype, feats[0].device + for i, stride in enumerate(strides): + h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1])) + sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_points), torch.cat(stride_tensor) + + +def dist2bbox(distance, anchor_points, xywh=True, dim=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = distance.chunk(2, dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return torch.cat((c_xy, wh), dim) # xywh bbox + return torch.cat((x1y1, x2y2), dim) # xyxy bbox + + +def bbox2dist(anchor_points, bbox, reg_max): + """Transform bbox(xyxy) to dist(ltrb).""" + x1y1, x2y2 = bbox.chunk(2, -1) + return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb) + + +def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1): + """ + Decode predicted rotated bounding box coordinates from anchor points and distribution. + + Args: + pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4). + pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1). + anchor_points (torch.Tensor): Anchor points with shape (h*w, 2). + dim (int, optional): Dimension along which to split. Defaults to -1. + + Returns: + (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4). + """ + lt, rb = pred_dist.split(2, dim=dim) + cos, sin = torch.cos(pred_angle), torch.sin(pred_angle) + # (bs, h*w, 1) + xf, yf = ((rb - lt) / 2).split(1, dim=dim) + x, y = xf * cos - yf * sin, xf * sin + yf * cos + xy = torch.cat([x, y], dim=dim) + anchor_points + return torch.cat([xy, lt + rb], dim=dim) diff --git a/tracking/ultralytics/utils/torch_utils.py b/tracking/ultralytics/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5da31c1e61cc8ddf54ce10f734e58706412a162a --- /dev/null +++ b/tracking/ultralytics/utils/torch_utils.py @@ -0,0 +1,959 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import gc +import math +import os +import random +import time +from contextlib import contextmanager +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils import ( + DEFAULT_CFG_DICT, + DEFAULT_CFG_KEYS, + LOGGER, + NUM_THREADS, + PYTHON_VERSION, + TORCHVISION_VERSION, + WINDOWS, + __version__, + colorstr, +) +from ultralytics.utils.checks import check_version + +try: + import thop +except ImportError: + thop = None # conda support without 'ultralytics-thop' installed + +# Version checks (all default to version>=min_version) +TORCH_1_9 = check_version(torch.__version__, "1.9.0") +TORCH_1_13 = check_version(torch.__version__, "1.13.0") +TORCH_2_0 = check_version(torch.__version__, "2.0.0") +TORCH_2_4 = check_version(torch.__version__, "2.4.0") +TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") +TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") +TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") +TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0") +if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows + LOGGER.warning( + "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve " + "https://github.com/ultralytics/ultralytics/issues/15049" + ) + + +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first.""" + initialized = dist.is_available() and dist.is_initialized() + + if initialized and local_rank not in {-1, 0}: + dist.barrier(device_ids=[local_rank]) + yield + if initialized and local_rank == 0: + dist.barrier(device_ids=[local_rank]) + + +def smart_inference_mode(): + """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.""" + + def decorate(fn): + """Applies appropriate torch decorator for inference mode based on torch version.""" + if TORCH_1_9 and torch.is_inference_mode_enabled(): + return fn # already in inference_mode, act as a pass-through + else: + return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) + + return decorate + + +def autocast(enabled: bool, device: str = "cuda"): + """ + Get the appropriate autocast context manager based on PyTorch version and AMP setting. + + This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both + older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions. + + Args: + enabled (bool): Whether to enable automatic mixed precision. + device (str, optional): The device to use for autocast. Defaults to 'cuda'. + + Returns: + (torch.amp.autocast): The appropriate autocast context manager. + + Notes: + - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`. + - For older versions, it uses `torch.cuda.autocast`. + + Examples: + >>> with autocast(enabled=True): + ... # Your mixed precision operations here + ... pass + """ + if TORCH_1_13: + return torch.amp.autocast(device, enabled=enabled) + else: + return torch.cuda.amp.autocast(enabled) + + +def get_cpu_info(): + """Return a string with system CPU information, i.e. 'Apple M2'.""" + from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error + + if "cpu_info" not in PERSISTENT_CACHE: + try: + import cpuinfo # pip install py-cpuinfo + + k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference + info = cpuinfo.get_cpu_info() # info dict + string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown") + PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "") + except Exception: + pass + return PERSISTENT_CACHE.get("cpu_info", "unknown") + + +def get_gpu_info(index): + """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.""" + properties = torch.cuda.get_device_properties(index) + return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB" + + +def select_device(device="", batch=0, newline=False, verbose=True): + """ + Select the appropriate PyTorch device based on the provided arguments. + + The function takes a string specifying the device or a torch.device object and returns a torch.device object + representing the selected device. The function also validates the number of available devices and raises an + exception if the requested device(s) are not available. + + Args: + device (str | torch.device, optional): Device string or torch.device object. + Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects + the first available GPU, or CPU if no GPU is available. + batch (int, optional): Batch size being used in your model. Defaults to 0. + newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False. + verbose (bool, optional): If True, logs the device information. Defaults to True. + + Returns: + (torch.device): Selected device. + + Raises: + ValueError: If the specified device is not available or if the batch size is not a multiple of the number of + devices when using multiple GPUs. + + Examples: + >>> select_device("cuda:0") + device(type='cuda', index=0) + + >>> select_device("cpu") + device(type='cpu') + + Note: + Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use. + """ + if isinstance(device, torch.device) or str(device).startswith("tpu"): + return device + + s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} " + device = str(device).lower() + for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ": + device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' + cpu = device == "cpu" + mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS) + if cpu or mps: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False + elif device: # non-cpu device requested + if device == "cuda": + device = "0" + if "," in device: + device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1" + visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) + os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() + if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))): + LOGGER.info(s) + install = ( + "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " + "CUDA devices are seen by torch.\n" + if torch.cuda.device_count() == 0 + else "" + ) + raise ValueError( + f"Invalid CUDA 'device={device}' requested." + f" Use 'device=cpu' or pass valid CUDA device(s) if available," + f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" + f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" + f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" + f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" + f"{install}" + ) + + if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available + devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"] + n = len(devices) # device count + if n > 1: # multi-GPU + if batch < 1: + raise ValueError( + "AutoBatch with batch<1 not supported for Multi-GPU training, " + "please specify a valid batch size, i.e. batch=16." + ) + if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count + raise ValueError( + f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " + f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}." + ) + space = " " * (len(s) + 1) + for i, d in enumerate(devices): + s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB + arg = "cuda:0" + elif mps and TORCH_2_0 and torch.backends.mps.is_available(): + # Prefer MPS if available + s += f"MPS ({get_cpu_info()})\n" + arg = "mps" + else: # revert to CPU + s += f"CPU ({get_cpu_info()})\n" + arg = "cpu" + + if arg in {"cpu", "mps"}: + torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training + if verbose: + LOGGER.info(s if newline else s.rstrip()) + return torch.device(arg) + + +def time_sync(): + """PyTorch-accurate time.""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.time() + + +def fuse_conv_and_bn(conv, bn): + """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/.""" + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # Prepare filters + w_conv = conv.weight.view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) + + # Prepare spatial bias + b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def fuse_deconv_and_bn(deconv, bn): + """Fuse ConvTranspose2d() and BatchNorm2d() layers.""" + fuseddconv = ( + nn.ConvTranspose2d( + deconv.in_channels, + deconv.out_channels, + kernel_size=deconv.kernel_size, + stride=deconv.stride, + padding=deconv.padding, + output_padding=deconv.output_padding, + dilation=deconv.dilation, + groups=deconv.groups, + bias=True, + ) + .requires_grad_(False) + .to(deconv.weight.device) + ) + + # Prepare filters + w_deconv = deconv.weight.view(deconv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape)) + + # Prepare spatial bias + b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fuseddconv + + +def model_info(model, detailed=False, verbose=True, imgsz=640): + """ + Print and return detailed model information layer by layer. + + Args: + model (nn.Module): Model to analyze. + detailed (bool, optional): Whether to print detailed layer information. Defaults to False. + verbose (bool, optional): Whether to print model information. Defaults to True. + imgsz (int | List, optional): Input image size. Defaults to 640. + + Returns: + (Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs. + """ + if not verbose: + return + n_p = get_num_params(model) # number of parameters + n_g = get_num_gradients(model) # number of gradients + layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0) + n_l = len(layers) # number of layers + if detailed: + h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}" + LOGGER.info(h) + for i, (mn, m) in enumerate(layers.items()): + mn = mn.replace("module_list.", "") + mt = m.__class__.__name__ + if len(m._parameters): + for pn, p in m.named_parameters(): + LOGGER.info( + f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}" + ) + else: # layers with no learnable params + LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}") + + flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] + fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else "" + fs = f", {flops:.1f} GFLOPs" if flops else "" + yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "") + model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model" + LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}") + return n_l, n_p, n_g, flops + + +def get_num_params(model): + """Return the total number of parameters in a YOLO model.""" + return sum(x.numel() for x in model.parameters()) + + +def get_num_gradients(model): + """Return the total number of parameters with gradients in a YOLO model.""" + return sum(x.numel() for x in model.parameters() if x.requires_grad) + + +def model_info_for_loggers(trainer): + """ + Return model info dict with useful model information. + + Args: + trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data. + + Returns: + (dict): Dictionary containing model parameters, GFLOPs, and inference speeds. + + Examples: + YOLOv8n info for loggers + >>> results = { + ... "model/parameters": 3151904, + ... "model/GFLOPs": 8.746, + ... "model/speed_ONNX(ms)": 41.244, + ... "model/speed_TensorRT(ms)": 3.211, + ... "model/speed_PyTorch(ms)": 18.755, + ...} + """ + if trainer.args.profile: # profile ONNX and TensorRT times + from ultralytics.utils.benchmarks import ProfileModels + + results = ProfileModels([trainer.last], device=trainer.device).profile()[0] + results.pop("model/name") + else: # only return PyTorch times from most recent validation + results = { + "model/parameters": get_num_params(trainer.model), + "model/GFLOPs": round(get_flops(trainer.model), 3), + } + results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3) + return results + + +def get_flops(model, imgsz=640): + """ + Return a YOLO model's FLOPs. + + Args: + model (nn.Module): The model to calculate FLOPs for. + imgsz (int | List[int], optional): Input image size. Defaults to 640. + + Returns: + (float): The model's FLOPs in billions. + """ + if not thop: + return 0.0 # if not installed return 0.0 GFLOPs + + try: + model = de_parallel(model) + p = next(model.parameters()) + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor + stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride + im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs + return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs + except Exception: + return 0.0 + + +def get_flops_with_torch_profiler(model, imgsz=640): + """ + Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower). + + Args: + model (nn.Module): The model to calculate FLOPs for. + imgsz (int | List[int], optional): Input image size. Defaults to 640. + + Returns: + (float): The model's FLOPs in billions. + """ + if not TORCH_2_0: # torch profiler implemented in torch>=2.0 + return 0.0 + model = de_parallel(model) + p = next(model.parameters()) + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor + stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride + im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + with torch.profiler.profile(with_flops=True) as prof: + model(im) + flops = sum(x.flops for x in prof.key_averages()) / 1e9 + flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + with torch.profiler.profile(with_flops=True) as prof: + model(im) + flops = sum(x.flops for x in prof.key_averages()) / 1e9 + return flops + + +def initialize_weights(model): + """Initialize model weights to random values.""" + for m in model.modules(): + t = type(m) + if t is nn.Conv2d: + pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif t is nn.BatchNorm2d: + m.eps = 1e-3 + m.momentum = 0.03 + elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}: + m.inplace = True + + +def scale_img(img, ratio=1.0, same_shape=False, gs=32): + """ + Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple. + + Args: + img (torch.Tensor): Input image tensor. + ratio (float, optional): Scaling ratio. Defaults to 1.0. + same_shape (bool, optional): Whether to maintain the same shape. Defaults to False. + gs (int, optional): Grid size for padding. Defaults to 32. + + Returns: + (torch.Tensor): Scaled and padded image tensor. + """ + if ratio == 1.0: + return img + h, w = img.shape[2:] + s = (int(h * ratio), int(w * ratio)) # new size + img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize + if not same_shape: # pad/crop img + h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) + return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean + + +def copy_attr(a, b, include=(), exclude=()): + """ + Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes. + + Args: + a (object): Destination object to copy attributes to. + b (object): Source object to copy attributes from. + include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to (). + exclude (tuple, optional): Attributes to exclude. Defaults to (). + """ + for k, v in b.__dict__.items(): + if (len(include) and k not in include) or k.startswith("_") or k in exclude: + continue + else: + setattr(a, k, v) + + +def get_latest_opset(): + """ + Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity. + + Returns: + (int): The ONNX opset version. + """ + if TORCH_1_13: + # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset' + return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 + # Otherwise for PyTorch<=1.12 return the corresponding predefined opset + version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3' + return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12) + + +def intersect_dicts(da, db, exclude=()): + """ + Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values. + + Args: + da (dict): First dictionary. + db (dict): Second dictionary. + exclude (tuple, optional): Keys to exclude. Defaults to (). + + Returns: + (dict): Dictionary of intersecting keys with matching shapes. + """ + return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} + + +def is_parallel(model): + """ + Returns True if model is of type DP or DDP. + + Args: + model (nn.Module): Model to check. + + Returns: + (bool): True if model is DataParallel or DistributedDataParallel. + """ + return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + + +def de_parallel(model): + """ + De-parallelize a model: returns single-GPU model if model is of type DP or DDP. + + Args: + model (nn.Module): Model to de-parallelize. + + Returns: + (nn.Module): De-parallelized model. + """ + return model.module if is_parallel(model) else model + + +def one_cycle(y1=0.0, y2=1.0, steps=100): + """ + Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf. + + Args: + y1 (float, optional): Initial value. Defaults to 0.0. + y2 (float, optional): Final value. Defaults to 1.0. + steps (int, optional): Number of steps. Defaults to 100. + + Returns: + (function): Lambda function for computing the sinusoidal ramp. + """ + return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1 + + +def init_seeds(seed=0, deterministic=False): + """ + Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html. + + Args: + seed (int, optional): Random seed. Defaults to 0. + deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe + # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 + if deterministic: + if TORCH_2_0: + torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible + torch.backends.cudnn.deterministic = True + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + os.environ["PYTHONHASHSEED"] = str(seed) + else: + LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.") + else: + unset_deterministic() + + +def unset_deterministic(): + """Unsets all the configurations applied for deterministic training.""" + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False + os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) + os.environ.pop("PYTHONHASHSEED", None) + + +class ModelEMA: + """ + Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. + + Keeps a moving average of everything in the model state_dict (parameters and buffers). + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + To disable EMA set the `enabled` attribute to `False`. + + Attributes: + ema (nn.Module): Copy of the model in evaluation mode. + updates (int): Number of EMA updates. + decay (function): Decay function that determines the EMA weight. + enabled (bool): Whether EMA is enabled. + """ + + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + """ + Initialize EMA for 'model' with given arguments. + + Args: + model (nn.Module): Model to create EMA for. + decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999. + tau (int, optional): EMA decay time constant. Defaults to 2000. + updates (int, optional): Initial number of updates. Defaults to 0. + """ + self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA + self.updates = updates # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) + for p in self.ema.parameters(): + p.requires_grad_(False) + self.enabled = True + + def update(self, model): + """ + Update EMA parameters. + + Args: + model (nn.Module): Model to update EMA from. + """ + if self.enabled: + self.updates += 1 + d = self.decay(self.updates) + + msd = de_parallel(model).state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: # true for FP16 and FP32 + v *= d + v += (1 - d) * msd[k].detach() + # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}' + + def update_attr(self, model, include=(), exclude=("process_group", "reducer")): + """ + Updates attributes and saves stripped model with optimizer removed. + + Args: + model (nn.Module): Model to update attributes from. + include (tuple, optional): Attributes to include. Defaults to (). + exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer"). + """ + if self.enabled: + copy_attr(self.ema, model, include, exclude) + + +def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict: + """ + Strip optimizer from 'f' to finalize training, optionally save as 's'. + + Args: + f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'. + s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten. + updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving. + + Returns: + (dict): The combined checkpoint dictionary. + + Examples: + >>> from pathlib import Path + >>> from ultralytics.utils.torch_utils import strip_optimizer + >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"): + >>> strip_optimizer(f) + """ + try: + x = torch.load(f, map_location=torch.device("cpu")) + assert isinstance(x, dict), "checkpoint is not a Python dictionary" + assert "model" in x, "'model' missing from checkpoint" + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}") + return {} + + metadata = { + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + + # Update model + if x.get("ema"): + x["model"] = x["ema"] # replace model with EMA + if hasattr(x["model"], "args"): + x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict + if hasattr(x["model"], "criterion"): + x["model"].criterion = None # strip loss criterion + x["model"].half() # to FP16 + for p in x["model"].parameters(): + p.requires_grad = False + + # Update other keys + args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args + for k in "optimizer", "best_fitness", "ema", "updates": # keys + x[k] = None + x["epoch"] = -1 + x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys + # x['model'].args = x['train_args'] + + # Save + combined = {**metadata, **x, **(updates or {})} + torch.save(combined, s or f) # combine dicts (prefer to the right) + mb = os.path.getsize(s or f) / 1e6 # file size + LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") + return combined + + +def convert_optimizer_state_dict_to_fp16(state_dict): + """ + Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions. + + Args: + state_dict (dict): Optimizer state dictionary. + + Returns: + (dict): Converted optimizer state dictionary with FP16 tensors. + """ + for state in state_dict["state"].values(): + for k, v in state.items(): + if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32: + state[k] = v.half() + + return state_dict + + +@contextmanager +def cuda_memory_usage(device=None): + """ + Monitor and manage CUDA memory usage. + + This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory. + It then yields a dictionary containing memory usage information, which can be updated by the caller. + Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device. + + Args: + device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None. + + Yields: + (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory. + """ + cuda_info = dict(memory=0) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + yield cuda_info + finally: + cuda_info["memory"] = torch.cuda.memory_reserved(device) + else: + yield cuda_info + + +def profile(input, ops, n=10, device=None, max_num_obj=0): + """ + Ultralytics speed, memory and FLOPs profiler. + + Args: + input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile. + ops (nn.Module | List[nn.Module]): Model or list of operations to profile. + n (int, optional): Number of iterations to average. Defaults to 10. + device (str | torch.device, optional): Device to profile on. Defaults to None. + max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0. + + Returns: + (list): Profile results for each operation. + + Examples: + >>> from ultralytics.utils.torch_utils import profile + >>> input = torch.randn(16, 3, 640, 640) + >>> m1 = lambda x: x * torch.sigmoid(x) + >>> m2 = nn.SiLU() + >>> profile(input, [m1, m2], n=100) # profile over 100 iterations + """ + results = [] + if not isinstance(device, torch.device): + device = select_device(device) + LOGGER.info( + f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" + f"{'input':>24s}{'output':>24s}" + ) + gc.collect() # attempt to free unused memory + torch.cuda.empty_cache() + for x in input if isinstance(input, list) else [input]: + x = x.to(device) + x.requires_grad = True + for m in ops if isinstance(ops, list) else [ops]: + m = m.to(device) if hasattr(m, "to") else m # device + m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m + tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward + try: + flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs + except Exception: + flops = 0 + + try: + mem = 0 + for _ in range(n): + with cuda_memory_usage(device) as cuda_info: + t[0] = time_sync() + y = m(x) + t[1] = time_sync() + try: + (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() + t[2] = time_sync() + except Exception: # no backward method + # print(e) # for debug + t[2] = float("nan") + mem += cuda_info["memory"] / 1e9 # (GB) + tf += (t[1] - t[0]) * 1000 / n # ms per op forward + tb += (t[2] - t[1]) * 1000 / n # ms per op backward + if max_num_obj: # simulate training with predictions per image grid (for AutoBatch) + with cuda_memory_usage(device) as cuda_info: + torch.randn( + x.shape[0], + max_num_obj, + int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())), + device=device, + dtype=torch.float32, + ) + mem += cuda_info["memory"] / 1e9 # (GB) + s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes + p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters + LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}") + results.append([p, flops, mem, tf, tb, s_in, s_out]) + except Exception as e: + LOGGER.info(e) + results.append(None) + finally: + gc.collect() # attempt to free unused memory + torch.cuda.empty_cache() + return results + + +class EarlyStopping: + """ + Early stopping class that stops training when a specified number of epochs have passed without improvement. + + Attributes: + best_fitness (float): Best fitness value observed. + best_epoch (int): Epoch where best fitness was observed. + patience (int): Number of epochs to wait after fitness stops improving before stopping. + possible_stop (bool): Flag indicating if stopping may occur next epoch. + """ + + def __init__(self, patience=50): + """ + Initialize early stopping object. + + Args: + patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. + """ + self.best_fitness = 0.0 # i.e. mAP + self.best_epoch = 0 + self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop + self.possible_stop = False # possible stop may occur next epoch + + def __call__(self, epoch, fitness): + """ + Check whether to stop training. + + Args: + epoch (int): Current epoch of training + fitness (float): Fitness value of current epoch + + Returns: + (bool): True if training should stop, False otherwise + """ + if fitness is None: # check if fitness=None (happens when val=False) + return False + + if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training + self.best_epoch = epoch + self.best_fitness = fitness + delta = epoch - self.best_epoch # epochs without improvement + self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch + stop = delta >= self.patience # stop training if patience exceeded + if stop: + prefix = colorstr("EarlyStopping: ") + LOGGER.info( + f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. " + f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n" + f"To update EarlyStopping(patience={self.patience}) pass a new patience value, " + f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." + ) + return stop + + +class FXModel(nn.Module): + """ + A custom model class for torch.fx compatibility. + + This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph + manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper + copying. + + Attributes: + model (nn.Module): The original model's layers. + """ + + def __init__(self, model): + """ + Initialize the FXModel. + + Args: + model (nn.Module): The original model to wrap for torch.fx compatibility. + """ + super().__init__() + copy_attr(self, model) + # Explicitly set `model` since `copy_attr` somehow does not copy it. + self.model = model.model + + def forward(self, x): + """ + Forward pass through the model. + + This method performs the forward pass through the model, handling the dependencies between layers and saving + intermediate outputs. + + Args: + x (torch.Tensor): The input tensor to the model. + + Returns: + (torch.Tensor): The output tensor from the model. + """ + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + # from earlier layers + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] + x = m(x) # run + y.append(x) # save output + return x diff --git a/tracking/ultralytics/utils/triton.py b/tracking/ultralytics/utils/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..13b009b8646db0d61a18ac774621155c928320cf --- /dev/null +++ b/tracking/ultralytics/utils/triton.py @@ -0,0 +1,103 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + This class provides a convenient interface for sending inference requests to a Triton Inference Server + and processing the responses. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + metadata: The metadata associated with the model. + + Examples: + Initialize a Triton client with HTTP + >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http") + Make inference with numpy arrays + >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32)) + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None")) + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (np.ndarray): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs with the same dtype as the input. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/tracking/ultralytics/utils/tuner.py b/tracking/ultralytics/utils/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..39069aab2e4593ed4df67f749fe5918d0ac47191 --- /dev/null +++ b/tracking/ultralytics/utils/tuner.py @@ -0,0 +1,145 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir +from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks + + +def run_ray_tune( + model, + space: dict = None, + grace_period: int = 10, + gpu_per_trial: int = None, + max_samples: int = 10, + **train_args, +): + """ + Run hyperparameter tuning using Ray Tune. + + Args: + model (YOLO): Model to run the tuner on. + space (dict, optional): The hyperparameter search space. + grace_period (int, optional): The grace period in epochs of the ASHA scheduler. + gpu_per_trial (int, optional): The number of GPUs to allocate per trial. + max_samples (int, optional): The maximum number of trials to run. + **train_args (Any): Additional arguments to pass to the `train()` method. + + Returns: + (dict): A dictionary containing the results of the hyperparameter search. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") # Load a YOLO11n model + + Start tuning hyperparameters for YOLO11n training on the COCO8 dataset + >>> result_grid = model.tune(data="coco8.yaml", use_ray=True) + """ + LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune") + if train_args is None: + train_args = {} + + try: + checks.check_requirements("ray[tune]") + + import ray + from ray import tune + from ray.air import RunConfig + from ray.air.integrations.wandb import WandbLoggerCallback + from ray.tune.schedulers import ASHAScheduler + except ImportError: + raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"') + + try: + import wandb + + assert hasattr(wandb, "__version__") + except (ImportError, AssertionError): + wandb = False + + checks.check_version(ray.__version__, ">=2.0.0", "ray") + default_space = { + # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), + "lr0": tune.uniform(1e-5, 1e-1), + "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 + "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay + "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum + "box": tune.uniform(0.02, 0.2), # box loss gain + "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) + "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg) + "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction) + "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain) + "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg) + "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability) + "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability) + "bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability) + "mosaic": tune.uniform(0.0, 1.0), # image mixup (probability) + "mixup": tune.uniform(0.0, 1.0), # image mixup (probability) + "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability) + } + + # Put the model in ray store + task = model.task + model_in_store = ray.put(model) + + def _tune(config): + """Train the YOLO model with the specified hyperparameters.""" + model_to_train = ray.get(model_in_store) # get the model from ray store for tuning + model_to_train.reset_callbacks() + config.update(train_args) + results = model_to_train.train(**config) + return results.results_dict + + # Get search space + if not space: + space = default_space + LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.") + + # Get dataset + data = train_args.get("data", TASK2DATA[task]) + space["data"] = data + if "data" not in train_args: + LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".') + + # Define the trainable function with allocated resources + trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0}) + + # Define the ASHA scheduler for hyperparameter search + asha_scheduler = ASHAScheduler( + time_attr="epoch", + metric=TASK2METRIC[task], + mode="max", + max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100, + grace_period=grace_period, + reduction_factor=3, + ) + + # Define the callbacks for the hyperparameter search + tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else [] + + # Create the Ray Tune hyperparameter search tuner + tune_dir = get_save_dir( + get_cfg(DEFAULT_CFG, train_args), name=train_args.pop("name", "tune") + ).resolve() # must be absolute dir + tune_dir.mkdir(parents=True, exist_ok=True) + tuner = tune.Tuner( + trainable_with_resources, + param_space=space, + tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), + run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir), + ) + + # Run the hyperparameter search + tuner.fit() + + # Get the results of the hyperparameter search + results = tuner.get_results() + + # Shut down Ray to clean up workers + ray.shutdown() + + return results diff --git a/tracking/utils.py b/tracking/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d99760e15bbce7bae443084ec088ed2e16bb45 --- /dev/null +++ b/tracking/utils.py @@ -0,0 +1,407 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import numpy as np +import torch +import time +import pandas as pd +from ultralytics.utils import ops +from ultralytics.engine.results import Results +from typing import Union +from pathlib import Path +import json +import shutil +import os +import sys +import git +import requests +import zipfile +import subprocess +from git import Repo, exc +from boxmot.utils import logger as LOGGER +from tqdm import tqdm +from boxmot.utils import EXAMPLES, ROOT + + +def split_dataset(src_fldr: Path, percent_to_delete: float = 0.5) -> None: + """ + Copies the dataset to a new location and removes a specified percentage of images and annotations, + adjusting the frame index to start at 1. + + Args: + src_fldr (Path): Source folder containing the dataset. + percent_to_delete (float): Percentage of images and annotations to remove. + """ + # Ensure source path is a Path object + src_fldr = Path(src_fldr) + + # Generate the destination path by replacing "MOT17" with "MOT17-half" in the source path + new_benchmark_name = f'MOT17-{int(percent_to_delete * 100)}' + dst_fldr = Path(str(src_fldr).replace('MOT17', new_benchmark_name)) + + # Copy the dataset to a new location manually using pathlib if it doesn't already exist + if not dst_fldr.exists(): + dst_fldr.mkdir(parents=True) + for item in src_fldr.rglob('*'): + if item.is_dir(): + (dst_fldr / item.relative_to(src_fldr)).mkdir(parents=True, exist_ok=True) + else: + (dst_fldr / item.relative_to(src_fldr)).write_bytes(item.read_bytes()) + + # List all sequences in the destination folder + seq_paths = [f for f in dst_fldr.iterdir() if f.is_dir()] + + # Iterate over each sequence and remove a percentage of images and annotations + for seq_path in seq_paths: + seq_gt_path = seq_path / 'gt' / 'gt.txt' + + # Check if the gt.txt file exists + if not seq_gt_path.exists(): + print(f"Ground truth file not found for {seq_path}. Skipping...") + continue + + df = pd.read_csv(seq_gt_path, sep=",", header=None) + nr_seq_imgs = df[0].unique().max() + split = int(nr_seq_imgs * (1 - percent_to_delete)) + + # Check if the sequence is already split + if nr_seq_imgs <= split: + print(f'Sequence {seq_path} already split. Skipping...') + continue + + print(f'Number of annotated frames in {seq_path}: Keeping from frame {split + 1} to {nr_seq_imgs}') + + # Keep rows from the ground truth file beyond the split point + df = df[df[0] > split] + + # Adjust the frame indices to start from 1 + df[0] = df[0] - split + + df.to_csv(seq_gt_path, header=None, index=None, sep=',') + + # Remove images before the split point using pathlib + jpg_folder_path = seq_path / 'img1' + jpg_paths = list(jpg_folder_path.glob('*.jpg')) + for jpg_path in jpg_paths: + # Extract frame number from image file name (e.g., '000300.jpg' -> 300) + frame_number = int(jpg_path.stem) + # Check if this frame number is in the removed range + if frame_number <= split: + jpg_path.unlink() + + # Rename the remaining images to have a continuous sequence starting from 1 + remaining_jpg_paths = sorted(jpg_folder_path.glob('*.jpg')) + for new_index, jpg_path in enumerate(remaining_jpg_paths, start=1): + new_jpg_name = f"{new_index:06}.jpg" # zero-padded to 6 digits + jpg_path.rename(jpg_folder_path / new_jpg_name) + + remaining_images = len(list(jpg_folder_path.glob('*.jpg'))) + print(f'Number of images in {seq_path} after delete: {remaining_images}') + + return dst_fldr, new_benchmark_name + + +def download_mot_eval_tools(val_tools_path): + """ + Download the official evaluation tools for MOT metrics from the GitHub repository. + + Parameters: + val_tools_path (Path): Path to the destination folder where the evaluation tools will be downloaded. + + Returns: + None. Clones the evaluation tools repository and updates deprecated numpy types. + """ + val_tools_url = "https://github.com/JonathonLuiten/TrackEval" + + try: + # Clone the repository + Repo.clone_from(val_tools_url, val_tools_path) + LOGGER.debug('Official MOT evaluation repo downloaded successfully.') + except exc.GitError as err: + LOGGER.debug(f'Evaluation repo already downloaded or an error occurred: {err}') + + # Fix deprecated np.float, np.int & np.bool by replacing them with native Python types + deprecated_types = {'np.float': 'float', 'np.int': 'int', 'np.bool': 'bool'} + + for file_path in val_tools_path.rglob('*'): + if file_path.suffix in {'.py', '.txt'}: # only consider .py and .txt files + try: + content = file_path.read_text(encoding='utf-8') + updated_content = content + for old_type, new_type in deprecated_types.items(): + updated_content = updated_content.replace(old_type, new_type) + + if updated_content != content: # Only write back if there were changes + file_path.write_text(updated_content, encoding='utf-8') + LOGGER.info(f'Replaced deprecated types in {file_path}.') + except Exception as e: + LOGGER.error(f'Error processing {file_path}: {e}') + + +def download_mot_dataset(val_tools_path, benchmark, max_retries=5, backoff_factor=2): + """ + Download a specific MOT dataset zip file with resumable support and retry logic. + + Parameters: + val_tools_path (Path): Path to the destination folder where the MOT benchmark zip will be downloaded. + benchmark (str): The MOT benchmark to download (e.g., 'MOT20', 'MOT17'). + max_retries (int): Maximum number of retries for the download in case of failure. + backoff_factor (int): Exponential backoff factor for delays between retries. + + Returns: + Path: The path to the downloaded zip file. + """ + url = f'https://motchallenge.net/data/{benchmark}.zip' + zip_dst = val_tools_path / f'{benchmark}.zip' + + retries = 0 # Initialize retry counter + + response = None + while retries <= max_retries: + try: + response = requests.head(url, allow_redirects=True) + # Consider any status code less than 400 (e.g., 200, 302) as indicating that the resource exists + if response.status_code < 400: + # Get the total size of the file from the server + total_size_in_bytes = int(response.headers.get('content-length', 0)) + + # Check if there is already a partially or fully downloaded file + if zip_dst.exists(): + current_size = zip_dst.stat().st_size + + # If the file is fully downloaded, skip the download + if current_size >= total_size_in_bytes: + LOGGER.info(f"{benchmark}.zip is already fully downloaded.") + return zip_dst + + # If the file is partially downloaded, set the range header to resume + resume_header = {'Range': f'bytes={current_size}-'} + LOGGER.info(f"Resuming download for {benchmark}.zip from byte {current_size}...") + else: + current_size = 0 + resume_header = {} + + # Start or resume the download + response = requests.get(url, headers=resume_header, stream=True) + response.raise_for_status() # Check for HTTP request errors + + with open(zip_dst, 'ab') as file, tqdm( + desc=zip_dst.name, + total=total_size_in_bytes, + initial=current_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) + + LOGGER.info(f'{benchmark}.zip downloaded successfully.') + return zip_dst # If download is successful, return the path + + else: + LOGGER.warning(f'{benchmark} is not downloadable from {url}') + return None + + except (requests.HTTPError, requests.ConnectionError) as e: + if response and response.status_code == 416: # Handle "Requested Range Not Satisfiable" error + LOGGER.info(f"{benchmark}.zip is already fully downloaded.") + return zip_dst + LOGGER.error(f'Error occurred while downloading {benchmark}.zip: {e}') + retries += 1 + wait_time = backoff_factor ** retries + LOGGER.info(f"Retrying download in {wait_time} seconds... (Attempt {retries} of {max_retries})") + time.sleep(wait_time) # Exponential backoff delay + + except Exception as e: + LOGGER.error(f'An unexpected error occurred: {e}') + retries += 1 + wait_time = backoff_factor ** retries + LOGGER.info(f"Retrying download in {wait_time} seconds... (Attempt {retries} of {max_retries})") + time.sleep(wait_time) # Exponential backoff delay + + LOGGER.error(f"Failed to download {benchmark}.zip after {max_retries} retries.") + return None + + +def unzip_mot_dataset(zip_path, val_tools_path, benchmark): + """ + Unzip a downloaded MOT dataset zip file into the specified directory. + + Parameters: + zip_path (Path): Path to the downloaded MOT benchmark zip file. + val_tools_path (Path): Base path to the destination folder where the dataset will be unzipped. + benchmark (str): The MOT benchmark that was downloaded (e.g., 'MOT20', 'MOT17'). + + Returns: + None + """ + if zip_path is None: + LOGGER.warning(f'No zip file. Skipping unzipping') + return None + + extract_path = val_tools_path / 'data' / benchmark + if not extract_path.exists(): + try: + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + # folder will be called as the original fetched file + zip_ref.extractall(val_tools_path / 'data') + + LOGGER.info(f'{benchmark}.zip unzipped successfully.') + except zipfile.BadZipFile: + LOGGER.error(f'{zip_path.name} is corrupted. Try deleting the file and run the script again.') + except Exception as e: + LOGGER.error(f'An error occurred while unzipping {zip_path.name}: {e}') + else: + LOGGER.info(f'{benchmark} folder already exists.') + return extract_path + + +def set_gt_fps(opt, seq_paths): + fps_json_filepath = opt.exp_folder_path / 'seqs_frame_nums.json' + with open(fps_json_filepath, 'r') as f: + seqs_frame_nums = json.load(f) + + for seq_path in seq_paths: + seq_name = seq_path.parent.name + frame_nums = seqs_frame_nums[seq_name] + + gt_dir = seq_path.parent / 'gt' + gt_orig_path = gt_dir / 'gt.txt' + gt_temp_path = gt_dir / 'gt_temp.txt' + shutil.copy(gt_orig_path, gt_temp_path) + + seq = np.loadtxt(gt_temp_path, delimiter=',') + seq_filtered = seq[np.isin(seq[:, 0], frame_nums)] + np.savetxt(gt_temp_path, seq_filtered, delimiter=',') + + +def eval_setup(opt, val_tools_path): + """ + Initializes and sets up evaluation paths for MOT challenge datasets. + + This function prepares the directories and paths needed for evaluating + object tracking algorithms on MOT datasets like MOT17 or custom datasets like MOT17-mini. + It filters sequence paths based on the detector (for MOT17), sets up the ground truth, + sequences, and results directories according to the provided options. + + Parameters: + - opt: An object with attributes that include benchmark (str), split (str), + eval_existing (bool), project (str), and name (str). These options dictate + the dataset to use, the split of the dataset, whether to evaluate on an + existing setup, and the naming for the project and evaluation results directory. + - val_tools_path: A string or Path object pointing to the base directory where + the validation tools and datasets are located. + + Returns: + - seq_paths: A list of Path objects pointing to the sequence directories to be evaluated. + - save_dir: A Path object pointing to the directory where evaluation results will be saved. + - MOT_results_folder: A Path object pointing to the directory where MOT challenge + formatted results should be placed. + - gt_folder: A Path object pointing to the directory where ground truth data is located. + """ + + # Convert val_tools_path to Path object if it's not already one + val_tools_path = Path(val_tools_path) + + # Initial setup for paths based on benchmark and split options + mot_seqs_path = val_tools_path / 'data' / opt.benchmark / opt.split + gt_folder = mot_seqs_path # Assuming gt_folder is the same as mot_seqs_path initially + + # Handling different benchmarks + if opt.benchmark == 'MOT17': + # Filter for FRCNN sequences in MOT17 + seq_paths = [p / 'img1' for p in mot_seqs_path.iterdir() if p.is_dir()] + elif opt.benchmark == 'MOT17-mini': + # Adjust paths for MOT17-mini + base_path = ROOT / 'assets' / opt.benchmark / opt.split + mot_seqs_path = gt_folder = base_path + seq_paths = [p / 'img1' for p in mot_seqs_path.iterdir() if p.is_dir()] + else: + # Default handling for other datasets + seq_paths = [p / 'img1' for p in mot_seqs_path.iterdir() if p.is_dir()] + + # Set FPS for GT files + set_gt_fps(opt, seq_paths) + + # Determine save directory + save_dir = Path(opt.project) / opt.name + + # Setup MOT results folder + MOT_results_folder = val_tools_path / 'data' / 'trackers' / 'mot_challenge' / opt.benchmark / save_dir.name / 'data' + MOT_results_folder.mkdir(parents=True, exist_ok=True) # Ensure directory exists + + return seq_paths, save_dir, MOT_results_folder, gt_folder + + +def convert_to_mot_format(results: Union[Results, np.ndarray], frame_idx: int) -> np.ndarray: + """ + Converts tracking results for a single frame into MOT challenge format. + + This function supports inputs as either a custom object with a 'boxes' attribute or a numpy array. + For custom object inputs, 'boxes' should contain 'id', 'xyxy', 'conf', and 'cls' sub-attributes. + For numpy array inputs, the expected format per row is: (xmin, ymin, xmax, ymax, id, conf, cls). + + Parameters: + - results (Union[Results, np.ndarray]): Tracking results for the current frame. + - frame_idx (int): The zero-based index of the frame being processed. + + Returns: + - np.ndarray: An array containing the MOT formatted results for the frame. + """ + + # Check if results are not empty + if results.size != 0: + if isinstance(results, np.ndarray): + # Convert numpy array results to MOT format + tlwh = ops.xyxy2ltwh(results[:, 0:4]) + frame_idx_column = np.full((results.shape[0], 1), frame_idx, dtype=np.int32) + mot_results = np.column_stack(( + frame_idx_column, # frame index + results[:, 4].astype(np.int32), # track id + tlwh.round().astype(np.int32), # top,left,width,height + np.ones((results.shape[0], 1), dtype=np.int32), # "not ignored" + results[:, 6].astype(np.int32), # class + results[:, 5], # confidence (float) + )) + return mot_results + else: + # Convert ultralytics results to MOT format + num_detections = len(results.boxes) + frame_indices = torch.full((num_detections, 1), frame_idx + 1, dtype=torch.int32) + not_ignored = torch.ones((num_detections, 1), dtype=torch.int32) + + mot_results = torch.cat([ + frame_indices, # frame index + results.boxes.id.unsqueeze(1).astype(np.int32), # track id + ops.xyxy2ltwh(results.boxes.xyxy).astype(np.int32), ## top,left,width,height + not_ignored, # "not ignored" + results.boxes.cls.unsqueeze(1).astype(np.int32), # class + results.boxes.conf.unsqueeze(1).astype(np.float32), # confidence (float) + ], dim=1) + + return mot_results.numpy() + + +def write_mot_results(txt_path: Path, mot_results: np.ndarray) -> None: + """ + Writes the MOT challenge formatted results to a text file. + + Parameters: + - txt_path (Path): The path to the text file where results are saved. + - mot_results (np.ndarray): An array containing the MOT formatted results. + + Note: The text file will be created if it does not exist, and the directory + path to the file will be created as well if necessary. + """ + if mot_results is not None: + # Ensure the parent directory of the txt_path exists + txt_path.parent.mkdir(parents=True, exist_ok=True) + + # Ensure the file exists before opening + txt_path.touch(exist_ok=True) + + if mot_results.size != 0: + # Open the file in append mode and save the MOT results + with open(str(txt_path), 'a') as file: + np.savetxt(file, mot_results, fmt='%d,%d,%d,%d,%d,%d,%d,%d,%.6f') diff --git a/tracking/val.py b/tracking/val.py new file mode 100644 index 0000000000000000000000000000000000000000..f32874d221b77beaa0a9270629e3a504d8dad2f9 --- /dev/null +++ b/tracking/val.py @@ -0,0 +1,600 @@ +# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license + +import argparse +import subprocess +from pathlib import Path +import numpy as np +from tqdm import tqdm +import configparser +import shutil +import json +import queue +import select +import re +import os +import torch +from functools import partial +import threading +import sys +import copy +import concurrent.futures + +from boxmot import TRACKERS +from boxmot.tracker_zoo import create_tracker +from boxmot.utils import ROOT, WEIGHTS, TRACKER_CONFIGS, logger as LOGGER, EXAMPLES, DATA +from boxmot.utils.checks import RequirementsChecker +from boxmot.utils.torch_utils import select_device +from boxmot.utils.misc import increment_path +from boxmot.postprocessing.gsi import gsi + +from ultralytics import YOLO +from ultralytics.data.loaders import LoadImagesAndVideos + +from tracking.detectors import (get_yolo_inferer, default_imgsz, + is_ultralytics_model, is_yolox_model) +from tracking.utils import convert_to_mot_format, write_mot_results, download_mot_eval_tools, download_mot_dataset, unzip_mot_dataset, eval_setup, split_dataset +from boxmot.appearance.reid.auto_backend import ReidAutoBackend + +checker = RequirementsChecker() +checker.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install + + +def cleanup_mot17(data_dir, keep_detection='FRCNN'): + """ + Cleans up the MOT17 dataset to resemble the MOT16 format by keeping only one detection folder per sequence. + Skips sequences that have already been cleaned. + + Args: + - data_dir (str): Path to the MOT17 train directory. + - keep_detection (str): Detection type to keep (options: 'DPM', 'FRCNN', 'SDP'). Default is 'DPM'. + """ + + # Get all folders in the train directory + all_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] + + # Identify unique sequences by removing detection suffixes + unique_sequences = set(seq.split('-')[0] + '-' + seq.split('-')[1] for seq in all_dirs) + + for seq in unique_sequences: + # Directory path to the cleaned sequence + cleaned_seq_dir = os.path.join(data_dir, seq) + + # Skip if the sequence is already cleaned + if os.path.exists(cleaned_seq_dir): + print(f"Sequence {seq} is already cleaned. Skipping.") + continue + + # Directories for each detection method + seq_dirs = [os.path.join(data_dir, d) + for d in all_dirs if d.startswith(seq)] + + # Directory path for the detection folder to keep + keep_dir = os.path.join(data_dir, f"{seq}-{keep_detection}") + + if os.path.exists(keep_dir): + # Move the directory to a new name (removing the detection suffix) + shutil.move(keep_dir, cleaned_seq_dir) + print(f"Moved {keep_dir} to {cleaned_seq_dir}") + + # Remove other detection directories + for seq_dir in seq_dirs: + if os.path.exists(seq_dir) and seq_dir != keep_dir: + shutil.rmtree(seq_dir) + print(f"Removed {seq_dir}") + else: + print(f"Directory for {seq} with {keep_detection} detection does not exist. Skipping.") + + print("MOT17 Cleanup completed!") + + +def prompt_overwrite(path_type: str, path: str, ci: bool = True) -> bool: + """ + Prompts the user to confirm overwriting an existing file. + + Args: + path_type (str): Type of the path (e.g., 'Detections and Embeddings', 'MOT Result'). + path (str): The path to check. + ci (bool): If True, automatically reuse existing file without prompting (for CI environments). + + Returns: + bool: True if user confirms to overwrite, False otherwise. + """ + if ci: + LOGGER.debug(f"{path_type} {path} already exists. Use existing due to no UI mode.") + return False + + def input_with_timeout(prompt, timeout=3.0): + print(prompt, end='', flush=True) + + result = [] + input_received = threading.Event() + + def get_input(): + user_input = sys.stdin.readline().strip().lower() + result.append(user_input) + input_received.set() + + input_thread = threading.Thread(target=get_input) + input_thread.daemon = True # Ensure thread does not prevent program exit + input_thread.start() + input_thread.join(timeout) + + if input_received.is_set(): + return result[0] in ['y', 'yes'] + else: + print("\nNo response, not proceeding with overwrite...") + return False + + return input_with_timeout(f"{path_type} {path} already exists. Overwrite? [y/N]: ") + + +def generate_dets_embs(args: argparse.Namespace, y: Path, source: Path) -> None: + """ + Generates detections and embeddings for the specified + arguments, YOLO model and source. + + Args: + args (Namespace): Parsed command line arguments. + y (Path): Path to the YOLO model file. + source (Path): Path to the source directory. + """ + WEIGHTS.mkdir(parents=True, exist_ok=True) + + if args.imgsz is None: + args.imgsz = default_imgsz(y) + + yolo = YOLO( + y if is_ultralytics_model(y) + else 'yolov8n.pt', + ) + + results = yolo( + source=source, + conf=args.conf, + iou=args.iou, + agnostic_nms=args.agnostic_nms, + stream=True, + device=args.device, + verbose=False, + exist_ok=args.exist_ok, + project=args.project, + name=args.name, + classes=args.classes, + imgsz=args.imgsz, + vid_stride=args.vid_stride, + ) + + if not is_ultralytics_model(y): + m = get_yolo_inferer(y) + yolo_model = m(model=y, device=yolo.predictor.device, + args=yolo.predictor.args) + yolo.predictor.model = yolo_model + + # If current model is YOLOX, change the preprocess and postprocess + if is_yolox_model(y): + # add callback to save image paths for further processing + yolo.add_callback("on_predict_batch_start", + lambda p: yolo_model.update_im_paths(p)) + yolo.predictor.preprocess = ( + lambda im: yolo_model.preprocess(im=im)) + yolo.predictor.postprocess = ( + lambda preds, im, im0s: + yolo_model.postprocess(preds=preds, im=im, im0s=im0s)) + + reids = [] + for r in args.reid_model: + reid_model = ReidAutoBackend(weights=args.reid_model, + device=yolo.predictor.device, + half=args.half).model + reids.append(reid_model) + embs_path = args.project / 'dets_n_embs' / y.stem / 'embs' / r.stem / (source.parent.name + '.txt') + embs_path.parent.mkdir(parents=True, exist_ok=True) + embs_path.touch(exist_ok=True) + + if os.path.getsize(embs_path) > 0: + open(embs_path, 'w').close() + + yolo.predictor.custom_args = args + + dets_path = args.project / 'dets_n_embs' / y.stem / 'dets' / (source.parent.name + '.txt') + dets_path.parent.mkdir(parents=True, exist_ok=True) + dets_path.touch(exist_ok=True) + + if os.path.getsize(dets_path) > 0: + open(dets_path, 'w').close() + + with open(str(dets_path), 'ab+') as f: + np.savetxt(f, [], fmt='%f', header=str(source)) + + for frame_idx, r in enumerate(tqdm(results, desc="Frames")): + nr_dets = len(r.boxes) + frame_idx = torch.full((1, 1), frame_idx + 1).repeat(nr_dets, 1) + img = r.orig_img + + dets = np.concatenate( + [ + frame_idx, + r.boxes.xyxy.to('cpu'), + r.boxes.conf.unsqueeze(1).to('cpu'), + r.boxes.cls.unsqueeze(1).to('cpu'), + ], axis=1 + ) + + # Filter dets with incorrect boxes: (x2 < x1 or y2 < y1) + boxes = r.boxes.xyxy.to('cpu').numpy().round().astype(int) + boxes_filter = ((np.maximum(0, boxes[:, 0]) < np.minimum(boxes[:, 2], img.shape[1])) & + (np.maximum(0, boxes[:, 1]) < np.minimum(boxes[:, 3], img.shape[0]))) + dets = dets[boxes_filter] + + with open(str(dets_path), 'ab+') as f: + np.savetxt(f, dets, fmt='%f') + + for reid, reid_model_name in zip(reids, args.reid_model): + embs = reid.get_features(dets[:, 1:5], img) + embs_path = args.project / "dets_n_embs" / y.stem / 'embs' / reid_model_name.stem / (source.parent.name + '.txt') + with open(str(embs_path), 'ab+') as f: + np.savetxt(f, embs, fmt='%f') + + +def generate_mot_results(args: argparse.Namespace, config_dict: dict = None) -> dict[str, np.ndarray]: + """ + Generates MOT results for the specified arguments and configuration. + + Args: + args (Namespace): Parsed command line arguments. + config_dict (dict, optional): Additional configuration dictionary. + + Returns: + dict[str, np.ndarray]: {seq_name: array} with frame ids used for MOT + """ + args.device = select_device(args.device) + tracker = create_tracker( + args.tracking_method, + TRACKER_CONFIGS / (args.tracking_method + '.yaml'), + args.reid_model[0].with_suffix('.pt'), + args.device, + False, + False, + config_dict + ) + + with open(args.dets_file_path, 'r') as file: + source = Path(file.readline().strip().replace("# ", "")) + + dets = np.loadtxt(args.dets_file_path, skiprows=1) + embs = np.loadtxt(args.embs_file_path) + + dets_n_embs = np.concatenate([dets, embs], axis=1) + + dataset = LoadImagesAndVideos(source) + + txt_path = args.exp_folder_path / (source.parent.name + '.txt') + all_mot_results = [] + + # Change FPS + if args.fps: + + # Extract original FPS + conf_path = source.parent / 'seqinfo.ini' + conf = configparser.ConfigParser() + conf.read(conf_path) + + orig_fps = int(conf.get("Sequence", "frameRate")) + + if orig_fps < args.fps: + LOGGER.warning(f"Original FPS ({orig_fps}) is lower than " + f"requested FPS ({args.fps}) for sequence " + f"{source.parent.name}. Using original FPS.") + target_fps = orig_fps + else: + target_fps = args.fps + + + step = orig_fps/target_fps + else: + step = 1 + + # Create list with frame numbers according to needed step + frame_nums = np.arange(1, len(dataset) + 1, step).astype(int).tolist() + + seq_frame_nums = {source.parent.name: frame_nums.copy()} + + for frame_num, d in enumerate(tqdm(dataset, desc=source.parent.name), 1): + # Filter using list with needed numbers + if len(frame_nums) > 0: + if frame_num < frame_nums[0]: + continue + else: + frame_nums.pop(0) + + im = d[1][0] + frame_dets_n_embs = dets_n_embs[dets_n_embs[:, 0] == frame_num] + + dets = frame_dets_n_embs[:, 1:7] + embs = frame_dets_n_embs[:, 7:] + tracks = tracker.update(dets, im, embs) + + if tracks.size > 0: + mot_results = convert_to_mot_format(tracks, frame_num) + all_mot_results.append(mot_results) + + if all_mot_results: + all_mot_results = np.vstack(all_mot_results) + else: + all_mot_results = np.empty((0, 0)) + + write_mot_results(txt_path, all_mot_results) + + return seq_frame_nums + + +def parse_mot_results(results: str) -> dict: + """ + Extracts the COMBINED HOTA, MOTA, IDF1 from the results generated by the run_mot_challenge.py script. + + Args: + results (str): MOT results as a string. + + Returns: + dict: A dictionary containing HOTA, MOTA, and IDF1 scores. + """ + combined_results = results.split('COMBINED')[2:-1] + combined_results = [float(re.findall(r"[-+]?(?:\d*\.*\d+)", f)[0]) + for f in combined_results] + + results_dict = {} + for key, value in zip(["HOTA", "MOTA", "IDF1"], combined_results): + results_dict[key] = value + + return results_dict + + +def trackeval(args: argparse.Namespace, seq_paths: list, save_dir: Path, MOT_results_folder: Path, gt_folder: Path, metrics: list = ["HOTA", "CLEAR", "Identity"]) -> str: + """ + Executes a Python script to evaluate MOT challenge tracking results using specified metrics. + + Args: + seq_paths (list): List of sequence paths. + save_dir (Path): Directory to save evaluation results. + MOT_results_folder (Path): Folder containing MOT results. + gt_folder (Path): Folder containing ground truth data. + metrics (list, optional): List of metrics to use for evaluation. Defaults to ["HOTA", "CLEAR", "Identity"]. + + Returns: + str: Standard output from the evaluation script. + """ + + d = [seq_path.parent.name for seq_path in seq_paths] + + args = [ + sys.executable, EXAMPLES / 'val_utils' / 'scripts' / 'run_mot_challenge.py', + "--GT_FOLDER", str(gt_folder), + "--BENCHMARK", "", + "--TRACKERS_FOLDER", args.exp_folder_path, + "--TRACKERS_TO_EVAL", "", + "--SPLIT_TO_EVAL", "train", + "--METRICS", *metrics, + "--USE_PARALLEL", "True", + "--TRACKER_SUB_FOLDER", "", + "--NUM_PARALLEL_CORES", str(4), + "--SKIP_SPLIT_FOL", "True", + "--GT_LOC_FORMAT", "{gt_folder}/{seq}/gt/gt_temp.txt", + "--SEQ_INFO", *d + ] + + p = subprocess.Popen( + args=args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + stdout, stderr = p.communicate() + + if stderr: + print("Standard Error:\n", stderr) + return stdout + + +def run_generate_dets_embs(opt: argparse.Namespace) -> None: + """ + Runs the generate_dets_embs function for all YOLO models and source directories. + + Args: + opt (Namespace): Parsed command line arguments. + """ + mot_folder_paths = sorted([item for item in Path(opt.source).iterdir()]) + for y in opt.yolo_model: + for i, mot_folder_path in enumerate(mot_folder_paths): + dets_path = Path(opt.project) / 'dets_n_embs' / y.stem / 'dets' / (mot_folder_path.name + '.txt') + embs_path = Path(opt.project) / 'dets_n_embs' / y.stem / 'embs' / (opt.reid_model[0].stem) / (mot_folder_path.name + '.txt') + if dets_path.exists() and embs_path.exists(): + if prompt_overwrite('Detections and Embeddings', dets_path, opt.ci): + LOGGER.debug(f'Overwriting detections and embeddings for {mot_folder_path}...') + else: + LOGGER.debug(f'Skipping generation for {mot_folder_path} as they already exist.') + continue + LOGGER.debug(f'Generating detections and embeddings for data under {mot_folder_path} [{i + 1}/{len(mot_folder_paths)} seqs]') + generate_dets_embs(opt, y, source=mot_folder_path / 'img1') + + +def process_single_mot(opt: argparse.Namespace, d: Path, e: Path, evolve_config: dict): + # Create a deep copy of opt so each task works independently + new_opt = copy.deepcopy(opt) + new_opt.dets_file_path = d + new_opt.embs_file_path = e + frames_dict = generate_mot_results(new_opt, evolve_config) + return frames_dict + +def run_generate_mot_results(opt: argparse.Namespace, evolve_config: dict = None) -> None: + """ + Runs the generate_mot_results function for all YOLO models and detection/embedding files + in parallel. + """ + + for y in opt.yolo_model: + exp_folder_path = opt.project / 'mot' / (f"{y.stem}_{opt.reid_model[0].stem}_{opt.tracking_method}") + exp_folder_path = increment_path(path=exp_folder_path, sep="_", exist_ok=False) + opt.exp_folder_path = exp_folder_path + + mot_folder_names = [item.stem for item in Path(opt.source).iterdir()] + + dets_folder = opt.project / "dets_n_embs" / y.stem / 'dets' + embs_folder = opt.project / "dets_n_embs" / y.stem / 'embs' / opt.reid_model[0].stem + + dets_file_paths = sorted([ + item for item in dets_folder.glob('*.txt') + if not item.name.startswith('.') and item.stem in mot_folder_names + ]) + embs_file_paths = sorted([ + item for item in embs_folder.glob('*.txt') + if not item.name.startswith('.') and item.stem in mot_folder_names + ]) + + LOGGER.info(f"\nStarting tracking on:\n\t{opt.source}\nwith preloaded dets\n\t({dets_folder.relative_to(ROOT)})\nand embs\n\t({embs_folder.relative_to(ROOT)})\nusing\n\t{opt.tracking_method}") + + tasks = [] + # Create a thread pool to run each file pair in parallel + with concurrent.futures.ThreadPoolExecutor() as executor: + for d, e in zip(dets_file_paths, embs_file_paths): + mot_result_path = exp_folder_path / (d.stem + '.txt') + if mot_result_path.exists(): + if prompt_overwrite('MOT Result', mot_result_path, opt.ci): + LOGGER.info(f'Overwriting MOT result for {d.stem}...') + else: + LOGGER.info(f'Skipping MOT result generation for {d.stem} as it already exists.') + continue + # Submit the task to process this file pair in parallel + tasks.append(executor.submit(process_single_mot, opt, d, e, evolve_config)) + + # Dict with {seq_name: [frame_nums]} + seqs_frame_nums = {} + # Wait for all tasks to complete and log any exceptions + for future in concurrent.futures.as_completed(tasks): + try: + seqs_frame_nums.update(future.result()) + except Exception as exc: + LOGGER.error(f'Error processing file pair: {exc}') + + # Postprocess data with gsi if requested + if opt.gsi: + gsi(mot_results_folder=opt.exp_folder_path) + + with open(opt.exp_folder_path / 'seqs_frame_nums.json', 'w') as f: + json.dump(seqs_frame_nums, f) + + +def run_trackeval(opt: argparse.Namespace) -> dict: + """ + Runs the trackeval function to evaluate tracking results. + + Args: + opt (Namespace): Parsed command line arguments. + """ + seq_paths, save_dir, MOT_results_folder, gt_folder = eval_setup(opt, opt.val_tools_path) + trackeval_results = trackeval(opt, seq_paths, save_dir, MOT_results_folder, gt_folder) + hota_mota_idf1 = parse_mot_results(trackeval_results) + if opt.verbose: + LOGGER.info(trackeval_results) + with open(opt.tracking_method + "_output.json", "w") as outfile: + outfile.write(json.dumps(hota_mota_idf1)) + LOGGER.info(json.dumps(hota_mota_idf1)) + return hota_mota_idf1 + + +def run_all(opt: argparse.Namespace) -> None: + """ + Runs all stages of the pipeline: generate_dets_embs, generate_mot_results, and trackeval. + + Args: + opt (Namespace): Parsed command line arguments. + """ + run_generate_dets_embs(opt) + run_generate_mot_results(opt) + run_trackeval(opt) + + +def parse_opt() -> argparse.Namespace: + parser = argparse.ArgumentParser() + + # Global arguments + parser.add_argument('--yolo-model', nargs='+', type=Path, default=[WEIGHTS / 'yolov8n.pt'], help='yolo model path') + parser.add_argument('--reid-model', nargs='+', type=Path, default=[WEIGHTS / 'osnet_x0_25_msmt17.pt'], help='reid model path') + parser.add_argument('--source', type=str, help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=None, help='inference size h,w') + parser.add_argument('--fps', type=int, default=None, help='video frame-rate') + parser.add_argument('--conf', type=float, default=0.01, help='min confidence threshold') + parser.add_argument('--iou', type=float, default=0.7, help='intersection over union (IoU) threshold for NMS') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--classes', nargs='+', type=int, default=0, help='filter by class: --classes 0, or --classes 0 2 3') + parser.add_argument('--project', default=ROOT / 'runs', type=Path, help='save results to project/name') + parser.add_argument('--name', default='', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', default=True, help='existing project/name ok, do not increment') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride') + parser.add_argument('--ci', action='store_true', help='Automatically reuse existing due to no UI in CI') + parser.add_argument('--tracking-method', type=str, default='deepocsort', help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc, boosttrack') + parser.add_argument('--dets-file-path', type=Path, help='path to detections file') + parser.add_argument('--embs-file-path', type=Path, help='path to embeddings file') + parser.add_argument('--exp-folder-path', type=Path, help='path to experiment folder') + parser.add_argument('--verbose', action='store_true', help='print results') + parser.add_argument('--agnostic-nms', default=False, action='store_true', help='class-agnostic NMS') + parser.add_argument('--gsi', action='store_true', help='apply Gaussian smooth interpolation postprocessing') + parser.add_argument('--n-trials', type=int, default=4, help='nr of trials for evolution') + parser.add_argument('--objectives', type=str, nargs='+', default=["HOTA", "MOTA", "IDF1"], help='set of objective metrics: HOTA,MOTA,IDF1') + parser.add_argument('--val-tools-path', type=Path, default=EXAMPLES / 'val_utils', help='path to store trackeval repo in') + parser.add_argument('--split-dataset', action='store_true', help='Use the second half of the dataset') + + subparsers = parser.add_subparsers(dest='command') + + # Subparser for generate_dets_embs + generate_dets_embs_parser = subparsers.add_parser('generate_dets_embs', help='Generate detections and embeddings') + generate_dets_embs_parser.add_argument('--source', type=str, required=True, help='file/dir/URL/glob, 0 for webcam') + generate_dets_embs_parser.add_argument('--yolo-model', nargs='+', type=Path, default=WEIGHTS / 'yolov8n.pt', help='yolo model path') + generate_dets_embs_parser.add_argument('--reid-model', nargs='+', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', help='reid model path') + generate_dets_embs_parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') + generate_dets_embs_parser.add_argument('--classes', nargs='+', type=int, default=0, help='filter by class: --classes 0, or --classes 0 2 3') + + # Subparser for generate_mot_results + generate_mot_results_parser = subparsers.add_parser('generate_mot_results', help='Generate MOT results') + generate_mot_results_parser.add_argument('--yolo-model', nargs='+', type=Path, default=WEIGHTS / 'yolov8n.pt', help='yolo model path') + generate_mot_results_parser.add_argument('--reid-model', nargs='+', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', help='reid model path') + generate_mot_results_parser.add_argument('--tracking-method', type=str, default='deepocsort', help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc, boosttrack') + generate_mot_results_parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') + + # Subparser for trackeval + trackeval_parser = subparsers.add_parser('trackeval', help='Evaluate tracking results') + trackeval_parser.add_argument('--source', type=str, required=True, help='file/dir/URL/glob, 0 for webcam') + trackeval_parser.add_argument('--exp-folder-path', type=Path, required=True, help='path to experiment folder') + + opt = parser.parse_args() + source_path = Path(opt.source) + opt.benchmark, opt.split = source_path.parent.name, source_path.name + + return opt + + +if __name__ == "__main__": + opt = parse_opt() + + # download MOT benchmark + download_mot_eval_tools(opt.val_tools_path) + + if not Path(opt.source).exists(): + zip_path = download_mot_dataset(opt.val_tools_path, opt.benchmark) + unzip_mot_dataset(zip_path, opt.val_tools_path, opt.benchmark) + + if opt.benchmark == 'MOT17': + cleanup_mot17(opt.source) + + if opt.split_dataset: + opt.source, opt.benchmark = split_dataset(opt.source) + + if opt.command == 'generate_dets_embs': + run_generate_dets_embs(opt) + elif opt.command == 'generate_mot_results': + run_generate_mot_results(opt) + elif opt.command == 'trackeval': + run_trackeval(opt) + else: + run_all(opt) diff --git a/tracking/weights/.gitkeep b/tracking/weights/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391