+ >>> print(inst_obj.name) # -> Will print John
+
+ Args:
+ clazz (type | Callable[[Any], Any] | str): class type or functor or
+ class string path.
+ **kwargs (ArgsType): Kwargs to pass to the class constructor.
+
+ Returns:
+ ConfigDict: _description_
+ """
+ class_path = resolve_class_name(clazz)
+ if class_path is None or len(kwargs) == 0:
+ return ConfigDict({"class_path": class_path})
+ return ConfigDict(
+ {"class_path": class_path, "init_args": ConfigDict(kwargs)}
+ )
+
+
+def delay_instantiation(instantiable: ConfigDict) -> ConfigDict:
+ """Delays the instantiation of the given configuration object.
+
+ This is a somewhat hacky way to delay the initialization of the optimizer
+ configuration object. It works by replacing the class_path with _class_path
+ which basically tells the instantiate_classes function to not instantiate
+ the class. Instead, it returns a function that can be called to instantiate
+ the class
+
+ Args:
+ instantiable (ConfigDict): The configuration object to delay the
+ instantiation of.
+ """
+ instantiable["_class_path"] = instantiable["class_path"]
+ del instantiable["class_path"]
+
+ return class_config(DelayedInstantiator, instantiable=instantiable)
+
+
+class DelayedInstantiator:
+ """Class that delays the instantiation of the given configuration object.
+
+ This is a somewhat hacky way to delay the initialization of the optimizer
+ configuration object. It works by replacing the class_path with _class_path
+ which basically tells the instantiate_classes function to not instantiate
+ the class. Instead, it returns a function that can be called to instantiate
+ the class.
+
+ Args:
+ instantiable (ConfigDict): The configuration object to delay the
+ instantiation of.
+ """
+
+ def __init__(self, instantiable: ConfigDict) -> None:
+ """Instantiates the DelayedInstantiator."""
+ self.instantiable = instantiable
+
+ def __call__(self, **kwargs: ArgsType) -> Any: # type: ignore
+ """Instantiates the configuration object."""
+ instantiable = class_config(
+ self.instantiable["_class_path"],
+ **self.instantiable.get("init_args", {}),
+ )
+
+ return instantiate_classes(instantiable, **kwargs)
+
+
+def instantiate_classes(data: ConfigDict | FieldReference, **kwargs: ArgsType) -> ConfigDict | Any: # type: ignore # pylint: disable=line-too-long
+ """Instantiates all classes in a given ConfigDict.
+
+ This function iterates over the configuration data and instantiates
+ all classes. Class defintions are provided by a config dict that has
+ the following structure:
+
+ {
+ 'data_path': 'path.to.my.class.Class',
+ 'init_args': ConfigDict(
+ {
+ 'arg1': 'value1',
+ 'arg2': 'value2',
+ }
+ )
+ }
+
+ Args:
+ data (ConfigDict | FieldReference): The general configuration object.
+ **kwargs (ArgsType): Additional arguments to pass to the class
+ constructor.
+
+ Returns:
+ ConfigDict | Any: The instantiated objects.
+ """
+ if isinstance(data, FieldReference): # De-Reference the field reference
+ data = data.get()
+
+ assert isinstance(data, ConfigDict), "Data must be a ConfigDict."
+
+ if isinstance(data, FieldConfigDict):
+ data.value_mode() # make sure data is in value mode
+
+ if len(kwargs) > 0:
+ if "init_args" not in data:
+ data["init_args"] = ConfigDict(kwargs)
+ else:
+ for k, v in kwargs.items():
+ data["init_args"][k] = v
+
+ resolved_data = copy_and_resolve_references(data)
+ instantiated_objects = _instantiate_classes(resolved_data)
+ return instantiated_objects
+
+
+def copy_and_resolve_references( # type: ignore
+ data: Any, visit_map: dict[int, Any] | None = None
+) -> Any:
+ """Returns a ConfigDict copy with FieldReferences replaced by values.
+
+ If the object is a FrozenConfigDict, the copy returned is also a
+ FrozenConfigDict. However, note that FrozenConfigDict should already have
+ FieldReferences resolved to values, so this method effectively produces
+ a deep copy.
+
+ Note: This method is overwritten from the ConfigDict class and allows to
+ also resolve FieldReferences in list, tuple and dict.
+
+ Args:
+ data (Any): object to copy.
+ visit_map (dict[int, Any]): A mapping from ConfigDict object ids to
+ their copy. Method is recursive in nature, and it will call
+ "copy_and_resolve_references(visit_map)" on each encountered
+ object, unless it is already in visit_map.
+
+ Returns:
+ Any: ConfigDict copy with previous FieldReferences replaced by values.
+ """
+ if isinstance(data, FieldReference):
+ data = data.get()
+
+ if is_namedtuple(data):
+ return type(data)(
+ **{
+ key: copy_and_resolve_references(getattr(data, key))
+ for key in get_all_keys(data)
+ }
+ )
+
+ if isinstance(data, (list, tuple)):
+ return type(data)(
+ copy_and_resolve_references(value, visit_map) for value in data
+ )
+
+ if isinstance(data, dict):
+ return {
+ k: copy_and_resolve_references(v, visit_map)
+ for k, v in data.items()
+ }
+
+ if not isinstance(data, ConfigDict):
+ return data
+
+ visit_map = visit_map or {}
+ config_dict = ConfigDict()
+
+ # copy attributes
+ super(ConfigDict, config_dict).__setattr__(
+ "_convert_dict", config_dict.convert_dict
+ )
+ visit_map[id(config_dict)] = config_dict
+
+ for key, value in data._fields.items():
+ if isinstance(value, FieldReference):
+ value = value.get()
+
+ if id(value) in visit_map:
+ value = visit_map[id(value)]
+
+ elif isinstance(value, ConfigDict):
+ value = copy_and_resolve_references(value, visit_map)
+
+ elif is_namedtuple(value):
+ value = type(value)(
+ **{
+ key: copy_and_resolve_references(getattr(value, key))
+ for key in get_all_keys(value)
+ }
+ )
+
+ elif isinstance(value, (list, tuple)):
+ value = type(value)(
+ copy_and_resolve_references(v, visit_map) for v in value
+ )
+
+ elif isinstance(value, dict):
+ value = {
+ k: copy_and_resolve_references(v, visit_map)
+ for k, v in value.items()
+ }
+
+ if isinstance(data, FrozenConfigDict):
+ config_dict._frozen_setattr( # pylint:disable=protected-access
+ key, value
+ )
+ else:
+ config_dict[key] = value
+
+ # copy attributes
+ super(ConfigDict, config_dict).__setattr__("_locked", data.is_locked)
+ super(ConfigDict, config_dict).__setattr__("_type_safe", data.is_type_safe)
+ return config_dict
+
+
+def _get_index(data: Any) -> Any: # type: ignore
+ """Internal function to generate a Sequence of indexes for a given object.
+
+ Example:
+ >>> [data[idx] for idx in _get_index(data)]
+
+ Args:
+ data (Any): The data entry to get an index for.
+
+ Returns:
+ Any: Iterable that can be used to index the data entry using e.g.
+ [data[idx] for idx in _get_index(data)]
+ """
+ if isinstance(data, (list, tuple)):
+ return range(len(data))
+ return data
+
+
+def _instantiate_classes(data: Any) -> Any: # type: ignore
+ """Instantiates all classes in a given data.
+
+ Data could be ConfigDict, FieldReference, tuple, list or dict.
+
+ This is the recursive implementation of the 'instantiate_classes'.
+
+ This function iterates over the configuration data and instantiates
+ all classes. Class defintions are provided by a config dict that has
+ the following structure:
+
+ {
+ 'data_path': 'path.to.my.class.Class',
+ 'init_args': ConfigDict(
+ {
+ 'arg1': 'value1',
+ 'arg2': 'value2',
+ }
+ )
+ }
+
+ Args:
+ data (Any): The general configuration object.
+
+ Returns:
+ Any: The ConfigDict with all classes intialized. Or, if the top level
+ element is a class config, the returned element will be the
+ instantiated class.
+ """
+ if isinstance(data, FieldReference):
+ data = data.get()
+
+ if not isinstance(data, (ConfigDict, dict, list, tuple)):
+ return data
+
+ for key in _get_index(data):
+ value = data[key]
+
+ if isinstance(value, FieldReference):
+ value = value.get()
+
+ if isinstance(value, (ConfigDict, dict)):
+ if isinstance(data, ConfigDict):
+ with data.ignore_type():
+ data[key] = _instantiate_classes(value)
+ else:
+ data[key] = _instantiate_classes(value)
+
+ elif is_namedtuple(value):
+ if isinstance(data, ConfigDict):
+ with data.ignore_type():
+ data[key] = type(value)(
+ **{
+ key: _instantiate_classes(getattr(value, key))
+ for key in get_all_keys(value)
+ }
+ )
+ else:
+ data[key] = type(value)(
+ **{
+ key: _instantiate_classes(getattr(value, key))
+ for key in get_all_keys(value)
+ }
+ )
+
+ elif isinstance(value, (list, tuple)):
+ if isinstance(data, ConfigDict):
+ with data.ignore_type():
+ data[key] = type(value)(
+ _instantiate_classes(value[idx])
+ for idx in range(len(value))
+ )
+ else:
+ data[key] = type(value)(
+ _instantiate_classes(value[idx])
+ for idx in range(len(value))
+ )
+
+ # Instantiate classs
+ if "class_path" in data and not isinstance(data["class_path"], ConfigDict):
+ module_name, class_name = data["class_path"].rsplit(".", 1)
+ init_args = data.get("init_args", {})
+
+ # Convert ConfigDict to normal dictionary
+ if isinstance(init_args, ConfigDict):
+ init_args = init_args.to_dict()
+
+ module = importlib.import_module(module_name)
+ # Instantiate class
+ clazz = getattr(module, class_name)(**init_args)
+ return clazz
+
+ return data
diff --git a/vis4d/config/registry.py b/vis4d/config/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d39870a9829dbef145b39199a72bad652f8d7da
--- /dev/null
+++ b/vis4d/config/registry.py
@@ -0,0 +1,262 @@
+"""Utility function for registering config files."""
+
+from __future__ import annotations
+
+import glob
+import os
+import pathlib
+import warnings
+from typing import Callable, Union
+
+import yaml
+from ml_collections import ConfigDict
+from ml_collections.config_flags.config_flags import _LoadConfigModule
+
+from vis4d.common.dict import flatten_dict, get_dict_nested
+from vis4d.common.typing import ArgsType
+from vis4d.common.util import create_did_you_mean_msg
+from vis4d.config.config_dict import FieldConfigDict
+from vis4d.zoo import AVAILABLE_MODELS
+
+MODEL_ZOO_FOLDER = str(
+ (pathlib.Path(os.path.dirname(__file__)) / ".." / "zoo").resolve()
+)
+
+# Paths that are used to search for config files.
+REGISTERED_CONFIG_PATHS = [MODEL_ZOO_FOLDER]
+
+
+TFunc = Callable[[ArgsType], ArgsType]
+TfuncConfDict = Union[Callable[[ArgsType], ConfigDict], type]
+
+
+def register_config(
+ category: str, name: str
+) -> Callable[[TfuncConfDict], None]:
+ """Register a config in the model zoo for the given name and category.
+
+ The config will then be available via `get_config_by_name` utilities and
+ located in the AVAILABLE_MODELS dictionary located at
+ [category][name].
+
+ Args:
+ category: Category of the config.
+ name: Name of the config.
+
+ Returns:
+ The decorator.
+ """
+
+ def decorator(fnc_or_clazz: TfuncConfDict) -> None:
+ """Decorator for registering a config.
+
+ Args:
+ fnc_or_clazz: Function or class to register. If a function is
+ passed, it will be wrapped in a class and the class will be
+ registered. If a class is passed, it will be registered
+ directly.
+ """
+ if callable(fnc_or_clazz):
+ # Directly annotated get_config function. Wrap it and register it.
+ class Wrapper:
+ """Wrapper class."""
+
+ def get_config(
+ self, *args: ArgsType, **kwargs: ArgsType
+ ) -> ConfigDict:
+ """Resolves the get_config function."""
+ return fnc_or_clazz(*args, **kwargs)
+
+ module = Wrapper()
+ else:
+ # Directly annotated class. Register it.
+ module = fnc_or_clazz
+
+ # Register the config
+ if category not in AVAILABLE_MODELS:
+ AVAILABLE_MODELS[category] = {}
+
+ assert isinstance(AVAILABLE_MODELS[category], dict)
+
+ AVAILABLE_MODELS[category][name] = module
+
+ return decorator
+
+
+def _resolve_config_path(path: str) -> str:
+ """Resolve the path of a config file.
+
+ Args:
+ path: Name or path of the config.
+ If the config is not found at this location,
+ the function will look for the config in the model zoo folder.
+
+ Returns:
+ The resolved path of the config.
+
+ Raises:
+ ValueError: If the config is not found.
+ """
+ if os.path.exists(path):
+ return path
+
+ # Check for duplicate paths.
+ found_paths: list[str] = []
+ all_paths = []
+
+ for p in REGISTERED_CONFIG_PATHS:
+ paths = sorted(
+ glob.glob(
+ os.path.join(p, f"**/*{ os.path.splitext(path)[-1]}"),
+ recursive=True,
+ )
+ )
+ print(
+ paths,
+ "lookup",
+ os.path.join(p, f"**/*{ os.path.splitext(path)[-1]}"),
+ )
+ for cfg_path in paths:
+ if cfg_path.endswith(path):
+ found_paths.append(cfg_path)
+ all_paths.extend(paths)
+
+ if len(found_paths) > 1:
+ warnings.warn(
+ f"Found multiple paths for config {path}:"
+ f"{found_paths}. Will load the config from the first one!"
+ )
+ elif len(found_paths) == 0:
+ hint = create_did_you_mean_msg(
+ [*all_paths, *[os.path.basename(p) for p in all_paths]], path
+ )
+ raise ValueError(
+ f"Could not find config {path}. \n"
+ f"The file does not exists at the path {path} or "
+ f"in the dedicated locations at {REGISTERED_CONFIG_PATHS}. \n"
+ f"Please check the path or add the config to the model zoo. \n"
+ f"Current working directory: {os.getcwd()}\n {hint}"
+ )
+ return found_paths[0]
+
+
+def _load_yaml_config(name_or_path: str) -> FieldConfigDict:
+ """Loads a .yaml configuration file.
+
+ Args:
+ name_or_path: Name or path of the config.
+ If the config is not found at this location, $
+ the function will look for the config in the model zoo folder.
+
+ Returns:
+ The config for the experiment.
+ """
+ path = _resolve_config_path(name_or_path)
+ with open(path, "r", encoding="utf-8") as yaml_file:
+ return FieldConfigDict(yaml.load(yaml_file, Loader=yaml.UnsafeLoader))
+
+
+def _load_py_config(
+ name_or_path: str, *args: ArgsType, method_name: str = "get_config"
+) -> ConfigDict:
+ """Loads a .py configuration file.
+
+ Args:
+ name_or_path: Name or path of the config.
+ If the config is not found at this location,
+ the function will look for the config in the model zoo folder.
+ *args: Additional arguments to pass to the config.
+ method_name: Name of the method to call from the file to get the
+ config. Defaults to "get_config".
+
+ Returns:
+ The config for the experiment.
+ """
+ path = _resolve_config_path(name_or_path)
+ config_module = _LoadConfigModule(f"{os.path.basename(path)}_config", path)
+ cfg = getattr(config_module, method_name)(*args)
+ assert isinstance(cfg, ConfigDict)
+ return cfg
+
+
+def _get_registered_configs(
+ config_name: str, *args: ArgsType, method_name: str = "get_config"
+) -> ConfigDict:
+ """Get a model from the registered config locations.
+
+ Args:
+ config_name: Name of the config. This can either be
+ the full path of the config relative to the registered locations
+ or the name of the config.
+ If the config matches multiple configs (e.g. if there are two
+ conflicting config a/cfg and b/cfg) or if it is not found,
+ a ValueError is raised.
+ *args: Additional arguments to pass to the config.
+ method_name: Name of the method to call from the file to get the
+ config. Defaults to "get_config".
+
+ Raises:
+ ValueError: If the config is not found.
+
+ Returns:
+ The Config.
+ """
+ models = flatten_dict(AVAILABLE_MODELS, os.path.sep)
+ # check if there is an absolute match for the config
+ if config_name in models:
+ module = get_dict_nested(
+ AVAILABLE_MODELS, config_name.split(os.path.sep)
+ )
+ return getattr(module, method_name)(*args)
+ # check if there is a partial match for the config
+ matches = {}
+ for model in models:
+ if model.endswith(config_name):
+ matches[model] = get_dict_nested(
+ AVAILABLE_MODELS, model.split(os.path.sep)
+ )
+
+ if len(matches) > 1:
+ raise ValueError(
+ f"Found multiple configs matching {config_name}:"
+ f"{matches.keys()}.\nPlease specify a unique config name."
+ )
+ if len(matches) == 0:
+ msg = create_did_you_mean_msg(
+ [*models, *[os.path.basename(m) for m in models]], config_name
+ )
+ raise ValueError(msg)
+
+ module = list(matches.values())[0]
+ return getattr(module, method_name)(*args)
+
+
+def get_config_by_name(
+ name_or_path: str, *args: ArgsType, method_name: str = "get_config"
+) -> ConfigDict:
+ """Get a config by name or path.
+
+ Args:
+ name_or_path: Name or path of the config.
+ If the path has a .yaml or .py extension, the function will
+ load the config from the file.
+ Otherwise, the function will try to resolve the config from the
+ registered config locations. You can specify a config by its full
+ registered path (e.g. "a/b/cfg") or by its name (e.g. "cfg").
+ *args: Additional arguments to pass to the config.
+ method_name: Name of the method to call from the file to get the
+ config. Defaults to "get_config".
+
+ Returns:
+ The config.
+
+ Raises:
+ ValueError: If the config is not found.
+ """
+ if name_or_path.endswith(".yaml"):
+ return _load_yaml_config(name_or_path)
+ if name_or_path.endswith(".py"):
+ return _load_py_config(name_or_path, *args, method_name=method_name)
+ return _get_registered_configs(
+ name_or_path, *args, method_name=method_name
+ )
diff --git a/vis4d/config/show_connection.py b/vis4d/config/show_connection.py
new file mode 100644
index 0000000000000000000000000000000000000000..e348815f9a2d12e46f16469d0631f3ff59e74414
--- /dev/null
+++ b/vis4d/config/show_connection.py
@@ -0,0 +1,551 @@
+"""Show connected components in the config."""
+
+from __future__ import annotations
+
+import inspect
+from typing import Any, TypedDict, get_type_hints
+
+from absl import app # pylint: disable=no-name-in-module
+from torch import nn
+
+from vis4d.common.typing import ArgsType
+from vis4d.engine.callbacks import (
+ Callback,
+ EvaluatorCallback,
+ VisualizerCallback,
+)
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.engine.flag import _CONFIG
+from vis4d.engine.loss_module import LossModule
+from vis4d.eval.base import Evaluator
+from vis4d.vis.base import Visualizer
+
+from .config_dict import instantiate_classes
+
+
+# Types
+class DataConnectionInfo(TypedDict):
+ """Internal type def for visualization.
+
+ This defines a block component
+ """
+
+ in_keys: list[str]
+ out_keys: list[str]
+ name: str
+
+
+# Private Functions
+def _rename_ds(name: str) -> str:
+ """Replaces data with d and prediction with p.
+
+ Use this to remap the datasources to shorter names.
+
+ Args:
+ name: Name to remap
+
+ Returns:
+ remapped name
+ """
+ return name.replace("data", "d").replace("prediction", "p")
+
+
+def _get_model_conn_infos(
+ model: nn.Module,
+) -> dict[str, DataConnectionInfo]:
+ """Returns the connection infos for a pytorch Model.
+
+ Requires "forward_train" and "forward_test" to be defined and properly
+ typed!
+
+ Args:
+ model: Model to extract data from
+
+ Returns:
+ train_connections, test_connections
+ """
+ train_t = get_type_hints(model.forward_train)["return"]
+ test_t = get_type_hints(model.forward_test)["return"]
+
+ train_connection_info = DataConnectionInfo(
+ in_keys=sorted(
+ list(inspect.signature(model.forward).parameters.keys())
+ ),
+ out_keys=[
+ "-" + e for e in sorted(resolve_named_tuple(train_t, prefix=""))
+ ],
+ name=model.__class__.__name__,
+ )
+
+ test_connection_info = DataConnectionInfo(
+ in_keys=sorted(
+ list(inspect.signature(model.forward).parameters.keys())
+ ),
+ out_keys=[
+ "
-" + e for e in sorted(resolve_named_tuple(test_t, prefix=""))
+ ],
+ name=model.__class__.__name__,
+ )
+ return {"train": train_connection_info, "test": test_connection_info}
+
+
+def _get_loss_connection_infos(loss: LossModule) -> list[DataConnectionInfo]:
+ """Returns the connection infos for a loss.
+
+ Args:
+ loss (LossModule): Custom loss module with .forward()
+
+ Returns:
+ DataConnectionInfo for the loss.
+ """
+ loss_connection_info = []
+ for l in loss.losses:
+ loss_out = []
+ loss_in = []
+ for entry, value in l["connector"].key_mapping.items():
+ loss_out.append(f"{entry}")
+ loss_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"])
+
+ loss_connection_info.append(
+ DataConnectionInfo(
+ in_keys=loss_in, out_keys=loss_out, name=l["name"]
+ )
+ )
+
+ return loss_connection_info
+
+
+def _get_vis_connection_infos(
+ visualizer: Visualizer,
+) -> DataConnectionInfo:
+ """Returns the connection infos for a visualizer.
+
+ Args:
+ visualizer: Visualizer to extract data from
+
+ Returns:
+ DataConnectionInfo for the visualizer.
+ """
+ return DataConnectionInfo(
+ in_keys=sorted(
+ list(inspect.signature(visualizer.process).parameters.keys())
+ ),
+ out_keys=[],
+ name=visualizer.__class__.__name__,
+ )
+
+
+def _get_evaluator_connection_infos(
+ evaluator: Evaluator,
+) -> DataConnectionInfo:
+ """Returns the connection infos for an evaluator.
+
+ Args:
+ evaluator: Evaluator to extract data from
+
+ Returns:
+ DataConnectionInfo for the evaluator.
+ """
+ return DataConnectionInfo(
+ in_keys=sorted(
+ list(inspect.signature(evaluator.process).parameters.keys())
+ ),
+ out_keys=[],
+ name=evaluator.__class__.__name__,
+ )
+
+
+def _get_data_connector_infos(
+ data_connector: DataConnector, name: str
+) -> DataConnectionInfo:
+ """Returns the connection infos for a DataConnector.
+
+ Args:
+ data_connector (DataConnector): Data connector to extract data.
+ name (str): Name of the data connector.
+
+ Returns:
+ DataConnectionInfo for the data connector.
+ """
+ return DataConnectionInfo(
+ in_keys=["-" + e for e in list(data_connector.key_mapping.keys())],
+ out_keys=list(data_connector.key_mapping.values()),
+ name=name,
+ )
+
+
+def _get_cb_connection_infos(
+ name: str,
+ cb_data_connector: None | CallbackConnector = None,
+) -> DataConnectionInfo | None:
+ """Returns the connection infos for a callback."""
+ if cb_data_connector is not None:
+ eval_out = []
+ eval_in = []
+ for entry, value in cb_data_connector.key_mapping.items():
+ eval_out.append(f"{entry}")
+ eval_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"])
+ return DataConnectionInfo(
+ in_keys=eval_in, out_keys=eval_out, name=name
+ )
+ return None
+
+
+def _get_with_color(key: str, warn_unconnected: bool = True) -> str:
+ """Prepends colors for internal vsiualization."""
+ if "*" in key:
+ # We connected this one
+ return f"\033[94m{key}\033[00m"
+ if "" in key: # key comes from data
+ return f"\033[90m{key}\033[00m"
+
+ # comes from prediction and is not connected
+ if warn_unconnected:
+ return f"\u001b[33m{key}\033[00m"
+ return f"\033[00m{key}\033[00m"
+
+
+# API Functions
+def print_box(
+ title: str, inputs: list[str], outputs: list[str], use_color: bool = True
+) -> str:
+ """Prints a box with title and in/outputs.
+
+ Args:
+ title: Title to plot in the middle.
+ inputs: inputs to plot on the left.
+ outputs: Outputs to plot on the right.
+ use_color: Whether to use color in the output.
+
+ Returns:
+ str: The box as a string.
+
+ Example:
+ --------------
+ -boxes2d | | *boxes2d
+ -boxes2d_classes | | *boxes2d_classes
+ -images | Train Data | *images
+ -input_hw | | *input_hw
+ --------------
+ """
+ len_title = len(title) + 4
+
+ n_lines = max(len(inputs), len(outputs))
+
+ max_len_inputs = max([0] + [len(inp) for inp in inputs])
+ max_len_outputs = max([0] + [len(out) for out in outputs])
+
+ divider = (
+ " " * (max_len_inputs + 1)
+ + "-" * len_title
+ + " " * (max_len_outputs + 1)
+ )
+ lines = divider + "\n"
+ for idx in range(n_lines):
+ in_data = inputs[idx] if len(inputs) > idx else ""
+ # left pad
+ in_key = " " * (max_len_inputs - len(in_data)) + in_data
+
+ out_data = outputs[idx] if len(outputs) > idx else ""
+ # right pad
+ out_key = out_data + " " * (max_len_outputs - len(out_data))
+
+ # title in middle
+ line = ""
+ line += _get_with_color(in_key)
+ line += " | "
+ line += " " * len(title) if idx != n_lines // 2 else title
+ line += " | "
+ line += _get_with_color(out_key) if use_color else out_key
+
+ lines += line + "\n"
+
+ lines += divider + "\n"
+ return lines
+
+
+def resolve_named_tuple( # type:ignore
+ clazz: Any, prefix: str = ""
+) -> list[str]:
+ """Returns all fields defined in the clazz t.
+
+ Use this to get all fields defined for an e.g. Named Tuple.
+
+ Args:
+ clazz: Class that should be resolved
+ prefix: Prefix to prepend (will be prefix.)
+
+ Returns:
+ List with all fields and prefixes prepended.
+
+ Examples:
+ >>> Person = namedtuple("Person", ["name", "age", "gender"])
+ >>> Address = namedtuple("Address", ["street", "city", "zipcode"])
+
+ >>> resolve_named_tuple(clazz=Person, prefix="person")
+ ["person.name", "person.age", "person.gender"]
+
+ >>> resolve_named_tuple(clazz=Address, prefix="address")
+ ["address.street", "address.city", "address.zipcode"]
+
+ >>> resolve_named_tuple(clazz=Person, prefix="")
+ ["name", "age", "gender"]
+
+ With more complex types:
+ >>> User = namedtuple("User", ["name", "address"])
+ >>> user = User(name=Person(name="John"), address=Address(street="str", city="zrh", zipcode="1"))
+
+ >>> resolve_named_tuple(clazz=user, prefix="user")
+ ["user.name.name", "user.address.street", "user.address.city",
+ "user.address.zipcode"]
+
+
+
+ """
+ fields = []
+ if hasattr(clazz, "_fields"):
+ for f in clazz._fields:
+ p = f"{prefix}.{f}" if len(prefix) > 0 else f
+ fields += resolve_named_tuple(getattr(clazz, f), prefix=p)
+ return fields
+ return [prefix]
+
+
+def connect_components(
+ in_info: DataConnectionInfo, out_info: DataConnectionInfo
+) -> None:
+ """Marks two components as connected.
+
+ Checks if they have intersecting keys and marks them as matched.
+ Updates the components inplace.
+
+ Args:
+ in_info (DataConnectionInfo): Input DataConnection
+ out_info (DataConnectionInfo): Ouput DataConnection
+ """
+ out_keys = []
+ for out in out_info["in_keys"]:
+ out = out.replace("*", "")
+ out_keys.append(out.split(".")[0])
+
+ # Check connection
+ for idx, key in enumerate(in_info["out_keys"]):
+ key = key.replace("*", "")
+ for o_idx, o_key in enumerate(out_keys):
+ if key == o_key:
+ in_info["out_keys"][idx] = "*" + key
+ out_info["in_keys"][o_idx] = (
+ " " + out_info["in_keys"][o_idx].replace("*", "") + "*"
+ )
+
+
+def prints_datagraph_for_config(
+ model: nn.Module,
+ train_data_connector: DataConnector,
+ test_data_connector: DataConnector,
+ loss: LossModule,
+ callbacks: list[Callback],
+) -> str:
+ """Shows the setup of the configuration objects.
+
+ For each components, plots which inputs is connected to which output.
+ Connected components are marked with "*". Use this to debug your
+ configuration setup.
+
+ Note, that data loaded from the dataset are highlighted with and data
+ from model predictions with .
+
+ Args:
+ model (nn.Module): Model to plot.
+ train_data_connector (DataConnector): Train data connector to plot.
+ test_data_connector (DataConnector): Test data connector to plot.
+ loss (LossModule): Loss to plot.
+ callbacks (list[Callback]): Callbacks to plot.
+
+ Returns:
+ str: The datagraph as a string, that can be printed to the console.
+
+ Example:
+ The following is train datagraph for FasterRCNN with COCO.
+ Inputs loaded from dataset are marked with and predictions
+ with . Unconnected inputs are missing a (*) sign.
+
+ >>> dg = prints_datagraph_for_config(model, train_data_connector, test_data_connector, loss, callbacks)))
+ >>> print(dg)
+ ```
+ # TODO: check if this is correct
+ ===================================
+ = Training Loop =
+ ===================================
+ --------------
+ -boxes2d | | *boxes2d
+ -boxes2d_classes | | *boxes2d_classes
+ -images | Train Data | *images
+ -input_hw | | *input_hw
+ --------------
+ --------------
+ boxes2d* | | -proposals
+ boxes2d_classes* | |
-roi
+ images* | | *
-rpn
+ input_hw* | FasterRCNN |
-sampled_proposals
+ original_hw | |
-sampled_target_indices
+ | |
-sampled_targets
+ --------------
+ -----------
+
-rpn.cls* | | cls_outs
+ -input_hw | | images_hw
+ -rpn.box* | RPNLoss | reg_outs
+ -boxes2d | | target_boxes
+ -----------
+ ------------
+ -sampled_proposals.boxes | | boxes
+
-sampled_targets.labels | | boxes_mask
+
-roi.cls_score | | class_outs
+
-roi.bbox_pred | RCNNLoss | regression_outs
+
-sampled_targets.boxes | | target_boxes
+
-sampled_targets.classes | | target_classes
+ ------------
+ ===================================
+ = Testing Loop =
+ ===================================
+ -------------
+ -images | | *images
+ -input_hw | Test Data | *input_hw
+ -original_hw | | *original_hw
+ -------------
+ --------------
+ boxes2d | | -boxes
+ boxes2d_classes | |
-class_ids
+ images* | FasterRCNN |
-scores
+ input_hw* | |
+ original_hw* | |
+ --------------
+ ===================================
+ = Callbacks =
+ ===================================
+ -------------------------
+ -original_images | | *images
+ -sample_names | | *image_names
+ -boxes | BoundingBoxVisualizer | *boxes
+
-scores | | *scores
+
-class_ids | | *class_ids
+ -------------------------
+ ----------------------
+ -sample_names | | *coco_image_id
+ -boxes | | *pred_boxes
+
-scores | COCODetectEvaluator | *pred_scores
+
-class_ids | | *pred_classes
+ ----------------------
+ ```
+ """
+ model_connection_info = _get_model_conn_infos(model)
+
+ # TODO: support more data connectors
+ assert isinstance(train_data_connector, DataConnector) and isinstance(
+ test_data_connector, DataConnector
+ ), "Only DataConnector is supported."
+ train_data_connection_info = _get_data_connector_infos(
+ train_data_connector, name="Train Data"
+ )
+ test_data_connection_info = _get_data_connector_infos(
+ test_data_connector, name="Test Data"
+ )
+
+ loss_info = _get_loss_connection_infos(loss)
+ log_str = ""
+
+ # connect components
+ log_str += "=" * 35 + "\n"
+ log_str += "=" + " " * 10 + "Training Loop" + " " * 10 + "=" + "\n"
+ log_str += "=" * 35 + "\n"
+
+ train_components = [
+ train_data_connection_info,
+ model_connection_info["train"],
+ ] + loss_info
+
+ for inp, out in zip(train_components[:-1], train_components[1:]):
+ connect_components(inp, out)
+ for e in train_components:
+ log_str += print_box(e["name"], e["in_keys"], e["out_keys"])
+
+ log_str += "=" * 35 + "\n"
+ log_str += "=" + " " * 10 + "Testing Loop " + " " * 10 + "=" + "\n"
+ log_str += "=" * 35 + "\n"
+
+ test_components = [
+ test_data_connection_info,
+ model_connection_info["test"],
+ ]
+
+ for inp, out in zip(test_components[:-1], test_components[1:]):
+ connect_components(inp, out)
+
+ for e in test_components:
+ log_str += print_box(e["name"], e["in_keys"], e["out_keys"])
+
+ # TODO: Add support for more callbacks and handle train_connector
+ log_str += "=" * 35 + "\n"
+ log_str += "=" + " " * 12 + "Callbacks" + " " * 12 + "=" + "\n"
+ log_str += "=" * 35 + "\n"
+
+ # evaluator and visualizer
+ callback_components: list[DataConnectionInfo] = []
+
+ for cb in callbacks:
+ if isinstance(cb, EvaluatorCallback):
+ evaluator = cb.evaluator
+
+ connect_info = _get_evaluator_connection_infos(evaluator)
+ component = _get_cb_connection_infos(
+ cb.evaluator.__class__.__name__, cb.test_connector
+ )
+
+ # found matching connector
+ if component is not None:
+ connect_components(component, connect_info)
+ callback_components.append(component)
+
+ if isinstance(cb, VisualizerCallback):
+ visualizer = cb.visualizer
+
+ connect_info = _get_vis_connection_infos(visualizer)
+
+ component = _get_cb_connection_infos(
+ cb.visualizer.__class__.__name__, cb.test_connector
+ )
+
+ # found matching connector
+ if component is not None:
+ connect_components(component, connect_info)
+ callback_components.append(component)
+
+ for e in callback_components:
+ log_str += print_box(e["name"], e["in_keys"], e["out_keys"])
+
+ return log_str
+
+
+def main(
+ argv: ArgsType, # pylint: disable=unused-argument
+) -> None: # pragma: no cover
+ """Main entry point to show connected components in the config.
+
+ >>> python -m vis4d.config.show_connection --config vis4d/zoo/faster_rcnn/faster_rcnn_coco.py
+ """
+ config = _CONFIG.value
+
+ train_data_connector = instantiate_classes(config.train_data_connector)
+ test_data_connector = instantiate_classes(config.test_data_connector)
+ loss = instantiate_classes(config.loss)
+ model = instantiate_classes(config.model)
+ callbacks = [instantiate_classes(cb) for cb in config.callbacks]
+
+ dg = prints_datagraph_for_config(
+ model, train_data_connector, test_data_connector, loss, callbacks
+ )
+ print(dg)
+
+
+if __name__ == "__main__": # pragma: no cover
+ app.run(main)
diff --git a/vis4d/config/sweep.py b/vis4d/config/sweep.py
new file mode 100644
index 0000000000000000000000000000000000000000..83ed954cb6a5695ab05adb5fd4b97398812ce14f
--- /dev/null
+++ b/vis4d/config/sweep.py
@@ -0,0 +1,44 @@
+"""Helper functions for creating sweep configurations."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict
+
+from vis4d.common.typing import ArgsType
+
+
+def grid_search(
+ param_names: list[str] | str,
+ param_values: list[ArgsType] | list[list[ArgsType]],
+) -> ConfigDict:
+ """Linear grid search configuration over a list of parameters.
+
+ Returns a configuration object that can be used to perform a grid search
+ over a list of parameters.
+
+ Args:
+ param_names (list[str] | str): The name of the parameters to be
+ sampled.
+ param_values (list[Any] | list[list[Any]]): The values which
+ should be sampled.
+
+ Example:
+ >>> grid_search("params.lr", [0.001, 0.01, 0.1])
+
+
+ Returns:
+ ConfigDict: The configuration object that can be used to perform a grid
+ search.
+ """
+ if isinstance(param_names, str):
+ param_names = [param_names]
+ param_values = [param_values]
+
+ assert len(param_names) == len(param_values)
+
+ config = ConfigDict()
+ config.method = "grid"
+ config.sampling_args = []
+ for name, values in zip(param_names, param_values):
+ config.sampling_args.append([name, values])
+ return config
diff --git a/vis4d/config/typing.py b/vis4d/config/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9161dfd05f5420d1af6629321b78f8562bf0e1e
--- /dev/null
+++ b/vis4d/config/typing.py
@@ -0,0 +1,194 @@
+"""Type definitions for configuration files."""
+
+from __future__ import annotations
+
+from typing import Any, TypedDict
+
+from ml_collections import ConfigDict, FieldReference
+from typing_extensions import NotRequired
+
+from .config_dict import FieldConfigDict
+
+
+class ParamGroupCfg(TypedDict):
+ """Parameter group config.
+
+ Attributes:
+ custom_keys (list[str]): List of custom keys.
+ lr_mult (NotRequired[float]): Learning rate multiplier.
+ decay_mult (NotRequired[float]): Weight Decay multiplier.
+ """
+
+ custom_keys: list[str]
+ lr_mult: NotRequired[float]
+ decay_mult: NotRequired[float]
+ norm_decay_mult: NotRequired[float]
+ bias_decay_mult: NotRequired[float]
+
+
+class DataConfig(ConfigDict): # type: ignore
+ """Configuration for a data set.
+
+ This data object is used to configure the training and test data of an
+ experiment. In particular, the train_dataloader and test_dataloader
+ need to be config dicts that can be instantiated as a dataloader.
+
+ Attributes:
+ train_dataloader (ConfigDict): Configuration for the training
+ dataloader.
+ test_dataloader (ConfigDict): Configuration for the test dataloader.
+
+
+ Example:
+ >>> from vis4d.config.types import DataConfig
+ >>> from vis4d.zoo.base import class_config
+ >>> from my_package.data import MyDataLoader
+ >>> cfg = DataConfig()
+ >>> cfg.train_dataloader = class_config(MyDataLoader, ...)
+ """
+
+ train_dataloader: ConfigDict
+ test_dataloader: ConfigDict
+
+
+class LrSchedulerConfig(ConfigDict): # type: ignore
+ """Configuration for a learning rate scheduler.
+
+ Attributes:
+ scheduler (ConfigDict): Configuration for the learning rate scheduler.
+ begin (int): Begin epoch.
+ end (int): End epoch.
+ epoch_based (bool): Whether the learning rate scheduler is epoch based
+ or step based.
+ convert_epochs_to_steps (bool): Whether to convert the begin and end
+ for a step based scheduler to steps automatically based on length
+ of train dataloader. Enables users to set the iteration breakpoints
+ as epochs. Defaults to False.
+ convert_attributes (list[str] | None): List of attributes in the
+ scheduler that should be converted to steps. Defaults to None.
+ """
+
+ scheduler: ConfigDict
+ begin: int
+ end: int
+ epoch_based: bool
+ convert_epochs_to_steps: bool = False
+ convert_attributes: list[str] | None = None
+
+
+class OptimizerConfig(ConfigDict): # type: ignore
+ """Configuration for an optimizer.
+
+ Attributes:
+ optimizer (ConfigDict): Configuration for the optimizer.
+ lr_scheduler (list[LrSchedulerConfig] | None): Configuration for the
+ learning rate scheduler.
+ param_groups (list[ParamGroupCfg] | None): Configuration for the
+ parameter groups.
+ """
+
+ optimizer: ConfigDict
+ lr_scheduler: list[LrSchedulerConfig] | None
+ param_groups: list[ParamGroupCfg] | None
+
+
+class ExperimentParameters(FieldConfigDict):
+ """Parameters for an experiment.
+
+ Attributes:
+ samples_per_gpu (int): Number of samples per GPU.
+ workers_per_gpu (int): Number of workers per GPU.
+ """
+
+ samples_per_gpu: int
+ workers_per_gpu: int
+
+
+class ExperimentConfig(FieldConfigDict):
+ """Configuration for an experiment.
+
+ This data object is used to configure an experiment. It contains the
+ minimal required configuration to run an experiment. In particular, the
+ data, model, optimizers, and loss need to be config dicts that can be
+ instantiated as a data set, model, optimizer, and loss function,
+ respectively.
+
+ Attributes:
+ work_dir (str | FieldReference): The working directory for the
+ experiment.
+ experiment_name (str | FieldReference): The name of the experiment.
+ timestamp (str | FieldReference): The timestamp of the experiment.
+ version (str | FieldReference): The version of the experiment.
+ output_dir (str | FieldReference): The output directory for the
+ experiment.
+ seed (int | FieldReference): The random seed for the experiment.
+ log_every_n_steps (int | FieldReference): The number of steps after
+ which the logs should be written.
+ use_tf32 (bool | FieldReference): Whether to use tf32.
+ benchmark (bool | FieldReference): Whether to enable benchmarking.
+ params (ExperimentParameters): Configuration for the experiment
+ parameters.
+ data (DataConfig): Configuration for the dataset.
+ model (FieldConfigDictOrRef): Configuration for the model.
+ loss (FieldConfigDictOrRef): Configuration for the loss function.
+ optimizers (list[OptimizerConfig]): Configuration for the optimizers.
+ data_connector (FieldConfigDictOrRef): Configuration for the data
+ connector.
+ callbacks (list[FieldConfigDictOrRef]): Configuration for the
+ callbacks which are used in the engine.
+ """
+
+ # General
+ work_dir: str | FieldReference
+ experiment_name: str | FieldReference
+ timestamp: str | FieldReference
+ version: str | FieldReference
+ output_dir: str | FieldReference
+ seed: int | FieldReference
+ log_every_n_steps: int | FieldReference
+ use_tf32: bool | FieldReference
+ benchmark: bool | FieldReference
+ tf32_matmul_precision: str | FieldReference
+
+ params: ExperimentParameters
+
+ # Data
+ data: DataConfig
+
+ # Model
+ model: ConfigDict
+
+ # Loss
+ loss: ConfigDict
+
+ # Optimizer
+ optimizers: list[OptimizerConfig]
+
+ # Data connector
+ train_data_connector: ConfigDict
+ test_data_connector: ConfigDict
+
+ # Callbacks
+ callbacks: list[ConfigDict]
+
+
+class ParameterSweepConfig(FieldConfigDict):
+ """Configuration for a parameter sweep.
+
+ Confguration object for a parameter sweep. It contains the minimal required
+ configuration to run a parameter sweep.
+
+ Attributes:
+ method (str): Sweep method that should be used (e.g. grid)
+ sampling_args (list[tuple[str, Any]]): Arguments that should be passed
+ to the sweep method. E.g. for grid, this would be a list of tuples
+ of the form (parameter_name, parameter_values).
+ suffix (str): Suffix that should be appended to the output directory.
+ This will be interpreted as a string template and can contain
+ references to the sampling_args.
+ E.g. "lr_{lr:.2e}_bs_{batch_size}".
+ """
+
+ method: str | FieldReference
+ sampling_args: list[tuple[str, Any]] | FieldReference # type: ignore
+ suffix: str | FieldReference = ""
diff --git a/vis4d/data/__init__.py b/vis4d/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..efa31a3fb5f1097d5b7e0863488ee8c19141f916
--- /dev/null
+++ b/vis4d/data/__init__.py
@@ -0,0 +1,10 @@
+"""The data package defines the full data pipeline.
+
+We provide dataset implementations in the `datasets` submodule that return a
+common data format `DictData`. This data format is used by the pre-processing
+functions in the submodule `transforms`. The preprocessing functions are
+composed with the datasets in `DataPipe`. Optionally, a reference view sampler
+can be added here. The `DataPipe` is input to `torch.data.DataLoader`, for
+which we provide utility functions for instantiation that handle also
+batch-wise preprocessing and batch collation.
+"""
diff --git a/vis4d/data/__pycache__/__init__.cpython-311.pyc b/vis4d/data/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1356948d291f0dc6d86759bd8ad5f72df5baae05
Binary files /dev/null and b/vis4d/data/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/data/__pycache__/const.cpython-311.pyc b/vis4d/data/__pycache__/const.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c51a3c233bb63003ef80eaa7deded1a5318907e
Binary files /dev/null and b/vis4d/data/__pycache__/const.cpython-311.pyc differ
diff --git a/vis4d/data/__pycache__/typing.cpython-311.pyc b/vis4d/data/__pycache__/typing.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c8957ee83a181a87b23cc17254e1252d25ae595
Binary files /dev/null and b/vis4d/data/__pycache__/typing.cpython-311.pyc differ
diff --git a/vis4d/data/cbgs.py b/vis4d/data/cbgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d087ad5419aac8158fef8aeb9490fdf9ee820ad7
--- /dev/null
+++ b/vis4d/data/cbgs.py
@@ -0,0 +1,153 @@
+"""Class-balanced Grouping and Sampling for 3D Object Detection.
+
+Implementation of `Class-balanced Grouping and Sampling for Point Cloud 3D
+Object Detection `_.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from vis4d.common.distributed import broadcast, rank_zero_only
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.time import Timer
+
+from .datasets.util import print_class_histogram
+from .reference import MultiViewDataset
+from .typing import DictDataOrList
+
+
+# TODO: Support sensor selection.
+class CBGSDataset(Dataset[DictDataOrList]):
+ """Balance the number of scenes under different classes."""
+
+ def __init__(
+ self,
+ dataset: Dataset[DictDataOrList],
+ class_map: dict[str, int],
+ ignore: int = -1,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.dataset = dataset
+ self.has_reference = isinstance(dataset, MultiViewDataset)
+ self.cat2id = dict(sorted(class_map.items(), key=lambda x: x[1]))
+ self.ignore = ignore
+
+ rank_zero_info("Wrapping dataset with CBGS...")
+ sample_indices = self._get_sample_indices()
+ self.sample_indices = broadcast(sample_indices)
+
+ def _show_histogram(
+ self,
+ sample_indices: list[int],
+ sample_frequencies: list[dict[str, int]],
+ ) -> None:
+ """Show class histogram."""
+ frequencies = {cat: 0 for cat in self.cat2id.keys()}
+
+ for idx in sample_indices:
+ freq = sample_frequencies[idx]
+ for box3d_class in freq:
+ frequencies[box3d_class] += freq[box3d_class]
+
+ print_class_histogram(frequencies)
+
+ def _get_class_sample_indices(
+ self,
+ ) -> tuple[dict[int, list[int]], list[dict[str, int]]]:
+ """Get sample indices."""
+ class_sample_indices: dict[int, list[int]] = {
+ cat_id: [] for cat_id in self.cat2id.values()
+ }
+ sample_frequencies = []
+ inv_class_map = {v: k for k, v in self.cat2id.items()}
+
+ # Handle the case that dataset is already wrapped.
+ if hasattr(self.dataset, "dataset"):
+ dataset = self.dataset.dataset
+ else:
+ dataset = self.dataset
+
+ for idx in range(len(dataset)):
+ assert hasattr(
+ dataset, "get_cat_ids"
+ ), "The dataset must have a method `get_cat_ids` to get cat ids."
+ cat_ids = dataset.get_cat_ids(idx)
+ cur_cats = {}
+ frequencies = {cat: 0 for cat in self.cat2id.keys()}
+
+ for cat_id in cat_ids:
+ if cat_id != self.ignore:
+ cur_cats[cat_id] = [idx]
+ frequencies[inv_class_map[cat_id]] += 1
+
+ sample_frequencies.append(frequencies)
+ for cat_id, index in cur_cats.items():
+ class_sample_indices[cat_id] += index
+
+ return class_sample_indices, sample_frequencies
+
+ @rank_zero_only
+ def _get_sample_indices(self) -> list[int]:
+ """Load sample indices.
+
+ Returns:
+ list[int]: List of indices after class sampling.
+ """
+ t = Timer()
+ (
+ class_sample_indices,
+ sample_frequencies,
+ ) = self._get_class_sample_indices()
+
+ duplicated_samples = sum(
+ len(v) for _, v in class_sample_indices.items()
+ )
+ class_distribution = {
+ k: len(v) / duplicated_samples
+ for k, v in class_sample_indices.items()
+ }
+
+ sample_indices = []
+
+ frac = 1.0 / len(self.cat2id)
+ ratios = [
+ frac / v if v > 0 else 1 for v in class_distribution.values()
+ ]
+ for cls_inds, ratio in zip(
+ list(class_sample_indices.values()), ratios
+ ):
+ sample_indices += np.random.choice(
+ cls_inds, int(len(cls_inds) * ratio)
+ ).tolist()
+
+ self._show_histogram(sample_indices, sample_frequencies)
+
+ rank_zero_info(
+ f"Generating {len(sample_indices)} CBGS samples takes "
+ + f"{t.time():.2f} seconds."
+ )
+
+ return sample_indices
+
+ def __len__(self) -> int:
+ """Return the length of sample indices.
+
+ Returns:
+ int: Length of sample indices.
+ """
+ return len(self.sample_indices)
+
+ def __getitem__(self, idx: int) -> DictDataOrList:
+ """Get original dataset idx according to the given index.
+
+ Args:
+ idx (int): The index of self.sample_indices.
+
+ Returns:
+ DictDataOrList: Data of the corresponding index.
+ """
+ ori_index = self.sample_indices[idx]
+ return self.dataset[ori_index]
diff --git a/vis4d/data/const.py b/vis4d/data/const.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb26daad3479d70c8bb366af928f9a302cc9166f
--- /dev/null
+++ b/vis4d/data/const.py
@@ -0,0 +1,179 @@
+"""Defines data related constants.
+
+While the datasets can hold arbitrary data types and formats, this file
+provides some constants that are used to define a common data format which is
+helpful to use for better data transformation.
+"""
+
+from dataclasses import dataclass
+from enum import Enum
+
+# A custom value to distinguish instance ID and category ID; need to be greater
+# than the number of categories. For a pixel in the panoptic result map:
+# panaptic_id = instance_id * INSTANCE_OFFSET + category_id
+INSTANCE_OFFSET = 1000
+
+
+class AxisMode(Enum):
+ """Enum for choosing among different coordinate frame conventions.
+
+ ROS: The coordinate frame aligns with the right hand rule:
+ - x axis points forward.
+ - y axis points left.
+ - z axis points up.
+ See also: https://www.ros.org/reps/rep-0103.html#axis-orientation
+
+ OpenCV: The coordinate frame aligns with a camera coordinate system:
+ - x axis points right.
+ - y axis points down.
+ - z axis points forward.
+ See also: https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
+
+ LiDAR: The coordinate frame aligns with a LiDAR coordinate system:
+ - x axis points right.
+ - y axis points forward.
+ - z axis points up.
+ See also: https://www.nuscenes.org/nuscenes#data-collection
+ """
+
+ ROS = 0
+ OPENCV = 1
+ LIDAR = 2
+
+
+@dataclass
+class CommonKeys:
+ """Common supported keys for DictData.
+
+ While DictData can hold arbitrary keys of data, we define a common set of
+ keys where we expect a pre-defined format to enable the usage of common
+ data pre-processing operations among different datasets.
+
+ General Info:
+ - sample_names (str): Name of the sample.
+
+ If the dataset contains videos:
+ - sequence_names (str): The name of the sequence.
+ - frame_ids (int): The temporal frame index of the sample.
+
+ Image Based Inputs:
+ - images (NDArrayF32): Image of shape [1, H, W, C].
+ - input_hw (Tuple[int, int]): Shape of image in (height, width) after
+ transformations.
+ - original_images (NDArrayF32): Original image of shape [1, H, W, C].
+ - original_hw (Tuple[int, int]): Shape of original image in
+ (height, width).
+
+ Image Classification:
+ - categories (NDArrayI64): Class labels of shape [1, ].
+
+ 2D Object Detection:
+ - boxes2d (NDArrayF32): 2D bounding boxes of shape [N, 4] in xyxy
+ format.
+ - boxes2d_classes (NDArrayI64): Classes of 2D bounding boxes of shape
+ [N,].
+ - boxes2d_names (List[str]): Names of 2D bounding box classes, same
+ order as `boxes2d_classes`.
+
+ 2D Object Tracking:
+ - boxes2d_track_ids (NDArrayI64): Tracking IDs of 2D bounding boxes of
+ shape [N,].
+
+ Segmentation:
+ - masks (NDArrayUI8): Segmentation masks of shape [N, H, W].
+ - seg_masks (NDArrayUI8): Semantic segmentation masks [H, W].
+ - instance_masks (NDArrayUI8): Instance segmentation masks of shape
+ [N, H, W].
+ - panoptic_masks (NDArrayI64): Panoptic segmentation masks [H, W].
+
+ Depth Estimation:
+ - depth_maps (NDArrayF32): Depth maps of shape [H, W].
+
+ Optical Flow:
+ - optical_flows (NDArrayF32): Optical flow maps of shape [H, W, 2].
+
+ Sensor Calibration:
+ - intrinsics (NDArrayF32): Intrinsic sensor calibration. Shape [3, 3].
+ - extrinsics (NDArrayF32): Extrinsic sensor calibration, transformation
+ of sensor to world coordinate frame. Shape [4, 4].
+ - axis_mode (AxisMode): Coordinate convention of the current sensor.
+ - timestamp (int): Sensor timestamp in Unix format.
+
+ 3D Point Cloud Data:
+ - points3d (NDArrayF32): 3D pointcloud data, assumed to be [N, 3] and
+ in sensor frame.
+ - colors3d (NDArrayF32): Associated color values for each point [N, 3].
+
+ 3D Point Cloud Annotations:
+ - semantics3d (NDArrayI64): Semantic classes of 3D points [N, 1].
+ - instances3d (NDArrayI64): Instance IDs of 3D points [N, 1].
+
+ 3D Object Detection:
+ - boxes3d (NDArrayF32): 3D bounding boxes of shape [N, 10], each
+ consists of center (XYZ), dimensions (WLH), and orientation
+ quaternion (WXYZ).
+ - boxes3d_classes (NDArrayI64): Associated semantic classes of 3D
+ bounding boxes of shape [N,].
+ - boxes3d_names (List[str]): Names of 3D bounding box classes, same
+ order as `boxes3d_classes`.
+ - boxes3d_track_ids (NDArrayI64): Associated tracking IDs of 3D
+ bounding boxes of shape [N,].
+ - boxes3d_velocities (NDArrayF32): Associated velocities of 3D bounding
+ boxes of shape [N, 3], where each velocity is in the form of
+ (vx, vy, vz).
+ """
+
+ # General Info
+ sample_names = "sample_names"
+ sequence_names = "sequence_names"
+ frame_ids = "frame_ids"
+
+ # image based inputs
+ images = "images"
+ input_hw = "input_hw"
+ original_images = "original_images"
+ original_hw = "original_hw"
+
+ # Image Classification
+ categories = "categories"
+
+ # 2D Object Detection
+ boxes2d = "boxes2d"
+ boxes2d_classes = "boxes2d_classes"
+ boxes2d_names = "boxes2d_names"
+
+ # 2D Object Tracking
+ boxes2d_track_ids = "boxes2d_track_ids"
+
+ # Segmentation
+ masks = "masks"
+ seg_masks = "seg_masks"
+ instance_masks = "instance_masks"
+ panoptic_masks = "panoptic_masks"
+
+ # Depth Estimation
+ depth_maps = "depth_maps"
+
+ # Optical Flow
+ optical_flows = "optical_flows"
+
+ # Sensor Calibration
+ intrinsics = "intrinsics"
+ extrinsics = "extrinsics"
+ axis_mode = "axis_mode"
+ timestamp = "timestamp"
+
+ # 3D Point Cloud Data
+ points3d = "points3d"
+ colors3d = "colors3d"
+
+ # 3D Point Cloud Annotations
+ semantics3d = "semantics3d"
+ instances3d = "instances3d"
+
+ # 3D Object Detection
+ boxes3d = "boxes3d"
+ boxes3d_classes = "boxes3d_classes"
+ boxes3d_names = "boxes3d_names"
+ boxes3d_track_ids = "boxes3d_track_ids"
+ boxes3d_velocities = "boxes3d_velocities"
diff --git a/vis4d/data/data_pipe.py b/vis4d/data/data_pipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..323c74f4bb63b1ef6f3afef784161bf36a27e8f6
--- /dev/null
+++ b/vis4d/data/data_pipe.py
@@ -0,0 +1,139 @@
+"""DataPipe wraps datasets to share the prepossessing pipeline."""
+
+from __future__ import annotations
+
+import random
+from collections.abc import Callable, Iterable
+
+from torch.utils.data import ConcatDataset, Dataset
+
+from .reference import MultiViewDataset
+from .transforms.base import TFunctor
+from .typing import DictData, DictDataOrList
+
+
+class DataPipe(ConcatDataset[DictDataOrList]):
+ """DataPipe class.
+
+ This class wraps one or multiple instances of a PyTorch Dataset so that the
+ preprocessing steps can be shared across those datasets. Composes dataset
+ and the preprocessing pipeline.
+ """
+
+ def __init__(
+ self,
+ datasets: Dataset[DictDataOrList] | Iterable[Dataset[DictDataOrList]],
+ preprocess_fn: Callable[
+ [list[DictData]], list[DictData]
+ ] = lambda x: x,
+ ):
+ """Creates an instance of the class.
+
+ Args:
+ datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by
+ this data pipeline.
+ preprocess_fn (Callable[[list[DictData]], list[DictData]]):
+ Preprocessing function of a single sample. It takes a list of
+ samples and returns a list of samples. Defaults to identity
+ function.
+ """
+ if isinstance(datasets, Dataset):
+ datasets = [datasets]
+ super().__init__(datasets)
+ self.preprocess_fn = preprocess_fn
+
+ self.has_reference = any(
+ _check_reference(dataset) for dataset in datasets
+ )
+
+ if self.has_reference and not all(
+ _check_reference(dataset) for dataset in datasets
+ ):
+ raise ValueError(
+ "All datasets must be MultiViewDataset / has reference if "
+ + "one of them is."
+ )
+
+ def __getitem__(self, idx: int) -> DictDataOrList:
+ """Wrap getitem to apply augmentations."""
+ samples = super().__getitem__(idx)
+ if isinstance(samples, list):
+ return self.preprocess_fn(samples)
+
+ return self.preprocess_fn([samples])[0]
+
+
+class MultiSampleDataPipe(DataPipe):
+ """MultiSampleDataPipe class.
+
+ This class wraps DataPipe to support augmentations that require multiple
+ images (e.g., Mosaic and Mixup) by sampling additional indices for each
+ image. NUM_SAMPLES needs to be defined as a class attribute for transforms
+ that require multi-sample augmentation.
+ """
+
+ def __init__(
+ self,
+ datasets: Dataset[DictDataOrList] | Iterable[Dataset[DictDataOrList]],
+ preprocess_fn: list[list[TFunctor]],
+ ):
+ """Creates an instance of the class.
+
+ Args:
+ datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by
+ this data pipeline.
+ preprocess_fn (list[list[TFunctor]]): Preprocessing functions of a
+ single sample. Different than DataPipe, this is a list of lists
+ of transformation functions. The inner list is for transforms
+ that needs to share the same sampled indices (e.g.,
+ GenMosaicParameters and MosaicImages), and the outer list is
+ for different transforms.
+ """
+ super().__init__(datasets)
+ self.preprocess_fns = preprocess_fn
+
+ def _sample_indices(self, idx: int, num_samples: int) -> list[int]:
+ """Sample additional indices for multi-sample augmentation."""
+ indices = [idx]
+ for _ in range(1, num_samples):
+ indices.append(random.randint(0, len(self) - 1))
+ return indices
+
+ def __getitem__(self, idx: int) -> DictDataOrList:
+ """Wrap getitem to apply augmentations."""
+ samples = super(DataPipe, self).__getitem__(idx)
+ if not isinstance(samples, list):
+ samples = [samples]
+ single_view = True
+ else:
+ single_view = False
+
+ for preprocess_fn in self.preprocess_fns:
+ if hasattr(preprocess_fn[0], "NUM_SAMPLES"):
+ num_samples = preprocess_fn[0].NUM_SAMPLES
+ aug_inds = self._sample_indices(idx, num_samples)
+ add_samples = [
+ super(DataPipe, self).__getitem__(ind)
+ for ind in aug_inds[1:]
+ ]
+ prep_samples = []
+ for i, samp in enumerate(samples):
+ prep_samples.append(samp)
+ prep_samples += [
+ s[i] if isinstance(s, list) else s for s in add_samples
+ ]
+ else:
+ num_samples = 1
+ prep_samples = samples
+ for prep_fn in preprocess_fn:
+ prep_samples = prep_fn.apply_to_data(prep_samples) # type: ignore # pylint: disable=line-too-long
+ samples = prep_samples[::num_samples]
+ return samples[0] if single_view else samples
+
+
+def _check_reference(dataset: Dataset[DictDataOrList]) -> bool:
+ """Check if the datasets have reference."""
+ has_reference = (
+ dataset.has_reference if hasattr(dataset, "has_reference") else False
+ )
+ return has_reference or isinstance(dataset, MultiViewDataset)
diff --git a/vis4d/data/datasets/__init__.py b/vis4d/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb7083d7b05f96c2fa6e262cb48a4c5377122c42
--- /dev/null
+++ b/vis4d/data/datasets/__init__.py
@@ -0,0 +1 @@
+"""Datasets module."""
diff --git a/vis4d/data/datasets/base.py b/vis4d/data/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0647f262534618ba40394ad44f1fbf7e1da0834
--- /dev/null
+++ b/vis4d/data/datasets/base.py
@@ -0,0 +1,118 @@
+"""Base dataset classes.
+
+We implement a typed version of the PyTorch dataset class here. In addition, we
+provide a number of Mixin classes which a dataset can inherit from to implement
+additional functionality.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import TypedDict
+
+from torch.utils.data import Dataset as TorchDataset
+
+from vis4d.common.typing import ArgsType
+from vis4d.data.io.base import DataBackend
+from vis4d.data.io.file import FileBackend
+from vis4d.data.typing import DictData
+
+
+class Dataset(TorchDataset[DictData]):
+ """Basic pytorch dataset with defined return type."""
+
+ # Dataset metadata.
+ DESCRIPTION = ""
+ HOMEPAGE = ""
+ PAPER = ""
+ LICENSE = ""
+
+ # List of all keys supported by this dataset.
+ KEYS: Sequence[str] = []
+
+ def __init__(
+ self,
+ image_channel_mode: str = "RGB",
+ data_backend: None | DataBackend = None,
+ ) -> None:
+ """Initialize dataset.
+
+ Args:
+ image_channel_mode (str): Image channel mode to use. Default: RGB.
+ data_backend (None | DataBackend): Data backend to use.
+ Default: None.
+ """
+ self.image_channel_mode = image_channel_mode
+ self.data_backend = (
+ data_backend if data_backend is not None else FileBackend()
+ )
+
+ def __len__(self) -> int:
+ """Return length of dataset."""
+ raise NotImplementedError
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Convert single element at given index into Vis4D data format."""
+ raise NotImplementedError
+
+ def validate_keys(self, keys_to_load: Sequence[str]) -> None:
+ """Validate that all keys to load are supported.
+
+ Args:
+ keys_to_load (list[str]): List of keys to load.
+
+ Raises:
+ ValueError: Raise if any key is not defined in AVAILABLE_KEYS.
+ """
+ for k in keys_to_load:
+ if k not in self.KEYS:
+ raise ValueError(f"Key '{k}' is not supported!")
+
+
+class VideoMapping(TypedDict):
+ """Grouped dataset sample indices and frame indices."""
+
+ video_to_indices: dict[str, list[int]]
+ video_to_frame_ids: dict[str, list[int]]
+
+
+class VideoDataset(Dataset):
+ """Video datasets.
+
+ Provides video_mapping attribute for video based interface and reference
+ view samplers.
+ """
+
+ def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None:
+ """Initialize dataset."""
+ super().__init__(*args, **kwargs)
+ self.video_mapping: VideoMapping = {
+ "video_to_indices": {},
+ "video_to_frame_ids": {},
+ }
+
+ def _sort_video_mapping(self, video_mapping: VideoMapping) -> VideoMapping:
+ """Sort video mapping by frame ids."""
+ video_to_indices = video_mapping["video_to_indices"]
+ video_to_frame_ids = video_mapping["video_to_frame_ids"]
+
+ for seq in video_to_indices:
+ sorted_zipped = sorted(
+ list(zip(video_to_indices[seq], video_to_frame_ids[seq])),
+ key=lambda x: x[1],
+ )
+ sorted_indices, sorted_frame_ids = zip(*sorted_zipped)
+ video_mapping["video_to_indices"][seq] = list(sorted_indices)
+ video_mapping["video_to_frame_ids"][seq] = list(sorted_frame_ids)
+
+ return video_mapping
+
+ def _generate_video_mapping(self) -> VideoMapping:
+ """Group dataset sample by their associated video ID.
+
+ The sample index is an integer while video IDs are string.
+
+ Returns:
+ VideoMapping: Mapping of video IDs to sample indices and frame IDs.
+ """
+ raise NotImplementedError
diff --git a/vis4d/data/datasets/bdd100k.py b/vis4d/data/datasets/bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3d2c1919cb4a7b4443777f654526e3ddd48ef1d
--- /dev/null
+++ b/vis4d/data/datasets/bdd100k.py
@@ -0,0 +1,126 @@
+"""BDD100K dataset."""
+
+from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE
+
+from .scalabel import Scalabel
+
+bdd100k_det_map = {
+ "pedestrian": 0,
+ "rider": 1,
+ "car": 2,
+ "truck": 3,
+ "bus": 4,
+ "train": 5,
+ "motorcycle": 6,
+ "bicycle": 7,
+ "traffic light": 8,
+ "traffic sign": 9,
+}
+bdd100k_track_map = {
+ "pedestrian": 0,
+ "rider": 1,
+ "car": 2,
+ "truck": 3,
+ "bus": 4,
+ "train": 5,
+ "motorcycle": 6,
+ "bicycle": 7,
+}
+bdd100k_seg_map = {
+ "road": 0,
+ "sidewalk": 1,
+ "building": 2,
+ "wall": 3,
+ "fence": 4,
+ "pole": 5,
+ "traffic light": 6,
+ "traffic sign": 7,
+ "vegetation": 8,
+ "terrain": 9,
+ "sky": 10,
+ "person": 11,
+ "rider": 12,
+ "car": 13,
+ "truck": 14,
+ "bus": 15,
+ "train": 16,
+ "motorcycle": 17,
+ "bicycle": 18,
+}
+bdd100k_panseg_map = {
+ "dynamic": 0,
+ "ego vehicle": 1,
+ "ground": 2,
+ "static": 3,
+ "parking": 4,
+ "rail track": 5,
+ "road": 6,
+ "sidewalk": 7,
+ "bridge": 8,
+ "building": 9,
+ "fence": 10,
+ "garage": 11,
+ "guard rail": 12,
+ "tunnel": 13,
+ "wall": 14,
+ "banner": 15,
+ "billboard": 16,
+ "lane divider": 17,
+ "parking sign": 18,
+ "pole": 19,
+ "polegroup": 20,
+ "street light": 21,
+ "traffic cone": 22,
+ "traffic device": 23,
+ "traffic light": 24,
+ "traffic sign": 25,
+ "traffic sign frame": 26,
+ "terrain": 27,
+ "vegetation": 28,
+ "sky": 29,
+ "person": 30,
+ "rider": 31,
+ "bicycle": 32,
+ "bus": 33,
+ "car": 34,
+ "caravan": 35,
+ "motorcycle": 36,
+ "trailer": 37,
+ "train": 38,
+ "truck": 39,
+}
+
+if BDD100K_AVAILABLE and SCALABEL_AVAILABLE:
+ from bdd100k.common.utils import load_bdd100k_config
+ from bdd100k.label.to_scalabel import bdd100k_to_scalabel
+ from scalabel.label.io import load
+ from scalabel.label.typing import Dataset as ScalabelData
+else:
+ raise ImportError("bdd100k or scalabel is not installed.")
+
+
+class BDD100K(Scalabel):
+ """BDD100K type dataset, based on Scalabel."""
+
+ DESCRIPTION = """BDD100K is a large-scale dataset for driving scene
+ understanding."""
+ HOMEPAGE = "https://www.bdd100k.com/"
+ PAPER = "https://arxiv.org/abs/1805.04687"
+ LICENSE = "https://www.bdd100k.com/license"
+
+ def _generate_mapping(self) -> ScalabelData:
+ """Generate data mapping."""
+ bdd100k_anns = load(self.annotation_path)
+ if self.config_path is None:
+ return bdd100k_anns # pragma: no cover
+ frames = bdd100k_anns.frames
+ assert isinstance(self.config_path, str)
+ bdd100k_cfg = load_bdd100k_config(self.config_path)
+ scalabel_frames = bdd100k_to_scalabel(frames, bdd100k_cfg)
+ return ScalabelData(
+ frames=scalabel_frames, config=bdd100k_cfg.scalabel, groups=None
+ )
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return f"BDD100KDataset {self.data_root}"
diff --git a/vis4d/data/datasets/coco.py b/vis4d/data/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d60528d5f360048fba388ebb10a66d2dd3cd40
--- /dev/null
+++ b/vis4d/data/datasets/coco.py
@@ -0,0 +1,365 @@
+"""COCO dataset."""
+
+from __future__ import annotations
+
+import contextlib
+import io
+import os
+from collections.abc import Sequence
+
+import numpy as np
+import pycocotools.mask as maskUtils
+from pycocotools.coco import COCO as COCOAPI
+
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.typing import DictData
+
+from .base import Dataset
+from .util import CacheMappingMixin, get_category_names, im_decode
+
+# COCO detection
+coco_det_map = {
+ "person": 0,
+ "bicycle": 1,
+ "car": 2,
+ "motorcycle": 3,
+ "airplane": 4,
+ "bus": 5,
+ "train": 6,
+ "truck": 7,
+ "boat": 8,
+ "traffic light": 9,
+ "fire hydrant": 10,
+ "stop sign": 11,
+ "parking meter": 12,
+ "bench": 13,
+ "bird": 14,
+ "cat": 15,
+ "dog": 16,
+ "horse": 17,
+ "sheep": 18,
+ "cow": 19,
+ "elephant": 20,
+ "bear": 21,
+ "zebra": 22,
+ "giraffe": 23,
+ "backpack": 24,
+ "umbrella": 25,
+ "handbag": 26,
+ "tie": 27,
+ "suitcase": 28,
+ "frisbee": 29,
+ "skis": 30,
+ "snowboard": 31,
+ "sports ball": 32,
+ "kite": 33,
+ "baseball bat": 34,
+ "baseball glove": 35,
+ "skateboard": 36,
+ "surfboard": 37,
+ "tennis racket": 38,
+ "bottle": 39,
+ "wine glass": 40,
+ "cup": 41,
+ "fork": 42,
+ "knife": 43,
+ "spoon": 44,
+ "bowl": 45,
+ "banana": 46,
+ "apple": 47,
+ "sandwich": 48,
+ "orange": 49,
+ "broccoli": 50,
+ "carrot": 51,
+ "hot dog": 52,
+ "pizza": 53,
+ "donut": 54,
+ "cake": 55,
+ "chair": 56,
+ "couch": 57,
+ "potted plant": 58,
+ "bed": 59,
+ "dining table": 60,
+ "toilet": 61,
+ "tv": 62,
+ "laptop": 63,
+ "mouse": 64,
+ "remote": 65,
+ "keyboard": 66,
+ "cell phone": 67,
+ "microwave": 68,
+ "oven": 69,
+ "toaster": 70,
+ "sink": 71,
+ "refrigerator": 72,
+ "book": 73,
+ "clock": 74,
+ "vase": 75,
+ "scissors": 76,
+ "teddy bear": 77,
+ "hair drier": 78,
+ "toothbrush": 79,
+}
+
+# COCO segmentation categories
+coco_seg_map = {
+ "background": 0,
+ "airplane": 1,
+ "bicycle": 2,
+ "bird": 3,
+ "boat": 4,
+ "bottle": 5,
+ "bus": 6,
+ "car": 7,
+ "cat": 8,
+ "chair": 9,
+ "cow": 10,
+ "dining table": 11,
+ "dog": 12,
+ "horse": 13,
+ "motorcycle": 14,
+ "person": 15,
+ "potted plant": 16,
+ "sheep": 17,
+ "couch": 18,
+ "train": 19,
+ "tv": 20,
+}
+
+
+class COCO(CacheMappingMixin, Dataset):
+ """COCO dataset class."""
+
+ DESCRIPTION = """COCO is a large-scale object detection, segmentation, and
+ captioning dataset."""
+ HOMEPAGE = "http://cocodataset.org"
+ PAPER = "http://arxiv.org/abs/1405.0312"
+ LICENSE = "BY-NC-SA 2.0"
+
+ KEYS = [
+ K.images,
+ K.input_hw,
+ K.original_images,
+ K.original_hw,
+ K.sample_names,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.instance_masks,
+ K.seg_masks,
+ ]
+
+ def __init__(
+ self,
+ data_root: str,
+ keys_to_load: Sequence[str] = (
+ K.images,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.instance_masks,
+ ),
+ split: str = "train2017",
+ remove_empty: bool = False,
+ use_pascal_voc_cats: bool = False,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Initialize the COCO dataset.
+
+ Args:
+ data_root (str): Path to the root directory of the dataset.
+ keys_to_load (tuple[str, ...]): Keys to load from the dataset.
+ split (split): Which split to load. Default: "train2017".
+ remove_empty (bool): Whether to remove images with no annotations.
+ use_pascal_voc_cats (bool): Whether to use Pascal VOC categories.
+ cache_as_binary (bool): Whether to cache the dataset as binary.
+ Default: False.
+ cached_file_path (str | None): Path to a cached file. If cached
+ file exist then it will load it instead of generating the data
+ mapping. Default: None.
+ """
+ super().__init__(**kwargs)
+
+ self.data_root = data_root
+ self.keys_to_load = keys_to_load
+ self.split = split
+ self.remove_empty = remove_empty
+ self.use_pascal_voc_cats = use_pascal_voc_cats
+
+ # handling keys to load
+ self.validate_keys(keys_to_load)
+
+ self.load_annotations = (
+ K.boxes2d in keys_to_load
+ or K.boxes2d_classes in keys_to_load
+ or K.instance_masks in keys_to_load
+ or K.seg_masks in keys_to_load
+ )
+
+ self.data, _ = self._load_mapping(
+ self._generate_data_mapping,
+ self._filter_data,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ if self.use_pascal_voc_cats:
+ self.category_names = get_category_names(coco_seg_map)
+ else:
+ self.category_names = get_category_names(coco_det_map)
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return (
+ f"COCODataset(root={self.data_root}, split={self.split}, "
+ f"use_pascal_voc_cats={self.use_pascal_voc_cats})"
+ )
+
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
+ """Remove empty samples."""
+ if self.remove_empty:
+ samples = []
+ for sample in data:
+ if len(sample["anns"]) > 0:
+ samples.append(sample)
+ return samples
+ return data
+
+ def _generate_data_mapping(self) -> list[DictStrAny]:
+ """Generate coco dataset mapping."""
+ annotation_file = os.path.join(
+ self.data_root, "annotations", "instances_" + self.split + ".json"
+ )
+ with contextlib.redirect_stdout(io.StringIO()):
+ coco_api = COCOAPI(annotation_file)
+ cat_ids = sorted(coco_api.getCatIds())
+ cats_map = {c["id"]: c["name"] for c in coco_api.loadCats(cat_ids)}
+ if self.use_pascal_voc_cats:
+ voc_cats = set(coco_seg_map.keys())
+
+ img_ids = sorted(coco_api.imgs.keys())
+ imgs = coco_api.loadImgs(img_ids)
+ samples = []
+ for img_id, img in zip(img_ids, imgs):
+ anns = coco_api.imgToAnns[img_id]
+ if self.use_pascal_voc_cats:
+ anns = [
+ ann
+ for ann in anns
+ if cats_map[ann["category_id"]] in voc_cats
+ ]
+ for ann in anns:
+ cat_name = cats_map[ann["category_id"]]
+ if self.use_pascal_voc_cats:
+ ann["category_id"] = coco_seg_map[cat_name]
+ else:
+ ann["category_id"] = coco_det_map[cat_name]
+ samples.append({"img_id": img_id, "img": img, "anns": anns})
+ return samples
+
+ def __len__(self) -> int:
+ """Return length of dataset."""
+ return len(self.data)
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Transform coco sample to vis4d input format.
+
+ Returns:
+ DataDict[DataKeys, Union[torch.Tensor, Dict[Any]]]
+ """
+ data = self.data[idx]
+ img_h, img_w = data["img"]["height"], data["img"]["width"]
+
+ dict_data: DictData = {}
+
+ if K.images in self.keys_to_load:
+ img_path = os.path.join(
+ self.data_root, self.split, data["img"]["file_name"]
+ )
+ im_bytes = self.data_backend.get(img_path)
+ img = im_decode(im_bytes, mode=self.image_channel_mode)
+ img_ = np.ascontiguousarray(img, dtype=np.float32)[None]
+ assert (img_h, img_w) == img_.shape[
+ 1:3
+ ], "Image's shape doesn't match annotation."
+
+ dict_data[K.sample_names] = data["img"]["id"]
+ dict_data[K.images] = img_
+ dict_data[K.input_hw] = [img_h, img_w]
+
+ if K.original_images in self.keys_to_load:
+ dict_data[K.original_images] = img_
+ dict_data[K.original_hw] = [img_h, img_w]
+
+ if self.load_annotations:
+ boxes = []
+ classes = []
+ masks = []
+
+ for ann in data["anns"]:
+ if K.boxes2d in self.keys_to_load:
+ x1, y1, width, height = ann["bbox"]
+ x2, y2 = x1 + width, y1 + height
+ boxes.append((x1, y1, x2, y2))
+ if (
+ K.boxes2d in self.keys_to_load
+ or K.boxes2d_classes in self.keys_to_load
+ or K.seg_masks in self.keys_to_load
+ ):
+ classes.append(ann["category_id"])
+
+ if (
+ K.seg_masks in self.keys_to_load
+ or K.instance_masks in self.keys_to_load
+ ):
+ mask_ann = ann.get("segmentation", None)
+ if mask_ann is not None:
+ if isinstance(mask_ann, list):
+ rles = maskUtils.frPyObjects(
+ mask_ann, img_h, img_w
+ )
+ rle = maskUtils.merge(rles)
+ elif isinstance(mask_ann["counts"], list):
+ # uncompressed RLE
+ rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ else:
+ # RLE
+ rle = mask_ann
+ masks.append(maskUtils.decode(rle))
+ else: # pragma: no cover
+ masks.append(np.empty((img_h, img_w), dtype=np.uint8))
+
+ box_tensor = (
+ np.empty((0, 4), dtype=np.float32)
+ if not boxes
+ else np.array(boxes, dtype=np.float32)
+ )
+ mask_tensor = (
+ np.empty((0, img_h, img_w), dtype=np.uint8)
+ if not masks
+ else np.ascontiguousarray(masks, dtype=np.uint8)
+ )
+
+ if K.boxes2d in self.keys_to_load:
+ dict_data[K.boxes2d] = box_tensor
+
+ if K.boxes2d_classes in self.keys_to_load:
+ dict_data[K.boxes2d_classes] = np.array(
+ classes, dtype=np.int64
+ )
+
+ if K.instance_masks in self.keys_to_load:
+ dict_data[K.instance_masks] = mask_tensor
+
+ if K.seg_masks in self.keys_to_load:
+ seg_masks = (
+ mask_tensor * np.array(classes)[:, None, None]
+ ).max(axis=0)
+ seg_masks = seg_masks.astype(np.int64)
+ seg_masks[mask_tensor.sum(0) > 1] = 255 # discard overlapped
+ dict_data[K.seg_masks] = seg_masks[None]
+
+ dict_data[K.boxes2d_names] = self.category_names
+
+ return dict_data
diff --git a/vis4d/data/datasets/imagenet.py b/vis4d/data/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ae6622154262f222110b49bea5e3b0a48790051
--- /dev/null
+++ b/vis4d/data/datasets/imagenet.py
@@ -0,0 +1,145 @@
+"""ImageNet 1k dataset."""
+
+from __future__ import annotations
+
+import os
+import pickle
+import tarfile
+from collections.abc import Sequence
+
+import numpy as np
+
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.time import Timer
+from vis4d.common.typing import ArgsType
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.typing import DictData
+
+from .base import Dataset
+from .util import im_decode, to_onehot
+
+
+class ImageNet(Dataset):
+ """ImageNet 1K dataset."""
+
+ DESCRIPTION = """ImageNet is a large visual database designed for use in
+ visual object recognition software research."""
+ HOMEPAGE = "http://www.image-net.org/"
+ PAPER = "http://www.image-net.org/papers/imagenet_cvpr09.pdf"
+ LICENSE = "http://www.image-net.org/terms-of-use"
+
+ KEYS = [K.images, K.categories]
+
+ def __init__(
+ self,
+ data_root: str,
+ keys_to_load: Sequence[str] = (K.images, K.categories),
+ split: str = "train",
+ num_classes: int = 1000,
+ use_sample_lists: bool = False,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Initialize ImageNet dataset.
+
+ Args:
+ data_root (str): Path to root directory of dataset.
+ keys_to_load (list[str], optional): List of keys to load. Defaults
+ to (K.images, K.categories).
+ split (str, optional): Dataset split to load. Defaults to "train".
+ num_classes (int, optional): Number of classes to load. Defaults to
+ 1000.
+ use_sample_lists (bool, optional): Whether to use sample lists for
+ loading the dataset. Defaults to False.
+
+ NOTE: The dataset is expected to be in the following format:
+ data_root
+ ├── train.pkl # Sample lists for training set (optional)
+ ├── val.pkl # Sample lists for validation set (optional)
+ ├── train
+ │ ├── n01440764.tar
+ │ ├── ...
+ └── val
+ ├── n01440764.tar
+ ├── ...
+ With each tar file containing the images of a single class. The
+ images are expected to be in ".JPEG" extension.
+
+ Currently, we are not using the DataBackend for loading the tars to
+ avoid keeping too many file pointers open at the same time.
+ """
+ super().__init__(**kwargs)
+ self.data_root = data_root
+ self.keys_to_load = keys_to_load
+ self.split = split
+ self.num_classes = num_classes
+ self.use_sample_lists = use_sample_lists
+ self.data_infos: list[tuple[tarfile.TarInfo, int]] = []
+ self._classes: list[str] = []
+ self._load_data_infos()
+
+ def _load_data_infos(self) -> None:
+ """Load data infos from disk."""
+ timer = Timer()
+ # Load tar files
+ for file in os.listdir(os.path.join(self.data_root, self.split)):
+ if file.endswith(".tar"):
+ self._classes.append(file)
+ assert len(self._classes) == self.num_classes, (
+ f"Expected {self.num_classes} classes, but found "
+ f"{len(self._classes)} tar files."
+ )
+ self._classes = sorted(self._classes)
+
+ sample_list_path = os.path.join(self.data_root, f"{self.split}.pkl")
+ if self.use_sample_lists and os.path.exists(sample_list_path):
+ with open(sample_list_path, "rb") as f:
+ sample_list = pickle.load(f)[0]
+ if sample_list[-1][1] == self.num_classes - 1:
+ self.data_infos = sample_list
+ else:
+ raise ValueError(
+ "Sample list does not match the number of classes. "
+ "Please regenerate the sample list or set "
+ "use_sample_lists=False."
+ )
+ # If sample lists are not available, generate them on the fly.
+ else:
+ for class_idx, file in enumerate(self._classes):
+ with tarfile.open(
+ os.path.join(self.data_root, self.split, file)
+ ) as f:
+ members = f.getmembers()
+ for member in members:
+ if member.isfile() and member.name.endswith(".JPEG"):
+ self.data_infos.append((member, class_idx))
+
+ rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.")
+
+ def __len__(self) -> int:
+ """Return length of dataset."""
+ return len(self.data_infos)
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Convert single element at given index into Vis4D data format."""
+ member, class_idx = self.data_infos[idx]
+ with tarfile.open(
+ os.path.join(self.data_root, self.split, self._classes[class_idx]),
+ mode="r:*", # unexclusive read mode
+ ) as f:
+ im_bytes = f.extractfile(member)
+ assert im_bytes is not None, f"Could not extract {member.name}!"
+ image = im_decode(im_bytes.read())
+
+ data_dict: DictData = {}
+ if K.images in self.keys_to_load:
+ data_dict[K.images] = np.ascontiguousarray(
+ image, dtype=np.float32
+ )[np.newaxis, ...]
+ image_hw = image.shape[:2]
+ data_dict[K.input_hw] = image_hw
+ data_dict[K.original_hw] = image_hw
+ if K.categories in self.keys_to_load:
+ data_dict[K.categories] = to_onehot(
+ np.array(class_idx, dtype=np.int64), self.num_classes
+ )
+ return data_dict
diff --git a/vis4d/data/datasets/nuscenes.py b/vis4d/data/datasets/nuscenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..a80270bfce21436046bcc3d335a208fff6819e8a
--- /dev/null
+++ b/vis4d/data/datasets/nuscenes.py
@@ -0,0 +1,1011 @@
+"""NuScenes multi-sensor video dataset."""
+
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from collections.abc import Sequence
+
+import numpy as np
+import torch
+from scipy.spatial.transform import Rotation as R
+from tqdm import tqdm
+
+from vis4d.common.imports import NUSCENES_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.time import Timer
+from vis4d.common.typing import (
+ ArgsType,
+ DictStrAny,
+ NDArrayBool,
+ NDArrayF32,
+ NDArrayI64,
+)
+from vis4d.data.const import AxisMode
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.typing import DictData
+from vis4d.op.geometry.projection import generate_depth_map
+from vis4d.op.geometry.transform import (
+ inverse_rigid_transform,
+ transform_points,
+)
+
+from .base import VideoDataset, VideoMapping
+from .util import CacheMappingMixin, im_decode, print_class_histogram
+
+if NUSCENES_AVAILABLE:
+ from nuscenes import NuScenes as NuScenesDevkit
+ from nuscenes.can_bus.can_bus_api import NuScenesCanBus
+ from nuscenes.eval.common.utils import quaternion_yaw
+ from nuscenes.eval.detection.utils import category_to_detection_name
+ from nuscenes.scripts.export_2d_annotations_as_json import (
+ post_process_coords,
+ )
+ from nuscenes.utils.data_classes import Quaternion
+ from nuscenes.utils.geometry_utils import (
+ box_in_image,
+ transform_matrix,
+ view_points,
+ )
+ from nuscenes.utils.splits import create_splits_scenes
+else:
+ raise ImportError("nusenes-devkit is not available.")
+
+nuscenes_class_map = {
+ "bicycle": 0,
+ "motorcycle": 1,
+ "pedestrian": 2,
+ "bus": 3,
+ "car": 4,
+ "trailer": 5,
+ "truck": 6,
+ "construction_vehicle": 7,
+ "traffic_cone": 8,
+ "barrier": 9,
+}
+
+nuscenes_attribute_map = {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 1,
+ "pedestrian.moving": 2,
+ "pedestrian.standing": 3,
+ "pedestrian.sitting_lying_down": 4,
+ "vehicle.moving": 5,
+ "vehicle.parked": 6,
+ "vehicle.stopped": 7,
+ "": 8,
+}
+
+nuscenes_detection_range_map = {
+ "bicycle": 40,
+ "motorcycle": 40,
+ "pedestrian": 40,
+ "bus": 50,
+ "car": 50,
+ "trailer": 50,
+ "truck": 50,
+ "construction_vehicle": 50,
+ "traffic_cone": 30,
+ "barrier": 30,
+}
+
+
+def _get_extrinsics(
+ ego_pose: DictStrAny, car_from_sensor: DictStrAny
+) -> NDArrayF32:
+ """Get NuScenes sensor to global extrinsics."""
+ global_from_car = transform_matrix(
+ ego_pose["translation"],
+ Quaternion(ego_pose["rotation"]),
+ inverse=False,
+ )
+ car_from_sensor_ = transform_matrix(
+ car_from_sensor["translation"],
+ Quaternion(car_from_sensor["rotation"]),
+ inverse=False,
+ )
+ extrinsics = np.dot(global_from_car, car_from_sensor_).astype(np.float32)
+ return extrinsics
+
+
+class NuScenes(CacheMappingMixin, VideoDataset):
+ """NuScenes multi-sensor video dataset.
+
+ This dataset loads both LiDAR and camera inputs from the NuScenes dataset
+ into the Vis4D expected format for multi-sensor, video datasets.
+ """
+
+ DESCRIPTION = "NuScenes multi-sensor driving video dataset."
+ HOMEPAGE = "https://www.nuscenes.org/"
+ PAPER = "https://arxiv.org/abs/1903.11027"
+ LICENSE = "https://www.nuscenes.org/license"
+
+ KEYS = [
+ K.images,
+ K.original_hw,
+ K.input_hw,
+ K.intrinsics,
+ K.extrinsics,
+ K.timestamp,
+ K.axis_mode,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ ]
+
+ SENSORS = [
+ "LIDAR_TOP",
+ "CAM_FRONT",
+ "CAM_FRONT_LEFT",
+ "CAM_FRONT_RIGHT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT",
+ ]
+
+ CAMERAS = [
+ "CAM_FRONT",
+ "CAM_FRONT_LEFT",
+ "CAM_FRONT_RIGHT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT",
+ ]
+
+ def __init__(
+ self,
+ data_root: str,
+ keys_to_load: Sequence[str] = (
+ K.images,
+ K.boxes2d,
+ K.boxes3d,
+ ),
+ sensors: Sequence[str] = (
+ "LIDAR_TOP",
+ "CAM_FRONT",
+ "CAM_FRONT_LEFT",
+ "CAM_FRONT_RIGHT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT",
+ ),
+ version: str = "v1.0-trainval",
+ split: str = "train",
+ max_sweeps: int = 10,
+ skip_empty_samples: bool = False,
+ point_based_filter: bool = False,
+ distance_based_filter: bool = False,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ data_root (str): Root directory of nuscenes data in original
+ format.
+ keys_to_load (tuple[str, ...]): Keys to load from the dataset.
+ Defaults to (K.images, K.boxes2d, K.boxes3d).
+ sensors (Sequence[str, ...]): Which sensor to load. Defaults
+ to ("LIDAR_TOP", "CAM_FRONT", "CAM_FRONT_LEFT",
+ "CAM_FRONT_RIGHT", "CAM_BACK", "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT").
+ version (str, optional): Version of the data to load. Defaults to
+ "v1.0-trainval".
+ split (str, optional): Split of the data to load. Defaults to
+ "train".
+ max_sweeps (int, optional): Maximum number of sweeps for a single
+ key-frame to load. Defaults to 10.
+ skip_empty_samples (bool, optional): Whether to skip samples
+ without annotations. Defaults to False.
+ point_based_filter (bool, optional): Whether to filter out
+ samples based on the number of points in the point cloud.
+ Defaults to False.
+ distance_based_filter (bool, optional): Whether to filter out
+ samples based on the distance of the object from the ego
+ vehicle. Defaults to False.
+ cache_as_binary (bool): Whether to cache the dataset as binary.
+ Default: False.
+ cached_file_path (str | None): Path to a cached file. If cached
+ file exist then it will load it instead of generating the data
+ mapping. Default: None.
+ """
+ super().__init__(**kwargs)
+ self.data_root = data_root
+ self.keys_to_load = keys_to_load
+ self.sensors = sensors
+ self._check_version_and_split(version, split)
+ self.max_sweeps = max_sweeps
+ self.skip_empty_samples = skip_empty_samples
+
+ self.point_based_filter = point_based_filter
+ self.distance_based_filter = distance_based_filter
+
+ # Load annotations
+ self.samples, self.original_len = self._load_mapping(
+ self._generate_data_mapping,
+ self._filter_data,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ # Generate video mapping
+ self.video_mapping = self._generate_video_mapping()
+
+ # Needed for CBGS
+ def get_cat_ids(self, idx: int) -> list[int]:
+ """Return the samples."""
+ return self.samples[idx]["LIDAR_TOP"]["annotations"]["boxes3d_classes"]
+
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
+ """Remove empty samples."""
+ if self.split == "test":
+ return data
+
+ samples = []
+ frequencies = {cat: 0 for cat in nuscenes_class_map}
+ inv_nuscenes_class_map = {v: k for k, v in nuscenes_class_map.items()}
+
+ t = Timer()
+ for sample in data:
+ (
+ _,
+ boxes3d,
+ boxes3d_classes,
+ boxes3d_attributes,
+ boxes3d_track_ids,
+ boxes3d_velocities,
+ ) = self._filter_boxes(sample["LIDAR_TOP"]["annotations"])
+
+ sample["LIDAR_TOP"]["annotations"]["boxes3d"] = boxes3d
+ sample["LIDAR_TOP"]["annotations"][
+ "boxes3d_classes"
+ ] = boxes3d_classes
+ sample["LIDAR_TOP"]["annotations"][
+ "boxes3d_attributes"
+ ] = boxes3d_attributes
+ sample["LIDAR_TOP"]["annotations"][
+ "boxes3d_track_ids"
+ ] = boxes3d_track_ids
+ sample["LIDAR_TOP"]["annotations"][
+ "boxes3d_velocities"
+ ] = boxes3d_velocities
+
+ for box3d_class in boxes3d_classes:
+ frequencies[inv_nuscenes_class_map[box3d_class]] += 1
+
+ for cam in NuScenes.CAMERAS:
+ (
+ mask,
+ boxes3d,
+ boxes3d_classes,
+ boxes3d_attributes,
+ boxes3d_track_ids,
+ boxes3d_velocities,
+ ) = self._filter_boxes(sample[cam]["annotations"])
+
+ sample[cam]["annotations"]["boxes3d"] = boxes3d
+ sample[cam]["annotations"]["boxes3d_classes"] = boxes3d_classes
+ sample[cam]["annotations"][
+ "boxes3d_attributes"
+ ] = boxes3d_attributes
+ sample[cam]["annotations"][
+ "boxes3d_track_ids"
+ ] = boxes3d_track_ids
+ sample[cam]["annotations"][
+ "boxes3d_velocities"
+ ] = boxes3d_velocities
+ sample[cam]["annotations"]["boxes2d"] = sample[cam][
+ "annotations"
+ ]["boxes2d"][mask]
+
+ if self.skip_empty_samples:
+ if len(sample["LIDAR_TOP"]["annotations"]["boxes3d"]) > 0:
+ samples.append(sample)
+ else:
+ samples.append(sample)
+
+ rank_zero_info(
+ f"Preprocessing {len(data)} frames takes {t.time():.2f}"
+ " seconds."
+ )
+
+ print_class_histogram(frequencies)
+
+ if self.skip_empty_samples:
+ rank_zero_info(
+ f"Filtered {len(data) - len(samples)} empty frames."
+ )
+
+ return samples
+
+ def _check_version_and_split(self, version: str, split: str) -> None:
+ """Check that the version and split are valid."""
+ assert version in {
+ "v1.0-trainval",
+ "v1.0-test",
+ "v1.0-mini",
+ }, f"Invalid version {version} for NuScenes!"
+ self.version = version
+
+ if "mini" in version:
+ valid_splits = {"mini_train", "mini_val"}
+ elif "test" in version:
+ valid_splits = {"test"}
+ else:
+ valid_splits = {"train", "val"}
+
+ assert (
+ split in valid_splits
+ ), f"Invalid split {split} for NuScenes {version}!"
+ self.split = split
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return f"NuScenesDataset {self.version} {self.split}"
+
+ def _generate_video_mapping(self) -> VideoMapping:
+ """Group dataset sample indices by their associated video ID.
+
+ The sample index is an integer while video IDs are string.
+
+ Returns:
+ VideoMapping: Mapping of video IDs to sample indices and frame IDs.
+ """
+ video_to_indices: dict[str, list[int]] = defaultdict(list)
+ video_to_frame_ids: dict[str, list[int]] = defaultdict(list)
+ for i, sample in enumerate(self.samples): # type: ignore
+ seq = sample["scene_name"]
+ video_to_indices[seq].append(i)
+ video_to_frame_ids[seq].append(sample["frame_ids"])
+
+ return self._sort_video_mapping(
+ {
+ "video_to_indices": video_to_indices,
+ "video_to_frame_ids": video_to_frame_ids,
+ }
+ )
+
+ def _generate_data_mapping(self) -> list[DictStrAny]:
+ """Generate data mapping.
+
+ Returns:
+ List[DictStrAny]: List of items required to load for a single
+ dataset sample.
+ """
+ data = NuScenesDevkit(
+ version=self.version, dataroot=self.data_root, verbose=False
+ )
+
+ can_bus_data = NuScenesCanBus(dataroot=self.data_root)
+
+ frames = []
+ instance_tokens: list[str] = []
+
+ scene_names_per_split = create_splits_scenes()
+
+ scenes = [
+ scene
+ for scene in data.scene
+ if scene["name"] in scene_names_per_split[self.split]
+ ]
+
+ for scene in tqdm(scenes):
+ scene_name = scene["name"]
+ frame_ids = 0
+ sample_token = scene["first_sample_token"]
+ while sample_token:
+ frame = {}
+ sample = data.get("sample", sample_token)
+
+ frame["scene_name"] = scene_name
+ frame["token"] = sample["token"]
+ frame["frame_ids"] = frame_ids
+
+ sd_rec = data.get("sample_data", sample["data"]["LIDAR_TOP"])
+
+ # Can bus data
+ can_bus = self._load_can_bus_data(
+ scene_name, can_bus_data, sample["timestamp"]
+ )
+
+ pose_record = data.get("ego_pose", sd_rec["ego_pose_token"])
+ rotation = Quaternion(pose_record["rotation"])
+ translation = pose_record["translation"]
+
+ can_bus[:3] = translation
+ can_bus[3:7] = rotation
+ patch_angle = quaternion_yaw(rotation) / np.pi * 180
+ patch_angle += 360 if patch_angle < 0 else 0
+ can_bus[-2] = patch_angle / 180 * np.pi
+ can_bus[-1] = patch_angle
+
+ frame["can_bus"] = can_bus
+
+ # LIDAR data
+ lidar_token = sample["data"]["LIDAR_TOP"]
+
+ frame["LIDAR_TOP"] = self._load_lidar_data(data, lidar_token)
+
+ if self.split != "test":
+ frame["LIDAR_TOP"]["annotations"] = self._load_annotations(
+ data,
+ frame["LIDAR_TOP"]["extrinsics"],
+ sample["anns"],
+ instance_tokens,
+ axis_mode=AxisMode.LIDAR,
+ )
+
+ # obtain sweeps for a single key-frame
+ sweeps: list[DictStrAny] = []
+ while len(sweeps) < self.max_sweeps:
+ if sd_rec["prev"] != "":
+ sweep = self._load_lidar_data(data, sd_rec["prev"])
+ sweeps.append(sweep)
+ sd_rec = data.get("sample_data", sd_rec["prev"])
+ else:
+ break
+ frame["LIDAR_TOP"]["sweeps"] = sweeps
+
+ # Get the sample data for each camera
+ for cam in self.CAMERAS:
+ cam_token = sample["data"][cam]
+
+ frame[cam] = self._load_cam_data(data, cam_token)
+
+ if self.split != "test":
+ frame[cam]["annotations"] = self._load_annotations(
+ data,
+ frame[cam]["extrinsics"],
+ sample["anns"],
+ instance_tokens,
+ axis_mode=AxisMode.OPENCV,
+ export_2d_annotations=True,
+ intrinsics=frame[cam]["intrinsics"],
+ image_hw=frame[cam]["image_hw"],
+ )
+
+ # TODO add RADAR, Map
+
+ frames.append(frame)
+
+ sample_token = sample["next"]
+ frame_ids += 1
+
+ return frames
+
+ def _load_can_bus_data(
+ self,
+ scene_name: str,
+ can_bus_data: NuScenesCanBus,
+ sample_timestamp: int,
+ ) -> list[float]:
+ """Load can bus data."""
+ try:
+ pose_list = can_bus_data.get_messages(scene_name, "pose")
+ except: # pylint: disable=bare-except
+ # server scenes do not have can bus information.
+ return [0.0] * 18
+
+ # during each scene, the first timestamp of can_bus may be large than
+ # the first sample's timestamp
+ can_bus = []
+ last_pose = pose_list[0]
+ for pose in pose_list:
+ if pose["utime"] > sample_timestamp:
+ break
+ last_pose = pose
+
+ last_pose.pop("utime")
+ pos = last_pose.pop("pos")
+ rotation = last_pose.pop("orientation")
+ can_bus.extend(pos)
+ can_bus.extend(rotation)
+
+ # 16 elements
+ for key in last_pose.keys():
+ can_bus.extend(last_pose[key])
+ can_bus.extend([0.0, 0.0])
+
+ return can_bus
+
+ def _load_lidar_data(
+ self, data: NuScenesDevkit, lidar_token: str
+ ) -> DictStrAny:
+ """Load LiDAR data.
+
+ Args:
+ data (NuScenesDevkit): NuScenes toolkit.
+ lidar_token (str): LiDAR token.
+
+ Returns:
+ DictStrAny: LiDAR data.
+ """
+ lidar_data = data.get("sample_data", lidar_token)
+
+ sample_name = (
+ lidar_data["filename"].split("/")[-1].replace(".pcd.bin", "")
+ )
+
+ lidar_path = os.path.join(self.data_root, lidar_data["filename"])
+
+ calibration_lidar = data.get(
+ "calibrated_sensor", lidar_data["calibrated_sensor_token"]
+ )
+
+ ego_pose = data.get("ego_pose", lidar_data["ego_pose_token"])
+
+ extrinsics = _get_extrinsics(ego_pose, calibration_lidar)
+
+ return {
+ "sample_name": sample_name,
+ "lidar_path": lidar_path,
+ "extrinsics": extrinsics,
+ "timestamp": lidar_data["timestamp"],
+ }
+
+ def _load_cam_data(
+ self, data: NuScenesDevkit, cam_token: str
+ ) -> DictStrAny:
+ """Load camera data.
+
+ Args:
+ data (NuScenesDevkit): NuScenes toolkit.
+ cam_token (str): Camera token.
+
+ Returns:
+ DictStrAny: Camera data containing the sample name, image path,
+ image height and width, intrinsics, extrinsics, and
+ timestamp.
+ """
+ cam_data = data.get("sample_data", cam_token)
+
+ sample_name = (
+ cam_data["filename"]
+ .split("/")[-1]
+ .replace(f".{cam_data['fileformat']}", "")
+ )
+
+ image_path = os.path.join(self.data_root, cam_data["filename"])
+
+ calibration_cam = data.get(
+ "calibrated_sensor", cam_data["calibrated_sensor_token"]
+ )
+
+ intrinsics = np.array(
+ calibration_cam["camera_intrinsic"], dtype=np.float32
+ )
+
+ ego_pose = data.get("ego_pose", cam_data["ego_pose_token"])
+ extrinsics = _get_extrinsics(ego_pose, calibration_cam)
+
+ return {
+ "sample_name": sample_name,
+ "image_path": image_path,
+ "image_hw": (cam_data["height"], cam_data["width"]),
+ "intrinsics": intrinsics,
+ "extrinsics": extrinsics,
+ "timestamp": cam_data["timestamp"],
+ }
+
+ def _load_annotations(
+ self,
+ data: NuScenesDevkit,
+ extrinsics: NDArrayF32,
+ ann_tokens: list[str],
+ instance_tokens: list[str],
+ axis_mode: AxisMode = AxisMode.ROS,
+ export_2d_annotations: bool = False,
+ intrinsics: NDArrayF32 | None = None,
+ image_hw: tuple[int, int] | None = None,
+ ) -> DictStrAny:
+ """Load annonations."""
+ boxes3d = np.empty((1, 10), dtype=np.float32)[1:]
+ boxes3d_classes = np.empty((1,), dtype=np.int64)[1:]
+ boxes3d_attributes = np.empty((1,), dtype=np.int64)[1:]
+ boxes3d_track_ids = np.empty((1,), dtype=np.int64)[1:]
+ boxes3d_velocities = np.empty((1, 3), dtype=np.float32)[1:]
+ boxes3d_num_lidar_pts = np.empty((1,), dtype=np.int64)[1:]
+ boxes3d_num_radar_pts = np.empty((1,), dtype=np.int64)[1:]
+
+ if export_2d_annotations:
+ assert (
+ axis_mode == AxisMode.OPENCV
+ ), "2D annotations are only supported in camera coordinates."
+ assert intrinsics is not None, "Intrinsics must be provided."
+ boxes2d = np.empty((1, 4), dtype=np.float32)[1:]
+
+ sensor_from_global = inverse_rigid_transform(
+ torch.from_numpy(extrinsics)
+ )
+ translation = sensor_from_global[:3, 3].numpy()
+ rotation = Quaternion(
+ matrix=sensor_from_global[:3, :3].numpy(), atol=1e-5
+ )
+
+ for ann_token in ann_tokens:
+ ann_info = data.get("sample_annotation", ann_token)
+ box3d_class = category_to_detection_name(ann_info["category_name"])
+
+ if box3d_class is None:
+ continue
+
+ # 3D box in global coordinates
+ box3d = data.get_box(ann_info["token"])
+
+ # Get 3D box velocity
+ box3d.velocity = data.box_velocity(ann_info["token"])
+
+ # Move 3D box to sensor coordinates
+ box3d.rotate(rotation)
+ box3d.translate(translation)
+
+ if export_2d_annotations:
+ assert (
+ image_hw is not None
+ ), "Image height and width must be provided."
+ if not box_in_image(
+ box3d, intrinsics, (image_hw[1], image_hw[0])
+ ):
+ continue
+
+ # Number of points in the 3D box
+ boxes3d_num_lidar_pts = np.concatenate(
+ [
+ boxes3d_num_lidar_pts,
+ np.array([ann_info["num_lidar_pts"]], dtype=np.int64),
+ ]
+ )
+ boxes3d_num_radar_pts = np.concatenate(
+ [
+ boxes3d_num_radar_pts,
+ np.array([ann_info["num_radar_pts"]], dtype=np.int64),
+ ]
+ )
+
+ # Get 2D box
+ if export_2d_annotations:
+ corner_coords = (
+ view_points(box3d.corners(), intrinsics, True)
+ .T[:, :2]
+ .tolist()
+ )
+
+ boxes2d = np.concatenate(
+ [
+ boxes2d,
+ np.array(
+ [post_process_coords(corner_coords)],
+ dtype=np.float32,
+ ),
+ ]
+ )
+
+ # Get 3D box yaw. Use extrinsic rotation to align with PyTorch3D.
+ if axis_mode == AxisMode.OPENCV:
+ yaw = -box3d.orientation.yaw_pitch_roll[0]
+ x, y, z, w = R.from_euler("XYZ", [0, yaw, 0]).as_quat()
+ else:
+ yaw = box3d.orientation.yaw_pitch_roll[0]
+ x, y, z, w = R.from_euler("XYZ", [0, 0, yaw]).as_quat()
+
+ orientation = Quaternion([w, x, y, z])
+
+ boxes3d = np.concatenate(
+ [
+ boxes3d,
+ np.array(
+ [[*box3d.center, *box3d.wlh, *orientation.elements]],
+ dtype=np.float32,
+ ),
+ ]
+ )
+
+ # Get 3D box class id
+ boxes3d_classes = np.concatenate(
+ [
+ boxes3d_classes,
+ np.array(
+ [nuscenes_class_map[box3d_class]], dtype=np.int64
+ ),
+ ]
+ )
+
+ # Get 3D box attribute id
+ if len(ann_info["attribute_tokens"]) == 0:
+ box3d_attr = ""
+ else:
+ box3d_attr = data.get(
+ "attribute", ann_info["attribute_tokens"][0]
+ )["name"]
+ boxes3d_attributes = np.concatenate(
+ [
+ boxes3d_attributes,
+ np.array(
+ [nuscenes_attribute_map[box3d_attr]], dtype=np.int64
+ ),
+ ]
+ )
+
+ # Get 3D box track id
+ instance_token = data.get("sample_annotation", box3d.token)[
+ "instance_token"
+ ]
+ if not instance_token in instance_tokens:
+ instance_tokens.append(instance_token)
+ track_id = instance_tokens.index(instance_token)
+
+ boxes3d_track_ids = np.concatenate(
+ [boxes3d_track_ids, np.array([track_id], dtype=np.int64)]
+ )
+
+ # 3D bounding box velocity
+ velocity = box3d.velocity.astype(np.float32)
+ if np.any(np.isnan(velocity)):
+ velocity = np.zeros(3, dtype=np.float32)
+
+ boxes3d_velocities = np.concatenate(
+ [boxes3d_velocities, velocity[None]]
+ )
+
+ annotations = {
+ "boxes3d": boxes3d,
+ "boxes3d_classes": boxes3d_classes,
+ "boxes3d_attributes": boxes3d_attributes,
+ "boxes3d_track_ids": boxes3d_track_ids,
+ "boxes3d_velocities": boxes3d_velocities,
+ "boxes3d_num_lidar_pts": boxes3d_num_lidar_pts,
+ "boxes3d_num_radar_pts": boxes3d_num_radar_pts,
+ }
+
+ if export_2d_annotations:
+ annotations["boxes2d"] = boxes2d
+
+ return annotations
+
+ def _accumulate_sweeps(
+ self,
+ points: NDArrayF32,
+ lidar2global: NDArrayF32,
+ sweeps: list[DictStrAny],
+ ) -> NDArrayF32:
+ """Accumulate LiDAR sweeps."""
+ if len(sweeps) == 0:
+ return points
+
+ global2lidar = inverse_rigid_transform(torch.from_numpy(lidar2global))
+
+ points_sweeps = [torch.from_numpy(points)]
+ for sweep in sweeps:
+ points_bytes = self.data_backend.get(sweep["lidar_path"])
+ lidar_points = np.frombuffer(
+ bytearray(points_bytes), dtype=np.float32
+ )
+ lidar_points = lidar_points.reshape(-1, 5)[:, :3]
+
+ # Transform LiDAR points to global frame
+ global_lidar_points = transform_points(
+ torch.from_numpy(lidar_points),
+ torch.from_numpy(sweep["extrinsics"]),
+ )
+
+ # Transform LiDAR points to current LiDAR frame
+ current_lidar_points = transform_points(
+ global_lidar_points, global2lidar
+ )
+
+ points_sweeps.append(current_lidar_points)
+
+ return torch.cat(points_sweeps).numpy()
+
+ def _load_depth_map(
+ self,
+ points_lidar: NDArrayF32,
+ lidar2global: NDArrayF32,
+ cam2global: NDArrayF32,
+ intrinsics: NDArrayF32,
+ image_hw: tuple[int, int],
+ ) -> NDArrayF32:
+ """Load depth map.
+
+ Args:
+ points_lidar (NDArrayF32): LiDAR points.
+ lidar2global (NDArrayF32): LiDAR to global extrinsics.
+ cam2global (NDArrayF32): Camera to global extrinsics.
+ intrinsics (NDArrayF32): Camera intrinsic matrix.
+ image_hw (tuple[int, int]): Image height and width.
+
+ Returns:
+ NDArrayF32: Depth map.
+ """
+ cam2global_ = torch.from_numpy(cam2global)
+ lidar2global_ = torch.from_numpy(lidar2global)
+ intrinsics_ = torch.from_numpy(intrinsics)
+ points_lidar_ = torch.from_numpy(np.copy(points_lidar))
+
+ lidar2cam = torch.matmul(torch.inverse(cam2global_), lidar2global_)
+ cam2img = torch.eye(4, 4)
+ cam2img[:3, :3] = intrinsics_
+ points_cam = points_lidar_[:, :3] @ (lidar2cam[:3, :3].T) + lidar2cam[
+ :3, 3
+ ].unsqueeze(0)
+
+ depth_map = generate_depth_map(points_cam, intrinsics_, image_hw)
+ return depth_map.numpy()
+
+ def _filter_boxes(
+ self, annotations: DictStrAny
+ ) -> tuple[
+ NDArrayBool, NDArrayF32, NDArrayI64, NDArrayI64, NDArrayI64, NDArrayF32
+ ]:
+ """Load boxes."""
+ valid_mask = np.full(annotations["boxes3d"].shape[0], True)
+
+ if self.point_based_filter:
+ boxes3d_num_lidar_pts = annotations["boxes3d_num_lidar_pts"]
+ boxes3d_num_radar_pts = annotations["boxes3d_num_radar_pts"]
+ valid_mask = np.logical_and(
+ (boxes3d_num_lidar_pts + boxes3d_num_radar_pts) > 0, valid_mask
+ )
+
+ if self.distance_based_filter:
+ raise NotImplementedError(
+ "Distance based filter not implemented yet"
+ )
+
+ boxes3d = annotations["boxes3d"][valid_mask]
+ boxes3d_classes = annotations["boxes3d_classes"][valid_mask]
+ boxes3d_attributes = annotations["boxes3d_attributes"][valid_mask]
+ boxes3d_track_ids = annotations["boxes3d_track_ids"][valid_mask]
+ boxes3d_velocities = annotations["boxes3d_velocities"][valid_mask]
+
+ return (
+ valid_mask,
+ boxes3d,
+ boxes3d_classes,
+ boxes3d_attributes,
+ boxes3d_track_ids,
+ boxes3d_velocities,
+ )
+
+ def __len__(self) -> int:
+ """Length."""
+ return len(self.samples)
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Get single sample.
+
+ Args:
+ idx (int): Index of sample.
+
+ Returns:
+ DictData: sample at index in Vis4D input format.
+ """
+ sample = self.samples[idx]
+ data_dict: DictData = {}
+
+ # metadata
+ data_dict["token"] = sample["token"]
+ data_dict[K.frame_ids] = sample["frame_ids"]
+ data_dict[K.sequence_names] = sample["scene_name"]
+ data_dict["can_bus"] = sample["can_bus"]
+
+ if "LIDAR_TOP" in self.sensors:
+ lidar_data = sample["LIDAR_TOP"]
+
+ # load LiDAR frame
+ data_dict["LIDAR_TOP"] = {
+ K.sample_names: lidar_data["sample_name"],
+ K.timestamp: lidar_data["timestamp"],
+ K.extrinsics: lidar_data["extrinsics"],
+ K.axis_mode: AxisMode.LIDAR,
+ }
+
+ if (
+ K.points3d in self.keys_to_load
+ or K.depth_maps in self.keys_to_load
+ ):
+ points_bytes = self.data_backend.get(lidar_data["lidar_path"])
+ lidar_points = np.frombuffer(
+ bytearray(points_bytes), dtype=np.float32
+ )
+ lidar_points = lidar_points.reshape(-1, 5)[:, :3]
+
+ lidar_points = self._accumulate_sweeps(
+ lidar_points,
+ lidar_data["extrinsics"],
+ lidar_data["sweeps"],
+ )
+
+ if K.points3d in self.keys_to_load:
+ data_dict["LIDAR_TOP"][K.points3d] = lidar_points
+
+ if K.boxes3d in self.keys_to_load:
+ data_dict["LIDAR_TOP"][K.boxes3d] = lidar_data["annotations"][
+ "boxes3d"
+ ]
+ data_dict["LIDAR_TOP"][K.boxes3d_classes] = lidar_data[
+ "annotations"
+ ]["boxes3d_classes"]
+ data_dict["LIDAR_TOP"][K.boxes3d_track_ids] = lidar_data[
+ "annotations"
+ ]["boxes3d_track_ids"]
+ data_dict["LIDAR_TOP"][K.boxes3d_velocities] = lidar_data[
+ "annotations"
+ ]["boxes3d_velocities"]
+ data_dict["LIDAR_TOP"]["attributes"] = lidar_data[
+ "annotations"
+ ]["boxes3d_attributes"]
+
+ # load camera frame
+ for cam in NuScenes.CAMERAS:
+ if cam in self.sensors:
+ cam_data = sample[cam]
+
+ data_dict[cam] = {K.timestamp: cam_data["timestamp"]}
+
+ if K.images in self.keys_to_load:
+ im_bytes = self.data_backend.get(cam_data["image_path"])
+ image = np.ascontiguousarray(
+ im_decode(im_bytes, mode=self.image_channel_mode),
+ dtype=np.float32,
+ )[None]
+
+ data_dict[cam][K.images] = image
+ data_dict[cam][K.input_hw] = cam_data["image_hw"]
+ data_dict[cam][K.sample_names] = cam_data["sample_name"]
+ data_dict[cam][K.intrinsics] = cam_data["intrinsics"]
+ data_dict[cam][K.extrinsics] = cam_data["extrinsics"]
+ data_dict[cam][K.axis_mode] = AxisMode.OPENCV
+
+ if K.original_images in self.keys_to_load:
+ data_dict[cam][K.original_images] = image
+ data_dict[cam][K.original_hw] = cam_data["image_hw"]
+
+ if (
+ K.boxes3d in self.keys_to_load
+ or K.boxes2d in self.keys_to_load
+ ):
+ if K.boxes3d in self.keys_to_load:
+ data_dict[cam][K.boxes3d] = cam_data["annotations"][
+ "boxes3d"
+ ]
+ data_dict[cam][K.boxes3d_classes] = cam_data[
+ "annotations"
+ ]["boxes3d_classes"]
+ data_dict[cam][K.boxes3d_track_ids] = cam_data[
+ "annotations"
+ ]["boxes3d_track_ids"]
+ data_dict[cam][K.boxes3d_velocities] = cam_data[
+ "annotations"
+ ]["boxes3d_velocities"]
+ data_dict[cam]["attributes"] = cam_data["annotations"][
+ "boxes3d_attributes"
+ ]
+
+ if K.boxes2d in self.keys_to_load:
+ boxes2d = cam_data["annotations"]["boxes2d"]
+
+ data_dict[cam][K.boxes2d] = boxes2d
+ data_dict[cam][K.boxes2d_classes] = data_dict[cam][
+ K.boxes3d_classes
+ ]
+ data_dict[cam][K.boxes2d_track_ids] = data_dict[cam][
+ K.boxes3d_track_ids
+ ]
+
+ if K.depth_maps in self.keys_to_load:
+ depth_maps = self._load_depth_map(
+ lidar_points,
+ lidar_data["extrinsics"],
+ cam_data["extrinsics"],
+ cam_data["intrinsics"],
+ cam_data["image_hw"],
+ )
+
+ data_dict[cam][K.depth_maps] = depth_maps
+
+ return data_dict
diff --git a/vis4d/data/datasets/nuscenes_detection.py b/vis4d/data/datasets/nuscenes_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..decd7e7b5f885717f2d9e06cf4fe33e60c89f58e
--- /dev/null
+++ b/vis4d/data/datasets/nuscenes_detection.py
@@ -0,0 +1,113 @@
+"""NuScenes multi-sensor video dataset."""
+
+from __future__ import annotations
+
+import json
+
+import numpy as np
+
+from vis4d.common.typing import ArgsType, DictStrAny, NDArrayF32, NDArrayI64
+from vis4d.data.typing import DictData
+
+from .nuscenes import NuScenes, nuscenes_class_map
+
+
+class NuScenesDetection(NuScenes):
+ """NuScenes detection dataset."""
+
+ def __init__(
+ self,
+ pure_detection: str,
+ score_thres: float = 0.05,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ self.pure_detection = pure_detection
+ self.score_thres = score_thres
+
+ with open(self.pure_detection, encoding="utf-8") as f:
+ self.predictions = json.load(f)
+
+ super().__init__(**kwargs)
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return (
+ f"NuScenesDetection {self.version} {self.split} using "
+ + f"{self.pure_detection}"
+ )
+
+ def _load_pred(
+ self, preds: list[DictStrAny]
+ ) -> tuple[NDArrayF32, NDArrayI64, NDArrayF32, NDArrayF32]:
+ """Load nuscenes format prediction."""
+ boxes3d = np.empty((1, 10), dtype=np.float32)[1:]
+ boxes3d_classes = np.empty((1,), dtype=np.int64)[1:]
+ boxes3d_scores = np.empty((1,), dtype=np.float32)[1:]
+ boxes3d_velocities = np.empty((1, 3), dtype=np.float32)[1:]
+
+ for pred in preds:
+ if pred["detection_name"] not in nuscenes_class_map:
+ continue
+
+ if float(pred["detection_score"]) <= self.score_thres:
+ continue
+
+ boxes3d = np.concatenate(
+ [
+ boxes3d,
+ np.array(
+ [
+ [
+ *pred["translation"],
+ *pred["size"],
+ *pred["rotation"],
+ ]
+ ],
+ dtype=np.float32,
+ ),
+ ]
+ )
+ boxes3d_classes = np.concatenate(
+ [
+ boxes3d_classes,
+ np.array(
+ [nuscenes_class_map[pred["detection_name"]]],
+ dtype=np.int64,
+ ),
+ ]
+ )
+ boxes3d_scores = np.concatenate(
+ [
+ boxes3d_scores,
+ np.array([pred["detection_score"]], dtype=np.float32),
+ ]
+ )
+ boxes3d_velocities = np.concatenate(
+ [
+ boxes3d_velocities,
+ np.array([[*pred["velocity"], 0]], dtype=np.float32),
+ ]
+ )
+
+ return boxes3d, boxes3d_classes, boxes3d_scores, boxes3d_velocities
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Get single sample.
+
+ Args:
+ idx (int): Index of sample.
+
+ Returns:
+ DictData: sample at index in Vis4D input format.
+ """
+ data_dict = super().__getitem__(idx)
+
+ (
+ data_dict["LIDAR_TOP"]["pred_boxes3d"],
+ data_dict["LIDAR_TOP"]["pred_boxes3d_classes"],
+ data_dict["LIDAR_TOP"]["pred_boxes3d_scores"],
+ data_dict["LIDAR_TOP"]["pred_boxes3d_velocities"],
+ ) = self._load_pred(self.predictions["results"][data_dict["token"]])
+
+ return data_dict
diff --git a/vis4d/data/datasets/nuscenes_mono.py b/vis4d/data/datasets/nuscenes_mono.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a48f9222cefba099e7a445c6486acca34ba385c
--- /dev/null
+++ b/vis4d/data/datasets/nuscenes_mono.py
@@ -0,0 +1,248 @@
+"""NuScenes monocular dataset."""
+
+from __future__ import annotations
+
+import numpy as np
+from tqdm import tqdm
+
+from vis4d.common.imports import NUSCENES_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.time import Timer
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import AxisMode
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.typing import DictData
+
+from .nuscenes import NuScenes, nuscenes_class_map
+from .util import im_decode, print_class_histogram
+
+if NUSCENES_AVAILABLE:
+ from nuscenes import NuScenes as NuScenesDevkit
+ from nuscenes.utils.splits import create_splits_scenes
+else:
+ raise ImportError("nusenes-devkit is not available.")
+
+
+class NuScenesMono(NuScenes):
+ """NuScenes monocular dataset."""
+
+ def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None:
+ """Initialize the dataset."""
+ super().__init__(*args, **kwargs)
+
+ # Needed for CBGS
+ def get_cat_ids(self, idx: int) -> list[int]:
+ """Return the samples."""
+ return self.samples[idx]["CAM"]["annotations"]["boxes3d_classes"]
+
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
+ """Remove empty samples."""
+ samples = []
+ frequencies = {cat: 0 for cat in nuscenes_class_map}
+ inv_nuscenes_class_map = {v: k for k, v in nuscenes_class_map.items()}
+
+ t = Timer()
+ for sample in data:
+ (
+ mask,
+ boxes3d,
+ boxes3d_classes,
+ boxes3d_attributes,
+ boxes3d_track_ids,
+ boxes3d_velocities,
+ ) = self._filter_boxes(sample["CAM"]["annotations"])
+
+ sample["CAM"]["annotations"]["boxes3d"] = boxes3d
+ sample["CAM"]["annotations"]["boxes3d_classes"] = boxes3d_classes
+ sample["CAM"]["annotations"][
+ "boxes3d_attributes"
+ ] = boxes3d_attributes
+ sample["CAM"]["annotations"][
+ "boxes3d_track_ids"
+ ] = boxes3d_track_ids
+ sample["CAM"]["annotations"][
+ "boxes3d_velocities"
+ ] = boxes3d_velocities
+ sample["CAM"]["annotations"]["boxes2d"] = sample["CAM"][
+ "annotations"
+ ]["boxes2d"][mask]
+
+ for box3d_class in boxes3d_classes:
+ frequencies[inv_nuscenes_class_map[box3d_class]] += 1
+
+ if self.skip_empty_samples:
+ if len(sample["CAM"]["annotations"]["boxes3d"]) > 0:
+ samples.append(sample)
+ else:
+ samples.append(sample)
+
+ rank_zero_info(
+ f"Preprocessing {len(data)} frames takes {t.time():.2f}"
+ " seconds."
+ )
+
+ print_class_histogram(frequencies)
+
+ if self.skip_empty_samples:
+ rank_zero_info(
+ f"Filtered {len(data) - len(samples)} empty frames."
+ )
+
+ return samples
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return f"NuScenes Monocular Dataset {self.version} {self.split}"
+
+ def _generate_data_mapping(self) -> list[DictStrAny]:
+ """Generate data mapping.
+
+ Returns:
+ List[DictStrAny]: List of items required to load for a single
+ dataset sample.
+ """
+ data = NuScenesDevkit(
+ version=self.version, dataroot=self.data_root, verbose=False
+ )
+
+ frames = []
+ instance_tokens: list[str] = []
+
+ scene_names_per_split = create_splits_scenes()
+
+ scenes = [
+ scene
+ for scene in data.scene
+ if scene["name"] in scene_names_per_split[self.split]
+ ]
+
+ for scene in tqdm(scenes):
+ scene_name = scene["name"]
+ frame_ids = 0
+ sample_token = scene["first_sample_token"]
+ while sample_token:
+ sample = data.get("sample", sample_token)
+
+ # LIDAR data
+ lidar_token = sample["data"]["LIDAR_TOP"]
+
+ lidar_data = self._load_lidar_data(data, lidar_token)
+ lidar_data["annotations"] = self._load_annotations(
+ data,
+ lidar_data["extrinsics"],
+ sample["anns"],
+ instance_tokens,
+ )
+
+ # TODO add RADAR, Map data
+
+ # Get the sample data for each camera
+ for cam in self.CAMERAS:
+ frame: DictStrAny = {}
+ frame["scene_name"] = f"{scene_name}_{cam}"
+ frame["token"] = sample["token"]
+ frame["frame_ids"] = frame_ids
+
+ frame["LIDAR_TOP"] = lidar_data
+
+ cam_token = sample["data"][cam]
+
+ frame["CAM"] = self._load_cam_data(data, cam_token)
+ frame["CAM"]["annotations"] = self._load_annotations(
+ data,
+ frame["CAM"]["extrinsics"],
+ sample["anns"],
+ instance_tokens,
+ axis_mode=AxisMode.OPENCV,
+ export_2d_annotations=True,
+ intrinsics=frame["CAM"]["intrinsics"],
+ image_hw=frame["CAM"]["image_hw"],
+ )
+
+ frames.append(frame)
+
+ sample_token = sample["next"]
+ frame_ids += 1
+
+ return frames
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Get single sample.
+
+ Args:
+ idx (int): Index of sample.
+
+ Returns:
+ DictData: sample at index in Vis4D input format.
+ """
+ sample = self.samples[idx]
+ data_dict: DictData = {}
+
+ if K.depth_maps in self.keys_to_load:
+ lidar_data = sample["LIDAR_TOP"]
+
+ points_bytes = self.data_backend.get(lidar_data["lidar_path"])
+ points = np.frombuffer(points_bytes, dtype=np.float32)
+ points = points.reshape(-1, 5)[:, :3]
+
+ if K.depth_maps in self.keys_to_load:
+ lidar_to_global = lidar_data["extrinsics"]
+
+ # load camera frame
+ data_dict = {
+ "token": sample["token"],
+ K.sequence_names: sample["scene_name"],
+ K.frame_ids: sample["frame_ids"],
+ K.timestamp: sample["CAM"]["timestamp"],
+ }
+
+ if K.images in self.keys_to_load:
+ im_bytes = self.data_backend.get(sample["CAM"]["image_path"])
+ image = np.ascontiguousarray(
+ im_decode(im_bytes), dtype=np.float32
+ )[None]
+
+ data_dict[K.images] = image
+ data_dict[K.input_hw] = sample["CAM"]["image_hw"]
+ data_dict[K.sample_names] = sample["CAM"]["sample_name"]
+ data_dict[K.intrinsics] = sample["CAM"]["intrinsics"]
+
+ if K.original_images in self.keys_to_load:
+ data_dict[K.original_images] = image
+ data_dict[K.original_hw] = sample["CAM"]["image_hw"]
+
+ if K.boxes3d in self.keys_to_load or K.boxes2d in self.keys_to_load:
+ if K.boxes3d in self.keys_to_load:
+ data_dict[K.boxes3d] = sample["CAM"]["annotations"]["boxes3d"]
+ data_dict[K.boxes3d_classes] = sample["CAM"]["annotations"][
+ "boxes3d_classes"
+ ]
+ data_dict[K.boxes3d_track_ids] = sample["CAM"]["annotations"][
+ "boxes3d_track_ids"
+ ]
+ data_dict[K.boxes3d_velocities] = sample["CAM"]["annotations"][
+ "boxes3d_velocities"
+ ]
+ data_dict["attributes"] = sample["CAM"]["annotations"][
+ "boxes3d_attributes"
+ ]
+ data_dict[K.extrinsics] = sample["CAM"]["extrinsics"]
+ data_dict[K.axis_mode] = AxisMode.OPENCV
+
+ if K.boxes2d in self.keys_to_load:
+ data_dict[K.boxes2d] = sample["CAM"]["annotations"]["boxes2d"]
+ data_dict[K.boxes2d_classes] = data_dict[K.boxes3d_classes]
+ data_dict[K.boxes2d_track_ids] = data_dict[K.boxes3d_track_ids]
+
+ if K.depth_maps in self.keys_to_load:
+ depth_maps = self._load_depth_map(
+ points,
+ lidar_to_global,
+ sample["CAM"]["extrinsics"],
+ sample["CAM"]["intrinsics"],
+ sample["CAM"]["image_hw"],
+ )
+
+ data_dict[K.depth_maps] = depth_maps
+
+ return data_dict
diff --git a/vis4d/data/datasets/nuscenes_trajectory.py b/vis4d/data/datasets/nuscenes_trajectory.py
new file mode 100644
index 0000000000000000000000000000000000000000..a155ac837d63c1b7aa22de795cab09bf2be4319d
--- /dev/null
+++ b/vis4d/data/datasets/nuscenes_trajectory.py
@@ -0,0 +1,264 @@
+"""NuScenes trajectory dataset."""
+
+from __future__ import annotations
+
+import json
+
+import numpy as np
+from scipy.spatial.distance import cdist
+from tqdm import tqdm
+
+from vis4d.common.imports import NUSCENES_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import DictStrAny, NDArrayF32
+from vis4d.data.typing import DictData
+
+from .base import Dataset
+from .util import CacheMappingMixin
+
+if NUSCENES_AVAILABLE:
+ from nuscenes import NuScenes as NuScenesDevkit
+ from nuscenes.eval.detection.utils import category_to_detection_name
+ from nuscenes.utils.data_classes import Quaternion
+ from nuscenes.utils.splits import create_splits_scenes
+else:
+ raise ImportError("nusenes-devkit is not available.")
+
+
+class NuScenesTrajectory(CacheMappingMixin, Dataset):
+ """NuScenes Trajectory dataset with given detection results.
+
+ It will generate a trajectory data pair with minimum sequence length. The
+ detection results will be matched with the ground truth trajectory
+ according to the BEV distance.
+ """
+
+ def __init__(
+ self,
+ detector: str,
+ pure_detection: str,
+ data_root: str,
+ version: str = "v1.0-trainval",
+ split: str = "train",
+ min_seq_len: int = 10,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ ) -> None:
+ """Init dataset.
+
+ Args:
+ detector (str): The detector name.
+ pure_detection (str): The path to the pure detection results. It
+ should be the same format as nuScenes submission format.
+ data_root (str): The root path of the dataset.
+ version (str, optional): The version of the dataset. Defaults to
+ "v1.0-trainval".
+ split (str, optional): The split of the dataset. Defaults to
+ "train".
+ min_seq_len (int, optional): The minimum sequence length of the
+ trajectory. Defaults to 10.
+ cache_as_binary (bool, optional): Whether to cache the dataset as
+ binary. Defaults to False.
+ cached_file_path (str | None, optional): The path to the cached
+ file. Defaults to None.
+ """
+ super().__init__()
+ self.data_root = data_root
+ self.version = version
+ self.split = split
+
+ self.detector = detector
+ self.min_seq_len = min_seq_len
+
+ self.pure_detection = pure_detection
+
+ # Load trajectories
+ self.samples, _ = self._load_mapping(
+ self._generate_data_mapping,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+ rank_zero_info(f"Generated {len(self.samples)} trajectories.")
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return f"NuScenes Trajectory Data with {self.detector} detection"
+
+ def _match_gt_pred(
+ self,
+ gt_world: NDArrayF32,
+ gt_class: str,
+ predictions: list[DictStrAny],
+ ) -> tuple[NDArrayF32, bool]:
+ """Match gt and pred according to BEV center distance.
+
+ If the distance is less than 2 meters, the prediction will be used
+ instead of the ground truth.
+ """
+ if len(predictions) > 0:
+ same_class_preds = [
+ pred
+ for pred in predictions
+ if pred["detection_name"] == gt_class
+ ]
+
+ if len(same_class_preds) > 0:
+ preds_center = [
+ pred["translation"][:2] for pred in same_class_preds
+ ]
+ distance_matrix = (
+ cdist( # pylint: disable=unsubscriptable-object
+ gt_world[:, :2],
+ np.array(preds_center).reshape(-1, 2),
+ )[0]
+ )
+
+ if distance_matrix[distance_matrix.argmin()] <= 2:
+ match_pred = same_class_preds[distance_matrix.argmin()]
+
+ # WLH -> HWL
+ w, l, h = match_pred["size"]
+ dimensions = [h, w, l]
+ yaw = Quaternion(match_pred["rotation"]).yaw_pitch_roll[0]
+
+ pred_world = np.array(
+ [
+ [
+ *match_pred["translation"],
+ *dimensions,
+ yaw,
+ match_pred["detection_score"],
+ ]
+ ],
+ dtype=np.float32,
+ )
+
+ return pred_world, False
+
+ return gt_world, True
+
+ def _generate_data_mapping(self) -> list[dict[str, NDArrayF32]]:
+ """Generate trajectories predction and groundtruth.
+
+ Trajectories will be generated for each scene. Each trajectory consists
+ of [x, y, z, h, w, l, yaw, score] in world coordinate.
+
+ Returns:
+ list[dict[str, NDArrayF32]]: The list of trajectories.
+ """
+ data = NuScenesDevkit(
+ version=self.version, dataroot=self.data_root, verbose=False
+ )
+
+ scene_names_per_split = create_splits_scenes()
+
+ scenes = [
+ scene
+ for scene in data.scene
+ if scene["name"] in scene_names_per_split[self.split]
+ ]
+
+ instance_tokens = []
+
+ with open(self.pure_detection, "r", encoding="utf-8") as f:
+ predictions = json.load(f)
+
+ num_gt_boxes = 0
+ num_pred_boxes = 0
+ total_traj = []
+ for scene in tqdm(scenes):
+ local_traj: dict[int, dict[str, list[NDArrayF32]]] = {}
+
+ sample_token = scene["first_sample_token"]
+ while sample_token:
+ sample = data.get("sample", sample_token)
+
+ preds = predictions["results"][sample_token]
+
+ for ann_token in sample["anns"]:
+ ann_info = data.get("sample_annotation", ann_token)
+ box3d_class = category_to_detection_name(
+ ann_info["category_name"]
+ )
+
+ if box3d_class is None:
+ continue
+
+ box3d = data.get_box(ann_info["token"])
+
+ instance_token = data.get(
+ "sample_annotation", box3d.token
+ )["instance_token"]
+
+ if not instance_token in instance_tokens:
+ instance_tokens.append(instance_token)
+ track_id = instance_tokens.index(instance_token)
+
+ if track_id not in local_traj:
+ local_traj[track_id] = {"gt": [], "pred": []}
+
+ # WLH -> HWL
+ w, l, h = box3d.wlh
+ dimensions = [h, w, l]
+ yaw = box3d.orientation.yaw_pitch_roll[0]
+
+ gt_world = np.array(
+ [[*box3d.center, *dimensions, yaw, 1.0]],
+ dtype=np.float32,
+ )
+
+ local_traj[track_id]["gt"].append(gt_world)
+
+ matched_pred, is_gt = self._match_gt_pred(
+ gt_world, box3d_class, preds
+ )
+ local_traj[track_id]["pred"].append(matched_pred)
+
+ if is_gt:
+ num_gt_boxes += 1
+ else:
+ num_pred_boxes += 1
+
+ sample_token = sample["next"]
+
+ for _, traj in local_traj.items():
+ if len(traj["gt"]) >= self.min_seq_len:
+ trajectory = {
+ "gt": np.concatenate(traj["gt"]),
+ "pred": np.concatenate(traj["pred"]),
+ }
+ total_traj.append(trajectory)
+
+ rank_zero_info(f"Use {num_gt_boxes} gt boxes.")
+ rank_zero_info(f"Use {num_pred_boxes} pred boxes.")
+
+ return total_traj
+
+ def __len__(self) -> int:
+ """Return the length of the dataset."""
+ return len(self.samples)
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Return the item at the given index.
+
+ The trajectory will be randomly cropped to the minimum sequence length.
+ """
+ trajectory = self.samples[idx]
+ data_dict: DictData = {}
+
+ traj_len = len(trajectory["gt"])
+
+ if traj_len > self.min_seq_len:
+ first_frame = np.random.randint(traj_len - self.min_seq_len)
+ else:
+ first_frame = 0
+
+ data_dict["gt_traj"] = trajectory["gt"][
+ first_frame : first_frame + self.min_seq_len
+ ]
+
+ data_dict["pred_traj"] = trajectory["pred"][
+ first_frame : first_frame + self.min_seq_len
+ ]
+
+ return data_dict
diff --git a/vis4d/data/datasets/s3dis.py b/vis4d/data/datasets/s3dis.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad2657d0a69d32e4b2fba4d53bfaa10f21e1c82e
--- /dev/null
+++ b/vis4d/data/datasets/s3dis.py
@@ -0,0 +1,276 @@
+"""Stanford 3D indoor dataset."""
+
+from __future__ import annotations
+
+import copy
+import glob
+import os
+from collections.abc import Sequence
+from io import BytesIO
+
+import numpy as np
+import pandas as pd
+import torch
+
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.typing import DictData
+
+from .base import Dataset
+from .util import CacheMappingMixin
+
+
+class S3DIS(CacheMappingMixin, Dataset):
+ """S3DIS dataset class."""
+
+ DESCRIPTION = """S3DIS is a large-scale indoor pointcloud dataset."""
+ HOMEPAGE = "https://buildingparser.stanford.edu/dataset.html"
+ PAPER = (
+ "https://openaccess.thecvf.com/content_cvpr_2016/papers/"
+ "Armeni_3D_Semantic_Parsing_CVPR_2016_paper.pdf"
+ )
+ LICENSE = "CC BY-NC-SA 4.0"
+
+ KEYS = [
+ K.points3d,
+ K.colors3d,
+ K.semantics3d,
+ K.instances3d,
+ ]
+
+ CLASS_NAME_TO_IDX = {
+ "ceiling": 0,
+ "floor": 1,
+ "wall": 2,
+ "beam": 3,
+ "column": 4,
+ "window": 5,
+ "door": 6,
+ "chair": 7,
+ "table": 8,
+ "bookcase": 9,
+ "sofa": 10,
+ "board": 11,
+ "clutter": 12,
+ }
+
+ CLASS_COUNTS = torch.Tensor(
+ [
+ 3370714,
+ 2856755,
+ 4919229,
+ 318158,
+ 375640,
+ 478001,
+ 974733,
+ 650464,
+ 791496,
+ 88727,
+ 1284130,
+ 229758,
+ 2272837,
+ ]
+ )
+
+ AVAILABLE_KEYS: Sequence[str] = (
+ K.points3d,
+ K.colors3d,
+ K.semantics3d,
+ K.instances3d,
+ )
+
+ COLOR_MAPPING = torch.tensor(
+ [
+ [152, 223, 138],
+ [31, 119, 180],
+ [188, 189, 34],
+ [140, 86, 75],
+ [255, 152, 150],
+ [214, 39, 40],
+ [197, 176, 213],
+ [23, 190, 207],
+ [178, 76, 76],
+ [247, 182, 210],
+ [66, 188, 102],
+ [219, 219, 141],
+ [140, 57, 197],
+ [202, 185, 52],
+ ]
+ )
+
+ def __init__(
+ self,
+ data_root: str,
+ split: str = "trainNoArea5",
+ keys_to_load: Sequence[str] = (
+ K.points3d,
+ K.colors3d,
+ K.semantics3d,
+ K.instances3d,
+ ),
+ cache_points: bool = True,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new S3DIS dataset.
+
+ Args:
+ data_root (str): Path to S3DIS folder
+ split (str): which split to load. Must either be trainNoArea[1-6]
+ or testArea[1-6]. e.g. trainNoArea5 will load all areas except
+ area 5 and testArea5 will only load area 5.
+ keys_to_load (list[str]): What kind of data should be loaded
+ (e.g. colors, xyz, semantics, ...)
+ cache_points (bool): If true caches loaded points instead of
+ reading them from the disk every time.
+ cache_as_binary (bool): Whether to cache the dataset as binary.
+ Default: False.
+ cached_file_path (str | None): Path to a cached file. If cached
+ file exist then it will load it instead of generating the data
+ mapping. Default: None.
+
+ Raises:
+ ValueError: If requested split is malformed.
+ """
+ super().__init__(**kwargs)
+
+ self.data_root = data_root
+ self.split = split
+
+ self.areas: list[str] = [
+ "Area_1",
+ "Area_2",
+ "Area_3",
+ "Area_4",
+ "Area_5",
+ "Area_6",
+ ]
+ area_number = int(self.split.split("Area")[-1])
+ if "trainNoArea" in self.split:
+ self.areas.remove(self.areas[area_number - 1])
+ elif "testArea" in self.split:
+ self.areas = [self.areas[area_number - 1]]
+ else:
+ raise ValueError("Unknown split: ", self.split)
+
+ self.data, _ = self._load_mapping(
+ self._generate_data_mapping,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+ self.keys_to_load = keys_to_load
+
+ # Cache
+ self.cache_points = cache_points
+ self._cache: dict[int, DictData] = {}
+
+ @property
+ def num_classes(self) -> int:
+ """The number of classes int he datset."""
+ return len(S3DIS.CLASS_NAME_TO_IDX)
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return f"S3DIS(root={self.data_root}, split={self.split})"
+
+ def _generate_data_mapping(self) -> list[DictStrAny]:
+ """Generate 3dis dataset mapping."""
+ data: list[DictStrAny] = []
+ for area in self.areas:
+ for room_path in glob.glob(
+ os.path.join(self.data_root, area + "/*")
+ ):
+ room_data: DictStrAny = {}
+ if not os.path.isdir(room_path):
+ continue
+
+ for anns in glob.glob(
+ os.path.join(room_path, "Annotations/*.txt")
+ ):
+ instance_id = os.path.basename(anns.replace(".txt", ""))
+ sem_name = instance_id.split("_")[0]
+ room_data[instance_id] = {
+ "class_label": S3DIS.CLASS_NAME_TO_IDX.get(
+ sem_name, 12
+ ),
+ "path": anns,
+ }
+ data.append(room_data)
+
+ return data
+
+ def __len__(self) -> int:
+ """Length of the datset."""
+ return len(self.data)
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Transform s3dis sample to vis4d input format.
+
+ Returns:
+ coordinates: 3D Poitns coordinate Shape(n x 3)
+ colors: 3D Point colors Shape(n x 3)
+ Semantic Classes: 3D Point classes Shape(n x 1)
+
+ Raises:
+ ValueError: If a requested key does not exist in this dataset.
+ """
+ data = self.data[idx]
+
+ # Cache data
+ if self.cache_points and idx in self._cache:
+ return copy.deepcopy(self._cache[idx])
+
+ coords = np.zeros((0, 3), dtype=np.float32)
+ color = np.zeros((0, 3), dtype=np.float32)
+ semantic_ids = np.zeros((0, 1), dtype=int)
+ instance_ids = np.zeros((0, 1), dtype=int)
+
+ for values in data.values():
+ data_path = values["path"]
+ instance_id = int(
+ values["path"].split("_")[-1].replace(".txt", "")
+ )
+ np_data = pd.read_csv(
+ BytesIO(self.data_backend.get(data_path)),
+ header=None,
+ delimiter=" ",
+ ).values.astype(np.float32)
+
+ if K.points3d in self.keys_to_load:
+ coords = np.vstack([coords, np_data[:, :3]])
+ if K.colors3d in self.keys_to_load:
+ color = np.vstack([color, np_data[:, 3:]])
+ if K.semantics3d in self.keys_to_load:
+ semantic_ids = np.vstack(
+ [
+ semantic_ids,
+ np.ones((np_data.shape[0], 1), dtype=int)
+ * values["class_label"],
+ ]
+ )
+ if K.instances3d in self.keys_to_load:
+ instance_ids = np.vstack(
+ [
+ instance_ids,
+ np.ones((np_data.shape[0], 1), dtype=int)
+ * instance_id,
+ ]
+ )
+
+ data = {}
+ for key in self.keys_to_load:
+ if key == K.points3d:
+ data[key] = coords
+ elif key == K.colors3d:
+ data[key] = color / 255.0
+ elif key == K.semantics3d:
+ data[key] = semantic_ids.squeeze(-1)
+ elif key == K.instances3d:
+ data[key] = instance_ids.squeeze(-1)
+ else:
+ raise ValueError(f"Can not load data for key: {key}")
+
+ if self.cache_points:
+ self._cache[idx] = copy.deepcopy(data)
+ return data
diff --git a/vis4d/data/datasets/scalabel.py b/vis4d/data/datasets/scalabel.py
new file mode 100644
index 0000000000000000000000000000000000000000..746cd2be84473897d6c8ba5e41681589569b3eb1
--- /dev/null
+++ b/vis4d/data/datasets/scalabel.py
@@ -0,0 +1,753 @@
+"""Scalabel type dataset."""
+
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from collections.abc import Callable, Sequence
+from typing import Union
+
+import numpy as np
+import torch
+
+from vis4d.common.distributed import broadcast
+from vis4d.common.imports import SCALABEL_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.time import Timer
+from vis4d.common.typing import (
+ ArgsType,
+ ListAny,
+ NDArrayF32,
+ NDArrayI64,
+ NDArrayUI8,
+)
+from vis4d.data.const import AxisMode
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.util import CacheMappingMixin, DatasetFromList
+from vis4d.data.io import DataBackend
+from vis4d.data.typing import DictData
+from vis4d.op.geometry.rotation import (
+ euler_angles_to_matrix,
+ matrix_to_quaternion,
+)
+
+from .base import VideoDataset, VideoMapping
+from .util import DatasetFromList, im_decode, ply_decode, print_class_histogram
+
+if SCALABEL_AVAILABLE:
+ from scalabel.label.io import load, load_label_config
+ from scalabel.label.transforms import (
+ box2d_to_xyxy,
+ poly2ds_to_mask,
+ rle_to_mask,
+ )
+ from scalabel.label.typing import (
+ Config,
+ )
+ from scalabel.label.typing import Dataset as ScalabelData
+ from scalabel.label.typing import (
+ Extrinsics,
+ Frame,
+ ImageSize,
+ Intrinsics,
+ Label,
+ )
+ from scalabel.label.utils import (
+ check_crowd,
+ check_ignored,
+ get_leaf_categories,
+ get_matrix_from_extrinsics,
+ get_matrix_from_intrinsics,
+ )
+else:
+ raise ImportError("scalabel is not installed.")
+
+
+def load_intrinsics(intrinsics: Intrinsics) -> NDArrayF32:
+ """Transform intrinsic camera matrix according to augmentations."""
+ return get_matrix_from_intrinsics(intrinsics).astype(np.float32)
+
+
+def load_extrinsics(extrinsics: Extrinsics) -> NDArrayF32:
+ """Transform extrinsics from Scalabel to Vis4D."""
+ return get_matrix_from_extrinsics(extrinsics).astype(np.float32)
+
+
+def load_image(
+ url: str, backend: DataBackend, image_channel_mode: str
+) -> NDArrayF32:
+ """Load image tensor from url."""
+ im_bytes = backend.get(url)
+ image = im_decode(im_bytes, mode=image_channel_mode)
+ return np.ascontiguousarray(image, dtype=np.float32)[None]
+
+
+def load_pointcloud(url: str, backend: DataBackend) -> NDArrayF32:
+ """Load pointcloud tensor from url."""
+ assert url.endswith(".ply"), "Only PLY files are supported now."
+ ply_bytes = backend.get(url)
+ pointcloud = ply_decode(ply_bytes)
+ return pointcloud.astype(np.float32)
+
+
+def instance_ids_to_global(
+ frames: list[Frame], local_instance_ids: dict[str, list[str]]
+) -> None:
+ """Use local (per video) instance ids to produce global ones."""
+ video_names = list(local_instance_ids.keys())
+ for frame_id, ann in enumerate(frames):
+ if ann.labels is None: # pragma: no cover
+ continue
+ for label in ann.labels:
+ assert label.attributes is not None
+ if not check_crowd(label) and not check_ignored(label):
+ video_name = (
+ ann.videoName
+ if ann.videoName is not None
+ else "no-video-" + str(frame_id)
+ )
+ sum_previous_vids = sum(
+ (
+ len(local_instance_ids[v])
+ for v in video_names[: video_names.index(video_name)]
+ )
+ )
+ label.attributes["instance_id"] = (
+ sum_previous_vids
+ + local_instance_ids[video_name].index(label.id)
+ )
+
+
+def add_data_path(data_root: str, frames: list[Frame]) -> None:
+ """Add filepath to frame using data_root."""
+ for ann in frames:
+ assert ann.name is not None
+ if ann.url is None:
+ if ann.videoName is not None:
+ ann.url = os.path.join(data_root, ann.videoName, ann.name)
+ else:
+ ann.url = os.path.join(data_root, ann.name)
+ else:
+ ann.url = os.path.join(data_root, ann.url)
+
+
+def discard_labels_outside_set(
+ dataset: list[Frame], class_set: list[str]
+) -> None:
+ """Discard labels outside given set of classes.
+
+ Args:
+ dataset (list[Frame]): List of frames to filter.
+ class_set (list[str]): List of classes to keep.
+ """
+ for frame in dataset:
+ remove_anns = []
+ if frame.labels is not None:
+ for i, ann in enumerate(frame.labels):
+ if not ann.category in class_set:
+ remove_anns.append(i)
+ for i in reversed(remove_anns):
+ frame.labels.pop(i)
+
+
+def remove_empty_samples(frames: list[Frame]) -> list[Frame]:
+ """Remove empty samples."""
+ new_frames = []
+ for frame in frames:
+ if frame.labels is None:
+ continue
+ labels_used = []
+ for label in frame.labels:
+ assert label.attributes is not None and label.category is not None
+ if not check_crowd(label) and not check_ignored(label):
+ labels_used.append(label)
+
+ if len(labels_used) != 0:
+ frame.labels = labels_used
+ new_frames.append(frame)
+ rank_zero_info(f"Filtered {len(frames) - len(new_frames)} empty frames.")
+ del frames
+ return new_frames
+
+
+def prepare_labels(
+ frames: list[Frame],
+ class_list: list[str],
+ global_instance_ids: bool = False,
+) -> dict[str, int]:
+ """Add category id and instance id to labels, return class frequencies.
+
+ Args:
+ frames (list[Frame]): List of frames.
+ class_list (list[str]): List of classes.
+ global_instance_ids (bool): Whether to use global instance ids.
+ Defaults to False.
+ """
+ instance_ids: dict[str, list[str]] = defaultdict(list)
+ frequencies = {cat: 0 for cat in class_list}
+ for frame_id, ann in enumerate(frames):
+ if ann.labels is None: # pragma: no cover
+ continue
+
+ for label in ann.labels:
+ attr: dict[str, bool | int | float | str] = {}
+ if label.attributes is not None:
+ attr = label.attributes
+
+ if check_crowd(label) or check_ignored(label):
+ continue
+
+ assert label.category is not None
+ frequencies[label.category] += 1
+ video_name = (
+ ann.videoName
+ if ann.videoName is not None
+ else "no-video-" + str(frame_id)
+ )
+ if label.id not in instance_ids[video_name]:
+ instance_ids[video_name].append(label.id)
+ attr["instance_id"] = instance_ids[video_name].index(label.id)
+ label.attributes = attr
+
+ if global_instance_ids:
+ instance_ids_to_global(frames, instance_ids)
+
+ return frequencies
+
+
+def filter_frames_by_attributes(
+ frames: list[Frame],
+ attributes_to_load: Sequence[dict[str, str | float]] | None,
+) -> list[Frame]:
+ """Filter frames based on attributes."""
+ if attributes_to_load is None:
+ return frames
+ filtered_frames: list[Frame] = []
+ for frame in frames:
+ for attribute_dict in attributes_to_load:
+ if hasattr(frame, "attributes") and frame.attributes is not None:
+ if all(
+ frame.attributes.get(key) == value
+ for key, value in attribute_dict.items()
+ ):
+ filtered_frames.append(frame)
+ break
+ else:
+ raise ValueError(
+ "Attribute to load is specified but no attributes "
+ "are found in the frame."
+ )
+ return filtered_frames
+
+
+# Not using | operator because of a bug in Python 3.9
+# https://bugs.python.org/issue42233
+CategoryMap = Union[dict[str, int], dict[str, dict[str, int]]]
+
+
+class Scalabel(CacheMappingMixin, VideoDataset):
+ """Scalabel type dataset.
+
+ This class loads scalabel format data into Vis4D.
+ """
+
+ def __init__(
+ self,
+ data_root: str,
+ annotation_path: str,
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d),
+ category_map: None | CategoryMap = None,
+ config_path: None | str | Config = None,
+ global_instance_ids: bool = False,
+ bg_as_class: bool = False,
+ skip_empty_samples: bool = False,
+ attributes_to_load: Sequence[dict[str, str | float]] | None = None,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ data_root (str): Root directory of the data.
+ annotation_path (str): Path to the annotation json(s).
+ keys_to_load (Sequence[str, ...], optional): Keys to load from the
+ dataset. Defaults to (K.images, K.boxes2d).
+ category_map (None | CategoryMap, optional): Mapping from a
+ Scalabel category string to an integer index. If None, the
+ standard mapping in the dataset config will be used. Defaults
+ to None.
+ config_path (None | str | Config, optional): Path to the dataset
+ config, can be added if it is not provided together with the
+ labels or should be modified. Defaults to None.
+ global_instance_ids (bool): Whether to convert tracking IDs of
+ annotations into dataset global IDs or stay with local,
+ per-video IDs. Defaults to false.
+ bg_as_class (bool): Whether to include background pixels as an
+ additional class for masks.
+ skip_empty_samples (bool): Whether to skip samples without
+ annotations.
+ attributes_to_load (Sequence[dict[str, str]]): List of attributes
+ dictionaries to load. Each dictionary is a mapping from the
+ attribute name to its desired value. If any of the attributes
+ dictionaries is matched, the corresponding frame will be
+ loaded. Defaults to None.
+ cache_as_binary (bool): Whether to cache the dataset as binary.
+ Default: False.
+ cached_file_path (str | None): Path to a cached file. If cached
+ file exist then it will load it instead of generating the data
+ mapping. Default: None.
+ """
+ super().__init__(**kwargs)
+ assert SCALABEL_AVAILABLE, "Scalabel is not installed."
+ self.data_root = data_root
+ self.annotation_path = annotation_path
+ self.keys_to_load = keys_to_load
+ self.global_instance_ids = global_instance_ids
+ self.bg_as_class = bg_as_class
+ self.config_path = config_path
+ self.skip_empty_samples = skip_empty_samples
+
+ self.cats_name2id: dict[str, dict[str, int]] = {}
+ self.category_map = category_map
+
+ self.attributes_to_load = attributes_to_load
+
+ self.frames, self.cfg = self._load_mapping(
+ self._generate_mapping,
+ remove_empty_samples,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ assert self.cfg is not None, (
+ "No dataset configuration found. Please provide a configuration "
+ "via config_path."
+ )
+
+ if self.category_map is None:
+ class_list = list(
+ c.name for c in get_leaf_categories(self.cfg.categories)
+ )
+ self.category_map = {c: i for i, c in enumerate(class_list)}
+ self._setup_categories()
+ self.video_mapping = self._generate_video_mapping()
+
+ def _generate_video_mapping(self) -> VideoMapping:
+ """Group all dataset sample indices (int) by their video ID (str).
+
+ Returns:
+ VideoMapping: Mapping of video IDs to sample indices and frame IDs.
+ """
+ video_to_indices: dict[str, list[int]] = defaultdict(list)
+ video_to_frame_ids: dict[str, list[int]] = defaultdict(list)
+ for idx, frame in enumerate(self.frames): # type: ignore
+ if frame.videoName is not None:
+ assert (
+ frame.frameIndex is not None
+ ), "found videoName but no frameIndex!"
+ video_to_indices[frame.videoName].append(idx)
+ video_to_frame_ids[frame.videoName].append(frame.frameIndex)
+
+ return self._sort_video_mapping(
+ {
+ "video_to_indices": video_to_indices,
+ "video_to_frame_ids": video_to_frame_ids,
+ }
+ )
+
+ def _setup_categories(self) -> None:
+ """Setup categories."""
+ assert self.category_map is not None
+ for target in self.keys_to_load:
+ if isinstance(list(self.category_map.values())[0], int):
+ self.cats_name2id[target] = self.category_map # type: ignore
+ else:
+ assert (
+ target in self.category_map
+ ), f"Target={target} not specified in category_mapping"
+ target_map = self.category_map[target]
+ assert isinstance(target_map, dict)
+ self.cats_name2id[target] = target_map
+
+ def _load_mapping( # type: ignore
+ self,
+ generate_map_func: Callable[[], ScalabelData],
+ filter_func: Callable[[ListAny], ListAny] = lambda x: x,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ ) -> tuple[DatasetFromList, Config]:
+ """Load cached mapping or generate if not exists."""
+ timer = Timer()
+ data = self._load_mapping_data(
+ generate_map_func, cache_as_binary, cached_file_path
+ )
+ if data is not None:
+ frames, cfg = data.frames, data.config
+
+ add_data_path(self.data_root, frames)
+ rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.")
+
+ if self.category_map is None:
+ class_list = list(
+ c.name for c in get_leaf_categories(cfg.categories)
+ )
+ self.category_map = {c: i for i, c in enumerate(class_list)}
+ else:
+ class_list = list(self.category_map.keys())
+
+ assert len(set(class_list)) == len(
+ class_list
+ ), "Class names are not unique!"
+
+ discard_labels_outside_set(frames, class_list)
+
+ frames = filter_frames_by_attributes(
+ frames, self.attributes_to_load
+ )
+
+ if self.skip_empty_samples:
+ frames = filter_func(frames)
+
+ t = Timer()
+ frequencies = prepare_labels(
+ frames,
+ class_list,
+ global_instance_ids=self.global_instance_ids,
+ )
+ rank_zero_info(
+ f"Preprocessing {len(frames)} frames takes {t.time():.2f}"
+ " seconds."
+ )
+ print_class_histogram(frequencies)
+ frames_dataset = DatasetFromList(frames)
+ else:
+ frames_dataset = None
+ cfg = None
+ frames_dataset = broadcast(frames_dataset)
+ cfg = broadcast(cfg)
+ assert frames_dataset is not None
+ return frames_dataset, cfg
+
+ def _generate_mapping(self) -> ScalabelData:
+ """Generate data mapping."""
+ data = load(self.annotation_path)
+ if self.config_path is not None:
+ if isinstance(self.config_path, str):
+ data.config = load_label_config(self.config_path)
+ else:
+ data.config = self.config_path
+ return data
+
+ def _load_inputs(self, frame: Frame) -> DictData:
+ """Load inputs given a scalabel frame."""
+ data: DictData = {}
+ if K.images in self.keys_to_load:
+ assert frame.url is not None, "url is None!"
+ image = load_image(
+ frame.url, self.data_backend, self.image_channel_mode
+ )
+ input_hw = (image.shape[1], image.shape[2])
+ data[K.images] = image
+ data[K.input_hw] = input_hw
+
+ # Original image
+ data[K.original_images] = image
+ data[K.original_hw] = input_hw
+
+ data[K.axis_mode] = AxisMode.OPENCV
+ data[K.frame_ids] = frame.frameIndex
+
+ data[K.sample_names] = frame.name
+ data[K.sequence_names] = frame.videoName
+
+ if K.points3d in self.keys_to_load:
+ assert frame.url is not None, "url is None!"
+ data[K.points3d] = load_pointcloud(frame.url, self.data_backend)
+
+ if frame.intrinsics is not None and K.intrinsics in self.keys_to_load:
+ data[K.intrinsics] = load_intrinsics(frame.intrinsics)
+
+ if frame.extrinsics is not None and K.extrinsics in self.keys_to_load:
+ data[K.extrinsics] = load_extrinsics(frame.extrinsics)
+ return data
+
+ def _add_annotations(self, frame: Frame, data: DictData) -> None:
+ """Add annotations given a scalabel frame and a data dictionary."""
+ labels_used, instid_map = [], {}
+ if frame.labels is not None:
+ for label in frame.labels:
+ assert (
+ label.attributes is not None and label.category is not None
+ )
+ if not check_crowd(label) and not check_ignored(label):
+ labels_used.append(label)
+ if label.id not in instid_map:
+ instid_map[label.id] = int(
+ label.attributes["instance_id"]
+ )
+
+ image_size = (
+ ImageSize(height=data[K.input_hw][0], width=data[K.input_hw][1])
+ if K.input_hw in data
+ else frame.size
+ )
+
+ if K.boxes2d in self.keys_to_load:
+ cats_name2id = self.cats_name2id[K.boxes2d]
+ boxes2d, classes, track_ids = boxes2d_from_scalabel(
+ labels_used, cats_name2id, instid_map
+ )
+ data[K.boxes2d] = boxes2d
+ data[K.boxes2d_classes] = classes
+ data[K.boxes2d_track_ids] = track_ids
+
+ if K.instance_masks in self.keys_to_load:
+ # NOTE: instance masks' mapping is consistent with boxes2d
+ cats_name2id = self.cats_name2id[K.instance_masks]
+ instance_masks = instance_masks_from_scalabel(
+ labels_used, cats_name2id, image_size
+ )
+ data[K.instance_masks] = instance_masks
+
+ if K.seg_masks in self.keys_to_load:
+ sem_map = self.cats_name2id[K.seg_masks]
+ semantic_masks = semantic_masks_from_scalabel(
+ labels_used, sem_map, image_size, self.bg_as_class
+ )
+ data[K.seg_masks] = semantic_masks
+
+ if K.boxes3d in self.keys_to_load:
+ boxes3d, classes, track_ids = boxes3d_from_scalabel(
+ labels_used, self.cats_name2id[K.boxes3d], instid_map
+ )
+ data[K.boxes3d] = boxes3d
+ data[K.boxes3d_classes] = classes
+ data[K.boxes3d_track_ids] = track_ids
+
+ def __len__(self) -> int:
+ """Length of dataset."""
+ return len(self.frames)
+
+ def __getitem__(self, index: int) -> DictData:
+ """Get item from dataset at given index."""
+ frame = self.frames[index]
+ data = self._load_inputs(frame)
+
+ # load annotations to input sample
+ self._add_annotations(frame, data)
+
+ return data
+
+
+def boxes2d_from_scalabel(
+ labels: list[Label],
+ class_to_idx: dict[str, int],
+ label_id_to_idx: dict[str, int] | None = None,
+) -> tuple[NDArrayF32, NDArrayI64, NDArrayI64]:
+ """Convert from scalabel format to Vis4D.
+
+ NOTE: The box definition in Scalabel includes x2y2 in the box area, whereas
+ Vis4D and other software libraries like detectron2 and mmdet do not include
+ this, which is why we convert via box2d_to_xyxy.
+
+ Args:
+ labels (list[Label]): list of scalabel labels.
+ class_to_idx (dict[str, int]): mapping from class name to index.
+ label_id_to_idx (dict[str, int] | None, optional): mapping from label
+ id to index. Defaults to None.
+
+ Returns:
+ tuple[NDArrayF32, NDArrayI64, NDArrayI64]: boxes, classes, track_ids
+ """
+ box_list, cls_list, idx_list = [], [], []
+ for i, label in enumerate(labels):
+ box, box_cls, l_id = label.box2d, label.category, label.id
+ if box is None:
+ continue
+ if box_cls in class_to_idx:
+ cls_list.append(class_to_idx[box_cls])
+ else:
+ continue
+
+ box_list.append(box2d_to_xyxy(box))
+ idx = label_id_to_idx[l_id] if label_id_to_idx is not None else i
+ idx_list.append(idx)
+
+ if len(box_list) == 0:
+ return (
+ np.empty((0, 4), dtype=np.float32),
+ np.empty((0,), dtype=np.int64),
+ np.empty((0,), dtype=np.int64),
+ )
+
+ box_tensor = np.array(box_list, dtype=np.float32)
+ class_ids = np.array(cls_list, dtype=np.int64)
+ track_ids = np.array(idx_list, dtype=np.int64)
+ return box_tensor, class_ids, track_ids
+
+
+def instance_masks_from_scalabel(
+ labels: list[Label],
+ class_to_idx: dict[str, int],
+ image_size: ImageSize | None = None,
+) -> NDArrayUI8:
+ """Convert instance masks from scalabel format to Vis4D.
+
+ Args:
+ labels (list[Label]): list of scalabel labels.
+ class_to_idx (dict[str, int]): mapping from class name to index.
+ image_size (ImageSize, optional): image size. Defaults to None.
+
+ Returns:
+ NDArrayUI8: instance masks.
+ """
+ bitmask_list = []
+ for _, label in enumerate(labels):
+ if label.category not in class_to_idx: # pragma: no cover
+ continue # skip unknown classes
+ if label.poly2d is None and label.rle is None:
+ continue
+ if label.rle is not None:
+ bitmask = rle_to_mask(label.rle)
+ elif label.poly2d is not None:
+ assert (
+ image_size is not None
+ ), "image size must be specified for masks with polygons!"
+ bitmask_raw = poly2ds_to_mask(image_size, label.poly2d)
+ bitmask: NDArrayUI8 = (bitmask_raw > 0).astype( # type: ignore
+ bitmask_raw.dtype
+ )
+ else:
+ raise ValueError("No mask found in label.")
+ bitmask_list.append(bitmask)
+ if len(bitmask_list) == 0: # pragma: no cover
+ return np.empty((0, 0, 0), dtype=np.uint8)
+ mask_array = np.array(bitmask_list, dtype=np.uint8)
+ return mask_array
+
+
+def nhw_to_hwc_mask(
+ masks: NDArrayUI8, class_ids: NDArrayI64, ignore_class: int = 255
+) -> NDArrayUI8:
+ """Convert N binary HxW masks to HxW semantic mask.
+
+ Args:
+ masks (NDArrayUI8): Masks with shape [N, H, W].
+ class_ids (NDArrayI64): Class IDs with shape [N, 1].
+ ignore_class (int, optional): Ignore label. Defaults to 255.
+
+ Returns:
+ NDArrayUI8: Masks with shape [H, W], where each location indicate the
+ class label.
+ """
+ hwc_mask = np.full(masks.shape[1:], ignore_class, dtype=masks.dtype)
+ for mask, cat_id in zip(masks, class_ids):
+ hwc_mask[mask > 0] = cat_id
+ return hwc_mask
+
+
+def semantic_masks_from_scalabel(
+ labels: list[Label],
+ class_to_idx: dict[str, int],
+ image_size: ImageSize | None = None,
+ bg_as_class: bool = False,
+) -> NDArrayUI8:
+ """Convert masks from scalabel format to Vis4D.
+
+ Args:
+ labels (list[Label]): list of scalabel labels.
+ class_to_idx (dict[str, int]): mapping from class name to index.
+ image_size (ImageSize, optional): image size. Defaults to None.
+ bg_as_class (bool, optional): whether to include background as a class.
+ Defaults to False.
+
+ Returns:
+ NDArrayUI8: instance masks.
+ """
+ bitmask_list, cls_list = [], []
+ if bg_as_class:
+ foreground: NDArrayUI8 | None = None
+ for _, label in enumerate(labels):
+ if label.poly2d is None and label.rle is None:
+ continue
+ mask_cls = label.category
+ if mask_cls in class_to_idx:
+ cls_list.append(class_to_idx[mask_cls])
+ else: # pragma: no cover
+ continue # skip unknown classes
+ if label.rle is not None:
+ bitmask = rle_to_mask(label.rle)
+ elif label.poly2d is not None:
+ assert (
+ image_size is not None
+ ), "image size must be specified for masks with polygons!"
+ bitmask_raw = poly2ds_to_mask(image_size, label.poly2d)
+ bitmask: NDArrayUI8 = (bitmask_raw > 0).astype( # type: ignore
+ bitmask_raw.dtype
+ )
+ else:
+ raise ValueError("No mask found in label.")
+ bitmask_list.append(bitmask)
+ if bg_as_class:
+ foreground = (
+ bitmask
+ if foreground is None
+ else np.logical_or(foreground, bitmask)
+ )
+ if bg_as_class:
+ if foreground is None: # pragma: no cover
+ assert image_size is not None
+ foreground = np.zeros(
+ (image_size.height, image_size.width), dtype=np.uint8
+ )
+ bitmask_list.append(np.logical_not(foreground))
+ assert "background" in class_to_idx, (
+ '"bg_as_class" requires "background" class to be '
+ "in category_mapping"
+ )
+ cls_list.append(class_to_idx["background"])
+ if len(bitmask_list) == 0: # pragma: no cover
+ return np.empty((0, 0), dtype=np.uint8)
+ mask_array = np.array(bitmask_list, dtype=np.uint8)
+ class_ids = np.array(cls_list, dtype=np.int64)
+ return nhw_to_hwc_mask(mask_array, class_ids)
+
+
+def boxes3d_from_scalabel(
+ labels: list[Label],
+ class_to_idx: dict[str, int],
+ label_id_to_idx: dict[str, int] | None = None,
+) -> tuple[NDArrayF32, NDArrayI64, NDArrayI64]:
+ """Convert 3D bounding boxes from scalabel format to Vis4D."""
+ box_list, cls_list, idx_list = [], [], []
+ for i, label in enumerate(labels):
+ box, box_cls, l_id = label.box3d, label.category, label.id
+ if box is None:
+ continue
+ if box_cls in class_to_idx:
+ cls_list.append(class_to_idx[box_cls])
+ else:
+ continue
+
+ quaternion = (
+ matrix_to_quaternion(
+ euler_angles_to_matrix(torch.tensor([box.orientation]))
+ )[0]
+ .numpy()
+ .tolist()
+ )
+ box_list.append([*box.location, *box.dimension, *quaternion])
+ idx = label_id_to_idx[l_id] if label_id_to_idx is not None else i
+ idx_list.append(idx)
+
+ if len(box_list) == 0:
+ return (
+ np.empty((0, 10), dtype=np.float32),
+ np.empty((0,), dtype=np.int64),
+ np.empty((0,), dtype=np.int64),
+ )
+ box_tensor = np.array(box_list, dtype=np.float32)
+ class_ids = np.array(cls_list, dtype=np.int64)
+ track_ids = np.array(idx_list, dtype=np.int64)
+ return box_tensor, class_ids, track_ids
diff --git a/vis4d/data/datasets/shift.py b/vis4d/data/datasets/shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd79e2fced301293f937e01579b54cc57bc834e6
--- /dev/null
+++ b/vis4d/data/datasets/shift.py
@@ -0,0 +1,621 @@
+"""SHIFT dataset."""
+
+from __future__ import annotations
+
+import json
+import multiprocessing
+import os
+from collections.abc import Sequence
+from functools import partial
+
+import numpy as np
+from tqdm import tqdm
+
+from vis4d.common.imports import SCALABEL_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import NDArrayF32, NDArrayI64, NDArrayNumber
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.base import VideoDataset
+from vis4d.data.datasets.util import im_decode, npy_decode
+from vis4d.data.io import DataBackend, FileBackend, HDF5Backend, ZipBackend
+from vis4d.data.typing import DictData
+
+from .base import VideoDataset, VideoMapping
+from .scalabel import Scalabel
+
+shift_det_map = {
+ "pedestrian": 0,
+ "car": 1,
+ "truck": 2,
+ "bus": 3,
+ "motorcycle": 4,
+ "bicycle": 5,
+}
+shfit_track_map = {
+ "pedestrian": 0,
+ "car": 1,
+ "truck": 2,
+ "bus": 3,
+ "motorcycle": 4,
+ "bicycle": 5,
+}
+shift_seg_map = {
+ "unlabeled": 0,
+ "building": 1,
+ "fence": 2,
+ "other": 3,
+ "pedestrian": 4,
+ "pole": 5,
+ "road line": 6,
+ "road": 7,
+ "sidewalk": 8,
+ "vegetation": 9,
+ "vehicle": 10,
+ "wall": 11,
+ "traffic sign": 12,
+ "sky": 13,
+ "ground": 14,
+ "bridge": 15,
+ "rail track": 16,
+ "guard rail": 17,
+ "traffic light": 18,
+ "static": 19,
+ "dynamic": 20,
+ "water": 21,
+ "terrain": 22,
+}
+shift_seg_ignore = [
+ "unlabeled",
+ "other",
+ "ground",
+ "bridge",
+ "rail track",
+ "guard rail",
+ "static",
+ "dynamic",
+ "water",
+]
+
+if SCALABEL_AVAILABLE:
+ from scalabel.label.io import parse
+ from scalabel.label.typing import Config
+ from scalabel.label.typing import Dataset as ScalabelData
+else:
+ raise ImportError("scalabel is not installed.")
+
+
+def _get_extension(backend: DataBackend) -> str:
+ """Get the appropriate file extension for the given backend."""
+ if isinstance(backend, HDF5Backend):
+ return ".hdf5"
+ if isinstance(backend, ZipBackend):
+ return ".zip"
+ if isinstance(backend, FileBackend): # pragma: no cover
+ return ""
+ raise ValueError(f"Unsupported backend {backend}.") # pragma: no cover
+
+
+class _SHIFTScalabelLabels(Scalabel):
+ """Helper class for labels in SHIFT that are stored in Scalabel format."""
+
+ VIEWS = [
+ "front",
+ "center",
+ "left_45",
+ "left_90",
+ "right_45",
+ "right_90",
+ "left_stereo",
+ ]
+
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ data_file: str = "",
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d),
+ attributes_to_load: Sequence[dict[str, str | float]] | None = None,
+ annotation_file: str = "",
+ view: str = "front",
+ framerate: str = "images",
+ shift_type: str = "discrete",
+ skip_empty_frames: bool = False,
+ backend: DataBackend = HDF5Backend(),
+ verbose: bool = False,
+ num_workers: int = 1,
+ ) -> None:
+ """Initialize SHIFT dataset for one view.
+
+ Args:
+ data_root (str): Path to the root directory of the dataset.
+ split (str): Which data split to load.
+ data_file (str): Path to the data archive file. Default: "".
+ keys_to_load (Sequence[str]): List of keys to load.
+ Default: (K.images, K.boxes2d).
+ attributes_to_load (Sequence[dict[str, str | float]] | None):
+ List of attributes to load. Default: None.
+ annotation_file (str): Path to the annotation file. Default: "".
+ view (str): Which view to load. Default: "front". Options: "front",
+ "center", "left_45", "left_90", "right_45", "right_90", and
+ "left_stereo".
+ framerate (str): Which framerate to load. Default: "images".
+ shift_type (str): Which shift type to load. Default: "discrete".
+ Options: "discrete", "continuous/1x", "continuous/10x", and
+ "continuous/100x".
+ skip_empty_frames (bool): Whether to skip frames with no
+ instance annotations. Default: False.
+ backend (DataBackend): Backend to use for loading data. Default:
+ HDF5Backend().
+ verbose (bool): Whether to print verbose logs. Default: False.
+ num_workers (int): Number of workers to use for loading data.
+ Default: 1.
+ """
+ self.verbose = verbose
+ self.num_workers = num_workers
+
+ # Validate input
+ assert split in {"train", "val", "test"}, f"Invalid split '{split}'"
+ assert view in _SHIFTScalabelLabels.VIEWS, f"Invalid view '{view}'"
+
+ # Set attributes
+ ext = _get_extension(backend)
+ if shift_type.startswith("continuous"):
+ shift_speed = shift_type.split("/")[-1]
+ annotation_path = os.path.join(
+ data_root,
+ "continuous",
+ framerate,
+ shift_speed,
+ split,
+ view,
+ annotation_file,
+ )
+ data_path = os.path.join(
+ data_root,
+ "continuous",
+ framerate,
+ shift_speed,
+ split,
+ view,
+ f"{data_file}{ext}",
+ )
+ else:
+ annotation_path = os.path.join(
+ data_root, "discrete", framerate, split, view, annotation_file
+ )
+ data_path = os.path.join(
+ data_root,
+ "discrete",
+ framerate,
+ split,
+ view,
+ f"{data_file}{ext}",
+ )
+ super().__init__(
+ data_path,
+ annotation_path,
+ data_backend=backend,
+ keys_to_load=keys_to_load,
+ attributes_to_load=attributes_to_load,
+ skip_empty_samples=skip_empty_frames,
+ )
+
+ def _generate_mapping(self) -> ScalabelData:
+ """Generate data mapping."""
+ # Skipping validation for much faster loading
+ if self.verbose:
+ rank_zero_info(
+ "Loading annotation from '%s' ...", self.annotation_path
+ )
+ return self._load(self.annotation_path)
+
+ def _load(self, filepath: str) -> ScalabelData:
+ """Load labels from a json file or a folder of json files."""
+ raw_frames: list[DictData] = []
+ raw_groups: list[DictData] = []
+ if not os.path.exists(filepath):
+ raise FileNotFoundError(f"{filepath} does not exist.")
+
+ def process_file(filepath: str) -> DictData | None:
+ raw_cfg = None
+ with open(filepath, mode="r", encoding="utf-8") as fp:
+ content = json.load(fp)
+ if isinstance(content, dict):
+ raw_frames.extend(content["frames"])
+ if "groups" in content and content["groups"] is not None:
+ raw_groups.extend(content["groups"])
+ if "config" in content and content["config"] is not None:
+ raw_cfg = content["config"]
+ elif isinstance(content, list):
+ raw_frames.extend(content)
+ else:
+ raise TypeError(
+ "The input file contains neither dict nor list."
+ )
+
+ rank_zero_info(
+ "Loading SHIFT annotation from '%s' Done.", filepath
+ )
+ return raw_cfg
+
+ cfg = None
+ if os.path.isfile(filepath) and filepath.endswith("json"):
+ ret_cfg = process_file(filepath)
+ if ret_cfg is not None:
+ cfg = ret_cfg
+ else:
+ raise TypeError("Inputs must be a folder or a JSON file.")
+
+ config = None
+ if cfg is not None:
+ config = Config(**cfg)
+
+ parse_func = partial(parse, validate_frames=False)
+ if self.num_workers > 1:
+ with multiprocessing.Pool(self.num_workers) as pool:
+ frames = []
+ with tqdm(total=len(raw_frames)) as pbar:
+ for result in pool.imap_unordered(
+ parse_func, raw_frames, chunksize=1000
+ ):
+ frames.append(result)
+ pbar.update()
+ else:
+ frames = [parse_func(frame) for frame in raw_frames]
+ return ScalabelData(frames=frames, config=config, groups=None)
+
+
+class SHIFT(VideoDataset):
+ """SHIFT dataset class, supporting multiple tasks and views."""
+
+ DESCRIPTION = """SHIFT Dataset, a synthetic driving dataset for continuous
+ multi-task domain adaptation"""
+ HOMEPAGE = "https://www.vis.xyz/shift/"
+ PAPER = "https://arxiv.org/abs/2206.08367"
+ LICENSE = "CC BY-NC-SA 4.0"
+
+ KEYS = [
+ # Inputs
+ K.images,
+ K.original_hw,
+ K.input_hw,
+ K.points3d,
+ # Scalabel formatted annotations
+ K.intrinsics,
+ K.extrinsics,
+ K.timestamp,
+ K.axis_mode,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ K.instance_masks,
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ # Bit masks
+ K.seg_masks,
+ K.depth_maps,
+ K.optical_flows,
+ ]
+
+ VIEWS = [
+ "front",
+ "center",
+ "left_45",
+ "left_90",
+ "right_45",
+ "right_90",
+ "left_stereo",
+ ]
+
+ DATA_GROUPS = {
+ "img": [
+ K.images,
+ K.original_hw,
+ K.input_hw,
+ K.intrinsics,
+ ],
+ "det_2d": [
+ K.timestamp,
+ K.axis_mode,
+ K.extrinsics,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ ],
+ "det_3d": [
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ ],
+ "det_insseg_2d": [
+ K.instance_masks,
+ ],
+ "semseg": [
+ K.seg_masks,
+ ],
+ "depth": [
+ K.depth_maps,
+ ],
+ "flow": [
+ K.optical_flows,
+ ],
+ "lidar": [
+ K.points3d,
+ ],
+ }
+
+ GROUPS_IN_SCALABEL = ["det_2d", "det_3d", "det_insseg_2d"]
+
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d),
+ views_to_load: Sequence[str] = ("front",),
+ attributes_to_load: Sequence[dict[str, str | float]] | None = None,
+ framerate: str = "images",
+ shift_type: str = "discrete",
+ skip_empty_frames: bool = False,
+ backend: DataBackend = HDF5Backend(),
+ num_workers: int = 1,
+ verbose: bool = False,
+ ) -> None:
+ """Initialize SHIFT dataset."""
+ super().__init__(data_backend=backend)
+ # Validate input
+ assert split in {"train", "val", "test"}, f"Invalid split '{split}'."
+ assert framerate in {
+ "images",
+ "videos",
+ }, f"Invalid framerate '{framerate}'. Must be 'images' or 'videos'."
+ assert shift_type in {
+ "discrete",
+ "continuous/1x",
+ "continuous/10x",
+ "continuous/100x",
+ }, (
+ f"Invalid shift_type '{shift_type}'. "
+ "Must be one of 'discrete', 'continuous/1x', 'continuous/10x', "
+ "or 'continuous/100x'."
+ )
+ self.validate_keys(keys_to_load)
+
+ # Set attributes
+ self.data_root = data_root
+ self.split = split
+ self.keys_to_load = keys_to_load
+ self.views_to_load = views_to_load
+ self.attributes_to_load = attributes_to_load
+ self.framerate = framerate
+ self.shift_type = shift_type
+ self.backend = backend
+ self.verbose = verbose
+ self.ext = _get_extension(backend)
+ if self.shift_type.startswith("continuous"):
+ shift_speed = self.shift_type.split("/")[-1]
+ self.annotation_base = os.path.join(
+ self.data_root,
+ "continuous",
+ self.framerate,
+ shift_speed,
+ self.split,
+ )
+ else:
+ self.annotation_base = os.path.join(
+ self.data_root, self.shift_type, self.framerate, self.split
+ )
+ if self.verbose:
+ print(f"Base: {self.annotation_base}. Backend: {self.backend}")
+
+ # Get the data groups' classes that need to be loaded
+ self._data_groups_to_load = self._get_data_groups(keys_to_load)
+ if "det_2d" not in self._data_groups_to_load:
+ raise ValueError(
+ "In current implementation, the 'det_2d' data group must be "
+ "loaded to load any other data group."
+ )
+
+ self.scalabel_datasets = {}
+ for view in self.views_to_load:
+ if view == "center":
+ # Load lidar data, only available for center view
+ self.scalabel_datasets["center/lidar"] = _SHIFTScalabelLabels(
+ data_root=self.data_root,
+ split=self.split,
+ data_file="lidar",
+ annotation_file="det_3d.json",
+ view=view,
+ framerate=self.framerate,
+ shift_type=self.shift_type,
+ keys_to_load=(K.points3d, *self.DATA_GROUPS["det_3d"]),
+ attributes_to_load=self.attributes_to_load,
+ skip_empty_frames=skip_empty_frames,
+ backend=backend,
+ num_workers=num_workers,
+ verbose=verbose,
+ )
+ else:
+ # Skip the lidar data group, which is loaded separately
+ image_loaded = False
+ for group in self._data_groups_to_load:
+ name = f"{view}/{group}"
+ keys_to_load = list(self.DATA_GROUPS[group])
+ # Load the image data group only once
+ if not image_loaded:
+ keys_to_load.extend(self.DATA_GROUPS["img"])
+ image_loaded = True
+ self.scalabel_datasets[name] = _SHIFTScalabelLabels(
+ data_root=self.data_root,
+ split=self.split,
+ data_file="img",
+ annotation_file=f"{group}.json",
+ view=view,
+ framerate=self.framerate,
+ shift_type=self.shift_type,
+ keys_to_load=keys_to_load,
+ attributes_to_load=self.attributes_to_load,
+ skip_empty_frames=skip_empty_frames,
+ backend=backend,
+ num_workers=num_workers,
+ verbose=verbose,
+ )
+
+ self.video_mapping = self._generate_video_mapping()
+
+ def validate_keys(self, keys_to_load: Sequence[str]) -> None:
+ """Validate that all keys to load are supported."""
+ for k in keys_to_load:
+ if k not in self.KEYS:
+ raise ValueError(f"Key '{k}' is not supported!")
+
+ def _get_data_groups(self, keys_to_load: Sequence[str]) -> list[str]:
+ """Get the data groups that need to be loaded from Scalabel."""
+ data_groups = ["det_2d"]
+ for data_group, group_keys in self.DATA_GROUPS.items():
+ if data_group in self.GROUPS_IN_SCALABEL:
+ # If the data group is loaded by Scalabel, add it to the list
+ if any(key in group_keys for key in keys_to_load):
+ data_groups.append(data_group)
+ return list(set(data_groups))
+
+ def _load(
+ self, view: str, data_group: str, file_ext: str, video: str, frame: str
+ ) -> NDArrayNumber:
+ """Load data from the given data group."""
+ frame_number = frame.split("_")[0]
+ filepath = os.path.join(
+ self.annotation_base,
+ view,
+ f"{data_group}{self.ext}",
+ video,
+ f"{frame_number}_{data_group}_{view}.{file_ext}",
+ )
+ if data_group == "semseg":
+ return self._load_semseg(filepath)
+ if data_group == "depth":
+ return self._load_depth(filepath)
+ if data_group == "flow":
+ return self._load_flow(filepath)
+ raise ValueError(
+ f"Invalid data group '{data_group}'"
+ ) # pragma: no cover
+
+ def _load_semseg(self, filepath: str) -> NDArrayI64:
+ """Load semantic segmentation data."""
+ im_bytes = self.backend.get(filepath)
+ image = im_decode(im_bytes)[..., 0]
+ return image.astype(np.int64)
+
+ def _load_depth(
+ self, filepath: str, depth_factor: float = 16777.216 # 256 ^ 3 / 1000
+ ) -> NDArrayF32:
+ """Load depth data."""
+ assert depth_factor > 0, "Max depth value must be greater than 0."
+
+ im_bytes = self.backend.get(filepath)
+ image = im_decode(im_bytes)
+ if image.shape[2] > 3: # pragma: no cover
+ image = image[:, :, :3]
+ image = image.astype(np.float32)
+
+ # Convert to depth
+ depth = (
+ image[:, :, 2] * 256 * 256 + image[:, :, 1] * 256 + image[:, :, 0]
+ )
+ return np.ascontiguousarray(depth / depth_factor, dtype=np.float32)
+
+ def _load_flow(self, filepath: str) -> NDArrayF32:
+ """Load optical flow data."""
+ npy_bytes = self.backend.get(filepath)
+ flow = npy_decode(npy_bytes, key="flow")
+ flow = flow[:, :, [1, 0]] # Convert to (u, v) format
+ flow *= flow.shape[1] # Scale to image size (1280)
+ if self.framerate == "images":
+ flow *= 10.0 # NOTE: Scale to 1 fps approximately
+ return flow.astype(np.float32)
+
+ def _get_frame_key(self, idx: int) -> tuple[str, str]:
+ """Get the frame identifier (video name, frame name) by index."""
+ if len(self.scalabel_datasets) > 0:
+ frames = self.scalabel_datasets[
+ list(self.scalabel_datasets.keys())[0]
+ ].frames
+ return frames[idx].videoName, frames[idx].name
+ raise ValueError("No Scalabel file has been loaded.")
+
+ def __len__(self) -> int:
+ """Get the number of samples in the dataset."""
+ if len(self.scalabel_datasets) > 0:
+ return len(
+ self.scalabel_datasets[list(self.scalabel_datasets.keys())[0]]
+ )
+ raise ValueError(
+ "No Scalabel file has been loaded."
+ ) # pragma: no cover
+
+ def _generate_video_mapping(self) -> VideoMapping:
+ """Group all dataset sample indices (int) by their video ID (str).
+
+ Returns:
+ VideoMapping: Mapping of video IDs to sample indices and frame IDs.
+
+ Raises:
+ ValueError: If no Scalabel file has been loaded.
+ """
+ if len(self.scalabel_datasets) > 0:
+ return self.scalabel_datasets[
+ list(self.scalabel_datasets.keys())[0]
+ ].video_mapping
+ raise ValueError("No Scalabel file has been loaded.")
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Get single sample.
+
+ Args:
+ idx (int): Index of sample.
+
+ Returns:
+ DictData: sample at index in Vis4D input format.
+ """
+ # load camera frames
+ data_dict = {}
+
+ # metadata
+ video_name, frame_name = self._get_frame_key(idx)
+ data_dict[K.sample_names] = frame_name
+ data_dict[K.sequence_names] = video_name
+ data_dict[K.frame_ids] = frame_name.split("_")[0]
+
+ for view in self.views_to_load:
+ data_dict_view = {}
+
+ if view == "center":
+ # Lidar is only available in the center view
+ if K.points3d in self.keys_to_load:
+ data_dict_view.update(
+ self.scalabel_datasets["center/lidar"][idx]
+ )
+ else:
+ # Load data from Scalabel
+ for group in self._data_groups_to_load:
+ data_dict_view.update(
+ self.scalabel_datasets[f"{view}/{group}"][idx]
+ )
+
+ # Load data from bit masks
+ if K.seg_masks in self.keys_to_load:
+ data_dict_view[K.seg_masks] = self._load(
+ view, "semseg", "png", video_name, frame_name
+ )
+ if K.depth_maps in self.keys_to_load:
+ data_dict_view[K.depth_maps] = self._load(
+ view, "depth", "png", video_name, frame_name
+ )
+ if K.optical_flows in self.keys_to_load:
+ data_dict_view[K.optical_flows] = self._load(
+ view, "flow", "npz", video_name, frame_name
+ )
+ data_dict[view] = data_dict_view # type: ignore
+
+ return data_dict
diff --git a/vis4d/data/datasets/torchvision.py b/vis4d/data/datasets/torchvision.py
new file mode 100644
index 0000000000000000000000000000000000000000..a23c974b56f472079261807bee5bc816850e5a20
--- /dev/null
+++ b/vis4d/data/datasets/torchvision.py
@@ -0,0 +1,130 @@
+"""Provides functionalities to wrap torchvision datasets."""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any
+
+import numpy as np
+from PIL.Image import Image
+from torchvision.datasets import VisionDataset
+from torchvision.transforms import ToTensor
+
+from ..const import CommonKeys as K
+from ..typing import DictData
+from .base import Dataset
+
+
+class TorchvisionDataset(Dataset):
+ """Wrapper for torchvision datasets.
+
+ This class wraps torchvision datasets and converts them to the format that
+ is expected by the vis4d framework.
+
+ The return of the torchvisons dataset is passed to the data_converter,
+ which needs to be provided by the user. The data_converter is expected to
+ return a DictData object following the vis4d conventions.
+
+ For well defined dataformats, such as classification, there
+ are already implemented wrappers that can be used. See
+ `TorchvisionClassificationDataset` for an example.
+ """
+
+ def __init__( # type: ignore
+ self,
+ torchvision_ds: VisionDataset,
+ data_converter: Callable[[Any], DictData],
+ ) -> None:
+ """Creates a new instance of the class.
+
+ Args:
+ torchvision_ds (VisionDataset): Torchvision dataset that should be
+ converted.
+ data_converter (Callable[[Any], DictData]): Function that
+ converts the output of the torchvision datasets __getitem__
+ to the format expected by the vis4d framework.
+ """
+ super().__init__()
+ self.torchvision_ds = torchvision_ds
+ self.data_converter = data_converter
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Returns a new sample from the dataset.
+
+ Args:
+ idx (int): Index of the sample.
+
+ Returns:
+ DictData: Data in vis4d format.
+ """
+ return self.data_converter(self.torchvision_ds[idx])
+
+ def __len__(self) -> int:
+ """Returns the number of samples in the dataset.
+
+ Returns:
+ int: Length of the dataset.
+ """
+ return len(self.torchvision_ds)
+
+
+class TorchvisionClassificationDataset(TorchvisionDataset):
+ """Wrapper for torchvision classification datasets.
+
+ This class wraps torchvision classification datasets and converts them to
+ the format that is expected by the vis4d framework.
+
+ It expects the torchvision dataset to return a tuple of (image, class_id)
+ where the image is a PIL Image and the class_id is an integer.
+
+ If you want to use a torchvision dataset that returns a different format,
+ you can provide a custom data_converter function to the
+ `TorchvisionDataset` class.
+
+ The returned sample will have the following key, values:
+ images: ndarray of dimension (1, H, W, C)
+ categories: ndarray of dimension 1.
+
+ Example:
+ >>> from torchvision.datasets.mnist import MNIST
+ >>> ds = TorchvisionClassificationDataset(
+ >>> MNIST("data/mnist_ds", train=False)
+ >>> )
+ >>> data = next(iter(ds))
+ >>> print(data.keys)
+ dict_keys(['images', 'categories'])
+ """
+
+ def __init__(self, detection_ds: VisionDataset) -> None:
+ """Creates a new instance of the class.
+
+ Args:
+ detection_ds (VisionDataset): Torchvision dataset that
+ returns a tuple of (image, class_id) where the image is a PIL
+ Image and the class_id is an integer.
+ """
+ img_to_tensor = ToTensor()
+
+ def _data_converter(img_and_target: tuple[Image, int]) -> DictData:
+ """Converts the output of a torchvision dataset.
+
+ The output is converted to the format expected by the vis4d
+ framework.
+
+ Args:
+ img_and_target (tuple[Image, int]): Output of the datasets
+ __getitem__ method.
+
+ Returns:
+ DictData: Sample in vis4d format.
+ """
+ img, class_id = img_and_target
+ data: DictData = {}
+ data[K.images] = (
+ img_to_tensor(img).unsqueeze(0).permute(0, 2, 3, 1).numpy()
+ )
+ data[K.categories] = np.array([class_id], dtype=np.int64)
+
+ return data
+
+ super().__init__(detection_ds, _data_converter)
diff --git a/vis4d/data/datasets/util.py b/vis4d/data/datasets/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb232259b9f8a584f8d7ca4db3ba8f8b08bd1d5
--- /dev/null
+++ b/vis4d/data/datasets/util.py
@@ -0,0 +1,367 @@
+"""Utility functions for datasets."""
+
+from __future__ import annotations
+
+import copy
+import itertools
+import os
+import pickle
+from collections.abc import Callable, Sequence
+from datetime import datetime
+from io import BytesIO
+from typing import Any
+
+import numpy as np
+import plyfile
+from PIL import Image, ImageOps
+from tabulate import tabulate
+from termcolor import colored
+from torch.utils.data import Dataset
+
+from vis4d.common.distributed import broadcast, rank_zero_only
+from vis4d.common.imports import OPENCV_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.time import Timer
+from vis4d.common.typing import (
+ DictStrAny,
+ ListAny,
+ NDArrayFloat,
+ NDArrayI64,
+ NDArrayUI8,
+)
+
+from ..typing import DictData
+
+if OPENCV_AVAILABLE:
+ from cv2 import ( # pylint: disable=no-member,no-name-in-module
+ COLOR_BGR2RGB,
+ IMREAD_COLOR,
+ IMREAD_GRAYSCALE,
+ cvtColor,
+ imdecode,
+ )
+else:
+ raise ImportError("cv2 is not installed.")
+
+
+def im_decode(
+ im_bytes: bytes, mode: str = "RGB", backend: str = "PIL"
+) -> NDArrayUI8:
+ """Decode to image (numpy array, RGB) from bytes."""
+ assert mode in {
+ "BGR",
+ "RGB",
+ "L",
+ }, f"{mode} not supported for image decoding!"
+ if backend == "PIL":
+ pil_img_file = Image.open(BytesIO(bytearray(im_bytes)))
+ pil_img = ImageOps.exif_transpose(pil_img_file)
+ assert pil_img is not None, "Image could not be loaded!"
+ if pil_img.mode == "L": # pragma: no cover
+ if mode == "L":
+ img: NDArrayUI8 = np.array(pil_img)[..., None]
+ else:
+ # convert grayscale image to RGB
+ pil_img = pil_img.convert("RGB")
+ elif mode == "L": # pragma: no cover
+ raise ValueError("Cannot convert colorful image to grayscale!")
+ if mode == "BGR": # pragma: no cover
+ img = np.array(pil_img)[..., [2, 1, 0]]
+ elif mode == "RGB":
+ img = np.array(pil_img)
+ elif backend == "cv2": # pragma: no cover
+ if not OPENCV_AVAILABLE:
+ raise ImportError(
+ "Please install opencv-python to use cv2 backend!"
+ )
+ img_np: NDArrayUI8 = np.frombuffer(im_bytes, np.uint8)
+ img = imdecode( # type: ignore
+ img_np, IMREAD_GRAYSCALE if mode == "L" else IMREAD_COLOR
+ )
+ if mode == "RGB":
+ cvtColor(img, COLOR_BGR2RGB, img)
+ else:
+ raise NotImplementedError(f"Image backend {backend} not known!")
+ return img
+
+
+def ply_decode(ply_bytes: bytes, mode: str = "XYZI") -> NDArrayFloat:
+ """Decode to point clouds (numpy array) from bytes.
+
+ Args:
+ ply_bytes (bytes): The bytes of the ply file.
+ mode (str, optional): The point format of the ply file. If "XYZI", the
+ intensity channel will be included, otherwise only the XYZ
+ coordinates. Defaults to "XYZI".
+ """
+ assert mode in {
+ "XYZ",
+ "XYZI",
+ }, f"{mode} not supported for points decoding!"
+
+ plydata = plyfile.PlyData.read(BytesIO(bytearray(ply_bytes)))
+ num_points = plydata["vertex"].count
+ num_channels = 3 if mode == "XYZ" else 4
+ points = np.zeros((num_points, num_channels), dtype=np.float32)
+
+ points[:, 0] = plydata["vertex"].data["x"]
+ points[:, 1] = plydata["vertex"].data["y"]
+ points[:, 2] = plydata["vertex"].data["z"]
+ if mode == "XYZI":
+ points[:, 3] = plydata["vertex"].data["intensity"]
+ return points
+
+
+def npy_decode(npy_bytes: bytes, key: str | None = None) -> NDArrayFloat:
+ """Decode to numpy array from npy/npz file bytes."""
+ data = np.load(BytesIO(bytearray(npy_bytes)))
+ if key is not None:
+ data = data[key]
+ return data
+
+
+def filter_by_keys(
+ data_dict: DictData, keys_to_keep: Sequence[str]
+) -> DictData:
+ """Filter a dictionary by keys.
+
+ Args:
+ data_dict (DictData): The dictionary to filter.
+ keys_to_keep (list[str]): The keys to keep.
+
+ Returns:
+ DictData: The filtered dictionary.
+ """
+ return {key: data_dict[key] for key in keys_to_keep if key in data_dict}
+
+
+def get_used_data_groups(
+ data_groups: dict[str, list[str]], keys: list[str]
+) -> list[str]:
+ """Get the data groups that are used by the given keys.
+
+ Args:
+ data_groups (dict[str, list[str]]): The data groups.
+ keys (list[str]): The keys to check.
+
+ Returns:
+ list[str]: The used data groups.
+ """
+ used_groups = []
+ for group_name, group_keys in data_groups.items():
+ if not group_keys:
+ continue
+ if any(key in keys for key in group_keys):
+ used_groups.append(group_name)
+ return used_groups
+
+
+def to_onehot(categories: NDArrayI64, num_classes: int) -> NDArrayFloat:
+ """Transform integer categorical labels to onehot vectors.
+
+ Args:
+ categories (NDArrayI64): Integer categorical labels of shape (N, ).
+ num_classes (int): Number of classes.
+
+ Returns:
+ NDArrayFloat: Onehot vector of shape (N, num_classes).
+ """
+ _eye = np.eye(num_classes, dtype=np.float32)
+ return _eye[categories]
+
+
+class CacheMappingMixin:
+ """Caches a mapping for fast I/O and multi-processing.
+
+ This class provides functionality for caching a mapping from dataset index
+ requested by a call on __getitem__ to a dictionary that holds relevant
+ information for loading the sample in question from the disk.
+ Caching the mapping reduces startup time by loading the mapping instead of
+ re-computing it at every startup.
+
+ NOTE: Make sure your annotations file is up-to-date. Otherwise, the mapping
+ will be wrong and you will get wrong samples.
+ """
+
+ @rank_zero_only
+ def _load_mapping_data(
+ self,
+ generate_map_func: Callable[[], list[DictStrAny]],
+ cache_as_binary: bool,
+ cached_file_path: str | None,
+ ) -> ListAny:
+ """Load possibly cached mapping via generate_map_func.
+
+ Args:
+ generate_map_func (Callable[[], list[DictStrAny]]): The function
+ that generates the mapping.
+ cache_as_binary (bool): Whether to cache the mapping as binary.
+ cached_file_path (str | None): The path to the cached mapping file.
+ """
+ if cache_as_binary:
+ assert (
+ cached_file_path is not None
+ ), "cached_file_path must be set if cache_as_binary is True!"
+ if not os.path.exists(cached_file_path):
+ rank_zero_info(
+ f"Did not find {cached_file_path}, generating it..."
+ )
+ data = generate_map_func()
+ os.makedirs(os.path.dirname(cached_file_path), exist_ok=True)
+ with open(cached_file_path, "wb") as file:
+ file.write(pickle.dumps(data))
+ else:
+ dt = datetime.fromtimestamp(os.stat(cached_file_path).st_mtime)
+ rank_zero_info(
+ f"Found {cached_file_path} generated at {dt.isoformat()} "
+ + "and loading it..."
+ )
+ with open(cached_file_path, "rb") as file:
+ data = pickle.loads(file.read())
+ else:
+ rank_zero_info(f"Generating {self} data mapping...")
+ data = generate_map_func()
+ return data
+
+ def _load_mapping(
+ self,
+ generate_map_func: Callable[[], list[DictStrAny]],
+ filter_func: Callable[[ListAny], ListAny] = lambda x: x,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ ) -> tuple[DatasetFromList, int]:
+ """Load cached mapping or generate if not exists.
+
+ Args:
+ generate_map_func (Callable[[], list[DictStrAny]]): The function
+ that generates the mapping.
+ filter_func (Callable[[ListAny], ListAny], optional): The function
+ that filters the mapping. Defaults to lambda x: x.
+ cache_as_binary (bool, optional): Whether to cache the mapping as
+ binary. Defaults to True.
+ cached_file_path (str | None, optional): The path to the cached
+ mapping file. Defaults to None.
+ """
+ timer = Timer()
+ dataset = self._load_mapping_data(
+ generate_map_func, cache_as_binary, cached_file_path
+ )
+ original_len = 0
+ if dataset is not None:
+ original_len = len(dataset)
+ dataset = filter_func(dataset)
+ dataset = DatasetFromList(dataset)
+ dataset = broadcast(dataset)
+ original_len = broadcast(original_len)
+ rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.")
+ return dataset, original_len
+
+
+# reference:
+# https://github.com/facebookresearch/detectron2/blob/7f8f29deae278b75625872c8a0b00b74129446ac/detectron2/data/common.py#L109
+class DatasetFromList(Dataset): # type: ignore
+ """Wrap a list to a torch Dataset.
+
+ We serialize and wrap big python objects in a torch.Dataset due to a
+ memory leak when dealing with large python objects using multiple workers.
+ See: https://github.com/pytorch/pytorch/issues/13246
+ """
+
+ def __init__(
+ self, lst: ListAny, deepcopy: bool = False, serialize: bool = True
+ ):
+ """Creates an instance of the class.
+
+ Args:
+ lst: a list which contains elements to produce.
+ deepcopy: whether to deepcopy the element when producing it, s.t.
+ the result can be modified in place without affecting the source
+ in the list.
+ serialize: whether to hold memory using serialized objects. When
+ enabled, data loader workers can use shared RAM from master
+ process instead of making a copy.
+ """
+ self._copy = deepcopy
+ self._serialize = serialize
+
+ def _serialize(data: Any) -> NDArrayUI8: # type: ignore
+ """Serialize python object to numpy array."""
+ buffer = pickle.dumps(data, protocol=-1)
+ return np.frombuffer(buffer, dtype=np.uint8)
+
+ if self._serialize:
+ self._lst = [_serialize(x) for x in lst]
+ self._addr: NDArrayI64 = np.asarray(
+ [len(x) for x in self._lst], dtype=np.int64
+ )
+ self._addr = np.cumsum(self._addr)
+ self._lst = np.concatenate(self._lst) # type: ignore
+ else:
+ self._lst = lst # pragma: no cover
+
+ def __len__(self) -> int:
+ """Return len of list."""
+ if self._serialize:
+ return len(self._addr)
+ return len(self._lst) # pragma: no cover
+
+ def __getitem__(self, idx: int) -> Any: # type: ignore
+ """Return item of list at idx."""
+ if self._serialize:
+ start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
+ end_addr = self._addr[idx].item()
+ bytes_ = memoryview(self._lst[start_addr:end_addr]) # type: ignore
+ return pickle.loads(bytes_)
+ if self._copy: # pragma: no cover
+ return copy.deepcopy(self._lst[idx])
+
+ return self._lst[idx] # pragma: no cover
+
+
+def print_class_histogram(class_frequencies: dict[str, int]) -> None:
+ """Prints out given class frequencies."""
+ if len(class_frequencies) == 0: # pragma: no cover
+ return
+
+ class_names = list(class_frequencies.keys())
+ frequencies = list(class_frequencies.values())
+ num_classes = len(class_names)
+
+ n_cols = min(6, len(class_names) * 2)
+
+ def short_name(name: str) -> str:
+ """Make long class names shorter."""
+ if len(name) > 13:
+ return name[:11] + ".." # pragma: no cover
+ return name
+
+ data = list(
+ itertools.chain(
+ *[
+ [short_name(class_names[i]), int(v)]
+ for i, v in enumerate(frequencies)
+ ]
+ )
+ )
+ total_num_instances = sum(data[1::2]) # type: ignore
+ data.extend([None] * (n_cols - (len(data) % n_cols)))
+ if num_classes > 1:
+ data.extend(["total", total_num_instances])
+
+ table = tabulate(
+ itertools.zip_longest(*[data[i::n_cols] for i in range(n_cols)]),
+ headers=["category", "#instances"] * (n_cols // 2),
+ tablefmt="pipe",
+ numalign="left",
+ stralign="center",
+ )
+
+ rank_zero_info(
+ f"Distribution of instances among all {num_classes} categories:\n"
+ + colored(table, "cyan")
+ )
+
+
+def get_category_names(det_mapping: dict[str, int]) -> list[str]:
+ """Get category names from a mapping of category names to ids."""
+ return sorted(det_mapping, key=det_mapping.get) # type: ignore
diff --git a/vis4d/data/io/__init__.py b/vis4d/data/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f167ac5fc413a17d4248f900d3eb0545fb8fe2f3
--- /dev/null
+++ b/vis4d/data/io/__init__.py
@@ -0,0 +1,13 @@
+"""Init io module."""
+
+from .base import DataBackend
+from .file import FileBackend
+from .hdf5 import HDF5Backend
+from .zip import ZipBackend
+
+__all__ = [
+ "DataBackend",
+ "HDF5Backend",
+ "FileBackend",
+ "ZipBackend",
+]
diff --git a/vis4d/data/io/base.py b/vis4d/data/io/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e2c7fee9066ce6b575df96f412e2ff50ff31948
--- /dev/null
+++ b/vis4d/data/io/base.py
@@ -0,0 +1,84 @@
+"""Backends for the data types a dataset of interest is saved in.
+
+Those can be used to load data from diverse storage backends, e.g. from HDF5
+files which are more suitable for data centers. The naive backend is the
+FileBackend, which loads from / saves to file naively.
+"""
+
+from abc import abstractmethod
+from typing import Literal
+
+
+class DataBackend:
+ """Abstract class of storage backends.
+
+ All backends need to implement three functions: get(), set() and exists().
+ get() reads the file as a byte stream and set() writes a byte stream to a
+ file. exists() checks if a certain filepath exists.
+ """
+
+ @abstractmethod
+ def set(
+ self, filepath: str, content: bytes, mode: Literal["w", "a"]
+ ) -> None:
+ """Set the file content at the given filepath.
+
+ Args:
+ filepath (str): The filepath to store the data at.
+ content (bytes): The content to store as bytes.
+ mode (str): The mode to open the file in.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def get(self, filepath: str) -> bytes:
+ """Get the file content at the given filepath as bytes.
+
+ Args:
+ filepath (str): The filepath to retrieve the data from."
+
+ Returns:
+ bytes: The content of the file as bytes.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def exists(self, filepath: str) -> bool:
+ """Check if filepath exists.
+
+ Args:
+ filepath (str): The filepath to check.
+
+ Returns:
+ bool: True if the filepath exists, False otherwise.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def isfile(self, filepath: str) -> bool:
+ """Check if filepath is a file.
+
+ Args:
+ filepath (str): The filepath to check.
+
+ Returns:
+ bool: True if the filepath is a file, False otherwise.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def listdir(self, filepath: str) -> list[str]:
+ """List all files in a directory.
+
+ Args:
+ filepath (str): The directory to list.
+
+ Returns:
+ list[str]: A list of all files in the directory.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self) -> None:
+ """Close all opened files in the backend."""
+ raise NotImplementedError
diff --git a/vis4d/data/io/file.py b/vis4d/data/io/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..abfbe85dde75ea07fcdb33c62510234ba8fa68cc
--- /dev/null
+++ b/vis4d/data/io/file.py
@@ -0,0 +1,83 @@
+"""Standard backend for local files on a hard drive.
+
+This backends loads data from and saves data to the local hard drive.
+"""
+
+import os
+from typing import Literal
+
+from .base import DataBackend
+
+
+class FileBackend(DataBackend):
+ """Raw file from hard disk data backend."""
+
+ def isfile(self, filepath: str) -> bool:
+ """Check if filepath is a file.
+
+ Args:
+ filepath (str): Path to file.
+
+ Returns:
+ bool: True if file exists, False otherwise.
+ """
+ return os.path.isfile(filepath)
+
+ def listdir(self, filepath: str) -> list[str]:
+ """List all files in the directory.
+
+ Args:
+ filepath (str): Path to file.
+
+ Returns:
+ list[str]: List of all files in the directory.
+ """
+ return sorted(os.listdir(filepath))
+
+ def exists(self, filepath: str) -> bool:
+ """Check if filepath exists.
+
+ Args:
+ filepath (str): Path to file.
+
+ Returns:
+ bool: True if file exists, False otherwise.
+ """
+ return os.path.exists(filepath)
+
+ def set(
+ self, filepath: str, content: bytes, mode: Literal["w", "a"] = "w"
+ ) -> None:
+ """Write the file content to disk.
+
+ Args:
+ filepath (str): Path to file.
+ content (bytes): Content to write in bytes.
+ mode (Literal["w", "a"], optional): Overwrite or append mode.
+ Defaults to "w".
+ """
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ mode_binary: Literal["wb", "ab"] = "wb" if mode == "w" else "ab"
+ with open(filepath, mode_binary) as f:
+ f.write(content)
+
+ def get(self, filepath: str) -> bytes:
+ """Get file content as bytes.
+
+ Args:
+ filepath (str): Path to file.
+
+ Raises:
+ FileNotFoundError: If filepath does not exist.
+
+ Returns:
+ bytes: File content as bytes.
+ """
+ if not self.exists(filepath):
+ raise FileNotFoundError(f"File not found:" f" {filepath}")
+ with open(filepath, "rb") as f:
+ value_buf = f.read()
+ return value_buf
+
+ def close(self) -> None:
+ """No need to close manually."""
diff --git a/vis4d/data/io/hdf5.py b/vis4d/data/io/hdf5.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7eee55a75b830d3ed40055416cd56d8261eec0f
--- /dev/null
+++ b/vis4d/data/io/hdf5.py
@@ -0,0 +1,242 @@
+"""Hdf5 data backend.
+
+This backend works with filepaths pointing to valid HDF5 files. We assume that
+the given HDF5 file contains the whole dataset associated to this backend.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Literal
+
+import numpy as np
+
+from vis4d.common.imports import H5PY_AVAILABLE
+
+from .base import DataBackend
+
+if H5PY_AVAILABLE:
+ import h5py
+ from h5py import File
+else:
+ raise ImportError("Please install h5py to enable HDF5Backend.")
+
+
+class HDF5Backend(DataBackend):
+ """Backend for loading data from HDF5 files.
+
+ This backend works with filepaths pointing to valid HDF5 files. We assume
+ that the given HDF5 file contains the whole dataset associated to this
+ backend.
+
+ You can use the provided script at vis4d/data/datasets/to_hdf5.py to
+ convert your dataset to the expected hdf5 format before using this backend.
+ """
+
+ def __init__(self) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ if not H5PY_AVAILABLE:
+ raise ImportError("Please install h5py to enable HDF5Backend.")
+ self.db_cache: dict[str, File] = {}
+
+ @staticmethod
+ def _get_hdf5_path(
+ filepath: str, allow_omitted_ext: bool = True
+ ) -> tuple[str, list[str]]:
+ """Get .hdf5 path and keys from filepath.
+
+ Args:
+ filepath (str): The filepath to retrieve the data from.
+ Should have the following format: 'path/to/file.hdf5/key1/key2'
+ allow_omitted_ext (bool, optional): Whether to allow omitted
+ extension, in which case the backend will try to append
+ '.hdf5' to the filepath. Defaults to True.
+
+ Returns:
+ tuple[str, list[str]]: The .hdf5 path and the keys to retrieve.
+
+ Examples:
+ >>> HDF5Backend._get_hdf5_path("path/to/file.hdf5/key1/key2")
+ ("path/to/file.hdf5", ["key2", "key1"])
+ >>> HDF5Backend._get_hdf5_path("path/to/file/key1/key2", True)
+ ("path/to/file.hdf5", ["key2", "key1"]) # if file.hdf5 exists and
+ # is a valid hdf5 file
+ """
+ filepath_as_list = filepath.split("/")
+ keys = []
+
+ while True:
+ if filepath.endswith(".hdf5") or filepath == "":
+ break
+ if allow_omitted_ext and h5py.is_hdf5(filepath + ".hdf5"):
+ filepath = filepath + ".hdf5"
+ break
+ keys.append(filepath_as_list.pop())
+ filepath = "/".join(filepath_as_list)
+ return filepath, keys
+
+ def exists(self, filepath: str) -> bool:
+ """Check if filepath exists.
+
+ Args:
+ filepath (str): Path to file.
+
+ Returns:
+ bool: True if file exists, False otherwise.
+ """
+ hdf5_path, keys = self._get_hdf5_path(filepath)
+ if not os.path.exists(hdf5_path):
+ return False
+ value_buf = self._get_client(hdf5_path, "r")
+
+ while keys:
+ value_buf = value_buf.get(keys.pop())
+ if value_buf is None:
+ return False
+ return True
+
+ def set(
+ self, filepath: str, content: bytes, mode: Literal["w", "a"] = "a"
+ ) -> None:
+ """Set the file content.
+
+ Args:
+ filepath: path/to/file.hdf5/key1/key2/key3
+ content: Bytes to be written to entry key3 within group key2
+ within another group key1, for example.
+ mode: "w" to overwrite the file, "a" to append to it.
+
+ Raises:
+ ValueError: If filepath is not a valid .hdf5 file
+ """
+ if ".hdf5" not in filepath:
+ raise ValueError(f"{filepath} not a valid .hdf5 filepath!")
+ hdf5_path, keys_str = filepath.split(".hdf5")
+ key_list = keys_str.split("/")
+ file = self._get_client(hdf5_path + ".hdf5", mode)
+ if len(key_list) > 1:
+ group_str = "/".join(key_list[:-1])
+ if group_str == "":
+ group_str = "/"
+
+ group = file[group_str]
+ key = key_list[-1]
+ group.create_dataset(
+ key, data=np.frombuffer(content, dtype="uint8")
+ )
+
+ def _get_client(self, hdf5_path: str, mode: str) -> File:
+ """Get HDF5 client from path.
+
+ Args:
+ hdf5_path (str): Path to HDF5 file.
+ mode (str): Mode to open the file in.
+
+ Returns:
+ File: the hdf5 file.
+ """
+ if hdf5_path not in self.db_cache:
+ client = File(hdf5_path, mode, swmr=True, libver="latest")
+ self.db_cache[hdf5_path] = [client, mode]
+ else:
+ client, current_mode = self.db_cache[hdf5_path]
+ if current_mode != mode:
+ client.close()
+ client = File(hdf5_path, mode, swmr=True, libver="latest")
+ self.db_cache[hdf5_path] = [client, mode]
+ return client
+
+ def get(self, filepath: str) -> bytes:
+ """Get values according to the filepath as bytes.
+
+ Args:
+ filepath (str): The path to the file. It consists of an HDF5 path
+ together with the relative path inside it, e.g.: "/path/to/
+ file.hdf5/key/subkey/data". If no .hdf5 given inside filepath,
+ the function will search for the first .hdf5 file present in
+ the path, i.e. "/path/to/file/key/subkey/data" will also /key/
+ subkey/data from /path/to/file.hdf5.
+
+ Raises:
+ FileNotFoundError: If no suitable file exists.
+ ValueError: If key not found inside hdf5 file.
+
+ Returns:
+ bytes: The file content in bytes
+ """
+ hdf5_path, keys = self._get_hdf5_path(filepath)
+
+ if not os.path.exists(hdf5_path):
+ raise FileNotFoundError(
+ f"Corresponding HDF5 file not found:" f" {filepath}"
+ )
+ value_buf = self._get_client(hdf5_path, "r")
+ url = "/".join(reversed(keys))
+ while keys:
+ value_buf = value_buf.get(keys.pop())
+ if value_buf is None:
+ raise ValueError(f"Value {url} not found in {hdf5_path}!")
+
+ return bytes(value_buf[()])
+
+ def isfile(self, filepath: str) -> bool:
+ """Check if filepath is a file.
+
+ Args:
+ filepath (str): Path to file.
+
+ Raises:
+ FileNotFoundError: If no suitable file exists.
+ ValueError: If key not found inside hdf5 file.
+
+ Returns:
+ bool: True if file exists, False otherwise.
+ """
+ hdf5_path, keys = self._get_hdf5_path(filepath)
+ if not os.path.exists(hdf5_path):
+ raise FileNotFoundError(
+ f"Corresponding HDF5 file not found:" f" {filepath}"
+ )
+ value_buf = self._get_client(hdf5_path, "r")
+ url = "/".join(reversed(keys))
+ while keys:
+ value_buf = value_buf.get(keys.pop())
+ if value_buf is None:
+ raise ValueError(f"Value {url} not found in {hdf5_path}!")
+ return not isinstance(value_buf, h5py.Group)
+
+ def listdir(self, filepath: str) -> list[str]:
+ """List all files in the given directory.
+
+ Args:
+ filepath (str): Path to directory.
+
+ Raises:
+ FileNotFoundError: If no suitable file exists.
+ ValueError: If key not found inside hdf5 file.
+
+ Returns:
+ list[str]: List of files in the given directory.
+ """
+ hdf5_path, keys = self._get_hdf5_path(filepath)
+ if not os.path.exists(hdf5_path):
+ raise FileNotFoundError(
+ f"Corresponding HDF5 file not found:" f" {filepath}"
+ )
+ value_buf = self._get_client(hdf5_path, "r")
+ url = "/".join(reversed(keys))
+ while keys:
+ value_buf = value_buf.get(keys.pop())
+ if value_buf is None:
+ raise ValueError(f"Value {url} not found in {hdf5_path}!")
+ if not isinstance(value_buf, h5py.Group):
+ raise ValueError(f"Value {url} is not a group in {hdf5_path}!")
+
+ return sorted(list(value_buf.keys()))
+
+ def close(self) -> None:
+ """Close all opened HDF5 files."""
+ for client, _ in self.db_cache.values():
+ client.close()
+ self.db_cache.clear()
diff --git a/vis4d/data/io/to_hdf5.py b/vis4d/data/io/to_hdf5.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a2161a5ba4d73de179b80347a639de443022acc
--- /dev/null
+++ b/vis4d/data/io/to_hdf5.py
@@ -0,0 +1,76 @@
+"""Script to convert a dataset to hdf5 format."""
+
+from __future__ import annotations
+
+import argparse
+import os
+
+import numpy as np
+from tqdm import tqdm
+
+from vis4d.common.imports import H5PY_AVAILABLE
+
+if H5PY_AVAILABLE:
+ import h5py
+else:
+ raise ImportError("Please install h5py to enable HDF5Backend.")
+
+
+def convert_dataset(source_dir: str) -> None:
+ """Convert a dataset to HDF5 format.
+
+ This function converts an arbitary dictionary to an HDF5 file. The keys
+ inside the HDF5 file preserve the directory structure of the original.
+
+ As an example, if you convert "/path/to/dataset" to HDF5, the resulting
+ file will be: "/path/to/dataset.hdf5". The file "relative/path/to/file"
+ will be stored at "relative/path/to/file" inside /path/to/dataset.hdf5.
+
+ Args:
+ source_dir (str): The path to the dataset to convert.
+ """
+ if not os.path.exists(source_dir):
+ raise FileNotFoundError(f"No such file or directory: {source_dir}")
+
+ source_dir = os.path.join(source_dir, "") # must end with trailing slash
+ hdf5_path = source_dir.rstrip("/") + ".hdf5"
+ if os.path.exists(hdf5_path):
+ print(f"File {hdf5_path} already exists! Skipping {source_dir}")
+ return
+
+ print(f"Converting dataset at: {source_dir}")
+ hdf5_file = h5py.File(hdf5_path, mode="w")
+ sub_dirs = list(os.walk(source_dir))
+ file_count = sum(len(files) for (_, _, files) in sub_dirs)
+
+ with tqdm(total=file_count) as pbar:
+ for root, _, files in sub_dirs:
+ g_name = root.replace(source_dir, "")
+ g = hdf5_file.create_group(g_name) if g_name else hdf5_file
+ for f in files:
+ filepath = os.path.join(root, f)
+ if os.path.isfile(filepath):
+ with open(filepath, "rb") as fp:
+ file_content = fp.read()
+ g.create_dataset(
+ f, data=np.frombuffer(file_content, dtype="uint8")
+ )
+ pbar.update()
+
+ hdf5_file.close()
+ print("done.")
+
+
+if __name__ == "__main__": # pragma: no cover
+ parser = argparse.ArgumentParser(
+ description="Converts a dataset at the specified path to hdf5. The "
+ "local directory structure is preserved in the hdf5 file."
+ )
+ parser.add_argument(
+ "-p",
+ "--path",
+ required=True,
+ help="path to the root folder of a specific dataset to convert",
+ )
+ args = parser.parse_args()
+ convert_dataset(args.path)
diff --git a/vis4d/data/io/util.py b/vis4d/data/io/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d23dbc6bf8df17c54f730346c9e5d4231f3f89c
--- /dev/null
+++ b/vis4d/data/io/util.py
@@ -0,0 +1,21 @@
+"""Data I/O Utilities."""
+
+from __future__ import annotations
+
+import sys
+
+
+def str_decode(str_bytes: bytes, encoding: None | str = None) -> str:
+ """Decode to string from bytes.
+
+ Args:
+ str_bytes (bytes): Bytes to decode.
+ encoding (None | str): Encoding to use. Defaults to None which is
+ equivalent to sys.getdefaultencoding().
+
+ Returns:
+ str: Decoded string.
+ """
+ if encoding is None:
+ encoding = sys.getdefaultencoding()
+ return str_bytes.decode(encoding)
diff --git a/vis4d/data/io/zip.py b/vis4d/data/io/zip.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ea4f7cddc438ca8699bdc5e557081504956d647
--- /dev/null
+++ b/vis4d/data/io/zip.py
@@ -0,0 +1,206 @@
+"""Zip data backend.
+
+This backend works with filepaths pointing to valid Zip files. We assume that
+the given Zip file contains the whole dataset associated to this backend.
+"""
+
+from __future__ import annotations
+
+import os
+import zipfile
+from typing import Literal
+from zipfile import ZipFile
+
+from .base import DataBackend
+
+
+class ZipBackend(DataBackend):
+ """Backend for loading data from Zip files.
+
+ This backend works with filepaths pointing to valid Zip files. We assume
+ that the given Zip file contains the whole dataset associated to this
+ backend.
+ """
+
+ def __init__(self) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.db_cache: dict[str, tuple[ZipFile, str]] = {}
+
+ @staticmethod
+ def _get_zip_path(
+ filepath: str, allow_omitted_ext: bool = True
+ ) -> tuple[str, list[str]]:
+ """Get .zip path and keys from filepath.
+
+ Args:
+ filepath (str): The filepath to retrieve the data from.
+ Should have the following format: 'path/to/file.zip/key1/key2'
+ allow_omitted_ext (bool, optional): Whether to allow omitted
+ extension, in which case the backend will try to append
+ '.zip' to the filepath. Defaults to True.
+
+ Returns:
+ tuple[str, list[str]]: The .hdf5 path and the keys to retrieve.
+
+ Examples:
+ >>> _get_zip_path("path/to/file.zip/key1/key2")
+ ("path/to/file.zip", ["key2", "key1"])
+ >>> _get_zip_path("path/to/file/key1/key2", True)
+ ("path/to/file.zip", ["key2", "key1"]) # if file.hdf5 exists and
+ # is a valid hdf5 file
+ """
+ filepath_as_list = filepath.split("/")
+ keys = []
+
+ while True:
+ if filepath.endswith(".zip") or filepath == "":
+ break
+ if allow_omitted_ext and zipfile.is_zipfile(filepath + ".zip"):
+ filepath = filepath + ".zip"
+ break
+ keys.append(filepath_as_list.pop())
+ filepath = "/".join(filepath_as_list)
+ return filepath, keys
+
+ def exists(self, filepath: str) -> bool:
+ """Check if filepath exists.
+
+ Args:
+ filepath (str): Path to file.
+
+ Returns:
+ bool: True if file exists, False otherwise.
+ """
+ zip_path, keys = self._get_zip_path(filepath)
+ if not os.path.exists(zip_path):
+ return False
+ file = self._get_client(zip_path, "r")
+ url = "/".join(reversed(keys))
+ return url in file.namelist()
+
+ def set(
+ self, filepath: str, content: bytes, mode: Literal["w", "a"] = "w"
+ ) -> None:
+ """Write the file content to the zip file.
+
+ Args:
+ filepath: path/to/file.zip/key1/key2/key3
+ content: Bytes to be written to entry key3 within group key2
+ within another group key1, for example.
+ mode: Mode to open the file in. "w" for writing a file, "a" for
+ appending to existing file.
+
+ Raises:
+ ValueError: If filepath is not a valid .zip file
+ NotImplementedError: If the method is not implemented.
+ """
+ if ".zip" not in filepath:
+ raise ValueError(f"{filepath} not a valid .zip filepath!")
+
+ zip_path, keys = self._get_zip_path(filepath)
+ zip_file = self._get_client(zip_path, mode)
+ url = "/".join(reversed(keys))
+ zip_file.writestr(url, content)
+
+ def _get_client(
+ self, zip_path: str, mode: Literal["r", "w", "a", "x"]
+ ) -> ZipFile:
+ """Get Zip client from path.
+
+ Args:
+ zip_path (str): Path to Zip file.
+ mode (str): Mode to open the file in.
+
+ Returns:
+ ZipFile: the hdf5 file.
+ """
+ assert len(mode) == 1, "Mode must be a single character for zip file."
+ if zip_path not in self.db_cache:
+ os.makedirs(os.path.dirname(zip_path), exist_ok=True)
+ client = ZipFile(zip_path, mode)
+ self.db_cache[zip_path] = (client, mode)
+ else:
+ client, current_mode = self.db_cache[zip_path]
+ if current_mode != mode:
+ client.close()
+ client = ZipFile( # pylint:disable=consider-using-with
+ zip_path, mode
+ )
+ self.db_cache[zip_path] = (client, mode)
+ return client
+
+ def get(self, filepath: str) -> bytes:
+ """Get values according to the filepath as bytes.
+
+ Args:
+ filepath (str): The path to the file. It consists of an Zip path
+ together with the relative path inside it, e.g.: "/path/to/
+ file.zip/key/subkey/data". If no .zip given inside filepath,
+ the function will search for the first .zip file present in
+ the path, i.e. "/path/to/file/key/subkey/data" will also /key/
+ subkey/data from /path/to/file.zip.
+
+ Raises:
+ ZipFileNotFoundError: If no suitable file exists.
+ OSError: If the file cannot be opened.
+ ValueError: If key not found inside zip file.
+
+ Returns:
+ bytes: The file content in bytes
+ """
+ zip_path, keys = self._get_zip_path(filepath)
+
+ if not os.path.exists(zip_path):
+ raise FileNotFoundError(
+ f"Corresponding zip file not found:" f" {filepath}"
+ )
+ zip_file = self._get_client(zip_path, "r")
+ url = "/".join(reversed(keys))
+ try:
+ with zip_file.open(url) as zf:
+ content = zf.read()
+ except KeyError as e:
+ raise ValueError(f"Value '{url}' not found in {zip_path}!") from e
+ return bytes(content)
+
+ def listdir(self, filepath: str) -> list[str]:
+ """List all files in the given directory.
+
+ Args:
+ filepath (str): The path to the directory.
+
+ Returns:
+ list[str]: List of all files in the given directory.
+ """
+ zip_path, keys = self._get_zip_path(filepath)
+ zip_file = self._get_client(zip_path, "r")
+ url = "/".join(reversed(keys))
+ files = [
+ os.path.basename(key)
+ for key in zip_file.namelist()
+ if key.startswith(url) and os.path.basename(key) != ""
+ ]
+ return sorted(files)
+
+ def isfile(self, filepath: str) -> bool:
+ """Check if filepath is a file.
+
+ Args:
+ filepath (str): Path to file.
+
+ Returns:
+ bool: True if file exists, False otherwise.
+ """
+ zip_path, keys = self._get_zip_path(filepath)
+ if not os.path.exists(zip_path):
+ return False
+ zip_file = self._get_client(zip_path, "r")
+ url = "/".join(reversed(keys))
+ return url in zip_file.namelist()
+
+ def close(self) -> None:
+ """Close all opened Zip files."""
+ for client, _ in self.db_cache.values():
+ client.close()
+ self.db_cache = {}
diff --git a/vis4d/data/iterable.py b/vis4d/data/iterable.py
new file mode 100644
index 0000000000000000000000000000000000000000..6caca4656dc11088e8abdd25e19ae9c001dc097b
--- /dev/null
+++ b/vis4d/data/iterable.py
@@ -0,0 +1,100 @@
+"""Iterable datasets."""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Callable, Iterator
+
+import numpy as np
+from torch.utils.data import Dataset, IterableDataset, get_worker_info
+
+from .typing import DictData
+
+
+class SubdividingIterableDataset(IterableDataset[DictData]):
+ """Subdivides a given dataset into smaller chunks.
+
+ This also adds a field called 'index' (DataKeys.index) to the data
+ struct in order to relate the data to the source index.
+
+ Example: Given a dataset (ds) that outputs tensors of the shape (10, 3):
+ sub_ds = SubdividingIterableDataset(ds, n_samples_per_batch = 5)
+
+ next(iter(sub_ds))['key'].shape
+ >> torch.Size([5, 3])
+
+ next(DataLoader(sub_ds, batch_size = 4))['key'].shape
+ >> torch.size([4,5,3])
+
+ Assuming the dataset returns two entries with shape (10,3):
+ [e['index'].item() for e in sub_ds]
+ >> [0,0,1,1]
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset[DictData],
+ n_samples_per_batch: int,
+ preprocess_fn: Callable[
+ [list[DictData]], list[DictData]
+ ] = lambda x: x,
+ ) -> None:
+ """Creates a new Dataset.
+
+ Args:
+ dataset (Dataset): The dataset which should be subdivided.
+ n_samples_per_batch: How many samples each batch should contain.
+ The first dimension of dataset[0].shape must be divisible by
+ this number.
+ preprocess_fn (Callable[[list[DictData]], list[DictData]):
+ Preprocessing function. Defaults to identity.
+ """
+ super().__init__()
+ self.dataset = dataset
+ self.n_samples_per_batch = n_samples_per_batch
+ self.preprocess_fn = preprocess_fn
+
+ def __getitem__(self, index: int) -> DictData:
+ """Indexing is not supported for IterableDatasets."""
+ raise NotImplementedError("IterableDataset does not support indeing")
+
+ def __iter__(self) -> Iterator[DictData]:
+ """Iterates over the dataset, supporting distributed sampling."""
+ worker_info = get_worker_info()
+ if worker_info is None:
+ # not distributed
+ num_workers = 1
+ worker_id = 0
+ else: # pragma: no cover
+ num_workers = worker_info.num_workers
+ worker_id = worker_info.id
+
+ assert hasattr(
+ self.dataset, "__len__"
+ ), "Dataset must have __len__ in order to be subdivided."
+ n_samples = len(self.dataset)
+ for i in range(math.ceil(n_samples / num_workers)):
+ data_idx = i * num_workers + worker_id
+ if data_idx >= n_samples:
+ continue
+ data_sample = self.dataset[data_idx]
+
+ n_elements = list((data_sample.values()))[0].shape[0]
+ for idx in range(int(n_elements / self.n_samples_per_batch)):
+ # This is kind of ugly
+ # this field defines from which source the data was loaded
+ # (first entry, second entry, ...)
+ # this is required if we e.g. want to subdivide a room that is
+ # too big into equal sized chunks and stick them back together
+ # for visualizaton
+ out_data: DictData = {"source_index": np.ndarray([data_idx])}
+ for key in data_sample:
+ start_idx = idx * self.n_samples_per_batch
+ end_idx = (idx + 1) * self.n_samples_per_batch
+ if (len(data_sample[key])) < self.n_samples_per_batch:
+ out_data[key] = data_sample[key]
+ else:
+ out_data[key] = data_sample[key][
+ start_idx:end_idx, ...
+ ]
+ yield self.preprocess_fn([out_data])[0]
diff --git a/vis4d/data/loader.py b/vis4d/data/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ecf3c08953a7fe61507cc185176329629111586
--- /dev/null
+++ b/vis4d/data/loader.py
@@ -0,0 +1,281 @@
+"""Dataloader utility functions."""
+
+from __future__ import annotations
+
+import random
+import warnings
+from collections.abc import Callable, Sequence
+
+import numpy as np
+import torch
+from torch.utils.data import (
+ DataLoader,
+ Dataset,
+ RandomSampler,
+ SequentialSampler,
+)
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import Sampler
+
+from vis4d.common.distributed import get_rank, get_world_size
+
+from .const import CommonKeys as K
+from .data_pipe import DataPipe
+from .datasets.base import VideoDataset
+from .samplers import AspectRatioBatchSampler, VideoInferenceSampler
+from .transforms import compose
+from .transforms.to_tensor import ToTensor
+from .typing import DictData, DictDataOrList
+
+DEFAULT_COLLATE_KEYS = (
+ K.seg_masks,
+ K.extrinsics,
+ K.intrinsics,
+ K.depth_maps,
+ K.optical_flows,
+ K.categories,
+)
+
+
+def default_collate(
+ batch: list[DictData],
+ collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
+ sensors: Sequence[str] | None = None,
+) -> DictData:
+ """Default batch collate.
+
+ It will concatenate images and stack seg_masks, extrinsics, intrinsics,
+ and depth_maps. Other keys will be put into a list.
+
+ Args:
+ batch (list[DictData]): List of data dicts.
+ collate_keys (Sequence[str]): Keys to be collated. Default is
+ DEFAULT_COLLATE_KEYS.
+ sensors (Sequence[str] | None): List of sensors to collate. If is not
+ None will raise an error. Default is None.
+
+ Returns:
+ DictData: Collated data dict.
+ """
+ assert sensors is None, "If specified sensors, use multi_sensor_collate."
+
+ data: DictData = {}
+ for key in batch[0]:
+ try:
+ if key == "transforms": # skip transform parameters
+ continue
+ if key in [K.images]:
+ data[key] = torch.cat([b[key] for b in batch])
+ elif key in collate_keys:
+ data[key] = torch.stack([b[key] for b in batch], 0)
+ else:
+ data[key] = [b[key] for b in batch]
+ except RuntimeError as e:
+ raise RuntimeError(f"Error collating key {key}") from e
+ return data
+
+
+def multi_sensor_collate(
+ batch: list[DictData],
+ collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
+ sensors: Sequence[str] | None = None,
+) -> DictData:
+ """Default multi-sensor batch collate.
+
+ Args:
+ batch (list[DictData]): List of data dicts. Each data dict contains
+ data from multiple sensors.
+ collate_keys (Sequence[str]): Keys to be collated. Default is
+ DEFAULT_COLLATE_KEYS.
+ sensors (Sequence[str] | None): List of sensors to collate. If None,
+ will raise an error. Default is None.
+
+ Returns:
+ DictData: Collated data dict.
+ """
+ assert (
+ sensors is not None
+ ), "If not specified sensors, use default_collate."
+
+ collated_batch: DictData = {}
+
+ # For each sensor, collate the batch. Other keys will be put into a list.
+ for key in batch[0]:
+ inner_batch = [b[key] for b in batch]
+ if key in sensors:
+ collated_batch[key] = default_collate(inner_batch, collate_keys)
+ else:
+ collated_batch[key] = inner_batch
+ return collated_batch
+
+
+def default_pipeline(data: list[DictData]) -> list[DictData]:
+ """Default data pipeline."""
+ return compose([ToTensor()])(data)
+
+
+def build_train_dataloader(
+ dataset: DataPipe,
+ samples_per_gpu: int = 1,
+ workers_per_gpu: int = 1,
+ batchprocess_fn: Callable[
+ [list[DictData]], list[DictData]
+ ] = default_pipeline,
+ collate_fn: Callable[
+ [list[DictData], Sequence[str]], DictData
+ ] = default_collate,
+ collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
+ sensors: Sequence[str] | None = None,
+ pin_memory: bool = True,
+ shuffle: bool | None = True,
+ drop_last: bool = False,
+ seed: int | None = None,
+ aspect_ratio_grouping: bool = False,
+ sampler: Sampler | None = None, # type: ignore
+ disable_subprocess_warning: bool = False,
+) -> DataLoader[DictDataOrList]:
+ """Build training dataloader."""
+ assert isinstance(dataset, DataPipe), "dataset must be a DataPipe"
+
+ def _collate_fn_single(data: list[DictData]) -> DictData:
+ """Collates data from single view dataset."""
+ return collate_fn( # type: ignore
+ batch=batchprocess_fn(data),
+ collate_keys=collate_keys,
+ sensors=sensors,
+ )
+
+ def _collate_fn_multi(data: list[list[DictData]]) -> list[DictData]:
+ """Collates data from multi view dataset."""
+ views = []
+ for view_idx in range(len(data[0])):
+ view = collate_fn( # type: ignore
+ batch=batchprocess_fn([d[view_idx] for d in data]),
+ collate_keys=collate_keys,
+ sensors=sensors,
+ )
+ views.append(view)
+ return views
+
+ def _worker_init_fn(worker_id: int) -> None:
+ """Will be called on each worker after seeding and before data loading.
+
+ Args:
+ worker_id (int): Worker id in [0, num_workers - 1].
+ """
+ if seed is not None:
+ # The seed of each worker equals to
+ # num_workers * rank + worker_id + user_seed
+ worker_seed = workers_per_gpu * get_rank() + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ if disable_subprocess_warning and worker_id != 0:
+ warnings.simplefilter("ignore")
+
+ if sampler is None:
+ if get_world_size() > 1:
+ assert isinstance(
+ shuffle, bool
+ ), "When using distributed training, shuffle must be a boolean."
+ sampler = DistributedSampler(
+ dataset, shuffle=shuffle, drop_last=drop_last
+ )
+ shuffle = False
+ drop_last = False
+ elif shuffle:
+ sampler = RandomSampler(dataset)
+ shuffle = False
+ else:
+ sampler = SequentialSampler(dataset)
+
+ batch_sampler = None
+ if aspect_ratio_grouping:
+ batch_sampler = AspectRatioBatchSampler(
+ sampler, batch_size=samples_per_gpu, drop_last=drop_last
+ )
+ samples_per_gpu = 1
+ shuffle = None
+ drop_last = False
+ sampler = None
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=samples_per_gpu,
+ num_workers=workers_per_gpu,
+ collate_fn=(
+ _collate_fn_multi if dataset.has_reference else _collate_fn_single
+ ),
+ sampler=sampler,
+ batch_sampler=batch_sampler,
+ worker_init_fn=_worker_init_fn,
+ persistent_workers=workers_per_gpu > 0,
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ )
+ return dataloader
+
+
+def build_inference_dataloaders(
+ datasets: Dataset[DictDataOrList] | list[Dataset[DictDataOrList]],
+ samples_per_gpu: int = 1,
+ workers_per_gpu: int = 1,
+ video_based_inference: bool = False,
+ batchprocess_fn: Callable[
+ [list[DictData]], list[DictData]
+ ] = default_pipeline,
+ collate_fn: Callable[
+ [list[DictData], Sequence[str]], DictData
+ ] = default_collate,
+ collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
+ sensors: Sequence[str] | None = None,
+) -> list[DataLoader[DictDataOrList]]:
+ """Build dataloaders for test / predict."""
+
+ def _collate_fn(data: list[DictData]) -> DictData:
+ """Collates data for inference."""
+ return collate_fn( # type: ignore
+ batch=batchprocess_fn(data),
+ collate_keys=collate_keys,
+ sensors=sensors,
+ )
+
+ if isinstance(datasets, Dataset):
+ datasets_ = [datasets]
+ else:
+ datasets_ = datasets
+
+ dataloaders = []
+ for dataset in datasets_:
+ sampler: DistributedSampler[list[int]] | None
+ if get_world_size() > 1:
+ if video_based_inference:
+ if isinstance(dataset, DataPipe):
+ assert (
+ len(dataset.datasets) == 1
+ ), "DDP Vdieo Inference only support a single dataset."
+ current_dataset = dataset.datasets[0]
+ else:
+ current_dataset = dataset
+
+ assert isinstance(
+ current_dataset, VideoDataset
+ ), "Video based inference needs a VideoDataset."
+ sampler = VideoInferenceSampler(current_dataset)
+ else:
+ sampler = DistributedSampler(dataset)
+ else:
+ sampler = None
+
+ test_dataloader = DataLoader(
+ dataset,
+ batch_size=samples_per_gpu,
+ num_workers=workers_per_gpu,
+ sampler=sampler,
+ shuffle=False,
+ collate_fn=_collate_fn,
+ persistent_workers=workers_per_gpu > 0,
+ )
+ dataloaders.append(test_dataloader)
+ return dataloaders
diff --git a/vis4d/data/reference.py b/vis4d/data/reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b27f7862f7b8c3e21de1b770363f22be67422ac
--- /dev/null
+++ b/vis4d/data/reference.py
@@ -0,0 +1,239 @@
+"""Reference View Sampling.
+
+These Classes sample reference views from a dataset that contains videos.
+This is usually used when a model needs multiple samples of a video during
+training.
+"""
+
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing import Callable, List
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from .const import CommonKeys as K
+from .datasets.base import VideoDataset
+from .typing import DictData
+
+SortingFunc = Callable[[DictData, list[DictData]], List[DictData]]
+
+
+def sort_key_first(
+ cur_sample: DictData, ref_data: list[DictData]
+) -> list[DictData]:
+ """Sort views as key first."""
+ return [cur_sample, *ref_data]
+
+
+def sort_temporal(
+ cur_sample: DictData, ref_data: list[DictData]
+) -> list[DictData]:
+ """Sort views temporally."""
+ return sorted([cur_sample, *ref_data], key=lambda x: x[K.frame_ids])
+
+
+class ReferenceViewSampler:
+ """Base reference view sampler."""
+
+ def __init__(self, num_ref_samples: int) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_ref_samples (int): Number of reference views to sample.
+ """
+ self.num_ref_samples = num_ref_samples
+
+ @abstractmethod
+ def __call__(
+ self,
+ key_dataset_index: int,
+ indices_in_video: list[int],
+ frame_ids: list[int],
+ ) -> list[int]:
+ """Sample num_ref_samples reference view indices.
+
+ Args:
+ key_index (int): Index of key view in the video.
+ indices_in_video (list[int]): All dataset indices in the video.
+ frame_ids (list[int]): Frame ids of all views in the video.
+
+ Returns:
+ list[int]: dataset indices of reference views.
+ """
+ raise NotImplementedError
+
+
+class SequentialViewSampler(ReferenceViewSampler):
+ """Sequential View Sampler."""
+
+ def __call__(
+ self,
+ key_dataset_index: int,
+ indices_in_video: list[int],
+ frame_ids: list[int],
+ ) -> list[int]:
+ """Sample sequential reference views."""
+ assert len(frame_ids) >= self.num_ref_samples + 1
+
+ key_index = indices_in_video.index(key_dataset_index)
+
+ right = key_index + 1 + self.num_ref_samples
+ if right <= len(indices_in_video):
+ ref_dataset_indices = indices_in_video[key_index + 1 : right]
+ else:
+ left = key_index - (right - len(indices_in_video))
+ ref_dataset_indices = (
+ indices_in_video[left:key_index]
+ + indices_in_video[key_index + 1 :]
+ )
+ return ref_dataset_indices
+
+
+class UniformViewSampler(ReferenceViewSampler):
+ """View Sampler that chooses reference views uniform at random."""
+
+ def __init__(self, scope: int, num_ref_samples: int) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ scope (int): Define scope of neighborhood to key view to sample
+ from.
+ num_ref_samples (int): Number of reference views to sample.
+ """
+ super().__init__(num_ref_samples)
+ if scope != 0 and scope < num_ref_samples // 2:
+ raise ValueError("Scope must be higher than num_ref_imgs / 2.")
+ self.scope = scope
+
+ def _get_valid_indices(
+ self, key_index: int, indices_in_video: list[int], frame_ids: list[int]
+ ) -> list[int]:
+ """Get valid indices in video."""
+ key_fid = frame_ids[key_index]
+ min_fid = max(0, key_fid - self.scope)
+ max_fid = min(key_fid + self.scope, frame_ids[-1])
+
+ return [
+ ind
+ for i, ind in enumerate(indices_in_video)
+ if min_fid <= frame_ids[i] <= max_fid and i != key_index
+ ]
+
+ def __call__(
+ self,
+ key_dataset_index: int,
+ indices_in_video: list[int],
+ frame_ids: list[int],
+ ) -> list[int]:
+ """Uniformly sample reference views."""
+ if self.scope > 0:
+ key_index = indices_in_video.index(key_dataset_index)
+
+ valid_indices = self._get_valid_indices(
+ key_index, indices_in_video, frame_ids
+ )
+
+ if len(valid_indices) > 0:
+ assert len(valid_indices) >= self.num_ref_samples
+ return np.random.choice(
+ valid_indices, self.num_ref_samples, replace=False
+ ).tolist()
+
+ return [key_dataset_index] * self.num_ref_samples
+
+
+class MultiViewDataset(Dataset[list[DictData]]):
+ """Dataset that samples reference views from a video dataset."""
+
+ def __init__(
+ self,
+ dataset: VideoDataset,
+ sampler: ReferenceViewSampler,
+ sort_fn: SortingFunc = sort_key_first,
+ num_retry: int = 3,
+ match_key: str = K.boxes2d_track_ids,
+ skip_nomatch_samples: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ dataset (Dataset): Video dataset to sample from.
+ sampler (ReferenceViewSampler): Sampler that samples reference
+ views.
+ sort_fn (SortingFunc, optional): Function that sorts key and
+ reference views. Defaults to sort_key_first.
+ num_retry (int, optional): Number of retries if no match is found.
+ Defaults to 3.
+ match_key (str, optional): Key to match reference views with key
+ view. Defaults to K.boxes2d_track_ids.
+ skip_nomatch_samples (bool, optional): Whether to skip samples
+ where no match is found. Defaults to False.
+ """
+ self.dataset = dataset
+ self.sampler = sampler
+ self.sort_fn = sort_fn
+ self.num_retry = num_retry
+ self.match_key = match_key
+ self.skip_nomatch_samples = skip_nomatch_samples
+
+ def has_matches(
+ self, key_data: DictData, ref_data: list[DictData]
+ ) -> bool:
+ """Check if key / ref data have matches."""
+ key_target = key_data[self.match_key]
+ for ref_view in ref_data:
+ ref_target = ref_view[self.match_key]
+ match = np.equal(
+ np.expand_dims(key_target, axis=1), ref_target[None]
+ )
+ if match.any():
+ return True
+ return False # pragma: no cover
+
+ def __len__(self) -> int:
+ """Get length of dataset."""
+ return len(self.dataset)
+
+ def get_ref_data(self, ref_indices: list[int]) -> list[DictData]:
+ """Get reference data from dataset."""
+ ref_data = []
+ for ref_index in ref_indices:
+ ref_sample = self.dataset[ref_index]
+ ref_sample["keyframes"] = False
+ ref_data.append(ref_sample)
+
+ assert self.sampler.num_ref_samples == len(ref_data)
+ return ref_data
+
+ def __getitem__(self, index: int) -> list[DictData]:
+ """Get item from dataset."""
+ cur_sample = self.dataset[index]
+ cur_sample["keyframes"] = True
+
+ indices_in_video = self.dataset.video_mapping["video_to_indices"][
+ cur_sample[K.sequence_names]
+ ]
+ frame_ids = self.dataset.video_mapping["video_to_frame_ids"][
+ cur_sample[K.sequence_names]
+ ]
+
+ if self.sampler.num_ref_samples > 0:
+ for _ in range(self.num_retry):
+ ref_indices = self.sampler(index, indices_in_video, frame_ids)
+
+ ref_data = self.get_ref_data(ref_indices)
+
+ if self.skip_nomatch_samples and not (
+ self.has_matches(cur_sample, ref_data)
+ ):
+ continue
+
+ return self.sort_fn(cur_sample, ref_data)
+
+ ref_indices = [index] * self.sampler.num_ref_samples
+ ref_data = self.get_ref_data(ref_indices)
+ return [cur_sample, *ref_data]
+
+ return [cur_sample]
diff --git a/vis4d/data/resample.py b/vis4d/data/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..e31b0c6361f6e78c30cd4eb20f135e903151dad7
--- /dev/null
+++ b/vis4d/data/resample.py
@@ -0,0 +1,77 @@
+"""Resample index to recover the original dataset length."""
+
+from __future__ import annotations
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from vis4d.common.logging import rank_zero_info
+
+from .reference import MultiViewDataset
+from .typing import DictDataOrList
+
+
+class ResampleDataset(Dataset[DictDataOrList]):
+ """Dataset wrapper to recover the filtered samples through resampling.
+
+ In MMEngine and Detectron2, the dataset might return None when the sample
+ has no valid annotations. They will resample the index and try to get the
+ valid training data. The length of dataset will be different depends on
+ whether filtering the empty samples first.
+
+ This dataset wrapper resamples the index to recover the original dataset
+ length (before filter empty frames) to align with the other codebases'
+ implementation.
+
+ https://github.com/open-mmlab/mmengine/blob/main/mmengine/dataset/base_dataset.py#L411
+ https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/common.py#L96
+ """
+
+ def __init__(self, dataset: Dataset[DictDataOrList]) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.dataset = dataset
+ self.has_reference = isinstance(dataset, MultiViewDataset)
+ self.valid_len = len(dataset) # type: ignore
+
+ # Handle the case that dataset is already wrapped.
+ if hasattr(self.dataset, "dataset"):
+ _dataset = self.dataset.dataset
+ else:
+ _dataset = self.dataset
+
+ assert hasattr(_dataset, "original_len"), (
+ "The dataset must have the attribute `original_len` to resample "
+ + "index to recover the original length."
+ )
+ self.original_len = _dataset.original_len
+
+ rank_zero_info(
+ f"Recover {_dataset} to {self.original_len} samples by resampling "
+ + "index."
+ )
+
+ def __len__(self) -> int:
+ """Return the length of dataset.
+
+ Returns:
+ int: Length of dataset.
+ """
+ return self.original_len
+
+ def __getitem__(self, idx: int) -> DictDataOrList:
+ """Get original dataset idx according to the given index.
+
+ Resample index to recover the original dataset length.
+
+ Args:
+ idx (int): The index of original dataset length.
+
+ Returns:
+ DictDataOrList: Data of the corresponding index.
+ """
+ if idx < self.valid_len:
+ index = idx
+ else:
+ index = np.random.randint(0, self.valid_len)
+ return self.dataset[index]
diff --git a/vis4d/data/samplers.py b/vis4d/data/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7401fcc2db25c7c5f356278674e5e8949e0f443
--- /dev/null
+++ b/vis4d/data/samplers.py
@@ -0,0 +1,156 @@
+"""Vis4D data samplers."""
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+
+import numpy as np
+from torch.utils.data import Dataset
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import BatchSampler, Sampler
+
+from vis4d.data.const import CommonKeys as K
+
+from .datasets.base import VideoDataset
+from .typing import DictDataOrList
+
+
+class VideoInferenceSampler(
+ DistributedSampler[list[int]]
+): # pragma: no cover # No unittest for distributed setting.
+ """Produce sequence ordered indices for inference across all workers.
+
+ Inference needs to run on the __exact__ set of sequences and their
+ respective samples, therefore if the sequences are not divisible by the
+ number of workers or if they have different length, the sampler
+ produces different number of samples on different workers.
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset[DictDataOrList],
+ num_replicas: None | int = None,
+ rank: None | int = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ dataset (Dataset): Inference dataset.
+ num_replicas (int, optional): Number of processes participating in
+ distributed training. By default, :attr:`world_size` is
+ retrieved from the current distributed group.
+ rank (int, optional): Rank of the current process within
+ :attr:`num_replicas`. By default, :attr:`rank` is retrieved
+ from the current distributed group.
+ shuffle (bool, optional): If ``True`` (default), sampler will
+ shuffle the indices.
+ seed (int, optional): random seed used to shuffle the sampler if
+ :attr:`shuffle=True`. This number should be identical across
+ all processes in the distributed group. Default: ``0``.
+ drop_last (bool, optional): if ``True``, then the sampler will drop
+ the tail of the data to make it evenly divisible across the
+ number of replicas. If ``False``, the sampler will add extra
+ indices to make the data evenly divisible across the replicas.
+ Default: ``False``.
+ """
+ super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
+ assert isinstance(dataset, VideoDataset)
+ self.sequences = list(dataset.video_mapping["video_to_indices"])
+ self.num_seqs = len(self.sequences)
+ assert self.num_seqs >= self.num_replicas, (
+ f"Number of sequences ({self.num_seqs}) must be greater or "
+ f"equal to number of replicas ({self.num_replicas})!"
+ )
+ chunks = np.array_split(self.sequences, self.num_replicas)
+ self._local_seqs = chunks[self.rank]
+ self._local_idcs: list[int] = []
+ for seq in self._local_seqs:
+ self._local_idcs.extend(
+ dataset.video_mapping["video_to_indices"][seq]
+ )
+
+ def __iter__(self) -> Iterator[list[int]]:
+ """Iteration method."""
+ return iter(self._local_idcs) # type: ignore
+
+ def __len__(self) -> int:
+ """Return length of sampler instance."""
+ return len(self._local_idcs)
+
+
+class AspectRatioBatchSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio.
+
+ Moidified from:
+ https://github.com/open-mmlab/mmdetection/blob/main/mmdet/datasets/samplers/batch_sampler.py
+
+ Args:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ """
+
+ def __init__(
+ self,
+ sampler: Sampler, # type: ignore
+ batch_size: int,
+ drop_last: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ if not isinstance(sampler, Sampler):
+ raise TypeError(
+ "sampler should be an instance of ``Sampler``, "
+ f"but got {sampler}"
+ )
+
+ super().__init__(sampler, batch_size, drop_last)
+
+ # two groups for w < h and w >= h
+ self._aspect_ratio_buckets: list[list[int]] = [[] for _ in range(2)]
+
+ def __iter__(self) -> Iterator[list[int]]:
+ """Iteration method."""
+ for idx in self.sampler:
+ if hasattr(self.sampler, "dataset"):
+ data_dict = self.sampler.dataset[idx]
+ elif hasattr(self.sampler, "data_source"):
+ data_dict = self.sampler.data_source[idx]
+ else:
+ raise ValueError(
+ "sampler should have dataset or data_source attribute"
+ )
+ height, width = data_dict[K.input_hw]
+ bucket_id = 0 if width < height else 1
+ bucket = self._aspect_ratio_buckets[bucket_id]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+
+ # yield the rest data and reset the bucket
+ left_data = (
+ self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[1]
+ )
+ self._aspect_ratio_buckets = [[] for _ in range(2)]
+ while len(left_data) > 0:
+ if len(left_data) <= self.batch_size:
+ if not self.drop_last:
+ yield left_data[:]
+ left_data = []
+ else:
+ yield left_data[: self.batch_size]
+ left_data = left_data[self.batch_size :]
+
+ def __len__(self) -> int:
+ """Return length of sampler instance."""
+ if self.drop_last:
+ return len(self.sampler) // self.batch_size # type: ignore
+
+ return (
+ len(self.sampler) + self.batch_size - 1 # type: ignore
+ ) // self.batch_size
diff --git a/vis4d/data/transforms/__init__.py b/vis4d/data/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59f70b93c8bacdc7e7f439214c81236985465264
--- /dev/null
+++ b/vis4d/data/transforms/__init__.py
@@ -0,0 +1,5 @@
+"""Transforms."""
+
+from .base import RandomApply, Transform, compose
+
+__all__ = ["Transform", "RandomApply", "compose"]
diff --git a/vis4d/data/transforms/__pycache__/__init__.cpython-311.pyc b/vis4d/data/transforms/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f97d8d001485ca71033275df79285446c239c2f
Binary files /dev/null and b/vis4d/data/transforms/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/data/transforms/__pycache__/base.cpython-311.pyc b/vis4d/data/transforms/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b999e4559784cda0663ad91e77b6d31c1731273b
Binary files /dev/null and b/vis4d/data/transforms/__pycache__/base.cpython-311.pyc differ
diff --git a/vis4d/data/transforms/__pycache__/normalize.cpython-311.pyc b/vis4d/data/transforms/__pycache__/normalize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e9f57f46b17f4f29d1a15a9974ee0786c28cbfe
Binary files /dev/null and b/vis4d/data/transforms/__pycache__/normalize.cpython-311.pyc differ
diff --git a/vis4d/data/transforms/__pycache__/pad.cpython-311.pyc b/vis4d/data/transforms/__pycache__/pad.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51e47e25b5f4feae685eb67e8c8e0bc7da04afb6
Binary files /dev/null and b/vis4d/data/transforms/__pycache__/pad.cpython-311.pyc differ
diff --git a/vis4d/data/transforms/__pycache__/resize.cpython-311.pyc b/vis4d/data/transforms/__pycache__/resize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffd0b059ad4bf3ce2e6c060d786f91e749f1cd32
Binary files /dev/null and b/vis4d/data/transforms/__pycache__/resize.cpython-311.pyc differ
diff --git a/vis4d/data/transforms/__pycache__/to_tensor.cpython-311.pyc b/vis4d/data/transforms/__pycache__/to_tensor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..646112b7bdd3340b1f0b820c33ea7920c7558605
Binary files /dev/null and b/vis4d/data/transforms/__pycache__/to_tensor.cpython-311.pyc differ
diff --git a/vis4d/data/transforms/affine.py b/vis4d/data/transforms/affine.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c1bacc109b220fee73fa8cac2b083b93e7ff70
--- /dev/null
+++ b/vis4d/data/transforms/affine.py
@@ -0,0 +1,314 @@
+"""Affine transformation.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import math
+import random
+from typing import TypedDict
+
+import numpy as np
+import torch
+
+from vis4d.common.imports import OPENCV_AVAILABLE
+from vis4d.common.typing import NDArrayF32, NDArrayI64
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.box.box2d import bbox_clip, bbox_project
+
+from .base import Transform
+from .crop import _get_keep_mask
+
+if OPENCV_AVAILABLE:
+ import cv2
+else:
+ raise ImportError("Please install opencv-python to use this module.")
+
+
+class AffineParam(TypedDict):
+ """Parameters for Affine."""
+
+ warp_matrix: NDArrayF32
+ height: int
+ width: int
+
+
+def get_rotation_matrix(rotate_degrees: float) -> NDArrayF32:
+ """Generate rotation matrix.
+
+ Args:
+ rotate_degrees (float): Rotation degrees.
+ """
+ radian = math.radians(rotate_degrees)
+ rotation_matrix = np.array(
+ [
+ [np.cos(radian), -np.sin(radian), 0.0],
+ [np.sin(radian), np.cos(radian), 0.0],
+ [0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ return rotation_matrix
+
+
+def get_scaling_matrix(scale_ratio: float) -> NDArrayF32:
+ """Generate scaling matrix.
+
+ Args:
+ scale_ratio (float): Scale ratio.
+ """
+ scaling_matrix = np.array(
+ [[scale_ratio, 0.0, 0.0], [0.0, scale_ratio, 0.0], [0.0, 0.0, 1.0]],
+ dtype=np.float32,
+ )
+ return scaling_matrix
+
+
+def get_shear_matrix(
+ x_shear_degrees: float, y_shear_degrees: float
+) -> NDArrayF32:
+ """Generate shear matrix.
+
+ Args:
+ x_shear_degrees (float): X shear degrees.
+ y_shear_degrees (float): Y shear degrees.
+ """
+ x_radian = math.radians(x_shear_degrees)
+ y_radian = math.radians(y_shear_degrees)
+ shear_matrix = np.array(
+ [
+ [1, np.tan(x_radian), 0.0],
+ [np.tan(y_radian), 1, 0.0],
+ [0.0, 0.0, 1.0],
+ ],
+ dtype=np.float32,
+ )
+ return shear_matrix
+
+
+def get_translation_matrix(x_trans: float, y_trans: float) -> NDArrayF32:
+ """Generate translation matrix.
+
+ Args:
+ x_trans (float): X translation.
+ y_trans (float): Y translation.
+ """
+ translation_matrix = np.array(
+ [[1, 0.0, x_trans], [0.0, 1, y_trans], [0.0, 0.0, 1.0]],
+ dtype=np.float32,
+ )
+ return translation_matrix
+
+
+@Transform(K.input_hw, ["transforms.affine"])
+class GenAffineParameters:
+ """Random affine transform data augmentation.
+
+ This operation randomly generates affine transform matrix which including
+ rotation, translation, shear, and scaling transforms.
+ """
+
+ def __init__(
+ self,
+ max_rotate_degree: float = 10.0,
+ max_translate_ratio: float = 0.1,
+ scaling_ratio_range: tuple[float, float] = (0.5, 1.5),
+ max_shear_degree: float = 2.0,
+ border: tuple[int, int] = (0, 0),
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ max_rotate_degree (float): Maximum degrees of rotation transform.
+ Defaults to 10.
+ max_translate_ratio (float): Maximum ratio of translation.
+ Defaults to 0.1.
+ scaling_ratio_range (tuple[float]): Min and max ratio of
+ scaling transform. Defaults to (0.5, 1.5).
+ max_shear_degree (float): Maximum degrees of shear
+ transform. Defaults to 2.
+ border (tuple[int, int]): Distance from height and width sides of
+ input image to adjust output shape. Only used in mosaic
+ dataset. Defaults to (0, 0).
+ """
+ assert 0 <= max_translate_ratio <= 1
+ assert scaling_ratio_range[0] <= scaling_ratio_range[1]
+ assert scaling_ratio_range[0] > 0
+ self.max_rotate_degree = max_rotate_degree
+ self.max_translate_ratio = max_translate_ratio
+ self.scaling_ratio_range = scaling_ratio_range
+ self.max_shear_degree = max_shear_degree
+ self.border = border
+
+ def _get_random_homography_matrix(
+ self, height: int, width: int
+ ) -> NDArrayF32:
+ """Generate random homography matrix."""
+ # Rotation
+ rotation_degree = random.uniform(
+ -self.max_rotate_degree, self.max_rotate_degree
+ )
+ rotation_matrix = get_rotation_matrix(rotation_degree)
+
+ # Scaling
+ scaling_ratio = random.uniform(
+ self.scaling_ratio_range[0], self.scaling_ratio_range[1]
+ )
+ scaling_matrix = get_scaling_matrix(scaling_ratio)
+
+ # Shear
+ x_degree = random.uniform(
+ -self.max_shear_degree, self.max_shear_degree
+ )
+ y_degree = random.uniform(
+ -self.max_shear_degree, self.max_shear_degree
+ )
+ shear_matrix = get_shear_matrix(x_degree, y_degree)
+
+ # Translation
+ trans_x = (
+ random.uniform(-self.max_translate_ratio, self.max_translate_ratio)
+ * width
+ )
+ trans_y = (
+ random.uniform(-self.max_translate_ratio, self.max_translate_ratio)
+ * height
+ )
+ translate_matrix = get_translation_matrix(trans_x, trans_y)
+
+ warp_matrix = (
+ translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix
+ )
+ return warp_matrix
+
+ def __call__(self, input_hw: list[tuple[int, int]]) -> list[AffineParam]:
+ """Compute the parameters and put them in the data dict."""
+ img_shape = input_hw[0]
+ height = img_shape[0] + self.border[0] * 2
+ width = img_shape[1] + self.border[1] * 2
+
+ warp_matrix = self._get_random_homography_matrix(height, width)
+ return [
+ AffineParam(warp_matrix=warp_matrix, height=height, width=width)
+ ] * len(input_hw)
+
+
+@Transform(
+ [
+ K.images,
+ "transforms.affine.warp_matrix",
+ "transforms.affine.height",
+ "transforms.affine.width",
+ ],
+ [K.images, K.input_hw],
+)
+class AffineImages:
+ """Affine Images."""
+
+ def __init__(
+ self,
+ border_val: tuple[int, int, int] = (114, 114, 114),
+ as_int: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ border_val (tuple[int, int, int]): Border padding values of 3
+ channels. Defaults to (114, 114, 114).
+ as_int (bool): Whether to convert the image to int. Defaults to
+ False.
+ """
+ self.border_val = border_val
+ self.as_int = as_int
+
+ def __call__(
+ self,
+ images: list[NDArrayF32],
+ warp_matrix_list: list[NDArrayF32],
+ height_list: list[int],
+ width_list: list[int],
+ ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]:
+ """Affine a list of image of dimensions [N, H, W, C]."""
+ input_hw_list = []
+ for i, (image, warp_matrix, height, width) in enumerate(
+ zip(images, warp_matrix_list, height_list, width_list)
+ ):
+ image = image[0].astype(np.uint8) if self.as_int else image[0]
+ image = cv2.warpPerspective( # pylint: disable=no-member, unsubscriptable-object, line-too-long
+ image,
+ warp_matrix,
+ dsize=(width, height),
+ borderValue=self.border_val,
+ )[
+ None, ...
+ ].astype(
+ np.float32
+ )
+
+ images[i] = image
+ input_hw_list.append((height, width))
+ return images, input_hw_list
+
+
+@Transform(
+ in_keys=[
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ "transforms.affine.warp_matrix",
+ "transforms.affine.height",
+ "transforms.affine.width",
+ ],
+ out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids],
+)
+class AffineBoxes2D:
+ """Apply Affine to a list of 2D bounding boxes."""
+
+ def __init__(self, bbox_clip_border: bool = True) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ bbox_clip_border (bool, optional): Whether to clip the objects
+ outside the border of the image. In some dataset like MOT17,
+ the gt bboxes are allowed to cross the border of images.
+ Therefore, we don't need to clip the gt bboxes in these cases.
+ Defaults to True.
+ """
+ self.bbox_clip_border = bbox_clip_border
+
+ def __call__(
+ self,
+ boxes: list[NDArrayF32],
+ classes: list[NDArrayI64],
+ track_ids: list[NDArrayI64] | None,
+ warp_matrix_list: list[NDArrayF32],
+ height_list: list[int],
+ width_list: list[int],
+ ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
+ """Apply Affine to 2D bounding boxes."""
+ for i, (box, class_, warp_matrix, height, width) in enumerate(
+ zip(
+ boxes,
+ classes,
+ warp_matrix_list,
+ height_list,
+ width_list,
+ )
+ ):
+ box_ = bbox_project(
+ torch.from_numpy(box), torch.from_numpy(warp_matrix)
+ )
+ if self.bbox_clip_border:
+ box_ = bbox_clip(box_, (height, width))
+ boxes[i] = box_.numpy()
+
+ keep_mask = _get_keep_mask(
+ boxes[i], np.array([0, 0, width, height])
+ )
+ boxes[i] = boxes[i][keep_mask]
+ classes[i] = class_[keep_mask]
+ if track_ids is not None:
+ track_ids[i] = track_ids[i][keep_mask]
+
+ return boxes, classes, track_ids
diff --git a/vis4d/data/transforms/autoaugment.py b/vis4d/data/transforms/autoaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e47ad9f98846af3cce04aae27f5c4e9d57c0d5e
--- /dev/null
+++ b/vis4d/data/transforms/autoaugment.py
@@ -0,0 +1,209 @@
+"""A wrap for timm transforms."""
+
+from __future__ import annotations
+
+from typing import Union
+
+import numpy as np
+from PIL import Image
+
+from vis4d.common.imports import TIMM_AVAILABLE
+from vis4d.common.typing import NDArrayUI8
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+if TIMM_AVAILABLE:
+ from timm.data.auto_augment import (
+ _RAND_INCREASING_TRANSFORMS,
+ _RAND_TRANSFORMS,
+ AugMixAugment,
+ AutoAugment,
+ RandAugment,
+ augmix_ops,
+ auto_augment_policy,
+ rand_augment_ops,
+ )
+else:
+ raise ImportError("timm is not installed.")
+
+AugOp = Union[AutoAugment, RandAugment, AugMixAugment]
+
+
+def _apply_aug(images: NDArrayUI8, aug_op: AugOp) -> NDArrayUI8:
+ """Apply augmentation to a batch of images with shape [N, H, W, C]."""
+ assert images.shape[-1] == 3, "Images must be in RGB format."
+ imgs: list[Image.Image] = []
+ for img in images:
+ # convert to uint8 if necessary
+ if img.dtype != np.uint8:
+ img = img.astype(np.uint8)
+ imgs.append(aug_op(Image.fromarray(img)))
+ return np.stack([np.array(img).astype(np.float32) for img in imgs])
+
+
+@Transform(K.images, K.images)
+class _AutoAug:
+ """Apply Timm's AutoAugment to a image array."""
+
+ def __init__(self) -> None:
+ self.aug_op: AugOp | None = None
+
+ def _create(self, policy: str, hparams: dict[str, float]) -> AugOp:
+ """Create augmentation op."""
+ aa_policy = auto_augment_policy(policy, hparams=hparams)
+ return AutoAugment(aa_policy)
+
+ def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
+ """Execute the transform."""
+ assert self.aug_op is not None, "Augmentation op is not created."
+ for i, img in enumerate(images):
+ images[i] = _apply_aug(img, self.aug_op)
+ return images
+
+
+class AutoAugV0(_AutoAug):
+ """Apply Timm's AutoAugment (policy=v0) to a image array."""
+
+ def __init__(self, magnitude_std: float = 0.5):
+ """Create an instance of AutoAug.
+
+ Args:
+ magnitude_std (float, optional): Standard deviation of the
+ magnitude for random autoaugment. Defaults to 0.5.
+ """
+ super().__init__()
+ self.aug_op = self._create("v0", {"magnitude_std": magnitude_std})
+
+
+class AutoAugOriginal(_AutoAug):
+ """Apply Timm's AutoAugment (policy=original) to a image array."""
+
+ def __init__(self, magnitude_std: float = 0.5):
+ """Create an instance of AutoAug.
+
+ Args:
+ magnitude_std (float, optional): Standard deviation of the
+ magnitude for random autoaugment. Defaults to 0.5.
+ """
+ super().__init__()
+ self.aug_op = self._create(
+ "original", {"magnitude_std": magnitude_std}
+ )
+
+
+@Transform(K.images, K.images)
+class RandAug:
+ """Apply Timm's RandomAugment to a image tensor."""
+
+ def __init__(
+ self,
+ magnitude: int = 10,
+ num_layers: int = 2,
+ use_increasing: bool = False,
+ magnitude_std: float = 0.5,
+ ):
+ """Create an instance of RandAug.
+
+ Args:
+ magnitude (int): Level of magnitude for augments, ranging from 1 to
+ 9.
+ num_layers (int, optional): Number of layers for rand augment.
+ Defaults to 2.
+ use_increasing (bool, optional): Whether to use increasing setting
+ for transforms. Defaults to False.
+ magnitude_std (float, optional): Standard deviation of the
+ magnitude for random autoaugment. Defaults to 0.5.
+
+ Returns:
+ Callable: A function that takes a tensor of shape [N, C, H, W] and
+ returns a tensor of the same shape.
+
+ Example:
+ Rand augment with magnitude 9. (`https://arxiv.org/abs/1909.13719`)
+ >>> rand_augment(magnitude=9)
+ """
+ super().__init__()
+ assert TIMM_AVAILABLE, "timm is not installed."
+ self.magnitude = magnitude
+ self.num_layers = num_layers
+ self.use_increasing = use_increasing
+ self.magnitude_std = magnitude_std
+ hparams = {"magnitude_std": self.magnitude_std}
+
+ if self.use_increasing:
+ transforms = _RAND_INCREASING_TRANSFORMS
+ else:
+ transforms = _RAND_TRANSFORMS
+ ra_ops = rand_augment_ops(
+ magnitude=self.magnitude, hparams=hparams, transforms=transforms
+ )
+ self.aug_op = RandAugment(ra_ops, self.num_layers)
+
+ def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
+ """Execute the transform."""
+ for i, img in enumerate(images):
+ images[i] = _apply_aug(img, self.aug_op)
+ return images
+
+
+@Transform(K.images, K.images)
+class AugMix:
+ """Apply Timm's AugMix to a image tensor."""
+
+ def __init__(
+ self,
+ magnitude: int = 10,
+ width: int = 3,
+ alpha: float = 1.0,
+ depth: int = -1,
+ blended: bool = True,
+ magnitude_std: float = 0.5,
+ ):
+ """Create an instance of AugMix.
+
+ Args:
+ magnitude (int): Level of magnitude, ranging from 1 to 9.
+ width (int, optional): Width of the augmentation chain. Defaults to
+ 3.
+ alpha (float, optional): Alpha for beta distribution. Defaults to
+ 1.0.
+ depth (int, optional): Depth of the augmentation chain. Defaults to
+ -1.
+ blended (bool, optional): Whether to blend the original image with
+ the augmented image. Defaults to True.
+ magnitude_std (float, optional): Standard deviation of the
+ magnitude for random autoaugment. Defaults to 0.5.
+
+ Returns:
+ Callable: A function that takes a tensor of shape [N, C, H, W] and
+ returns a tensor of the same shape.
+
+ Example:
+ Augmix with magnitude 9. (`https://arxiv.org/abs/1912.02781`)
+ >>> augmix(magnitude=9)
+ """
+ super().__init__()
+ assert TIMM_AVAILABLE, "timm is not installed."
+ self.magnitude = magnitude
+ self.width = width
+ self.alpha = alpha
+ self.depth = depth
+ self.blended = blended
+ self.magnitude_std = magnitude_std
+ hparams = {"magnitude_std": self.magnitude_std}
+
+ am_ops = augmix_ops(magnitude=self.magnitude, hparams=hparams)
+ self.aug_op = AugMixAugment(
+ am_ops,
+ alpha=self.alpha,
+ width=self.width,
+ depth=self.depth,
+ blended=self.blended,
+ )
+
+ def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
+ """Execute the transform."""
+ for i, img in enumerate(images):
+ images[i] = _apply_aug(img, self.aug_op)
+ return images
diff --git a/vis4d/data/transforms/base.py b/vis4d/data/transforms/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec1b2d8b52a10eecf2bbb79d512070bada4532e9
--- /dev/null
+++ b/vis4d/data/transforms/base.py
@@ -0,0 +1,227 @@
+"""Basic data augmentation class."""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Sequence
+from typing import TypeVar, no_type_check
+
+import torch
+
+from vis4d.common.dict import get_dict_nested, set_dict_nested
+from vis4d.data.typing import DictData
+
+TFunctor = TypeVar("TFunctor", bound=object) # pylint: disable=invalid-name
+TransformFunction = Callable[[list[DictData]], list[DictData]]
+
+
+class Transform:
+ """Transforms Decorator.
+
+ This class stores which `in_keys` are input to a transformation function
+ and which `out_keys` are overwritten in the data dictionary by the output
+ of this transformation.
+ Nested keys in the data dictionary can be accessed via key.subkey1.subkey2
+ If any of `in_keys` is 'data', the full data dictionary will be forwarded
+ to the transformation.
+ If the only entry in `out_keys` is 'data', the full data dictionary will
+ be updated with the return value of the transformation.
+ For the case of multi-sensor data, the sensors that the transform should be
+ applied can be set via the 'sensors' attribute. By default, we assume
+ a transformation is applied to all sensors.
+ This class will add a 'apply_to_data' method to a given Functor which is
+ used to call it on a DictData object. NOTE: This is an issue for static
+ checking and is not recognized by pylint. It will usually be called in the
+ compose() function and will not be called directly.
+
+ Example:
+ >>> @Transform(in_keys="images", out_keys="images")
+ >>> class MyTransform:
+ >>> def __call__(images: list[np.array]) -> list[np.array]:
+ >>> images = do_something(images)
+ >>> return images
+ >>> my_transform = MyTransform()
+ >>> data = my_transform.apply_to_data(data)
+ """
+
+ def __init__(
+ self,
+ in_keys: Sequence[str] | str,
+ out_keys: Sequence[str] | str,
+ sensors: Sequence[str] | str | None = None,
+ same_on_batch: bool = True,
+ ) -> None:
+ """Creates an instance of Transform.
+
+ Args:
+ in_keys (Sequence[str] | str): Specifies one or multiple (if any)
+ input keys of the data dictionary which should be remapeed to
+ another key. Defaults to None.
+ out_keys (Sequence[str] | str): Specifies one or multiple (if any)
+ output keys of the data dictionary which should be remaped to
+ another key. Defaults to None.
+ sensors (Sequence[str] | str | None, optional): Specifies the
+ sensors this transformation should be applied to. If None, it
+ will be applied to all available sensors. Defaults to None.
+ same_on_batch (bool, optional): Whether to use the same
+ transformation parameters to all sensors / view. Defaults to
+ True.
+ """
+ if isinstance(in_keys, str):
+ in_keys = [in_keys]
+ self.in_keys = in_keys
+
+ if isinstance(out_keys, str):
+ out_keys = [out_keys]
+ self.out_keys = out_keys
+
+ if isinstance(sensors, str):
+ sensors = [sensors]
+ self.sensors = sensors
+
+ self.same_on_batch = same_on_batch
+
+ @no_type_check
+ def __call__(self, transform: TFunctor) -> TFunctor:
+ """Add in_keys / out_keys / sensors / apply_to_data attributes.
+
+ Args:
+ transform (TFunctor): A given Functor.
+
+ Returns:
+ TFunctor: The decorated Functor.
+ """
+ original_init = transform.__init__
+
+ def apply_to_data(
+ self_, input_batch: list[DictData]
+ ) -> list[DictData]:
+ """Wrap function with a handler for input / output keys.
+
+ We use the specified in_keys in order to extract the positional
+ input arguments of a function from the data dictionary, and the
+ out_keys to replace the corresponding values in the output
+ dictionary.
+ """
+
+ def _transform_fn(batch: list[DictData]) -> list[DictData]:
+ in_batch = []
+ for key in self_.in_keys:
+ key_data = []
+ for data in batch:
+ # Optionally allow the function to get the full data
+ # dict as aux input and set default value to None if
+ # key is not found
+ key_data += [
+ (
+ get_dict_nested(
+ data, key.split("."), allow_missing=True
+ )
+ if key != "data"
+ else data
+ )
+ ]
+ if any(d is None for d in key_data):
+ # If any of the data in the batch is None, replace
+ # the input of the key with None.
+ in_batch.append(None)
+ else:
+ in_batch.append(key_data)
+
+ result = self_(*in_batch)
+
+ if len(self_.out_keys) == 1:
+ if self_.out_keys[0] == "data":
+ return result
+ result = [result]
+
+ for key, values in zip(self_.out_keys, result):
+ if values is None:
+ continue
+ for data, value in zip(batch, values):
+ if value is not None:
+ set_dict_nested(data, key.split("."), value)
+ return batch
+
+ if self_.sensors is not None:
+ if self_.same_on_batch:
+ for sensor in self_.sensors:
+ batch_sensor = _transform_fn(
+ [d[sensor] for d in input_batch]
+ )
+ for i, d in enumerate(batch_sensor):
+ input_batch[i][sensor] = d
+ else:
+ for i, data in enumerate(input_batch):
+ for sensor in self_.sensors:
+ input_batch[i][sensor] = _transform_fn(
+ [data[sensor]]
+ )
+ elif self_.same_on_batch:
+ input_batch = _transform_fn(input_batch)
+ else:
+ for i, data in enumerate(input_batch):
+ input_batch[i] = _transform_fn([data])[0]
+
+ return input_batch
+
+ def init(
+ *args,
+ in_keys: Sequence[str] = self.in_keys,
+ out_keys: Sequence[str] = self.out_keys,
+ sensors: Sequence[str] | None = self.sensors,
+ same_on_batch: bool = self.same_on_batch,
+ **kwargs,
+ ):
+ self_ = args[0]
+ original_init(*args, **kwargs)
+ self_.in_keys = in_keys
+ self_.out_keys = out_keys
+ self_.sensors = sensors
+ self_.same_on_batch = same_on_batch
+ self_.apply_to_data = lambda *args, **kwargs: apply_to_data(
+ self_, *args, **kwargs
+ )
+
+ transform.__init__ = init
+ return transform
+
+
+def compose(transforms: list[TFunctor]) -> TransformFunction:
+ """Compose transformations.
+
+ This function composes a given set of transformation functions, i.e. any
+ functor decorated with Transform, into a single transform.
+ """
+
+ def _preprocess_func(batch: list[DictData]) -> list[DictData]:
+ for op in transforms:
+ batch = op.apply_to_data(batch) # type: ignore
+ return batch
+
+ return _preprocess_func
+
+
+@Transform("data", "data")
+class RandomApply:
+ """Randomize the application of a given set of transformations."""
+
+ def __init__(
+ self, transforms: list[TFunctor], probability: float = 0.5
+ ) -> None:
+ """Creates an instance of RandomApply.
+
+ Args:
+ transforms (list[TFunctor]): Transformations that are applied with
+ a given probability.
+ probability (float, optional): Probability to apply
+ transformations. Defaults to 0.5.
+ """
+ self.transforms = transforms
+ self.probability = probability
+
+ def __call__(self, batch: list[DictData]) -> list[DictData]:
+ """Apply transforms with a given probability."""
+ if torch.rand(1) < self.probability:
+ for op in self.transforms:
+ batch = op.apply_to_data(batch) # type: ignore
+ return batch
diff --git a/vis4d/data/transforms/crop.py b/vis4d/data/transforms/crop.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1fa1154e1ac33bfa9415a4b2edc69477f795b89
--- /dev/null
+++ b/vis4d/data/transforms/crop.py
@@ -0,0 +1,529 @@
+"""Crop transformation."""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Callable
+from typing import List, Tuple, TypedDict, Union
+
+import numpy as np
+import torch
+
+from vis4d.common.logging import rank_zero_warn
+from vis4d.common.typing import (
+ NDArrayBool,
+ NDArrayF32,
+ NDArrayI32,
+ NDArrayI64,
+ NDArrayUI8,
+)
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.box.box2d import bbox_intersection
+
+from .base import Transform
+
+CropShape = Union[
+ Tuple[float, float],
+ Tuple[int, int],
+ List[Tuple[float, float]],
+ List[Tuple[int, int]],
+]
+CropFunc = Callable[[int, int, CropShape], Tuple[int, int]]
+
+
+class CropParam(TypedDict):
+ """Parameters for Crop."""
+
+ crop_box: NDArrayI32
+ keep_mask: NDArrayBool
+
+
+def absolute_crop(im_h: int, im_w: int, shape: CropShape) -> tuple[int, int]:
+ """Absolute crop."""
+ assert isinstance(shape, tuple)
+ assert shape[0] > 0 and shape[1] > 0
+ return (min(int(shape[0]), im_h), min(int(shape[1]), im_w))
+
+
+def absolute_range_crop(
+ im_h: int, im_w: int, shape: CropShape
+) -> tuple[int, int]:
+ """Absolute range crop."""
+ assert isinstance(shape, list)
+ assert len(shape) == 2
+ assert shape[1][0] >= shape[0][0]
+ assert shape[1][1] >= shape[0][1]
+
+ for crop in shape:
+ assert crop[0] > 0 and crop[1] > 0
+ shape_min: tuple[int, int] = (int(shape[0][0]), int(shape[0][1]))
+ shape_max: tuple[int, int] = (int(shape[1][0]), int(shape[1][1]))
+
+ crop_h = np.random.randint(
+ min(im_h, shape_min[0]), min(im_h, shape_max[0]) + 1
+ )
+ crop_w = np.random.randint(
+ min(im_w, shape_min[1]), min(im_w, shape_max[1]) + 1
+ )
+ return int(crop_h), int(crop_w)
+
+
+def relative_crop(im_h: int, im_w: int, shape: CropShape) -> tuple[int, int]:
+ """Relative crop."""
+ assert isinstance(shape, tuple)
+ assert 0 < shape[0] <= 1 and 0 < shape[1] <= 1
+ crop_h, crop_w = shape
+ return int(im_h * crop_h + 0.5), int(im_w * crop_w + 0.5)
+
+
+def relative_range_crop(
+ im_h: int, im_w: int, shape: CropShape
+) -> tuple[int, int]:
+ """Relative range crop."""
+ assert isinstance(shape, list)
+ assert len(shape) == 2
+ assert shape[1][0] >= shape[0][0]
+ assert shape[1][1] >= shape[0][1]
+ for crop in shape:
+ assert 0 < crop[0] <= 1 and 0 < crop[1] <= 1
+ scale_min: tuple[float, float] = shape[0]
+ scale_max: tuple[float, float] = shape[1]
+
+ crop_h = np.random.rand() * (scale_max[0] - scale_min[0]) + scale_min[0]
+ crop_w = np.random.rand() * (scale_max[1] - scale_min[1]) + scale_min[1]
+ return int(im_h * crop_h + 0.5), int(im_w * crop_w + 0.5)
+
+
+@Transform(
+ in_keys=[K.input_hw, K.boxes2d, K.seg_masks],
+ out_keys="transforms.crop",
+)
+class GenCropParameters:
+ """Generate the parameters for a crop operation."""
+
+ def __init__(
+ self,
+ shape: CropShape,
+ crop_func: CropFunc = absolute_crop,
+ allow_empty_crops: bool = True,
+ cat_max_ratio: float = 1.0,
+ ignore_index: int = 255,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ shape (CropShape): Image shape to be cropped to in [H, W].
+ crop_func (CropFunc, optional): Function used to generate the size
+ of the crop. Defaults to absolute_crop.
+ allow_empty_crops (bool, optional): Allow crops which result in
+ empty labels. Defaults to True.
+ cat_max_ratio (float, optional): Maximum ratio of a particular
+ class in segmentation masks after cropping. Defaults to 1.0.
+ ignore_index (int, optional): The index to ignore. Defaults to 255.
+ """
+ self.shape = shape
+ self.crop_func = crop_func
+ self.allow_empty_crops = allow_empty_crops
+ self.cat_max_ratio = cat_max_ratio
+ self.ignore_index = ignore_index
+
+ def _get_crop(
+ self, im_h: int, im_w: int, boxes: NDArrayF32 | None = None
+ ) -> tuple[NDArrayI32, NDArrayBool]:
+ """Get the crop parameters."""
+ crop_size = self.crop_func(im_h, im_w, self.shape)
+ crop_box = _sample_crop(im_h, im_w, crop_size)
+ keep_mask = _get_keep_mask(boxes, crop_box)
+ return crop_box, keep_mask
+
+ def __call__(
+ self,
+ input_hw_list: list[tuple[int, int]],
+ boxes_list: list[NDArrayF32] | None,
+ masks_list: list[NDArrayUI8] | None,
+ ) -> list[CropParam]:
+ """Compute the parameters and put them in the data dict."""
+ im_h, im_w = input_hw_list[0]
+ boxes = boxes_list[0] if boxes_list is not None else None
+ masks = masks_list[0] if masks_list is not None else None
+
+ crop_box, keep_mask = self._get_crop(im_h, im_w, boxes)
+ if (boxes is not None and len(boxes) > 0) or masks is not None:
+ # resample crop if conditions not satisfied
+ found_crop = False
+ for _ in range(10):
+ # try resampling 10 times, otherwise use last crop
+ if (self.allow_empty_crops or keep_mask.sum() != 0) and (
+ _check_seg_max_cat(
+ masks, crop_box, self.cat_max_ratio, self.ignore_index
+ )
+ ):
+ found_crop = True
+ break
+ crop_box, keep_mask = self._get_crop(im_h, im_w, boxes)
+ if not found_crop:
+ rank_zero_warn("Random crop not found within 10 resamples.")
+
+ crop_params = [
+ CropParam(crop_box=crop_box, keep_mask=keep_mask)
+ ] * len(input_hw_list)
+
+ return crop_params
+
+
+@Transform([K.input_hw, K.boxes2d], "transforms.crop")
+class GenCentralCropParameters:
+ """Generate the parameters for a central crop operation."""
+
+ def __init__(
+ self,
+ shape: CropShape,
+ crop_func: CropFunc = absolute_crop,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ shape (CropShape): Image shape to be cropped to.
+ crop_func (CropFunc, optional): Function used to generate the size
+ of the crop. Defaults to absolute_crop.
+ """
+ self.shape = shape
+ self.crop_func = crop_func
+
+ def __call__(
+ self,
+ input_hw_list: list[tuple[int, int]],
+ boxes_list: list[NDArrayF32] | None,
+ ) -> list[CropParam]:
+ """Compute the parameters and put them in the data dict."""
+ im_h, im_w = input_hw_list[0]
+ boxes = boxes_list[0] if boxes_list is not None else None
+
+ crop_size = self.crop_func(im_h, im_w, self.shape)
+ crop_box = _get_central_crop(im_h, im_w, crop_size)
+ keep_mask = _get_keep_mask(boxes, crop_box)
+ crop_params = [
+ CropParam(crop_box=crop_box, keep_mask=keep_mask)
+ ] * len(input_hw_list)
+
+ return crop_params
+
+
+@Transform([K.input_hw, K.boxes2d], "transforms.crop")
+class GenRandomSizeCropParameters:
+ """Generate the parameters for a random size crop operation.
+
+ A crop of the original image is made: the crop has a random area (H * W)
+ and a random aspect ratio. Code adapted from torchvision.
+ """
+
+ def __init__(
+ self,
+ scale: tuple[float, float] = (0.08, 1.0),
+ ratio: tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
+ ):
+ """Creates an instance of the class.
+
+ Args:
+ scale (tuple[float, float], optional): Scale range of the cropped
+ area. Defaults to (0.08, 1.0).
+ ratio (tuple[float, float], optional): Aspect ratio range of the
+ cropped area. Defaults to (3.0 / 4.0, 4.0 / 3.0).
+ """
+ self.scale = scale
+ self.ratio = np.array(ratio)
+ self.log_ratio = np.log(self.ratio)
+
+ def get_params(self, height: int, width: int) -> NDArrayI32:
+ """Get parameters for the random size crop."""
+ area = height * width
+ for _ in range(10):
+ target_area = area * np.random.uniform(
+ self.scale[0], self.scale[1]
+ )
+ aspect_ratio = np.exp(
+ np.random.uniform(self.log_ratio[0], self.log_ratio[1])
+ )
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if 0 < w <= width and 0 < h <= height:
+ i = np.random.randint(0, height - h + 1)
+ j = np.random.randint(0, width - w + 1)
+ crop_x1, crop_y1, crop_x2, crop_y2 = i, j, i + h, j + w
+ return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
+
+ # Fallback to central crop
+ in_ratio = float(width) / float(height)
+ if in_ratio < min(self.ratio):
+ w = width
+ h = int(round(w / min(self.ratio)))
+ elif in_ratio > max(self.ratio):
+ h = height
+ w = int(round(h * max(self.ratio)))
+ else: # whole image
+ w = width
+ h = height
+ i = (height - h) // 2
+ j = (width - w) // 2
+ crop_x1, crop_y1, crop_x2, crop_y2 = i, j, i + h, j + w
+ return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
+
+ def __call__(
+ self,
+ input_hw_list: list[tuple[int, int]],
+ boxes_list: list[NDArrayF32] | None,
+ ) -> list[CropParam]:
+ """Compute the parameters and put them in the data dict."""
+ im_h, im_w = input_hw_list[0]
+ boxes = boxes_list[0] if boxes_list is not None else None
+
+ crop_box = self.get_params(im_h, im_w)
+ keep_mask = _get_keep_mask(boxes, crop_box)
+
+ crop_params = [
+ CropParam(crop_box=crop_box, keep_mask=keep_mask)
+ ] * len(input_hw_list)
+
+ return crop_params
+
+
+@Transform([K.images, "transforms.crop.crop_box"], [K.images, K.input_hw])
+class CropImages:
+ """Crop Images."""
+
+ def __call__(
+ self, images: list[NDArrayF32], crop_box_list: list[NDArrayI32]
+ ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]:
+ """Crop a list of image of dimensions [N, H, W, C].
+
+ Args:
+ images (list[NDArrayF32]): The list of image.
+ crop_box (list[NDArrayI32]): The list of box to crop.
+
+ Returns:
+ list[NDArrayF32]: List of cropped image according to parameters.
+ """
+ input_hw_list = []
+ for i, (image, crop_box) in enumerate(zip(images, crop_box_list)):
+ h, w = image.shape[1], image.shape[2]
+ x1, y1, x2, y2 = crop_box
+ crop_w, crop_h = x2 - x1, y2 - y1
+ image = image[:, y1:y2, x1:x2, :]
+ input_hw = (min(crop_h, h), min(crop_w, w))
+
+ images[i] = image
+ input_hw_list.append(input_hw)
+ return images, input_hw_list
+
+
+@Transform(
+ in_keys=[
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ "transforms.crop.crop_box",
+ "transforms.crop.keep_mask",
+ ],
+ out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids],
+)
+class CropBoxes2D:
+ """Crop 2D bounding boxes."""
+
+ def __call__(
+ self,
+ boxes_list: list[NDArrayF32],
+ classes_list: list[NDArrayI64],
+ track_ids_list: list[NDArrayI64] | None,
+ crop_box_list: list[NDArrayI32],
+ keep_mask_list: list[NDArrayBool],
+ ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
+ """Crop 2D bounding boxes.
+
+ Args:
+ boxes_list (list[NDArrayF32]): The list of bounding boxes to be
+ cropped.
+ classes_list (list[NDArrayI64]): The list of the corresponding
+ classes.
+ track_ids_list (list[NDArrayI64] | None, optional): The list of
+ corresponding tracking IDs. Defaults to None.
+ crop_box_list (list[NDArrayI32]): The list of box to crop.
+ keep_mask_list (list[NDArrayBool]): Which boxes to keep.
+
+ Returns:
+ tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64]] | None:
+ List of cropped bounding boxes according to parameters.
+ """
+ for i, (boxes, classes, crop_box, keep_mask) in enumerate(
+ zip(
+ boxes_list,
+ classes_list,
+ crop_box_list,
+ keep_mask_list,
+ )
+ ):
+ x1, y1 = crop_box[:2]
+ boxes -= np.array([x1, y1, x1, y1])
+
+ boxes_list[i] = boxes[keep_mask]
+ classes_list[i] = classes[keep_mask]
+
+ if track_ids_list is not None:
+ track_ids_list[i] = track_ids_list[i][keep_mask]
+
+ return boxes_list, classes_list, track_ids_list
+
+
+@Transform([K.seg_masks, "transforms.crop.crop_box"], K.seg_masks)
+class CropSegMasks:
+ """Crop segmentation masks."""
+
+ def __call__(
+ self, masks_list: list[NDArrayUI8], crop_box_list: list[NDArrayI32]
+ ) -> list[NDArrayUI8]:
+ """Crop masks."""
+ for i, (masks, crop_box) in enumerate(zip(masks_list, crop_box_list)):
+ x1, y1, x2, y2 = crop_box
+ masks_list[i] = masks[y1:y2, x1:x2]
+ return masks_list
+
+
+@Transform(
+ in_keys=[
+ K.instance_masks,
+ "transforms.crop.crop_box",
+ "transforms.crop.keep_mask",
+ ],
+ out_keys=[K.instance_masks],
+)
+class CropInstanceMasks:
+ """Crop instance segmentation masks."""
+
+ def __call__(
+ self,
+ masks_list: list[NDArrayUI8],
+ crop_box_list: list[NDArrayI32],
+ keep_mask_list: list[NDArrayBool],
+ ) -> list[NDArrayUI8]:
+ """Crop masks."""
+ for i, (masks, crop_box) in enumerate(zip(masks_list, crop_box_list)):
+ x1, y1, x2, y2 = crop_box
+ masks = masks[:, y1:y2, x1:x2]
+ masks_list[i] = masks[keep_mask_list[i]]
+ return masks_list
+
+
+@Transform([K.depth_maps, "transforms.crop.crop_box"], K.depth_maps)
+class CropDepthMaps:
+ """Crop depth maps."""
+
+ def __call__(
+ self, depth_maps: list[NDArrayF32], crop_box_list: list[NDArrayI32]
+ ) -> list[NDArrayF32]:
+ """Crop depth maps."""
+ for i, (depth_map, crop_box) in enumerate(
+ zip(depth_maps, crop_box_list)
+ ):
+ x1, y1, x2, y2 = crop_box
+ depth_maps[i] = depth_map[y1:y2, x1:x2]
+ return depth_maps
+
+
+@Transform([K.optical_flows, "transforms.crop.crop_box"], K.optical_flows)
+class CropOpticalFlows:
+ """Crop optical flows."""
+
+ def __call__(
+ self, optical_flows: list[NDArrayF32], crop_box_list: NDArrayI32
+ ) -> list[NDArrayF32]:
+ """Crop optical flows."""
+ for i, (optical_flow, crop_box) in enumerate(
+ zip(optical_flows, crop_box_list)
+ ):
+ x1, y1, x2, y2 = crop_box
+ optical_flows[i] = optical_flow[y1:y2, x1:x2]
+ return optical_flows
+
+
+@Transform([K.intrinsics, "transforms.crop.crop_box"], K.intrinsics)
+class CropIntrinsics:
+ """Crop Intrinsics."""
+
+ def __call__(
+ self,
+ intrinsics_list: list[NDArrayF32],
+ crop_box_list: list[NDArrayI32],
+ ) -> list[NDArrayF32]:
+ """Crop camera intrinsics."""
+ for i, crop_box in enumerate(crop_box_list):
+ x1, y1 = crop_box[:2]
+ intrinsics_list[i][0, 2] -= x1
+ intrinsics_list[i][1, 2] -= y1
+ return intrinsics_list
+
+
+def _sample_crop(
+ im_h: int, im_w: int, crop_size: tuple[int, int]
+) -> NDArrayI32:
+ """Sample crop parameters according to config."""
+ margin_h = max(im_h - crop_size[0], 0)
+ margin_w = max(im_w - crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
+ return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
+
+
+def _get_central_crop(
+ im_h: int, im_w: int, crop_size: tuple[int, int]
+) -> NDArrayI32:
+ """Get central crop parameters."""
+ margin_h = max(im_h - crop_size[0], 0)
+ margin_w = max(im_w - crop_size[1], 0)
+ offset_h = margin_h // 2
+ offset_w = margin_w // 2
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
+ return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
+
+
+def _get_keep_mask(
+ boxes: NDArrayF32 | None, crop_box: NDArrayI32
+) -> NDArrayBool:
+ """Get mask for 2D annotations to keep."""
+ if boxes is None or len(boxes) == 0:
+ return np.array([], dtype=bool)
+ # will be better to compute mask intersection (if exists) instead
+ overlap = bbox_intersection(
+ torch.tensor(boxes), torch.tensor(crop_box).unsqueeze(0)
+ ).numpy()
+ return overlap.squeeze(-1) > 0
+
+
+def _check_seg_max_cat(
+ masks: NDArrayUI8 | None,
+ crop_box: NDArrayI32,
+ cat_max_ratio: float,
+ ignore_index: int = 255,
+) -> bool:
+ """Check if any category occupies more than cat_max_ratio.
+
+ Args:
+ masks (NDArrayUI8 | None): Segmentation masks.
+ crop_box (NDArrayI32): The box to crop.
+ cat_max_ratio (float): Maximum category ratio.
+ ignore_index (int, optional): The index to ignore. Defaults to 255.
+
+ Returns:
+ bool: True if no category occupies more than cat_max_ratio.
+ """
+ if cat_max_ratio >= 1.0 or masks is None:
+ return True
+ x1, y1, x2, y2 = crop_box
+ crop_masks = masks[y1:y2, x1:x2]
+ cls_ids, cnts = np.unique(crop_masks, return_counts=True)
+ cnts = cnts[cls_ids != ignore_index]
+
+ return (cnts.max() / cnts.sum()) < cat_max_ratio
diff --git a/vis4d/data/transforms/flip.py b/vis4d/data/transforms/flip.py
new file mode 100644
index 0000000000000000000000000000000000000000..42a314e17ef82fe208eecb4016356c87b27f83d0
--- /dev/null
+++ b/vis4d/data/transforms/flip.py
@@ -0,0 +1,359 @@
+"""Horizontal flip augmentation."""
+
+import numpy as np
+import torch
+
+from vis4d.common.typing import NDArrayF32, NDArrayUI8
+from vis4d.data.const import AxisMode
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.geometry.rotation import (
+ euler_angles_to_matrix,
+ matrix_to_euler_angles,
+ matrix_to_quaternion,
+ quaternion_to_matrix,
+)
+
+from .base import Transform
+
+
+@Transform(K.images, K.images)
+class FlipImages:
+ """Flip a list of numpy image array of shape [N, H, W, C]."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipImage.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+
+ Raises:
+ ValueError: If direction is not horizontal or vertical.
+ """
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {direction} not known!")
+ self.direction = direction
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Execute flipping op.
+
+ Args:
+ image (NDArrayF32): [N, H, W, C] array of image.
+
+ Returns:
+ list[NDArrayF32]: [N, H, W, C] array of flipped image.
+ """
+ for i, image in enumerate(images):
+ image_ = torch.from_numpy(image)
+ if self.direction == "horizontal":
+ images[i] = image_.flip(2).numpy()
+ if self.direction == "vertical":
+ images[i] = image_.flip(1).numpy()
+ return images
+
+
+@Transform(in_keys=(K.boxes2d, K.images), out_keys=(K.boxes2d,))
+class FlipBoxes2D:
+ """Flip a list of 2D bounding boxes."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipBoxes2D.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+
+ Raises:
+ ValueError: If direction is not horizontal or vertical.
+ """
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {direction} not known!")
+ self.direction = direction
+
+ def __call__(
+ self, boxes_list: list[NDArrayF32], images: list[NDArrayF32]
+ ) -> list[NDArrayF32]:
+ """Execute flipping op.
+
+ Args:
+ boxes (list[NDArrayF32]): List of [M, 4] array of boxes.
+ image (list[NDArrayF32]): List of [N, H, W, C] array of image.
+
+ Returns:
+ list[NDArrayF32]: List of [M, 4] array of flipped boxes.
+ """
+ for i, (boxes, image) in enumerate(zip(boxes_list, images)):
+ if self.direction == "horizontal":
+ im_width = image.shape[2]
+ tmp = im_width - boxes[..., 2::4]
+ boxes[..., 2::4] = im_width - boxes[..., 0::4]
+ boxes[..., 0::4] = tmp
+ elif self.direction == "vertical":
+ im_height = image.shape[1]
+ tmp = im_height - boxes[..., 3::4]
+ boxes[..., 3::4] = im_height - boxes[..., 1::4]
+ boxes[..., 1::4] = tmp
+ boxes_list[i] = boxes
+ return boxes_list
+
+
+@Transform(K.seg_masks, K.seg_masks)
+class FlipSegMasks:
+ """Flip segmentation masks."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipSemanticMasks.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+
+ Raises:
+ ValueError: If direction is not horizontal or vertical.
+ """
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {direction} not known!")
+ self.direction = direction
+
+ def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]:
+ """Execute flipping op.
+
+ Args:
+ masks (NDArrayUI8): [H, W] array of masks.
+
+ Returns:
+ list[NDArrayUI8]: [H, W] array of flipped masks.
+ """
+ for i, mask in enumerate(masks):
+ mask_ = torch.from_numpy(mask)
+ if self.direction == "horizontal":
+ mask = mask_.flip(1).numpy()
+ if self.direction == "vertical":
+ mask = mask_.flip(0).numpy()
+ masks[i] = mask
+ return masks
+
+
+@Transform(K.depth_maps, K.depth_maps)
+class FlipDepthMaps:
+ """Flip depth map."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipDepth.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+ """
+ self.direction = direction
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {self.direction} not known!")
+
+ def __call__(self, depths: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Execute flipping op.
+
+ Args:
+ depths (list[NDArrayF32]): Each is a [H, W] array of depth.
+
+ Returns:
+ list[NDArrayF32]: Each is a [H, W] array of flipped depth.
+ """
+ for i, depth in enumerate(depths):
+ depth_ = torch.from_numpy(depth)
+ if self.direction == "horizontal":
+ depths[i] = depth_.flip(1).numpy()
+ if self.direction == "vertical":
+ depths[i] = depth_.flip(0).numpy()
+
+ return depths
+
+
+@Transform(K.optical_flows, K.optical_flows)
+class FlipOpticalFlows:
+ """Flip optical flow map."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipOpticalFlow.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+ """
+ self.direction = direction
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {self.direction} not known!")
+
+ def __call__(self, flows: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Execute flipping op.
+
+ Args:
+ flows (NDArrayF32): Each is a [H, W, 2] array of optical flow.
+
+ Returns:
+ list[NDArrayF32]: Each is a [H, W, 2] array of flipped optical
+ flow.
+ """
+ for i, flow in enumerate(flows):
+ flow_ = torch.from_numpy(flow)
+ if self.direction == "horizontal":
+ image_flipped = flow_.flip(1).numpy()
+ image_flipped[..., 0] *= -1
+ flows[i] = image_flipped
+ if self.direction == "vertical":
+ image_flipped = flow_.flip(0).numpy()
+ image_flipped[..., 1] *= -1
+ flows[i] = image_flipped
+ return flows
+
+
+@Transform(K.instance_masks, K.instance_masks)
+class FlipInstanceMasks:
+ """Flip instance masks."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipInstanceMasks.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+
+ Raises:
+ ValueError: If direction is not horizontal or vertical.
+ """
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {direction} not known!")
+ self.direction = direction
+
+ def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]:
+ """Execute flipping op.
+
+ Args:
+ masks (list[NDArrayUI8]): List of [N, H, W] array of masks.
+
+ Returns:
+ list[NDArrayUI8]: List of [N, H, W] array of flipped masks.
+ """
+ for i, mask in enumerate(masks):
+ mask_ = torch.from_numpy(mask)
+ if self.direction == "horizontal":
+ mask = mask_.flip(2).numpy()
+ if self.direction == "vertical":
+ mask = mask_.flip(1).numpy()
+ masks[i] = mask
+ return masks
+
+
+def get_axis(direction: str, axis_mode: AxisMode) -> int:
+ """Get axis number of certain direction given axis mode.
+
+ Args:
+ direction (str): One of horizontal, vertical and lateral.
+ axis_mode (AxisMode): axis mode.
+
+ Returns:
+ int: Number of axis in certain direction.
+ """
+ if direction not in {"horizontal", "lateral", "vertical"}:
+ raise ValueError(f"Direction {direction} not known!")
+ coord_mapping = {
+ AxisMode.ROS: {"horizontal": 0, "lateral": 1, "vertical": 2},
+ AxisMode.OPENCV: {"horizontal": 0, "vertical": 1, "lateral": 2},
+ }
+ return coord_mapping[axis_mode][direction]
+
+
+@Transform(in_keys=(K.boxes3d, K.axis_mode), out_keys=(K.boxes3d,))
+class FlipBoxes3D:
+ """Flip 3D bounding box array."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipBoxes3D.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+ """
+ self.direction = direction
+
+ def __call__(
+ self, boxes_list: list[NDArrayF32], axis_mode_list: list[AxisMode]
+ ) -> list[NDArrayF32]:
+ """Execute flipping."""
+ for i, (boxes, axis_mode) in enumerate(
+ zip(boxes_list, axis_mode_list)
+ ):
+ axis = get_axis(self.direction, axis_mode)
+ angle_dir = (
+ "vertical" if self.direction == "horizontal" else "lateral"
+ )
+ angles_axis = get_axis(angle_dir, axis_mode)
+ boxes[:, axis] *= -1.0
+ angles = matrix_to_euler_angles(
+ quaternion_to_matrix(torch.from_numpy(boxes[:, 6:]))
+ )
+ angles[:, angles_axis] = np.pi - angles[:, angles_axis]
+ boxes[:, 6:] = matrix_to_quaternion(
+ euler_angles_to_matrix(angles)
+ ).numpy()
+
+ boxes_list[i] = boxes
+
+ return boxes_list
+
+
+@Transform(in_keys=(K.points3d, K.axis_mode), out_keys=(K.points3d,))
+class FlipPoints3D:
+ """Flip pointcloud array."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipBoxes2D.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+ """
+ self.direction = direction
+
+ def __call__(
+ self, points3d_list: list[NDArrayF32], axis_mode_list: list[AxisMode]
+ ) -> list[NDArrayF32]:
+ """Execute flipping."""
+ for i, (points3d, axis_mode) in enumerate(
+ zip(points3d_list, axis_mode_list)
+ ):
+ points3d[:, get_axis(self.direction, axis_mode)] *= -1.0
+ points3d_list[i] = points3d
+ return points3d_list
+
+
+@Transform(in_keys=(K.intrinsics, K.images), out_keys=(K.intrinsics,))
+class FlipIntrinsics:
+ """Modify intrinsics for image flip."""
+
+ def __init__(self, direction: str = "horizontal"):
+ """Creates an instance of FlipIntrinsics.
+
+ Args:
+ direction (str, optional): Either vertical or horizontal.
+ Defaults to "horizontal".
+
+ Raises:
+ ValueError: If direction is not horizontal or vertical.
+ """
+ if direction not in ["horizontal", "vertical"]:
+ raise ValueError(f"Direction {direction} not known!")
+ self.direction = direction
+
+ def __call__(
+ self, intrinsics_list: list[NDArrayF32], images: list[NDArrayF32]
+ ) -> list[NDArrayF32]:
+ """Execute flipping."""
+ for i, (intrinsics, image) in enumerate(zip(intrinsics_list, images)):
+ if self.direction == "horizontal":
+ center = image.shape[2] / 2
+ intrinsics[0, 2] = center - intrinsics[0, 2] + center
+ elif self.direction == "vertical":
+ center = image.shape[1] / 2
+ intrinsics[1, 2] = center - intrinsics[1, 2] + center
+ intrinsics_list[i] = intrinsics
+ return intrinsics_list
diff --git a/vis4d/data/transforms/mask.py b/vis4d/data/transforms/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2b422b6483560b3e8c696b2dcf5cf87c1b2a849
--- /dev/null
+++ b/vis4d/data/transforms/mask.py
@@ -0,0 +1,74 @@
+"""Segmentation/Instance Mask Transform."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.typing import NDArrayI64, NDArrayUI8
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+
+@Transform(
+ in_keys=(K.boxes2d_classes, K.instance_masks),
+ out_keys=K.seg_masks,
+)
+class ConvertInstanceMaskToSegMask:
+ """Merge all instance masks into a single segmentation map."""
+
+ def __call__(
+ self, classes_list: list[NDArrayI64], masks_list: list[NDArrayUI8]
+ ) -> list[NDArrayUI8]:
+ """Execute conversion.
+
+ Args:
+ classes_list (list[NDArrayI64]): List of Array of class ids, shape
+ [N,].
+ masks_list (NDArrayU8): List of array of instance masks, shape
+ [N, H, W].
+
+ Returns:
+ list[NDArrayU8]: List of Segmentation mask, shape [H, W].
+ """
+ seg_masks = []
+ for classes, masks in zip(classes_list, masks_list):
+ classes = np.asarray(classes, dtype=masks.dtype)
+ target = np.max(masks * classes[:, None, None], axis=0)
+ # discard overlapping instances
+ target[np.sum(masks, axis=0) > 1] = 255
+
+ seg_masks.append(target)
+ return seg_masks
+
+
+@Transform(
+ in_keys=K.boxes2d_classes,
+ out_keys=K.boxes2d_classes,
+)
+class RemappingCategories:
+ """Remap classes using a mapping list."""
+
+ def __init__(self, mapping: list[int]):
+ """Initialize remapping.
+
+ Args:
+ mapping (List[int]): List of class ids, such that classes will be
+ mapped to its location in the list.
+ """
+ self.mapping = mapping
+
+ def __call__(self, classes_list: list[NDArrayI64]) -> list[NDArrayI64]:
+ """Execute remapping.
+
+ Args:
+ classes_list (list[NDArrayI64]): List of array of class ids, shape
+ [N,].
+
+ Returns:
+ list[NDArrayI64]: List of array of remapped class ids, shape [N,].
+ """
+ for i, classes in enumerate(classes_list):
+ for j, class_ in enumerate(classes):
+ classes_list[i][j] = self.mapping.index(class_)
+ return classes_list
diff --git a/vis4d/data/transforms/mixup.py b/vis4d/data/transforms/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..584d45462261c90b14a6eaa337ddb077713a41c0
--- /dev/null
+++ b/vis4d/data/transforms/mixup.py
@@ -0,0 +1,376 @@
+"""Mixup data augmentation."""
+
+from __future__ import annotations
+
+import random
+from typing import TypedDict
+
+import numpy as np
+import torch
+
+from vis4d.common.typing import NDArrayF32, NDArrayI64
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.box.box2d import bbox_intersection
+
+from .base import Transform
+from .resize import get_resize_shape, resize_image
+
+
+class MixupParam(TypedDict):
+ """Typed dict for mixup parameters.
+
+ The parameters are used to mixup a pair of items in a batch. Usually, the
+ pairs are selected as follows:
+ (0, bs - 1), (1, bs - 2), ..., (bs // 2, bs // 2)
+ where bs is the batch size. The batch size must be even for mixup to work.
+ """
+
+ ratio: float
+ im_shape: tuple[int, int]
+ im_scale: tuple[float, float]
+ other_ori_hw: tuple[int, int]
+ other_new_hw: tuple[int, int]
+ crop_coord: tuple[int, int, int, int]
+ pad_hw: tuple[int, int]
+ pad_value: float
+
+
+@Transform(in_keys=(K.images,), out_keys=("transforms.mixup",))
+class GenMixupParameters:
+ """Generate the parameters for a mixup operation."""
+
+ NUM_SAMPLES = 2
+
+ def __init__(
+ self,
+ out_shape: tuple[int, int],
+ mixup_ratio_dist: str = "beta",
+ alpha: float = 1.0,
+ const_ratio: float = 0.5,
+ scale_range: tuple[float, float] = (1.0, 1.0),
+ pad_value: float = 0.0,
+ ) -> None:
+ """Init function.
+
+ Args:
+ out_shape (tuple[int, int]): Output shape of the mixed up images.
+ mixup_ratio_dist (str, optional): Distribution for sampling the
+ mixup ratio (i.e., lambda). Options are "beta" and "const".
+ Defaults to "beta". If "const", the mixup ratio will be fixed
+ to the value of `const_ratio`. Otherwise, the mixup ratio will
+ be sampled from a beta distribution with parameters `alpha`.
+ alpha (float, optional): Parameter for beta distribution used for
+ sampling the mixup ratio (i.e., lambda). Defaults to 1.0.
+ const_ratio (float, optional): Constant mixup ratio. Defaults to
+ 0.5.
+ scale_range (tuple[float, float], optional): Range for
+ random scale jitter. Defaults to (1.0, 1.0).
+ pad_value (float, optional): Value for padding the mixed up image.
+ Defaults to 0.0.
+ """
+ assert mixup_ratio_dist in {
+ "beta",
+ "const",
+ }, "Mixup ratio distribution must be either 'beta' or 'const'."
+ self.out_shape = out_shape
+ self.mixup_ratio_dist = mixup_ratio_dist
+ self.alpha = alpha
+ self.const_ratio = const_ratio
+ self.scale_range = scale_range
+ self.pad_value = pad_value
+
+ def __call__(self, images: list[NDArrayF32]) -> list[MixupParam]:
+ """Generate parameters for MixUp."""
+ batch_size = len(images)
+ assert batch_size % 2 == 0, "MixUp only supports even batch size."
+
+ if self.mixup_ratio_dist == "beta":
+ ratio = np.random.beta(self.alpha, self.alpha)
+ else:
+ ratio = self.const_ratio
+ jit_factor = random.uniform(*self.scale_range)
+
+ h, w = self.out_shape
+ ori_img, other_img = images[0], images[1]
+ ori_h, ori_w = ori_img.shape[1], ori_img.shape[2]
+ other_ori_h, other_ori_w = other_img.shape[1], other_img.shape[2]
+ other_ori_hw = (other_ori_h, other_ori_w)
+ h_i, w_i = get_resize_shape(other_ori_hw, (h, w), keep_ratio=True)
+ h_i, w_i = int(jit_factor * h_i), int(jit_factor * w_i)
+ pad_shape = (max(h_i, ori_h), max(w_i, ori_w))
+
+ x_offset, y_offset = 0, 0
+ if pad_shape[0] > ori_h:
+ y_offset = random.randint(0, pad_shape[0] - ori_h)
+ if pad_shape[1] > ori_w:
+ x_offset = random.randint(0, pad_shape[1] - ori_w)
+
+ parameter_list = [
+ MixupParam(
+ ratio=ratio,
+ im_scale=(h_i / other_ori_h, w_i / other_ori_w),
+ im_shape=(h_i, w_i),
+ other_ori_hw=other_ori_hw,
+ other_new_hw=(min(h_i, ori_h), min(w_i, ori_w)),
+ pad_hw=pad_shape,
+ pad_value=self.pad_value,
+ crop_coord=(
+ x_offset,
+ y_offset,
+ x_offset + ori_w,
+ y_offset + ori_h,
+ ),
+ )
+ for _ in range(batch_size)
+ ]
+ return parameter_list
+
+
+@Transform(in_keys=(K.images, "transforms.mixup"), out_keys=(K.images,))
+class MixupImages:
+ """Mixup a batch of images."""
+
+ NUM_SAMPLES = 2
+
+ def __init__(
+ self, interpolation: str = "bilinear", imresize_backend: str = "torch"
+ ) -> None:
+ """Init function.
+
+ Args:
+ interpolation (str, optional): Interpolation method for resizing
+ the other image. Defaults to "bilinear".
+ imresize_backend (str): One of torch, cv2. Defaults to torch.
+ """
+ self.interpolation = interpolation
+ self.imresize_backend = imresize_backend
+ assert imresize_backend in {
+ "torch",
+ "cv2",
+ }, f"Invalid imresize backend: {imresize_backend}"
+
+ def __call__(
+ self, images: list[NDArrayF32], mixup_parameters: list[MixupParam]
+ ) -> list[NDArrayF32]:
+ """Execute image mixup operation."""
+ batch_size = len(images)
+ assert (
+ batch_size % self.NUM_SAMPLES == 0
+ ), "Batch size must be even for mixup!"
+
+ mixup_images = []
+ for i in range(0, batch_size, self.NUM_SAMPLES):
+ j = i + 1
+ ori_img, other_img = images[i], images[j]
+ h_i, w_i = mixup_parameters[i]["im_shape"]
+ c = ori_img.shape[-1]
+
+ # resize, scale jitter other image
+ other_img = resize_image(
+ other_img,
+ (h_i, w_i),
+ self.interpolation,
+ backend=self.imresize_backend,
+ )
+
+ # pad, optionally random crop other image
+ padded_img = np.full(
+ (1, *mixup_parameters[i]["pad_hw"], c),
+ fill_value=mixup_parameters[i]["pad_value"],
+ dtype=np.float32,
+ )
+ padded_img[:, :h_i, :w_i, :] = other_img
+ x1_c, y1_c, x2_c, y2_c = mixup_parameters[i]["crop_coord"]
+ padded_cropped_img = padded_img[:, y1_c:y2_c, x1_c:x2_c, :]
+
+ # mix ori and other
+ ratio = mixup_parameters[i]["ratio"]
+ mixup_image = ratio * ori_img + (1 - ratio) * padded_cropped_img
+ mixup_images += [mixup_image for _ in range(self.NUM_SAMPLES)]
+ return mixup_images
+
+
+@Transform(
+ in_keys=(K.categories, "transforms.mixup"), out_keys=(K.categories,)
+)
+class MixupCategories:
+ """Mixup a batch of categories."""
+
+ NUM_SAMPLES = 2
+
+ def __init__(self, num_classes: int, label_smoothing: float = 0.1) -> None:
+ """Creates an instance of MixupCategories.
+
+ Args:
+ num_classes (int): Number of classes.
+ label_smoothing (float, optional): Label smoothing parameter for
+ the mixup of categories. Defaults to 0.1.
+ """
+ self.num_classes = num_classes
+ self.label_smoothing = label_smoothing
+
+ def _label_smoothing(
+ self,
+ cat_1: NDArrayF32,
+ cat_2: NDArrayF32,
+ ratio: float,
+ ) -> NDArrayF32:
+ """Apply label smoothing to two category labels."""
+ lam = np.array(ratio, dtype=np.float32)
+ off_value = np.array(
+ self.label_smoothing / self.num_classes, dtype=np.float32
+ )
+ on_value = np.array(
+ 1 - self.label_smoothing + off_value, dtype=np.float32
+ )
+ categories_1: NDArrayF32 = (
+ np.zeros((self.num_classes,), dtype=np.float32) + off_value
+ )
+ categories_2: NDArrayF32 = (
+ np.zeros((self.num_classes,), dtype=np.float32) + off_value
+ )
+ categories_1 = cat_1 * on_value
+ categories_2 = cat_2 * on_value
+ mixed = categories_1 * lam + categories_2 * (1 - lam)
+ return mixed.astype(np.float32)
+
+ def __call__(
+ self,
+ categories: list[NDArrayF32],
+ mixup_parameters: list[MixupParam],
+ ) -> list[NDArrayF32]:
+ """Execute the categories mixup operation."""
+ batch_size = len(categories)
+ assert (
+ batch_size % self.NUM_SAMPLES == 0
+ ), "Batch size must be even for mixup!"
+
+ smooth_categories = [np.empty(0, dtype=np.float32)] * batch_size
+ for i in range(0, batch_size, self.NUM_SAMPLES):
+ j = i + 1
+ smooth_categories[i] = self._label_smoothing(
+ categories[i], categories[j], mixup_parameters[i]["ratio"]
+ )
+ smooth_categories[j] = smooth_categories[i]
+ return smooth_categories
+
+
+@Transform(
+ in_keys=(
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ "transforms.mixup",
+ ),
+ out_keys=(K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids),
+)
+class MixupBoxes2D:
+ """Mixup a batch of boxes."""
+
+ NUM_SAMPLES = 2
+
+ def __init__(
+ self, clip_inside_image: bool = True, max_track_ids: int = 1000
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ clip_inside_image (bool): Whether to clip the boxes to be inside
+ the image. Defaults to True.
+ max_track_ids (int): The maximum number of track ids. Defaults to
+ 1000.
+ """
+ self.clip_inside_image = clip_inside_image
+ self.max_track_ids = max_track_ids
+
+ def __call__(
+ self,
+ boxes_list: list[NDArrayF32],
+ classes_list: list[NDArrayI64],
+ track_ids_list: list[NDArrayI64] | None,
+ mixup_parameters: list[MixupParam],
+ ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
+ """Execute the boxes2d mixup operation."""
+ batch_size = len(boxes_list)
+ assert (
+ batch_size % self.NUM_SAMPLES == 0
+ ), "Batch size must be even for mixup!"
+
+ mixup_boxes_list = []
+ mixup_classes_list = []
+ mixup_track_ids_list: list[NDArrayI64] | None = (
+ [] if track_ids_list is not None else None
+ )
+ for i in range(0, batch_size, self.NUM_SAMPLES):
+ j = i + 1
+ ori_boxes, other_boxes = boxes_list[i].copy(), boxes_list[j].copy()
+ ori_classes, other_classes = (
+ classes_list[i].copy(),
+ classes_list[j].copy(),
+ )
+
+ crop_coord = mixup_parameters[i]["crop_coord"]
+ im_scale = mixup_parameters[i]["im_scale"]
+ x1_c, y1_c, _, _ = crop_coord
+
+ if len(other_boxes) == 0:
+ continue
+ # adjust boxes to new image size and origin coord
+ other_boxes[:, [0, 2]] = (
+ im_scale[1] * other_boxes[:, [0, 2]] - x1_c
+ )
+ other_boxes[:, [1, 3]] = (
+ im_scale[0] * other_boxes[:, [1, 3]] - y1_c
+ )
+ # filter boxes outside other image
+ crop_box = torch.tensor(crop_coord).unsqueeze(0)
+ is_overlap = (
+ bbox_intersection(torch.from_numpy(other_boxes), crop_box)
+ .squeeze(-1)
+ .numpy()
+ )
+ other_boxes = other_boxes[is_overlap > 0]
+ other_classes = other_classes[is_overlap > 0]
+
+ # mixup track ids if available
+ if track_ids_list is not None:
+ assert mixup_track_ids_list is not None
+ ori_track_ids = track_ids_list[i].copy()
+ other_track_ids = track_ids_list[j].copy()
+ if (
+ len(ori_track_ids) > 0
+ and max(ori_track_ids) >= self.max_track_ids
+ ) or (
+ len(other_track_ids) > 0
+ and max(other_track_ids) >= self.max_track_ids
+ ):
+ raise ValueError(
+ f"Track id exceeds maximum track id"
+ f"{self.max_track_ids}!"
+ )
+ other_track_ids += self.max_track_ids
+ other_track_ids = other_track_ids[is_overlap > 0]
+ mixup_track_ids: NDArrayI64 = np.concatenate(
+ (ori_track_ids, other_track_ids), 0
+ )
+ mixup_track_ids_list += [
+ mixup_track_ids for _ in range(self.NUM_SAMPLES)
+ ]
+
+ if self.clip_inside_image:
+ new_h, new_w = mixup_parameters[i]["other_new_hw"]
+ other_boxes[:, [0, 2]] = np.clip(
+ other_boxes[:, [0, 2]], 0, new_w
+ )
+ other_boxes[:, [1, 3]] = np.clip(
+ other_boxes[:, [1, 3]], 0, new_h
+ )
+ mixup_boxes = np.concatenate((ori_boxes, other_boxes), axis=0)
+ mixup_classes = np.concatenate(
+ (ori_classes, other_classes), axis=0
+ )
+ mixup_boxes_list += [mixup_boxes for _ in range(self.NUM_SAMPLES)]
+ mixup_classes_list += [
+ mixup_classes for _ in range(self.NUM_SAMPLES)
+ ]
+ return mixup_boxes_list, mixup_classes_list, mixup_track_ids_list
diff --git a/vis4d/data/transforms/mosaic.py b/vis4d/data/transforms/mosaic.py
new file mode 100644
index 0000000000000000000000000000000000000000..60845b868c4408c7ecc48516006709a3fbf1bbd0
--- /dev/null
+++ b/vis4d/data/transforms/mosaic.py
@@ -0,0 +1,358 @@
+"""Mosaic transformation.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import random
+from typing import TypedDict
+
+import numpy as np
+
+from vis4d.common.typing import NDArrayF32, NDArrayI64
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+from .crop import _get_keep_mask
+from .resize import resize_image
+
+
+class MosaicParam(TypedDict):
+ """Parameters for Mosaic."""
+
+ out_shape: tuple[int, int]
+ paste_coords: list[tuple[int, int, int, int]]
+ crop_coords: list[tuple[int, int, int, int]]
+ im_shapes: list[tuple[int, int]]
+ im_scales: list[tuple[float, float]]
+
+
+def mosaic_combine(
+ index: int,
+ center: tuple[int, int],
+ im_hw: tuple[int, int],
+ out_shape: tuple[int, int],
+) -> tuple[tuple[int, int, int, int], tuple[int, int, int, int]]:
+ """Compute the mosaic parameters for the image at the current index.
+
+ Index:
+ 0 = top_left, 1 = top_right, 3 = bottom_left, 4 = bottom_right
+ """
+ assert index in {0, 1, 2, 3}
+ if index == 0:
+ # index0 to top left part of image
+ x1, y1, x2, y2 = (
+ max(center[1] - im_hw[1], 0),
+ max(center[0] - im_hw[0], 0),
+ center[1],
+ center[0],
+ )
+ crop_coord = (
+ im_hw[1] - (x2 - x1),
+ im_hw[0] - (y2 - y1),
+ im_hw[1],
+ im_hw[0],
+ )
+ elif index == 1:
+ # index1 to top right part of image
+ x1, y1, x2, y2 = (
+ center[1],
+ max(center[0] - im_hw[0], 0),
+ min(center[1] + im_hw[1], out_shape[1] * 2),
+ center[0],
+ )
+ crop_coord = (
+ 0,
+ im_hw[0] - (y2 - y1),
+ min(im_hw[1], x2 - x1),
+ im_hw[0],
+ )
+ elif index == 2:
+ # index2 to bottom left part of image
+ x1, y1, x2, y2 = (
+ max(center[1] - im_hw[1], 0),
+ center[0],
+ center[1],
+ min(out_shape[0] * 2, center[0] + im_hw[0]),
+ )
+ crop_coord = (
+ im_hw[1] - (x2 - x1),
+ 0,
+ im_hw[1],
+ min(y2 - y1, im_hw[0]),
+ )
+ else:
+ # index3 to bottom right part of image
+ x1, y1, x2, y2 = (
+ center[1],
+ center[0],
+ min(center[1] + im_hw[1], out_shape[1] * 2),
+ min(out_shape[0] * 2, center[0] + im_hw[0]),
+ )
+ crop_coord = 0, 0, min(im_hw[1], x2 - x1), min(y2 - y1, im_hw[0])
+
+ paste_coord = x1, y1, x2, y2
+ return paste_coord, crop_coord
+
+
+@Transform(K.input_hw, ["transforms.mosaic"])
+class GenMosaicParameters:
+ """Generate the parameters for a mosaic operation.
+
+ Given 4 images, mosaic transform combines them into
+ one output image. The output image is composed of the parts from each sub-
+ image.
+
+ mosaic transform
+ center_x
+ +------------------------------+
+ | pad | pad |
+ | +-----------+ |
+ | | | |
+ | | image1 |--------+ |
+ | | | | |
+ | | | image2 | |
+ center_y |----+-------------+-----------|
+ | | cropped | |
+ |pad | image3 | image4 |
+ | | | |
+ +----|-------------+-----------+
+ | |
+ +-------------+
+
+ The mosaic transform steps are as follows:
+
+ 1. Choose the mosaic center as the intersections of 4 images.
+ 2. Get the left top image according to the index, and randomly
+ sample another 3 images from the dataset.
+ 3. Sub image will be cropped if image is larger than mosaic patch.
+
+ Args:
+ out_shape (tuple[int, int]): The output shape of the mosaic transform.
+ center_ratio_range (tuple[float, float]): The range of the ratio of
+ the center of the mosaic patch to the output image size.
+ """
+
+ NUM_SAMPLES = 4
+
+ def __init__(
+ self,
+ out_shape: tuple[int, int],
+ center_ratio_range: tuple[float, float] = (0.5, 1.5),
+ ) -> None:
+ """Creates an instance of the class."""
+ self.out_shape = out_shape
+ self.center_ratio_range = center_ratio_range
+
+ def __call__(self, input_hw: list[tuple[int, int]]) -> list[MosaicParam]:
+ """Compute the parameters and put them in the data dict."""
+ assert (
+ len(input_hw) % self.NUM_SAMPLES == 0
+ ), "Input number of images must be a multiple of 4 for Mosaic."
+ h, w = self.out_shape
+ # mosaic center x, y
+ center_y = int(random.uniform(*self.center_ratio_range) * h)
+ center_x = int(random.uniform(*self.center_ratio_range) * w)
+ center = (center_y, center_x)
+
+ mosaic_params = []
+ for i in range(0, len(input_hw), self.NUM_SAMPLES):
+ paste_coords, crop_coords, im_scales, im_shapes = [], [], [], []
+ for idx, ori_hw in enumerate(input_hw[i : i + self.NUM_SAMPLES]):
+ # compute the resize shape
+ scale_ratio_i = min(h / ori_hw[0], w / ori_hw[1])
+ h_i = int(ori_hw[0] * scale_ratio_i)
+ w_i = int(ori_hw[1] * scale_ratio_i)
+
+ # compute the combine parameters
+ paste_coord, crop_coord = mosaic_combine(
+ idx, center, (h_i, w_i), self.out_shape
+ )
+ paste_coords.append(paste_coord)
+ crop_coords.append(crop_coord)
+ im_shapes.append((h_i, w_i))
+ im_scales.append((scale_ratio_i, scale_ratio_i))
+ mosaic_params += [
+ MosaicParam(
+ out_shape=self.out_shape,
+ paste_coords=paste_coords,
+ crop_coords=crop_coords,
+ im_shapes=im_shapes,
+ im_scales=im_scales,
+ )
+ for _ in range(self.NUM_SAMPLES)
+ ]
+
+ return mosaic_params
+
+
+@Transform(
+ in_keys=[
+ K.images,
+ "transforms.mosaic.out_shape",
+ "transforms.mosaic.paste_coords",
+ "transforms.mosaic.crop_coords",
+ "transforms.mosaic.im_shapes",
+ ],
+ out_keys=[K.images, K.input_hw],
+)
+class MosaicImages:
+ """Apply Mosaic to images."""
+
+ NUM_SAMPLES = 4
+
+ def __init__(
+ self,
+ pad_value: float = 114.0,
+ interpolation: str = "bilinear",
+ imresize_backend: str = "torch",
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ pad_value (float): The value to pad the image with. Defaults to
+ 114.0.
+ interpolation (str): Interpolation mode for resizing image.
+ Defaults to bilinear.
+ imresize_backend (str): One of torch, cv2. Defaults to torch.
+ """
+ self.pad_value = pad_value
+ self.interpolation = interpolation
+ self.imresize_backend = imresize_backend
+ assert imresize_backend in {
+ "torch",
+ "cv2",
+ }, f"Invalid imresize backend: {imresize_backend}"
+
+ def __call__(
+ self,
+ images: list[NDArrayF32],
+ out_shape: list[tuple[int, int]],
+ paste_coords: list[list[tuple[int, int, int, int]]],
+ crop_coords: list[list[tuple[int, int, int, int]]],
+ im_shapes: list[list[tuple[int, int]]],
+ ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]:
+ """Resize an image of dimensions [N, H, W, C]."""
+ h, w = out_shape[0]
+ c = images[0].shape[-1]
+
+ mosaic_imgs = []
+ for i in range(0, len(images), self.NUM_SAMPLES):
+ mosaic_img = np.full(
+ (1, h * 2, w * 2, c), self.pad_value, dtype=np.float32
+ )
+ for idx, img in enumerate(images[i : i + self.NUM_SAMPLES]):
+ # resize current image
+ h_i, w_i = im_shapes[i][idx]
+ img_ = resize_image(
+ img,
+ (h_i, w_i),
+ self.interpolation,
+ backend=self.imresize_backend,
+ )
+
+ x1_p, y1_p, x2_p, y2_p = paste_coords[i][idx]
+ x1_c, y1_c, x2_c, y2_c = crop_coords[i][idx]
+
+ # crop and paste image
+ mosaic_img[:, y1_p:y2_p, x1_p:x2_p, :] = img_[
+ :, y1_c:y2_c, x1_c:x2_c, :
+ ]
+ mosaic_imgs += [mosaic_img for _ in range(self.NUM_SAMPLES)]
+ return mosaic_imgs, [(m.shape[1], m.shape[2]) for m in mosaic_imgs]
+
+
+@Transform(
+ in_keys=[
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ "transforms.mosaic.paste_coords",
+ "transforms.mosaic.crop_coords",
+ "transforms.mosaic.im_scales",
+ ],
+ out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids],
+)
+class MosaicBoxes2D:
+ """Apply Mosaic to a list of 2D bounding boxes."""
+
+ NUM_SAMPLES = 4
+
+ def __init__(
+ self, clip_inside_image: bool = True, max_track_ids: int = 1000
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ clip_inside_image (bool): Whether to clip the boxes to be inside
+ the image. Defaults to True.
+ max_track_ids (int): The maximum number of track ids. Defaults to
+ 1000.
+ """
+ self.clip_inside_image = clip_inside_image
+ self.max_track_ids = max_track_ids
+
+ def __call__(
+ self,
+ boxes: list[NDArrayF32],
+ classes: list[NDArrayI64],
+ track_ids: list[NDArrayI64] | None,
+ paste_coords: list[list[tuple[int, int, int, int]]],
+ crop_coords: list[list[tuple[int, int, int, int]]],
+ im_scales: list[list[tuple[float, float]]],
+ ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
+ """Apply Mosaic to 2D bounding boxes."""
+ new_boxes, new_classes = [], []
+ new_track_ids: list[NDArrayI64] | None = (
+ [] if track_ids is not None else None
+ )
+ for i in range(0, len(boxes), self.NUM_SAMPLES):
+ for idx in range(self.NUM_SAMPLES):
+ j = i + idx
+
+ x1_p, y1_p, x2_p, y2_p = paste_coords[i][idx]
+ x1_c, y1_c, _, _ = crop_coords[i][idx]
+
+ pw = x1_p - x1_c
+ ph = y1_p - y1_c
+ boxes[j][:, [0, 2]] = (
+ im_scales[i][idx][1] * boxes[j][:, [0, 2]] + pw
+ )
+ boxes[j][:, [1, 3]] = (
+ im_scales[i][idx][0] * boxes[j][:, [1, 3]] + ph
+ )
+
+ keep_mask = _get_keep_mask(
+ boxes[j], np.array([x1_p, y1_p, x2_p, y2_p])
+ )
+ boxes[j] = boxes[j][keep_mask]
+ classes[j] = classes[j][keep_mask]
+ if track_ids is not None:
+ track_ids[j] = track_ids[j][keep_mask].copy()
+ if len(track_ids[j]) > 0:
+ if max(track_ids[j]) >= self.max_track_ids:
+ raise ValueError(
+ f"Track id exceeds maximum track id"
+ f"{self.max_track_ids}!"
+ )
+ track_ids[j] += self.max_track_ids * idx
+
+ if self.clip_inside_image:
+ boxes[j][:, [0, 2]] = boxes[j][:, [0, 2]].clip(x1_p, x2_p)
+ boxes[j][:, [1, 3]] = boxes[j][:, [1, 3]].clip(y1_p, y2_p)
+ new_boxes += [
+ np.concatenate(boxes[i : i + self.NUM_SAMPLES])
+ for _ in range(self.NUM_SAMPLES)
+ ]
+ new_classes += [
+ np.concatenate(classes[i : i + self.NUM_SAMPLES])
+ for _ in range(self.NUM_SAMPLES)
+ ]
+ if track_ids is not None:
+ assert new_track_ids is not None
+ new_track_ids += [
+ np.concatenate(track_ids[i : i + self.NUM_SAMPLES])
+ for _ in range(self.NUM_SAMPLES)
+ ]
+ return new_boxes, new_classes, new_track_ids
diff --git a/vis4d/data/transforms/normalize.py b/vis4d/data/transforms/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d3845def8b6ab103550794c6cb3cfabee8fd2e
--- /dev/null
+++ b/vis4d/data/transforms/normalize.py
@@ -0,0 +1,50 @@
+"""Normalize Transform."""
+
+from __future__ import annotations
+
+import torch
+
+from vis4d.common.typing import NDArrayF32
+
+from ..const import CommonKeys as K
+from .base import Transform
+
+
+@Transform(K.images, K.images)
+class NormalizeImages:
+ """Normalize a list of image tensor with given mean and std.
+
+ Image tensor is of shape [N, H, W, C] and range (0, 255).
+ """
+
+ def __init__(
+ self,
+ mean: tuple[float, float, float] = (123.675, 116.28, 103.53),
+ std: tuple[float, float, float] = (58.395, 57.12, 57.375),
+ epsilon: float = 1e-08,
+ ) -> None:
+ """Creates an instance of NormalizeImage.
+
+ Args:
+ mean (Tuple[float, float, float], optional): Mean value. Defaults
+ to (123.675, 116.28, 103.53).
+ std (Tuple[float, float, float], optional): Standard deviation
+ value. Defaults to (58.395, 57.12, 57.375).
+ epsilon (float, optional): Epsilon for numerical stability of
+ division. Defaults to 1e-08.
+ """
+ self.mean = mean
+ self.std = std
+ self.epsilon = epsilon
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Normalize image tensor."""
+ for i, image in enumerate(images):
+ img = torch.from_numpy(image).permute(0, 3, 1, 2)
+ pixel_mean = torch.tensor(self.mean).view(-1, 1, 1)
+ pixel_std = torch.tensor(self.std).view(-1, 1, 1)
+ img = (img - pixel_mean) / (pixel_std + self.epsilon)
+
+ images[i] = img.permute(0, 2, 3, 1).numpy()
+
+ return images
diff --git a/vis4d/data/transforms/pad.py b/vis4d/data/transforms/pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda0f4c99341903adaa71c591ed65bfe8c52abf7
--- /dev/null
+++ b/vis4d/data/transforms/pad.py
@@ -0,0 +1,155 @@
+"""Pad transformation."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from vis4d.common.typing import NDArrayF32, NDArrayUI8
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+
+@Transform(K.images, K.images)
+class PadImages:
+ """Pad batch of images at the bottom right."""
+
+ def __init__(
+ self,
+ stride: int = 32,
+ mode: str = "constant",
+ value: float = 0.0,
+ shape: tuple[int, int] | None = None,
+ pad2square: bool = False,
+ ) -> None:
+ """Creates an instance of PadImage.
+
+ Args:
+ stride (int, optional): Chooses padding size so that the input will
+ be divisible by stride. Defaults to 32.
+ mode (str, optional): Padding mode. One of constant, reflect,
+ replicate or circular. Defaults to "constant".
+ value (float, optional): Value for constant padding.
+ Defaults to 0.0.
+ shape (tuple[int, int], optional): Shape of the padded image
+ (H, W). Defaults to None.
+ pad2square (bool, optional): Pad to square. Defaults to False.
+ """
+ if pad2square:
+ assert (
+ shape is None
+ ), "Cannot specify shape when pad2square is True."
+ self.stride = stride
+ self.mode = mode
+ self.value = value
+ self.shape = shape
+ self.pad2square = pad2square
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Pad images to consistent size."""
+ heights = [im.shape[1] for im in images]
+ widths = [im.shape[2] for im in images]
+ max_hw = _get_max_shape(
+ heights, widths, self.stride, self.shape, self.pad2square
+ )
+
+ # generate params for torch pad
+ for i, (image, h, w) in enumerate(zip(images, heights, widths)):
+ pad_param = (0, max_hw[1] - w, 0, max_hw[0] - h)
+ image_ = torch.from_numpy(image).permute(0, 3, 1, 2)
+ image_ = F.pad( # pylint: disable=not-callable
+ image_, pad_param, self.mode, self.value
+ )
+ images[i] = image_.permute(0, 2, 3, 1).numpy()
+ return images
+
+
+@Transform(K.seg_masks, K.seg_masks)
+class PadSegMasks:
+ """Pad batch of segmentation masks at the bottom right."""
+
+ def __init__(
+ self,
+ stride: int = 32,
+ mode: str = "constant",
+ value: int = 255,
+ shape: tuple[int, int] | None = None,
+ pad2square: bool = False,
+ ) -> None:
+ """Creates an instance of PadSegMasks.
+
+ Args:
+ stride (int, optional): Chooses padding size so that the input will
+ be divisible by stride. Defaults to 32.
+ mode (str, optional): Padding mode. One of constant, reflect,
+ replicate or circular. Defaults to "constant".
+ value (float, optional): Value for constant padding.
+ Defaults to 0.0.
+ shape (tuple[int, int], optional): Shape of the padded image
+ (H, W). Defaults to None.
+ pad2square (bool, optional): Pad to square. Defaults to False.
+ """
+ if pad2square:
+ assert (
+ shape is None
+ ), "Cannot specify shape when pad2square is True."
+ self.stride = stride
+ self.mode = mode
+ self.value = value
+ self.shape = shape
+ self.pad2square = pad2square
+
+ def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]:
+ """Pad images to consistent size."""
+ heights = [mask.shape[0] for mask in masks]
+ widths = [mask.shape[1] for mask in masks]
+ max_hw = _get_max_shape(
+ heights, widths, self.stride, self.shape, self.pad2square
+ )
+
+ # generate params for torch pad
+ for i, (mask, h, w) in enumerate(zip(masks, heights, widths)):
+ pad_param = ((0, max_hw[0] - h), (0, max_hw[1] - w))
+ masks[i] = np.pad( # type: ignore
+ mask, pad_param, mode=self.mode, constant_values=self.value
+ )
+ return masks
+
+
+def _get_max_shape(
+ heights: list[int],
+ widths: list[int],
+ stride: int,
+ shape: tuple[int, int] | None,
+ pad2square: bool,
+) -> tuple[int, int]:
+ """Get max shape for padding.
+
+ Args:
+ stride (int): Chooses padding size so that the input will be divisible
+ by stride.
+ shape (tuple[int, int], optional): Shape of the padded image (H, W).
+ Defaults to None.
+ heights (list[int]): List of heights of input.
+ widths (list[int]): List of widths of input.
+ pad2square (bool): Pad to square.
+
+ Returns:
+ tuple[int, int]: Max shape for padding.
+ """
+ if pad2square:
+ max_size = max(heights + widths)
+ max_hw = (max_size, max_size)
+ elif shape is not None:
+ max_hw = shape
+ else:
+ max_hw = max(heights), max(widths)
+ max_hw = tuple(_make_divisible(x, stride) for x in max_hw) # type: ignore # pylint: disable=line-too-long
+ return max_hw
+
+
+def _make_divisible(x: int, stride: int) -> int:
+ """Ensure divisibility by stride."""
+ return (x + (stride - 1)) // stride * stride
diff --git a/vis4d/data/transforms/photometric.py b/vis4d/data/transforms/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0abd454e25e1cea76b6b1dcf297e17c81f78947
--- /dev/null
+++ b/vis4d/data/transforms/photometric.py
@@ -0,0 +1,355 @@
+"""Photometric transforms."""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+
+import numpy as np
+import torch
+import torchvision.transforms.v2.functional as TF
+from torch import Tensor
+
+from vis4d.common.imports import OPENCV_AVAILABLE
+from vis4d.common.typing import NDArrayF32
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+if OPENCV_AVAILABLE:
+ import cv2
+else:
+ raise ImportError("cv2 is not installed.")
+
+
+@Transform(K.images, K.images)
+class RandomGamma:
+ """Apply Gamma transformation to images.
+
+ Args:
+ gamma_range (tuple[float, float]): Range of gamma values.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+ """
+
+ def __init__(
+ self,
+ gamma_range: tuple[float, float] = (1.0, 1.0),
+ image_channel_mode: str = "RGB",
+ ) -> None:
+ """Init function for Gamma."""
+ self.gamma_range = gamma_range
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for Gamma transformation."""
+ factor = np.random.uniform(self.gamma_range[0], self.gamma_range[1])
+ return _adjust_images(
+ images, TF.adjust_gamma, factor, self.image_channel_mode
+ )
+
+
+@Transform(K.images, K.images)
+class RandomBrightness:
+ """Apply Brightness transformation to images.
+
+ Args:
+ brightness_range (tuple[float, float]): Range of brightness values.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+ """
+
+ def __init__(
+ self,
+ brightness_range: tuple[float, float] = (1.0, 1.0),
+ image_channel_mode: str = "RGB",
+ ) -> None:
+ """Init function for Brightness."""
+ self.brightness_range = brightness_range
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for Brightness transformation."""
+ factor = np.random.uniform(
+ self.brightness_range[0], self.brightness_range[1]
+ )
+ return _adjust_images(
+ images, TF.adjust_brightness, factor, self.image_channel_mode
+ )
+
+
+@Transform(K.images, K.images)
+class RandomContrast:
+ """Apply Contrast transformation to images.
+
+ Args:
+ contrast_range (tuple[float, float]): Range of contrast values.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+ """
+
+ def __init__(
+ self,
+ contrast_range: tuple[float, float] = (1.0, 1.0),
+ image_channel_mode: str = "RGB",
+ ):
+ """Init function for Contrast."""
+ self.contrast_range = contrast_range
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for Contrast transformation."""
+ factor = np.random.uniform(
+ self.contrast_range[0], self.contrast_range[1]
+ )
+ return _adjust_images(
+ images, TF.adjust_contrast, factor, self.image_channel_mode
+ )
+
+
+@Transform(K.images, K.images)
+class RandomSaturation:
+ """Apply saturation transformation to images.
+
+ Args:
+ saturation_range (tuple[float, float]): Range of saturation values.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+ """
+
+ def __init__(
+ self,
+ saturation_range: tuple[float, float] = (1.0, 1.0),
+ image_channel_mode: str = "RGB",
+ ):
+ """Init function for saturation."""
+ self.saturation_range = saturation_range
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for saturation transformation."""
+ factor = np.random.uniform(
+ self.saturation_range[0], self.saturation_range[1]
+ )
+ return _adjust_images(
+ images, TF.adjust_saturation, factor, self.image_channel_mode
+ )
+
+
+@Transform(K.images, K.images)
+class RandomHue:
+ """Apply hue transformation to images.
+
+ Args:
+ hue_range (tuple[float, float]): Range of hue values.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+ """
+
+ def __init__(
+ self,
+ hue_range: tuple[float, float] = (0.0, 0.0),
+ image_channel_mode: str = "RGB",
+ ):
+ """Init function for hue."""
+ self.hue_range = hue_range
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for Hue transformation."""
+ factor = np.random.uniform(self.hue_range[0], self.hue_range[1])
+ return _adjust_images(
+ images, TF.adjust_hue, factor, self.image_channel_mode
+ )
+
+
+@Transform(K.images, K.images)
+class ColorJitter:
+ """Apply color jitter to images.
+
+ Args:
+ brightness_range (tuple[float, float]): Range of brightness values.
+ contrast_range (tuple[float, float]): Range of contrast values.
+ saturation_range (tuple[float, float]): Range of saturation values.
+ hue_range (tuple[float, float]): Range of hue values.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+ """
+
+ def __init__(
+ self,
+ brightness_range: tuple[float, float] = (0.875, 1.125),
+ contrast_range: tuple[float, float] = (0.5, 1.5),
+ saturation_range: tuple[float, float] = (0.5, 1.5),
+ hue_range: tuple[float, float] = (-0.05, 0.05),
+ image_channel_mode: str = "RGB",
+ ):
+ """Init function for color jitter."""
+ self.brightness_range = brightness_range
+ self.contrast_range = contrast_range
+ self.saturation_range = saturation_range
+ self.hue_range = hue_range
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for Hue transformation."""
+ transform_order = np.random.permutation(4)
+ for transform in transform_order:
+ # apply photometric transforms in a random order
+ if transform == 0:
+ # random brightness
+ bfactor = np.random.uniform(
+ self.brightness_range[0], self.brightness_range[1]
+ )
+ images = _adjust_images(
+ images,
+ TF.adjust_brightness,
+ bfactor,
+ self.image_channel_mode,
+ )
+ elif transform == 1:
+ # random contrast
+ cfactor = np.random.uniform(
+ self.contrast_range[0], self.contrast_range[1]
+ )
+ images = _adjust_images(
+ images,
+ TF.adjust_contrast,
+ cfactor,
+ self.image_channel_mode,
+ )
+ elif transform == 2:
+ # random saturation
+ sfactor = np.random.uniform(
+ self.saturation_range[0], self.saturation_range[1]
+ )
+ images = _adjust_images(
+ images,
+ TF.adjust_saturation,
+ sfactor,
+ self.image_channel_mode,
+ )
+ elif transform == 3:
+ # random hue
+ hfactor = np.random.uniform(
+ self.hue_range[0], self.hue_range[1]
+ )
+ images = _adjust_images(
+ images, TF.adjust_hue, hfactor, self.image_channel_mode
+ )
+ return images
+
+
+def _adjust_images(
+ images: list[NDArrayF32],
+ adjust_func: Callable[[Tensor, float], Tensor],
+ adj_factor: float,
+ image_channel_mode: str = "RGB",
+) -> list[NDArrayF32]:
+ """Apply color transformation to images.
+
+ Args:
+ images (list[NDArrayF32]): Image to be transformed.
+ adjust_func (Callable[[Tensor, float], Tensor]): Function to apply.
+ adj_factor (float): Adjustment factor.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "RGB".
+
+ Returns:
+ list[NDArrayF32]: Transformed image.
+ """
+ for i, image in enumerate(images):
+ if image_channel_mode == "BGR":
+ image = image[..., [2, 1, 0]] # convert to RGB
+ image_ = torch.from_numpy(image).permute(0, 3, 1, 2) / 255.0
+ image_ = adjust_func(image_, adj_factor) * 255.0
+ images[i] = image_.permute(0, 2, 3, 1).numpy()
+ if image_channel_mode == "BGR":
+ images[i] = images[i][..., [2, 1, 0]] # convert back to BGR
+ return images
+
+
+@Transform(K.images, K.images)
+class RandomHSV:
+ """Apply HSV transformation to images.
+
+ Used by YOLOX. Modifed from: https://github.com/Megvii-BaseDetection/YOLOX.
+
+ Args:
+ hue_delta (int): Delta for hue.
+ saturation_delta (int): Delta for saturation.
+ value_delta (int): Delta for value.
+ image_channel_mode (str, optional): Image channel mode. Defaults to
+ "BGR".
+ """
+
+ def __init__(
+ self,
+ hue_delta: int = 5,
+ saturation_delta: int = 30,
+ value_delta: int = 30,
+ image_channel_mode: str = "BGR",
+ ):
+ """Init function for HSV transformation."""
+ assert OPENCV_AVAILABLE, "RandomHSV requires OpenCV to be installed."
+ self.hue_delta = hue_delta
+ self.saturation_delta = saturation_delta
+ self.value_delta = value_delta
+ self.image_channel_mode = image_channel_mode
+ assert image_channel_mode in {"RGB", "BGR"}, (
+ "image_channel_mode should be 'RGB' or 'BGR', "
+ f"got {image_channel_mode}."
+ )
+
+ # pylint: disable=no-member
+ def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]:
+ """Call function for Hue transformation."""
+ for i, image in enumerate(images):
+ image = image[0].astype(np.uint8)
+ if self.image_channel_mode == "BGR":
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+ else:
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
+ image = image.astype(np.int16)
+ hsv_gains = np.random.uniform(-1, 1, 3) * [
+ self.hue_delta,
+ self.saturation_delta,
+ self.value_delta,
+ ]
+ # random selection of h, s, v
+ hsv_gains = (hsv_gains * np.random.randint(0, 2, 3)).astype(
+ np.int16
+ )
+ image[..., 0] = (image[..., 0] + hsv_gains[0]) % 180
+ image[..., 1] = np.clip(image[..., 1] + hsv_gains[1], 0, 255)
+ image[..., 2] = np.clip(image[..., 2] + hsv_gains[2], 0, 255)
+ image = image.astype(np.uint8)
+ if self.image_channel_mode == "BGR":
+ cv2.cvtColor(image, cv2.COLOR_HSV2BGR, dst=image)
+ else:
+ cv2.cvtColor(image, cv2.COLOR_HSV2RGB, dst=image)
+ images[i] = image[None, ...].astype(np.float32)
+ return images
diff --git a/vis4d/data/transforms/point_sampling.py b/vis4d/data/transforms/point_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d83f39d51f3306ac3adb1841ebc76ae6dc0e8723
--- /dev/null
+++ b/vis4d/data/transforms/point_sampling.py
@@ -0,0 +1,253 @@
+"""Contains different Sampling Trasnforms for pointclouds."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.typing import NDArrayInt, NDArrayNumber
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+
+@Transform(K.points3d, "transforms.sampling_idxs")
+class GenerateSamplingIndices:
+ """Samples num_pts from the first dim of the provided data tensor.
+
+ If num_pts > data.shape[0], the indices will be upsampled with
+ replacement. If num_pts < data.shape[0], the indices will be sampled
+ without replacement.
+ """
+
+ def __init__(self, num_pts: int) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_pts (int): Number of indices to sample
+ """
+ self.num_pts = num_pts
+
+ def __call__(self, data_list: list[NDArrayNumber]) -> list[NDArrayInt]:
+ """Samples num_pts from the first dim of the provided data tensor.
+
+ If num_pts > data.shape[0], the indices will be upsampled with
+ replacement. If num_pts < data.shape[0], the indices will be sampled
+ without replacement.
+
+ Args:
+ data_list (list[NDArrayNumber]): Data from which to sample indices.
+
+ Returns:
+ list[NDArrayInt]: List of indices.
+
+ Raises:
+ ValueError: If data is empty.
+ """
+ data = data_list[0]
+
+ if len(data) == 0:
+ raise ValueError("Data sample was empty!")
+
+ if self.num_pts > len(data):
+ return [
+ np.concatenate(
+ [
+ np.arange(len(data)),
+ np.random.randint(
+ 0, len(data), self.num_pts - len(data)
+ ),
+ ]
+ )
+ ] * len(data_list)
+ return [
+ np.random.choice(len(data), self.num_pts, replace=False)
+ ] * len(data_list)
+
+
+@Transform(K.points3d, "transforms.sampling_idxs")
+class GenerateBlockSamplingIndices:
+ """Samples num_pts from the first dim of the provided data tensor.
+
+ Makes sure that the sampled points are within a block of size block_size
+ centered around center_xyz. If num_pts > data.shape[0], the indices will
+ be upsampled with replacement. If num_pts < data.shape[0], the indices
+ will be sampled without replacement.
+ """
+
+ def __init__(
+ self,
+ num_pts: int,
+ block_dimensions: tuple[float, float, float],
+ center_point: tuple[float, float, float] | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_pts (int): Number of indices to sample
+ block_dimensions (tuple[float, float, float]): Dimensions of the
+ block in x,y,z
+ center_point (tuple[float, float, float] | None): Center point of
+ the block in x,y,z. If None, the center will be sampled
+ randomly.
+ """
+ self.block_dimensions = np.asarray(block_dimensions)
+ self.center_point = (
+ np.asarray(center_point) if center_point is not None else None
+ )
+
+ self._idx_sampler = GenerateSamplingIndices(num_pts)
+
+ def __call__(self, data_list: list[NDArrayNumber]) -> list[NDArrayInt]:
+ """Samples num_pts from the first dim of the provided data tensor."""
+ data = data_list[0]
+
+ if self.center_point is None:
+ center_point = data[np.random.choice(len(data), 1)]
+ else:
+ center_point = self.center_point
+
+ max_box = center_point + self.block_dimensions / 2.0
+ min_box = center_point - self.block_dimensions / 2.0
+
+ box_mask = np.logical_and(
+ np.all(data >= min_box, axis=1),
+ np.all(data <= max_box, axis=1),
+ )
+ if box_mask.sum().item() == 0: # No valid data sample found!
+ return [np.array([], dtype=np.int32)] * len(data_list)
+
+ idxs = self._idx_sampler([data[box_mask, ...]])[0]
+
+ masked_idxs = np.arange(data.shape[0])[box_mask]
+ selected_idxs_global = masked_idxs[idxs]
+ return [selected_idxs_global] * len(data_list)
+
+
+@Transform(K.points3d, "transforms.sampling_idxs")
+class GenFullCovBlockSamplingIndices:
+ """Subsamples the pointcloud using blocks of a given size."""
+
+ def __init__(
+ self,
+ num_pts: int,
+ block_dimensions: tuple[float, float, float],
+ min_pts: int = 32,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_pts (int): Number of points to sample for each block
+ block_dimensions (tuple[float, float, float]): Dimensions of the
+ block in x,y,z
+ min_pts (int): Minimum number of points in a block to be considered
+ valid
+ """
+ self.num_pts = num_pts
+ self.min_pts = min_pts
+ self.block_dimensions = np.asarray(block_dimensions)
+ self._idx_sampler = GenerateBlockSamplingIndices(
+ num_pts=self.num_pts,
+ block_dimensions=block_dimensions,
+ )
+
+ def __call__(
+ self, coordinates_list: list[NDArrayNumber]
+ ) -> list[NDArrayInt]:
+ """Subsamples the pointcloud using blocks of a given size."""
+ coordinates = coordinates_list[0]
+
+ # Get bounding box for sampling
+ coord_min, coord_max = (
+ np.min(coordinates, axis=0),
+ np.max(coordinates, axis=0),
+ )
+ sampled_idxs = []
+ hwl = coord_max - coord_min
+ num_blocks = np.ceil(hwl / self.block_dimensions).astype(np.int32)
+
+ for idx_x in range(num_blocks[0]):
+ for idx_y in range(num_blocks[1]):
+ for idx_z in range(num_blocks[2]):
+ center_pt = (
+ coord_min
+ + np.array([idx_x, idx_y, idx_z])
+ * self.block_dimensions
+ + self.block_dimensions / 2.0
+ )
+
+ self._idx_sampler.center_point = center_pt
+ selected_idxs = self._idx_sampler([coordinates])[0]
+ if selected_idxs.sum() >= self.min_pts:
+ sampled_idxs.append(selected_idxs)
+ return [np.stack(sampled_idxs)] * len(coordinates_list) # type: ignore
+
+
+@Transform([K.points3d, "transforms.sampling_idxs"], K.points3d)
+class SamplePoints:
+ """Subsamples points randomly.
+
+ Samples 'num_pts' randomly from the provided data tensors using the
+ provided sampling indices.
+
+ This transform is used to sample points from a pointcloud. The indices
+ are generated by the GenerateSamplingIndices transform.
+
+ """
+
+ def __call__(
+ self,
+ data_list: list[NDArrayNumber],
+ selected_idxs_list: list[NDArrayInt],
+ ) -> list[NDArrayNumber]:
+ """Returns data[selected_idxs].
+
+ If the provided indices have two dimension (i.e n_masks, 64), then
+ this operation indices the data n_masks times and returns an array
+ """
+ for i, (data, selected_idxs) in enumerate(
+ zip(data_list, selected_idxs_list)
+ ):
+ assert selected_idxs.ndim <= 2, "Indices must be 1D or 2D"
+ if selected_idxs.ndim == 2:
+ data_list[i] = np.stack(
+ [data[idxs, ...] for idxs in selected_idxs]
+ )
+ else:
+ data_list[i] = data[selected_idxs, ...]
+ return data_list
+
+
+@Transform([K.colors3d, "transforms.sampling_idxs"], K.colors3d)
+class SampleColors(SamplePoints):
+ """Subsamples colors randomly.
+
+ Samples 'num_pts' randomly from the provided data tensors using the
+ provided sampling indices.
+
+ This transform is used to sample colors from a pointcloud. The indices
+ are generated by the GenerateSamplingIndices transform.
+ """
+
+
+@Transform([K.semantics3d, "transforms.sampling_idxs"], K.semantics3d)
+class SampleSemantics(SamplePoints):
+ """Subsamples semantics randomly.
+
+ Samples 'num_pts' randomly from the provided data tensors using the
+ provided sampling indices.
+
+ This transform is used to sample semantics from a pointcloud. The indices
+ are generated by the GenerateSamplingIndices transform.
+ """
+
+
+@Transform([K.instances3d, "transforms.sampling_idxs"], K.instances3d)
+class SampleInstances(SamplePoints):
+ """Subsamples instances randomly.
+
+ Samples 'num_pts' randomly from the provided data tensors using the
+ provided sampling indices.
+
+ This transform is used to sample instances from a pointcloud. The indices
+ are generated by the GenerateSamplingIndices transform.
+ """
diff --git a/vis4d/data/transforms/points.py b/vis4d/data/transforms/points.py
new file mode 100644
index 0000000000000000000000000000000000000000..d981f25927c222daae4476ae966ffa32180d5ec7
--- /dev/null
+++ b/vis4d/data/transforms/points.py
@@ -0,0 +1,269 @@
+"""Pointwise transformations."""
+
+from __future__ import annotations
+
+from typing import TypedDict
+
+import numpy as np
+
+from vis4d.common.typing import NDArrayFloat
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+
+@Transform(in_keys=K.points3d, out_keys="transforms.pc_bounds")
+class GenPcBounds:
+ """Extracts the max and min values of the loaded points."""
+
+ def __call__(
+ self, coordinates_list: list[NDArrayFloat]
+ ) -> list[NDArrayFloat]:
+ """Extracts the max and min values of the pointcloud."""
+ coordinates = coordinates_list[0]
+
+ pc_bounds = [np.stack([coordinates.min(0), coordinates.max(0)])] * len(
+ coordinates_list
+ )
+
+ return pc_bounds
+
+
+@Transform(in_keys=(K.points3d, "trasforms.pc_bounds"), out_keys=K.points3d)
+class NormalizeByMaxBounds:
+ """Normalizes the pointcloud by the max bounds."""
+
+ def __init__(self, axes: tuple[int, int, int] = (0, 1, 2)) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ axes (tuple[int, int, int]): Over which axes to apply
+ normalization.
+ """
+ self.axes = axes
+
+ def __call__(
+ self,
+ coords_list: list[NDArrayFloat],
+ pc_bounds_list: list[NDArrayFloat],
+ ) -> list[NDArrayFloat]:
+ """Applies the normalization."""
+ for i, (coords, pc_bounds) in enumerate(
+ zip(coords_list, pc_bounds_list)
+ ):
+ max_bound = np.max(np.abs(pc_bounds), axis=0)
+ for ax in self.axes:
+ coords[:, ax] = coords[:, ax] / max_bound[ax]
+ coords_list[i] = coords
+ return coords_list
+
+
+@Transform(in_keys=K.points3d, out_keys=K.points3d)
+class CenterAndNormalize:
+ """Centers and normalizes the pointcloud."""
+
+ def __init__(self, centering: bool = True, normalize: bool = True) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ centering (bool): Whether to center the pointcloud
+ normalize (bool): Whether to normalize the pointcloud
+ """
+ self.centering = centering
+ self.normalize = normalize
+
+ def __call__(self, coords_list: list[NDArrayFloat]) -> list[NDArrayFloat]:
+ """Applies the Center and Normalization operations."""
+ for i, coords in enumerate(coords_list):
+ if self.centering:
+ coords = coords - np.mean(coords, axis=0)
+ if self.normalize:
+ coords = coords / np.max(np.sqrt(np.sum(coords**2, axis=-1)))
+ coords_list[i] = coords
+ return coords_list
+
+
+@Transform(in_keys=K.points3d, out_keys=K.points3d)
+class AddGaussianNoise:
+ """Adds random normal distributed noise with given std to the data.
+
+ Args:
+ std (float): Standard Deviation of the noise
+ """
+
+ def __init__(self, noise_level: float = 0.01):
+ """Creates an instance of the class.
+
+ Args:
+ noise_level (float): The noise level. Standard deviation for
+ the gaussian noise.
+ """
+ self.noise_level = noise_level
+
+ def __call__(
+ self, coordinates_list: list[NDArrayFloat]
+ ) -> list[NDArrayFloat]:
+ """Adds gaussian noise to the coordiantes."""
+ for i, coordinates in enumerate(coordinates_list):
+ coordinates[i] = (
+ coordinates
+ + np.random.randn(*coordinates.shape) * self.noise_level
+ )
+ return coordinates_list
+
+
+@Transform(in_keys=K.points3d, out_keys=K.points3d)
+class AddUniformNoise:
+ """Adds random normal distributed noise with given std to the data.
+
+ Args:
+ std (float): Standard Deviation of the noise
+ """
+
+ def __init__(self, noise_level: float = 0.01):
+ """Creates an instance of the class.
+
+ Args:
+ noise_level (float): The noise level. Half the range of the
+ uniform noise. The noise is sampled from
+ [-noise_level, noise_level].
+ """
+ self.noise_level = noise_level
+
+ def __call__(
+ self, coordinates_list: list[NDArrayFloat]
+ ) -> list[NDArrayFloat]:
+ """Adds uniform noise to the coordinates."""
+ for i, coordinates in enumerate(coordinates_list):
+ coordinates_list[i] = coordinates + np.random.uniform(
+ -self.noise_level, self.noise_level, coordinates.shape
+ )
+ return coordinates_list
+
+
+class SE3Transform(TypedDict):
+ """Parameters for Resize."""
+
+ translation: NDArrayFloat
+ rotation: NDArrayFloat
+
+
+def _gen_random_se3_transform(
+ translation_min: NDArrayFloat,
+ translation_max: NDArrayFloat,
+ rotation_min: NDArrayFloat,
+ rotation_max: NDArrayFloat,
+) -> SE3Transform:
+ """Creates a random SE3 Transforms.
+
+ The transform is generated by sampling a random translation and
+ rotation from a uniform distribution.
+ """
+ angle = np.random.uniform(rotation_min, rotation_max)
+ translation = np.random.uniform(translation_min, translation_max)
+ cos_x, sin_x = np.cos(angle[0]), np.sin(angle[0])
+ cos_y, sin_y = np.cos(angle[1]), np.sin(angle[1])
+ cos_z, sin_z = np.cos(angle[2]), np.sin(angle[2])
+ rotx = np.array([[1, 0, 0], [0, cos_x, -sin_x], [0, sin_x, cos_x]])
+ roty = np.array([[cos_y, 0, sin_y], [0, 1, 0], [-sin_y, 0, cos_y]])
+ rotz = np.array([[cos_z, -sin_z, 0], [sin_z, cos_z, 0], [0, 0, 1]])
+ rot = np.dot(rotz, np.dot(roty, rotx))
+ return SE3Transform(translation=translation, rotation=rot)
+
+
+@Transform(in_keys=K.points3d, out_keys=K.points3d)
+class ApplySE3Transform:
+ """Applies a given SE3 Transform to the data."""
+
+ def __init__(
+ self,
+ translation_min: tuple[float, float, float] = (0.0, 0.0, 0.0),
+ translation_max: tuple[float, float, float] = (0.0, 0.0, 0.0),
+ rotation_min: tuple[float, float, float] = (0.0, 0.0, 0.0),
+ rotation_max: tuple[float, float, float] = (0.0, 0.0, 0.0),
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ translation_min (tuple[float, float, float]): Minimum translation.
+ translation_max (tuple[float, float, float]): Maximum translation.
+ rotation_min (tuple[float, float, float]): Minimum euler rotation
+ angles [rad]. Applied in the order rot_x -> rot_y -> rot_z.
+ rotation_max (tuple[float, float, float]): Maximum euler rotation
+ angles [rad]. Applied in the order rot_x -> rot_y -> rot_z.
+ """
+ self.translation_min = np.asarray(translation_min)
+ self.translation_max = np.asarray(translation_max)
+ self.rotation_min = np.asarray(rotation_min)
+ self.rotation_max = np.asarray(rotation_max)
+
+ def __call__(
+ self, coordinates_list: list[NDArrayFloat]
+ ) -> list[NDArrayFloat]:
+ """Applies a SE3 Transform."""
+ for i, coordinates in enumerate(coordinates_list):
+ transform = _gen_random_se3_transform(
+ self.translation_min,
+ self.translation_max,
+ self.rotation_min,
+ self.rotation_max,
+ )
+ if coordinates.shape[-1] == 3:
+ coordinates_list[i] = (
+ transform["rotation"] @ coordinates.T
+ ).T + transform["translation"]
+ elif coordinates.shape[-2] == 3:
+ coordinates_list[i] = (
+ transform["rotation"] @ coordinates
+ ).T + transform["translation"]
+ else:
+ raise ValueError(
+ f"Invalid shape for coordinates: {coordinates.shape}"
+ )
+ return coordinates_list
+
+
+class ApplySO3Transform(ApplySE3Transform):
+ """Applies a given SO3 Transform to the data."""
+
+ def __call__(
+ self, coordinates_list: list[NDArrayFloat]
+ ) -> list[NDArrayFloat]:
+ """Applies a given SO3 Transform to the data."""
+ for i, coordinates in enumerate(coordinates_list):
+ transform = _gen_random_se3_transform(
+ self.translation_min,
+ self.translation_max,
+ self.rotation_min,
+ self.rotation_max,
+ )["rotation"]
+ if coordinates.shape[-1] == 3:
+ coordinates_list[i] = (transform @ coordinates.T).T
+ elif coordinates.shape[-2] == 3:
+ coordinates_list[i] = (transform @ coordinates).T
+ else:
+ raise ValueError(
+ f"Invalid shape for coordinates: {coordinates.shape}"
+ )
+ return coordinates_list
+
+
+@Transform(in_keys=K.points3d, out_keys=K.points3d)
+class TransposeChannels:
+ """Transposes some predifined channels."""
+
+ def __init__(self, channels: tuple[int, int] = (-1, -2)):
+ """Creates an instance of the class.
+
+ Args:
+ channels (tuple[int, int]): Tuple of channels to transpose
+ """
+ self.channels = channels
+
+ def __call__(
+ self, coordinates_list: list[NDArrayFloat]
+ ) -> list[NDArrayFloat]:
+ """Transposes some predifined channels."""
+ for i, coordinates in enumerate(coordinates_list):
+ coordinates_list[i] = coordinates.transpose(*self.channels)
+ return coordinates_list
diff --git a/vis4d/data/transforms/post_process.py b/vis4d/data/transforms/post_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6605c201a8b8e43412cc7597dacb7f7c9ad821a
--- /dev/null
+++ b/vis4d/data/transforms/post_process.py
@@ -0,0 +1,161 @@
+"""Post process after transformation."""
+
+from __future__ import annotations
+
+import torch
+
+from vis4d.common.typing import NDArrayF32, NDArrayI64
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.box.box2d import bbox_area, bbox_clip
+
+from .base import Transform
+
+
+@Transform(
+ in_keys=[
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ K.input_hw,
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ ],
+ out_keys=[
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ ],
+)
+class PostProcessBoxes2D:
+ """Post process after transformation."""
+
+ def __init__(
+ self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ min_area (float): Minimum area of the bounding box. Defaults to
+ 7.0 * 7.0.
+ clip_bboxes_to_image (bool): Whether to clip the bounding boxes to
+ the image size. Defaults to True.
+ """
+ self.min_area = min_area
+ self.clip_bboxes_to_image = clip_bboxes_to_image
+
+ def __call__(
+ self,
+ boxes_list: list[NDArrayF32],
+ classes_list: list[NDArrayI64],
+ track_ids_list: list[NDArrayI64] | None,
+ input_hw_list: list[tuple[int, int]],
+ boxes3d_list: list[NDArrayF32] | None,
+ boxes3d_classes_list: list[NDArrayI64] | None,
+ boxes3d_track_ids_list: list[NDArrayI64] | None,
+ ) -> tuple[
+ list[NDArrayF32],
+ list[NDArrayI64],
+ list[NDArrayI64] | None,
+ list[NDArrayF32] | None,
+ list[NDArrayI64] | None,
+ list[NDArrayI64] | None,
+ ]:
+ """Post process according to boxes2D after transformation.
+
+ Args:
+ boxes_list (list[NDArrayF32]): The bounding boxes to be post
+ processed.
+ classes_list (list[NDArrayF32]): The classes of the bounding boxes.
+ track_ids_list (list[NDArrayI64] | None): The track ids of the
+ bounding boxes.
+ input_hw_list (list[tuple[int, int]]): The height and width of the
+ input image.
+ boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be
+ post processed.
+ boxes3d_classes_list (list[NDArrayI64] | None): The classes of the
+ 3D bounding boxes.
+ boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of
+ the 3D bounding boxes.
+
+ Returns:
+ tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None,
+ list[NDArrayF32] | None, list[NDArrayI64] | None,
+ list[NDArrayI64] | None]: The post processed results.
+ """
+ new_track_ids: list[NDArrayI64] | None = (
+ [] if track_ids_list is not None else None
+ )
+ new_boxes3d: list[NDArrayF32] | None = (
+ [] if boxes3d_list is not None else None
+ )
+ new_boxes3d_classes: list[NDArrayI64] | None = (
+ [] if boxes3d_classes_list is not None else None
+ )
+ new_boxes3d_track_ids: list[NDArrayI64] | None = (
+ [] if boxes3d_track_ids_list is not None else None
+ )
+ for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)):
+ boxes_ = torch.from_numpy(boxes)
+ if self.clip_bboxes_to_image:
+ boxes_ = bbox_clip(boxes_, input_hw_list[i])
+
+ keep = (bbox_area(boxes_) >= self.min_area).numpy()
+
+ boxes_list[i] = boxes[keep]
+ classes_list[i] = classes[keep]
+
+ if track_ids_list is not None:
+ assert new_track_ids is not None
+ new_track_ids.append(track_ids_list[i][keep])
+
+ if boxes3d_list is not None:
+ assert new_boxes3d is not None
+ new_boxes3d.append(boxes3d_list[i][keep])
+
+ if boxes3d_classes_list is not None:
+ assert new_boxes3d_classes is not None
+ new_boxes3d_classes.append(boxes3d_classes_list[i][keep])
+
+ if boxes3d_track_ids_list is not None:
+ assert new_boxes3d_track_ids is not None
+ new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep])
+
+ return (
+ boxes_list,
+ classes_list,
+ new_track_ids,
+ new_boxes3d,
+ new_boxes3d_classes,
+ new_boxes3d_track_ids,
+ )
+
+
+@Transform(in_keys=[K.boxes2d_track_ids], out_keys=[K.boxes2d_track_ids])
+class RescaleTrackIDs:
+ """Rescale track ids."""
+
+ def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]:
+ """Rescale the track ids.
+
+ Args:
+ track_ids_list (list[NDArrayI64]): The track ids to be
+ rescaled.
+
+ Returns:
+ list[NDArrayI64]: The rescaled track ids.
+ """
+ track_ids_all: dict[int, int] = {}
+ for track_ids in track_ids_list:
+ for track_id in track_ids:
+ if track_id not in track_ids_all:
+ track_ids_all[track_id] = len(track_ids_all)
+
+ for track_ids in track_ids_list:
+ for i, track_id in enumerate(track_ids):
+ track_ids[i] = track_ids_all[track_id]
+
+ return track_ids_list
diff --git a/vis4d/data/transforms/random_erasing.py b/vis4d/data/transforms/random_erasing.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace6af900b6629625f6ed60b2f4cec6623c73be4
--- /dev/null
+++ b/vis4d/data/transforms/random_erasing.py
@@ -0,0 +1,91 @@
+"""Random erasing data augmentation."""
+
+import numpy as np
+
+from vis4d.common.typing import NDArrayNumber
+from vis4d.data.const import CommonKeys as K
+
+from .base import Transform
+
+
+@Transform(in_keys=K.images, out_keys=K.images)
+class RandomErasing:
+ """Randomly erase a rectangular region in an image tensor."""
+
+ def __init__(
+ self,
+ min_area: float = 0.02,
+ max_area: float = 0.4,
+ min_aspect_ratio: float = 0.3,
+ max_aspect_ratio: float = 1 / 0.3,
+ mean: tuple[float, float, float] = (0.0, 0.0, 0.0),
+ num_attempt: int = 10,
+ ):
+ """Creates an instance of RandomErasing.
+
+ Recommended to use this transform after normalization. The erased
+ region will be filled with the mean value. See
+ `https://arxiv.org/abs/1708.04896`.
+
+ Args:
+ min_area (float, optional): Minimum area of the erased region.
+ Defaults to 0.02.
+ max_area (float, optional): Maximum area of the erased region.
+ Defaults to 0.4.
+ min_aspect_ratio (float, optional): Minimum aspect ratio of the
+ erased region. Defaults to 0.3.
+ max_aspect_ratio (float, optional): Maximum aspect ratio of the
+ erased region. Defaults to 1 / 0.3.
+ mean (tuple[float, float, float], optional): Mean of the dataset.
+ Defaults to (0.0, 0.0, 0.0).
+ num_attempt (int, optional): Number of maximum attempts to find a
+ valid erased region. This is used to avoid infinite attempts of
+ resampling the region, though such cases are very unlikely to
+ happen. Defaults to 10.
+
+ Returns:
+ Callable: A function that takes a tensor of shape [N, H, W, C] and
+ returns a tensor of the same shape.
+ """
+ self.min_area = min_area
+ self.max_area = max_area
+ self.min_aspect_ratio = min_aspect_ratio
+ self.max_aspect_ratio = max_aspect_ratio
+ self.mean = mean
+ self.num_attempt = num_attempt
+
+ def do_erasing(self, images: NDArrayNumber) -> NDArrayNumber:
+ """Execute the random erasing."""
+ fill = np.array(self.mean)
+ for i in range(images.shape[0]):
+ image = images[i]
+ h, w = image.shape[0:2]
+ area = h * w
+
+ for _ in range(self.num_attempt):
+ target_area = (
+ np.random.uniform(self.min_area, self.max_area) * area
+ )
+ aspect_ratio = np.random.uniform(
+ self.min_aspect_ratio, self.max_aspect_ratio
+ )
+ h_erase = int(round(np.sqrt(target_area * aspect_ratio)))
+ w_erase = int(round(np.sqrt(target_area / aspect_ratio)))
+ if w_erase < w and h_erase < h:
+ x_erase = np.random.randint(0, w - w_erase)
+ y_erase = np.random.randint(0, h - h_erase)
+ image[
+ y_erase : y_erase + h_erase,
+ x_erase : x_erase + w_erase,
+ :,
+ ] = fill
+ break
+ return images
+
+ def __call__(
+ self, images_list: list[NDArrayNumber]
+ ) -> list[NDArrayNumber]:
+ """Execute the transform."""
+ for i, images in enumerate(images_list):
+ images_list[i] = self.do_erasing(images)
+ return images_list
diff --git a/vis4d/data/transforms/resize.py b/vis4d/data/transforms/resize.py
new file mode 100644
index 0000000000000000000000000000000000000000..c829f7a0935e771578613e47cdd7dae9f58fc4b0
--- /dev/null
+++ b/vis4d/data/transforms/resize.py
@@ -0,0 +1,539 @@
+"""Resize transformation."""
+
+from __future__ import annotations
+
+import random
+from typing import TypedDict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from vis4d.common.imports import OPENCV_AVAILABLE
+from vis4d.common.typing import NDArrayF32
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.box.box2d import transform_bbox
+
+from .base import Transform
+
+if OPENCV_AVAILABLE:
+ import cv2
+ from cv2 import ( # pylint: disable=no-member,no-name-in-module
+ INTER_AREA,
+ INTER_CUBIC,
+ INTER_LANCZOS4,
+ INTER_LINEAR,
+ INTER_NEAREST,
+ )
+else:
+ raise ImportError("Please install opencv-python to use this module.")
+
+
+class ResizeParam(TypedDict):
+ """Parameters for Resize."""
+
+ target_shape: tuple[int, int]
+ scale_factor: tuple[float, float]
+
+
+@Transform(K.images, ["transforms.resize", K.input_hw])
+class GenResizeParameters:
+ """Generate the parameters for a resize operation."""
+
+ def __init__(
+ self,
+ shape: tuple[int, int] | list[tuple[int, int]],
+ keep_ratio: bool = False,
+ multiscale_mode: str = "range",
+ scale_range: tuple[float, float] = (1.0, 1.0),
+ align_long_edge: bool = False,
+ resize_short_edge: bool = False,
+ allow_overflow: bool = False,
+ fixed_scale: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ shape (tuple[int, int] | list[tuple[int, int]]): Image shape to
+ be resized to in (H, W) format. In multiscale mode 'list',
+ shape represents the list of possible shapes for resizing.
+ keep_ratio (bool, optional): If aspect ratio of the original image
+ should be kept, the new shape will modified to fit the aspect
+ ratio of the original image. Defaults to False.
+ multiscale_mode (str, optional): one of [range, list]. Defaults to
+ "range".
+ scale_range (tuple[float, float], optional): Range of sampled image
+ scales in range mode, e.g. (0.8, 1.2), indicating minimum of
+ 0.8 * shape and maximum of 1.2 * shape. Defaults to (1.0, 1.0).
+ align_long_edge (bool, optional): If keep_ratio=true, this option
+ indicates if shape should be automatically aligned with the
+ long edge of the original image, e.g. original shape=(100, 80),
+ shape to be resized=(100, 200) will yield (125, 100) as new
+ shape. Defaults to False.
+ resize_short_edge (bool, optional): If keep_ratio=true, this option
+ scale the image according to the short edge. e.g. original
+ shape=(80, 100), shape to be resized=(100, 200) will yield
+ (100, 125) as new shape. Defaults to False.
+ allow_overflow (bool, optional): If set to True, we scale the image
+ to the smallest size such that it is no smaller than shape.
+ Otherwise, we scale the image to the largest size such that it
+ is no larger than shape. Defaults to False.
+ fixed_scale (bool, optional): If set to True, we scale the image
+ without offset. Defaults to False.
+ """
+ self.shape = shape
+ self.keep_ratio = keep_ratio
+
+ assert multiscale_mode in {"list", "range"}
+ self.multiscale_mode = multiscale_mode
+
+ assert (
+ scale_range[0] <= scale_range[1]
+ ), f"Invalid scale range: {scale_range[1]} < {scale_range[0]}"
+ self.scale_range = scale_range
+
+ self.align_long_edge = align_long_edge
+ self.resize_short_edge = resize_short_edge
+ self.allow_overflow = allow_overflow
+ self.fixed_scale = fixed_scale
+
+ def _get_target_shape(
+ self, input_shape: tuple[int, int]
+ ) -> tuple[int, int]:
+ """Generate possibly random target shape."""
+ if self.multiscale_mode == "range":
+ assert isinstance(
+ self.shape, tuple
+ ), "Specify shape as tuple when using multiscale mode range."
+ if self.scale_range[0] < self.scale_range[1]: # do multi-scale
+ w_scale = (
+ random.uniform(0, 1)
+ * (self.scale_range[1] - self.scale_range[0])
+ + self.scale_range[0]
+ )
+ h_scale = (
+ random.uniform(0, 1)
+ * (self.scale_range[1] - self.scale_range[0])
+ + self.scale_range[0]
+ )
+ else:
+ h_scale = w_scale = 1.0
+
+ shape = int(self.shape[0] * h_scale), int(self.shape[1] * w_scale)
+ else:
+ assert isinstance(
+ self.shape, list
+ ), "Specify shape as list when using multiscale mode list."
+ shape = random.choice(self.shape)
+
+ return get_resize_shape(
+ input_shape,
+ shape,
+ self.keep_ratio,
+ self.align_long_edge,
+ self.resize_short_edge,
+ self.allow_overflow,
+ self.fixed_scale,
+ )
+
+ def __call__(
+ self, images: list[NDArrayF32]
+ ) -> tuple[list[ResizeParam], list[tuple[int, int]]]:
+ """Compute the parameters and put them in the data dict."""
+ image = images[0]
+
+ im_shape = (image.shape[1], image.shape[2])
+ target_shape = self._get_target_shape(im_shape)
+ scale_factor = (
+ target_shape[1] / im_shape[1],
+ target_shape[0] / im_shape[0],
+ )
+
+ resize_params = [
+ ResizeParam(target_shape=target_shape, scale_factor=scale_factor)
+ ] * len(images)
+ target_shapes = [target_shape] * len(images)
+
+ return resize_params, target_shapes
+
+
+def get_resize_shape(
+ original_shape: tuple[int, int],
+ new_shape: tuple[int, int],
+ keep_ratio: bool = True,
+ align_long_edge: bool = False,
+ resize_short_edge: bool = False,
+ allow_overflow: bool = False,
+ fixed_scale: bool = False,
+) -> tuple[int, int]:
+ """Get shape for resize, considering keep_ratio and align_long_edge.
+
+ Args:
+ original_shape (tuple[int, int]): Original shape in [H, W].
+ new_shape (tuple[int, int]): New shape in [H, W].
+ keep_ratio (bool, optional): Whether to keep the aspect ratio.
+ Defaults to True.
+ align_long_edge (bool, optional): Whether to align the long edge of
+ the original shape with the long edge of the new shape.
+ Defaults to False.
+ resize_short_edge (bool, optional): Whether to resize according to the
+ short edge. Defaults to False.
+ allow_overflow (bool, optional): Whether to allow overflow.
+ Defaults to False.
+ fixed_scale (bool, optional): Whether to use fixed scale.
+
+ Returns:
+ tuple[int, int]: The new shape in [H, W].
+ """
+ h, w = original_shape
+ new_h, new_w = new_shape
+
+ if keep_ratio:
+ if allow_overflow:
+ comp_fn = max
+ else:
+ comp_fn = min
+
+ if align_long_edge:
+ long_edge, short_edge = max(new_shape), min(new_shape)
+ scale_factor = comp_fn(
+ long_edge / max(h, w), short_edge / min(h, w)
+ )
+ elif resize_short_edge:
+ short_edge = min(original_shape)
+ new_short_edge = min(new_shape)
+ scale_factor = new_short_edge / short_edge
+ else:
+ scale_factor = comp_fn(new_w / w, new_h / h)
+
+ if fixed_scale:
+ offset = 0.0
+ else:
+ offset = 0.5
+
+ new_h = int(h * scale_factor + offset)
+ new_w = int(w * scale_factor + offset)
+
+ return new_h, new_w
+
+
+@Transform([K.images, "transforms.resize.target_shape"], K.images)
+class ResizeImages:
+ """Resize Images."""
+
+ def __init__(
+ self,
+ interpolation: str = "bilinear",
+ antialias: bool = False,
+ imresize_backend: str = "torch",
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ interpolation (str, optional): Interpolation method. One of
+ ["nearest", "bilinear", "bicubic"]. Defaults to "bilinear".
+ antialias (bool): Whether to use antialiasing. Defaults to False.
+ imresize_backend (str): One of torch, cv2. Defaults to torch.
+ """
+ self.interpolation = interpolation
+ self.antialias = antialias
+ self.imresize_backend = imresize_backend
+ assert imresize_backend in {
+ "torch",
+ "cv2",
+ }, f"Invalid imresize backend: {imresize_backend}"
+
+ def __call__(
+ self, images: list[NDArrayF32], target_shapes: list[tuple[int, int]]
+ ) -> list[NDArrayF32]:
+ """Resize an image of dimensions [N, H, W, C].
+
+ Args:
+ image (Tensor): The image.
+ target_shape (tuple[int, int]): The target shape after resizing.
+
+ Returns:
+ list[NDArrayF32]: Resized images according to parameters in resize.
+ """
+ for i, (image, target_shape) in enumerate(zip(images, target_shapes)):
+ images[i] = resize_image(
+ image,
+ target_shape,
+ interpolation=self.interpolation,
+ antialias=self.antialias,
+ backend=self.imresize_backend,
+ )
+ return images
+
+
+def resize_image(
+ inputs: NDArrayF32,
+ shape: tuple[int, int],
+ interpolation: str = "bilinear",
+ antialias: bool = False,
+ backend: str = "torch",
+) -> NDArrayF32:
+ """Resize image."""
+ if backend == "torch":
+ image = torch.from_numpy(inputs).permute(0, 3, 1, 2)
+ image = resize_tensor(image, shape, interpolation, antialias)
+ return image.permute(0, 2, 3, 1).numpy()
+
+ if backend == "cv2":
+ cv2_interp_codes = {
+ "nearest": INTER_NEAREST,
+ "bilinear": INTER_LINEAR,
+ "bicubic": INTER_CUBIC,
+ "area": INTER_AREA,
+ "lanczos": INTER_LANCZOS4,
+ }
+ return cv2.resize( # pylint: disable=no-member, unsubscriptable-object
+ inputs[0].astype(np.uint8),
+ (shape[1], shape[0]),
+ interpolation=cv2_interp_codes[interpolation],
+ )[None, ...].astype(np.float32)
+
+ raise ValueError(f"Invalid imresize backend: {backend}")
+
+
+@Transform([K.boxes2d, "transforms.resize.scale_factor"], K.boxes2d)
+class ResizeBoxes2D:
+ """Resize list of 2D bounding boxes."""
+
+ def __call__(
+ self,
+ boxes_list: list[NDArrayF32],
+ scale_factors: list[tuple[float, float]],
+ ) -> list[NDArrayF32]:
+ """Resize 2D bounding boxes.
+
+ Args:
+ boxes_list: (list[NDArrayF32]): The bounding boxes to be resized.
+ scale_factors (list[tuple[float, float]]): scaling factors.
+
+ Returns:
+ list[NDArrayF32]: Resized bounding boxes according to parameters in
+ resize.
+ """
+ for i, (boxes, scale_factor) in enumerate(
+ zip(boxes_list, scale_factors)
+ ):
+ boxes_ = torch.from_numpy(boxes)
+ scale_matrix = torch.eye(3)
+ scale_matrix[0, 0] = scale_factor[0]
+ scale_matrix[1, 1] = scale_factor[1]
+ boxes_list[i] = transform_bbox(scale_matrix, boxes_).numpy()
+ return boxes_list
+
+
+@Transform(
+ [
+ K.depth_maps,
+ "transforms.resize.target_shape",
+ "transforms.resize.scale_factor",
+ ],
+ K.depth_maps,
+)
+class ResizeDepthMaps:
+ """Resize depth maps."""
+
+ def __init__(
+ self,
+ interpolation: str = "nearest",
+ rescale_depth_values: bool = False,
+ check_scale_factors: bool = False,
+ ):
+ """Initialize the transform.
+
+ Args:
+ interpolation (str, optional): Interpolation method. One of
+ ["nearest", "bilinear", "bicubic"]. Defaults to "nearest".
+ rescale_depth_values (bool, optional): If the depth values should
+ be rescaled according to the new scale factor. Defaults to
+ False. This is useful if we want to keep the intrinsic
+ parameters of the camera the same.
+ check_scale_factors (bool, optional): If the scale factors should
+ be checked to ensure they are the same. Defaults to False.
+ If False, the scale factor is assumed to be the same for both
+ dimensions and will just use the first scale factor.
+ """
+ self.interpolation = interpolation
+ self.rescale_depth_values = rescale_depth_values
+ self.check_scale_factors = check_scale_factors
+
+ def __call__(
+ self,
+ depth_maps: list[NDArrayF32],
+ target_shapes: list[tuple[int, int]],
+ scale_factors: list[tuple[float, float]],
+ ) -> list[NDArrayF32]:
+ """Resize depth maps."""
+ for i, (depth_map, target_shape, scale_factor) in enumerate(
+ zip(depth_maps, target_shapes, scale_factors)
+ ):
+ depth_map_ = torch.from_numpy(depth_map)
+ depth_map_ = (
+ resize_tensor(
+ depth_map_.float().unsqueeze(0).unsqueeze(0),
+ target_shape,
+ interpolation=self.interpolation,
+ )
+ .type(depth_map_.dtype)
+ .squeeze(0)
+ .squeeze(0)
+ )
+ if self.rescale_depth_values:
+ if self.check_scale_factors:
+ assert np.isclose(
+ scale_factor[0], scale_factor[1], atol=1e-4
+ ), "Depth map scale factors must be the same"
+ depth_map_ /= scale_factor[0]
+ depth_maps[i] = depth_map_.numpy()
+ return depth_maps
+
+
+@Transform(
+ [
+ K.optical_flows,
+ "transforms.resize.target_shape",
+ "transforms.resize.scale_factor",
+ ],
+ K.optical_flows,
+)
+class ResizeOpticalFlows:
+ """Resize optical flows."""
+
+ def __init__(self, normalized_flow: bool = True):
+ """Create a ResizeOpticalFlows instance.
+
+ Args:
+ normalized_flow (bool): Whether the optical flow is normalized.
+ Defaults to True. If false, the optical flow will be scaled
+ according to the scale factor.
+ """
+ self.normalized_flow = normalized_flow
+
+ def __call__(
+ self,
+ optical_flows: list[NDArrayF32],
+ target_shapes: list[tuple[int, int]],
+ scale_factors: list[tuple[float, float]],
+ ) -> list[NDArrayF32]:
+ """Resize optical flows."""
+ for i, (optical_flow, target_shape, scale_factor) in enumerate(
+ zip(optical_flows, target_shapes, scale_factors)
+ ):
+ optical_flow_ = torch.from_numpy(optical_flow).permute(2, 0, 1)
+ optical_flow_ = (
+ resize_tensor(
+ optical_flow_.float().unsqueeze(0),
+ target_shape,
+ interpolation="bilinear",
+ )
+ .type(optical_flow_.dtype)
+ .squeeze(0)
+ .permute(1, 2, 0)
+ )
+ # scale optical flows
+ if not self.normalized_flow:
+ optical_flow_[:, :, 0] *= scale_factor[0]
+ optical_flow_[:, :, 1] *= scale_factor[1]
+ optical_flows[i] = optical_flow_.numpy()
+ return optical_flows
+
+
+@Transform(
+ [K.instance_masks, "transforms.resize.target_shape"], K.instance_masks
+)
+class ResizeInstanceMasks:
+ """Resize instance segmentation masks."""
+
+ def __call__(
+ self,
+ masks_list: list[NDArrayF32],
+ target_shapes: list[tuple[int, int]],
+ ) -> list[NDArrayF32]:
+ """Resize masks."""
+ for i, (masks, target_shape) in enumerate(
+ zip(masks_list, target_shapes)
+ ):
+ if len(masks) == 0: # handle empty masks
+ continue
+ masks_ = torch.from_numpy(masks)
+ masks_ = (
+ resize_tensor(
+ masks_.float().unsqueeze(1),
+ target_shape,
+ interpolation="nearest",
+ )
+ .type(masks_.dtype)
+ .squeeze(1)
+ )
+ masks_list[i] = masks_.numpy()
+ return masks_list
+
+
+@Transform([K.seg_masks, "transforms.resize.target_shape"], K.seg_masks)
+class ResizeSegMasks:
+ """Resize segmentation masks."""
+
+ def __call__(
+ self,
+ masks_list: list[NDArrayF32],
+ target_shape_list: list[tuple[int, int]],
+ ) -> list[NDArrayF32]:
+ """Resize masks."""
+ for i, (masks, target_shape) in enumerate(
+ zip(masks_list, target_shape_list)
+ ):
+ masks_ = torch.from_numpy(masks)
+ masks_ = (
+ resize_tensor(
+ masks_.float().unsqueeze(0).unsqueeze(0),
+ target_shape,
+ interpolation="nearest",
+ )
+ .type(masks_.dtype)
+ .squeeze(0)
+ .squeeze(0)
+ )
+ masks_list[i] = masks_.numpy()
+ return masks_list
+
+
+@Transform([K.intrinsics, "transforms.resize.scale_factor"], K.intrinsics)
+class ResizeIntrinsics:
+ """Resize Intrinsics."""
+
+ def __call__(
+ self,
+ intrinsics: list[NDArrayF32],
+ scale_factors: list[tuple[float, float]],
+ ) -> list[NDArrayF32]:
+ """Scale camera intrinsics when resizing."""
+ for i, scale_factor in enumerate(scale_factors):
+ scale_matrix = np.eye(3, dtype=np.float32)
+ scale_matrix[0, 0] *= scale_factor[0]
+ scale_matrix[1, 1] *= scale_factor[1]
+ intrinsics[i] = scale_matrix @ intrinsics[i]
+ return intrinsics
+
+
+def resize_tensor(
+ inputs: Tensor,
+ shape: tuple[int, int],
+ interpolation: str = "bilinear",
+ antialias: bool = False,
+) -> Tensor:
+ """Resize Tensor."""
+ assert interpolation in {"nearest", "bilinear", "bicubic"}
+ align_corners = None if interpolation == "nearest" else False
+ output = F.interpolate(
+ inputs,
+ shape,
+ mode=interpolation,
+ align_corners=align_corners,
+ antialias=antialias,
+ )
+ return output
diff --git a/vis4d/data/transforms/select_sensor.py b/vis4d/data/transforms/select_sensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..534ef6ef3f464d1ca8f9ab688a15657d423c55a5
--- /dev/null
+++ b/vis4d/data/transforms/select_sensor.py
@@ -0,0 +1,52 @@
+# pylint: disable=no-member
+"""Select Sensor transformation."""
+from vis4d.data.typing import DictData
+
+from .base import Transform
+
+
+@Transform("data", "data")
+class SelectSensor:
+ """Keep data from one sensor only but keep shared data.
+
+ Note: The input data is assumed to be in the format of DictData[DictData],
+ i.e. a list of data dictionaries, each of which contains a dictionary of
+ either the data from a sensor or the shared data (metadata) for all
+ sensors.
+
+ Example:
+ >>> data = [
+ {
+ "sensor1": {"image": 1, "label": 2},
+ "sensor2": {"image": 1, "label": 2},
+ "meta": 3},
+ },
+ ]
+ >>> tsfm = SelectSensor(
+ sensor="sensor1", sensors=["sensor1", "sensor2"]
+ )
+ >>> tsfm(data)
+ [{"image": 1, "label": 2, "meta": 3},]
+ """
+
+ def __init__(self, selected_sensor: str) -> None:
+ """Creates an instance of SelectSensor.
+
+ Args:
+ selected_sensor (str): The name of the sensor to keep.
+ """
+ self.selected_sensor = selected_sensor
+
+ def __call__(self, batch: list[DictData]) -> list[DictData]:
+ """Select data from one sensor only."""
+ output_batch = []
+ for data in batch:
+ output_data = {}
+ for key in data.keys():
+ if key in self.sensors: # type: ignore
+ if key == self.selected_sensor:
+ output_data.update(data[key])
+ else:
+ output_data[key] = data[key]
+ output_batch.append(output_data)
+ return output_batch
diff --git a/vis4d/data/transforms/to_tensor.py b/vis4d/data/transforms/to_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3ec5e3e3b22491a4471bc21aed64fffff07f8d8
--- /dev/null
+++ b/vis4d/data/transforms/to_tensor.py
@@ -0,0 +1,48 @@
+"""ToTensor transformation."""
+
+import numpy as np
+import torch
+
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.typing import DictData
+
+from .base import Transform
+
+
+def _replace_arrays(data: DictData) -> None:
+ """Replace numpy arrays with tensors."""
+ for key in data.keys():
+ if key in [K.images, K.original_images]:
+ if not data[key].flags.c_contiguous:
+ data[key] = np.ascontiguousarray(
+ data[key].transpose(0, 3, 1, 2)
+ )
+ data[key] = torch.from_numpy(data[key])
+ else:
+ data[key] = (
+ torch.from_numpy(data[key])
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ elif isinstance(data[key], np.ndarray):
+ data[key] = torch.from_numpy(data[key])
+ elif isinstance(data[key], dict):
+ _replace_arrays(data[key])
+ elif isinstance(data[key], list):
+ for i, entry in enumerate(data[key]):
+ if isinstance(entry, np.ndarray):
+ data[key][i] = torch.from_numpy(entry)
+
+
+@Transform("data", "data")
+class ToTensor:
+ """Transform all entries in a list of DataDict from numpy to torch.
+
+ Note that we reshape K.images from NHWC to NCHW.
+ """
+
+ def __call__(self, batch: list[DictData]) -> list[DictData]:
+ """Transform all entries to tensor."""
+ for data in batch:
+ _replace_arrays(data)
+ return batch
diff --git a/vis4d/data/typing.py b/vis4d/data/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3da0b03b9b11c6976ab84f7270f3d4301707fc1
--- /dev/null
+++ b/vis4d/data/typing.py
@@ -0,0 +1,17 @@
+"""Type definitions related to the data pipeline.
+
+This file defines the data format `DictData` as an arbitrary dictionary that
+can, in principle, hold arbitrary data. However, we provide `CommonKeys` in
+`vis4d.data.const` to define the input format for commonly used input types,
+so that the data pre-processing pipeline can take advantage of pre-defined
+data formats that are necessary to properly pre-process a given data sample.
+"""
+
+from __future__ import annotations
+
+from typing import Union
+
+from vis4d.common.typing import DictStrAny
+
+DictData = DictStrAny
+DictDataOrList = Union[DictData, list[DictData]]
diff --git a/vis4d/engine/__init__.py b/vis4d/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0f099edf7f3185b4f5f5c2ce63a0e4a37244aee
--- /dev/null
+++ b/vis4d/engine/__init__.py
@@ -0,0 +1 @@
+"""Vis4D run package."""
diff --git a/vis4d/engine/callbacks/__init__.py b/vis4d/engine/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e95dc07eef537b48ba383a38a40715afe7983037
--- /dev/null
+++ b/vis4d/engine/callbacks/__init__.py
@@ -0,0 +1,25 @@
+"""Callback modules."""
+
+from .base import Callback
+from .ema import EMACallback
+from .evaluator import EvaluatorCallback
+from .logging import LoggingCallback
+from .scheduler import LRSchedulerCallback
+from .visualizer import VisualizerCallback
+from .yolox_callbacks import (
+ YOLOXModeSwitchCallback,
+ YOLOXSyncNormCallback,
+ YOLOXSyncRandomResizeCallback,
+)
+
+__all__ = [
+ "Callback",
+ "EMACallback",
+ "EvaluatorCallback",
+ "LoggingCallback",
+ "VisualizerCallback",
+ "LRSchedulerCallback",
+ "YOLOXModeSwitchCallback",
+ "YOLOXSyncNormCallback",
+ "YOLOXSyncRandomResizeCallback",
+]
diff --git a/vis4d/engine/callbacks/base.py b/vis4d/engine/callbacks/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce9cea86007b43bfbb75408d4be8574dd654f9a
--- /dev/null
+++ b/vis4d/engine/callbacks/base.py
@@ -0,0 +1,85 @@
+"""Base module for callbacks."""
+
+from __future__ import annotations
+
+import lightning.pytorch as pl
+from torch import Tensor
+
+from vis4d.common.typing import DictStrArrNested
+from vis4d.data.typing import DictData
+from vis4d.engine.connectors import CallbackConnector
+
+
+class Callback(pl.Callback):
+ """Base class for Callbacks."""
+
+ def __init__(
+ self,
+ epoch_based: bool = True,
+ train_connector: None | CallbackConnector = None,
+ test_connector: None | CallbackConnector = None,
+ ) -> None:
+ """Init callback.
+
+ Args:
+ epoch_based (bool, optional): Whether the callback is epoch based.
+ Defaults to False.
+ train_connector (None | CallbackConnector, optional): Defines which
+ kwargs to use during training for different callbacks. Defaults
+ to None.
+ test_connector (None | CallbackConnector, optional): Defines which
+ kwargs to use during testing for different callbacks. Defaults
+ to None.
+ """
+ self.epoch_based = epoch_based
+ self.train_connector = train_connector
+ self.test_connector = test_connector
+
+ def setup(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
+ ) -> None:
+ """Setup callback."""
+
+ def get_train_callback_inputs(
+ self, outputs: DictData, batch: DictData
+ ) -> dict[str, Tensor | DictStrArrNested]:
+ """Returns the data connector results for training.
+
+ It extracts the required data from prediction and datas and passes it
+ to the next component with the provided new key.
+
+ Args:
+ outputs (DictData): Outputs of the model.
+ batch (DictData): Batch data.
+
+ Returns:
+ dict[str, Tensor | DictStrArrNested]: Data connector results.
+
+ Raises:
+ AssertionError: If train connector is None.
+ """
+ assert self.train_connector is not None, "Train connector is None."
+
+ return self.train_connector(outputs, batch)
+
+ def get_test_callback_inputs(
+ self, outputs: DictData, batch: DictData
+ ) -> dict[str, Tensor | DictStrArrNested]:
+ """Returns the data connector results for inference.
+
+ It extracts the required data from prediction and datas and passes it
+ to the next component with the provided new key.
+
+ Args:
+ outputs (DictData): Outputs of the model.
+ batch (DictData): Batch data.
+
+ Returns:
+ dict[str, Tensor | DictStrArrNested]: Data connector results.
+
+ Raises:
+ AssertionError: If test connector is None.
+ """
+ assert self.test_connector is not None, "Test connector is None."
+
+ return self.test_connector(outputs, batch)
diff --git a/vis4d/engine/callbacks/ema.py b/vis4d/engine/callbacks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb732fd99804031350dbd5e3e96122b6a79133e3
--- /dev/null
+++ b/vis4d/engine/callbacks/ema.py
@@ -0,0 +1,39 @@
+"""Callback for updating EMA model."""
+
+from __future__ import annotations
+
+import lightning.pytorch as pl
+
+from vis4d.common.distributed import is_module_wrapper
+from vis4d.data.typing import DictData
+from vis4d.model.adapter import ModelEMAAdapter
+
+from .base import Callback
+from .util import get_model
+
+
+class EMACallback(Callback):
+ """Callback for EMA."""
+
+ def on_train_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: DictData,
+ batch: DictData,
+ batch_idx: int,
+ ) -> None:
+ """Hook to run at the end of a training batch."""
+ model = get_model(pl_module)
+
+ if is_module_wrapper(model):
+ module = model.module
+ else:
+ module = model
+
+ assert isinstance(module, ModelEMAAdapter), (
+ "Model should be wrapped with ModelEMAAdapter when using "
+ "EMACallback."
+ )
+
+ module.update(trainer.global_step)
diff --git a/vis4d/engine/callbacks/evaluator.py b/vis4d/engine/callbacks/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2430edf10126eb22f5105d674c30bd0163c7c41f
--- /dev/null
+++ b/vis4d/engine/callbacks/evaluator.py
@@ -0,0 +1,193 @@
+"""This module contains utilities for callbacks."""
+
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import lightning.pytorch as pl
+
+from vis4d.common.distributed import (
+ all_gather_object_cpu,
+ broadcast,
+ rank_zero_only,
+ synchronize,
+)
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import ArgsType, MetricLogs
+from vis4d.data.typing import DictData
+from vis4d.eval.base import Evaluator
+
+from .base import Callback
+
+
+class EvaluatorCallback(Callback):
+ """Callback for model evaluation."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ evaluator: Evaluator,
+ metrics_to_eval: list[str] | None = None,
+ save_predictions: bool = False,
+ save_prefix: None | str = None,
+ output_dir: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Init callback.
+
+ Args:
+ evaluator (Evaluator): Evaluator.
+ metrics_to_eval (list[str], Optional): Metrics to evaluate. If
+ None, all metrics in the evaluator will be evaluated. Defaults
+ to None.
+ save_predictions (bool): If the predictions should be saved.
+ Defaults to False.
+ save_prefix (str, Optional): Output directory for saving the
+ evaluation results. Defaults to None.
+ output_dir (str, Optional): Output directory for saving the
+ evaluation results.
+ """
+ super().__init__(*args, **kwargs)
+ self.evaluator = evaluator
+ self.save_predictions = save_predictions
+ self.metrics_to_eval = metrics_to_eval or self.evaluator.metrics
+
+ if self.save_predictions:
+ assert (
+ output_dir is not None
+ ), "If save_predictions is True, save_prefix must be provided."
+
+ output_dir = os.path.join(output_dir, "eval")
+
+ self.output_dir = output_dir
+ self.save_prefix = save_prefix
+
+ def setup(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
+ ) -> None: # pragma: no cover
+ """Setup callback."""
+ if self.save_predictions:
+ self.output_dir = broadcast(self.output_dir)
+
+ if self.save_prefix is not None:
+ self.output_dir = os.path.join(
+ self.output_dir, self.save_prefix
+ )
+
+ for metric in self.metrics_to_eval:
+ output_dir = os.path.join(self.output_dir, metric)
+ os.makedirs(output_dir, exist_ok=True)
+ self.evaluator.reset()
+
+ def on_validation_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ """Hook to run at the end of a validation batch."""
+ self.on_test_batch_end(
+ trainer=trainer,
+ pl_module=pl_module,
+ outputs=outputs,
+ batch=batch,
+ batch_idx=batch_idx,
+ dataloader_idx=dataloader_idx,
+ )
+
+ def on_validation_epoch_end(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Wait for on_validation_epoch_end PL hook to call 'evaluate'."""
+ log_dict = self.run_eval()
+
+ for k, v in log_dict.items():
+ pl_module.log(f"val/{k}", v, sync_dist=True, rank_zero_only=True)
+
+ def on_test_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: DictData,
+ batch: DictData,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ """Hook to run at the end of a testing batch."""
+ self.evaluator.process_batch(
+ **self.get_test_callback_inputs(outputs, batch)
+ )
+ for metric in self.metrics_to_eval:
+ # Save output predictions in current batch.
+ if self.save_predictions:
+ output_dir = os.path.join(self.output_dir, metric)
+ self.evaluator.save_batch(metric, output_dir)
+
+ def on_test_epoch_end(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Hook to run at the end of a testing epoch."""
+ log_dict = self.run_eval()
+
+ for k, v in log_dict.items():
+ pl_module.log(f"test/{k}", v, sync_dist=True, rank_zero_only=True)
+
+ def run_eval(self) -> MetricLogs:
+ """Run evaluation for the given evaluator."""
+ self.evaluator.gather(all_gather_object_cpu)
+
+ synchronize()
+ self.process()
+
+ log_dict: MetricLogs = {}
+ for metric in self.metrics_to_eval:
+ metric_dict = self.evaluate(metric)
+ metric_dict = broadcast(metric_dict)
+ assert isinstance(metric_dict, dict)
+ log_dict.update(metric_dict)
+
+ self.evaluator.reset()
+
+ return log_dict
+
+ @rank_zero_only
+ def process(self) -> None:
+ """Process the evaluator."""
+ self.evaluator.process()
+
+ @rank_zero_only
+ def evaluate(self, metric: str) -> MetricLogs:
+ """Evaluate the performance after processing all input/output pairs.
+
+ Returns:
+ MetricLogs: A dictionary containing the evaluation results. The
+ keys are formatted as {metric_name}/{key_name}, and the
+ values are the corresponding evaluated values.
+ """
+ rank_zero_info(
+ f"Running evaluator {str(self.evaluator)} with {metric} metric... "
+ )
+ log_dict = {}
+
+ # Save output predictions. This is done here instead of
+ # on_test_batch_end because the evaluator may not have processed
+ # all batches yet.
+ if self.save_predictions:
+ output_dir = os.path.join(self.output_dir, metric)
+ self.evaluator.save(metric, output_dir)
+
+ # Evaluate metric
+ metric_dict, metric_str = self.evaluator.evaluate(metric)
+ for k, v in metric_dict.items():
+ log_k = metric + "/" + k
+ rank_zero_info("%s: %.4f", log_k, v)
+ log_dict[f"{metric}/{k}"] = v
+
+ rank_zero_info("Showing results for metric: %s", metric)
+ rank_zero_info(metric_str)
+
+ return log_dict
diff --git a/vis4d/engine/callbacks/logging.py b/vis4d/engine/callbacks/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..86fcb6e5886b528ae34f7932eaed6b10e62e5e92
--- /dev/null
+++ b/vis4d/engine/callbacks/logging.py
@@ -0,0 +1,165 @@
+"""This module contains utilities for callbacks."""
+
+from __future__ import annotations
+
+from collections import defaultdict
+from typing import Any
+
+import lightning.pytorch as pl
+
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.progress import compose_log_str
+from vis4d.common.time import Timer
+from vis4d.common.typing import ArgsType, MetricLogs
+
+from .base import Callback
+
+
+class LoggingCallback(Callback):
+ """Callback for logging."""
+
+ def __init__(
+ self, *args: ArgsType, refresh_rate: int = 50, **kwargs: ArgsType
+ ) -> None:
+ """Init callback."""
+ super().__init__(*args, **kwargs)
+ self._refresh_rate = refresh_rate
+ self._metrics: dict[str, list[float]] = defaultdict(list)
+ self.train_timer = Timer()
+ self.test_timer = Timer()
+ self.last_step = 0
+
+ def on_train_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Hook to run at the start of a training epoch."""
+ if self.epoch_based:
+ self.train_timer.reset()
+ self.last_step = 0
+ self._metrics.clear()
+ elif trainer.global_step == 0:
+ self.train_timer.reset()
+
+ def on_train_batch_start( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ batch: Any,
+ batch_idx: int,
+ ) -> None:
+ """Hook to run at the start of a training batch."""
+ if self.train_timer.paused:
+ self.train_timer.resume()
+
+ def on_train_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ ) -> None:
+ """Hook to run at the end of a training batch."""
+ if "metrics" in outputs:
+ for k, v in outputs["metrics"].items():
+ self._metrics[k].append(v)
+
+ if self.epoch_based:
+ cur_iter = batch_idx + 1
+
+ # Resolve float("inf") to -1
+ if isinstance(trainer.num_training_batches, float):
+ total_iters = -1
+ else:
+ total_iters = trainer.num_training_batches
+ else:
+ cur_iter = trainer.global_step + 1
+ total_iters = trainer.max_steps
+
+ if cur_iter % self._refresh_rate == 0 and cur_iter != self.last_step:
+ prefix = (
+ f"Epoch {pl_module.current_epoch + 1}"
+ if self.epoch_based
+ else "Iter"
+ )
+
+ log_dict: MetricLogs = {
+ k: sum(v) / len(v) if len(v) > 0 else float("NaN")
+ for k, v in self._metrics.items()
+ }
+
+ rank_zero_info(
+ compose_log_str(
+ prefix, cur_iter, total_iters, self.train_timer, log_dict
+ )
+ )
+
+ self._metrics.clear()
+ self.last_step = cur_iter
+
+ for k, v in log_dict.items():
+ pl_module.log(f"train/{k}", v, rank_zero_only=True)
+
+ def on_validation_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Hook to run at the start of a validation epoch."""
+ self.test_timer.reset()
+ self.train_timer.pause()
+
+ def on_validation_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ """Wait for on_validation_batch_end PL hook to call 'process'."""
+ cur_iter = batch_idx + 1
+
+ # Resolve float("inf") to -1
+ if isinstance(trainer.num_val_batches[dataloader_idx], int):
+ total_iters = int(trainer.num_val_batches[dataloader_idx])
+ else:
+ total_iters = -1
+
+ if cur_iter % self._refresh_rate == 0:
+ rank_zero_info(
+ compose_log_str(
+ "Validation", cur_iter, total_iters, self.test_timer
+ )
+ )
+
+ def on_test_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Hook to run at the start of a testing epoch."""
+ self.test_timer.reset()
+ self.train_timer.pause()
+
+ def on_test_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ """Hook to run at the end of a testing batch."""
+ cur_iter = batch_idx + 1
+
+ # Resolve float("inf") to -1
+ if isinstance(trainer.num_test_batches[dataloader_idx], int):
+ total_iters = int(trainer.num_test_batches[dataloader_idx])
+ else:
+ total_iters = -1
+
+ if cur_iter % self._refresh_rate == 0:
+ rank_zero_info(
+ compose_log_str(
+ "Testing", cur_iter, total_iters, self.test_timer
+ )
+ )
diff --git a/vis4d/engine/callbacks/scheduler.py b/vis4d/engine/callbacks/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8cb0a541b22ffb9cf368aea7b5c03f1d1a86bd4
--- /dev/null
+++ b/vis4d/engine/callbacks/scheduler.py
@@ -0,0 +1,44 @@
+"""Callback to configure learning rate during training."""
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Any
+
+import lightning.pytorch as pl
+
+from vis4d.engine.optim.scheduler import LRSchedulerWrapper
+
+from .base import Callback
+
+
+class LRSchedulerCallback(Callback):
+ """Callback to configure learning rate during training."""
+
+ def __init__(self) -> None:
+ """Initialize the callback."""
+ super().__init__()
+ self.last_step = 0
+
+ def on_train_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ ) -> None:
+ """Hook on training batch end."""
+ schedulers = pl_module.lr_schedulers()
+
+ if not isinstance(schedulers, Iterable):
+ schedulers = [schedulers] # type: ignore
+
+ if trainer.global_step != self.last_step:
+ for scheduler in schedulers:
+ if scheduler is None:
+ continue
+ assert isinstance(scheduler, LRSchedulerWrapper)
+ scheduler.step_on_batch(trainer.global_step)
+
+ self.last_step = trainer.global_step
diff --git a/vis4d/engine/callbacks/util.py b/vis4d/engine/callbacks/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e46038a8ef696a679777bb806d5c008e4a92ab80
--- /dev/null
+++ b/vis4d/engine/callbacks/util.py
@@ -0,0 +1,24 @@
+"""PyTorch Lightning callbacks utilities."""
+
+from __future__ import annotations
+
+import lightning.pytorch as pl
+from torch import nn
+
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.training_module import TrainingModule
+
+
+def get_model(model: pl.LightningModule) -> nn.Module:
+ """Get model from pl module."""
+ if isinstance(model, TrainingModule):
+ return model.model
+ return model
+
+
+def get_loss_module(loss_module: pl.LightningModule) -> LossModule:
+ """Get loss_module from pl module."""
+ assert hasattr(loss_module, "loss_module") and isinstance(
+ loss_module.loss_module, LossModule
+ ), "Loss module is not set in the training module."
+ return loss_module.loss_module
diff --git a/vis4d/engine/callbacks/visualizer.py b/vis4d/engine/callbacks/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..30179597fb3e653a66457409f34576b82ccc7b38
--- /dev/null
+++ b/vis4d/engine/callbacks/visualizer.py
@@ -0,0 +1,165 @@
+"""This module contains utilities for callbacks."""
+
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import lightning.pytorch as pl
+
+from vis4d.common.distributed import broadcast, synchronize
+from vis4d.common.typing import ArgsType
+from vis4d.vis.base import Visualizer
+
+from .base import Callback
+
+
+class VisualizerCallback(Callback):
+ """Callback for model visualization."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ visualizer: Visualizer,
+ visualize_train: bool = False,
+ show: bool = False,
+ save_to_disk: bool = True,
+ save_prefix: str | None = None,
+ output_dir: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Init callback.
+
+ Args:
+ visualizer (Visualizer): Visualizer.
+ visualize_train (bool): If the training data should be visualized.
+ Defaults to False.
+ show (bool): If the visualizations should be shown. Defaults to
+ False.
+ save_to_disk (bool): If the visualizations should be saved to disk.
+ Defaults to True.
+ save_prefix (str): Output directory prefix for distinguish
+ different visualizations.
+ output_dir (str): Output directory for saving the visualizations.
+ """
+ super().__init__(*args, **kwargs)
+ self.visualizer = visualizer
+ self.visualize_train = visualize_train
+ self.save_prefix = save_prefix
+ self.show = show
+ self.save_to_disk = save_to_disk
+
+ if self.save_to_disk:
+ assert (
+ output_dir is not None
+ ), "If save_to_disk is True, output_dir must be provided."
+
+ output_dir = os.path.join(output_dir, "vis")
+
+ self.output_dir = output_dir
+ self.save_prefix = save_prefix
+
+ def setup(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
+ ) -> None: # pragma: no cover
+ """Setup callback."""
+ if self.save_to_disk:
+ self.output_dir = broadcast(self.output_dir)
+
+ def on_train_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ ) -> None:
+ """Hook to run at the end of a training batch."""
+ cur_iter = batch_idx + 1
+
+ if self.visualize_train:
+ self.visualizer.process(
+ cur_iter=cur_iter,
+ **self.get_train_callback_inputs(outputs, batch),
+ )
+
+ if self.show:
+ self.visualizer.show(cur_iter=cur_iter)
+
+ if self.save_to_disk:
+ self.save(cur_iter=cur_iter, stage="train")
+
+ self.visualizer.reset()
+
+ def on_validation_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ """Hook to run at the end of a validation batch."""
+ cur_iter = batch_idx + 1
+
+ self.visualizer.process(
+ cur_iter=cur_iter,
+ **self.get_test_callback_inputs(outputs, batch),
+ )
+
+ if self.show:
+ self.visualizer.show(cur_iter=cur_iter)
+
+ if self.save_to_disk:
+ self.save(cur_iter=cur_iter, stage="val")
+
+ self.visualizer.reset()
+
+ def on_test_batch_end( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ outputs: Any,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ """Hook to run at the end of a testing batch."""
+ cur_iter = batch_idx + 1
+
+ self.visualizer.process(
+ cur_iter=cur_iter,
+ **self.get_test_callback_inputs(outputs, batch),
+ )
+
+ if self.show:
+ self.visualizer.show(cur_iter=cur_iter)
+
+ if self.save_to_disk:
+ self.save(cur_iter=cur_iter, stage="test")
+
+ self.visualizer.reset()
+
+ def save(self, cur_iter: int, stage: str) -> None:
+ """Save the visualizer state."""
+ output_folder = os.path.join(self.output_dir, stage)
+
+ if self.save_prefix is not None:
+ output_folder = os.path.join(output_folder, self.save_prefix)
+
+ os.makedirs(output_folder, exist_ok=True)
+
+ self.visualizer.save_to_disk(
+ cur_iter=cur_iter, output_folder=output_folder
+ )
+
+ # TODO: Add support for logging images to WandB.
+ # if get_rank() == 0:
+ # if isinstance(trainer.logger, WandbLogger) and image is not None:
+ # trainer.logger.log_image(
+ # key=f"{self.visualizer}/{cur_iter}",
+ # images=[image],
+ # )
+
+ synchronize()
diff --git a/vis4d/engine/callbacks/yolox_callbacks.py b/vis4d/engine/callbacks/yolox_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..868e825dc41c52bb930d76b1798d8d503cd9286d
--- /dev/null
+++ b/vis4d/engine/callbacks/yolox_callbacks.py
@@ -0,0 +1,196 @@
+"""YOLOX-specific callbacks."""
+
+from __future__ import annotations
+
+import random
+from collections import OrderedDict
+from typing import Any
+
+import lightning.pytorch as pl
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.modules.batchnorm import _NormBase
+from torch.utils.data import DataLoader
+
+from vis4d.common.distributed import (
+ all_reduce_dict,
+ broadcast,
+ get_rank,
+ get_world_size,
+ synchronize,
+)
+from vis4d.common.logging import rank_zero_info, rank_zero_warn
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.op.detect.yolox import YOLOXHeadLoss
+from vis4d.op.loss.common import l1_loss
+
+from .base import Callback
+from .util import get_loss_module, get_model
+
+
+class YOLOXModeSwitchCallback(Callback):
+ """Callback for switching the mode of YOLOX training."""
+
+ def __init__(
+ self, *args: ArgsType, switch_epoch: int, **kwargs: ArgsType
+ ) -> None:
+ """Init callback.
+
+ Args:
+ switch_epoch (int): Epoch to switch the mode.
+ """
+ super().__init__(*args, **kwargs)
+ self.switch_epoch = switch_epoch
+ self.switched = False
+
+ def on_train_epoch_end(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Hook to run at the end of a training epoch."""
+ if pl_module.current_epoch < self.switch_epoch - 1 or self.switched:
+ # TODO: Make work with resume.
+ return
+
+ loss_module = get_loss_module(pl_module)
+
+ found_loss = False
+ for loss in loss_module.losses:
+ if isinstance(loss["loss"], YOLOXHeadLoss):
+ found_loss = True
+ yolox_loss = loss["loss"]
+ break
+ rank_zero_info(
+ "Switching YOLOX training mode starting next training epoch "
+ "(turning off strong augmentations, adding L1 loss, switching to "
+ "validation every epoch)."
+ )
+ if found_loss:
+ yolox_loss.loss_l1 = l1_loss # set L1 loss function
+ else:
+ rank_zero_warn("YOLOXHeadLoss should be in LossModule.")
+ # Set data pipeline to default DataPipe to skip strong augs.
+ # Switch to checking validation every epoch.
+ dataloader = trainer.train_dataloader
+ assert dataloader is not None
+ new_dataloader = DataLoader(
+ DataPipe(dataloader.dataset.datasets),
+ batch_size=dataloader.batch_size,
+ num_workers=dataloader.num_workers,
+ collate_fn=dataloader.collate_fn,
+ sampler=dataloader.sampler,
+ persistent_workers=dataloader.persistent_workers,
+ pin_memory=dataloader.pin_memory,
+ )
+
+ pl_module.check_val_every_n_epoch = 1 # type: ignore
+
+ # Override train_dataloader method in PL datamodule.
+ # Set reload_dataloaders_every_n_epochs to 1 to use the new
+ # dataloader.
+ def train_dataloader() -> DataLoader: # type: ignore
+ """Return dataloader for training."""
+ return new_dataloader
+
+ pl_module.datamodule.train_dataloader = train_dataloader # type: ignore # pylint: disable=line-too-long
+ pl_module.reload_dataloaders_every_n_epochs = self.switch_epoch # type: ignore # pylint: disable=line-too-long
+
+ self.switched = True
+
+
+def get_norm_states(module: nn.Module) -> DictStrAny:
+ """Get the state_dict of batch norms in the module.
+
+ Args:
+ module (nn.Module): Module to get batch norm states from.
+ """
+ async_norm_states = OrderedDict()
+ for name, child in module.named_modules():
+ if isinstance(child, _NormBase):
+ for k, v in child.state_dict().items():
+ async_norm_states[".".join([name, k])] = v
+ return async_norm_states
+
+
+class YOLOXSyncNormCallback(Callback):
+ """Callback for syncing the norm states of YOLOX training."""
+
+ def on_test_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ) -> None:
+ """Hook to run at the beginning of a testing epoch."""
+ if get_world_size() > 1:
+ model = get_model(pl_module)
+ norm_states = get_norm_states(model)
+
+ if len(norm_states) > 0:
+ rank_zero_info("Synced norm states across all processes.")
+ norm_states = all_reduce_dict(norm_states, reduce_op="mean")
+ model.load_state_dict(norm_states, strict=False)
+
+
+class YOLOXSyncRandomResizeCallback(Callback):
+ """Callback for syncing random resize during YOLOX training."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ size_list: list[tuple[int, int]],
+ interval: int,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Init callback."""
+ super().__init__(*args, **kwargs)
+ self.size_list = size_list
+ self.interval = interval
+ self.random_shape = size_list[-1]
+
+ def _get_random_shape(self, device: torch.device) -> tuple[int, int]:
+ """Randomly generate shape from size_list and sync across ranks."""
+ shape_tensor = torch.zeros(2, dtype=torch.int).to(device)
+ if get_rank() == 0:
+ random_shape = random.choice(self.size_list)
+ shape_tensor[0], shape_tensor[1] = random_shape[0], random_shape[1]
+ synchronize()
+ shape_tensor = broadcast(shape_tensor, 0)
+ return (int(shape_tensor[0].item()), int(shape_tensor[1].item()))
+
+ def on_train_batch_start( # type: ignore
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ batch: Any,
+ batch_idx: int,
+ ) -> None:
+ """Hook to run at the start of a training batch."""
+ if not isinstance(batch, list):
+ batch = [batch]
+ if (trainer.global_step + 1) % self.interval == 0:
+ self.random_shape = self._get_random_shape(
+ batch[0][K.images].device
+ )
+
+ for b in batch:
+ scale_y = self.random_shape[0] / b[K.images].shape[-2]
+ scale_x = self.random_shape[1] / b[K.images].shape[-1]
+
+ if scale_y == 1 and scale_x == 1:
+ return
+
+ # resize images
+ b[K.images] = F.interpolate(
+ b[K.images],
+ size=self.random_shape,
+ mode="bilinear",
+ align_corners=False,
+ )
+ b[K.input_hw] = [
+ self.random_shape for _ in range(b[K.images].size(0))
+ ]
+
+ # resize boxes
+ for boxes in b[K.boxes2d]:
+ boxes[..., ::2] = boxes[..., ::2] * scale_x
+ boxes[..., 1::2] = boxes[..., 1::2] * scale_y
diff --git a/vis4d/engine/connectors/__init__.py b/vis4d/engine/connectors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..18454f0ffa2f7f1455e24dcb3ad7f321801661f4
--- /dev/null
+++ b/vis4d/engine/connectors/__init__.py
@@ -0,0 +1,31 @@
+"""Data connector for data connection."""
+
+from .base import CallbackConnector, DataConnector, LossConnector
+from .multi_sensor import (
+ MultiSensorCallbackConnector,
+ MultiSensorDataConnector,
+ MultiSensorLossConnector,
+ get_multi_sensor_inputs,
+)
+from .util import (
+ SourceKeyDescription,
+ data_key,
+ get_inputs_for_pred_and_data,
+ pred_key,
+ remap_pred_keys,
+)
+
+__all__ = [
+ "CallbackConnector",
+ "DataConnector",
+ "data_key",
+ "get_multi_sensor_inputs",
+ "get_inputs_for_pred_and_data",
+ "LossConnector",
+ "MultiSensorDataConnector",
+ "MultiSensorCallbackConnector",
+ "MultiSensorLossConnector",
+ "pred_key",
+ "remap_pred_keys",
+ "SourceKeyDescription",
+]
diff --git a/vis4d/engine/connectors/base.py b/vis4d/engine/connectors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..75e38d6e35e60767e41cca39c599ff98d97a775d
--- /dev/null
+++ b/vis4d/engine/connectors/base.py
@@ -0,0 +1,108 @@
+"""Base data connector to define data structures for data connection."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from torch import Tensor
+
+from vis4d.common.typing import DictStrArrNested
+from vis4d.data.typing import DictData, DictDataOrList
+
+from .util import SourceKeyDescription, get_inputs_for_pred_and_data
+
+
+class DataConnector:
+ """Defines which data to pass to which component.
+
+ It extracts the required data from a 'DictData' objects and passes it to
+ the next component with the provided new key.
+ """
+
+ def __init__(self, key_mapping: dict[str, str]):
+ """Initializes the data connector with static remapping of the keys.
+
+ Args:
+ key_mapping (dict[str, str]): Defines which kwargs to pass onto the
+ module.
+
+ Simple Example Configuration:
+
+ >>> train = dict(images = "images", gt = "gt_images)
+ >>> train_data_connector = DataConnector(train)
+ >>> test = dict(images = "images")
+ >>> test_data_connector = DataConnector(test)
+ """
+ self.key_mapping = key_mapping
+
+ def __call__(self, data: DictDataOrList) -> DictData:
+ """Returns the kwargs that are passed to the module.
+
+ Args:
+ data (DictDataorList): The data (e.g. from the dataloader) which
+ contains all data that was loaded.
+
+ Returns:
+ DictData: kwargs that are passed onto the model.
+ """
+ if isinstance(data, list):
+ return {
+ k: [d[v] for d in data] for k, v in self.key_mapping.items()
+ }
+ return {k: data[v] for k, v in self.key_mapping.items()}
+
+
+class LossConnector:
+ """Defines which data to pass to loss module of the training pipeline.
+
+ It extracts the required data from prediction and data and passes it to
+ the next component with the provided new key.
+ """
+
+ def __init__(self, key_mapping: dict[str, SourceKeyDescription]) -> None:
+ """Initializes the data connector with static remapping of the keys."""
+ self.key_mapping = key_mapping
+
+ def __call__(
+ self, prediction: DictData | NamedTuple, data: DictData
+ ) -> dict[str, Tensor | DictStrArrNested]:
+ """Returns the kwargs that are passed to the loss module.
+
+ Args:
+ prediction (DictData | NamedTuple): The output from model.
+ data (DictData): The data dictionary from the dataloader which
+ contains all data that was loaded.
+
+ Returns:
+ dict[str, Tensor | DictStrArrNested]: kwargs that are passed
+ onto the loss.
+ """
+ return get_inputs_for_pred_and_data(self.key_mapping, prediction, data)
+
+
+class CallbackConnector:
+ """Data connector for the callback.
+
+ It extracts the required data from prediction and datas and passes it to
+ the next component with the provided new key.
+ """
+
+ def __init__(self, key_mapping: dict[str, SourceKeyDescription]) -> None:
+ """Initializes the data connector with static remapping of the keys."""
+ self.key_mapping = key_mapping
+
+ def __call__(
+ self, prediction: DictData | NamedTuple, data: DictData
+ ) -> dict[str, Tensor | DictStrArrNested]:
+ """Returns the kwargs that are passed to the callback.
+
+ Args:
+ prediction (DictData | NamedTuple): The output from model.
+ data (DictData): The data dictionary from the dataloader which
+ contains all data that was loaded.
+
+ Returns:
+ dict[str, Tensor | DictStrArrNested]: kwargs that are passed
+ onto the callback.
+ """
+ return get_inputs_for_pred_and_data(self.key_mapping, prediction, data)
diff --git a/vis4d/engine/connectors/multi_sensor.py b/vis4d/engine/connectors/multi_sensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..abfa46fdd8270421fef659e249998071a9cbe892
--- /dev/null
+++ b/vis4d/engine/connectors/multi_sensor.py
@@ -0,0 +1,147 @@
+"""Data connector for multi-sensor dataset."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from vis4d.data.typing import DictData, DictDataOrList
+
+from .base import CallbackConnector, DataConnector, LossConnector
+from .util import SourceKeyDescription, get_field_from_prediction
+
+
+class MultiSensorDataConnector(DataConnector):
+ """Data connector for multi-sensor data dict."""
+
+ def __init__(self, key_mapping: dict[str, str | SourceKeyDescription]):
+ """Initializes the data connector with static remapping of the keys.
+
+ Args:
+ key_mapping (dict[str, | SourceKeyDescription]): Defines which
+ kwargs to pass onto the module.
+
+ TODO: Add Simple Example Configuration:
+ """
+ _key_mapping = {}
+ multi_sensor_key_mapping = {}
+
+ for k, v in key_mapping.items():
+ if isinstance(v, dict):
+ sensors = v.get("sensors")
+ if sensors is not None:
+ multi_sensor_key_mapping[k] = v
+ else:
+ _key_mapping[k] = v["key"]
+ else:
+ _key_mapping[k] = v
+
+ super().__init__(_key_mapping)
+ self.multi_sensor_key_mapping = multi_sensor_key_mapping
+
+ def __call__(self, data: DictDataOrList) -> DictData:
+ """Returns the train input for the model."""
+ input_dict = super().__call__(data)
+
+ for target_key, source_key in self.multi_sensor_key_mapping.items():
+ key = source_key["key"]
+ sensors = source_key["sensors"]
+
+ if isinstance(data, list):
+ input_dict[target_key] = [
+ [d[sensor][key] for sensor in sensors] for d in data
+ ]
+ else:
+ input_dict[target_key] = [
+ data[sensor][key] for sensor in sensors
+ ]
+ return input_dict
+
+
+class MultiSensorLossConnector(LossConnector):
+ """Multi-sensor Data connector for loss module of the training pipeline."""
+
+ def __call__(
+ self, prediction: DictData | NamedTuple, data: DictData
+ ) -> DictData:
+ """Returns the kwargs that are passed to the loss module.
+
+ Args:
+ prediction (DictData | NamedTuple): The output from model.
+ data (DictData): The data dictionary from the dataloader which
+ contains all data that was loaded.
+
+ Returns:
+ DictData: kwargs that are passed onto the loss.
+ """
+ return get_multi_sensor_inputs(self.key_mapping, prediction, data)
+
+
+class MultiSensorCallbackConnector(CallbackConnector):
+ """Multi-sensor data connector for the callback."""
+
+ def __call__(
+ self, prediction: DictData | NamedTuple, data: DictData
+ ) -> DictData:
+ """Returns the kwargs that are passed to the callback.
+
+ Args:
+ prediction (DictData | NamedTuple): The output from model.
+ data (DictData): The data dictionary from the dataloader which
+ contains all data that was loaded.
+
+ Returns:
+ DictData: kwargs that are passed onto the callback.
+ """
+ return get_multi_sensor_inputs(self.key_mapping, prediction, data)
+
+
+def get_multi_sensor_inputs(
+ connection_dict: dict[str, SourceKeyDescription],
+ prediction: DictData | NamedTuple,
+ data: DictData,
+) -> DictData:
+ """Extracts multi-sensor input data from the provided SourceKeyDescription.
+
+ Args:
+ connection_dict (dict[str, SourceKeyDescription]): Input Key
+ description which is used to gather and remap data from the
+ two data dicts.
+ prediction (DictData): Dict containing the model prediction output.
+ data (DictData): Dict containing the dataloader output.
+
+ Raises:
+ ValueError: If the datasource is invalid.
+
+ Returns:
+ out (DictData): Dict containing new kwargs consisting of new key name
+ and data extracted from the data dicts.
+ """
+ out: DictData = {}
+ for new_key_name, old_key_name in connection_dict.items():
+ # Assign field from data
+ if old_key_name["source"] == "data":
+ sensors = old_key_name.get("sensors")
+
+ if sensors is None:
+ if old_key_name["key"] not in data:
+ raise ValueError(
+ f"Key {old_key_name['key']} not found in data dict."
+ f" Available keys: {data.keys()}"
+ )
+ out[new_key_name] = data[old_key_name["key"]]
+ else:
+ out[new_key_name] = [
+ data[sensor][old_key_name["key"]] for sensor in sensors
+ ]
+
+ # Assign field from prediction
+ elif old_key_name["source"] == "prediction":
+ out[new_key_name] = get_field_from_prediction(
+ prediction, old_key_name
+ )
+ else:
+ raise ValueError(
+ f"Unknown data source {old_key_name['source']}."
+ f"Available: [prediction, data]"
+ )
+ return out
diff --git a/vis4d/engine/connectors/util.py b/vis4d/engine/connectors/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a101215c9fca108e314a9cc5cde0ea4203bbfc0
--- /dev/null
+++ b/vis4d/engine/connectors/util.py
@@ -0,0 +1,152 @@
+"""Utility functions for the connectors module."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from copy import deepcopy
+from typing import NamedTuple, TypedDict
+
+from torch import Tensor
+from typing_extensions import NotRequired
+
+from vis4d.common.dict import get_dict_nested
+from vis4d.common.named_tuple import get_from_namedtuple, is_namedtuple
+from vis4d.common.typing import DictStrArrNested
+from vis4d.data.typing import DictData
+
+
+class SourceKeyDescription(TypedDict):
+ """Defines a data entry by providing the key and source of the data.
+
+ Attributes:
+ key (str): Key that is used to index data from the specified source
+ source (str): Which datasource to choose from.
+ Options are ['data', 'prediction'] where data referes to the
+ output of the dataloader and prediction refers to the model
+ output
+ sensors (Sequence[str]): Which sensors to use for the data.
+ """
+
+ key: str
+ source: str
+ sensors: NotRequired[Sequence[str]]
+
+
+def remap_pred_keys(
+ info: dict[str, SourceKeyDescription], parent_key: str
+) -> dict[str, SourceKeyDescription]:
+ """Remaps the key of a connection mapping to a new parent key.
+
+ Args:
+ info (SourceKeyDescription): Description to remap.
+ parent_key (str): New parent_key to use.
+
+ Returns:
+ SourceKeyDescription: Description with new key.
+
+ """
+ info = deepcopy(info)
+
+ for value in info.values():
+ if value["source"] == "prediction":
+ value["key"] = parent_key + "." + value["key"]
+ return info
+
+
+def data_key(
+ key: str, sensors: Sequence[str] | None = None
+) -> SourceKeyDescription:
+ """Returns a SourceKeyDescription with data as source.
+
+ Args:
+ key (str): Key to use for the data entry.
+ sensors (Sequence[str] | None, optional): Which sensors to use for the
+ data. Defaults to None.
+
+ Returns:
+ SourceKeyDescription: A SourceKeyDescription with data as source.
+ """
+ if sensors is None:
+ return SourceKeyDescription(key=key, source="data")
+
+ return SourceKeyDescription(key=key, source="data", sensors=sensors)
+
+
+def pred_key(key: str) -> SourceKeyDescription:
+ """Returns a SourceKeyDescription with prediction as source.
+
+ Args:
+ key (str): Key to use for the data entry.
+
+ Returns:
+ SourceKeyDescription: A SourceKeyDescription with prediction as source.
+ """
+ return SourceKeyDescription(key=key, source="prediction")
+
+
+def get_field_from_prediction(
+ prediction: DictData | NamedTuple,
+ old_key_name: SourceKeyDescription,
+) -> Tensor | DictStrArrNested:
+ """Extracts a field from the prediction dict.
+
+ Args:
+ prediction (DictData): Dict containing the model prediction output.
+ old_key_name (SourceKeyDescription): Description of the data to
+ extract.
+
+ Returns:
+ Tensor | DictStrArrNested: Data extracted from the prediction dict.
+ """
+ if is_namedtuple(prediction):
+ return get_from_namedtuple(
+ prediction, old_key_name["key"] # type: ignore
+ )
+
+ old_key = old_key_name["key"]
+ return get_dict_nested(prediction, old_key.split(".")) # type: ignore
+
+
+def get_inputs_for_pred_and_data(
+ connection_dict: dict[str, SourceKeyDescription],
+ prediction: DictData | NamedTuple,
+ data: DictData,
+) -> dict[str, Tensor | DictStrArrNested]:
+ """Extracts input data from the provided SourceKeyDescription.
+
+ Args:
+ connection_dict (dict[str, SourceKeyDescription]): Input Key
+ description which is used to gather and remap data from the
+ two data dicts.
+ prediction (DictData): Dict containing the model prediction output.
+ data (DictData): Dict containing the dataloader output.
+
+ Raises:
+ ValueError: If the datasource is invalid.
+
+ Returns:
+ out (dict[str, Tensor | DictStrArrNested]): Dict containing new kwargs
+ consisting of new key name and data extracted from the data dicts.
+ """
+ out = {}
+ for new_key_name, old_key_name in connection_dict.items():
+ # Assign field from data
+ if old_key_name["source"] == "data":
+ if old_key_name["key"] not in data:
+ raise ValueError(
+ f"Key {old_key_name['key']} not found in data dict."
+ f" Available keys: {data.keys()}"
+ )
+ out[new_key_name] = data[old_key_name["key"]]
+
+ # Assign field from model prediction
+ elif old_key_name["source"] == "prediction":
+ out[new_key_name] = get_field_from_prediction(
+ prediction, old_key_name
+ )
+ else:
+ raise ValueError(
+ f"Unknown data source {old_key_name['source']}."
+ f" Available: [prediction, data]"
+ )
+ return out
diff --git a/vis4d/engine/data_module.py b/vis4d/engine/data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..de307a6e1519a4de29549c1aa887ff5edfcc9478
--- /dev/null
+++ b/vis4d/engine/data_module.py
@@ -0,0 +1,39 @@
+"""Data module composing the data loading pipeline."""
+
+from __future__ import annotations
+
+import lightning.pytorch as pl
+from torch.utils.data import DataLoader
+
+from vis4d.config import instantiate_classes
+from vis4d.config.typing import DataConfig
+from vis4d.data.typing import DictData
+
+
+class DataModule(pl.LightningDataModule):
+ """DataModule that wraps around the vis4d implementations.
+
+ This is a wrapper around the vis4d implementations that allows to use
+ pytorch-lightning for training and testing.
+ """
+
+ def __init__(self, data_cfg: DataConfig) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.data_cfg = data_cfg
+
+ def train_dataloader(self) -> DataLoader[DictData]:
+ """Return dataloader for training."""
+ if self.trainer is not None and hasattr(self.trainer, "seed"):
+ seed = self.trainer.seed
+ else:
+ seed = None
+ return instantiate_classes(self.data_cfg.train_dataloader, seed=seed)
+
+ def test_dataloader(self) -> list[DataLoader[DictData]]:
+ """Return dataloaders for testing."""
+ return instantiate_classes(self.data_cfg.test_dataloader)
+
+ def val_dataloader(self) -> list[DataLoader[DictData]]:
+ """Return dataloaders for validation."""
+ return self.test_dataloader()
diff --git a/vis4d/engine/flag.py b/vis4d/engine/flag.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0b9cfc3a3040ae0c29f987674e96e0edc49f6c
--- /dev/null
+++ b/vis4d/engine/flag.py
@@ -0,0 +1,34 @@
+"""Engine Flags."""
+
+from absl import flags
+
+from .parser import DEFINE_config_file
+
+_CONFIG = DEFINE_config_file("config", method_name="get_config")
+_GPUS = flags.DEFINE_integer("gpus", default=0, help="Number of GPUs per node")
+_NODES = flags.DEFINE_integer("nodes", default=1, help="Number of nodes")
+_WANDB = flags.DEFINE_bool(
+ "wandb", default=False, help="If set, use Weights & Biases for logging."
+)
+_CKPT = flags.DEFINE_string("ckpt", default=None, help="Checkpoint path")
+_RESUME = flags.DEFINE_bool("resume", default=False, help="Resume training")
+_SHOW_CONFIG = flags.DEFINE_bool(
+ "print-config", default=False, help="If set, prints the configuration."
+)
+_VIS = flags.DEFINE_bool(
+ "vis",
+ default=False,
+ help="If set, running visualization using visualizer callback.",
+)
+
+
+__all__ = [
+ "_CONFIG",
+ "_GPUS",
+ "_NODES",
+ "_CKPT",
+ "_RESUME",
+ "_SHOW_CONFIG",
+ "_WANDB",
+ "_VIS",
+]
diff --git a/vis4d/engine/loss_module.py b/vis4d/engine/loss_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc72531781898395ea559b5ff35be1fd3683cbe2
--- /dev/null
+++ b/vis4d/engine/loss_module.py
@@ -0,0 +1,215 @@
+"""Loss module maps loss function input keys and controls loss weight."""
+
+from __future__ import annotations
+
+from typing import TypedDict, Union
+
+import torch
+from torch import Tensor, nn
+from typing_extensions import NotRequired
+
+from vis4d.common.named_tuple import is_namedtuple
+from vis4d.common.typing import LossesType
+from vis4d.data.typing import DictData
+from vis4d.engine.connectors import LossConnector
+from vis4d.op.loss.base import Loss
+
+NestedLossesType = Union[dict[str, "NestedLossesType"], LossesType]
+
+
+class LossDefinition(TypedDict):
+ """Loss definition.
+
+ Attributes:
+ loss (Loss | nn.Module): Loss function to use.
+ connector (LossConnector): Connector to use for the loss.
+ weight (float | dict[str, float], optional): Weight to use for the
+ loss.
+ name (str, optional): Name to use for the loss.
+ """
+
+ loss: Loss | nn.Module
+ connector: LossConnector
+ weight: NotRequired[float | dict[str, float]]
+ name: NotRequired[str]
+
+
+def _get_tensors_nested(
+ loss_dict: NestedLossesType, prefix: str = ""
+) -> list[tuple[str, Tensor]]:
+ """Get tensors from loss dict.
+
+ Args:
+ loss_dict (LossesType): Loss dict.
+ prefix (str, optional): Prefix to add to keys. Defaults to "".
+
+ Returns:
+ list[tuple[str, Tensor]]: List of tensors.
+
+ Raises:
+ ValueError: If loss dict contains non-tensor or dict values.
+ """
+ named_tensors: list[tuple[str, Tensor]] = []
+ for key in loss_dict:
+ value = loss_dict[key]
+
+ if isinstance(value, Tensor):
+ named_tensors.append((prefix + key, value))
+ elif isinstance(value, dict):
+ named_tensors.extend(
+ _get_tensors_nested(value, prefix + key + ".")
+ )
+ else:
+ raise ValueError(
+ f"Loss dict must only contain tensors or dicts. "
+ f"Found {type(loss_dict[key])} at {prefix + key}."
+ )
+ return named_tensors
+
+
+class LossModule(nn.Module):
+ """Loss module maps input keys and combines losses with weights.
+
+ This loss combines multiple losses with weights. The loss values are
+ weighted by the corresponding weight and returned as a dictionary.
+ """
+
+ def __init__(
+ self,
+ losses: list[LossDefinition] | LossDefinition,
+ exclude_attributes: list[str] | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Each loss will be called with arguments matching the kwargs of the loss
+ function through its connector. By default, the weight is set to 1.0.
+
+ Args:
+ losses (list[LossDefinition]): List of loss definitions.
+ exclude_attributes (list[str] | None): List of attributes returned
+ by the losses that should be excluded from the total loss
+ computation. Use it to log metrics that should not be
+ optimised. Defaults to None.
+
+ Example:
+ >>> loss = LossModule(
+ >>> [
+ >>> {
+ >>> "loss": nn.MSELoss(),
+ >>> "weight": 0.7,
+ >>> "connector": LossConnector(
+ >>> {
+ >>> "input": pred_key("input"),
+ >>> "target": data_key("target"),
+ >>> }
+ >>> ),
+ >>> },
+ >>> {
+ >>> "loss": nn.L1Loss(),
+ >>> "weight": 0.3
+ >>> "connector": LossConnector(
+ >>> {
+ >>> "input": pred_key("input"),
+ >>> "target": data_key("target"),
+ >>> }
+ >>> ),
+ >>> },
+ >>> ]
+ >>> )
+ """
+ super().__init__()
+ self.losses: list[LossDefinition] = []
+
+ if not isinstance(losses, list):
+ losses = [losses]
+
+ for loss in losses:
+ assert "loss" in loss, "Loss definition must contain a loss."
+ assert (
+ "connector" in loss
+ ), "Loss definition must contain a connector."
+
+ if "name" not in loss:
+ loss["name"] = loss["loss"].__class__.__name__
+
+ if "weight" not in loss:
+ loss["weight"] = 1.0
+
+ self.losses.append(loss)
+
+ self.exclude_attributes = exclude_attributes
+
+ def forward(
+ self, output: DictData, batch: DictData
+ ) -> tuple[Tensor, dict[str, float]]:
+ """Forward of loss module.
+
+ This function will call all loss functions and return a dictionary
+ containing the loss values. The loss values are weighted by the
+ corresponding weight.
+
+ If two losses have the same name, the name will be appended with
+ two underscores.
+
+ Args:
+ output (DictData): Output of the model.
+ batch (DictData): Batch data.
+
+ Returns:
+ total_loss: The total loss value.
+ metrics: The metrics disctionary.
+ """
+ loss_dict: LossesType = {}
+
+ for loss in self.losses:
+ loss_values_as_dict: LossesType = {}
+ name = loss["name"]
+
+ loss_value = loss["loss"](**loss["connector"](output, batch))
+
+ # Convert loss value to one level dict.
+ if isinstance(loss_value, Tensor):
+ # Loss returned a simple tensor
+ loss_values_as_dict[name] = loss_value
+ elif isinstance(loss_value, dict):
+ # Loss returned a dictionary.
+ for loss_name, loss_value in _get_tensors_nested(
+ loss_value, name + "."
+ ):
+ loss_values_as_dict[loss_name] = loss_value
+ elif is_namedtuple(loss_value):
+ # Loss returned a named tuple.
+ for loss_name, loss_value in zip(
+ loss_value._fields, loss_value
+ ):
+ loss_values_as_dict[name + "." + loss_name] = loss_value
+
+ # Assign values
+ for key, value in loss_values_as_dict.items():
+ if value is None:
+ continue
+
+ if isinstance(loss["weight"], dict):
+ loss_weight = loss["weight"].get(key, 1.0)
+ else:
+ loss_weight = loss["weight"]
+
+ while key in loss_dict:
+ key = "__" + key
+
+ loss_dict[key] = torch.mul(loss_weight, value)
+
+ # Convert loss_dict to total loss and metrics dictionary
+ metrics: dict[str, float] = {}
+ keep_loss_dict: LossesType = {}
+ for k, v in loss_dict.items():
+ metrics[k] = v.detach().cpu().item()
+ if (
+ self.exclude_attributes is None
+ or k not in self.exclude_attributes
+ ):
+ keep_loss_dict[k] = v
+ total_loss: Tensor = sum(keep_loss_dict.values()) # type: ignore
+ metrics["loss"] = total_loss.detach().cpu().item()
+
+ return total_loss, metrics
diff --git a/vis4d/engine/optim/__init__.py b/vis4d/engine/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f88d69eb6f467b93a3f7fbc1104d7c44d2880e
--- /dev/null
+++ b/vis4d/engine/optim/__init__.py
@@ -0,0 +1,17 @@
+"""Optimizer modules."""
+
+from .optimizer import set_up_optimizers
+from .scheduler import (
+ ConstantLR,
+ LRSchedulerWrapper,
+ PolyLR,
+ QuadraticLRWarmup,
+)
+
+__all__ = [
+ "set_up_optimizers",
+ "LRSchedulerWrapper",
+ "PolyLR",
+ "ConstantLR",
+ "QuadraticLRWarmup",
+]
diff --git a/vis4d/engine/optim/optimizer.py b/vis4d/engine/optim/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..33629ee5ba5882cc7fb3a541cbf5e950a767ca71
--- /dev/null
+++ b/vis4d/engine/optim/optimizer.py
@@ -0,0 +1,166 @@
+"""Optimizer."""
+
+from __future__ import annotations
+
+from typing import TypedDict
+
+from torch import nn
+from torch.nn import GroupNorm, LayerNorm
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.instancenorm import _InstanceNorm
+from torch.optim.optimizer import Optimizer
+from typing_extensions import NotRequired
+
+from vis4d.common.logging import rank_zero_info
+from vis4d.config import instantiate_classes
+from vis4d.config.typing import OptimizerConfig, ParamGroupCfg
+
+from .scheduler import LRSchedulerWrapper
+
+
+class ParamGroup(TypedDict):
+ """Parameter dictionary.
+
+ Attributes:
+ params (list[nn.Parameter]): List of parameters.
+ lr (NotRequired[float]): Learning rate.
+ weight_decay (NotRequired[float]): Weight decay.
+ """
+
+ params: list[nn.Parameter]
+ lr: NotRequired[float]
+ weight_decay: NotRequired[float]
+
+
+# TODO: Add true support for multiple optimizers. This will need to
+# modify config to specify which optimizer to use for which module.
+def set_up_optimizers(
+ optimizers_cfg: list[OptimizerConfig],
+ models: list[nn.Module],
+ steps_per_epoch: int = -1,
+) -> tuple[list[Optimizer], list[LRSchedulerWrapper]]:
+ """Set up optimizers."""
+ optimizers = []
+ lr_schedulers = []
+ for optim_cfg, model in zip(optimizers_cfg, models):
+ optimizer = configure_optimizer(optim_cfg, model)
+ optimizers.append(optimizer)
+
+ if optim_cfg.lr_schedulers is not None:
+ lr_schedulers.append(
+ LRSchedulerWrapper(
+ optim_cfg.lr_schedulers, optimizer, steps_per_epoch
+ )
+ )
+
+ return optimizers, lr_schedulers
+
+
+def configure_optimizer(
+ optim_cfg: OptimizerConfig, model: nn.Module
+) -> Optimizer:
+ """Configure optimizer with parameter groups."""
+ param_groups_cfg = optim_cfg.get("param_groups", None)
+
+ if param_groups_cfg is None:
+ return instantiate_classes(
+ optim_cfg.optimizer, params=model.parameters()
+ )
+
+ params = []
+ base_lr = optim_cfg.optimizer["init_args"].lr
+ weight_decay = optim_cfg.optimizer["init_args"].get("weight_decay", None)
+ for group in param_groups_cfg:
+ lr_mult = group.get("lr_mult", 1.0)
+ decay_mult = group.get("decay_mult", 1.0)
+ norm_decay_mult = group.get("norm_decay_mult", None)
+ bias_decay_mult = group.get("bias_decay_mult", None)
+
+ param_group: ParamGroup = {"params": [], "lr": base_lr * lr_mult}
+
+ if weight_decay is not None:
+ if norm_decay_mult is not None:
+ param_group["weight_decay"] = weight_decay * norm_decay_mult
+ elif bias_decay_mult is not None:
+ param_group["weight_decay"] = weight_decay * bias_decay_mult
+ else:
+ param_group["weight_decay"] = weight_decay * decay_mult
+
+ params.append(param_group)
+
+ # Create a param group for the rest of the parameters
+ param_group = {"params": [], "lr": base_lr}
+ if weight_decay is not None:
+ param_group["weight_decay"] = weight_decay
+ params.append(param_group)
+
+ # Add the parameters to the param groups
+ add_params(params, model, param_groups_cfg)
+
+ return instantiate_classes(optim_cfg.optimizer, params=params)
+
+
+def add_params(
+ params: list[ParamGroup],
+ module: nn.Module,
+ param_groups_cfg: list[ParamGroupCfg],
+ prefix: str = "",
+) -> None:
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[DictStrAny]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ param_groups_cfg (dict[str, list[str] | float]): The configuration
+ of the param groups.
+ prefix (str): The prefix of the module. Default: ''.
+ """
+ for name, param in module.named_parameters(recurse=False):
+ if not param.requires_grad:
+ params[-1]["params"].append(param)
+ continue
+
+ is_norm = isinstance(
+ module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+ )
+
+ # if the parameter match one of the custom keys, ignore other rules
+ is_custom = False
+ msg = f"{prefix}.{name}"
+ for i, group in enumerate(param_groups_cfg):
+ for key in group["custom_keys"]:
+ if key not in f"{prefix}.{name}":
+ continue
+ norm_decay_mult = group.get("norm_decay_mult", None)
+ bias_decay_mult = group.get("bias_decay_mult", None)
+ if group.get("lr_mult", None) is not None:
+ msg += f" with lr_mult: {group['lr_mult']}"
+ if norm_decay_mult is not None:
+ if not is_norm:
+ continue
+ msg += f" with norm_decay_mult: {norm_decay_mult}"
+ if bias_decay_mult is not None:
+ if name != "bias":
+ continue
+ msg += f" with bias_decay_mult: {bias_decay_mult}"
+ if group.get("decay_mult", None) is not None:
+ msg += f" with decay_mult: {group['decay_mult']}"
+ params[i]["params"].append(param)
+ is_custom = True
+ break
+ if is_custom:
+ break
+
+ if is_custom:
+ rank_zero_info(msg)
+ else:
+ # add parameter to the last param group
+ params[-1]["params"].append(param)
+
+ for child_name, child_mod in module.named_children():
+ child_prefix = f"{prefix}.{child_name}" if prefix else child_name
+ add_params(params, child_mod, param_groups_cfg, prefix=child_prefix)
diff --git a/vis4d/engine/optim/scheduler.py b/vis4d/engine/optim/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..70288f09b1a9259a7266c8e61ef397f13d167a67
--- /dev/null
+++ b/vis4d/engine/optim/scheduler.py
@@ -0,0 +1,262 @@
+# pylint: disable=no-member
+"""LR schedulers."""
+
+from __future__ import annotations
+
+from typing import TypedDict
+
+from torch.optim.lr_scheduler import LRScheduler
+from torch.optim.optimizer import Optimizer
+
+from vis4d.common.typing import DictStrAny
+from vis4d.config import copy_and_resolve_references, instantiate_classes
+from vis4d.config.typing import LrSchedulerConfig
+
+
+class LRSchedulerDict(TypedDict):
+ """LR scheduler."""
+
+ scheduler: LRScheduler
+ begin: int
+ end: int
+ epoch_based: bool
+
+
+class LRSchedulerWrapper(LRScheduler):
+ """LR scheduler wrapper."""
+
+ def __init__(
+ self,
+ lr_schedulers_cfg: list[LrSchedulerConfig],
+ optimizer: Optimizer,
+ steps_per_epoch: int = -1,
+ ) -> None:
+ """Initialize LRSchedulerWrapper."""
+ self.lr_schedulers_cfg: list[LrSchedulerConfig] = (
+ copy_and_resolve_references(lr_schedulers_cfg)
+ )
+ self.lr_schedulers: dict[int, LRSchedulerDict] = {}
+ super().__init__(optimizer)
+ self.steps_per_epoch = steps_per_epoch
+ self._convert_epochs_to_steps()
+
+ for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg):
+ if lr_scheduler_cfg["begin"] == 0:
+ self._instantiate_lr_scheduler(i, lr_scheduler_cfg)
+
+ def _convert_epochs_to_steps(self) -> None:
+ """Convert epochs to steps."""
+ for lr_scheduler_cfg in self.lr_schedulers_cfg:
+ if (
+ lr_scheduler_cfg["convert_epochs_to_steps"]
+ and not lr_scheduler_cfg["epoch_based"]
+ ):
+ lr_scheduler_cfg["begin"] *= self.steps_per_epoch
+ lr_scheduler_cfg["end"] *= self.steps_per_epoch
+ if lr_scheduler_cfg["convert_attributes"] is not None:
+ for attr in lr_scheduler_cfg["convert_attributes"]:
+ lr_scheduler_cfg["scheduler"]["init_args"][
+ attr
+ ] *= self.steps_per_epoch
+
+ def _instantiate_lr_scheduler(
+ self, scheduler_idx: int, lr_scheduler_cfg: LrSchedulerConfig
+ ) -> None:
+ """Instantiate LR schedulers."""
+ # OneCycleLR needs max_lr to be set
+ if "max_lr" in lr_scheduler_cfg["scheduler"]["init_args"]:
+ lr_scheduler_cfg["scheduler"]["init_args"]["max_lr"] = [
+ pg["lr"] for pg in self.optimizer.param_groups
+ ]
+
+ self.lr_schedulers[scheduler_idx] = {
+ "scheduler": instantiate_classes(
+ lr_scheduler_cfg["scheduler"], optimizer=self.optimizer
+ ),
+ "begin": lr_scheduler_cfg["begin"],
+ "end": lr_scheduler_cfg["end"],
+ "epoch_based": lr_scheduler_cfg["epoch_based"],
+ }
+
+ def get_lr(self) -> list[float]:
+ """Get current learning rate."""
+ lr = []
+ for lr_scheduler in self.lr_schedulers.values():
+ lr.extend(lr_scheduler["scheduler"].get_lr())
+ return lr
+
+ def state_dict(self) -> dict[int, DictStrAny]: # type: ignore
+ """Get state dict."""
+ state_dict = {}
+ for scheduler_idx, lr_scheduler in self.lr_schedulers.items():
+ state_dict[scheduler_idx] = lr_scheduler["scheduler"].state_dict()
+ return state_dict
+
+ def load_state_dict(
+ self, state_dict: dict[int, DictStrAny] # type: ignore
+ ) -> None:
+ """Load state dict."""
+ for scheduler_idx, _state_dict in state_dict.items():
+ # Instantiate the lr scheduler if it is not instantiated yet
+ if not scheduler_idx in self.lr_schedulers:
+ self._instantiate_lr_scheduler(
+ scheduler_idx, self.lr_schedulers_cfg[scheduler_idx]
+ )
+ self.lr_schedulers[scheduler_idx]["scheduler"].load_state_dict(
+ _state_dict
+ )
+
+ def _step_lr(self, lr_scheduler: LRSchedulerDict, step: int) -> None:
+ """Step the learning rate."""
+ if lr_scheduler["begin"] <= step and (
+ lr_scheduler["end"] == -1 or lr_scheduler["end"] >= step
+ ):
+ lr_scheduler["scheduler"].step()
+
+ def step(self, epoch: int | None = None) -> None:
+ """Step on training epoch end."""
+ if epoch is not None:
+ for lr_scheduler in self.lr_schedulers.values():
+ if lr_scheduler["epoch_based"]:
+ self._step_lr(lr_scheduler, epoch)
+
+ for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg):
+ if lr_scheduler_cfg["epoch_based"] and (
+ lr_scheduler_cfg["begin"] == epoch + 1
+ ):
+ self._instantiate_lr_scheduler(i, lr_scheduler_cfg)
+
+ def step_on_batch(self, step: int) -> None:
+ """Step on training batch end."""
+ for lr_scheduler in self.lr_schedulers.values():
+ if not lr_scheduler["epoch_based"]:
+ self._step_lr(lr_scheduler, step)
+
+ for i, lr_scheduler_cfg in enumerate(self.lr_schedulers_cfg):
+ if not lr_scheduler_cfg["epoch_based"] and (
+ lr_scheduler_cfg["begin"] == step
+ ):
+ self._instantiate_lr_scheduler(i, lr_scheduler_cfg)
+
+
+class ConstantLR(LRScheduler):
+ """Constant learning rate scheduler.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ max_steps (int): Maximum number of steps.
+ factor (float): Scale factor. Default: 1.0 / 3.0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ max_steps: int,
+ factor: float = 1.0 / 3.0,
+ last_epoch: int = -1,
+ ):
+ """Initialize ConstantLR."""
+ self.max_steps = max_steps
+ self.factor = factor
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> list[float]:
+ """Compute current learning rate."""
+ step_count = self._step_count - 1
+ if step_count == 0:
+ return [
+ group["lr"] * self.factor
+ for group in self.optimizer.param_groups
+ ]
+ if step_count == self.max_steps:
+ return [
+ group["lr"] * (1.0 / self.factor)
+ for group in self.optimizer.param_groups
+ ]
+ return [group["lr"] for group in self.optimizer.param_groups]
+
+
+class PolyLR(LRScheduler):
+ """Polynomial learning rate decay.
+
+ Example:
+ Assuming lr = 0.001, max_steps = 4, min_lr = 0.0, and power = 1.0, the
+ learning rate will be:
+ lr = 0.001 if step == 0
+ lr = 0.00075 if step == 1
+ lr = 0.00050 if step == 2
+ lr = 0.00025 if step == 3
+ lr = 0.0 if step >= 4
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ max_steps (int): Maximum number of steps.
+ power (float, optional): Power factor. Default: 1.0.
+ min_lr (float): Minimum learning rate. Default: 0.0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ max_steps: int,
+ power: float = 1.0,
+ min_lr: float = 0.0,
+ last_epoch: int = -1,
+ ):
+ """Initialize PolyLRScheduler."""
+ self.max_steps = max_steps
+ self.power = power
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> list[float]:
+ """Compute current learning rate."""
+ step_count = self._step_count - 1
+ if step_count == 0 or step_count > self.max_steps:
+ return [group["lr"] for group in self.optimizer.param_groups]
+ decay_factor = (
+ (1.0 - step_count / self.max_steps)
+ / (1.0 - (step_count - 1) / self.max_steps)
+ ) ** self.power
+ return [
+ (group["lr"] - self.min_lr) * decay_factor + self.min_lr
+ for group in self.optimizer.param_groups
+ ]
+
+
+class QuadraticLRWarmup(LRScheduler):
+ """Quadratic learning rate warmup.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ max_steps (int): Maximum number of steps.
+ last_epoch (int): The index of last epoch. Default: -1.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ max_steps: int,
+ last_epoch: int = -1,
+ ):
+ """Initialize QuadraticLRWarmup."""
+ self.max_steps = max_steps
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self) -> list[float]:
+ """Compute current learning rate."""
+ step_count = self._step_count - 1
+ if step_count >= self.max_steps:
+ return self.base_lrs
+ factors = [
+ base_lr * (2 * step_count + 1) / self.max_steps**2
+ for base_lr in self.base_lrs # pylint: disable=not-an-iterable
+ ]
+ if step_count == 0:
+ return factors
+ return [
+ group["lr"] + factor
+ for factor, group in zip(factors, self.optimizer.param_groups)
+ ]
diff --git a/vis4d/engine/parser.py b/vis4d/engine/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cb728b57513ec09c0caf5a56ecb8e70123b16f2
--- /dev/null
+++ b/vis4d/engine/parser.py
@@ -0,0 +1,218 @@
+"""Parser for config files that can be used with absl flags."""
+
+from __future__ import annotations
+
+import logging
+import re
+import sys
+import traceback
+from typing import Any
+
+from absl import flags
+from ml_collections import ConfigDict, FieldReference
+from ml_collections.config_flags.config_flags import (
+ _ConfigFlag,
+ _ErrorConfig,
+ _LockConfig,
+)
+
+from vis4d.config import copy_and_resolve_references
+from vis4d.config.registry import get_config_by_name
+
+
+class ConfigFileParser(flags.ArgumentParser): # type: ignore
+ """Parser for config files."""
+
+ def __init__(
+ self,
+ name: str,
+ lock_config: bool = True,
+ method_name: str = "get_config",
+ ) -> None:
+ """Initializes the parser.
+
+ Args:
+ name (str): The name of the flag (e.g. config for --config flag)
+ lock_config (bool, optional): Whether or not to lock the config.
+ Defaults to True.
+ method_name (str, optional): Name of the method to call in the
+ config. Defaults to "get_config".
+ """
+ self.name = name
+ self._lock_config = lock_config
+ self.method_name = method_name
+
+ def parse( # pylint: disable=arguments-renamed
+ self, path: str
+ ) -> ConfigDict | _ErrorConfig:
+ """Loads a config module from `path` and returns the `method_name()`.
+
+ This implementation is based on the original ml_collections and
+ modified to allow for a custom method name.
+
+ If a colon is present in `path`, everything to the right of the first
+ colon is passed to `method_name` as an argument. This allows the
+ structure of what
+ is returned to be modified, which is useful when performing complex
+ hyperparameter sweeps.
+
+ Args:
+ path: string, path pointing to the config file to execute. May also
+ contain a config_string argument, e.g. be of the form
+ "config.py:some_configuration".
+
+ Returns:
+ Result of calling `method_name` in the specified module.
+ """
+ # This will be a 2 element list iff extra configuration args are
+ # present.
+ split_path = path.split(":", 1)
+
+ try:
+ config = get_config_by_name(
+ split_path[0],
+ *split_path[1:],
+ method_name=self.method_name,
+ )
+ if config is None:
+ logging.warning(
+ "%s:%s() returned None, did you forget a return "
+ "statement?",
+ path,
+ self.method_name,
+ )
+ except IOError as e:
+ # Don't raise the error unless/until the config is
+ # actually accessed.
+ return _ErrorConfig(e)
+ # Third party flags library catches TypeError and ValueError
+ # and rethrows,
+ # removing useful information unless it is added here (b/63877430):
+ except (TypeError, ValueError) as e:
+ error_trace = traceback.format_exc()
+ raise type(e)(
+ "Error whilst parsing config file:\n\n" + error_trace
+ )
+
+ if self._lock_config:
+ _LockConfig(config)
+
+ return config
+
+ def flag_type(self) -> str:
+ """Returns the type of the flag."""
+ return "config object"
+
+
+def DEFINE_config_file( # pylint: disable=invalid-name
+ name: str,
+ default: str | None = None,
+ help_string: str = "path to config file [.py |.yaml].",
+ lock_config: bool = False,
+ method_name: str = "get_config",
+) -> flags.FlagHolder: # type: ignore
+ """Registers a new flag for a config file.
+
+ Args:
+ name (str): The name of the flag (e.g. config for --config flag)
+ default (str | None, optional): Default Value. Defaults to None.
+ help_string (str, optional): Help String.
+ Defaults to "path to config file.".
+ lock_config (bool, optional): Whether or note to lock the returned
+ config. Defaults to False.
+ method_name (str, optional): Name of the method to call in the config.
+
+ Returns:
+ flags.FlagHolder: Flag holder instance.
+ """
+ parser = ConfigFileParser(
+ name=name, lock_config=lock_config, method_name=method_name
+ )
+ flag = _ConfigFlag(
+ parser=parser,
+ serializer=flags.ArgumentSerializer(),
+ name=name,
+ default=default,
+ help_string=help_string,
+ flag_values=flags.FLAGS,
+ )
+
+ # Get the module name for the frame at depth 1 in the call stack.
+ module_name = sys._getframe( # pylint: disable=protected-access
+ 1
+ ).f_globals.get("__name__", None)
+ module_name = sys.argv[0] if module_name == "__main__" else module_name
+ return flags.DEFINE_flag(flag, flags.FLAGS, module_name=module_name)
+
+
+def pprints_config(data: ConfigDict) -> str:
+ """Converts a Config Dict into a string with a .yaml like structure.
+
+ This function differs from __repr__ of ConfigDict in that it will not
+ encode python classes using binary formats but just prints the __repr__
+ of these classes.
+
+ Args:
+ data (ConfigDict): Configuration dict to convert to string
+
+ Returns:
+ str: A string representation of the ConfigDict
+ """
+ return _pprints_config(copy_and_resolve_references(data))
+
+
+def _pprints_config( # type: ignore
+ data: Any, prefix: str = "", n_indents: int = 1
+) -> str:
+ """Converts a ConfigDict into a string with a YAML like structure.
+
+ This is the recursive implementation of 'pprints_config' and will be called
+ recursively for every element in the dict.
+
+ This function differs from __repr__ of ConfigDict in that it will not
+ encode python classes using binary formats but just prints the __repr__
+ of these classes.
+
+ Args:
+ data (Any): Configuration dict or object to convert to
+ string
+ prefix (str): Prefix to print on each new line
+ n_indents (int): Number of spaces to append for each nester property.
+
+ Returns:
+ str: A string representation of the ConfigDict
+ """
+ string_repr = ""
+ if isinstance(data, FieldReference):
+ data = data.get()
+
+ if not isinstance(data, (dict, ConfigDict, list, tuple, dict)):
+ return str(data)
+
+ string_repr += "\n"
+
+ if isinstance(data, (ConfigDict, dict)):
+ for key in data:
+ value = data[key]
+ string_repr += (
+ prefix
+ + key
+ + ": "
+ + _pprints_config(value, prefix=prefix + " " * n_indents)
+ ) + "\n"
+
+ elif isinstance(data, (list, tuple)):
+ for value in data:
+ string_repr += prefix + "- "
+ if isinstance(value, (ConfigDict, dict)):
+ string_repr += "\n"
+
+ string_repr += (
+ _pprints_config(value, prefix=prefix + " " + " " * n_indents)
+ + "\n"
+ )
+ string_repr += " \n" # Add newline after list for better readability.
+
+ # Clean up some formatting issues using regex. Could be done better
+ string_repr = re.sub("\n\n+", "\n", string_repr)
+ return re.sub("- +\n +", "- ", string_repr)
diff --git a/vis4d/engine/run.py b/vis4d/engine/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c723b2d20d63589ce548c365bfb2c0ec97cf10
--- /dev/null
+++ b/vis4d/engine/run.py
@@ -0,0 +1,180 @@
+"""CLI interface using PyTorch Lightning."""
+
+from __future__ import annotations
+
+import logging
+import os.path as osp
+
+import torch
+from absl import app # pylint: disable=no-name-in-module
+from torch.utils.collect_env import get_pretty_env_info
+
+from vis4d.common.logging import dump_config, rank_zero_info, setup_logger
+from vis4d.common.typing import ArgsType
+from vis4d.common.util import set_tf32
+from vis4d.config import instantiate_classes
+from vis4d.config.typing import ExperimentConfig
+from vis4d.engine.callbacks import (
+ Callback,
+ LRSchedulerCallback,
+ VisualizerCallback,
+)
+from vis4d.engine.data_module import DataModule
+from vis4d.engine.flag import (
+ _CKPT,
+ _CONFIG,
+ _GPUS,
+ _NODES,
+ _RESUME,
+ _SHOW_CONFIG,
+ _VIS,
+ _WANDB,
+)
+from vis4d.engine.parser import pprints_config
+from vis4d.engine.trainer import PLTrainer
+from vis4d.engine.training_module import TrainingModule
+
+
+def main(argv: ArgsType) -> None:
+ """Main entry point for the CLI.
+
+ Example to run this script:
+ >>> python -m vis4d.pl.run fit --config configs/faster_rcnn/faster_rcnn_coco.py
+ """
+ # Get config
+ mode = argv[1]
+ assert mode in {"fit", "test"}, f"Invalid mode: {mode}"
+ config: ExperimentConfig = _CONFIG.value
+ num_gpus = _GPUS.value
+ num_nodes = _NODES.value
+
+ # Setup logging
+ logger_vis4d = logging.getLogger("vis4d")
+ logger_pl = logging.getLogger("pytorch_lightning")
+ log_file = osp.join(config.output_dir, f"log_{config.timestamp}.txt")
+ setup_logger(logger_vis4d, log_file)
+ setup_logger(logger_pl, log_file)
+
+ # Dump config
+ config_file = osp.join(
+ config.output_dir, f"config_{config.timestamp}.yaml"
+ )
+ dump_config(config, config_file)
+
+ rank_zero_info("Environment info: %s", get_pretty_env_info())
+
+ # PyTorch Setting
+ set_tf32(config.use_tf32, config.tf32_matmul_precision)
+ torch.hub.set_dir(f"{config.work_dir}/.cache/torch/hub")
+
+ # Setup device
+ if num_gpus > 0:
+ config.pl_trainer.accelerator = "gpu"
+ config.pl_trainer.devices = num_gpus
+ else:
+ config.pl_trainer.accelerator = "cpu"
+ config.pl_trainer.devices = 1
+
+ if num_nodes > 1:
+ config.pl_trainer.num_nodes = num_nodes
+
+ # Wandb
+ config.pl_trainer.wandb = _WANDB.value
+
+ trainer_args = instantiate_classes(config.pl_trainer).to_dict()
+
+ if _SHOW_CONFIG.value:
+ rank_zero_info(pprints_config(config))
+
+ # Instantiate classes
+ if mode == "fit":
+ train_data_connector = instantiate_classes(config.train_data_connector)
+ loss = instantiate_classes(config.loss)
+ else:
+ train_data_connector = None
+ loss = None
+
+ if config.test_data_connector is not None:
+ test_data_connector = instantiate_classes(config.test_data_connector)
+ else:
+ test_data_connector = None
+
+ # Callbacks
+ vis = _VIS.value
+
+ callbacks: list[Callback] = []
+ for cb in config.callbacks:
+ callback = instantiate_classes(cb)
+
+ assert isinstance(callback, Callback), (
+ "Callback must be a subclass of Callback. "
+ f"Provided callback: {cb} is not!"
+ )
+
+ if not vis and isinstance(callback, VisualizerCallback):
+ rank_zero_info(
+ f"{callback.visualizer} is not used. "
+ "Please set --vis=True to use it."
+ )
+ continue
+
+ callbacks.append(callback)
+
+ # Add needed callbacks
+ callbacks.append(LRSchedulerCallback())
+
+ # Checkpoint path
+ ckpt_path = _CKPT.value
+
+ # Resume training
+ resume = _RESUME.value
+ if resume:
+ if ckpt_path is None:
+ resume_ckpt_path = osp.join(
+ config.output_dir, "checkpoints/last.ckpt"
+ )
+ else:
+ resume_ckpt_path = ckpt_path
+ # Check if checkpoint exists, if not start fresh
+ if not osp.exists(resume_ckpt_path):
+ print(f"[vis4d] Checkpoint not found: {resume_ckpt_path}, starting fresh training")
+ resume_ckpt_path = None
+ else:
+ resume_ckpt_path = None
+
+ trainer = PLTrainer(callbacks=callbacks, **trainer_args)
+
+ hyper_params = trainer_args
+
+ if config.get("params", None) is not None:
+ hyper_params.update(config.params.to_dict())
+
+ training_module = TrainingModule(
+ config.model,
+ config.optimizers,
+ loss,
+ train_data_connector,
+ test_data_connector,
+ hyper_params,
+ config.seed,
+ ckpt_path if not resume else None,
+ config.compute_flops,
+ config.check_unused_parameters,
+ )
+ data_module = DataModule(config.data)
+
+ if mode == "fit":
+ trainer.fit(
+ training_module, datamodule=data_module, ckpt_path=resume_ckpt_path
+ )
+ elif mode == "test":
+ trainer.test(training_module, datamodule=data_module, verbose=False)
+
+
+def entrypoint() -> None:
+ """Entry point for the CLI."""
+ app.run(main)
+
+
+if __name__ == "__main__":
+ entrypoint()
diff --git a/vis4d/engine/trainer.py b/vis4d/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..616e593e5823da846e026fb1304449f576550236
--- /dev/null
+++ b/vis4d/engine/trainer.py
@@ -0,0 +1,141 @@
+"""Trainer for PyTorch Lightning."""
+
+from __future__ import annotations
+
+import datetime
+import os.path as osp
+
+from lightning.pytorch import Callback, Trainer
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.loggers import Logger, TensorBoardLogger
+from lightning.pytorch.loggers.wandb import WandbLogger
+from lightning.pytorch.strategies.ddp import DDPStrategy
+
+from vis4d.common.imports import TENSORBOARD_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import ArgsType
+
+
+class PLTrainer(Trainer):
+ """Trainer for PyTorch Lightning."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ work_dir: str,
+ exp_name: str,
+ version: str,
+ epoch_based: bool = True,
+ find_unused_parameters: bool = False,
+ save_top_k: int = 1,
+ checkpoint_period: int = 1,
+ checkpoint_callback: ModelCheckpoint | None = None,
+ wandb: bool = False,
+ seed: int = -1,
+ timeout: int = 3600,
+ wandb_id: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Perform some basic common setups at the beginning of a job.
+
+ Args:
+ work_dir: Specific directory to save checkpoints, logs, etc.
+ Integrates with exp_name and version to get output_dir.
+ exp_name: Name of current experiment.
+ version: Version of current experiment.
+ epoch_based: Use epoch-based / iteration-based training. Default is
+ True.
+ find_unused_parameters: Activates PyTorch checking for unused
+ parameters in DDP setting. Default: False, for better
+ performance.
+ save_top_k: Save top k checkpoints. Default: 1 (save last).
+ checkpoint_period: After N epochs / stpes, save out checkpoints.
+ Default: 1.
+ checkpoint_callback: Custom PL checkpoint callback. Default: None.
+ wandb: Use weights and biases logging instead of tensorboard.
+ Default: False.
+ seed (int, optional): The integer value seed for global random
+ state. Defaults to -1. If -1, a random seed will be generated.
+ This will be set by TrainingModule.
+ timeout: Timeout (seconds) for DDP connection. Default: 3600.
+ wandb_id: If using wandb, the id of the run. If None, a new run
+ will be created. Default: None.
+ """
+ self.work_dir = work_dir
+ self.exp_name = exp_name
+ self.version = version
+ self.seed = seed
+
+ self.output_dir = osp.join(work_dir, exp_name, version)
+
+ # setup experiment logging
+ if "logger" not in kwargs or (
+ isinstance(kwargs["logger"], bool) and kwargs["logger"]
+ ):
+ exp_logger: Logger | None = None
+ if wandb: # pragma: no cover
+ exp_logger = WandbLogger(
+ save_dir=work_dir,
+ project=exp_name,
+ name=version,
+ id=wandb_id,
+ )
+ elif TENSORBOARD_AVAILABLE:
+ exp_logger = TensorBoardLogger(
+ save_dir=work_dir,
+ name=exp_name,
+ version=version,
+ default_hp_metric=False,
+ )
+ else:
+ rank_zero_info(
+ "Neither `tensorboard` nor `tensorboardX` is "
+ "available. Running without experiment logger. To log "
+ "your experiments, try `pip install`ing either."
+ )
+ kwargs["logger"] = exp_logger
+
+ callbacks: list[Callback] = []
+
+ # add learning rate / GPU stats monitor (logs to tensorboard)
+ if TENSORBOARD_AVAILABLE or wandb:
+ callbacks += [LearningRateMonitor(logging_interval="step")]
+
+ # Model checkpointer
+ if checkpoint_callback is None:
+ if epoch_based:
+ checkpoint_cb = ModelCheckpoint(
+ dirpath=osp.join(self.output_dir, "checkpoints"),
+ verbose=True,
+ save_last=True,
+ save_top_k=save_top_k,
+ every_n_epochs=checkpoint_period,
+ save_on_train_epoch_end=True,
+ )
+ else:
+ checkpoint_cb = ModelCheckpoint(
+ dirpath=osp.join(self.output_dir, "checkpoints"),
+ verbose=True,
+ save_last=True,
+ save_top_k=save_top_k,
+ every_n_train_steps=checkpoint_period,
+ )
+ else:
+ checkpoint_cb = checkpoint_callback
+ callbacks += [checkpoint_cb]
+
+ kwargs["callbacks"] += callbacks
+
+ # add distributed strategy
+ if kwargs["devices"] == 0:
+ kwargs["accelerator"] = "cpu"
+ kwargs["devices"] = "auto"
+ elif kwargs["devices"] > 1: # pragma: no cover
+ if kwargs["accelerator"] == "gpu":
+ ddp_plugin = DDPStrategy(
+ find_unused_parameters=find_unused_parameters,
+ timeout=datetime.timedelta(timeout),
+ )
+ kwargs["strategy"] = ddp_plugin
+
+ super().__init__(*args, **kwargs)
diff --git a/vis4d/engine/training_module.py b/vis4d/engine/training_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..99a3cf955c1abe1b807f0ec354ba8b8e51fffe84
--- /dev/null
+++ b/vis4d/engine/training_module.py
@@ -0,0 +1,210 @@
+"""LightningModule that wraps around the models, losses and optims."""
+
+from __future__ import annotations
+
+from typing import Any
+
+import lightning.pytorch as pl
+from lightning.pytorch import seed_everything
+from lightning.pytorch.core.optimizer import LightningOptimizer
+from ml_collections import ConfigDict
+from torch import nn
+from torch.optim.optimizer import Optimizer
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.common.distributed import broadcast
+from vis4d.common.imports import FVCORE_AVAILABLE
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import DictStrAny, GenericFunc
+from vis4d.common.util import init_random_seed
+from vis4d.config import instantiate_classes
+from vis4d.config.typing import OptimizerConfig
+from vis4d.data.typing import DictData
+from vis4d.engine.connectors import DataConnector
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import LRSchedulerWrapper, set_up_optimizers
+from vis4d.model.adapter.flops import IGNORED_OPS, FlopsModelAdapter
+
+if FVCORE_AVAILABLE:
+ from fvcore.nn import FlopCountAnalysis
+
+
+class TrainingModule(pl.LightningModule):
+ """LightningModule that wraps around the vis4d implementations.
+
+ This is a wrapper around the vis4d implementations that allows to use
+ pytorch-lightning for training and testing.
+ """
+
+ def __init__(
+ self,
+ model_cfg: ConfigDict,
+ optimizers_cfg: list[OptimizerConfig],
+ loss_module: None | LossModule,
+ train_data_connector: None | DataConnector,
+ test_data_connector: None | DataConnector,
+ hyper_parameters: DictStrAny | None = None,
+ seed: int = -1,
+ ckpt_path: None | str = None,
+ compute_flops: bool = False,
+ check_unused_parameters: bool = False,
+ ) -> None:
+ """Initialize the TrainingModule.
+
+ Args:
+ model_cfg: The model config.
+ optimizers_cfg: The optimizers config.
+ loss_module: The loss module.
+ train_data_connector: The data connector to use.
+ test_data_connector: The data connector to use.
+ data_connector: The data connector to use.
+ hyper_parameters (DictStrAny | None, optional): The hyper
+ parameters to use. Defaults to None.
+ seed (int, optional): The integer value seed for global random
+ state. Defaults to -1. If -1, a random seed will be generated.
+ ckpt_path (str, optional): The path to the checkpoint to load.
+ Defaults to None.
+ compute_flops (bool, optional): If to compute the FLOPs of the
+ model. Defaults to False.
+ check_unused_parameters (bool, optional): If to check the
+ unused parameters. Defaults to False.
+ """
+ super().__init__()
+ self.model_cfg = model_cfg
+ self.optimizers_cfg = optimizers_cfg
+ self.loss_module = loss_module
+ self.train_data_connector = train_data_connector
+ self.test_data_connector = test_data_connector
+ self.hyper_parameters = hyper_parameters
+ self.seed = seed
+ self.ckpt_path = ckpt_path
+ self.compute_flops = compute_flops
+ self.check_unused_parameters = check_unused_parameters
+
+ # Create model placeholder
+ self.model: nn.Module
+
+ def setup(self, stage: str) -> None:
+ """Setup the model."""
+ if stage == "fit":
+ if self.seed == -1:
+ self.seed = init_random_seed()
+ self.seed = broadcast(self.seed)
+ self.trainer.seed = self.seed # type: ignore
+
+ seed_everything(self.seed, workers=True)
+ rank_zero_info(f"Global seed set to {self.seed}")
+
+ if self.hyper_parameters is not None:
+ self.hyper_parameters["seed"] = self.seed
+ if "checkpoint_callback" in self.hyper_parameters:
+ self.hyper_parameters.pop("checkpoint_callback")
+ self.save_hyperparameters(self.hyper_parameters)
+
+ # Instantiate the model after the seed has been set
+ self.model = instantiate_classes(self.model_cfg)
+
+ if self.ckpt_path is not None:
+ load_model_checkpoint(
+ self.model,
+ self.ckpt_path,
+ rev_keys=[(r"^model\.", ""), (r"^module\.", "")],
+ )
+
+ def forward( # type: ignore # pylint: disable=arguments-differ
+ self, data: DictData
+ ) -> Any:
+ """Forward pass through the model."""
+ if self.training:
+ assert self.train_data_connector is not None
+ return self.model(**self.train_data_connector(data))
+ assert self.test_data_connector is not None
+ return self.model(**self.test_data_connector(data))
+
+ def training_step( # type: ignore # pylint: disable=arguments-differ,line-too-long,unused-argument
+ self, batch: DictData, batch_idx: int
+ ) -> Any:
+ """Perform a single training step."""
+ assert self.train_data_connector is not None
+ out = self.model(**self.train_data_connector(batch))
+
+ assert self.loss_module is not None
+ total_loss, metrics = self.loss_module(out, batch)
+
+ return {
+ "loss": total_loss,
+ "metrics": metrics,
+ "predictions": out,
+ }
+
+ def validation_step( # pylint: disable=arguments-differ,line-too-long,unused-argument
+ self, batch: DictData, batch_idx: int, dataloader_idx: int = 0
+ ) -> DictData:
+ """Perform a single validation step."""
+ assert self.test_data_connector is not None
+ out = self.model(**self.test_data_connector(batch))
+ return out
+
+ def test_step( # pylint: disable=arguments-differ,line-too-long,unused-argument
+ self, batch: DictData, batch_idx: int, dataloader_idx: int = 0
+ ) -> DictData:
+ """Perform a single test step."""
+ assert self.test_data_connector is not None
+
+ if self.compute_flops:
+ flatten_inputs = [
+ self.test_data_connector(batch)[key]
+ for key in self.test_data_connector(batch)
+ ]
+
+ flops_model = FlopsModelAdapter(
+ self.model, self.test_data_connector
+ )
+
+ if not FVCORE_AVAILABLE:
+ raise RuntimeError(
+ "Please install fvcore to compute FLOPs of the model."
+ )
+
+ flop_analyzer = FlopCountAnalysis( # pylint: disable=possibly-used-before-assignment, line-too-long
+ flops_model, flatten_inputs
+ )
+
+ flop_analyzer.set_op_handle(**{k: None for k in IGNORED_OPS})
+
+ flops = flop_analyzer.total() / 1e9
+
+ rank_zero_info(f"Flops: {flops:.2f} Gflops")
+
+ out = self.model(**self.test_data_connector(batch))
+ return out
+
+ def configure_optimizers(self) -> Any: # type: ignore
+ """Return the optimizer to use."""
+ self.trainer.fit_loop.setup_data()
+ steps_per_epoch = len(self.trainer.train_dataloader) # type: ignore
+ return set_up_optimizers(
+ self.optimizers_cfg, [self.model], steps_per_epoch
+ )
+
+ def lr_scheduler_step( # type: ignore # pylint: disable=arguments-differ,line-too-long,unused-argument
+ self, scheduler: LRSchedulerWrapper, metric: Any | None = None
+ ) -> None:
+ """Perform a step on the lr scheduler."""
+ # TODO: Support metric if needed
+ scheduler.step(self.current_epoch)
+
+ def optimizer_step(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer: Optimizer | LightningOptimizer,
+ optimizer_closure: GenericFunc | None = None,
+ ) -> None:
+ """Optimizer step."""
+ if self.check_unused_parameters:
+ for name, param in self.model.named_parameters():
+ if param.grad is None:
+ rank_zero_info(name)
+
+ optimizer.step(closure=optimizer_closure)
diff --git a/vis4d/eval/__init__.py b/vis4d/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c79074b0775cf8cfeafa1938fc436faf38b747
--- /dev/null
+++ b/vis4d/eval/__init__.py
@@ -0,0 +1,5 @@
+"""Evaluation protocols and metrics for different tasks."""
+
+from .base import Evaluator
+
+__all__ = ["Evaluator"]
diff --git a/vis4d/eval/base.py b/vis4d/eval/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b49e6f0a465d03a72b9f8e3064eda4a830cc3e0
--- /dev/null
+++ b/vis4d/eval/base.py
@@ -0,0 +1,121 @@
+"""Vis4D base evaluation."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import GenericFunc, MetricLogs, unimplemented
+
+
+class Evaluator: # pragma: no cover
+ """Abstract evaluator class.
+
+ The evaluator is responsible for evaluating the model on a given dataset.
+ At each end of batches, the process_batch() is called with the model
+ outputs and the batch data to accumulate the data for evaluation. An
+ optional save_batch() can be implemented to save the predictions in the
+ current batch.
+
+ After all batches are processed, the gather() method is called to gather
+ the data from all ranks. Then, the process() method is used to process all
+ the accumulated data that are metrics-independent. Finally, the evaluate()
+ method is called to evaluate the model for the specified metrics and return
+ the results. Optionally, the save() method can be implemented to save the
+ predictions for the specified metrics.
+
+ The following diagram illustrates the evaluation process::
+
+ RANK 0 RANK 1 ...
+
+ x num_batches
+ ┌────────────────────────────────────────────────────────────────┐
+ │ ┌──────────────────────────┐ ┌──────────────────────────┐ │
+ │ │ process_batch(data, ...) │ │ process_batch(data, ...) │ │ <- Process a batch (predictions, labels, etc.)
+ │ └──────────────────────────┘ └──────────────────────────┘ │ and accumulate the data for evaluation.
+ │ ▼ ▼ │
+ │ ┌────────────────────┐ ┌────────────────────┐ │
+ │ │ save_batch(metric) │ │ save_batch(metric) │ │ <- Dump the predictions in a batch for a specified
+ │ └────────────────────┘ └────────────────────┘ │ metric (e.g., for online evaluation).
+ └────────────────┬──────────────────────────────┬────────────────┘
+ ┌─────┴────┐ │
+ │ gather() ├─────────────────────────┘
+ └──────────┘ <- Gather the data from all ranks
+ ▼
+ ┌───────────┐
+ │ process() │ <- Process the data that are
+ └───────────┘ metrics-independent (if any)
+ ▼
+ ┌──────────────────┐
+ │ evaluate(metric) │ <- Evaluate for a specified metric and
+ └──────────────────┘ return the results.
+ ▼
+ ┌──────────────┐
+ │ save(metric) │ <- Dump the predictions for a specified
+ └──────────────┘ metric (e.g., for online evaluation).
+
+ Note:
+ The save_batch() saves the predictions every batch, which is helpful
+ for reducing the memory usage, compared to saving all predictions at
+ once in the save() method. However, the save_batch() is optional and
+ can be omitted if the data can be saved only after all batches are
+ processed.
+ """ # pylint: disable=line-too-long
+
+ @property
+ def metrics(self) -> list[str]:
+ """Return list of metrics to evaluate.
+
+ Returns:
+ list[str]: Metrics to evaluate.
+ """
+ return []
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Gather variables in case of distributed setting (if needed).
+
+ Args:
+ gather_func (Callable[[Any], Any]): Gather function.
+ """
+
+ def reset(self) -> None:
+ """Reset evaluator for new round of evaluation.
+
+ Raises:
+ NotImplementedError: This is an abstract class method.
+ """
+ raise NotImplementedError
+
+ # Process a batch of data.
+ process_batch: GenericFunc = unimplemented
+
+ def process(self) -> None:
+ """Process all accumulated data at the end of an epoch, if any."""
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate all predictions according to given metric.
+
+ Args:
+ metric (str): Metric to evaluate.
+
+ Raises:
+ NotImplementedError: This is an abstract class method.
+
+ Returns:
+ tuple[MetricLogs, str]: Dictionary of scores to log and a pretty
+ printed string.
+ """
+ raise NotImplementedError
+
+ def save_batch(self, metric: str, output_dir: str) -> None:
+ """Save batch of predictions to file.
+
+ Args:
+ metric (str): Save predictions for the specified metrics.
+ output_dir (str): Output directory.
+ """
+
+ def save(self, metric: str, output_dir: str) -> None:
+ """Save all predictions to file at the end of an epoch.
+
+ Args:
+ metric (str): Save predictions for the specified metrics.
+ output_dir (str): Output directory.
+ """
diff --git a/vis4d/eval/bdd100k/__init__.py b/vis4d/eval/bdd100k/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4eb30ec2752affdfb193ce2417abfdd9caa470d
--- /dev/null
+++ b/vis4d/eval/bdd100k/__init__.py
@@ -0,0 +1,11 @@
+"""BDD100K evaluators."""
+
+from .detect import BDD100KDetectEvaluator
+from .seg import BDD100KSegEvaluator
+from .track import BDD100KTrackEvaluator
+
+__all__ = [
+ "BDD100KDetectEvaluator",
+ "BDD100KSegEvaluator",
+ "BDD100KTrackEvaluator",
+]
diff --git a/vis4d/eval/bdd100k/detect.py b/vis4d/eval/bdd100k/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..a252c9188977ff9d99027008cb565cc890efc988
--- /dev/null
+++ b/vis4d/eval/bdd100k/detect.py
@@ -0,0 +1,36 @@
+"""BDD100K detection evaluator."""
+
+from __future__ import annotations
+
+from vis4d.common.imports import BDD100K_AVAILABLE
+from vis4d.eval.scalabel import ScalabelDetectEvaluator
+
+if BDD100K_AVAILABLE:
+ from bdd100k.common.utils import load_bdd100k_config
+else:
+ raise ImportError("bdd100k is not installed.")
+
+
+class BDD100KDetectEvaluator(ScalabelDetectEvaluator):
+ """BDD100K 2D detection evaluation class."""
+
+ METRICS_DET = "Det"
+ METRICS_INS_SEG = "InsSeg"
+
+ def __init__(
+ self,
+ annotation_path: str,
+ config_path: str,
+ mask_threshold: float = 0.0,
+ ) -> None:
+ """Initialize the evaluator."""
+ config = load_bdd100k_config(config_path)
+ super().__init__(
+ annotation_path=annotation_path,
+ config=config.scalabel,
+ mask_threshold=mask_threshold,
+ )
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "BDD100K Detection Evaluator"
diff --git a/vis4d/eval/bdd100k/seg.py b/vis4d/eval/bdd100k/seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a874d642bdda8bb0f6df416ac5ba8d9f8eaf49d6
--- /dev/null
+++ b/vis4d/eval/bdd100k/seg.py
@@ -0,0 +1,101 @@
+"""BDD100K segmentation evaluator."""
+
+from __future__ import annotations
+
+import itertools
+from collections.abc import Callable
+from typing import Any
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE
+from vis4d.common.typing import ArrayLike, MetricLogs
+from vis4d.data.datasets.bdd100k import bdd100k_seg_map
+
+from ..base import Evaluator
+
+if SCALABEL_AVAILABLE and BDD100K_AVAILABLE:
+ from bdd100k.common.utils import load_bdd100k_config
+ from bdd100k.label.to_scalabel import bdd100k_to_scalabel
+ from scalabel.eval.sem_seg import evaluate_sem_seg
+ from scalabel.label.io import load
+ from scalabel.label.transforms import mask_to_rle
+ from scalabel.label.typing import Frame, Label
+else:
+ raise ImportError("scalabel or bdd100k is not installed.")
+
+
+class BDD100KSegEvaluator(Evaluator):
+ """BDD100K segmentation evaluation class."""
+
+ inverse_seg_map = {v: k for k, v in bdd100k_seg_map.items()}
+
+ def __init__(self, annotation_path: str) -> None:
+ """Initialize the evaluator."""
+ super().__init__()
+ self.annotation_path = annotation_path
+ self.frames: list[Frame] = []
+
+ bdd100k_anns = load(annotation_path)
+ frames = bdd100k_anns.frames
+ self.config = load_bdd100k_config("sem_seg")
+ self.gt_frames = bdd100k_to_scalabel(frames, self.config)
+
+ self.reset()
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "BDD100K Segmentation Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return ["sem_seg"]
+
+ def gather( # type: ignore # pragma: no cover
+ self, gather_func: Callable[[Any], Any]
+ ) -> None:
+ """Gather variables in case of distributed setting (if needed).
+
+ Args:
+ gather_func (Callable[[Any], Any]): Gather function.
+ """
+ all_preds = gather_func(self.frames)
+ if all_preds is not None:
+ self.frames = list(itertools.chain(*all_preds))
+
+ def reset(self) -> None:
+ """Reset the evaluator."""
+ self.frames = []
+
+ def process_batch(
+ self, data_names: list[str], masks_list: list[ArrayLike]
+ ) -> None:
+ """Process tracking results."""
+ masks_numpy = [array_to_numpy(m, None) for m in masks_list] # to numpy
+ for data_name, masks in zip(data_names, masks_numpy):
+ labels = []
+ for i, class_id in enumerate(np.unique(masks)):
+ label = Label(
+ rle=mask_to_rle((masks == class_id).astype(np.uint8)),
+ category=self.inverse_seg_map[int(class_id)],
+ id=str(i),
+ )
+ labels.append(label)
+ frame = Frame(name=data_name, labels=labels)
+ self.frames.append(frame)
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the dataset."""
+ if metric == "sem_seg":
+ results = evaluate_sem_seg(
+ ann_frames=self.gt_frames,
+ pred_frames=self.frames,
+ config=self.config.scalabel,
+ nproc=0,
+ )
+ else:
+ raise NotImplementedError
+
+ return {}, str(results)
diff --git a/vis4d/eval/bdd100k/track.py b/vis4d/eval/bdd100k/track.py
new file mode 100644
index 0000000000000000000000000000000000000000..a71b08f1facc2e199575c9939e99938516483a06
--- /dev/null
+++ b/vis4d/eval/bdd100k/track.py
@@ -0,0 +1,81 @@
+"""BDD100K tracking evaluator."""
+
+from __future__ import annotations
+
+from vis4d.common.imports import BDD100K_AVAILABLE, SCALABEL_AVAILABLE
+from vis4d.common.typing import MetricLogs
+from vis4d.data.datasets.bdd100k import bdd100k_track_map
+
+from ..scalabel.track import ScalabelTrackEvaluator
+
+if SCALABEL_AVAILABLE and BDD100K_AVAILABLE:
+ from bdd100k.common.utils import load_bdd100k_config
+ from bdd100k.label.to_scalabel import bdd100k_to_scalabel
+ from scalabel.eval.detect import evaluate_det
+ from scalabel.eval.mot import acc_single_video_mot, evaluate_track
+ from scalabel.label.io import group_and_sort
+else:
+ raise ImportError("scalabel or bdd100k is not installed.")
+
+
+class BDD100KTrackEvaluator(ScalabelTrackEvaluator):
+ """BDD100K 2D tracking evaluation class."""
+
+ METRICS_DET = "Det"
+ METRICS_TRACK = "Track"
+
+ def __init__(
+ self,
+ annotation_path: str,
+ config_path: str = "box_track",
+ mask_threshold: float = 0.0,
+ ) -> None:
+ """Initialize the evaluator."""
+ config = load_bdd100k_config(config_path)
+ super().__init__(
+ annotation_path=annotation_path,
+ config=config.scalabel,
+ mask_threshold=mask_threshold,
+ )
+ self.gt_frames = bdd100k_to_scalabel(self.gt_frames, config)
+ self.inverse_cat_map = {v: k for k, v in bdd100k_track_map.items()}
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "BDD100K Tracking Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [self.METRICS_DET, self.METRICS_TRACK]
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the dataset."""
+ assert self.config is not None, "BDD100K config is not loaded."
+ metrics_log: MetricLogs = {}
+ short_description = ""
+
+ if metric == self.METRICS_DET:
+ det_results = evaluate_det(
+ self.gt_frames,
+ self.frames,
+ config=self.config,
+ nproc=0,
+ )
+ for metric_name, metric_value in det_results.summary().items():
+ metrics_log[metric_name] = metric_value
+ short_description += str(det_results) + "\n"
+
+ if metric == self.METRICS_TRACK:
+ track_results = evaluate_track(
+ acc_single_video_mot,
+ gts=group_and_sort(self.gt_frames),
+ results=group_and_sort(self.frames),
+ config=self.config,
+ nproc=1,
+ )
+ for metric_name, metric_value in track_results.summary().items():
+ metrics_log[metric_name] = metric_value
+ short_description += str(track_results) + "\n"
+
+ return metrics_log, short_description
diff --git a/vis4d/eval/coco/__init__.py b/vis4d/eval/coco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a05a9be3e7c60702df544c15b7935b4f0c133d5e
--- /dev/null
+++ b/vis4d/eval/coco/__init__.py
@@ -0,0 +1,5 @@
+"""Detection evaluators."""
+
+from .detect import COCODetectEvaluator
+
+__all__ = ["COCODetectEvaluator"]
diff --git a/vis4d/eval/coco/detect.py b/vis4d/eval/coco/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..dce40302e9ed1b99d1249f55fb7393e59010460c
--- /dev/null
+++ b/vis4d/eval/coco/detect.py
@@ -0,0 +1,289 @@
+"""COCO evaluator."""
+
+from __future__ import annotations
+
+import contextlib
+import copy
+import io
+import itertools
+
+import numpy as np
+import pycocotools.mask as maskUtils
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from terminaltables import AsciiTable
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.logging import rank_zero_warn
+from vis4d.common.typing import (
+ ArrayLike,
+ DictStrAny,
+ GenericFunc,
+ MetricLogs,
+ NDArrayF32,
+ NDArrayI64,
+)
+from vis4d.data.datasets.coco import coco_det_map
+
+from ..base import Evaluator
+
+
+def xyxy_to_xywh(boxes: NDArrayF32) -> NDArrayF32:
+ """Convert Tensor [N, 4] in xyxy format into xywh.
+
+ Args:
+ boxes (NDArrayF32): Bounding boxes in Vis4D format.
+
+ Returns:
+ NDArrayF32: COCO format bounding boxes.
+ """
+ boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
+ boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
+ return boxes
+
+
+class COCOevalV2(COCOeval): # type: ignore
+ """Subclass COCO eval for logging / printing."""
+
+ def summarize(self) -> str:
+ """Capture summary in string.
+
+ Returns:
+ str: Pretty printed string.
+ """
+ f = io.StringIO()
+ with contextlib.redirect_stdout(f):
+ super().summarize()
+ summary_str = "\n" + f.getvalue()
+ return summary_str
+
+
+def predictions_to_coco(
+ cat_map: dict[str, int],
+ coco_id2name: dict[int, str],
+ image_id: int,
+ boxes: NDArrayF32,
+ scores: NDArrayF32,
+ classes: NDArrayI64,
+ masks: None | NDArrayF32 = None,
+) -> list[DictStrAny]:
+ """Convert Vis4D format predictions to COCO format.
+
+ Args:
+ cat_map (dict[str, int]): COCO class name to class ID mapping.
+ coco_id2name (dict[int, str]): COCO class ID to class name mapping.
+ image_id (int): ID of image.
+ boxes (NDArrayF32): Predicted bounding boxes.
+ scores (NDArrayF32): Predicted scores for each box.
+ classes (NDArrayI64): Predicted classes for each box.
+ masks (None | NDArrayF32, optional): Predicted masks. Defaults to
+ None.
+
+ Returns:
+ list[DictStrAny]: Predictions in COCO format.
+ """
+ predictions = []
+ boxes_xyxy = copy.deepcopy(boxes)
+ boxes_xywh = xyxy_to_xywh(boxes_xyxy)
+ for i, (box, score, cls) in enumerate(zip(boxes_xywh, scores, classes)):
+ mask = masks[i] if masks is not None else None
+ xywh = box.tolist()
+ area = float(xywh[2] * xywh[3])
+ annotation = {
+ "image_id": image_id,
+ "bbox": xywh,
+ "area": area,
+ "score": float(score),
+ "category_id": cat_map[coco_id2name[int(cls)]],
+ "iscrowd": 0,
+ }
+ if mask is not None:
+ annotation["segmentation"] = maskUtils.encode(
+ np.array(mask, order="F", dtype="uint8")
+ )
+ annotation["segmentation"]["counts"] = annotation["segmentation"][
+ "counts"
+ ].decode()
+ predictions.append(annotation)
+ return predictions
+
+
+class COCODetectEvaluator(Evaluator):
+ """COCO detection evaluation class."""
+
+ METRIC_DET = "Det"
+ METRIC_INS_SEG = "InsSeg"
+
+ def __init__(
+ self,
+ data_root: str,
+ split: str = "val2017",
+ per_class_eval: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ data_root (str): Root directory of data.
+ split (str, optional): COCO data split. Defaults to "val2017".
+ per_class_eval (bool, optional): Per-class evaluation. Defaults to
+ False.
+ """
+ super().__init__()
+ self.per_class_eval = per_class_eval
+ self.coco_id2name = {v: k for k, v in coco_det_map.items()}
+ self.annotation_path = (
+ f"{data_root}/annotations/instances_{split}.json"
+ )
+ with contextlib.redirect_stdout(io.StringIO()):
+ self._coco_gt = COCO(self.annotation_path)
+ coco_gt_cats = self._coco_gt.loadCats(self._coco_gt.getCatIds())
+ self.cat_map = {c["name"]: c["id"] for c in coco_gt_cats}
+ self._predictions: list[DictStrAny] = []
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics.
+
+ Returns:
+ list[str]: Metrics to evaluate.
+ """
+ return [self.METRIC_DET, self.METRIC_INS_SEG]
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across processes."""
+ all_preds = gather_func(self._predictions)
+ if all_preds is not None:
+ self._predictions = list(itertools.chain(*all_preds))
+
+ def reset(self) -> None:
+ """Reset the saved predictions to start new round of evaluation."""
+ self._predictions = []
+
+ def process_batch(
+ self,
+ coco_image_id: list[int],
+ pred_boxes: list[ArrayLike],
+ pred_scores: list[ArrayLike],
+ pred_classes: list[ArrayLike],
+ pred_masks: None | list[ArrayLike] = None,
+ ) -> None:
+ """Process sample and convert detections to coco format.
+
+ coco_image_id (list[int]): COCO image ID.
+ pred_boxes (list[ArrayLike]): Predicted bounding boxes.
+ pred_scores (list[ArrayLike]): Predicted scores for each box.
+ pred_classes (list[ArrayLike]): Predicted classes for each box.
+ pred_masks (None | list[ArrayLike], optional): Predicted masks.
+ """
+ for i, (image_id, boxes, scores, classes) in enumerate(
+ zip(coco_image_id, pred_boxes, pred_scores, pred_classes)
+ ):
+ boxes_np = array_to_numpy(boxes, n_dims=None, dtype=np.float32)
+ scores_np = array_to_numpy(scores, n_dims=None, dtype=np.float32)
+ classes_np = array_to_numpy(classes, n_dims=None, dtype=np.int64)
+
+ if pred_masks is not None:
+ masks_np = array_to_numpy(
+ pred_masks[i], n_dims=3, dtype=np.float32
+ )
+ else:
+ masks_np = None
+
+ coco_preds = predictions_to_coco(
+ self.cat_map,
+ self.coco_id2name,
+ image_id,
+ boxes_np,
+ scores_np,
+ classes_np,
+ masks_np,
+ )
+ self._predictions.extend(coco_preds)
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate COCO predictions.
+
+ Args:
+ metric (str): Metric to evaluate. Should be "COCO_AP".
+
+ Raises:
+ NotImplementedError: Raised if metric is not "COCO_AP".
+ RuntimeError: Raised if no predictions are available.
+
+ Returns:
+ tuple[MetricLogs, str]: Dictionary of scores to log and a pretty
+ printed string.
+ """
+ if metric not in [self.METRIC_DET, self.METRIC_INS_SEG]:
+ raise NotImplementedError(f"Metric {metric} not known!")
+
+ if len(self._predictions) == 0:
+ rank_zero_warn(
+ "No predictions to evaluate. Make sure to process batch first!"
+ )
+ return {
+ "AP": 0.0,
+ "AP50": 0.0,
+ "AP75": 0.0,
+ "APs": 0.0,
+ "APm": 0.0,
+ "APl": 0.0,
+ }, "No predictions to evaluate."
+
+ if metric == self.METRIC_DET:
+ iou_type = "bbox"
+ _predictions = self._predictions
+ else:
+ # remove bbox for segm evaluation so cocoapi will use mask
+ # area instead of box area
+ iou_type = "segm"
+ _predictions = copy.deepcopy(self._predictions)
+ for pred in _predictions:
+ pred.pop("bbox")
+ coco_dt = self._coco_gt.loadRes(_predictions)
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ assert coco_dt is not None
+ evaluator = COCOevalV2(self._coco_gt, coco_dt, iouType=iou_type)
+ evaluator.evaluate()
+ evaluator.accumulate()
+
+ log_str = evaluator.summarize()
+ metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl"]
+ score_dict = dict(zip(metrics, evaluator.stats))
+
+ if self.per_class_eval:
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = evaluator.eval["precision"]
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self._coco_gt.getCatIds()) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, cat_id in enumerate(self._coco_gt.getCatIds()):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self._coco_gt.loadCats(cat_id)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision).item()
+ else:
+ ap = float("nan")
+ results_per_category.append((f'{nm["name"]}', f"{ap:0.3f}"))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(itertools.chain(*results_per_category))
+ headers = ["category", "AP"] * (num_columns // 2)
+ results_2d = itertools.zip_longest(
+ *[results_flatten[i::num_columns] for i in range(num_columns)]
+ )
+ table_data = [headers] + list(results_2d)
+ table = AsciiTable(table_data)
+ log_str = f"\n{table.table}\n{log_str}"
+
+ return score_dict, log_str
+
+ def __repr__(self) -> str:
+ """Returns the string representation of the object."""
+ return f"CocoEvaluator(annotation_path={self.annotation_path})"
diff --git a/vis4d/eval/common/__init__.py b/vis4d/eval/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5434fe3032bca59d8c6f6333a2b7f42845d9e28
--- /dev/null
+++ b/vis4d/eval/common/__init__.py
@@ -0,0 +1,15 @@
+"""Common evaluation code."""
+
+from .binary import BinaryEvaluator
+from .cls import ClassificationEvaluator
+from .depth import DepthEvaluator
+from .flow import OpticalFlowEvaluator
+from .seg import SegEvaluator
+
+__all__ = [
+ "ClassificationEvaluator",
+ "DepthEvaluator",
+ "OpticalFlowEvaluator",
+ "BinaryEvaluator",
+ "SegEvaluator",
+]
diff --git a/vis4d/eval/common/binary.py b/vis4d/eval/common/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..30f661d398097218b9757dae567ccb101bd7c6f7
--- /dev/null
+++ b/vis4d/eval/common/binary.py
@@ -0,0 +1,191 @@
+"""Binary occupancy evaluator."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ MetricLogs,
+ NDArrayBool,
+ NDArrayNumber,
+)
+from vis4d.eval.base import Evaluator
+
+
+def threshold_and_flatten(
+ prediction: NDArrayNumber, target: NDArrayNumber, threshold_value: float
+) -> tuple[NDArrayBool, NDArrayBool]:
+ """Thresholds the predictions based on the provided treshold value.
+
+ Applies the following actions:
+ prediction -> prediction >= threshold_value
+ pred, gt = pred.ravel().bool(), gt.ravel().bool()
+
+ Args:
+ prediction: Prediction array with continuous values
+ target: Grondgtruth values {0,1}
+ threshold_value: Value to use to convert the continuous prediction
+ into binary.
+
+ Returns:
+ tuple of two boolean arrays, prediction and target
+ """
+ prediction_bin: NDArrayBool = prediction >= threshold_value
+ return prediction_bin.ravel().astype(bool), target.ravel().astype(bool)
+
+
+class BinaryEvaluator(Evaluator):
+ """Creates a new Evaluater that evaluates binary predictions."""
+
+ METRIC_BINARY = "BinaryCls"
+
+ KEY_IOU = "IoU"
+ KEY_ACCURACY = "Accuracy"
+ KEY_F1 = "F1"
+ KEY_PRECISION = "Precision"
+ KEY_RECALL = "Recall"
+
+ def __init__(
+ self,
+ threshold: float = 0.5,
+ ) -> None:
+ """Creates a new binary evaluator.
+
+ Args:
+ threshold (float): Threshold for prediction to convert
+ to binary. All prediction that are higher than
+ this value will be assigned the 'True' label
+ """
+ super().__init__()
+ self.threshold = threshold
+ self.reset()
+
+ self.true_positives: list[float] = []
+ self.false_positives: list[float] = []
+ self.true_negatives: list[float] = []
+ self.false_negatives: list[float] = []
+ self.n_samples: list[float] = []
+
+ self.has_samples = False
+
+ def _calc_confusion_matrix(
+ self, prediction: NDArrayBool, target: NDArrayBool
+ ) -> None:
+ """Calculates the confusion matrix and stores them as attributes.
+
+ Args:
+ prediction: the prediction (binary) (N, Pts)
+ target: the groundtruth (binary) (N, Pts)
+ """
+ tp = int(np.sum(np.logical_and(prediction == 1, target == 1)))
+ fp = int(np.sum(np.logical_and(prediction == 1, target == 0)))
+ tn = int(np.sum(np.logical_and(prediction == 0, target == 0)))
+ fn = int(np.sum(np.logical_and(prediction == 0, target == 1)))
+ self.true_positives.append(tp)
+ self.false_positives.append(fp)
+ self.true_negatives.append(tn)
+ self.false_negatives.append(fn)
+ self.n_samples.append(tp + fp + tn + fn)
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [self.METRIC_BINARY]
+
+ def reset(self) -> None:
+ """Reset the saved predictions to start new round of evaluation."""
+ self.true_positives = []
+ self.false_positives = []
+ self.true_negatives = []
+ self.false_negatives = []
+ self.n_samples = []
+
+ def process_batch(
+ self,
+ prediction: ArrayLike,
+ groundtruth: ArrayLike,
+ ) -> None:
+ """Processes a new (batch) of predictions.
+
+ Calculates the metrics and caches them internally.
+
+ Args:
+ prediction: the prediction(continuous values or bin) (Batch x Pts)
+ groundtruth: the groundtruth (binary) (Batch x Pts)
+ """
+ pred, gt = threshold_and_flatten(
+ array_to_numpy(prediction, n_dims=None, dtype=np.float32),
+ array_to_numpy(groundtruth, n_dims=None, dtype=np.bool_),
+ self.threshold,
+ )
+
+ # Confusion Matrix
+ self._calc_confusion_matrix(pred, gt)
+ self.has_samples = True
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions.
+
+ Returns a dict containing the raw data and a
+ short description string containing a readable result.
+
+ Args:
+ metric (str): Metric to use. See @property metric
+
+ Returns:
+ metric_data, description
+ tuple containing the metric data (dict with metric name and value)
+ as well as a short string with shortened information.
+
+ Raises:
+ RuntimeError: if no data has been registered to be evaluated.
+ ValueError: if metric is not supported.
+ """
+ if not self.has_samples:
+ raise RuntimeError(
+ """No data registered to calculate metric.
+ Register data using .process() first!"""
+ )
+ metric_data: MetricLogs = {}
+ short_description = ""
+
+ if metric == self.METRIC_BINARY:
+ # IoU
+ iou = sum(self.true_positives) / (
+ sum(self.n_samples) - sum(self.true_negatives) + 1e-6
+ )
+ metric_data[self.KEY_IOU] = iou
+ short_description += f"IoU: {iou:.3f}\n"
+
+ # Accuracy
+ acc = (sum(self.true_positives) + sum(self.true_negatives)) / sum(
+ self.n_samples
+ )
+ metric_data[self.KEY_ACCURACY] = acc
+ short_description += f"Accuracy: {acc:.3f}\n"
+
+ # Precision
+ tp_fp = sum(self.true_positives) + sum(self.false_positives)
+ precision = sum(self.true_positives) / tp_fp if tp_fp != 0 else 1
+ metric_data[self.KEY_PRECISION] = precision
+ short_description += f"Precision: {precision:.3f}\n"
+
+ # Recall
+ tp_fn = sum(self.true_positives) + sum(self.false_negatives)
+ recall = sum(self.true_positives) / tp_fn if tp_fn != 0 else 1
+ metric_data[self.KEY_RECALL] = recall
+ short_description += f"Recall: {acc:.3f}\n"
+
+ # F1
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
+ metric_data[self.KEY_F1] = f1
+ short_description += f"F1: {f1:.3f}\n"
+
+ else:
+ raise ValueError(
+ f"Unsupported metric: {metric}"
+ ) # pragma: no cover
+
+ return metric_data, short_description
diff --git a/vis4d/eval/common/cls.py b/vis4d/eval/common/cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..7db6482c5d2254190f7fdd7618e508040bebf42a
--- /dev/null
+++ b/vis4d/eval/common/cls.py
@@ -0,0 +1,137 @@
+"""Image classification evaluator."""
+
+from __future__ import annotations
+
+import itertools
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ GenericFunc,
+ MetricLogs,
+ NDArrayI64,
+ NDArrayNumber,
+)
+from vis4d.eval.base import Evaluator
+
+from ..metrics.cls import accuracy
+
+
+class ClassificationEvaluator(Evaluator):
+ """Multi-class classification evaluator."""
+
+ METRIC_CLASSIFICATION = "Cls"
+
+ KEY_ACCURACY = "Acc@1"
+ KEY_ACCURACY_TOP5 = "Acc@5"
+
+ def __init__(self) -> None:
+ """Initialize the classification evaluator."""
+ super().__init__()
+ self._metrics_list: list[dict[str, float]] = []
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [
+ self.KEY_ACCURACY,
+ self.KEY_ACCURACY_TOP5,
+ ]
+
+ def reset(self) -> None:
+ """Reset evaluator for new round of evaluation."""
+ self._metrics_list = []
+
+ def _is_correct(
+ self, pred: NDArrayNumber, target: NDArrayI64, top_k: int = 1
+ ) -> bool:
+ """Check if the prediction is correct for top-k.
+
+ Args:
+ pred (NDArrayNumber): Prediction logits, in shape (C, ).
+ target (NDArrayI64): Target logits, in shape (1, ).
+ top_k (int, optional): Top-k to check. Defaults to 1.
+
+ Returns:
+ bool: Whether the prediction is correct.
+ """
+ top_k = min(top_k, pred.shape[0])
+ top_k_idx = np.argsort(pred)[-top_k:]
+ return bool(np.any(top_k_idx == target))
+
+ def process_batch( # type: ignore # pylint: disable=arguments-differ
+ self, prediction: ArrayLike, groundtruth: ArrayLike
+ ):
+ """Process a batch of predictions and groundtruths.
+
+ Args:
+ prediction (ArrayLike): Prediction, in shape (N, C).
+ groundtruth (ArrayLike): Groundtruth, in shape (N, ).
+ """
+ pred = array_to_numpy(prediction, n_dims=None, dtype=np.float32)
+ gt = array_to_numpy(groundtruth, n_dims=None, dtype=np.int64)
+ for i in range(pred.shape[0]):
+ self._metrics_list.append(
+ {
+ "top1_correct": accuracy(pred[i], gt[i], top_k=1),
+ "top5_correct": accuracy(pred[i], gt[i], top_k=5),
+ }
+ )
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across processes."""
+ all_metrics = gather_func(self._metrics_list)
+ if all_metrics is not None:
+ self._metrics_list = list(itertools.chain(*all_metrics))
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions.
+
+ Returns a dict containing the raw data and a
+ short description string containing a readable result.
+
+ Args:
+ metric (str): Metric to use. See @property metric
+
+ Returns:
+ metric_data, description
+ tuple containing the metric data (dict with metric name and value)
+ as well as a short string with shortened information.
+
+ Raises:
+ RuntimeError: if no data has been registered to be evaluated.
+ ValueError: if the metric is not supported.
+ """
+ if len(self._metrics_list) == 0:
+ raise RuntimeError(
+ """No data registered to calculate metric.
+ Register data using .process() first!"""
+ )
+ metric_data: MetricLogs = {}
+ short_description = ""
+
+ if metric == self.METRIC_CLASSIFICATION:
+ # Top1 accuracy
+ top1_correct = np.array(
+ [metric["top1_correct"] for metric in self._metrics_list]
+ )
+ top1_acc = np.mean(top1_correct)
+ metric_data[self.KEY_ACCURACY] = top1_acc
+ short_description += f"Top1 Accuracy: {top1_acc:.4f}\n"
+
+ # Top5 accuracy
+ top5_correct = np.array(
+ [metric["top5_correct"] for metric in self._metrics_list]
+ )
+ top5_acc = np.mean(top5_correct)
+ metric_data[self.KEY_ACCURACY_TOP5] = top5_acc
+ short_description += f"Top5 Accuracy: {top5_acc:.4f}\n"
+
+ else:
+ raise ValueError(
+ f"Unsupported metric: {metric}"
+ ) # pragma: no cover
+
+ return metric_data, short_description
diff --git a/vis4d/eval/common/depth.py b/vis4d/eval/common/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d9b3244c3ba611663d47aacb237336675614919
--- /dev/null
+++ b/vis4d/eval/common/depth.py
@@ -0,0 +1,214 @@
+"""Depth estimation evaluator."""
+
+from __future__ import annotations
+
+import itertools
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ GenericFunc,
+ MetricLogs,
+ NDArrayFloat,
+)
+from vis4d.eval.base import Evaluator
+
+from ..metrics.depth import (
+ absolute_error,
+ absolute_relative_error,
+ delta_p,
+ log_10_error,
+ root_mean_squared_error,
+ root_mean_squared_error_log,
+ scale_invariant_log,
+ squared_relative_error,
+)
+
+
+class DepthEvaluator(Evaluator):
+ """Depth estimation evaluator."""
+
+ METRIC_DEPTH = "Depth"
+
+ KEY_DELTA05 = "d05"
+ KEY_DELTA1 = "d1"
+ KEY_DELTA2 = "d2"
+ KEY_DELTA3 = "d3"
+
+ KEY_ABS_REL = "AbsRel"
+ KEY_ABS_ERR = "AbsErr"
+ KEY_SQ_REL = "SqRel"
+ KEY_RMSE = "RMSE"
+ KEY_RMSE_LOG = "RMSELog"
+ KEY_SILOG = "SILog"
+ KEY_LOG10 = "Log10"
+
+ def __init__(
+ self,
+ min_depth: float = 0.0,
+ max_depth: float = 80.0,
+ scale: float = 1.0,
+ epsilon: float = 1e-3,
+ ) -> None:
+ """Initialize the optical flow evaluator.
+
+ Args:
+ min_depth (float): Minimum depth to evaluate. Defaults to 0.001.
+ max_depth (float): Maximum depth to evaluate. Defaults to 80.0.
+ scale (float): Scale factor for depth. Defaults to 1.0.
+ epsilon (float): Small value to avoid logarithms of small values.
+ Defaults to 1e-3.
+ """
+ super().__init__()
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.epsilon = epsilon
+ self.scale = scale
+ self._metrics_list: list[dict[str, float]] = []
+
+ def __repr__(self) -> str:
+ """Concise representation of the evaluator."""
+ return "Common Depth Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [self.METRIC_DEPTH]
+
+ def reset(self) -> None:
+ """Reset evaluator for new round of evaluation."""
+ self._metrics_list = []
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across processes."""
+ all_metrics = gather_func(self._metrics_list)
+ if all_metrics is not None:
+ self._metrics_list = list(itertools.chain(*all_metrics))
+
+ def _apply_mask(
+ self, prediction: NDArrayFloat, target: NDArrayFloat
+ ) -> tuple[NDArrayFloat, NDArrayFloat]:
+ """Apply mask to prediction and target."""
+ mask = (target > self.min_depth) & (target <= self.max_depth)
+ return prediction[mask], target[mask]
+
+ def process_batch(
+ self, prediction: ArrayLike, groundtruth: ArrayLike
+ ) -> None:
+ """Process a batch of data.
+
+ Args:
+ prediction (np.array): Prediction optical flow, in shape (B, H, W).
+ groundtruth (np.array): Target optical flow, in shape (B, H, W).
+ """
+ preds = (
+ array_to_numpy(prediction, n_dims=None, dtype=np.float32)
+ * self.scale
+ )
+ gts = array_to_numpy(groundtruth, n_dims=None, dtype=np.float32)
+
+ for pred, gt in zip(preds, gts):
+ pred, gt = self._apply_mask(pred, gt)
+ self._metrics_list.append(
+ {
+ self.KEY_ABS_REL: absolute_relative_error(pred, gt),
+ self.KEY_ABS_ERR: absolute_error(pred, gt),
+ self.KEY_SQ_REL: squared_relative_error(pred, gt),
+ self.KEY_RMSE: root_mean_squared_error(pred, gt),
+ self.KEY_RMSE_LOG: root_mean_squared_error_log(pred, gt),
+ self.KEY_SILOG: scale_invariant_log(pred, gt),
+ self.KEY_DELTA05: delta_p(pred, gt, 0.5),
+ self.KEY_DELTA1: delta_p(pred, gt, 1.0),
+ self.KEY_DELTA2: delta_p(pred, gt, 2.0),
+ self.KEY_DELTA3: delta_p(pred, gt, 3.0),
+ self.KEY_LOG10: log_10_error(pred, gt),
+ }
+ )
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions.
+
+ Returns a dict containing the raw data and a
+ short description string containing a readablae result.
+
+ Args:
+ metric (str): Metric to use. See @property metric
+
+ Returns:
+ metric_data, description
+ tuple containing the metric data (dict with metric name and value)
+ as well as a short string with shortened information.
+
+ Raises:
+ RuntimeError: if no data has been registered to be evaluated.
+ ValueError: if metric is not supported.
+ """
+ if len(self._metrics_list) == 0:
+ raise RuntimeError(
+ """No data registered to calculate metric.
+ Register data using .process() first!"""
+ )
+ metric_data: MetricLogs = {}
+ short_description = "\n"
+
+ if metric == self.METRIC_DEPTH:
+ abs_rel = np.mean(
+ [x[self.KEY_ABS_REL] for x in self._metrics_list]
+ )
+ metric_data[self.KEY_ABS_REL] = float(abs_rel)
+ short_description += f"Absolute relative error: {abs_rel:.3f}\n"
+
+ abs_err = np.mean(
+ [x[self.KEY_ABS_ERR] for x in self._metrics_list]
+ )
+ metric_data[self.KEY_ABS_ERR] = float(abs_err)
+ short_description += f"Absolute error: {abs_err:.3f}\n"
+
+ sq_rel = np.mean([x[self.KEY_SQ_REL] for x in self._metrics_list])
+ metric_data[self.KEY_SQ_REL] = float(sq_rel)
+ short_description += f"Squared relative error: {sq_rel:.3f}\n"
+
+ rmse = np.mean([x[self.KEY_RMSE] for x in self._metrics_list])
+ metric_data[self.KEY_RMSE] = float(rmse)
+ short_description += f"RMSE: {rmse:.3f}\n"
+
+ rmse_log = np.mean(
+ [x[self.KEY_RMSE_LOG] for x in self._metrics_list]
+ )
+ metric_data[self.KEY_RMSE_LOG] = float(rmse_log)
+ short_description += f"RMSE log: {rmse_log:.3f}\n"
+
+ silog = np.mean([x[self.KEY_SILOG] for x in self._metrics_list])
+ metric_data[self.KEY_SILOG] = float(silog)
+ short_description += f"SILog: {silog:.3f}\n"
+
+ delta05 = np.mean(
+ [x[self.KEY_DELTA05] for x in self._metrics_list]
+ )
+ metric_data[self.KEY_DELTA05] = float(delta05)
+ short_description += f"Delta 0.5: {delta05:.3f}\n"
+
+ delta1 = np.mean([x[self.KEY_DELTA1] for x in self._metrics_list])
+ metric_data[self.KEY_DELTA1] = float(delta1)
+ short_description += f"Delta 1: {delta1:.3f}\n"
+
+ delta2 = np.mean([x[self.KEY_DELTA2] for x in self._metrics_list])
+ metric_data[self.KEY_DELTA2] = float(delta2)
+ short_description += f"Delta 2: {delta2:.3f}\n"
+
+ delta3 = np.mean([x[self.KEY_DELTA3] for x in self._metrics_list])
+ metric_data[self.KEY_DELTA3] = float(delta3)
+ short_description += f"Delta 3: {delta3:.3f}\n"
+
+ log10 = np.mean([x[self.KEY_LOG10] for x in self._metrics_list])
+ metric_data[self.KEY_LOG10] = float(log10)
+ short_description += f"Log10 error: {log10:.3f}\n"
+
+ else:
+ raise ValueError(
+ f"Unsupported metric: {metric}"
+ ) # pragma: no cover
+
+ return metric_data, short_description
diff --git a/vis4d/eval/common/flow.py b/vis4d/eval/common/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a591587b2b3c4b5c9aa722ee1da0bc5c1f97f12
--- /dev/null
+++ b/vis4d/eval/common/flow.py
@@ -0,0 +1,152 @@
+"""Optical flow evaluator."""
+
+from __future__ import annotations
+
+import itertools
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ GenericFunc,
+ MetricLogs,
+ NDArrayFloat,
+)
+from vis4d.eval.base import Evaluator
+
+from ..metrics.flow import angular_error, end_point_error
+
+
+class OpticalFlowEvaluator(Evaluator):
+ """Optical flow evaluator."""
+
+ METRIC_FLOW = "Flow"
+
+ KEY_ENDPOINT_ERROR = "EPE"
+ KEY_ANGULAR_ERROR = "AE"
+
+ def __init__(
+ self,
+ max_flow: float = 400.0,
+ use_degrees: bool = False,
+ scale: float = 1.0,
+ epsilon: float = 1e-6,
+ ) -> None:
+ """Initialize the optical flow evaluator.
+
+ Args:
+ max_flow (float, optional): Maximum flow value. Defaults to 400.0.
+ use_degrees (bool, optional): Whether to use degrees for angular
+ error. Defaults to False.
+ scale (float, optional): Scale factor for the optical flow.
+ Defaults to 1.0.
+ epsilon (float, optional): Epsilon value for numerical stability.
+ """
+ super().__init__()
+ self.max_flow = max_flow
+ self.use_degrees = use_degrees
+ self.scale = scale
+ self.epsilon = epsilon
+ self._metrics_list: list[dict[str, float]] = []
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [
+ OpticalFlowEvaluator.METRIC_FLOW,
+ ]
+
+ def reset(self) -> None:
+ """Reset evaluator for new round of evaluation."""
+ self._metrics_list = []
+
+ def _apply_mask(
+ self, prediction: NDArrayFloat, target: NDArrayFloat
+ ) -> tuple[NDArrayFloat, NDArrayFloat]:
+ """Apply mask to prediction and target."""
+ mask = np.sum(np.abs(target), axis=-1) <= self.max_flow
+ return prediction[mask], target[mask]
+
+ def process_batch(
+ self, prediction: ArrayLike, groundtruth: ArrayLike
+ ) -> None:
+ """Process a batch of data.
+
+ Args:
+ prediction (NDArrayNumber): Prediction optical flow, in shape
+ (N, H, W, 2).
+ groundtruth (NDArrayNumber): Target optical flow, in shape
+ (N, H, W, 2).
+ """
+ preds = (
+ array_to_numpy(prediction, n_dims=None, dtype=np.float32)
+ * self.scale
+ )
+ gts = array_to_numpy(groundtruth, n_dims=None, dtype=np.float32)
+
+ for pred, gt in zip(preds, gts):
+ pred, gt = self._apply_mask(pred, gt)
+ epe = end_point_error(pred, gt)
+ ae = angular_error(pred, gt, self.epsilon)
+ self._metrics_list.append(
+ {
+ OpticalFlowEvaluator.KEY_ENDPOINT_ERROR: epe,
+ OpticalFlowEvaluator.KEY_ANGULAR_ERROR: ae,
+ }
+ )
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across processes."""
+ all_metrics = gather_func(self._metrics_list)
+ if all_metrics is not None:
+ self._metrics_list = list(itertools.chain(*all_metrics))
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions.
+
+ Returns a dict containing the raw data and a
+ short description string containing a readable result.
+
+ Args:
+ metric (str): Metric to use. See @property metric
+
+ Returns:
+ metric_data, description
+ tuple containing the metric data (dict with metric name and value)
+ as well as a short string with shortened information.
+
+ Raises:
+ RuntimeError: if no data has been registered to be evaluated.
+ ValueError: if metric is not supported.
+ """
+ if len(self._metrics_list) == 0:
+ raise RuntimeError(
+ """No data registered to calculate metric.
+ Register data using .process() first!"""
+ )
+ metric_data: MetricLogs = {}
+ short_description = ""
+
+ if metric == OpticalFlowEvaluator.METRIC_FLOW:
+ # EPE
+ epe = np.mean(
+ [x[self.KEY_ENDPOINT_ERROR] for x in self._metrics_list]
+ )
+ metric_data[self.KEY_ENDPOINT_ERROR] = float(epe)
+ short_description = f"EPE: {epe:.3f}"
+
+ # AE
+ ae = np.mean(
+ [x[self.KEY_ANGULAR_ERROR] for x in self._metrics_list]
+ )
+ metric_data[self.KEY_ANGULAR_ERROR] = float(ae)
+ angular_unit = "rad" if not self.use_degrees else "deg"
+ short_description = f"AE: {ae:.3f}{angular_unit}"
+
+ else:
+ raise ValueError(
+ f"Unsupported metric: {metric}"
+ ) # pragma: no cover
+
+ return metric_data, short_description
diff --git a/vis4d/eval/common/seg.py b/vis4d/eval/common/seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..28817a89cf1ef49495f532d8cc6503714bf4294f
--- /dev/null
+++ b/vis4d/eval/common/seg.py
@@ -0,0 +1,185 @@
+"""Common segmentation evaluator."""
+
+from __future__ import annotations
+
+import numpy as np
+from terminaltables import AsciiTable
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ MetricLogs,
+ NDArrayI64,
+ NDArrayNumber,
+)
+from vis4d.eval.base import Evaluator
+
+
+class SegEvaluator(Evaluator):
+ """Creates an evaluator that calculates mIoU score and confusion matrix."""
+
+ METRIC_MIOU = "mIoU"
+ METRIC_CONFUSION_MATRIX = "confusion_matrix"
+
+ def __init__(
+ self,
+ num_classes: int | None = None,
+ class_to_ignore: int | None = None,
+ class_mapping: dict[int, str] | None = None,
+ ):
+ """Creates a new evaluator.
+
+ Args:
+ num_classes (int): Number of semantic classes
+ class_to_ignore (int | None): Groundtruth class that should be
+ ignored
+ class_mapping (int): dict mapping each class_id to a readable name
+
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.class_mapping = class_mapping if class_mapping is not None else {}
+ self.class_to_ignore = class_to_ignore
+
+ self._confusion_matrix: NDArrayI64 | None = None
+ self.reset()
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [
+ self.METRIC_MIOU,
+ self.METRIC_CONFUSION_MATRIX,
+ ]
+
+ # Taken and modified (added static N) from
+ # https://stackoverflow.com/questions/59080843/faster-method-of-computing-confusion-matrix
+ def calc_confusion_matrix(
+ self, prediction: NDArrayNumber, groundtruth: NDArrayI64
+ ) -> NDArrayI64:
+ """Calculates the confusion matrix for multi class predictions.
+
+ Args:
+ prediction (array): Class predictions
+ groundtruth (array): Groundtruth classes
+
+ Returns:
+ Confusion Matrix of dimension n_classes x n_classes.
+ """
+ y_true = groundtruth.reshape(-1)
+ if prediction.shape != groundtruth.shape:
+ y_pred = np.argmax(prediction, axis=1).reshape(-1)
+ else:
+ y_pred = prediction.reshape(-1)
+ y_pred = y_pred.astype(np.int64)
+
+ if self.class_to_ignore is not None:
+ valid = y_true != self.class_to_ignore
+ y_true = y_true[valid]
+ y_pred = y_pred[valid]
+ if self.num_classes is None:
+ n_classes = np.max(np.max(groundtruth), np.max(y_pred)) + 1
+ else:
+ n_classes = self.num_classes
+
+ y = n_classes * y_true + y_pred
+ y = np.bincount(y, minlength=n_classes * n_classes)
+ return y.reshape(n_classes, n_classes)
+
+ def reset(self) -> None:
+ """Reset the saved predictions to start new round of evaluation."""
+ self._confusion_matrix = None
+
+ def process_batch(
+ self, prediction: ArrayLike, groundtruth: ArrayLike
+ ) -> None:
+ """Process sample and update confusion matrix.
+
+ Args:
+ prediction: Predictions of shape [N,C,...] or [N,...] with
+ C* being any number if channels. Note, C is passed,
+ the prediction is converted to target labels by applying
+ the max operations along the second axis
+ groundtruth: Groundtruth of shape [N_batch, ...] type int
+ """
+ confusion_matrix = self.calc_confusion_matrix(
+ array_to_numpy(prediction, n_dims=None, dtype=np.float32),
+ array_to_numpy(groundtruth, n_dims=None, dtype=np.int64),
+ )
+
+ if self._confusion_matrix is None:
+ self._confusion_matrix = confusion_matrix
+ else:
+ assert (
+ self._confusion_matrix.shape == confusion_matrix.shape
+ ), """Shape of confusion matrix changed during runtime!,
+ Please specify a static number of classes in constructor."""
+ self._confusion_matrix += confusion_matrix
+
+ def _get_class_name_for_idx(self, idx: int) -> str:
+ """Maps a class index to a unique class name.
+
+ Args:
+ idx (int): class index.
+
+ Returns:
+ (str) class name
+ """
+ return self.class_mapping.get(idx, f"class_{idx}")
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions.
+
+ Returns a dict containing the raw data and a
+ short description string containing a readable result.
+
+ Args:
+ metric (str): Metric to use. See @property metric.
+
+ Returns:
+ (dict, str) containing the raw data and a short description string.
+
+ Raises:
+ ValueError: If metric is not supported.
+ """
+ assert (
+ self._confusion_matrix is not None
+ ), """Evaluate() needs to process samples first.
+ Please call the process() function before calling evaluate()"""
+
+ metric_data, short_description = {}, ""
+ if metric == self.METRIC_MIOU:
+ # Calculate miou from confusion matrix
+ tp = np.diag(self._confusion_matrix)
+ fp = np.sum(self._confusion_matrix, axis=0) - tp
+ fn = np.sum(self._confusion_matrix, axis=1) - tp
+ iou = tp / (tp + fn + fp) * 100
+ m_iou = np.nanmean(iou)
+
+ iou_class_str = ", ".join(
+ f"{self._get_class_name_for_idx(idx)}: ({d:.3f}%)"
+ for idx, d in enumerate(iou)
+ )
+ metric_data[self.METRIC_MIOU] = m_iou
+ short_description += f"mIoU: {m_iou:.3f}% \n"
+ short_description += iou_class_str + "\n"
+
+ elif metric == self.METRIC_CONFUSION_MATRIX:
+ headers = ["Confusion"] + [
+ self._get_class_name_for_idx(i)
+ for i in range(self._confusion_matrix.shape[0])
+ ]
+ table_data = self._confusion_matrix / (
+ np.sum(self._confusion_matrix, axis=1)
+ )
+ data = list(
+ [f"Class_{idx}"] + list(d) for idx, d in enumerate(table_data)
+ )
+ table = AsciiTable([headers] + data)
+ # TODO, change MetricLogs type for more complex log types as e.g.
+ # confusion matrix
+ short_description += table.table + "\n"
+
+ else:
+ raise ValueError(f"Metric {metric} not supported")
+ return metric_data, short_description
diff --git a/vis4d/eval/kitti/__init__.py b/vis4d/eval/kitti/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae643cb0372f2507a574267967dd5e7ba5814aa
--- /dev/null
+++ b/vis4d/eval/kitti/__init__.py
@@ -0,0 +1,5 @@
+"""KITTI evaluator."""
+
+from .depth import KITTIDepthEvaluator
+
+__all__ = ["KITTIDepthEvaluator"]
diff --git a/vis4d/eval/kitti/depth.py b/vis4d/eval/kitti/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..df21b448028df460ef2cd3080724967687de899b
--- /dev/null
+++ b/vis4d/eval/kitti/depth.py
@@ -0,0 +1,87 @@
+"""KITTI evaluation code."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.typing import NDArrayFloat, NDArrayNumber
+
+from ..common import DepthEvaluator
+
+
+def apply_garg_crop(mask: NDArrayNumber) -> NDArrayNumber:
+ """Apply Garg ECCV16 crop to the mask.
+
+ Args:
+ mask (np.array): Mask to be cropped, in shape (..., H, W).
+
+ Returns:
+ np.array: Cropped mask, in shape (..., H', W').
+ """
+ # crop used by Garg ECCV16
+ h, w = mask.shape[-2:]
+ crop = np.array(
+ [0.40810811 * h, 0.99189189 * h, 0.03594771 * w, 0.96405229 * w]
+ ).astype(np.int32)
+ mask[..., crop[0] : crop[1], crop[2] : crop[3]] = 1
+ return mask
+
+
+def apply_eigen_crop(mask: NDArrayNumber) -> NDArrayNumber:
+ """Apply Eigen NIPS14 crop to the mask.
+
+ Args:
+ mask (np.array): Mask to be cropped, in shape (N, H, W).
+
+ Returns:
+ np.array: Cropped mask, in shape (N, H', W').
+ """
+ # https://github.com/mrharicot/monodepth/utils/evaluate_kitti.py
+ h, w = mask.shape[-2:]
+ crop = np.array(
+ [0.3324324 * h, 0.91351351 * h, 0.0359477 * w, 0.96405229 * w]
+ ).astype(np.int32)
+ mask[..., crop[0] : crop[1], crop[2] : crop[3]] = 1
+ return mask
+
+
+class KITTIDepthEvaluator(DepthEvaluator):
+ """KITTI depth evaluation class."""
+
+ METRIC_DEPTH = "depth"
+
+ def __init__(
+ self,
+ min_depth: float = 0.01,
+ max_depth: float = 80.0,
+ eval_crop: str | None = None,
+ ) -> None:
+ """Initialize KITTI depth evaluator."""
+ super().__init__(min_depth, max_depth)
+ self.eval_crop = eval_crop
+ self.reset()
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "KITTI evaluation for depth"
+
+ def _get_eval_mask(self, valid_mask: NDArrayNumber) -> NDArrayNumber:
+ """Do Grag or Eigen cropping for testing."""
+ eval_mask = np.zeros_like(valid_mask)
+ if self.eval_crop == "garg_crop":
+ eval_mask = apply_garg_crop(eval_mask)
+ elif self.eval_crop == "eigen_crop":
+ eval_mask = apply_eigen_crop(eval_mask)
+ else:
+ eval_mask = np.ones_like(valid_mask)
+ return np.logical_and(valid_mask, eval_mask)
+
+ def _apply_mask(
+ self, prediction: NDArrayFloat, target: NDArrayFloat
+ ) -> tuple[NDArrayFloat, NDArrayFloat]:
+ """Apply mask to prediction and target."""
+ valid_mask = (target > self.min_depth) & (target < self.max_depth)
+ eval_mask = self._get_eval_mask(valid_mask)
+ prediction = prediction[eval_mask]
+ target = target[eval_mask]
+ return prediction, target
diff --git a/vis4d/eval/metrics/__init__.py b/vis4d/eval/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2a4dbf7b6c58249fdca883e2ed3a4f8d9268c8b
--- /dev/null
+++ b/vis4d/eval/metrics/__init__.py
@@ -0,0 +1 @@
+"""Eval metrics."""
diff --git a/vis4d/eval/metrics/cls.py b/vis4d/eval/metrics/cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a863fb7853644166c45202b45e87c621182c791
--- /dev/null
+++ b/vis4d/eval/metrics/cls.py
@@ -0,0 +1,31 @@
+"""Classification metrics."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import ArrayLike, ArrayLikeInt
+
+
+def accuracy(
+ prediction: ArrayLike, target: ArrayLikeInt, top_k: int = 1
+) -> float:
+ """Calculate the accuracy of the prediction.
+
+ Args:
+ prediction (ArrayLike): Probabilities (or logits) of shape (N, C) or
+ (C, ).
+ target (ArrayLikeInt): Target of shape (N, ) or (1, ).
+ top_k (int, optional): Top k accuracy. Defaults to 1.
+
+ Returns:
+ float: Accuracy of the prediction, in range [0, 1].
+ """
+ prediction = array_to_numpy(prediction, n_dims=2, dtype=np.float32)
+ target = array_to_numpy(target, n_dims=1, dtype=np.int64)
+ assert prediction.shape[0] == target.shape[0], "Batch size mismatch."
+ top_k = min(top_k, prediction.shape[1])
+ top_k_idx = np.argsort(prediction, axis=1)[:, -top_k:]
+ correct = np.any(top_k_idx == target[:, None], axis=1)
+ return float(np.mean(correct))
diff --git a/vis4d/eval/metrics/depth.py b/vis4d/eval/metrics/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d783b6c0ff1f55305edcde697b0846264165362
--- /dev/null
+++ b/vis4d/eval/metrics/depth.py
@@ -0,0 +1,146 @@
+"""Depth estimation metrics."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.typing import ArrayLike
+
+from ..utils import check_shape_match, dense_inputs_to_numpy
+
+
+def absolute_error(prediction: ArrayLike, target: ArrayLike) -> float:
+ """Compute the absolute error.
+
+ Args:
+ prediction (NDArrayNumber): Prediction depth map, in shape (..., H, W).
+ target (NDArrayNumber): Target depth map, in shape (..., H, W).
+
+ Returns:
+ float: Absolute error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ return np.mean(np.abs(prediction - target)).item()
+
+
+def squared_relative_error(prediction: ArrayLike, target: ArrayLike) -> float:
+ """Compute the squared relative error.
+
+ Args:
+ prediction (NDArrayNumber): Prediction depth map, in shape (..., H, W).
+ target (NDArrayNumber): Target depth map, in shape (..., H, W).
+
+ Returns:
+ float: Square relative error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ return np.mean(np.square(prediction - target) / target).item()
+
+
+def absolute_relative_error(prediction: ArrayLike, target: ArrayLike) -> float:
+ """Compute the absolute relative error.
+
+ Args:
+ prediction (NDArrayNumber): Prediction depth map, in shape (..., H, W).
+ target (NDArrayNumber): Target depth map, in shape (..., H, W).
+
+ Returns:
+ float: Absolute relative error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ return np.mean(np.abs(prediction - target) / target).item()
+
+
+def root_mean_squared_error(prediction: ArrayLike, target: ArrayLike) -> float:
+ """Compute the root mean squared error.
+
+ Args:
+ prediction (ArrayLike): Prediction depth map, in shape (..., H, W).
+ target (ArrayLike): Target depth map, in shape (..., H, W).
+
+ Returns:
+ float: Root mean squared error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ squared_diff = np.square(prediction - target)
+ return np.sqrt(np.mean(squared_diff)).item()
+
+
+def root_mean_squared_error_log(
+ prediction: ArrayLike, target: ArrayLike, epsilon: float = 1e-8
+) -> float:
+ """Compute the root mean squared error in log space.
+
+ Args:
+ prediction (ArrayLike): Prediction depth map, in shape (H, W).
+ target (ArrayLike): Target depth map, in shape (H, W).
+ epsilon (float, optional): Epsilon to avoid log(0). Defaults to 1e-6.
+
+ Returns:
+ float: Root mean squared error in log space.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ log_pred = np.log(prediction + epsilon)
+ log_target = np.log(target + epsilon)
+ squared_diff = np.square(log_pred - log_target)
+ return np.sqrt(np.mean(squared_diff)).item()
+
+
+def scale_invariant_log(
+ prediction: ArrayLike, target: ArrayLike, epsilon: float = 1e-8
+) -> float:
+ """Compute the scale invariant log error.
+
+ Args:
+ prediction (ArrayLike): Prediction depth map, in shape (H, W).
+ target (ArrayLike): Target depth map, in shape (H, W).
+ epsilon (float, optional): Epsilon to avoid log(0). Defaults to 1e-6.
+
+ Returns:
+ float: Scale invariant log error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ log_diff = np.log(prediction + epsilon) - np.log(target + epsilon)
+ return 100.0 * float(np.sqrt(np.var(log_diff)).mean())
+
+
+def delta_p(
+ prediction: ArrayLike, target: ArrayLike, power: float = 1
+) -> float:
+ """Compute the delta_p metric.
+
+ Args:
+ prediction (ArrayLike): Prediction depth map, in shape (H, W).
+ target (ArrayLike): Target depth map, in shape (H, W).
+ power (float, optional): Power of the threshold. Defaults to 1.
+
+ Returns:
+ float: Delta_p metric.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ return np.mean(
+ np.maximum((target / prediction), (prediction / target)) < 1.25**power
+ ).item()
+
+
+def log_10_error(prediction: ArrayLike, target: ArrayLike) -> float:
+ """Compute the log_10 error.
+
+ Args:
+ prediction (ArrayLike): Prediction depth map, in shape (H, W).
+ target (ArrayLike): Target depth map, in shape (H, W).
+
+ Returns:
+ float: Log_10 error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ log10_diff = np.log10(prediction) - np.log10(target)
+ return np.mean(np.abs(log10_diff)).item()
diff --git a/vis4d/eval/metrics/flow.py b/vis4d/eval/metrics/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c9e7d13c6d7b4735c4ccd037b78f505d27dc69a
--- /dev/null
+++ b/vis4d/eval/metrics/flow.py
@@ -0,0 +1,47 @@
+"""Depth estimation metrics."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.typing import ArrayLike
+
+from ..utils import check_shape_match, dense_inputs_to_numpy
+
+
+def end_point_error(prediction: ArrayLike, target: ArrayLike) -> float:
+ """Compute the end point error.
+
+ Args:
+ prediction (ArrayLike): Prediction UV optical flow, in shape (..., 2).
+ target (ArrayLike): Target UV optical flow, in shape (..., 2).
+
+ Returns:
+ float: End point error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ squared_sum = np.sum((prediction - target) ** 2, axis=-1)
+ return np.mean(np.sqrt(squared_sum)).item()
+
+
+def angular_error(
+ prediction: ArrayLike, target: ArrayLike, epsilon: float = 1e-6
+) -> float:
+ """Compute the angular error.
+
+ Args:
+ prediction (ArrayLike): Prediction UV optical flow, in shape (..., 2).
+ target (ArrayLike): Target UV optical flow, in shape (..., 2).
+ epsilon (float, optional): Epsilon value for numerical stability.
+
+ Returns:
+ float: Angular error.
+ """
+ prediction, target = dense_inputs_to_numpy(prediction, target)
+ check_shape_match(prediction, target)
+ product = np.sum(prediction * target, axis=-1)
+ pred_norm = np.linalg.norm(prediction, axis=-1)
+ target_norm = np.linalg.norm(target, axis=-1)
+ cos_angle = np.abs(product) / (pred_norm * target_norm + epsilon)
+ return np.mean(np.arccos(np.clip(cos_angle, 0.0, 1.0))).item()
diff --git a/vis4d/eval/nuscenes/__init__.py b/vis4d/eval/nuscenes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fbf75265b626beadbe81b68b09dec0162b8f623
--- /dev/null
+++ b/vis4d/eval/nuscenes/__init__.py
@@ -0,0 +1,6 @@
+"""NuScenes evaluator."""
+
+from .detect3d import NuScenesDet3DEvaluator
+from .track3d import NuScenesTrack3DEvaluator
+
+__all__ = ["NuScenesDet3DEvaluator", "NuScenesTrack3DEvaluator"]
diff --git a/vis4d/eval/nuscenes/detect3d.py b/vis4d/eval/nuscenes/detect3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2817120eca10fea05cbfbb10225843e74cc44e7
--- /dev/null
+++ b/vis4d/eval/nuscenes/detect3d.py
@@ -0,0 +1,338 @@
+"""NuScenes 3D detection evaluation code."""
+
+from __future__ import annotations
+
+import json
+import os
+from collections.abc import Callable
+from typing import Any
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.imports import NUSCENES_AVAILABLE
+from vis4d.common.logging import rank_zero_warn
+from vis4d.common.typing import ArrayLike, DictStrAny, MetricLogs
+from vis4d.data.datasets.nuscenes import (
+ nuscenes_attribute_map,
+ nuscenes_class_map,
+)
+
+from ..base import Evaluator
+
+if NUSCENES_AVAILABLE:
+ from nuscenes import NuScenes as NuScenesDevkit
+ from nuscenes.eval.detection.config import config_factory
+ from nuscenes.eval.detection.evaluate import NuScenesEval
+ from nuscenes.utils.data_classes import Quaternion
+else:
+ raise ImportError("nuscenes-devkit is not installed.")
+
+
+def _parse_high_level_metrics(
+ mean_ap: float,
+ tp_errors: dict[str, float],
+ nd_score: float,
+ eval_time: float,
+) -> tuple[MetricLogs, list[str]]:
+ """Collect high-level metrics."""
+ log_dict: MetricLogs = {
+ "mAP": mean_ap,
+ "mATE": tp_errors["trans_err"],
+ "mASE": tp_errors["scale_err"],
+ "mAOE": tp_errors["orient_err"],
+ "mAVE": tp_errors["vel_err"],
+ "mAAE": tp_errors["attr_err"],
+ "NDS": nd_score,
+ }
+
+ str_summary_list = ["\nHigh-level metrics:"]
+ for k, v in log_dict.items():
+ str_summary_list.append(f"{k}: {v:.4f}")
+
+ str_summary_list.append(f"Eval time: {eval_time:.1f}s")
+
+ return log_dict, str_summary_list
+
+
+def _parse_per_class_metrics(
+ str_summary_list: list[str],
+ class_aps: dict[str, float],
+ class_tps: dict[str, dict[str, float]],
+) -> list[str]:
+ """Collect per-class metrics."""
+ str_summary_list.append("\nPer-class results:")
+ str_summary_list.append("Object Class\tAP\tATE\tASE\tAOE\tAVE\tAAE")
+
+ for class_name in class_aps.keys():
+ tmp_str_list = [class_name]
+ tmp_str_list.append(f"{class_aps[class_name]:.3f}")
+ tmp_str_list.append(f"{class_tps[class_name]['trans_err']:.3f}")
+ tmp_str_list.append(f"{class_tps[class_name]['scale_err']:.3f}")
+ tmp_str_list.append(f"{class_tps[class_name]['orient_err']:.3f}")
+ tmp_str_list.append(f"{class_tps[class_name]['vel_err']:.3f}")
+ tmp_str_list.append(f"{class_tps[class_name]['attr_err']:.3f}")
+
+ str_summary_list.append("\t".join(tmp_str_list))
+ return str_summary_list
+
+
+class NuScenesDet3DEvaluator(Evaluator):
+ """NuScenes 3D detection evaluation class."""
+
+ inv_nuscenes_attribute_map = {
+ v: k for k, v in nuscenes_attribute_map.items()
+ }
+
+ DefaultAttribute = {
+ "car": "vehicle.parked",
+ "pedestrian": "pedestrian.moving",
+ "trailer": "vehicle.parked",
+ "truck": "vehicle.parked",
+ "bus": "vehicle.moving",
+ "motorcycle": "cycle.without_rider",
+ "construction_vehicle": "vehicle.parked",
+ "bicycle": "cycle.without_rider",
+ "barrier": "",
+ "traffic_cone": "",
+ }
+
+ def __init__(
+ self,
+ data_root: str,
+ version: str,
+ split: str,
+ save_only: bool = False,
+ class_map: dict[str, int] | None = None,
+ metadata: tuple[str, ...] = ("use_camera",),
+ use_default_attr: bool = False,
+ velocity_thres: float = 1.0,
+ ) -> None:
+ """Initialize NuScenes evaluator."""
+ super().__init__()
+ self.data_root = data_root
+ self.version = version
+ self.split = split
+ self.save_only = save_only
+ self.use_default_attr = use_default_attr
+ self.velocity_thres = velocity_thres
+
+ self.meta_data = {
+ "use_camera": False,
+ "use_lidar": False,
+ "use_radar": False,
+ "use_map": False,
+ "use_external": False,
+ }
+
+ for m in metadata:
+ self.meta_data[m] = True
+
+ class_map = class_map or nuscenes_class_map
+ self.inv_nuscenes_class_map = {v: k for k, v in class_map.items()}
+
+ self.output_dir = ""
+ self.detect_3d: DictStrAny = {}
+ self.reset()
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "NuScenes 3D Detection Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return ["detect_3d"]
+
+ def gather( # type: ignore
+ self, gather_func: Callable[[Any], Any]
+ ) -> None:
+ """Gather variables in case of distributed setting (if needed).
+
+ Args:
+ gather_func (Callable[[Any], Any]): Gather function.
+ """
+ detect_3d_list = gather_func(self.detect_3d)
+ if detect_3d_list is not None:
+ collated_detect_3d: DictStrAny = {}
+ for prediction in detect_3d_list:
+ for k, v in prediction.items():
+ if k not in collated_detect_3d:
+ collated_detect_3d[k] = v
+ else:
+ collated_detect_3d[k].extend(v)
+ self.detect_3d = collated_detect_3d
+
+ def reset(self) -> None:
+ """Reset evaluator."""
+ self.detect_3d.clear()
+
+ def get_attributes(self, name: str, velocity: list[float]) -> str:
+ """Get nuScenes attributes."""
+ if self.use_default_attr:
+ return self.DefaultAttribute[name]
+
+ if np.sqrt(velocity[0] ** 2 + velocity[1] ** 2) > self.velocity_thres:
+ if name in {
+ "car",
+ "construction_vehicle",
+ "bus",
+ "truck",
+ "trailer",
+ }:
+ attr = "vehicle.moving"
+ elif name in {"bicycle", "motorcycle"}:
+ attr = "cycle.with_rider"
+ else:
+ attr = self.DefaultAttribute[name]
+ elif name in {"pedestrian"}:
+ attr = "pedestrian.standing"
+ elif name in {"bus"}:
+ attr = "vehicle.stopped"
+ else:
+ attr = self.DefaultAttribute[name]
+ return attr
+
+ def _process_detect_3d(
+ self,
+ token: str,
+ boxes_3d: ArrayLike,
+ velocities: ArrayLike,
+ scores_3d: ArrayLike,
+ class_ids: ArrayLike,
+ attributes: ArrayLike | None = None,
+ ) -> None:
+ """Process 3D detection results."""
+ annos = []
+ boxes_3d_np = array_to_numpy(boxes_3d, n_dims=None, dtype=np.float32)
+ velocities_np = array_to_numpy(
+ velocities, n_dims=None, dtype=np.float32
+ )
+ scores_3d_np = array_to_numpy(scores_3d, n_dims=None, dtype=np.float32)
+ class_ids_np = array_to_numpy(class_ids, n_dims=None, dtype=np.int64)
+
+ if len(boxes_3d_np) != 0:
+ for i, (box_3d, velocity, score_3d, class_id) in enumerate(
+ zip(
+ boxes_3d_np,
+ velocities_np,
+ scores_3d_np,
+ class_ids_np,
+ )
+ ):
+ category = self.inv_nuscenes_class_map[int(class_id)]
+
+ translation = box_3d[0:3]
+
+ dims = box_3d[3:6].tolist()
+ dimension = [d if d >= 0 else 0.1 for d in dims]
+
+ rotation = Quaternion(box_3d[6:].tolist())
+
+ score = float(score_3d)
+
+ velocity_list = velocity.tolist()
+
+ if attributes is None:
+ attribute_name = self.get_attributes(
+ category, velocity_list
+ )
+ else:
+ attribute = array_to_numpy(
+ attributes[i], n_dims=None, dtype=np.int64 # type: ignore # pylint: disable=line-too-long
+ )
+ attribute_name = self.inv_nuscenes_attribute_map[
+ int(attribute)
+ ]
+
+ nusc_anno = {
+ "sample_token": token,
+ "translation": translation.tolist(),
+ "size": dimension,
+ "rotation": rotation.elements.tolist(),
+ "velocity": [velocity_list[0], velocity_list[1]],
+ "detection_name": category,
+ "detection_score": score,
+ "attribute_name": attribute_name,
+ }
+ annos.append(nusc_anno)
+ self.detect_3d[token] = annos
+
+ def process_batch(
+ self,
+ tokens: list[str],
+ boxes_3d: list[ArrayLike],
+ velocities: list[ArrayLike],
+ class_ids: list[ArrayLike],
+ scores_3d: list[ArrayLike],
+ attributes: list[ArrayLike] | None = None,
+ ) -> None:
+ """Process the results."""
+ for i, token in enumerate(tokens):
+ self._process_detect_3d(
+ token,
+ boxes_3d[i],
+ velocities[i],
+ scores_3d[i],
+ class_ids[i],
+ attributes[i] if attributes is not None else None,
+ )
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the results."""
+ assert metric == "detect_3d"
+ if self.save_only:
+ return {}, "Results are saved to the json file."
+
+ try:
+ nusc = NuScenesDevkit(
+ version=self.version,
+ dataroot=self.data_root,
+ verbose=False,
+ )
+
+ nusc_eval = NuScenesEval(
+ nusc,
+ config=config_factory("detection_cvpr_2019"),
+ result_path=f"{self.output_dir}/detect_3d_predictions.json",
+ eval_set=self.split,
+ output_dir=os.path.join(self.output_dir, "detection"),
+ )
+ metrics, _ = nusc_eval.evaluate()
+ metrics_summary = metrics.serialize()
+
+ log_dict, str_summary_list = _parse_high_level_metrics(
+ metrics_summary["mean_ap"],
+ metrics_summary["tp_errors"],
+ metrics_summary["nd_score"],
+ metrics_summary["eval_time"],
+ )
+
+ class_aps = metrics_summary["mean_dist_aps"]
+ class_tps = metrics_summary["label_tp_errors"]
+ str_summary_list = _parse_per_class_metrics(
+ str_summary_list, class_aps, class_tps
+ )
+
+ str_summary = "\n".join(str_summary_list)
+ except Exception as e: # pylint: disable=broad-except
+ error_msg = "".join(e.args)
+ rank_zero_warn(f"Evaluation error: {error_msg}")
+ log_dict = {}
+ str_summary = (
+ "Evaluation failure might be raised due to sanity check"
+ + "or all emtpy boxes."
+ )
+ rank_zero_warn(str_summary)
+ return log_dict, str_summary
+
+ def save(self, metric: str, output_dir: str) -> None:
+ """Save the results to json files."""
+ assert metric == "detect_3d"
+ nusc_annos = {"results": self.detect_3d, "meta": self.meta_data}
+ result_file = f"{output_dir}/detect_3d_predictions.json"
+
+ with open(result_file, mode="w", encoding="utf-8") as f:
+ json.dump(nusc_annos, f)
+
+ self.output_dir = output_dir
diff --git a/vis4d/eval/nuscenes/track3d.py b/vis4d/eval/nuscenes/track3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..77d7f89900e9d8f9995cd3503100a0f2fb3275c9
--- /dev/null
+++ b/vis4d/eval/nuscenes/track3d.py
@@ -0,0 +1,167 @@
+"""NuScenes 3D tracking evaluation code."""
+
+from __future__ import annotations
+
+import json
+from collections.abc import Callable
+from typing import Any
+
+import numpy as np
+from nuscenes.utils.data_classes import Quaternion
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import ArrayLike, DictStrAny, MetricLogs
+from vis4d.data.datasets.nuscenes import nuscenes_class_map
+
+from ..base import Evaluator
+
+
+class NuScenesTrack3DEvaluator(Evaluator):
+ """NuScenes 3D tracking evaluation class."""
+
+ inv_nuscenes_class_map = {v: k for k, v in nuscenes_class_map.items()}
+
+ tracking_cats = [
+ "bicycle",
+ "motorcycle",
+ "pedestrian",
+ "bus",
+ "car",
+ "trailer",
+ "truck",
+ ]
+
+ def __init__(self, metadata: tuple[str, ...] = ("use_camera",)) -> None:
+ """Initialize NuScenes evaluator."""
+ super().__init__()
+ self.meta_data = {
+ "use_camera": False,
+ "use_lidar": False,
+ "use_radar": False,
+ "use_map": False,
+ "use_external": False,
+ }
+
+ for m in metadata:
+ self.meta_data[m] = True
+
+ self.tracks_3d: DictStrAny = {}
+ self.reset()
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "NuScenes 3D Tracking Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return ["track_3d"]
+
+ def gather( # type: ignore
+ self, gather_func: Callable[[Any], Any]
+ ) -> None:
+ """Gather variables in case of distributed setting (if needed).
+
+ Args:
+ gather_func (Callable[[Any], Any]): Gather function.
+ """
+ tracks_3d_list = gather_func(self.tracks_3d)
+ if tracks_3d_list is not None:
+ collated_track_3d: DictStrAny = {}
+ for prediction in tracks_3d_list:
+ for k, v in prediction.items():
+ if k not in collated_track_3d:
+ collated_track_3d[k] = v
+ else:
+ collated_track_3d[k].extend(v)
+ self.tracks_3d = collated_track_3d
+
+ def reset(self) -> None:
+ """Reset evaluator."""
+ self.tracks_3d.clear()
+
+ def _process_track_3d(
+ self,
+ token: str,
+ boxes_3d: ArrayLike,
+ velocities: ArrayLike,
+ scores_3d: ArrayLike,
+ class_ids: ArrayLike,
+ track_ids: ArrayLike,
+ ) -> None:
+ """Process 3D tracking results."""
+ annos = []
+ boxes_3d_np = array_to_numpy(boxes_3d, n_dims=None, dtype=np.float32)
+ velocities_np = array_to_numpy(
+ velocities, n_dims=None, dtype=np.float32
+ )
+ scores_3d_np = array_to_numpy(scores_3d, n_dims=None, dtype=np.float32)
+ class_ids_np = array_to_numpy(class_ids, n_dims=None, dtype=np.int64)
+ track_ids_np = array_to_numpy(track_ids, n_dims=None, dtype=np.int64)
+
+ if len(boxes_3d_np) != 0:
+ for box_3d, velocity, score_3d, class_id, track_id in zip(
+ boxes_3d_np,
+ velocities_np,
+ scores_3d_np,
+ class_ids_np,
+ track_ids_np,
+ ):
+ category = self.inv_nuscenes_class_map[int(class_id)]
+ if not category in self.tracking_cats:
+ continue
+
+ translation = box_3d[0:3]
+
+ dimension = box_3d[3:6]
+
+ rotation = Quaternion(box_3d[6:].tolist())
+
+ score = float(score_3d)
+
+ velocity_list = velocity.tolist()
+
+ nusc_anno = {
+ "sample_token": token,
+ "translation": translation.tolist(),
+ "size": dimension.tolist(),
+ "rotation": rotation.elements.tolist(),
+ "velocity": [velocity_list[0], velocity_list[1]],
+ "tracking_id": int(track_id),
+ "tracking_name": category,
+ "tracking_score": score,
+ }
+ annos.append(nusc_anno)
+ self.tracks_3d[token] = annos
+
+ def process_batch(
+ self,
+ tokens: list[str],
+ boxes_3d: list[ArrayLike],
+ velocities: list[ArrayLike],
+ class_ids: list[ArrayLike],
+ scores_3d: list[ArrayLike],
+ track_ids: list[ArrayLike],
+ ) -> None:
+ """Process the results."""
+ for i, token in enumerate(tokens):
+ self._process_track_3d(
+ token,
+ boxes_3d[i],
+ velocities[i],
+ scores_3d[i],
+ class_ids[i],
+ track_ids[i],
+ )
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the results."""
+ return {}, "Currently only save the json file."
+
+ def save(self, metric: str, output_dir: str) -> None:
+ """Save the results to json files."""
+ nusc_annos = {"results": self.tracks_3d, "meta": self.meta_data}
+ result_file = f"{output_dir}/track_3d_predictions.json"
+
+ with open(result_file, mode="w", encoding="utf-8") as f:
+ json.dump(nusc_annos, f)
diff --git a/vis4d/eval/scalabel/__init__.py b/vis4d/eval/scalabel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d006bf0d3e6760054f6e766ee82419e0a2bf5a6
--- /dev/null
+++ b/vis4d/eval/scalabel/__init__.py
@@ -0,0 +1,11 @@
+"""Scalabel evaluator."""
+
+from .base import ScalabelEvaluator
+from .detect import ScalabelDetectEvaluator
+from .track import ScalabelTrackEvaluator
+
+__all__ = [
+ "ScalabelEvaluator",
+ "ScalabelDetectEvaluator",
+ "ScalabelTrackEvaluator",
+]
diff --git a/vis4d/eval/scalabel/base.py b/vis4d/eval/scalabel/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d96f60990d97a47992424e93c59e1af1f41223be
--- /dev/null
+++ b/vis4d/eval/scalabel/base.py
@@ -0,0 +1,65 @@
+"""Scalabel base evaluator."""
+
+from __future__ import annotations
+
+import itertools
+from collections.abc import Callable
+from typing import Any
+
+from vis4d.common.imports import SCALABEL_AVAILABLE
+from vis4d.common.typing import MetricLogs
+from vis4d.eval.base import Evaluator
+
+if SCALABEL_AVAILABLE:
+ from scalabel.label.io import load
+ from scalabel.label.typing import Config, Frame
+ from scalabel.label.utils import get_leaf_categories
+else:
+ raise ImportError("scalabel is not installed.")
+
+
+class ScalabelEvaluator(Evaluator):
+ """Scalabel base evaluation class."""
+
+ def __init__(
+ self, annotation_path: str, config: Config | None = None
+ ) -> None:
+ """Initialize the evaluator."""
+ super().__init__()
+ self.annotation_path = annotation_path
+ self.frames: list[Frame] = []
+
+ dataset = load(self.annotation_path, validate_frames=False)
+ self.gt_frames = dataset.frames
+ if config is not None:
+ self.config: Config | None = config
+ else:
+ self.config = dataset.config
+ if self.config is not None and self.config.categories is not None:
+ categories = get_leaf_categories(self.config.categories)
+ self.inverse_cat_map = {
+ cat_id: cat.name for cat_id, cat in enumerate(categories)
+ }
+ else:
+ self.inverse_cat_map = {}
+ self.reset()
+
+ def gather( # type: ignore # pragma: no cover
+ self, gather_func: Callable[[Any], Any]
+ ) -> None:
+ """Gather variables in case of distributed setting (if needed).
+
+ Args:
+ gather_func (Callable[[Any], Any]): Gather function.
+ """
+ all_preds = gather_func(self.frames)
+ if all_preds is not None:
+ self.frames = list(itertools.chain(*all_preds))
+
+ def reset(self) -> None:
+ """Reset the evaluator."""
+ self.frames = []
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the dataset."""
+ raise NotImplementedError
diff --git a/vis4d/eval/scalabel/detect.py b/vis4d/eval/scalabel/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e0a80444a9cfbb24476f7e7183cdc28bf69a0ef
--- /dev/null
+++ b/vis4d/eval/scalabel/detect.py
@@ -0,0 +1,139 @@
+"""Scalabel detection evaluator."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.imports import SCALABEL_AVAILABLE
+from vis4d.common.typing import ArrayLike, MetricLogs
+
+from .base import ScalabelEvaluator
+
+if SCALABEL_AVAILABLE:
+ from scalabel.eval.detect import evaluate_det
+ from scalabel.eval.ins_seg import evaluate_ins_seg
+ from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d
+ from scalabel.label.typing import Config, Frame, Label
+else:
+ raise ImportError("scalabel is not installed.")
+
+
+class ScalabelDetectEvaluator(ScalabelEvaluator):
+ """Scalabel 2D detection evaluation class."""
+
+ METRICS_DET = "Det"
+ METRICS_INS_SEG = "InsSeg"
+
+ def __init__(
+ self,
+ annotation_path: str,
+ config: Config | None = None,
+ mask_threshold: float = 0.0,
+ ) -> None:
+ """Initialize the evaluator."""
+ super().__init__(annotation_path=annotation_path, config=config)
+ self.mask_threshold = mask_threshold
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "Scalabel Detection Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [self.METRICS_DET, self.METRICS_INS_SEG]
+
+ def process_batch(
+ self,
+ frame_ids: list[int],
+ sample_names: list[str],
+ sequence_names: list[str],
+ pred_boxes: list[ArrayLike],
+ pred_classes: list[ArrayLike],
+ pred_scores: list[ArrayLike],
+ pred_masks: list[ArrayLike] | None = None,
+ ) -> None:
+ """Process tracking results."""
+ for i, (
+ frame_id,
+ sample_name,
+ sequence_name,
+ boxes,
+ class_ids,
+ scores,
+ ) in enumerate(
+ zip(
+ frame_ids,
+ sample_names,
+ sequence_names,
+ pred_boxes,
+ pred_classes,
+ pred_scores,
+ )
+ ):
+ boxes = array_to_numpy(boxes, n_dims=None, dtype=np.float32)
+ class_ids = array_to_numpy(class_ids, n_dims=None, dtype=np.int64)
+ scores = array_to_numpy(scores, n_dims=None, dtype=np.float32)
+ if pred_masks:
+ masks = array_to_numpy(
+ pred_masks[i], n_dims=None, dtype=np.float32
+ )
+ labels = []
+ for label_id, (box, score, class_id) in enumerate(
+ zip(boxes, scores, class_ids)
+ ):
+ box2d = xyxy_to_box2d(*box.tolist())
+
+ if pred_masks:
+ rle = mask_to_rle(
+ (masks[label_id] > self.mask_threshold).astype(
+ np.uint8
+ )
+ )
+ else:
+ rle = None
+
+ label = Label(
+ id=str(label_id),
+ box2d=box2d,
+ category=(
+ self.inverse_cat_map[int(class_id)]
+ if self.inverse_cat_map != {}
+ else str(class_id)
+ ),
+ score=float(score),
+ rle=rle,
+ )
+ labels.append(label)
+ frame = Frame(
+ name=sample_name,
+ videoName=sequence_name,
+ frameIndex=frame_id,
+ labels=labels,
+ )
+ self.frames.append(frame)
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the dataset."""
+ assert self.config is not None, "Scalabel config is not loaded."
+ metrics_log: MetricLogs = {}
+ short_description = ""
+
+ if metric == self.METRICS_DET:
+ results = evaluate_det(
+ self.gt_frames, self.frames, config=self.config, nproc=0
+ )
+ for metric_name, metric_value in results.summary().items():
+ metrics_log[metric_name] = metric_value
+ short_description += str(results) + "\n"
+
+ if metric == self.METRICS_INS_SEG:
+ results = evaluate_ins_seg(
+ self.gt_frames, self.frames, config=self.config, nproc=0
+ )
+ for metric_name, metric_value in results.summary().items():
+ metrics_log[metric_name] = metric_value
+ short_description += str(results) + "\n"
+
+ return metrics_log, short_description
diff --git a/vis4d/eval/scalabel/track.py b/vis4d/eval/scalabel/track.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb80132c8c45302f530949324b82ba85898cb42a
--- /dev/null
+++ b/vis4d/eval/scalabel/track.py
@@ -0,0 +1,153 @@
+"""Scalabel tracking evaluator."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.imports import SCALABEL_AVAILABLE
+from vis4d.common.typing import MetricLogs, NDArrayNumber
+
+from .base import ScalabelEvaluator
+
+if SCALABEL_AVAILABLE:
+ from scalabel.eval.mot import acc_single_video_mot, evaluate_track
+ from scalabel.eval.mots import acc_single_video_mots, evaluate_seg_track
+ from scalabel.label.io import group_and_sort
+ from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d
+ from scalabel.label.typing import Config, Frame, Label
+else:
+ raise ImportError("scalabel is not installed.")
+
+
+class ScalabelTrackEvaluator(ScalabelEvaluator):
+ """Scalabel 2D tracking evaluation class."""
+
+ METRICS_TRACK = "MOT"
+ METRICS_SEG_TRACK = "MOTS"
+ METRICS_ALL = "all"
+
+ def __init__(
+ self,
+ annotation_path: str,
+ config: Config | None = None,
+ mask_threshold: float = 0.0,
+ ) -> None:
+ """Initialize the evaluator."""
+ super().__init__(annotation_path=annotation_path, config=config)
+ self.mask_threshold = mask_threshold
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "Scalabel Tracking Evaluator"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics."""
+ return [self.METRICS_TRACK, self.METRICS_SEG_TRACK]
+
+ def process_batch(
+ self,
+ frame_ids: list[int],
+ sample_names: list[str],
+ sequence_names: list[str],
+ pred_boxes: list[NDArrayNumber],
+ pred_classes: list[NDArrayNumber],
+ pred_scores: list[NDArrayNumber],
+ pred_track_ids: list[NDArrayNumber],
+ pred_masks: list[NDArrayNumber] | None = None,
+ ) -> None:
+ """Process tracking results."""
+ for i, (
+ frame_id,
+ sample_name,
+ sequence_name,
+ boxes,
+ scores,
+ class_ids,
+ track_ids,
+ ) in enumerate(
+ zip(
+ frame_ids,
+ sample_names,
+ sequence_names,
+ pred_boxes,
+ pred_scores,
+ pred_classes,
+ pred_track_ids,
+ )
+ ):
+ boxes = array_to_numpy(boxes, n_dims=None, dtype=np.float32)
+ class_ids = array_to_numpy(class_ids, n_dims=None, dtype=np.int64)
+ scores = array_to_numpy(scores, n_dims=None, dtype=np.float32)
+ if pred_masks:
+ masks = array_to_numpy(
+ pred_masks[i], n_dims=None, dtype=np.float32
+ )
+
+ labels = []
+ for label_id, (box, score, class_id, track_id) in enumerate(
+ zip(boxes, scores, class_ids, track_ids)
+ ):
+ box2d = xyxy_to_box2d(*box.tolist())
+
+ if pred_masks:
+ rle = mask_to_rle(
+ (masks[label_id] > self.mask_threshold).astype(
+ np.uint8
+ )
+ )
+ else:
+ rle = None
+
+ label = Label(
+ box2d=box2d,
+ category=(
+ self.inverse_cat_map[int(class_id)]
+ if self.inverse_cat_map != {}
+ else str(class_id)
+ ),
+ score=float(score),
+ id=str(int(track_id)),
+ rle=rle,
+ )
+ labels.append(label)
+ frame = Frame(
+ name=sample_name,
+ videoName=sequence_name,
+ frameIndex=frame_id,
+ labels=labels,
+ )
+ self.frames.append(frame)
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate the dataset."""
+ assert self.config is not None, "config is not set"
+ metrics_log: MetricLogs = {}
+ short_description = ""
+
+ if metric in [self.METRICS_TRACK, self.METRICS_ALL]:
+ results = evaluate_track(
+ acc_single_video_mot,
+ gts=group_and_sort(self.gt_frames),
+ results=group_and_sort(self.frames),
+ config=self.config,
+ nproc=0,
+ )
+ for metric_name, metric_value in results.summary().items():
+ metrics_log[metric_name] = metric_value
+ short_description += str(results) + "\n"
+
+ if metric in [self.METRICS_SEG_TRACK, self.METRICS_ALL]:
+ results = evaluate_seg_track(
+ acc_single_video_mots,
+ gts=group_and_sort(self.gt_frames),
+ results=group_and_sort(self.frames),
+ config=self.config,
+ nproc=0,
+ )
+ for metric_name, metric_value in results.summary().items():
+ metrics_log[metric_name] = metric_value
+ short_description += str(results) + "\n"
+
+ return metrics_log, short_description
diff --git a/vis4d/eval/shift/__init__.py b/vis4d/eval/shift/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ff5c48a9c3320c36cec71a1b0ace0c70e9e3a4
--- /dev/null
+++ b/vis4d/eval/shift/__init__.py
@@ -0,0 +1,17 @@
+"""SHIFT evaluation metrics."""
+
+from .depth import SHIFTDepthEvaluator
+from .detect import SHIFTDetectEvaluator
+from .flow import SHIFTOpticalFlowEvaluator
+from .multitask_writer import SHIFTMultitaskWriter
+from .seg import SHIFTSegEvaluator
+from .track import SHIFTTrackEvaluator
+
+__all__ = [
+ "SHIFTDepthEvaluator",
+ "SHIFTDetectEvaluator",
+ "SHIFTOpticalFlowEvaluator",
+ "SHIFTSegEvaluator",
+ "SHIFTTrackEvaluator",
+ "SHIFTMultitaskWriter",
+]
diff --git a/vis4d/eval/shift/depth.py b/vis4d/eval/shift/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..55e4f3708f7d7c0ef3745594910c5f3764360dde
--- /dev/null
+++ b/vis4d/eval/shift/depth.py
@@ -0,0 +1,45 @@
+"""SHIFT depth estimation evaluator."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import NDArrayNumber
+
+from ..common import DepthEvaluator
+
+
+def apply_crop(depth: NDArrayNumber) -> NDArrayNumber:
+ """Apply crop to depth map to match SHIFT evaluation."""
+ return depth[..., 0:740, :]
+
+
+class SHIFTDepthEvaluator(DepthEvaluator):
+ """SHIFT depth estimation evaluation class."""
+
+ def __init__(self, use_eval_crop: bool = True) -> None:
+ """Initialize the evaluator.
+
+ Args:
+ use_eval_crop (bool): Whether to use the evaluation crop.
+ Default: True.
+ """
+ super().__init__(min_depth=0.01, max_depth=80.0)
+ self.use_eval_crop = use_eval_crop
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "SHIFT Depth Estimation Evaluator"
+
+ def process_batch( # type: ignore # pylint: disable=arguments-differ
+ self, prediction: NDArrayNumber, groundtruth: NDArrayNumber
+ ) -> None:
+ """Process sample and update confusion matrix.
+
+ Args:
+ prediction: Predictions of shape (N, H, W).
+ groundtruth: Groundtruth of shape (N, H, W).
+ """
+ if self.use_eval_crop:
+ prediction = apply_crop(prediction)
+ groundtruth = apply_crop(groundtruth)
+ print(prediction.shape, groundtruth.shape)
+ super().process_batch(prediction, groundtruth)
diff --git a/vis4d/eval/shift/detect.py b/vis4d/eval/shift/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef4575600c5d5c14155fa5e5328645b9bec55132
--- /dev/null
+++ b/vis4d/eval/shift/detect.py
@@ -0,0 +1,18 @@
+"""SHIFT detection evaluator."""
+
+from __future__ import annotations
+
+from vis4d.data.datasets.shift import shift_det_map
+
+from ..scalabel import ScalabelDetectEvaluator
+
+
+class SHIFTDetectEvaluator(ScalabelDetectEvaluator):
+ """SHIFT detection evaluation class."""
+
+ inverse_det_map = {v: k for k, v in shift_det_map.items()}
+
+ def __init__(self, annotation_path: str) -> None:
+ """Initialize the evaluator."""
+ super().__init__(annotation_path=annotation_path, mask_threshold=0)
+ self.inverse_cat_map = self.inverse_det_map
diff --git a/vis4d/eval/shift/flow.py b/vis4d/eval/shift/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c3b5b3dc404398e8852003dbf8a429147fbc24
--- /dev/null
+++ b/vis4d/eval/shift/flow.py
@@ -0,0 +1,19 @@
+"""SHIFT optical flow estimation evaluator."""
+
+from __future__ import annotations
+
+from ..common import OpticalFlowEvaluator
+
+
+class SHIFTOpticalFlowEvaluator(OpticalFlowEvaluator):
+ """SHIFT optical flow estimation evaluation class."""
+
+ def __init__(
+ self,
+ ) -> None:
+ """Initialize the evaluator."""
+ super().__init__(max_flow=200.0, use_degrees=False, scale=1.0)
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "SHIFT Optical Flow Estimation Evaluator"
diff --git a/vis4d/eval/shift/multitask_writer.py b/vis4d/eval/shift/multitask_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6d228bedcd9fb07ae6a909b1989c753d140373f
--- /dev/null
+++ b/vis4d/eval/shift/multitask_writer.py
@@ -0,0 +1,279 @@
+"""SHIFT result writer."""
+
+from __future__ import annotations
+
+import io
+import itertools
+import json
+import os
+from collections import defaultdict
+
+import numpy as np
+from PIL import Image
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.imports import SCALABEL_AVAILABLE
+from vis4d.common.typing import (
+ ArrayLike,
+ GenericFunc,
+ MetricLogs,
+ NDArrayNumber,
+)
+from vis4d.data.datasets.shift import shift_det_map
+from vis4d.data.io import DataBackend, ZipBackend
+from vis4d.eval.base import Evaluator
+
+if SCALABEL_AVAILABLE:
+ from scalabel.label.transforms import mask_to_rle, xyxy_to_box2d
+ from scalabel.label.typing import Dataset, Frame, Label
+else:
+ raise ImportError("scalabel is not installed.")
+
+
+class SHIFTMultitaskWriter(Evaluator):
+ """SHIFT result writer for online evaluation."""
+
+ inverse_cat_map = {v: k for k, v in shift_det_map.items()}
+
+ def __init__(
+ self,
+ output_dir: str,
+ submission_file: str = "submission.zip",
+ ) -> None:
+ """Creates a new writer.
+
+ Args:
+ output_dir (str): Output directory.
+ submission_file (str): Submission file name. Defaults to
+ "submission.zip".
+ """
+ super().__init__()
+ assert submission_file.endswith(
+ ".zip"
+ ), "Submission file must be a zip file."
+ self.backend: DataBackend = ZipBackend()
+ self.output_path = os.path.join(output_dir, submission_file)
+ self.frames_det_2d: list[Frame] = []
+ self.frames_det_3d: list[Frame] = []
+ self.sample_counts: defaultdict[str, int] = defaultdict(int)
+
+ def _write_sem_mask(
+ self, sem_mask: NDArrayNumber, sample_name: str, video_name: str
+ ) -> None:
+ """Write semantic mask.
+
+ Args:
+ sem_mask (NDArrayNumber): Predicted semantic mask, shape (H, W).
+ sample_name (str): Sample name.
+ video_name (str): Video name.
+ """
+ image = Image.fromarray(sem_mask.astype("uint8"), mode="L")
+ image_bytes = io.BytesIO()
+ image.save(image_bytes, format="PNG")
+ self.backend.set(
+ f"{self.output_path}/semseg/{video_name}/{sample_name}",
+ image_bytes.getvalue(),
+ mode="w",
+ )
+
+ def _write_depth(
+ self, depth_map: NDArrayNumber, sample_name: str, video_name: str
+ ) -> None:
+ """Write depth map.
+
+ Args:
+ depth_map (NDArrayNumber): Predicted depth map, shape (H, W).
+ sample_name (str): Sample name.
+ video_name (str): Video name.
+ """
+ depth_map = np.clip(depth_map / 80.0 * 255.0, 0, 255)
+ image = Image.fromarray(depth_map.astype("uint8"), mode="L")
+ image_bytes = io.BytesIO()
+ image.save(image_bytes, format="PNG")
+ self.backend.set(
+ f"{self.output_path}/depth/{video_name}/{sample_name}",
+ image_bytes.getvalue(),
+ mode="w",
+ )
+
+ def _write_flow(
+ self, flow: NDArrayNumber, sample_name: str, video_name: str
+ ) -> None:
+ """Write semantic mask.
+
+ Args:
+ flow (NDArrayNumber): Predicted optical flow, shape (H, W, 2).
+ sample_name (str): Sample name.
+ video_name (str): Video name.
+ """
+ raise NotImplementedError
+
+ def process_batch(
+ self,
+ frame_ids: list[int],
+ sample_names: list[str],
+ sequence_names: list[str],
+ pred_sem_mask: list[ArrayLike] | None = None,
+ pred_depth: list[ArrayLike] | None = None,
+ pred_flow: list[ArrayLike] | None = None,
+ pred_boxes2d: list[ArrayLike] | None = None,
+ pred_boxes2d_classes: list[ArrayLike] | None = None,
+ pred_boxes2d_scores: list[ArrayLike] | None = None,
+ pred_boxes2d_track_ids: list[ArrayLike] | None = None,
+ pred_instance_masks: list[ArrayLike] | None = None,
+ ) -> None:
+ """Process SHIFT results.
+
+ You can omit some of the predictions if they are not used.
+
+ Args:
+ frame_ids (list[int]): Frame IDs.
+ sample_names (list[str]): Sample names.
+ sequence_names (list[str]): Sequence names.
+ pred_sem_mask (list[ArrayLike], optional): Predicted semantic
+ masks, each in shape (C, H, W) or (H, W). Defaults to None.
+ pred_depth (list[ArrayLike], optional): Predicted depth maps,
+ each in shape (H, W), with meter unit. Defaults to None.
+ pred_flow (list[ArrayLike], optional): Predicted optical flows,
+ each in shape (H, W, 2). Defaults to None.
+ pred_boxes2d (list[ArrayLike], optional): Predicted 2D boxes,
+ each in shape (N, 4). Defaults to None.
+ pred_boxes2d_classes (list[ArrayLike], optional): Predicted
+ 2D box classes, each in shape (N,). Defaults to None.
+ pred_boxes2d_scores (list[ArrayLike], optional): Predicted
+ 2D box scores, each in shape (N,). Defaults to None.
+ pred_boxes2d_track_ids (list[ArrayLike], optional): Predicted
+ 2D box track IDs, each in shape (N,). Defaults to None.
+ pred_instance_masks (list[ArrayLike], optional): Predicted
+ instance masks, each in shape (N, H, W). Defaults to None.
+ """
+ for i, (frame_id, sample_name, sequence_name) in enumerate(
+ zip(frame_ids, sample_names, sequence_names)
+ ):
+ if pred_sem_mask is not None:
+ sem_mask_ = array_to_numpy(
+ pred_sem_mask[i],
+ n_dims=None,
+ dtype=np.float32,
+ )
+ if len(sem_mask_.shape) == 3:
+ sem_mask = sem_mask_.argmax(axis=0)
+ else:
+ sem_mask = sem_mask_.astype(np.uint8)
+ semseg_filename = sample_name.replace(".jpg", ".png").replace(
+ "img", "semseg"
+ )
+ self._write_sem_mask(sem_mask, semseg_filename, sequence_name)
+ self.sample_counts["semseg"] += 1
+ if pred_depth is not None:
+ depth = array_to_numpy(
+ pred_depth[i], n_dims=None, dtype=np.float32
+ )
+ depth_filename = sample_name.replace(".jpg", ".png").replace(
+ "img", "depth"
+ )
+ self._write_depth(depth, depth_filename, sequence_name)
+ self.sample_counts["depth"] += 1
+ if pred_flow is not None:
+ flow = array_to_numpy(
+ pred_flow[i], n_dims=None, dtype=np.float32
+ )
+ self._write_flow(flow, sample_name, sequence_name)
+ self.sample_counts["flow"] += 1
+ if (
+ pred_boxes2d is not None
+ and pred_boxes2d_classes is not None
+ and pred_boxes2d_scores is not None
+ ):
+ labels = []
+ if pred_instance_masks:
+ masks = array_to_numpy(
+ pred_instance_masks[i], n_dims=None, dtype=np.float32
+ )
+ if pred_boxes2d_track_ids:
+ track_ids = array_to_numpy(
+ pred_boxes2d_track_ids[i],
+ n_dims=None,
+ dtype=np.int64,
+ )
+ for box, score, class_id in zip(
+ pred_boxes2d[i],
+ pred_boxes2d_scores[i],
+ pred_boxes2d_classes[i],
+ ):
+ box2d = xyxy_to_box2d(*box.tolist())
+ if pred_instance_masks:
+ rle = mask_to_rle(
+ (masks[class_id] > 0.0).astype(np.uint8)
+ )
+ else:
+ rle = None
+
+ if pred_boxes2d_track_ids:
+ track_id = str(int(track_ids[0]))
+ else:
+ track_id = None
+
+ label = Label(
+ box2d=box2d,
+ category=(
+ self.inverse_cat_map[int(class_id)]
+ if self.inverse_cat_map != {}
+ else str(class_id)
+ ),
+ score=float(score),
+ rle=rle,
+ id=track_id,
+ )
+ labels.append(label)
+ frame = Frame(
+ name=sample_name,
+ videoName=sequence_name,
+ frameIndex=frame_id,
+ labels=labels,
+ )
+ self.frames_det_2d.append(frame)
+ self.sample_counts["det_2d"] += 1
+
+ def gather(self, gather_func: GenericFunc) -> None: # pragma: no cover
+ """Gather variables in case of distributed setting (if needed).
+
+ Args:
+ gather_func (Callable[[Any], Any]): Gather function.
+ """
+ all_preds = gather_func(self.frames_det_2d)
+ if all_preds is not None:
+ self.frames_det_2d = list(itertools.chain(*all_preds))
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """No evaluation locally."""
+ return {}, "No evaluation locally."
+
+ def save(self, metric: str, output_dir: str) -> None:
+ """Save scalabel output to zip file.
+
+ Raises:
+ ValueError: If the number of samples in each category is not the
+ same.
+ """
+ # Check if the sample counts are correct
+ equal_size = True
+ for key in self.sample_counts:
+ if self.sample_counts[key] != len(self.frames_det_2d):
+ equal_size = False
+ break
+ if not equal_size:
+ raise ValueError(
+ "The number of samples in each category is not the same."
+ )
+
+ # Save the 2D detection results
+ if len(self.frames_det_2d) > 0:
+ ds = Dataset(frames=self.frames_det_2d, groups=None, config=None)
+ ds_bytes = json.dumps(ds.dict()).encode("utf-8")
+ self.backend.set(
+ f"{self.output_path}/det_2d.json", ds_bytes, mode="w"
+ )
+
+ self.backend.close()
+ print(f"Saved the submission file at {self.output_path}.")
diff --git a/vis4d/eval/shift/seg.py b/vis4d/eval/shift/seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..939d007a981ad57e834b64e801a05d70f9c68b2c
--- /dev/null
+++ b/vis4d/eval/shift/seg.py
@@ -0,0 +1,48 @@
+"""SHIFT segmentation evaluator."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import NDArrayI64, NDArrayNumber
+from vis4d.data.datasets.shift import shift_seg_ignore, shift_seg_map
+from vis4d.eval.common.seg import SegEvaluator
+
+
+class SHIFTSegEvaluator(SegEvaluator):
+ """SHIFT segmentation evaluation class."""
+
+ inverse_seg_map = {v: k for k, v in shift_seg_map.items()}
+
+ def __init__(self, ignore_classes_as_cityscapes: bool = True) -> None:
+ """Initialize the evaluator."""
+ super().__init__(
+ num_classes=23,
+ class_to_ignore=255,
+ class_mapping=self.inverse_seg_map,
+ )
+ self.ignore_classes_as_cityscapes = ignore_classes_as_cityscapes
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset evaluator."""
+ return "SHIFT Segmentation Evaluator"
+
+ def _prune_class(self, label: NDArrayI64) -> NDArrayI64:
+ """Prune class labels."""
+ for cls in shift_seg_ignore:
+ label[label == shift_seg_map[cls]] = 255
+ return label
+
+ def process_batch( # type: ignore # pylint: disable=arguments-differ
+ self, prediction: NDArrayNumber, groundtruth: NDArrayI64
+ ) -> None:
+ """Process sample and update confusion matrix.
+
+ Args:
+ prediction: Predictions of shape [N,C,...] or [N,...] with
+ C* being any number if channels. Note, C is passed,
+ the prediction is converted to target labels by applying
+ the max operations along the second axis
+ groundtruth: Groundtruth of shape [N_batch, ...] type int
+ """
+ if self.ignore_classes_as_cityscapes:
+ groundtruth = self._prune_class(groundtruth)
+ super().process_batch(prediction, groundtruth)
diff --git a/vis4d/eval/shift/track.py b/vis4d/eval/shift/track.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a9b0d1c50d1838d9d3a29a8ee424317dfff9535
--- /dev/null
+++ b/vis4d/eval/shift/track.py
@@ -0,0 +1,22 @@
+"""SHIFT tracking evaluator."""
+
+from __future__ import annotations
+
+from vis4d.data.datasets.shift import shift_det_map
+
+from ..scalabel import ScalabelTrackEvaluator
+
+
+class SHIFTTrackEvaluator(ScalabelTrackEvaluator):
+ """SHIFT tracking evaluation class."""
+
+ inverse_det_map = {v: k for k, v in shift_det_map.items()}
+
+ def __init__(
+ self, annotation_path: str, mask_threshold: float = 0.0
+ ) -> None:
+ """Initialize the evaluator."""
+ super().__init__(
+ annotation_path=annotation_path, mask_threshold=mask_threshold
+ )
+ self.inverse_cat_map = self.inverse_det_map
diff --git a/vis4d/eval/utils.py b/vis4d/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9911ebd4d481c30b2e479b8d65d5cd3ab54bab4a
--- /dev/null
+++ b/vis4d/eval/utils.py
@@ -0,0 +1,25 @@
+"""Utility functions for evaluation."""
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import ArrayLike, NDArrayNumber
+
+
+def dense_inputs_to_numpy(
+ prediction: ArrayLike, target: ArrayLike
+) -> tuple[NDArrayNumber, NDArrayNumber]:
+ """Convert dense prediction and target to numpy arrays."""
+ prediction = array_to_numpy(prediction, n_dims=None, dtype=np.float32)
+ target = array_to_numpy(target, n_dims=None, dtype=np.float32)
+ return prediction, target
+
+
+def check_shape_match(
+ prediction: NDArrayNumber, target: NDArrayNumber
+) -> None:
+ """Check if the shape of prediction and target matches."""
+ assert prediction.shape == target.shape, (
+ f"Shape mismatch between prediction {prediction.shape} and target"
+ f"{target.shape}."
+ )
diff --git a/vis4d/model/__init__.py b/vis4d/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9643ab24c6f0ca60c08174bb19903bc1ba1e817f
--- /dev/null
+++ b/vis4d/model/__init__.py
@@ -0,0 +1,6 @@
+"""Model definitions that connect operators and states together.
+
+All the compute should go to operators and the model memories should be kept
+in states. The models are supposed to do minimum job to connect the model
+pipelines.
+"""
diff --git a/vis4d/model/adapter/__init__.py b/vis4d/model/adapter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..450a3d45261b6526bb6aa4530b207f3bcbd81ea5
--- /dev/null
+++ b/vis4d/model/adapter/__init__.py
@@ -0,0 +1,5 @@
+"""Model adapters."""
+
+from .ema import ModelEMAAdapter, ModelExpEMAAdapter
+
+__all__ = ["ModelEMAAdapter", "ModelExpEMAAdapter"]
diff --git a/vis4d/model/adapter/ema.py b/vis4d/model/adapter/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..573836bc4adade41103738dab475fd74aaafb60a
--- /dev/null
+++ b/vis4d/model/adapter/ema.py
@@ -0,0 +1,118 @@
+"""Exponential Moving Average (EMA) for PyTorch models."""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Callable
+from copy import deepcopy
+from typing import Any
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.logging import rank_zero_info
+
+
+class ModelEMAAdapter(nn.Module):
+ """Torch module with Exponential Moving Average (EMA).
+
+ Args:
+ model (nn.Module): Model to apply EMA.
+ decay (float): Decay factor for EMA. Defaults to 0.9998.
+ use_ema_during_test (bool): Use EMA model during testing. Defaults to
+ True.
+ device (torch.device | None): Device to use. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ decay: float = 0.9998,
+ use_ema_during_test: bool = True,
+ device: torch.device | None = None,
+ ):
+ """Init ModelEMAAdapter class."""
+ super().__init__()
+ self.model = model
+ self.ema_model = deepcopy(self.model)
+ self.ema_model.eval()
+ for p in self.ema_model.parameters():
+ p.requires_grad_(False)
+ self.decay = decay
+ self.use_ema_during_test = use_ema_during_test
+ self.device = device
+ if self.device is not None:
+ self.ema_model.to(device=device)
+ rank_zero_info("Using model EMA with decay rate %f", self.decay)
+
+ def _update(
+ self, model: nn.Module, update_fn: Callable[[Tensor, Tensor], Tensor]
+ ) -> None:
+ """Update model params."""
+ with torch.no_grad():
+ for ema_v, model_v in zip(
+ self.ema_model.state_dict().values(),
+ model.state_dict().values(),
+ ):
+ if self.device is not None:
+ model_v = model_v.to(device=self.device)
+ ema_v.copy_(update_fn(ema_v, model_v))
+
+ def update(self, steps: int) -> None: # pylint: disable=unused-argument
+ """Update the internal EMA model."""
+ self._update(
+ self.model,
+ update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m,
+ )
+
+ def set(self, model: nn.Module) -> None:
+ """Copy model params into the internal EMA."""
+ self._update(model, update_fn=lambda e, m: m)
+
+ def forward(self, *args: Any, **kwargs: Any) -> Any: # type: ignore
+ """Forward pass with original model."""
+ if self.training or not self.use_ema_during_test:
+ return self.model(*args, **kwargs)
+ return self.ema_model(*args, **kwargs)
+
+
+class ModelExpEMAAdapter(ModelEMAAdapter):
+ """Exponential Moving Average (EMA) with exponential decay strategy.
+
+ Used by YOLOX.
+
+ Args:
+ model (nn.Module): Model to apply EMA.
+ decay (float): Decay factor for EMA. Defaults to 0.9998.
+ warmup_steps (int): Number of warmup steps for decay. Use a smaller
+ decay early in training and gradually anneal to the set decay value
+ to update the EMA model smoothly.
+ use_ema_during_test (bool): Use EMA model during testing. Defaults to
+ True.
+ device (torch.device | None): Device to use. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ decay: float = 0.9998,
+ warmup_steps: int = 2000,
+ use_ema_during_test: bool = True,
+ device: torch.device | None = None,
+ ):
+ """Init ModelEMAAdapter class."""
+ super().__init__(model, decay, use_ema_during_test, device)
+ assert (
+ warmup_steps > 0
+ ), f"warmup_steps must be greater than 0, got {warmup_steps}"
+ self.warmup_steps = warmup_steps
+
+ def update(self, steps: int) -> None:
+ """Update the internal EMA model."""
+ decay = self.decay * (
+ 1 - math.exp(-float(1 + steps) / self.warmup_steps)
+ )
+ self._update(
+ self.model,
+ update_fn=lambda e, m: decay * e + (1.0 - decay) * m,
+ )
diff --git a/vis4d/model/adapter/flops.py b/vis4d/model/adapter/flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..170178af8b51f462390dcd59214bc87967173b60
--- /dev/null
+++ b/vis4d/model/adapter/flops.py
@@ -0,0 +1,59 @@
+"""Adapter for counting flops in a model."""
+
+from __future__ import annotations
+
+from typing import Any
+
+from torch import nn
+
+from vis4d.engine.connectors import DataConnector
+
+# Ops to ignore from counting, including elementwise and reduction ops
+IGNORED_OPS = {
+ "aten::add",
+ "aten::add_",
+ "aten::argmax",
+ "aten::argsort",
+ "aten::batch_norm",
+ "aten::constant_pad_nd",
+ "aten::div",
+ "aten::div_",
+ "aten::exp",
+ "aten::log2",
+ "aten::max_pool2d",
+ "aten::meshgrid",
+ "aten::mul",
+ "aten::mul_",
+ "aten::neg",
+ "aten::nonzero_numpy",
+ "aten::reciprocal",
+ "aten::repeat_interleave",
+ "aten::rsub",
+ "aten::sigmoid",
+ "aten::sigmoid_",
+ "aten::softmax",
+ "aten::sort",
+ "aten::sqrt",
+ "aten::sub",
+ "torchvision::nms",
+}
+
+
+class FlopsModelAdapter(nn.Module):
+ """Adapter for the model to count flops."""
+
+ def __init__(
+ self, model: nn.Module, data_connector: DataConnector
+ ) -> None:
+ """Initialize the adapter."""
+ super().__init__()
+ self.model = model
+ self.data_connector = data_connector
+
+ def forward(self, *args: Any) -> Any: # type: ignore
+ """Forward pass through the model."""
+ data_dict = {}
+ for i, key in enumerate(self.data_connector.key_mapping):
+ data_dict[key] = args[0][i]
+
+ return self.model(**data_dict)
diff --git a/vis4d/model/cls/__init__.py b/vis4d/model/cls/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd4e60c12ac67a9c748ef85d3e49c4347f584d4b
--- /dev/null
+++ b/vis4d/model/cls/__init__.py
@@ -0,0 +1,6 @@
+"""Common classes and functions for classification models."""
+
+from .common import ClsOut
+from .vit import ViTClassifer
+
+__all__ = ["ViTClassifer", "ClsOut"]
diff --git a/vis4d/model/cls/common.py b/vis4d/model/cls/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..91aad295018f262a300ce581c6a3668310a66e3f
--- /dev/null
+++ b/vis4d/model/cls/common.py
@@ -0,0 +1,12 @@
+"""Common types for classification models."""
+
+from typing import NamedTuple
+
+import torch
+
+
+class ClsOut(NamedTuple):
+ """Output of the classification results."""
+
+ logits: torch.Tensor # (N, num_classes)
+ probs: torch.Tensor # (N, num_classes)
diff --git a/vis4d/model/cls/vit.py b/vis4d/model/cls/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea4e6366cbf5333d04736f63c9af43db55ae1ffc
--- /dev/null
+++ b/vis4d/model/cls/vit.py
@@ -0,0 +1,122 @@
+"""ViT for classification tasks."""
+
+from __future__ import annotations
+
+import timm.models.vision_transformer as _vision_transformer
+import torch
+from torch import nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.common.typing import ArgsType
+from vis4d.op.base.vit import VisionTransformer, ViT_PRESET
+
+from .common import ClsOut
+
+
+class ViTClassifer(nn.Module):
+ """ViT for classification tasks."""
+
+ def __init__(
+ self,
+ variant: str = "",
+ num_classes: int = 1000,
+ use_global_pooling: bool = False,
+ weights: str | None = None,
+ num_prefix_tokens: int = 1,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Initialize the classification ViT.
+
+ Args:
+ variant (str): Name of the ViT variant. Defaults to "". If the name
+ starts with "timm://", the variant will be loaded from timm's
+ model zoo. Otherwise, the variant will be loaded from the
+ ViT_PRESET dict. If the variant is empty, the default ViT
+ variant will be used. In all cases, the additional keyword
+ arguments will override the default arguments.
+ num_classes (int, optional): Number of classes. Defaults to 1000.
+ use_global_pooling (bool, optional): If to use global pooling.
+ Defaults to False. If set to True, the output of the ViT will
+ be averaged over the spatial dimensions. Otherwise, the first
+ token will be used for classification.
+ weights (str, optional): If to load pretrained weights. If set to
+ "timm", the weights will be loaded from timm's model zoo that
+ matches the variant. If a URL is provided, the weights will be
+ downloaded from the URL. Defaults to None, which means no
+ weights will be loaded.
+ num_prefix_tokens (int, optional): Number of prefix tokens.
+ Defaults to 1.
+ **kwargs: Keyword arguments passed to the ViT model.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.use_global_pooling = use_global_pooling
+ self.num_prefix_tokens = num_prefix_tokens
+
+ if variant != "":
+ assert variant in ViT_PRESET, (
+ f"Unknown ViT variant: {variant}. "
+ f"Available ViT variants are: {list(ViT_PRESET.keys())}"
+ )
+ preset_kwargs = ViT_PRESET[variant]
+ preset_kwargs["num_classes"] = num_classes
+ preset_kwargs.update(kwargs)
+ self.vit = VisionTransformer(**preset_kwargs) # type: ignore
+ else:
+ # Build ViT from scratch using kwargs
+ preset_kwargs = {}
+ self.vit = VisionTransformer(num_classes=num_classes, **kwargs)
+
+ # Classification head
+ embed_dim = kwargs.get(
+ "embed_dim", preset_kwargs.get("embed_dim", 768)
+ )
+ self.norm = (
+ nn.LayerNorm(embed_dim) if use_global_pooling else nn.Identity()
+ )
+ self.head = (
+ nn.Linear(embed_dim, num_classes)
+ if num_classes > 0
+ else nn.Identity()
+ )
+
+ # Load pretrain weights
+ if weights is not None:
+ if weights.startswith("timm://"):
+ weights = weights.removeprefix("timm://")
+ if "." in weights:
+ model_name, pretrain_tag = weights.split(".")
+ else:
+ model_name = weights
+ pretrain_tag = None
+ assert model_name in _vision_transformer.__dict__, (
+ f"Unknown Timm ViT weights: {model_name}. "
+ f"Available Timm ViT weights are: "
+ f"{list(_vision_transformer.__dict__.keys())}"
+ )
+ _model = _vision_transformer.__dict__[model_name](
+ pretrained=True, pretrained_cfg=pretrain_tag, **kwargs
+ )
+ self.vit.load_state_dict(_model.state_dict(), strict=False)
+ self.norm.load_state_dict(
+ _model.norm.state_dict(), strict=False
+ )
+ self.head.load_state_dict(
+ _model.head.state_dict(), strict=False
+ )
+ else:
+ load_model_checkpoint(self, weights)
+
+ def forward(self, images: torch.Tensor) -> ClsOut:
+ """Forward pass."""
+ feats = self.vit(images)
+ x = feats[-1]
+ if self.use_global_pooling:
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
+ else:
+ x = x[:, 0]
+ x = self.norm(x)
+ logits = self.head(x)
+ return ClsOut(
+ logits=logits, probs=torch.softmax(logits.detach(), dim=-1)
+ )
diff --git a/vis4d/model/detect/__init__.py b/vis4d/model/detect/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4d5a49062d47c24ce4d6621d5a7aa9bee7bf59
--- /dev/null
+++ b/vis4d/model/detect/__init__.py
@@ -0,0 +1 @@
+"""This module contains the model implementations of 2D detectors."""
diff --git a/vis4d/model/detect/faster_rcnn.py b/vis4d/model/detect/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..97a672bafe271e1fc6a59aac253f48b57f4c1c8c
--- /dev/null
+++ b/vis4d/model/detect/faster_rcnn.py
@@ -0,0 +1,178 @@
+"""Faster RCNN model implementation and runtime."""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.op.base import BaseModel, ResNet
+from vis4d.op.box.box2d import scale_and_clip_boxes
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
+from vis4d.op.detect.common import DetOut
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
+from vis4d.op.detect.rcnn import RoI2Det
+from vis4d.op.fpp.fpn import FPN
+
+REV_KEYS = [
+ (r"^backbone\.", "basemodel."),
+ (r"^rpn_head.rpn_reg\.", "faster_rcnn_head.rpn_head.rpn_box."),
+ (r"^rpn_head.rpn_", "faster_rcnn_head.rpn_head.rpn_"),
+ (r"^roi_head.bbox_head\.", "faster_rcnn_head.roi_head."),
+ (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
+ (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
+ (r"\.conv.weight", ".weight"),
+ (r"\.conv.bias", ".bias"),
+]
+
+
+class FasterRCNN(nn.Module):
+ """Faster RCNN model."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ basemodel: BaseModel | None = None,
+ faster_rcnn_head: FasterRCNNHead | None = None,
+ rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None,
+ weights: None | str = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of object categories.
+ basemodel (BaseModel, optional): Base model network. Defaults to
+ None. If None, will use ResNet50.
+ faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head.
+ Defaults to None. if None, will use default FasterRCNNHead.
+ rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN
+ bounding boxes. Defaults to None.
+ weights (str, optional): Weights to load for model. If set to
+ "mmdet", will load MMDetection pre-trained weights. Defaults to
+ None.
+ """
+ super().__init__()
+ self.basemodel = (
+ ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
+ if basemodel is None
+ else basemodel
+ )
+
+ self.fpn = FPN(self.basemodel.out_channels[2:], 256)
+
+ if faster_rcnn_head is None:
+ self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes)
+ else:
+ self.faster_rcnn_head = faster_rcnn_head
+
+ self.roi2det = RoI2Det(rcnn_box_decoder)
+
+ if weights is not None:
+ if weights == "mmdet":
+ weights = (
+ "mmdet://faster_rcnn/faster_rcnn_r50_fpn_1x_coco/"
+ "faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
+ )
+ if weights.startswith("mmdet://") or weights.startswith(
+ "bdd100k://"
+ ):
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
+ else:
+ load_model_checkpoint(self, weights)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ input_hw: list[tuple[int, int]],
+ boxes2d: None | list[torch.Tensor] = None,
+ boxes2d_classes: None | list[torch.Tensor] = None,
+ original_hw: None | list[tuple[int, int]] = None,
+ ) -> FRCNNOut | DetOut:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images.
+ input_hw (list[tuple[int, int]]): Input image resolutions.
+ boxes2d (None | list[torch.Tensor], optional): Bounding box labels.
+ Required for training. Defaults to None.
+ boxes2d_classes (None | list[torch.Tensor], optional): Class
+ labels. Required for training. Defaults to None.
+ original_hw (None | list[tuple[int, int]], optional): Original
+ image resolutions (before padding and resizing). Required for
+ testing. Defaults to None.
+
+ Returns:
+ FRCNNOut | DetOut: Either raw model outputs (for training) or
+ predicted outputs (for testing).
+ """
+ if self.training:
+ assert boxes2d is not None and boxes2d_classes is not None
+ return self.forward_train(
+ images, input_hw, boxes2d, boxes2d_classes
+ )
+ assert original_hw is not None
+ return self.forward_test(images, input_hw, original_hw)
+
+ def __call__(
+ self,
+ images: torch.Tensor,
+ input_hw: list[tuple[int, int]],
+ boxes2d: None | list[torch.Tensor] = None,
+ boxes2d_classes: None | list[torch.Tensor] = None,
+ original_hw: None | list[tuple[int, int]] = None,
+ ) -> FRCNNOut | DetOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ images, input_hw, boxes2d, boxes2d_classes, original_hw
+ )
+
+ def forward_train(
+ self,
+ images: torch.Tensor,
+ images_hw: list[tuple[int, int]],
+ target_boxes: list[torch.Tensor],
+ target_classes: list[torch.Tensor],
+ ) -> FRCNNOut:
+ """Forward training stage.
+
+ Args:
+ images (torch.Tensor): Input images.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ target_boxes (list[torch.Tensor]): Bounding box labels.
+ target_classes (list[torch.Tensor]): Class labels.
+
+ Returns:
+ FRCNNOut: Raw model outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ return self.faster_rcnn_head(
+ features, images_hw, target_boxes, target_classes
+ )
+
+ def forward_test(
+ self,
+ images: torch.Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Forward testing stage.
+
+ Args:
+ images (torch.Tensor): Input images.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ original_hw (list[tuple[int, int]]): Original image resolutions
+ (before padding and resizing).
+
+ Returns:
+ DetOut: Predicted outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ outs = self.faster_rcnn_head(features, images_hw)
+ boxes, scores, class_ids = self.roi2det(
+ *outs.roi, outs.proposals.boxes, images_hw
+ )
+
+ for i, boxs in enumerate(boxes):
+ boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
+
+ return DetOut(boxes, scores, class_ids)
diff --git a/vis4d/model/detect/mask_rcnn.py b/vis4d/model/detect/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..22cc29b7632a36df5a4c21b855c0afac53a46d6f
--- /dev/null
+++ b/vis4d/model/detect/mask_rcnn.py
@@ -0,0 +1,219 @@
+"""Mask RCNN model implementation and runtime."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+from torch import nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.op.base import BaseModel, ResNet
+from vis4d.op.box.box2d import apply_mask, scale_and_clip_boxes
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
+from vis4d.op.detect.common import DetOut
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
+from vis4d.op.detect.mask_rcnn import (
+ Det2Mask,
+ MaskOut,
+ MaskRCNNHead,
+ MaskRCNNHeadOut,
+)
+from vis4d.op.detect.rcnn import RoI2Det
+from vis4d.op.fpp.fpn import FPN
+
+
+class MaskDetectionOut(NamedTuple):
+ """Mask detection output."""
+
+ boxes: DetOut
+ masks: MaskOut
+
+
+class MaskRCNNOut(NamedTuple):
+ """Mask RCNN output."""
+
+ boxes: FRCNNOut
+ masks: MaskRCNNHeadOut
+
+
+REV_KEYS = [
+ (r"^backbone\.", "basemodel."),
+ (r"^rpn_head.rpn_reg\.", "rpn_head.rpn_box."),
+ (r"^roi_head.bbox_head\.", "roi_head."),
+ (r"^roi_head.mask_head\.", "mask_head."),
+ (r"^convs\.", "mask_head.convs."),
+ (r"^upsample\.", "mask_head.upsample."),
+ (r"^conv_logits\.", "mask_head.conv_logits."),
+ (r"^roi_head\.", "faster_rcnn_head.roi_head."),
+ (r"^rpn_head\.", "faster_rcnn_head.rpn_head."),
+ (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
+ (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
+ (r"\.conv.weight", ".weight"),
+ (r"\.conv.bias", ".bias"),
+]
+
+
+class MaskRCNN(nn.Module):
+ """Mask RCNN model.
+
+ Args:
+ num_classes (int): Number of classes.
+ basemodel (BaseModel, optional): Base model network. Defaults to
+ None. If None, will use ResNet50.
+ faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head.
+ Defaults to None. if None, will use default FasterRCNNHead.
+ mask_head (MaskRCNNHead, optional): Mask RCNN head. Defaults to
+ None. if None, will use default MaskRCNNHead.
+ rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN
+ bounding boxes. Defaults to None.
+ no_overlap (bool, optional): Whether to remove overlapping pixels
+ between masks. Defaults to False.
+ weights (None | str, optional): Weights to load for model. If set
+ to "mmdet", will load MMDetection pre-trained weights.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ basemodel: BaseModel | None = None,
+ faster_rcnn_head: FasterRCNNHead | None = None,
+ mask_head: MaskRCNNHead | None = None,
+ rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None,
+ no_overlap: bool = False,
+ weights: None | str = None,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.basemodel = (
+ ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
+ if basemodel is None
+ else basemodel
+ )
+
+ self.fpn = FPN(self.basemodel.out_channels[2:], 256)
+
+ if faster_rcnn_head is None:
+ self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes)
+ else:
+ self.faster_rcnn_head = faster_rcnn_head
+
+ if mask_head is None:
+ self.mask_head = MaskRCNNHead(num_classes=num_classes)
+ else:
+ self.mask_head = mask_head
+
+ self.transform_outs = RoI2Det(rcnn_box_decoder)
+ self.det2mask = Det2Mask(no_overlap=no_overlap)
+
+ if weights is not None:
+ if weights == "mmdet":
+ weights = (
+ "mmdet://mask_rcnn/mask_rcnn_r50_fpn_2x_coco/"
+ "mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_"
+ "20200505_003907-3e542a40.pth"
+ )
+ if weights.startswith("mmdet://") or weights.startswith(
+ "bdd100k://"
+ ):
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
+ else:
+ load_model_checkpoint(self, weights)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ input_hw: list[tuple[int, int]],
+ boxes2d: None | list[torch.Tensor] = None,
+ boxes2d_classes: None | list[torch.Tensor] = None,
+ original_hw: None | list[tuple[int, int]] = None,
+ ) -> MaskRCNNOut | MaskDetectionOut:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images.
+ input_hw (list[tuple[int, int]]): Input image resolutions.
+ boxes2d (None | list[torch.Tensor], optional): Bounding box
+ labels. Required for training. Defaults to None.
+ boxes2d_classes (None | list[torch.Tensor], optional): Class
+ labels. Required for training. Defaults to None.
+ original_hw (None | list[tuple[int, int]], optional): Original
+ image resolutions (before padding and resizing). Required for
+ testing. Defaults to None.
+
+ Returns:
+ MaskRCNNOut | MaskDetectionOut: Either raw model
+ outputs (for training) or predicted outputs (for testing).
+ """
+ if self.training:
+ assert boxes2d is not None and boxes2d_classes is not None
+ return self.forward_train(
+ images, input_hw, boxes2d, boxes2d_classes
+ )
+ assert original_hw is not None
+ return self.forward_test(images, input_hw, original_hw)
+
+ def forward_train(
+ self,
+ images: torch.Tensor,
+ images_hw: list[tuple[int, int]],
+ target_boxes: list[torch.Tensor],
+ target_classes: list[torch.Tensor],
+ ) -> MaskRCNNOut:
+ """Forward training stage.
+
+ Args:
+ images (torch.Tensor): Input images.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ target_boxes (list[torch.Tensor]): Bounding box labels. Required
+ for training. Defaults to None.
+ target_classes (list[torch.Tensor]): Class labels. Required for
+ training. Defaults to None.
+
+ Returns:
+ MaskRCNNOut: Raw model outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ outputs = self.faster_rcnn_head(
+ features, images_hw, target_boxes, target_classes
+ )
+ assert outputs.sampled_proposals is not None
+ assert outputs.sampled_targets is not None
+ pos_proposals = apply_mask(
+ [torch.eq(label, 1) for label in outputs.sampled_targets.labels],
+ outputs.sampled_proposals.boxes,
+ )[0]
+ mask_outs = self.mask_head(features, pos_proposals)
+ return MaskRCNNOut(outputs, mask_outs)
+
+ def forward_test(
+ self,
+ images: torch.Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ ) -> MaskDetectionOut:
+ """Forward testing stage.
+
+ Args:
+ images (torch.Tensor): Input images.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ original_hw (list[tuple[int, int]]): Original image resolutions
+ (before padding and resizing).
+
+ Returns:
+ MaskDetectionOut: Predicted outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ outs = self.faster_rcnn_head(features, images_hw)
+ boxes, scores, class_ids = self.transform_outs(
+ *outs.roi, outs.proposals.boxes, images_hw
+ )
+ mask_outs = self.mask_head(features, boxes)
+ for i, boxs in enumerate(boxes):
+ boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
+ mask_preds = [m.sigmoid() for m in mask_outs.mask_pred]
+ masks = self.det2mask(
+ mask_preds, boxes, scores, class_ids, original_hw
+ )
+ return MaskDetectionOut(DetOut(boxes, scores, class_ids), masks)
diff --git a/vis4d/model/detect/retinanet.py b/vis4d/model/detect/retinanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..204bd866641d8c085ad40ac21a025bbfb8fcdadb
--- /dev/null
+++ b/vis4d/model/detect/retinanet.py
@@ -0,0 +1,193 @@
+"""RetinaNet model implementation and runtime."""
+
+from __future__ import annotations
+
+from torch import Tensor, nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.common.typing import LossesType
+from vis4d.op.base.resnet import ResNet
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.box2d import scale_and_clip_boxes
+from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
+from vis4d.op.box.matchers import Matcher
+from vis4d.op.box.samplers import Sampler
+from vis4d.op.detect.common import DetOut
+from vis4d.op.detect.retinanet import (
+ Dense2Det,
+ RetinaNetHead,
+ RetinaNetHeadLoss,
+ RetinaNetOut,
+)
+from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock
+
+REV_KEYS = [
+ (r"^backbone\.", "basemodel."),
+ (r"^bbox_head\.", "retinanet_head."),
+ (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
+ (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
+ (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."),
+ (r"^fpn.layer_blocks.4\.", "fpn.extra_blocks.convs.1."),
+ (r"\.conv.weight", ".weight"),
+ (r"\.conv.bias", ".bias"),
+]
+
+
+class RetinaNet(nn.Module):
+ """RetinaNet wrapper class for checkpointing etc."""
+
+ def __init__(self, num_classes: int, weights: None | str = None) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of classes.
+ weights (None | str, optional): Weights to load for model. If
+ set to "mmdet", will load MMDetection pre-trained weights.
+ Defaults to None.
+ """
+ super().__init__()
+ self.basemodel = ResNet(
+ "resnet50", pretrained=True, trainable_layers=3
+ )
+ self.fpn = FPN(
+ self.basemodel.out_channels[3:],
+ 256,
+ ExtraFPNBlock(2, 2048, 256, add_extra_convs="on_input"),
+ start_index=3,
+ )
+ self.retinanet_head = RetinaNetHead(
+ num_classes=num_classes, in_channels=256
+ )
+ self.transform_outs = Dense2Det(
+ self.retinanet_head.anchor_generator,
+ self.retinanet_head.box_decoder,
+ num_pre_nms=1000,
+ max_per_img=100,
+ nms_threshold=0.5,
+ score_thr=0.05,
+ )
+
+ if weights == "mmdet":
+ weights = (
+ "mmdet://retinanet/retinanet_r50_fpn_2x_coco/"
+ "retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth"
+ )
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
+ elif weights is not None:
+ load_model_checkpoint(self, weights)
+
+ def forward(
+ self,
+ images: Tensor,
+ input_hw: None | list[tuple[int, int]] = None,
+ original_hw: None | list[tuple[int, int]] = None,
+ ) -> RetinaNetOut | DetOut:
+ """Forward pass.
+
+ Args:
+ images (Tensor): Input images.
+ input_hw (None | list[tuple[int, int]], optional): Input image
+ resolutions. Defaults to None.
+ original_hw (None | list[tuple[int, int]], optional): Original
+ image resolutions (before padding and resizing). Required for
+ testing. Defaults to None.
+
+ Returns:
+ RetinaNetOut | DetOut: Either raw model outputs (for training) or
+ predicted outputs (for testing).
+ """
+ if self.training:
+ return self.forward_train(images)
+ assert input_hw is not None and original_hw is not None
+ return self.forward_test(images, input_hw, original_hw)
+
+ def forward_train(self, images: Tensor) -> RetinaNetOut:
+ """Forward training stage.
+
+ Args:
+ images (Tensor): Input images.
+
+ Returns:
+ RetinaNetOut: Raw model outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ return self.retinanet_head(features[-5:])
+
+ def forward_test(
+ self,
+ images: Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Forward testing stage.
+
+ Args:
+ images (Tensor): Input images.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ original_hw (list[tuple[int, int]]): Original image resolutions
+ (before padding and resizing).
+
+ Returns:
+ DetOut: Predicted outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ outs = self.retinanet_head(features[-5:])
+ boxes, scores, class_ids = self.transform_outs(
+ cls_outs=outs.cls_score,
+ reg_outs=outs.bbox_pred,
+ images_hw=images_hw,
+ )
+ for i, boxs in enumerate(boxes):
+ boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
+ return DetOut(boxes, scores, class_ids)
+
+
+class RetinaNetLoss(nn.Module):
+ """RetinaNet Loss."""
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ box_matcher: Matcher,
+ box_sampler: Sampler,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ anchor_generator (AnchorGenerator): Anchor generator for RPN.
+ box_encoder (BoxEncoder2D): Bounding box encoder.
+ box_matcher (BaseMatcher): Bounding box matcher.
+ box_sampler (BaseSampler): Bounding box sampler.
+ """
+ super().__init__()
+ self.retinanet_loss = RetinaNetHeadLoss(
+ anchor_generator, box_encoder, box_matcher, box_sampler
+ )
+
+ def forward(
+ self,
+ outputs: RetinaNetOut,
+ images_hw: list[tuple[int, int]],
+ target_boxes: list[Tensor],
+ target_classes: list[Tensor],
+ ) -> LossesType:
+ """Forward of loss function.
+
+ Args:
+ outputs (RetinaNetOut): Raw model outputs.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ target_boxes (list[Tensor]): Bounding box labels.
+ target_classes (list[Tensor]): Class labels.
+
+ Returns:
+ LossesType: Dictionary of model losses.
+ """
+ losses = self.retinanet_loss(
+ outputs.cls_score,
+ outputs.bbox_pred,
+ target_boxes,
+ images_hw,
+ target_classes,
+ )
+ return losses._asdict()
diff --git a/vis4d/model/detect/yolox.py b/vis4d/model/detect/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bfd9f7195977948a3b10243aaece0426390edb6
--- /dev/null
+++ b/vis4d/model/detect/yolox.py
@@ -0,0 +1,154 @@
+"""YOLOX model implementation and runtime."""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.op.base import BaseModel, CSPDarknet
+from vis4d.op.box.box2d import scale_and_clip_boxes
+from vis4d.op.detect.common import DetOut
+from vis4d.op.detect.yolox import YOLOXHead, YOLOXOut, YOLOXPostprocess
+from vis4d.op.fpp import YOLOXPAFPN, FeaturePyramidProcessing
+
+REV_KEYS = [
+ (r"^backbone\.", "basemodel."),
+ (r"^bbox_head\.", "yolox_head."),
+ (r"^neck\.", "fpn."),
+ (r"\.bn\.", ".norm."),
+ (r"\.conv.weight", ".weight"),
+ (r"\.conv.bias", ".bias"),
+]
+
+
+class YOLOX(nn.Module):
+ """YOLOX detector."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ basemodel: BaseModel | None = None,
+ fpn: FeaturePyramidProcessing | None = None,
+ yolox_head: YOLOXHead | None = None,
+ postprocessor: YOLOXPostprocess | None = None,
+ weights: None | str = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of classes.
+ basemodel (BaseModel, optional): Base model. Defaults to None. If
+ None, will use CSPDarknet.
+ fpn (FeaturePyramidProcessing, optional): Feature Pyramid
+ Processing. Defaults to None. If None, will use YOLOXPAFPN.
+ yolox_head (YOLOXHead, optional): YOLOX head. Defaults to None. If
+ None, will use YOLOXHead.
+ postprocessor (YOLOXPostprocess, optional): Post processor.
+ Defaults to None. If None, will use YOLOXPostprocess.
+ weights (None | str, optional): Weights to load for model. If
+ set to "mmdet", will load MMDetection pre-trained weights.
+ Defaults to None.
+ """
+ super().__init__()
+ self.basemodel = (
+ CSPDarknet(deepen_factor=0.33, widen_factor=0.5)
+ if basemodel is None
+ else basemodel
+ )
+ self.fpn = (
+ YOLOXPAFPN([128, 256, 512], 128, num_csp_blocks=1)
+ if fpn is None
+ else fpn
+ )
+ self.yolox_head = (
+ YOLOXHead(
+ num_classes=num_classes, in_channels=128, feat_channels=128
+ )
+ if yolox_head is None
+ else yolox_head
+ )
+ self.postprocessor = (
+ YOLOXPostprocess(
+ self.yolox_head.point_generator,
+ self.yolox_head.box_decoder,
+ nms_threshold=0.65,
+ score_thr=0.01,
+ )
+ if postprocessor is None
+ else postprocessor
+ )
+
+ if weights is not None:
+ if weights.startswith("mmdet://") or weights.startswith(
+ "bdd100k://"
+ ):
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
+ else:
+ load_model_checkpoint(self, weights)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ input_hw: None | list[tuple[int, int]] = None,
+ original_hw: None | list[tuple[int, int]] = None,
+ ) -> YOLOXOut | DetOut:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images.
+ input_hw (None | list[tuple[int, int]], optional): Input image
+ resolutions. Defaults to None.
+ original_hw (None | list[tuple[int, int]], optional): Original
+ image resolutions (before padding and resizing). Required for
+ testing. Defaults to None.
+
+ Returns:
+ YOLOXOut | DetOut: Either raw model outputs (for training) or
+ predicted outputs (for testing).
+ """
+ if self.training:
+ return self.forward_train(images)
+ assert input_hw is not None and original_hw is not None
+ return self.forward_test(images, input_hw, original_hw)
+
+ def forward_train(self, images: torch.Tensor) -> YOLOXOut:
+ """Forward training stage.
+
+ Args:
+ images (torch.Tensor): Input images.
+
+ Returns:
+ YOLOXOut: Raw model outputs.
+ """
+ features = self.fpn(self.basemodel(images.contiguous()))
+ return self.yolox_head(features[-3:])
+
+ def forward_test(
+ self,
+ images: torch.Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Forward testing stage.
+
+ Args:
+ images (torch.Tensor): Input images.
+ images_hw (list[tuple[int, int]]): Input image resolutions.
+ original_hw (list[tuple[int, int]]): Original image resolutions
+ (before padding and resizing).
+
+ Returns:
+ DetOut: Predicted outputs.
+ """
+ features = self.fpn(self.basemodel(images))
+ outs = self.yolox_head(features[-3:])
+ boxes, scores, class_ids = self.postprocessor(
+ cls_outs=outs.cls_score,
+ reg_outs=outs.bbox_pred,
+ obj_outs=outs.objectness,
+ images_hw=images_hw,
+ )
+ for i, boxs in enumerate(boxes):
+ boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
+ return DetOut(boxes, scores, class_ids)
diff --git a/vis4d/model/detect3d/__init__.py b/vis4d/model/detect3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44403d249f007a38297424080538f51d87b50463
--- /dev/null
+++ b/vis4d/model/detect3d/__init__.py
@@ -0,0 +1 @@
+"""3D Detection Models."""
diff --git a/vis4d/model/detect3d/bevformer.py b/vis4d/model/detect3d/bevformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f6d6621283ac31d5f88e6d6648801eccfe7bac
--- /dev/null
+++ b/vis4d/model/detect3d/bevformer.py
@@ -0,0 +1,162 @@
+"""BEVFromer model implementation.
+
+This file composes the operations associated with BEVFormer
+`https://arxiv.org/abs/2203.17270` into the full model implementation.
+"""
+
+from __future__ import annotations
+
+import copy
+from typing import TypedDict
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.op.base import BaseModel
+from vis4d.op.detect3d.bevformer import BEVFormerHead, GridMask
+from vis4d.op.detect3d.common import Detect3DOut
+from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock
+
+REV_KEYS = [
+ (r"^img_backbone\.", "basemodel."),
+ (r"^img_neck.lateral_convs\.", "fpn.inner_blocks."),
+ (r"^img_neck.fpn_convs\.", "fpn.layer_blocks."),
+ (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."),
+ (r"\.conv.weight", ".weight"),
+ (r"\.conv.bias", ".bias"),
+]
+
+
+class PrevFrameInfo(TypedDict):
+ """Previous frame information."""
+
+ scene_name: str
+ prev_bev: Tensor | None
+ prev_pos: Tensor
+ prev_angle: Tensor
+
+
+class BEVFormer(nn.Module):
+ """BEVFormer 3D Detector."""
+
+ def __init__(
+ self,
+ basemodel: BaseModel,
+ fpn: FPN | None = None,
+ pts_bbox_head: BEVFormerHead | None = None,
+ weights: str | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ basemodel (BaseModel): Base model network.
+ fpn (FPN, optional): Feature Pyramid Network. Defaults to None. If
+ None, a default FPN will be used.
+ pts_bbox_head (BEVFormerHead, optional): BEVFormer head. Defaults
+ to None. If None, a default BEVFormer head will be used.
+ weights (str, optional): Path to the checkpoint to load. Defaults
+ to None.
+ """
+ super().__init__()
+ self.basemodel = basemodel
+ self.fpn = fpn or FPN(
+ self.basemodel.out_channels[3:],
+ 256,
+ extra_blocks=ExtraFPNBlock(
+ extra_levels=1, in_channels=256, out_channels=256
+ ),
+ start_index=3,
+ )
+
+ self.grid_mask = GridMask(
+ True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
+ )
+
+ self.pts_bbox_head = pts_bbox_head or BEVFormerHead()
+
+ # Temporal information
+ self.prev_frame_info = PrevFrameInfo(
+ scene_name="",
+ prev_bev=None,
+ prev_pos=torch.zeros(3),
+ prev_angle=torch.zeros(1),
+ )
+
+ if weights is not None:
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
+
+ def extract_feat(self, images_list: list[Tensor]) -> list[Tensor]:
+ """Extract features of images."""
+ n = len(images_list) # N
+ b = images_list[0].shape[0] # B
+ images = torch.stack(images_list, dim=1) # [B, N, C, H, W]
+ images = images.view(-1, *images.shape[2:]) # [B*N, C, H, W]
+
+ # grid mask
+ if self.training:
+ images = self.grid_mask(images)
+
+ features = self.basemodel(images)
+ features = self.fpn(features)[self.fpn.start_index :]
+
+ img_feats = []
+ for img_feat in features:
+ _, c, h, w = img_feat.size()
+ img_feats.append(img_feat.view(b, n, c, h, w))
+
+ return img_feats
+
+ def forward(
+ self,
+ images: list[Tensor],
+ can_bus: list[list[float]],
+ scene_names: list[str],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: list[Tensor],
+ ) -> Detect3DOut:
+ """Forward."""
+ # Parse lidar extrinsics from LIDAR sensor data.
+ lidar_extrinsics_tensor = lidar_extrinsics[0]
+ can_bus_tensor = torch.tensor(
+ can_bus, dtype=torch.float32, device=images[0].device
+ )
+
+ if scene_names[0] != self.prev_frame_info["scene_name"]:
+ # the first sample of each scene is truncated
+ self.prev_frame_info["prev_bev"] = None
+
+ # update idx
+ self.prev_frame_info["scene_name"] = scene_names[0]
+
+ # Get the delta of ego position and angle between two timestamps.
+ tmp_pos = copy.deepcopy(can_bus_tensor[0][:3])
+ tmp_angle = copy.deepcopy(can_bus_tensor[0][-1])
+ if self.prev_frame_info["prev_bev"] is not None:
+ can_bus_tensor[0][:3] -= self.prev_frame_info["prev_pos"]
+ can_bus_tensor[0][-1] -= self.prev_frame_info["prev_angle"]
+ else:
+ can_bus_tensor[0][:3] = 0
+ can_bus_tensor[0][-1] = 0
+
+ images_hw = (int(images[0].shape[-2]), int(images[0].shape[-1]))
+ img_feats = self.extract_feat(images)
+
+ out, bev_embed = self.pts_bbox_head(
+ img_feats,
+ can_bus_tensor,
+ images_hw,
+ cam_intrinsics,
+ cam_extrinsics,
+ lidar_extrinsics_tensor,
+ prev_bev=self.prev_frame_info["prev_bev"],
+ )
+
+ # During inference, we save the BEV features and ego motion of each
+ # timestamp.
+ self.prev_frame_info["prev_pos"] = tmp_pos
+ self.prev_frame_info["prev_angle"] = tmp_angle
+ self.prev_frame_info["prev_bev"] = bev_embed
+
+ return out
diff --git a/vis4d/model/motion/__init__.py b/vis4d/model/motion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9db6284e04c7125a12886cd30b4c96b3556552d
--- /dev/null
+++ b/vis4d/model/motion/__init__.py
@@ -0,0 +1 @@
+"""Motion models."""
diff --git a/vis4d/model/motion/velo_lstm.py b/vis4d/model/motion/velo_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c08ae5e874b30d984467731002394a8ccaa499c
--- /dev/null
+++ b/vis4d/model/motion/velo_lstm.py
@@ -0,0 +1,309 @@
+"""VeloLSTM 3D motion model."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.op.geometry.rotation import acute_angle, normalize_angle
+from vis4d.op.layer.weight_init import xavier_init
+
+
+class VeloLSTMOut(NamedTuple):
+ """VeloLSTM output."""
+
+ loc_preds: Tensor
+ loc_refines: Tensor
+
+
+class VeloLSTM(nn.Module):
+ """Estimating object location in world coordinates.
+
+ Prediction LSTM:
+ Input: 5 frames velocity
+ Output: Next frame location
+ Updating LSTM:
+ Input: predicted location and observed location
+ Output: Refined location
+ """
+
+ def __init__(
+ self,
+ num_frames: int = 5,
+ feature_dim: int = 64,
+ hidden_size: int = 128,
+ num_layers: int = 2,
+ loc_dim: int = 7,
+ dropout: float = 0.1,
+ weights: str | None = None,
+ ) -> None:
+ """Init."""
+ super().__init__()
+ self.num_frames = num_frames
+ self.feature_dim = feature_dim
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.loc_dim = loc_dim
+
+ self.vel2feat = nn.Linear(
+ loc_dim,
+ feature_dim,
+ )
+
+ self.pred_lstm = nn.LSTM(
+ input_size=feature_dim,
+ hidden_size=hidden_size,
+ dropout=dropout,
+ num_layers=num_layers,
+ )
+
+ self.pred2atten = nn.Linear(
+ hidden_size,
+ loc_dim,
+ bias=False,
+ )
+
+ self.conf2feat = nn.Linear(
+ 1,
+ feature_dim,
+ bias=False,
+ )
+
+ self.refine_lstm = nn.LSTM(
+ input_size=3 * feature_dim,
+ hidden_size=hidden_size,
+ dropout=dropout,
+ num_layers=num_layers,
+ )
+
+ self.conf2atten = nn.Linear(
+ hidden_size,
+ loc_dim,
+ bias=False,
+ )
+
+ self._init_weights()
+
+ if weights is not None:
+ load_model_checkpoint(
+ self,
+ weights,
+ map_location="cpu",
+ rev_keys=[(r"^model\.", ""), (r"^module\.", "")],
+ )
+
+ def _init_weights(self) -> None:
+ """Initialize model weights."""
+ xavier_init(self.vel2feat)
+ xavier_init(self.pred2atten)
+ xavier_init(self.conf2feat)
+ xavier_init(self.conf2atten)
+ init_lstm_module(self.pred_lstm)
+ init_lstm_module(self.refine_lstm)
+
+ def init_hidden(
+ self, device: torch.device, batch_size: int = 1
+ ) -> tuple[Tensor, Tensor]:
+ """Initializae hidden state.
+
+ The axes semantics are (num_layers, minibatch_size, hidden_dim)
+ """
+ return (
+ torch.zeros(self.num_layers, batch_size, self.hidden_size).to(
+ device
+ ),
+ torch.zeros(self.num_layers, batch_size, self.hidden_size).to(
+ device
+ ),
+ )
+
+ def refine(
+ self,
+ location: Tensor,
+ observation: Tensor,
+ prev_location: Tensor,
+ confidence: Tensor,
+ hc_0: tuple[Tensor, Tensor],
+ ) -> tuple[Tensor, tuple[Tensor, Tensor]]:
+ """Refine predicted location using single frame estimation at t+1.
+
+ Input:
+ location: (num_batch x loc_dim), location from prediction
+ observation: (num_batch x loc_dim), location from single frame
+ estimation
+ prev_location: (num_batch x loc_dim), refined location
+ confidence: (num_batch X 1), depth estimation confidence
+ hc_0: (num_layers, num_batch, hidden_size), tuple of hidden and
+ cell
+ Middle:
+ loc_embed: (1, num_batch x feature_dim), predicted location feature
+ obs_embed: (1, num_batch x feature_dim), single frame location
+ feature
+ conf_embed: (1, num_batch x feature_dim), depth estimation
+ confidence feature
+ embed: (1, num_batch x 2*feature_dim), location feature
+ out: (1 x num_batch x hidden_size), lstm output
+ Output:
+ hc_n: (num_layers, num_batch, hidden_size), tuple of updated
+ hidden, cell
+ output_pred: (num_batch x loc_dim), predicted location
+ """
+ num_batch = location.shape[0]
+
+ pred_vel = location - prev_location
+ obsv_vel = observation - prev_location
+
+ # Embed feature to hidden_size
+ loc_embed = self.vel2feat(pred_vel).view(num_batch, self.feature_dim)
+ obs_embed = self.vel2feat(obsv_vel).view(num_batch, self.feature_dim)
+ conf_embed = self.conf2feat(confidence).view(
+ num_batch, self.feature_dim
+ )
+ embed = torch.cat(
+ [
+ loc_embed,
+ obs_embed,
+ conf_embed,
+ ],
+ dim=1,
+ ).view(1, num_batch, 3 * self.feature_dim)
+
+ out, (h_n, c_n) = self.refine_lstm(embed, hc_0)
+
+ delta_vel_atten = torch.sigmoid(self.conf2atten(out)).view(
+ num_batch, self.loc_dim
+ )
+
+ output_pred = (
+ delta_vel_atten * obsv_vel
+ + (1.0 - delta_vel_atten) * pred_vel
+ + prev_location
+ )
+
+ return output_pred, (h_n, c_n)
+
+ def predict(
+ self,
+ vel_history: Tensor,
+ location: Tensor,
+ hc_0: tuple[Tensor, Tensor],
+ ) -> tuple[Tensor, tuple[Tensor, Tensor]]:
+ """Predict location at t+1 using updated location at t.
+
+ Input:
+ vel_history: (num_seq, num_batch, loc_dim), velocity from previous
+ num_seq updates
+ location: (num_batch, loc_dim), location from previous update
+ hc_0: (num_layers, num_batch, hidden_size), tuple of hidden and
+ cell
+ Middle:
+ embed: (num_seq, num_batch x feature_dim), location feature
+ out: (num_seq x num_batch x hidden_size), lstm output
+ attention_logit: (num_seq x num_batch x loc_dim), the predicted
+ residual
+ Output:
+ hc_n: (num_layers, num_batch, hidden_size), tuple of updated
+ hidden, cell
+ output_pred: (num_batch x loc_dim), predicted location
+ """
+ num_seq, num_batch, _ = vel_history.shape
+
+ # Embed feature to hidden_size
+ embed = self.vel2feat(vel_history).view(
+ num_seq, num_batch, self.feature_dim
+ )
+
+ out, (h_n, c_n) = self.pred_lstm(embed, hc_0)
+
+ attention_logit = self.pred2atten(out).view(
+ num_seq, num_batch, self.loc_dim
+ )
+ attention = torch.softmax(attention_logit, dim=0)
+
+ output_pred = torch.sum(attention * vel_history, dim=0) + location
+
+ return output_pred, (h_n, c_n)
+
+ def forward(self, pred_traj: Tensor) -> VeloLSTMOut:
+ """Forward of QD3DTrackGraph in training stage."""
+ loc_preds_list = []
+ loc_refines_list = []
+
+ hidden_predict = self.init_hidden(
+ pred_traj.device, batch_size=pred_traj.shape[0]
+ )
+ hidden_refine = self.init_hidden(
+ pred_traj.device, batch_size=pred_traj.shape[0]
+ )
+
+ vel_history = pred_traj.new_zeros(
+ self.num_frames, pred_traj.shape[0], self.loc_dim
+ )
+
+ # Starting condition
+ pred_traj[:, :, 6] = normalize_angle(pred_traj[:, :, 6])
+ prev_refine = pred_traj[:, 0, : self.loc_dim]
+ loc_pred = pred_traj[:, 1, : self.loc_dim]
+
+ # LSTM
+ for i in range(1, pred_traj.shape[1]):
+ # Update
+ loc_pred[:, 6] = normalize_angle(loc_pred[:, 6])
+
+ for batch_id in range(pred_traj.shape[0]):
+ # acute angle
+ loc_pred[batch_id, 6] = acute_angle(
+ loc_pred[batch_id, 6], pred_traj[batch_id, i, 6]
+ )
+
+ loc_refine, hidden_refine = self.refine(
+ loc_pred.detach().clone(),
+ pred_traj[:, i, : self.loc_dim],
+ prev_refine.detach().clone(),
+ pred_traj[:, i, -1].unsqueeze(-1),
+ hidden_refine,
+ )
+ loc_refine[:, 6] = normalize_angle(loc_refine[:, 6])
+
+ if i == 1:
+ vel_history = torch.cat(
+ [(loc_refine - prev_refine).unsqueeze(0)] * self.num_frames
+ )
+ else:
+ vel_history = torch.cat(
+ [vel_history[1:], (loc_refine - prev_refine).unsqueeze(0)],
+ dim=0,
+ )
+ prev_refine = loc_refine
+
+ # Predict
+ loc_pred, hidden_predict = self.predict(
+ vel_history, loc_refine.detach().clone(), hidden_predict
+ )
+ loc_pred[:, 6] = normalize_angle(loc_pred[:, 6])
+
+ loc_refines_list.append(loc_refine)
+ loc_preds_list.append(loc_pred)
+
+ loc_refines = torch.cat(loc_refines_list, dim=1).view(
+ pred_traj.shape[0], -1, self.loc_dim
+ )
+ loc_preds = torch.cat(loc_preds_list, dim=1).view(
+ pred_traj.shape[0], -1, self.loc_dim
+ )
+
+ return VeloLSTMOut(loc_preds=loc_preds, loc_refines=loc_refines)
+
+
+def init_lstm_module(layer: nn.Module) -> None:
+ """Initialize LSTM weights and biases."""
+ for name, param in layer.named_parameters():
+ if "weight_ih" in name:
+ torch.nn.init.xavier_uniform_(param.data)
+ elif "weight_hh" in name:
+ torch.nn.init.orthogonal_(param.data)
+ elif "bias" in name:
+ param.data.fill_(0)
diff --git a/vis4d/model/seg/__init__.py b/vis4d/model/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa068c2da72651b6be0701a3f570ca51b462c846
--- /dev/null
+++ b/vis4d/model/seg/__init__.py
@@ -0,0 +1 @@
+"""Semantic segmentation models."""
diff --git a/vis4d/model/seg/fcn_resnet.py b/vis4d/model/seg/fcn_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5a7ff6e1feca6daba60a6b85777f9ce4d1f9a91
--- /dev/null
+++ b/vis4d/model/seg/fcn_resnet.py
@@ -0,0 +1,85 @@
+"""FCN Resnet Implementation."""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from vis4d.op.base.resnet import ResNet
+from vis4d.op.seg.fcn import FCNHead, FCNOut
+
+REV_KEYS = [
+ (r"^backbone\.", "basemodel."),
+ (r"^aux_classifier\.", "fcn.heads.0."),
+ (r"^classifier\.", "fcn.heads.1."),
+]
+
+
+class FCNResNet(nn.Module):
+ """FCN with ResNet basemodel for semantic segmentation."""
+
+ def __init__(
+ self,
+ base_model: str = "resnet50",
+ num_classes: int = 21,
+ resize: None | tuple[int, int] = (520, 520),
+ ) -> None:
+ """FCN with ResNet basemodel, following torchvision implementation.
+
+ _.
+
+ model: FCNResNet(base_model="resnet50")
+ - dataset: Coco2017
+ - recipe: vis4d/model/segment/FCNResNet_coco_training.py
+ - metrics:
+ - mIoU: 62.52
+ - Acc: 90.50
+ """
+ super().__init__()
+ if base_model.startswith("resnet"):
+ self.basemodel = ResNet(
+ base_model,
+ pretrained=True,
+ replace_stride_with_dilation=[False, True, True],
+ )
+ else:
+ raise ValueError("base model not supported!")
+ self.fcn = FCNHead(
+ self.basemodel.out_channels[4:], num_classes, resize=resize
+ )
+
+ def forward_train(self, images: torch.Tensor) -> FCNOut:
+ """Forward pass for training.
+
+ Args:
+ images (torch.Tensor): Input images.
+
+ Returns:
+ FCNOut: Raw model predictions.
+ """
+ return self.forward(images)
+
+ def forward_test(self, images: torch.Tensor) -> FCNOut:
+ """Forward pass for testing.
+
+ Args:
+ images (torch.Tensor): Input images.
+
+ Returns:
+ FCNOut: Raw model predictions.
+ """
+ return self.forward(images)
+
+ def forward(self, images: torch.Tensor) -> FCNOut:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images.
+
+ Returns:
+ FCNOut: Raw model predictions.
+ """
+ features = self.basemodel(images)
+ out = self.fcn(features)
+ return out
diff --git a/vis4d/model/seg/semantic_fpn.py b/vis4d/model/seg/semantic_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b8420840f100f30f88fedb0ae7f4eb7917c2b08
--- /dev/null
+++ b/vis4d/model/seg/semantic_fpn.py
@@ -0,0 +1,152 @@
+"""SemanticFPN Implementation."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.op.base import BaseModel, ResNetV1c
+from vis4d.op.fpp.fpn import FPN
+from vis4d.op.mask.util import clip_mask
+from vis4d.op.seg.semantic_fpn import SemanticFPNHead, SemanticFPNOut
+
+REV_KEYS = [
+ (r"^decode_head\.", "seg_head."),
+ (r"^classifier\.", "fcn.heads.1."),
+ (r"^backbone\.", "basemodel."),
+ (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
+ (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
+ (r"\.conv.weight", ".weight"),
+ (r"\.conv.bias", ".bias"),
+]
+for ki in range(4):
+ for kj in range(5):
+ REV_KEYS += [
+ (
+ rf"^seg_head.scale_heads\.{ki}\.{kj}\.bn\.",
+ f"seg_head.scale_heads.{ki}.{kj}.norm.",
+ )
+ ]
+
+
+class MaskOut(NamedTuple):
+ """Output mask predictions."""
+
+ masks: list[torch.Tensor] # list of masks for each image
+
+
+class SemanticFPN(nn.Module):
+ """Semantic FPN.
+
+ Args:
+ num_classes (int): Number of classes.
+ resize (bool): Resize output to input size.
+ weights (None | str): Pre-trained weights.
+ basemodel (None | BaseModel): Base model to use. If None is passed,
+ this will default to ResNetV1c
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ resize: bool = True,
+ weights: None | str = None,
+ basemodel: None | BaseModel = None,
+ ):
+ """Init."""
+ super().__init__()
+ self.resize = resize
+ if basemodel is None:
+ basemodel = ResNetV1c(
+ "resnet50_v1c",
+ pretrained=True,
+ trainable_layers=5,
+ norm_frozen=False,
+ )
+
+ self.basemodel = basemodel
+ self.fpn = FPN(self.basemodel.out_channels[2:], 256, extra_blocks=None)
+ self.seg_head = SemanticFPNHead(num_classes, 256)
+
+ if weights is not None:
+ if weights.startswith("mmseg://") or weights.startswith(
+ "bdd100k://"
+ ):
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
+ else:
+ load_model_checkpoint(self, weights)
+
+ def forward_train(self, images: torch.Tensor) -> SemanticFPNOut:
+ """Forward pass for training.
+
+ Args:
+ images (torch.Tensor): Input images.
+
+ Returns:
+ SemanticFPNOut: Raw model predictions.
+ """
+ features = self.fpn(self.basemodel(images.contiguous()))
+ out = self.seg_head(features)
+ if self.resize:
+ return SemanticFPNOut(
+ outputs=F.interpolate(
+ out.outputs,
+ scale_factor=4,
+ mode="bilinear",
+ align_corners=False,
+ )
+ )
+ return out
+
+ def forward_test(
+ self, images: torch.Tensor, original_hw: list[tuple[int, int]]
+ ) -> MaskOut:
+ """Forward pass for testing.
+
+ Args:
+ images (torch.Tensor): Input images.
+ original_hw (list[tuple[int, int]], optional): Original image
+ resolutions (before padding and resizing). Required for
+ testing.
+
+ Returns:
+ SemanticFPNOut: Raw model predictions.
+ """
+ features = self.fpn(self.basemodel(images))
+ out = self.seg_head(features)
+
+ new_masks = []
+ for i, outputs in enumerate(out.outputs):
+ opt = F.interpolate(
+ outputs.unsqueeze(0),
+ scale_factor=4,
+ mode="bilinear",
+ align_corners=False,
+ ).squeeze(0)
+ new_masks.append(clip_mask(opt, original_hw[i]).argmax(dim=0))
+ return MaskOut(masks=new_masks)
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ original_hw: None | list[tuple[int, int]] = None,
+ ) -> SemanticFPNOut | MaskOut:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images.
+ original_hw (None | list[tuple[int, int]], optional): Original
+ image resolutions (before padding and resizing). Required for
+ testing. Defaults to None.
+
+ Returns:
+ MaskOut: Raw model predictions.
+ """
+ if self.training:
+ return self.forward_train(images)
+ assert original_hw is not None
+ return self.forward_test(images, original_hw)
diff --git a/vis4d/model/segment3d/pointnet.py b/vis4d/model/segment3d/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff85d098cfda4caf5ef875d5af2223bd69053f99
--- /dev/null
+++ b/vis4d/model/segment3d/pointnet.py
@@ -0,0 +1,143 @@
+"""Implementation of Pointnet."""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.common.typing import LossesType, ModelOutput
+from vis4d.data.const import CommonKeys
+from vis4d.op.base.pointnet import PointNetSegmentation, PointNetSemanticsOut
+from vis4d.op.loss.orthogonal_transform_loss import (
+ OrthogonalTransformRegularizationLoss,
+)
+
+
+class PointnetSegmentationModel(nn.Module):
+ """Simple Segmentation Model using Pointnet."""
+
+ def __init__(
+ self,
+ num_classes: int = 11,
+ in_dimensions: int = 3,
+ weights: str | None = None,
+ ) -> None:
+ """Simple Segmentation Model using Pointnet.
+
+ Args:
+ num_classes: Number of semantic classes
+ in_dimensions: Input dimension
+ weights: Path to weight file
+ """
+ super().__init__()
+ self.model = PointNetSegmentation(
+ n_classes=num_classes, in_dimensions=in_dimensions
+ )
+ if weights is not None:
+ load_model_checkpoint(self, weights)
+
+ def __call__(
+ self, data: torch.Tensor, target: torch.Tensor | None = None
+ ) -> PointNetSemanticsOut | ModelOutput:
+ """Runs the semantic model.
+
+ Args:
+ data: Input Tensor Shape [N, C, n_pts]
+ target: Target Classes shape [N, n_pts]
+ """
+ return self._call_impl(data, target)
+
+ def forward(
+ self, data: torch.Tensor, target: torch.Tensor | None = None
+ ) -> PointNetSemanticsOut | ModelOutput:
+ """Runs the semantic model.
+
+ Args:
+ data: Input Tensor Shape [N, C, n_pts]
+ target: Target Classes shape [N, n_pts]
+ """
+ if target is not None:
+ return self.forward_train(data, target)
+ return self.forward_test(data)
+
+ def forward_train(
+ self,
+ points: torch.Tensor,
+ target: torch.Tensor,
+ ) -> PointNetSemanticsOut:
+ """Forward training stage.
+
+ Args:
+ points: Input Tensor Shape [N, C, n_pts]
+ target: Target Classes shape [N, n_pts]
+ """
+ out = self.model(points)
+ return out
+
+ def forward_test(
+ self,
+ points: torch.Tensor,
+ ) -> ModelOutput:
+ """Forward test stage.
+
+ Args:
+ points: Input Tensor Shape [N, C, n_pts]
+ """
+ return {
+ CommonKeys.semantics3d: torch.argmax(
+ self.model(points).class_logits, dim=1
+ )
+ }
+
+
+class PointnetSegmentationLoss(nn.Module):
+ """PointnetSegmentationLoss Loss."""
+
+ def __init__(
+ self,
+ regularize_transform: bool = True,
+ ignore_index: int = 255,
+ transform_weight: float = 1e-3,
+ semantic_weights: torch.Tensor | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ regularize_transform: If true add transforms to loss
+ ignore_index: Semantic class that should be ignored
+ transform_weight: Loss weight factor for transform
+ regularization loss
+ semantic_weights: Classwise weights for semantic loss
+ """
+ super().__init__()
+ self.segmentation_loss = nn.CrossEntropyLoss(
+ weight=semantic_weights, ignore_index=ignore_index
+ )
+ self.transformation_loss = OrthogonalTransformRegularizationLoss()
+ self.regularize_transform = regularize_transform
+ self.transform_weight = transform_weight
+
+ def forward(
+ self, outputs: PointNetSemanticsOut, target: torch.Tensor
+ ) -> LossesType:
+ """Calculates the losss.
+
+ Args:
+ outputs: Pointnet output
+ target: Target Labels
+ """
+ if not self.regularize_transform:
+ dict(
+ segmentation_loss=self.segmentation_loss(
+ outputs.class_logits, target
+ )
+ )
+
+ return dict(
+ segmentation_loss=self.segmentation_loss(
+ outputs.class_logits, target
+ ),
+ transform_loss=self.transform_weight
+ * self.transformation_loss(outputs.transformations),
+ )
diff --git a/vis4d/model/segment3d/pointnetpp.py b/vis4d/model/segment3d/pointnetpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e7a65d926649ef97fdee6648404373b3fd1e029
--- /dev/null
+++ b/vis4d/model/segment3d/pointnetpp.py
@@ -0,0 +1,95 @@
+"""Pointnet++ Implementation."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.common.typing import LossesType, ModelOutput
+from vis4d.data.const import CommonKeys as K
+from vis4d.op.base.pointnetpp import (
+ PointNet2Segmentation,
+ PointNet2SegmentationOut,
+)
+
+
+class PointNet2SegmentationModel(nn.Module):
+ """PointNet++ Segmentation Model implementaiton."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ in_dimensions: int = 3,
+ weights: str | None = None,
+ ):
+ """Creates a Pointnet+++ Model.
+
+ Args:
+ num_classes (int): Number of classes
+ in_dimensions (int, optional): Input dimensions. Defaults to 3.
+ weights (str, optional): Path to weights. Defaults to None.
+ """
+ super().__init__()
+
+ self.segmentation_model = PointNet2Segmentation(
+ num_classes, in_dimensions
+ )
+
+ if weights is not None:
+ load_model_checkpoint(self, weights)
+
+ def forward(
+ self, points3d: Tensor, semantics3d: Tensor | None = None
+ ) -> PointNet2SegmentationOut | ModelOutput:
+ """Forward pass of the model. Extract semantic predictions.
+
+ Args:
+ points3d (Tensor): Input point shape [b, N, C].
+ semantics3d (torch.Tenosr): Groundtruth semantic labels of
+ shape [b, N]. Defaults to None
+
+ Returns:
+ ModelOutput: Semantic predictions of the model.
+ """
+ x = self.segmentation_model(points3d)
+ if semantics3d is not None:
+ return x
+ class_pred = torch.argmax(x.class_logits, dim=1)
+ return {K.semantics3d: class_pred}
+
+
+class Pointnet2SegmentationLoss(nn.Module):
+ """Pointnet2SegmentationLoss Loss."""
+
+ def __init__(
+ self,
+ ignore_index: int = 255,
+ semantic_weights: Tensor | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ ignore_index (int, optional): Class Index that should be ignored.
+ Defaults to 255.
+ semantic_weights (Tensor, optional): Weights for each class.
+ """
+ super().__init__()
+ self.segmentation_loss = nn.CrossEntropyLoss(
+ weight=semantic_weights, ignore_index=ignore_index
+ )
+
+ def forward(
+ self, outputs: PointNet2SegmentationOut, semantics3d: Tensor
+ ) -> LossesType:
+ """Calculates the loss.
+
+ Args:
+ outputs (PointNet2SegmentationOut): Model outputs.
+ semantics3d (Tensor): Groundtruth semantic labels.
+ """
+ return dict(
+ segmentation_loss=self.segmentation_loss(
+ outputs.class_logits, semantics3d
+ ),
+ )
diff --git a/vis4d/model/track/__init__.py b/vis4d/model/track/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e64d68e5e16c5acb9535fb1ebc7c3245745ad4f0
--- /dev/null
+++ b/vis4d/model/track/__init__.py
@@ -0,0 +1 @@
+"""Contains the implementation of 2D tracking models."""
diff --git a/vis4d/model/track/qdtrack.py b/vis4d/model/track/qdtrack.py
new file mode 100644
index 0000000000000000000000000000000000000000..a89bbb103f5113a6367ff4273ff3e917de93c871
--- /dev/null
+++ b/vis4d/model/track/qdtrack.py
@@ -0,0 +1,567 @@
+"""Quasi-dense instance similarity learning model."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.model.detect.yolox import REV_KEYS as YOLOX_REV_KEYS
+from vis4d.op.base import BaseModel, CSPDarknet, ResNet
+from vis4d.op.box.box2d import scale_and_clip_boxes
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
+from vis4d.op.box.poolers import MultiScaleRoIAlign
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
+from vis4d.op.detect.rcnn import RoI2Det
+from vis4d.op.detect.yolox import YOLOXHead, YOLOXOut, YOLOXPostprocess
+from vis4d.op.fpp import FPN, YOLOXPAFPN, FeaturePyramidProcessing
+from vis4d.op.track.common import TrackOut
+from vis4d.op.track.qdtrack import (
+ QDSimilarityHead,
+ QDTrackAssociation,
+ QDTrackHead,
+)
+from vis4d.state.track.qdtrack import QDTrackGraph
+
+from .util import split_key_ref_indices
+
+REV_KEYS = [
+ (r"^faster_rcnn_heads\.", "faster_rcnn_head."),
+ (r"^backbone.body\.", "basemodel."),
+ (r"^qdtrack\.", "qdtrack_head."),
+]
+
+
+class FasterRCNNQDTrackOut(NamedTuple):
+ """Output of QDtrack model."""
+
+ detector_out: FRCNNOut
+ key_images_hw: list[tuple[int, int]]
+ key_target_boxes: list[Tensor]
+ key_embeddings: list[Tensor]
+ ref_embeddings: list[list[Tensor]]
+ key_track_ids: list[Tensor]
+ ref_track_ids: list[list[Tensor]]
+
+
+class FasterRCNNQDTrack(nn.Module):
+ """Wrap QDTrack with Faster R-CNN detector."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ basemodel: BaseModel | None = None,
+ faster_rcnn_head: FasterRCNNHead | None = None,
+ rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None,
+ qdtrack_head: QDTrackHead | None = None,
+ track_graph: QDTrackGraph | None = None,
+ weights: None | str = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of object categories.
+ basemodel (BaseModel, optional): Base model network. Defaults to
+ None. If None, will use ResNet50.
+ faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head.
+ Defaults to None. if None, will use default FasterRCNNHead.
+ rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN
+ bounding boxes. Defaults to None.
+ qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None.
+ If None, will use default QDTrackHead.
+ track_graph (QDTrackGraph, optional): Track graph. Defaults to
+ None. If None, will use default QDTrackGraph.
+ weights (str, optional): Weights to load for model.
+ """
+ super().__init__()
+ self.basemodel = (
+ ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
+ if basemodel is None
+ else basemodel
+ )
+
+ self.fpn = FPN(self.basemodel.out_channels[2:], 256)
+
+ if faster_rcnn_head is None:
+ self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes)
+ else:
+ self.faster_rcnn_head = faster_rcnn_head
+
+ self.roi2det = RoI2Det(rcnn_box_decoder)
+
+ self.qdtrack_head = (
+ QDTrackHead() if qdtrack_head is None else qdtrack_head
+ )
+
+ self.track_graph = (
+ QDTrackGraph() if track_graph is None else track_graph
+ )
+
+ if weights is not None:
+ load_model_checkpoint(
+ self, weights, map_location="cpu", rev_keys=REV_KEYS
+ )
+
+ def forward(
+ self,
+ images: list[Tensor] | Tensor,
+ images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ frame_ids: list[list[int]] | list[int],
+ boxes2d: None | list[list[Tensor]] = None,
+ boxes2d_classes: None | list[list[Tensor]] = None,
+ boxes2d_track_ids: None | list[list[Tensor]] = None,
+ keyframes: None | list[list[bool]] = None,
+ ) -> TrackOut | FasterRCNNQDTrackOut:
+ """Forward."""
+ if self.training:
+ assert (
+ isinstance(images, list)
+ and boxes2d is not None
+ and boxes2d_classes is not None
+ and boxes2d_track_ids is not None
+ and keyframes is not None
+ )
+ return self._forward_train(
+ images,
+ images_hw, # type: ignore
+ boxes2d,
+ boxes2d_classes,
+ boxes2d_track_ids,
+ keyframes,
+ )
+ return self._forward_test(images, images_hw, original_hw, frame_ids) # type: ignore # pylint: disable=line-too-long
+
+ def _forward_train(
+ self,
+ images: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ target_boxes: list[list[Tensor]],
+ target_classes: list[list[Tensor]],
+ target_track_ids: list[list[Tensor]],
+ keyframes: list[list[bool]],
+ ) -> FasterRCNNQDTrackOut:
+ """Forward training stage.
+
+ Args:
+ images (list[Tensor]): Input images.
+ images_hw (list[list[tuple[int, int]]]): Input image resolutions.
+ target_boxes (list[list[Tensor]]): Bounding box labels.
+ target_classes (list[list[Tensor]]): Class labels.
+ target_track_ids (list[list[Tensor]]): Track IDs.
+ keyframes (list[list[bool]]): Whether the frame is a keyframe.
+
+ Returns:
+ FasterRCNNQDTrackOut: Raw model outputs.
+ """
+ key_index, ref_indices = split_key_ref_indices(keyframes)
+
+ # feature extraction
+ key_features = self.fpn(self.basemodel(images[key_index]))
+ ref_features = [
+ self.fpn(self.basemodel(images[ref_index]))
+ for ref_index in ref_indices
+ ]
+
+ key_detector_out = self.faster_rcnn_head(
+ key_features,
+ images_hw[key_index],
+ target_boxes[key_index],
+ target_classes[key_index],
+ )
+
+ with torch.no_grad():
+ ref_detector_out = [
+ self.faster_rcnn_head(
+ ref_features[i],
+ images_hw[ref_index],
+ target_boxes[ref_index],
+ target_classes[ref_index],
+ )
+ for i, ref_index in enumerate(ref_indices)
+ ]
+
+ key_proposals = key_detector_out.proposals.boxes
+ ref_proposals = [ref.proposals.boxes for ref in ref_detector_out]
+ key_target_boxes = target_boxes[key_index]
+ ref_target_boxes = [
+ target_boxes[ref_index] for ref_index in ref_indices
+ ]
+ key_target_track_ids = target_track_ids[key_index]
+ ref_target_track_ids = [
+ target_track_ids[ref_index] for ref_index in ref_indices
+ ]
+
+ (
+ key_embeddings,
+ ref_embeddings,
+ key_track_ids,
+ ref_track_ids,
+ ) = self.qdtrack_head(
+ features=[key_features, *ref_features],
+ det_boxes=[key_proposals, *ref_proposals],
+ target_boxes=[key_target_boxes, *ref_target_boxes],
+ target_track_ids=[key_target_track_ids, *ref_target_track_ids],
+ )
+ assert (
+ ref_embeddings is not None
+ and key_track_ids is not None
+ and ref_track_ids is not None
+ )
+
+ return FasterRCNNQDTrackOut(
+ detector_out=key_detector_out,
+ key_images_hw=images_hw[key_index],
+ key_target_boxes=key_target_boxes,
+ key_embeddings=key_embeddings,
+ ref_embeddings=ref_embeddings,
+ key_track_ids=key_track_ids,
+ ref_track_ids=ref_track_ids,
+ )
+
+ def _forward_test(
+ self,
+ images: Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ frame_ids: list[int],
+ ) -> TrackOut:
+ """Forward inference stage."""
+ features = self.basemodel(images)
+ features = self.fpn(features)
+ detector_out = self.faster_rcnn_head(features, images_hw)
+
+ boxes, scores, class_ids = self.roi2det(
+ *detector_out.roi, detector_out.proposals.boxes, images_hw
+ )
+ embeddings, _, _, _ = self.qdtrack_head(features, boxes)
+
+ tracks = self.track_graph(
+ embeddings, boxes, scores, class_ids, frame_ids
+ )
+
+ for i, boxs in enumerate(tracks.boxes):
+ tracks.boxes[i] = scale_and_clip_boxes(
+ boxs, original_hw[i], images_hw[i]
+ )
+ return tracks
+
+ def __call__(
+ self,
+ images: list[Tensor] | Tensor,
+ images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ frame_ids: list[list[int]] | list[int],
+ boxes2d: None | list[list[Tensor]] = None,
+ boxes2d_classes: None | list[list[Tensor]] = None,
+ boxes2d_track_ids: None | list[list[Tensor]] = None,
+ keyframes: None | list[list[bool]] = None,
+ ) -> TrackOut | FasterRCNNQDTrackOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ images,
+ images_hw,
+ original_hw,
+ frame_ids,
+ boxes2d,
+ boxes2d_classes,
+ boxes2d_track_ids,
+ keyframes,
+ )
+
+
+class YOLOXQDTrackOut(NamedTuple):
+ """Output of QDtrack YOLOX model."""
+
+ detector_out: YOLOXOut
+ key_images_hw: list[tuple[int, int]]
+ key_target_boxes: list[Tensor]
+ key_target_classes: list[Tensor]
+ key_embeddings: list[Tensor]
+ ref_embeddings: list[list[Tensor]]
+ key_track_ids: list[Tensor]
+ ref_track_ids: list[list[Tensor]]
+
+
+class YOLOXQDTrack(nn.Module):
+ """Wrap QDTrack with YOLOX detector."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ basemodel: BaseModel | None = None,
+ fpn: FeaturePyramidProcessing | None = None,
+ yolox_head: YOLOXHead | None = None,
+ train_postprocessor: YOLOXPostprocess | None = None,
+ test_postprocessor: YOLOXPostprocess | None = None,
+ qdtrack_head: QDTrackHead | None = None,
+ track_graph: QDTrackGraph | None = None,
+ weights: None | str = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of object categories.
+ basemodel (BaseModel, optional): Base model. Defaults to None. If
+ None, will use CSPDarknet.
+ fpn (FeaturePyramidProcessing, optional): Feature Pyramid
+ Processing. Defaults to None. If None, will use YOLOXPAFPN.
+ yolox_head (YOLOXHead, optional): YOLOX head. Defaults to None. If
+ None, will use YOLOXHead.
+ train_postprocessor (YOLOXPostprocess, optional): Post processor
+ for training. Defaults to None. If None, will use
+ YOLOXPostprocess.
+ test_postprocessor (YOLOXPostprocess, optional): Post processor
+ for testing. Defaults to None. If None, will use
+ YOLOXPostprocess.
+ qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None.
+ If None, will use default QDTrackHead.
+ track_graph (QDTrackGraph, optional): Track graph. Defaults to
+ None. If None, will use default QDTrackGraph.
+ weights (str, optional): Weights to load for model.
+ """
+ super().__init__()
+ self.basemodel = (
+ CSPDarknet(deepen_factor=1.33, widen_factor=1.25)
+ if basemodel is None
+ else basemodel
+ )
+ self.fpn = (
+ YOLOXPAFPN([320, 640, 1280], 320, num_csp_blocks=4)
+ if fpn is None
+ else fpn
+ )
+ self.yolox_head = (
+ YOLOXHead(
+ num_classes=num_classes, in_channels=320, feat_channels=320
+ )
+ if yolox_head is None
+ else yolox_head
+ )
+ self.train_postprocessor = (
+ YOLOXPostprocess(
+ self.yolox_head.point_generator,
+ self.yolox_head.box_decoder,
+ nms_threshold=0.7,
+ score_thr=0.0,
+ nms_pre=2000,
+ max_per_img=1000,
+ )
+ if train_postprocessor is None
+ else train_postprocessor
+ )
+ self.test_postprocessor = (
+ YOLOXPostprocess(
+ self.yolox_head.point_generator,
+ self.yolox_head.box_decoder,
+ nms_threshold=0.65,
+ score_thr=0.1,
+ )
+ if test_postprocessor is None
+ else test_postprocessor
+ )
+
+ self.qdtrack_head = (
+ QDTrackHead(
+ QDSimilarityHead(
+ MultiScaleRoIAlign(
+ resolution=[7, 7],
+ strides=[8, 16, 32],
+ sampling_ratio=0,
+ ),
+ in_dim=320,
+ )
+ )
+ if qdtrack_head is None
+ else qdtrack_head
+ )
+
+ self.track_graph = (
+ QDTrackGraph(
+ track=QDTrackAssociation(
+ init_score_thr=0.5, obj_score_thr=0.35
+ )
+ )
+ if track_graph is None
+ else track_graph
+ )
+
+ if weights is not None:
+ load_model_checkpoint(
+ self, weights, map_location="cpu", rev_keys=YOLOX_REV_KEYS
+ )
+
+ def forward(
+ self,
+ images: list[Tensor] | Tensor,
+ images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ frame_ids: list[list[int]] | list[int],
+ boxes2d: None | list[list[Tensor]] = None,
+ boxes2d_classes: None | list[list[Tensor]] = None,
+ boxes2d_track_ids: None | list[list[Tensor]] = None,
+ keyframes: None | list[list[bool]] = None,
+ ) -> TrackOut | YOLOXQDTrackOut:
+ """Forward."""
+ if self.training:
+ assert (
+ isinstance(images, list)
+ and boxes2d is not None
+ and boxes2d_classes is not None
+ and boxes2d_track_ids is not None
+ and keyframes is not None
+ )
+ return self._forward_train(
+ images,
+ images_hw, # type: ignore
+ boxes2d,
+ boxes2d_classes,
+ boxes2d_track_ids,
+ keyframes,
+ )
+ return self._forward_test(images, images_hw, original_hw, frame_ids) # type: ignore # pylint: disable=line-too-long
+
+ def _forward_train(
+ self,
+ images: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ target_boxes: list[list[Tensor]],
+ target_classes: list[list[Tensor]],
+ target_track_ids: list[list[Tensor]],
+ keyframes: list[list[bool]],
+ ) -> YOLOXQDTrackOut:
+ """Forward training stage.
+
+ Args:
+ images (list[Tensor]): Input images.
+ images_hw (list[list[tuple[int, int]]]): Input image resolutions.
+ target_boxes (list[list[Tensor]]): Bounding box labels.
+ target_classes (list[list[Tensor]]): Class labels.
+ target_track_ids (list[list[Tensor]]): Track IDs.
+ keyframes (list[list[bool]]): Whether the frame is a keyframe.
+
+ Returns:
+ YOLOXQDTrackOut: Raw model outputs.
+ """
+ key_index, ref_indices = split_key_ref_indices(keyframes)
+
+ # feature extraction
+ key_features = self.fpn(self.basemodel(images[key_index].contiguous()))
+ ref_features = [
+ self.fpn(self.basemodel(images[ref_index].contiguous()))
+ for ref_index in ref_indices
+ ]
+
+ key_detector_out = self.yolox_head(key_features[-3:])
+ key_proposals, _, _ = self.train_postprocessor(
+ cls_outs=key_detector_out.cls_score,
+ reg_outs=key_detector_out.bbox_pred,
+ obj_outs=key_detector_out.objectness,
+ images_hw=images_hw[key_index],
+ )
+
+ with torch.no_grad():
+ ref_detector_out = [
+ self.yolox_head(ref_feat[-3:]) for ref_feat in ref_features
+ ]
+ ref_proposals = [
+ self.train_postprocessor(
+ cls_outs=ref_out.cls_score,
+ reg_outs=ref_out.bbox_pred,
+ obj_outs=ref_out.objectness,
+ images_hw=images_hw[ref_index],
+ )[0]
+ for ref_index, ref_out in zip(ref_indices, ref_detector_out)
+ ]
+
+ key_target_boxes = target_boxes[key_index]
+ ref_target_boxes = [
+ target_boxes[ref_index] for ref_index in ref_indices
+ ]
+ key_target_classes = target_classes[key_index]
+ key_target_track_ids = target_track_ids[key_index]
+ ref_target_track_ids = [
+ target_track_ids[ref_index] for ref_index in ref_indices
+ ]
+
+ (
+ key_embeddings,
+ ref_embeddings,
+ key_track_ids,
+ ref_track_ids,
+ ) = self.qdtrack_head(
+ features=[key_features, *ref_features],
+ det_boxes=[key_proposals, *ref_proposals],
+ target_boxes=[key_target_boxes, *ref_target_boxes],
+ target_track_ids=[key_target_track_ids, *ref_target_track_ids],
+ )
+ assert (
+ ref_embeddings is not None
+ and key_track_ids is not None
+ and ref_track_ids is not None
+ )
+
+ return YOLOXQDTrackOut(
+ detector_out=key_detector_out,
+ key_images_hw=images_hw[key_index],
+ key_target_boxes=key_target_boxes,
+ key_target_classes=key_target_classes,
+ key_embeddings=key_embeddings,
+ ref_embeddings=ref_embeddings,
+ key_track_ids=key_track_ids,
+ ref_track_ids=ref_track_ids,
+ )
+
+ def _forward_test(
+ self,
+ images: torch.Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+ frame_ids: list[int],
+ ) -> TrackOut:
+ """Forward inference stage."""
+ features = self.fpn(self.basemodel(images))
+ outs = self.yolox_head(features[-3:])
+ boxes, scores, class_ids = self.test_postprocessor(
+ cls_outs=outs.cls_score,
+ reg_outs=outs.bbox_pred,
+ obj_outs=outs.objectness,
+ images_hw=images_hw,
+ )
+
+ embeddings, _, _, _ = self.qdtrack_head(features, boxes)
+
+ tracks = self.track_graph(
+ embeddings, boxes, scores, class_ids, frame_ids
+ )
+
+ for i, boxs in enumerate(tracks.boxes):
+ tracks.boxes[i] = scale_and_clip_boxes(
+ boxs, original_hw[i], images_hw[i]
+ )
+ return tracks
+
+ def __call__(
+ self,
+ images: list[Tensor] | Tensor,
+ images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]],
+ frame_ids: list[list[int]] | list[int],
+ boxes2d: None | list[list[Tensor]] = None,
+ boxes2d_classes: None | list[list[Tensor]] = None,
+ boxes2d_track_ids: None | list[list[Tensor]] = None,
+ keyframes: None | list[list[bool]] = None,
+ ) -> TrackOut | FasterRCNNQDTrackOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ images,
+ images_hw,
+ original_hw,
+ frame_ids,
+ boxes2d,
+ boxes2d_classes,
+ boxes2d_track_ids,
+ keyframes,
+ )
diff --git a/vis4d/model/track/util.py b/vis4d/model/track/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f32af99ed1658d56ac2091994756642df4a441a9
--- /dev/null
+++ b/vis4d/model/track/util.py
@@ -0,0 +1,24 @@
+"""Utility functions for track module."""
+
+from __future__ import annotations
+
+
+def split_key_ref_indices(
+ keyframes: list[list[bool]],
+) -> tuple[int, list[int]]:
+ """Get key frame from list of sample attributes."""
+ key_ind = None
+ ref_inds = []
+ for i, is_keys in enumerate(keyframes):
+ assert all(
+ is_keys[0] == is_key for is_key in is_keys
+ ), "Same batch should have the same view."
+ if is_keys[0]:
+ key_ind = i
+ else:
+ ref_inds.append(i)
+
+ assert key_ind is not None, "Key frame not found."
+ assert len(ref_inds) > 0, "No reference frames found."
+
+ return key_ind, ref_inds
diff --git a/vis4d/model/track3d/__init__.py b/vis4d/model/track3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2d2ba631ffd586e77b6866ad8cd94e96a24419
--- /dev/null
+++ b/vis4d/model/track3d/__init__.py
@@ -0,0 +1 @@
+"""Contains the implementation of 3D Tracking models."""
diff --git a/vis4d/model/track3d/cc_3dt.py b/vis4d/model/track3d/cc_3dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..972728ccda86d586fb75ace53419d3373ea5c416
--- /dev/null
+++ b/vis4d/model/track3d/cc_3dt.py
@@ -0,0 +1,605 @@
+"""CC-3DT model implementation.
+
+This file composes the operations associated with CC-3DT
+`https://arxiv.org/abs/2212.01247` into the full model implementation.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import NamedTuple
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.data.const import AxisMode
+from vis4d.model.track.qdtrack import FasterRCNNQDTrackOut
+from vis4d.op.base import BaseModel, ResNet
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.box2d import bbox_area, bbox_clip
+from vis4d.op.box.box3d import boxes3d_to_corners, transform_boxes3d
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
+from vis4d.op.detect3d.qd_3dt import QD3DTBBox3DHead, RoI2Det3D
+from vis4d.op.detect3d.util import bev_3d_nms
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead
+from vis4d.op.detect.rcnn import RCNNHead, RoI2Det
+from vis4d.op.fpp import FPN
+from vis4d.op.geometry.projection import project_points
+from vis4d.op.geometry.rotation import (
+ quaternion_to_matrix,
+ rotation_matrix_yaw,
+)
+from vis4d.op.geometry.transform import inverse_rigid_transform
+from vis4d.op.track3d.cc_3dt import (
+ CC3DTrackAssociation,
+ cam_to_global,
+ get_track_3d_out,
+)
+from vis4d.op.track3d.common import Track3DOut
+from vis4d.op.track.qdtrack import QDTrackHead
+from vis4d.state.track3d.cc_3dt import CC3DTrackGraph
+
+from ..track.util import split_key_ref_indices
+
+
+class FasterRCNNCC3DTOut(NamedTuple):
+ """Output of CC-3DT model with Faster R-CNN detector."""
+
+ detector_3d_out: Tensor
+ detector_3d_target: Tensor
+ detector_3d_labels: Tensor
+ qdtrack_out: FasterRCNNQDTrackOut
+
+
+class FasterRCNNCC3DT(nn.Module):
+ """CC-3DT with Faster-RCNN detector."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ basemodel: BaseModel | None = None,
+ faster_rcnn_head: FasterRCNNHead | None = None,
+ rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None,
+ qdtrack_head: QDTrackHead | None = None,
+ track_graph: CC3DTrackGraph | None = None,
+ pure_det: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of object categories.
+ basemodel (BaseModel, optional): Base model network. Defaults to
+ None. If None, will use ResNet50.
+ faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head.
+ Defaults to None. if None, will use default FasterRCNNHead.
+ rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN
+ bounding boxes. Defaults to None.
+ qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None.
+ If None, will use default QDTrackHead.
+ track_graph (CC3DTrackGraph, optional): Track graph. Defaults to
+ None. If None, will use default CC3DTrackGraph.
+ pure_det (bool, optional): Whether to use pure detection. Defaults
+ to False.
+ """
+ super().__init__()
+ self.basemodel = (
+ ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
+ if basemodel is None
+ else basemodel
+ )
+
+ self.fpn = FPN(self.basemodel.out_channels[2:], 256)
+
+ if faster_rcnn_head is None:
+ anchor_generator = AnchorGenerator(
+ scales=[4, 8],
+ ratios=[0.25, 0.5, 1.0, 2.0, 4.0],
+ strides=[4, 8, 16, 32, 64],
+ )
+ roi_head = RCNNHead(num_shared_convs=4, num_classes=num_classes)
+ self.faster_rcnn_head = FasterRCNNHead(
+ num_classes=num_classes,
+ anchor_generator=anchor_generator,
+ roi_head=roi_head,
+ )
+ else:
+ self.faster_rcnn_head = faster_rcnn_head
+
+ self.roi2det = RoI2Det(rcnn_box_decoder)
+
+ self.bbox_3d_head = QD3DTBBox3DHead(num_classes=num_classes)
+
+ self.roi2det_3d = RoI2Det3D()
+
+ self.qdtrack_head = (
+ QDTrackHead() if qdtrack_head is None else qdtrack_head
+ )
+
+ self.track_graph = (
+ CC3DTrackGraph() if track_graph is None else track_graph
+ )
+
+ self.pure_det = pure_det
+
+ def forward(
+ self,
+ images: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ intrinsics: list[Tensor],
+ extrinsics: list[Tensor] | None = None,
+ frame_ids: list[int] | None = None,
+ boxes2d: list[list[Tensor]] | None = None,
+ boxes3d: list[list[Tensor]] | None = None,
+ boxes3d_classes: list[list[Tensor]] | None = None,
+ boxes3d_track_ids: list[list[Tensor]] | None = None,
+ keyframes: None | list[list[bool]] | None = None,
+ ) -> FasterRCNNCC3DTOut | Track3DOut:
+ """Forward."""
+ if self.training:
+ assert (
+ boxes2d is not None
+ and boxes3d is not None
+ and boxes3d_classes is not None
+ and boxes3d_track_ids is not None
+ and keyframes is not None
+ )
+ return self._forward_train(
+ images,
+ images_hw,
+ intrinsics,
+ boxes2d,
+ boxes3d,
+ boxes3d_classes,
+ boxes3d_track_ids,
+ keyframes,
+ )
+
+ assert extrinsics is not None and frame_ids is not None
+ return self._forward_test(
+ images, images_hw, intrinsics, extrinsics, frame_ids
+ )
+
+ def _forward_train(
+ self,
+ images: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ intrinsics: list[Tensor],
+ target_boxes2d: list[list[Tensor]],
+ target_boxes3d: list[list[Tensor]],
+ target_classes: list[list[Tensor]],
+ target_track_ids: list[list[Tensor]],
+ keyframes: list[list[bool]],
+ ) -> FasterRCNNCC3DTOut:
+ """Foward training stage."""
+ key_index, ref_indices = split_key_ref_indices(keyframes)
+
+ # feature extraction
+ key_features = self.fpn(self.basemodel(images[key_index]))
+ ref_features = [
+ self.fpn(self.basemodel(images[ref_index]))
+ for ref_index in ref_indices
+ ]
+
+ key_detector_out = self.faster_rcnn_head(
+ key_features,
+ images_hw[key_index],
+ target_boxes2d[key_index],
+ target_classes[key_index],
+ )
+
+ with torch.no_grad():
+ ref_detector_out = [
+ self.faster_rcnn_head(
+ ref_features[i],
+ images_hw[ref_index],
+ target_boxes2d[ref_index],
+ target_classes[ref_index],
+ )
+ for i, ref_index in enumerate(ref_indices)
+ ]
+
+ key_proposals = key_detector_out.proposals.boxes
+ ref_proposals = [ref.proposals.boxes for ref in ref_detector_out]
+ key_target_boxes = target_boxes2d[key_index]
+ ref_target_boxes = [
+ target_boxes2d[ref_index] for ref_index in ref_indices
+ ]
+ key_target_track_ids = target_track_ids[key_index]
+ ref_target_track_ids = [
+ target_track_ids[ref_index] for ref_index in ref_indices
+ ]
+
+ (
+ key_embeddings,
+ ref_embeddings,
+ key_track_ids,
+ ref_track_ids,
+ ) = self.qdtrack_head(
+ features=[key_features, *ref_features],
+ det_boxes=[key_proposals, *ref_proposals],
+ target_boxes=[key_target_boxes, *ref_target_boxes],
+ target_track_ids=[key_target_track_ids, *ref_target_track_ids],
+ )
+ assert (
+ ref_embeddings is not None
+ and key_track_ids is not None
+ and ref_track_ids is not None
+ )
+
+ predictions, targets, labels = self.bbox_3d_head(
+ features=key_features,
+ det_boxes=key_proposals,
+ intrinsics=intrinsics[key_index],
+ target_boxes=key_target_boxes,
+ target_boxes3d=target_boxes3d[key_index],
+ target_class_ids=target_classes[key_index],
+ )
+ detector_3d_out = torch.cat(predictions)
+ assert targets is not None and labels is not None
+
+ return FasterRCNNCC3DTOut(
+ detector_3d_out=detector_3d_out,
+ detector_3d_target=targets,
+ detector_3d_labels=labels,
+ qdtrack_out=FasterRCNNQDTrackOut(
+ detector_out=key_detector_out,
+ key_images_hw=images_hw[key_index],
+ key_target_boxes=key_target_boxes,
+ key_embeddings=key_embeddings,
+ ref_embeddings=ref_embeddings,
+ key_track_ids=key_track_ids,
+ ref_track_ids=ref_track_ids,
+ ),
+ )
+
+ def _forward_test(
+ self,
+ images_list: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ intrinsics_list: list[Tensor],
+ extrinsics_list: list[Tensor],
+ frame_ids: list[int],
+ ) -> Track3DOut:
+ """Forward inference stage.
+
+ Curretnly only work with single batch per gpu.
+ """
+ # (N, 1, 3, H, W) -> (N, 3, H, W)
+ images = torch.cat(images_list)
+ # (N, 1, 3, 3) -> (N, 3, 3)
+ intrinsics = torch.cat(intrinsics_list)
+ # (N, 1, 4, 4) -> (N, 4, 4)
+ extrinsics = torch.cat(extrinsics_list)
+ # (N, 1) -> (N,)
+ frame_id = frame_ids[0]
+ images_hw_list: list[tuple[int, int]] = sum(images_hw, [])
+
+ features = self.basemodel(images)
+ features = self.fpn(features)
+ _, roi, proposals, _, _, _ = self.faster_rcnn_head(
+ features, images_hw_list
+ )
+
+ boxes_2d_list, scores_2d_list, class_ids_list = self.roi2det(
+ *roi, proposals.boxes, images_hw_list
+ )
+
+ predictions, _, _ = self.bbox_3d_head(
+ features, det_boxes=boxes_2d_list
+ )
+
+ boxes_3d_list, scores_3d_list = self.roi2det_3d(
+ predictions, boxes_2d_list, class_ids_list, intrinsics
+ )
+
+ embeddings_list, _, _, _ = self.qdtrack_head(features, boxes_2d_list)
+
+ # Assign camera id
+ camera_ids_list = []
+ for i, boxes_2d in enumerate(boxes_2d_list):
+ camera_ids_list.append(
+ (torch.mul(torch.ones(len(boxes_2d)), i)).to(boxes_2d.device)
+ )
+
+ # Move 3D boxes to world coordinate
+ boxes_3d_list = cam_to_global(boxes_3d_list, extrinsics)
+
+ # Merge boxes from all cameras
+ boxes_2d = torch.cat(boxes_2d_list)
+ scores_2d = torch.cat(scores_2d_list)
+ camera_ids = torch.cat(camera_ids_list)
+ boxes_3d = torch.cat(boxes_3d_list)
+ scores_3d = torch.cat(scores_3d_list)
+ class_ids = torch.cat(class_ids_list)
+ embeddings = torch.cat(embeddings_list)
+
+ if self.pure_det:
+ return get_track_3d_out(
+ boxes_3d, class_ids, scores_3d, torch.zeros_like(class_ids)
+ )
+
+ # 3D NMS in world coordinate
+ keep_indices = bev_3d_nms(
+ center_x=boxes_3d[:, 0].unsqueeze(1),
+ center_y=boxes_3d[:, 1].unsqueeze(1),
+ width=boxes_3d[:, 4].unsqueeze(1),
+ length=boxes_3d[:, 5].unsqueeze(1),
+ angle=180.0 / torch.pi * boxes_3d[:, 8].unsqueeze(1),
+ scores=scores_2d * scores_3d,
+ )
+
+ boxes_2d = boxes_2d[keep_indices]
+ scores_2d = scores_2d[keep_indices]
+ camera_ids = camera_ids[keep_indices]
+ boxes_3d = boxes_3d[keep_indices]
+ scores_3d = scores_3d[keep_indices]
+ class_ids = class_ids[keep_indices]
+ embeddings = embeddings[keep_indices]
+
+ outs = self.track_graph(
+ boxes_2d,
+ scores_2d,
+ camera_ids,
+ boxes_3d,
+ scores_3d,
+ class_ids,
+ embeddings,
+ frame_id,
+ )
+
+ return outs
+
+ def __call__(
+ self,
+ images: list[Tensor] | Tensor,
+ images_hw: list[list[tuple[int, int]]],
+ intrinsics: list[Tensor] | Tensor,
+ extrinsics: Tensor | None = None,
+ frame_ids: list[list[int]] | None = None,
+ boxes2d: list[list[Tensor]] | None = None,
+ boxes3d: list[list[Tensor]] | None = None,
+ boxes3d_classes: list[list[Tensor]] | None = None,
+ boxes3d_track_ids: list[list[Tensor]] | None = None,
+ keyframes: None | list[list[bool]] | None = None,
+ ) -> FasterRCNNCC3DTOut | Track3DOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ images,
+ images_hw,
+ intrinsics,
+ extrinsics,
+ frame_ids,
+ boxes2d,
+ boxes3d,
+ boxes3d_classes,
+ boxes3d_track_ids,
+ keyframes,
+ )
+
+
+class CC3DT(nn.Module):
+ """CC-3DT with custom detection results."""
+
+ def __init__(
+ self,
+ basemodel: BaseModel | None = None,
+ qdtrack_head: QDTrackHead | None = None,
+ track_graph: CC3DTrackGraph | None = None,
+ detection_range: Sequence[float] | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ basemodel (BaseModel, optional): Base model network. Defaults to
+ None. If None, will use ResNet50.
+ qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None.
+ If None, will use default QDTrackHead.
+ track_graph (CC3DTrackGraph, optional): Track graph. Defaults to
+ None. If None, will use default CC3DTrackGraph.
+ detection_range (Sequence[float], optional): Detection range for
+ each class. Defaults to None.
+ """
+ super().__init__()
+ self.basemodel = (
+ ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
+ if basemodel is None
+ else basemodel
+ )
+
+ self.fpn = FPN(self.basemodel.out_channels[2:], 256)
+
+ self.qdtrack_head = (
+ QDTrackHead() if qdtrack_head is None else qdtrack_head
+ )
+
+ self.track_graph = track_graph or CC3DTrackGraph(
+ track=CC3DTrackAssociation(init_score_thr=0.2, obj_score_thr=0.1),
+ update_3d_score=False,
+ add_backdrops=False,
+ )
+
+ self.detection_range = detection_range
+
+ def forward(
+ self,
+ images_list: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ intrinsics_list: list[Tensor],
+ extrinsics_list: list[Tensor],
+ frame_ids: list[int],
+ pred_boxes3d: list[list[Tensor]],
+ pred_boxes3d_classes: list[list[Tensor]],
+ pred_boxes3d_scores: list[list[Tensor]],
+ pred_boxes3d_velocities: list[list[Tensor]],
+ ) -> Track3DOut:
+ """Forward inference stage.
+
+ Curretnly only work with single batch per gpu.
+ """
+ # (N, 1, 3, H, W) -> (N, 3, H, W)
+ images = torch.cat(images_list)
+ # (N, 1, 3, 3) -> (N, 3, 3)
+ intrinsics = torch.cat(intrinsics_list)
+ # (N, 1, 4, 4) -> (N, 4, 4)
+ extrinsics = torch.cat(extrinsics_list)
+ # (N, 1) -> (N,)
+ frame_id = frame_ids[0]
+ images_hw_list: list[tuple[int, int]] = sum(images_hw, [])
+
+ features = self.basemodel(images)
+ features = self.fpn(features)
+
+ # (1, 1, B,) -> (B,)
+ boxes_3d = pred_boxes3d[0][0]
+ class_ids = pred_boxes3d_classes[0][0]
+ scores_3d = pred_boxes3d_scores[0][0]
+ velocities = pred_boxes3d_velocities[0][0]
+
+ # Get 2D boxes and assign camera id
+ global_to_cams = inverse_rigid_transform(extrinsics)
+
+ boxes_3d_list = []
+ boxes_2d_list = []
+ class_ids_list = []
+ scores_list = []
+ camera_ids_list = []
+ for i, global_to_cam in enumerate(global_to_cams):
+ boxes3d_cam = transform_boxes3d(
+ boxes_3d,
+ global_to_cam,
+ source_axis_mode=AxisMode.ROS,
+ target_axis_mode=AxisMode.OPENCV,
+ )
+
+ corners = boxes3d_to_corners(
+ boxes3d_cam, axis_mode=AxisMode.OPENCV
+ )
+
+ corners_2d = project_points(corners, intrinsics[i])
+
+ boxes_2d = self._to_boxes2d(corners_2d)
+ boxes_2d = bbox_clip(boxes_2d, images_hw_list[i], 1)
+
+ mask = (
+ (boxes3d_cam[:, 2] > 0)
+ & (bbox_area(boxes_2d) > 0)
+ & (
+ bbox_area(boxes_2d)
+ < (images_hw_list[i][0] - 1) * (images_hw_list[i][1] - 1)
+ )
+ & self._filter_distance(class_ids, boxes3d_cam)
+ )
+
+ cc_3dt_boxes_3d = boxes_3d.new_zeros(len(boxes_2d[mask]), 12)
+ cc_3dt_boxes_3d[:, :3] = boxes_3d[mask][:, :3]
+ # WLH -> HWL
+ cc_3dt_boxes_3d[:, 3:6] = boxes_3d[mask][:, [5, 3, 4]]
+ cc_3dt_boxes_3d[:, 6:9] = rotation_matrix_yaw(
+ quaternion_to_matrix(boxes_3d[mask][:, 6:]), AxisMode.ROS
+ )
+ cc_3dt_boxes_3d[:, 9:] = velocities[mask]
+
+ boxes_3d_list.append(cc_3dt_boxes_3d)
+ boxes_2d_list.append(boxes_2d[mask])
+ class_ids_list.append(class_ids[mask])
+ scores_list.append(scores_3d[mask])
+ camera_ids_list.append(
+ (torch.mul(torch.ones(len(cc_3dt_boxes_3d)), i)).to(
+ boxes_2d.device
+ )
+ )
+
+ embeddings_list, _, _, _ = self.qdtrack_head(features, boxes_2d_list)
+
+ boxes_3d = torch.cat(boxes_3d_list)
+ boxes_2d = torch.cat(boxes_2d_list)
+ camera_ids = torch.cat(camera_ids_list)
+ scores = torch.cat(scores_list)
+ class_ids = torch.cat(class_ids_list)
+ embeddings = torch.cat(embeddings_list)
+
+ # Select project boxes2d according to bbox area
+ keep_indices = embeddings.new_ones(len(boxes_3d)).bool()
+ boxes_2d_area = bbox_area(boxes_2d)
+ for i, box3d in enumerate(boxes_3d):
+ for same_idx in (
+ (box3d[:3] == boxes_3d[:, :3]).all(dim=1).nonzero()
+ ):
+ if (
+ same_idx != i
+ and boxes_2d_area[same_idx] > boxes_2d_area[i]
+ ):
+ keep_indices[i] = False
+ break
+
+ boxes_3d = boxes_3d[keep_indices]
+ boxes_2d = boxes_2d[keep_indices]
+ camera_ids = camera_ids[keep_indices]
+ scores = scores[keep_indices]
+ class_ids = class_ids[keep_indices]
+ embeddings = embeddings[keep_indices]
+
+ outs = self.track_graph(
+ boxes_2d,
+ scores,
+ camera_ids,
+ boxes_3d,
+ scores,
+ class_ids,
+ embeddings,
+ frame_id,
+ )
+
+ return outs
+
+ def _to_boxes2d(self, corners_2d: Tensor) -> Tensor:
+ """Project 3D boxes (Camera coordinates) to 2D boxes."""
+ min_x = torch.min(corners_2d[:, :, 0], 1).values.unsqueeze(-1)
+ min_y = torch.min(corners_2d[:, :, 1], 1).values.unsqueeze(-1)
+ max_x = torch.max(corners_2d[:, :, 0], 1).values.unsqueeze(-1)
+ max_y = torch.max(corners_2d[:, :, 1], 1).values.unsqueeze(-1)
+
+ return torch.cat([min_x, min_y, max_x, max_y], dim=1)
+
+ def _filter_distance(
+ self, class_ids: Tensor, boxes3d: Tensor, tolerance: float = 2.0
+ ) -> Tensor:
+ """Filter boxes3d on distance."""
+ if self.detection_range is None:
+ return torch.ones_like(class_ids, dtype=torch.bool)
+
+ return torch.linalg.norm( # pylint: disable=not-callable
+ boxes3d[:, [0, 2]], dim=1
+ ) <= torch.tensor(
+ [
+ self.detection_range[class_id] + tolerance
+ for class_id in class_ids
+ ]
+ ).to(
+ class_ids.device
+ )
+
+ def __call__(
+ self,
+ images_list: list[Tensor],
+ images_hw: list[list[tuple[int, int]]],
+ intrinsics_list: list[Tensor],
+ extrinsics_list: list[Tensor],
+ frame_ids: list[int],
+ pred_boxes3d: list[list[Tensor]],
+ pred_boxes3d_classes: list[list[Tensor]],
+ pred_boxes3d_scores: list[list[Tensor]],
+ pred_boxes3d_velocities: list[list[Tensor]],
+ ) -> Track3DOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ images_list,
+ images_hw,
+ intrinsics_list,
+ extrinsics_list,
+ frame_ids,
+ pred_boxes3d,
+ pred_boxes3d_classes,
+ pred_boxes3d_scores,
+ pred_boxes3d_velocities,
+ )
diff --git a/vis4d/op/__init__.py b/vis4d/op/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fecfaf92b8fec6c23fe16b34dfc8f9bcd95370d
--- /dev/null
+++ b/vis4d/op/__init__.py
@@ -0,0 +1,8 @@
+"""Compositional operators used for implementing models.
+
+This is where most of the library APIs are implemented.
+All the operators are functors. They are native PyTorch modules and only have a
+forward member for function invocations. We follow the principle of functional
+programming. The operators don't keep internal states besides the operator
+weights. The operator computation and call has no side effects.
+"""
diff --git a/vis4d/op/__pycache__/__init__.cpython-311.pyc b/vis4d/op/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..efd7c26650690b1ce2842f2a236cc75e94eeb4ce
Binary files /dev/null and b/vis4d/op/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/op/base/__init__.py b/vis4d/op/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffa6ba27c722bb9b3131185555cae7ebb80aed39
--- /dev/null
+++ b/vis4d/op/base/__init__.py
@@ -0,0 +1,8 @@
+"""Base model module."""
+
+from .base import BaseModel
+from .csp_darknet import CSPDarknet
+from .dla import DLA
+from .resnet import ResNet, ResNetV1c
+
+__all__ = ["BaseModel", "CSPDarknet", "DLA", "ResNet", "ResNetV1c"]
diff --git a/vis4d/op/base/base.py b/vis4d/op/base/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..90222696739151aa78309f04b780a3f0a5f8eaa8
--- /dev/null
+++ b/vis4d/op/base/base.py
@@ -0,0 +1,58 @@
+"""Base model interface."""
+
+from __future__ import annotations
+
+import abc
+
+import torch
+from torch import nn
+
+
+class BaseModel(nn.Module):
+ """Abstract base model for feature extraction."""
+
+ @abc.abstractmethod
+ def forward(self, images: torch.Tensor) -> list[torch.Tensor]:
+ """Base model forward.
+
+ Args:
+ images (Tensor[N, C, H, W]): Image input to process. Expected to be
+ type float32.
+
+ Raises:
+ NotImplementedError: This is an abstract class method.
+
+ Returns:
+ fp (list[torch.Tensor]): The output feature pyramid. The list index
+ represents the level, which has a downsampling ratio of 2^index for
+ most of the cases. fp[2] is the C2 or P2 in the FPN paper
+ (https://arxiv.org/abs/1612.03144). fp[0] is the original image or
+ the feature map with the same resolution. fp[1] may be the copy of
+ the input image if the network doesn't generate the feature map of
+ the resolution.
+ """
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def out_channels(self) -> list[int]:
+ """Get the number of channels for each level of feature pyramid.
+
+ Raises:
+ NotImplementedError: This is an abstract class method.
+
+ Returns:
+ list[int]: Number of channels.
+ """
+ raise NotImplementedError
+
+ def __call__(self, images: torch.Tensor) -> list[torch.Tensor]:
+ """Type definition for call implementation.
+
+ Args:
+ images (torch.Tensor): Image input to process.
+
+ Returns:
+ list[torch.Tensor]: The output feature pyramid.
+ """
+ return self._call_impl(images)
diff --git a/vis4d/op/base/csp_darknet.py b/vis4d/op/base/csp_darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..830d171f31c5935ee5b6477e5dadac0cbf519e9b
--- /dev/null
+++ b/vis4d/op/base/csp_darknet.py
@@ -0,0 +1,305 @@
+"""CSP-Darknet base network used in YOLOX.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Sequence
+
+import torch
+from torch import nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from vis4d.op.layer.conv2d import Conv2d
+from vis4d.op.layer.csp_layer import CSPLayer
+
+
+class Focus(nn.Module):
+ """Focus width and height information into channel space.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ kernel_size (int, optional): The kernel size of the convolution.
+ Defaults to 1.
+ stride (int, optional): The stride of the convolution. Defaults to 1.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 1,
+ stride: int = 1,
+ ):
+ """Init."""
+ super().__init__()
+ self.conv = Conv2d(
+ in_channels * 4,
+ out_channels,
+ kernel_size,
+ stride,
+ padding=(kernel_size - 1) // 2,
+ bias=False,
+ norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ features (torch.Tensor): The input tensor of shape [B, C, W, H].
+ """
+ patch_top_left = features[..., ::2, ::2]
+ patch_top_right = features[..., ::2, 1::2]
+ patch_bot_left = features[..., 1::2, ::2]
+ patch_bot_right = features[..., 1::2, 1::2]
+ x = torch.cat(
+ (
+ patch_top_left,
+ patch_bot_left,
+ patch_top_right,
+ patch_bot_right,
+ ),
+ dim=1,
+ )
+ return self.conv(x)
+
+
+class SPPBottleneck(nn.Module):
+ """Spatial pyramid pooling layer used in YOLOv3-SPP.
+
+ Args:
+ in_channels (int): Input channels.
+ out_channels (int): Output channels.
+ kernel_sizes (Sequence[int], optional): Sequential of kernel sizes of
+ pooling layers. Defaults to (5, 9, 13).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_sizes: Sequence[int] = (5, 9, 13),
+ ):
+ """Init."""
+ super().__init__()
+ mid_channels = in_channels // 2
+ self.conv1 = Conv2d(
+ in_channels,
+ mid_channels,
+ 1,
+ stride=1,
+ bias=False,
+ norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+ self.poolings = nn.ModuleList(
+ [
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
+ for ks in kernel_sizes
+ ]
+ )
+ conv2_channels = mid_channels * (len(kernel_sizes) + 1)
+ self.conv2 = Conv2d(
+ conv2_channels,
+ out_channels,
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ features (torch.Tensor): Input features.
+ """
+ x = self.conv1(features)
+ x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
+ x = self.conv2(x)
+ return x
+
+
+class CSPDarknet(nn.Module):
+ """CSP-Darknet backbone used in YOLOv5 and YOLOX.
+
+ Args:
+ arch (str): Architecture of CSP-Darknet, from {P5, P6}.
+ Default: P5.
+ deepen_factor (float): Depth multiplier, multiply number of
+ blocks in CSP layer by this amount. Default: 1.0.
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (2, 3, 4).
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
+ mode). -1 means not freezing any parameters. Default: -1.
+ use_depthwise (bool): Whether to use depthwise separable convolution.
+ Default: False.
+ arch_ovewrite(list[list[int]], optional): Overwrite default arch
+ settings. Defaults to None.
+ spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
+ layers. Default: (5, 9, 13).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+
+ Example:
+ >>> import torch
+ >>> from vis4d.op.base import CSPDarknet
+ >>> self = CSPDarknet()
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 416, 416)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 256, 52, 52)
+ (1, 512, 26, 26)
+ (1, 1024, 13, 13)
+ """
+
+ # From left to right:
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
+ arch_settings = {
+ "P5": [
+ [64, 128, 3, True, False],
+ [128, 256, 9, True, False],
+ [256, 512, 9, True, False],
+ [512, 1024, 3, False, True],
+ ],
+ "P6": [
+ [64, 128, 3, True, False],
+ [128, 256, 9, True, False],
+ [256, 512, 9, True, False],
+ [512, 768, 3, True, False],
+ [768, 1024, 3, False, True],
+ ],
+ }
+
+ def __init__(
+ self,
+ arch: str = "P5",
+ deepen_factor: float = 1.0,
+ widen_factor: float = 1.0,
+ out_indices: Sequence[int] = (2, 3, 4),
+ frozen_stages: int = -1,
+ arch_ovewrite: list[list[int]] | None = None,
+ spp_kernal_sizes: Sequence[int] = (5, 9, 13),
+ norm_eval: bool = False,
+ ):
+ """Init."""
+ super().__init__()
+ arch_setting = self.arch_settings[arch]
+ if arch_ovewrite:
+ arch_setting = arch_ovewrite
+ assert set(out_indices).issubset(
+ i for i in range(len(arch_setting) + 1)
+ )
+ if frozen_stages not in range(-1, len(arch_setting) + 1):
+ raise ValueError(
+ "frozen_stages must be in range(-1, "
+ "len(arch_setting) + 1). But received "
+ f"{frozen_stages}"
+ )
+
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.norm_eval = norm_eval
+
+ self.stem = Focus(
+ 3, int(arch_setting[0][0] * widen_factor), kernel_size=3
+ )
+ self.layers = ["stem"]
+
+ for i, (
+ in_channels,
+ out_channels,
+ num_blocks,
+ add_identity,
+ use_spp,
+ ) in enumerate(arch_setting):
+ in_channels = int(in_channels * widen_factor)
+ out_channels = int(out_channels * widen_factor)
+ num_blocks = max(round(num_blocks * deepen_factor), 1)
+ stage: list[nn.Module] = []
+ conv_layer = Conv2d(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ bias=False,
+ norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+ stage.append(conv_layer)
+ if use_spp:
+ spp = SPPBottleneck(
+ out_channels, out_channels, kernel_sizes=spp_kernal_sizes
+ )
+ stage.append(spp)
+ csp_layer = CSPLayer(
+ out_channels,
+ out_channels,
+ num_blocks=num_blocks,
+ add_identity=bool(add_identity),
+ )
+ stage.append(csp_layer)
+ self.add_module(f"stage{i + 1}", nn.Sequential(*stage))
+ self.layers.append(f"stage{i + 1}")
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Initialize weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(
+ m.weight,
+ a=math.sqrt(5),
+ mode="fan_in",
+ nonlinearity="leaky_relu",
+ )
+
+ def _freeze_stages(self) -> None:
+ """Freeze stages."""
+ if self.frozen_stages >= 0:
+ for i in range(self.frozen_stages + 1):
+ m = getattr(self, self.layers[i])
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode: bool = True) -> CSPDarknet:
+ """Override the train mode for the model.
+
+ Args:
+ mode (bool): Whether to set training mode to True.
+ """
+ super().train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+ return self
+
+ def forward(self, images: torch.Tensor) -> list[torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images.
+ """
+ outs = [images, images]
+ x = images
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
diff --git a/vis4d/op/base/dla.py b/vis4d/op/base/dla.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ff9d486a26c0440c680a05fe7d9fac1562405c
--- /dev/null
+++ b/vis4d/op/base/dla.py
@@ -0,0 +1,647 @@
+"""DLA base model."""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Sequence
+
+import torch
+from torch import Tensor, nn
+from torch.utils.checkpoint import checkpoint
+
+from vis4d.common.ckpt import load_model_checkpoint
+
+from .base import BaseModel
+
+BN_MOMENTUM = 0.1
+
+DLA_MODEL_PREFIX = "http://dl.yf.io/dla/models/imagenet"
+
+DLA_MODEL_MAPPING = {
+ "dla34": "dla34-ba72cf86.pth",
+ "dla46_c": "dla46_c-2bfd52c3.pth",
+ "dla46x_c": "dla46x_c-d761bae7.pth",
+ "dla60x_c": "dla60x_c-b870c45c.pth",
+ "dla60": "dla60-24839fc4.pth",
+ "dla60x": "dla60x-d15cacda.pth",
+ "dla102": "dla102-d94d9790.pth",
+ "dla102x": "dla102x-ad62be81.pth",
+ "dla102x2": "dla102x2-262837b6.pth",
+ "dla169": "dla169-0914e092.pth",
+}
+
+DLA_ARCH_SETTINGS = { # pylint: disable=consider-using-namedtuple-or-dataclass
+ "dla34": (
+ (1, 1, 1, 2, 2, 1),
+ (16, 32, 64, 128, 256, 512),
+ False,
+ "BasicBlock",
+ ),
+ "dla46_c": (
+ (1, 1, 1, 2, 2, 1),
+ (16, 32, 64, 64, 128, 256),
+ False,
+ "Bottleneck",
+ ),
+ "dla46x_c": (
+ (1, 1, 1, 2, 2, 1),
+ (16, 32, 64, 64, 128, 256),
+ False,
+ "BottleneckX",
+ ),
+ "dla60x_c": (
+ (1, 1, 1, 2, 3, 1),
+ (16, 32, 64, 64, 128, 256),
+ False,
+ "BottleneckX",
+ ),
+ "dla60": (
+ (1, 1, 1, 2, 3, 1),
+ (16, 32, 128, 256, 512, 1024),
+ False,
+ "Bottleneck",
+ ),
+ "dla60x": (
+ (1, 1, 1, 2, 3, 1),
+ (16, 32, 128, 256, 512, 1024),
+ False,
+ "BottleneckX",
+ ),
+ "dla102": (
+ (1, 1, 1, 3, 4, 1),
+ (16, 32, 128, 256, 512, 1024),
+ True,
+ "Bottleneck",
+ ),
+ "dla102x": (
+ (1, 1, 1, 3, 4, 1),
+ (16, 32, 128, 256, 512, 1024),
+ True,
+ "BottleneckX",
+ ),
+ "dla102x2": (
+ (1, 1, 1, 3, 4, 1),
+ (16, 32, 128, 256, 512, 1024),
+ True,
+ "BottleneckX",
+ ),
+ "dla169": (
+ (1, 1, 2, 3, 5, 1),
+ (16, 32, 128, 256, 512, 1024),
+ True,
+ "Bottleneck",
+ ),
+}
+
+
+class BasicBlock(nn.Module):
+ """BasicBlock."""
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ with_cp: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.stride = stride
+ self.with_cp = with_cp
+
+ def forward(
+ self, input_x: Tensor, residual: None | Tensor = None
+ ) -> Tensor:
+ """Forward."""
+
+ def _inner_forward(
+ input_x: Tensor, residual: None | Tensor = None
+ ) -> Tensor:
+ if residual is None:
+ residual = input_x
+ out = self.conv1(input_x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and input_x.requires_grad:
+ out = checkpoint(
+ _inner_forward, input_x, residual, use_reentrant=True
+ )
+ else:
+ out = _inner_forward(input_x, residual)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck."""
+
+ expansion = 2
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ with_cp: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ expansion = Bottleneck.expansion
+ bottle_planes = planes // expansion
+ self.conv1 = nn.Conv2d(
+ inplanes, bottle_planes, kernel_size=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(
+ bottle_planes,
+ bottle_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(
+ bottle_planes, planes, kernel_size=1, bias=False
+ )
+ self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.stride = stride
+ self.with_cp = with_cp
+
+ def forward(
+ self, input_x: Tensor, residual: None | Tensor = None
+ ) -> Tensor:
+ """Forward."""
+
+ def _inner_forward(
+ input_x: Tensor, residual: None | Tensor = None
+ ) -> Tensor:
+ if residual is None:
+ residual = input_x
+
+ out = self.conv1(input_x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and input_x.requires_grad:
+ out = checkpoint(
+ _inner_forward, input_x, residual, use_reentrant=True
+ )
+ else:
+ out = _inner_forward(input_x, residual)
+
+ out = self.relu(out)
+
+ return out
+
+
+class BottleneckX(nn.Module):
+ """BottleneckX."""
+
+ expansion = 2
+ cardinality = 32
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ with_cp: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ cardinality = BottleneckX.cardinality
+ bottle_planes = planes * cardinality // 32
+ self.conv1 = nn.Conv2d(
+ inplanes, bottle_planes, kernel_size=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(
+ bottle_planes,
+ bottle_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ groups=cardinality,
+ )
+ self.bn2 = nn.BatchNorm2d(bottle_planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(
+ bottle_planes, planes, kernel_size=1, bias=False
+ )
+ self.bn3 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.stride = stride
+ self.with_cp = with_cp
+
+ def forward(
+ self, input_x: Tensor, residual: None | Tensor = None
+ ) -> Tensor:
+ """Forward."""
+
+ def _inner_forward(
+ input_x: Tensor, residual: None | Tensor = None
+ ) -> Tensor:
+ if residual is None:
+ residual = input_x
+
+ out = self.conv1(input_x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and input_x.requires_grad:
+ out = checkpoint(
+ _inner_forward, input_x, residual, use_reentrant=True
+ )
+ else:
+ out = _inner_forward(input_x, residual)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Root(nn.Module):
+ """Root."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ residual: bool,
+ with_cp: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ stride=1,
+ bias=False,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.bn = nn.BatchNorm2d( # pylint: disable=invalid-name
+ out_channels, momentum=BN_MOMENTUM
+ )
+ self.relu = nn.ReLU(inplace=True)
+ self.residual = residual
+ self.with_cp = with_cp
+
+ def forward(self, *input_x: Tensor) -> Tensor:
+ """Forward."""
+
+ def _inner_forward(*input_x: Tensor) -> Tensor:
+ feats = self.conv(torch.cat(input_x, 1))
+ feats = self.bn(feats)
+ if self.residual:
+ feats += input_x[0]
+ return feats
+
+ if self.with_cp and input_x[0].requires_grad:
+ feats = checkpoint(_inner_forward, *input_x, use_reentrant=True)
+ else:
+ feats = _inner_forward(*input_x)
+
+ feats = self.relu(feats)
+
+ return feats
+
+
+class Tree(nn.Module):
+ """Tree."""
+
+ def __init__( # pylint: disable=too-many-arguments
+ self,
+ levels: int,
+ block: str,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ level_root: bool = False,
+ root_dim: int = 0,
+ root_kernel_size: int = 1,
+ dilation: int = 1,
+ root_residual: bool = False,
+ with_cp: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ if block == "BasicBlock":
+ block_c = BasicBlock
+ elif block == "Bottleneck":
+ block_c = Bottleneck # type: ignore
+ elif block == "BottleneckX":
+ block_c = BottleneckX # type: ignore
+ else:
+ raise ValueError(f"Block={block} not yet supported in DLA!")
+ if root_dim == 0:
+ root_dim = 2 * out_channels
+ if level_root:
+ root_dim += in_channels
+ if levels == 1:
+ self.tree1: Tree | BasicBlock = block_c(
+ in_channels,
+ out_channels,
+ stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ )
+ self.tree2: Tree | BasicBlock = block_c(
+ out_channels,
+ out_channels,
+ 1,
+ dilation=dilation,
+ with_cp=with_cp,
+ )
+ self.root = Root(
+ root_dim,
+ out_channels,
+ root_kernel_size,
+ root_residual,
+ with_cp=with_cp,
+ )
+ else:
+ self.tree1 = Tree(
+ levels - 1,
+ block,
+ in_channels,
+ out_channels,
+ stride,
+ root_dim=0,
+ root_kernel_size=root_kernel_size,
+ dilation=dilation,
+ root_residual=root_residual,
+ with_cp=with_cp,
+ )
+ self.tree2 = Tree(
+ levels - 1,
+ block,
+ out_channels,
+ out_channels,
+ root_dim=root_dim + out_channels,
+ root_kernel_size=root_kernel_size,
+ dilation=dilation,
+ root_residual=root_residual,
+ with_cp=with_cp,
+ )
+ self.level_root = level_root
+ self.root_dim = root_dim
+ self.downsample = None
+ self.project = None
+ self.levels = levels
+ if stride > 1:
+ self.downsample = nn.MaxPool2d(stride, stride=stride)
+ if in_channels != out_channels and levels == 1:
+ # NOTE the official impl/weights have project layers in levels > 1
+ # case that are never used, hence 'levels == 1' is added but
+ # pretrained models will need strict=False while loading.
+ self.project = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels),
+ )
+
+ def forward(
+ self,
+ input_x: Tensor,
+ residual: None | Tensor = None,
+ children: None | list[Tensor] = None,
+ ) -> Tensor:
+ """Forward."""
+ children = [] if children is None else children
+ bottom = self.downsample(input_x) if self.downsample else input_x
+ residual = self.project(bottom) if self.project else bottom
+ if self.level_root:
+ children.append(bottom)
+ input_x1 = self.tree1(input_x, residual)
+ if self.levels == 1:
+ input_x2 = self.tree2(input_x1)
+ input_x = self.root(input_x2, input_x1, *children)
+ else:
+ children.append(input_x1)
+ input_x = self.tree2(input_x1, children=children)
+ return input_x
+
+
+class DLA(BaseModel):
+ """DLA base model."""
+
+ def __init__(
+ self,
+ name: str,
+ out_indices: Sequence[int] = (0, 1, 2, 3),
+ with_cp: bool = False,
+ pretrained: bool = False,
+ weights: None | str = None,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ assert name in DLA_ARCH_SETTINGS, f"{name} is not supported!"
+
+ levels, channels, residual_root, block = DLA_ARCH_SETTINGS[name]
+
+ if name == "dla102x2": # pragma: no cover
+ BottleneckX.cardinality = 64
+
+ self.base_layer = nn.Sequential(
+ nn.Conv2d(
+ 3, channels[0], kernel_size=7, stride=1, padding=3, bias=False
+ ),
+ nn.BatchNorm2d(channels[0], momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ )
+ self.level0 = self._make_conv_level(
+ channels[0], channels[0], levels[0]
+ )
+ self.level1 = self._make_conv_level(
+ channels[0], channels[1], levels[1], stride=2
+ )
+ self.level2 = Tree(
+ levels[2],
+ block,
+ channels[1],
+ channels[2],
+ 2,
+ level_root=False,
+ root_residual=residual_root,
+ with_cp=with_cp,
+ )
+ self.level3 = Tree(
+ levels[3],
+ block,
+ channels[2],
+ channels[3],
+ 2,
+ level_root=True,
+ root_residual=residual_root,
+ with_cp=with_cp,
+ )
+ self.level4 = Tree(
+ levels[4],
+ block,
+ channels[3],
+ channels[4],
+ 2,
+ level_root=True,
+ root_residual=residual_root,
+ with_cp=with_cp,
+ )
+ self.level5 = Tree(
+ levels[5],
+ block,
+ channels[4],
+ channels[5],
+ 2,
+ level_root=True,
+ root_residual=residual_root,
+ with_cp=with_cp,
+ )
+
+ self.out_indices = out_indices
+ self._out_channels = [channels[i + 2] for i in out_indices]
+
+ if pretrained:
+ if weights is None: # pragma: no cover
+ weights = f"{DLA_MODEL_PREFIX}/{DLA_MODEL_MAPPING[name]}"
+
+ load_model_checkpoint(self, weights)
+
+ else:
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Initialize module weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ @staticmethod
+ def _make_conv_level(
+ inplanes: int,
+ planes: int,
+ convs: int,
+ stride: int = 1,
+ dilation: int = 1,
+ ) -> nn.Sequential:
+ """Build convolutional level."""
+ modules = []
+ for i in range(convs):
+ modules.extend(
+ [
+ nn.Conv2d(
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ ),
+ nn.BatchNorm2d(planes, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ ]
+ )
+ inplanes = planes
+ return nn.Sequential(*modules)
+
+ def forward(self, images: Tensor) -> list[Tensor]:
+ """DLA forward.
+
+ Args:
+ images (Tensor[N, C, H, W]): Image input to process. Expected to
+ type float32 with values ranging 0..255.
+
+ Returns:
+ fp (list[Tensor]): The output feature pyramid. The list index
+ represents the level, which has a downsampling raio of 2^index.
+ """
+ input_x = self.base_layer(images)
+
+ outs = [images, images]
+
+ for i in range(6):
+ input_x = getattr(self, f"level{i}")(input_x)
+
+ if i - 2 in self.out_indices:
+ outs.append(input_x)
+
+ return outs
+
+ @property
+ def out_channels(self) -> list[int]:
+ """Get the numbers of channels for each level of feature pyramid.
+
+ Returns:
+ list[int]: number of channels
+ """
+ return [3, 3] + self._out_channels
diff --git a/vis4d/op/base/pointnet.py b/vis4d/op/base/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..39eb638c84fa54ec8d2cbe42af91dc5fdbc854b7
--- /dev/null
+++ b/vis4d/op/base/pointnet.py
@@ -0,0 +1,408 @@
+"""Operations for PointNet.
+
+Code taken from
+https://github.com/timothylimyl/PointNet-Pytorch/blob/master/pointnet/model.py
+and modified to allow for modular configuration.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Iterable
+from typing import NamedTuple
+
+import torch
+from torch import nn
+
+from vis4d.common.typing import ArgsType
+
+
+class PointNetEncoderOut(NamedTuple):
+ """Output of the PointNetEncoder.
+
+ features: Global features shape [N, feature_dim]
+ pointwise Features: Pointwise features shape [N, last_mlp_dim, n_pts]
+ transformations: list with all transformation matrixes that were used.
+ Shape [N, d, d]
+ """
+
+ features: torch.Tensor
+ pointwise_features: torch.Tensor #
+ transformations: list[ # list with all transformation matrices [[B, d, d]]
+ torch.Tensor
+ ]
+
+
+class PointNetSemanticsLoss(NamedTuple):
+ """Losses for the pointnet semantic segmentation network."""
+
+ semantic_loss: torch.Tensor
+ regularization_loss: torch.Tensor
+
+
+class PointNetSemanticsOut(NamedTuple):
+ """Output of the PointNet Segmentation network."""
+
+ class_logits: torch.Tensor # B, n_classes, n_pts
+ transformations: list[ # list with all transformation matrices [[B, d, d]]
+ torch.Tensor
+ ]
+
+
+class LinearTransform(nn.Module):
+ """Module that learns a linear transformation for a input pointcloud.
+
+ Code taken from
+ https://github.com/timothylimyl/PointNet-Pytorch/blob/master/pointnet/model.py
+ and modified to allow for modular configuration.
+
+ See T-Net in Pointnet publication (https://arxiv.org/pdf/1612.00593.pdf)
+ for more information
+ """
+
+ def __init__(
+ self,
+ in_dimension: int = 3,
+ upsampling_dims: Iterable[int] = (64, 128, 1024),
+ downsampling_dims: Iterable[int] = (1024, 512, 256),
+ norm_cls: str | None = "BatchNorm1d",
+ activation_cls: str = "ReLU",
+ ) -> None:
+ """Creates a new LinearTransform.
+
+ This learns a transformation matrix from data.
+
+ Args:
+ in_dimension (int): input dimension
+ upsampling_dims (Iterable[int]): list of intermediate feature
+ shapes for upsampling
+ downsampling_dims (Iterable[int]): list of intermediate feature
+ shapes for downsampling.
+ Make sure this matches with the
+ last upsampling_dims
+ norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
+ activation_cls (str): class for activation (nn.'activation_cls')
+ """
+ super().__init__()
+ self.upsampling_dims = list(upsampling_dims)
+ self.downsampling_dims = list(downsampling_dims)
+
+ assert (
+ len(self.upsampling_dims) != 0 and len(self.downsampling_dims) != 0
+ )
+ assert self.upsampling_dims[-1] == self.downsampling_dims[0]
+
+ self.in_dimension_ = in_dimension
+ self.identity: torch.Tensor
+ self.register_buffer(
+ "identity", torch.eye(in_dimension).reshape(1, in_dimension**2)
+ )
+
+ # Create activation
+ self.activation_ = getattr(nn, activation_cls)()
+
+ # Create norms
+ norm_fn: Callable[[int], nn.Module] | None = (
+ getattr(nn, norm_cls) if norm_cls is not None else None
+ )
+
+ if norm_fn is not None:
+ self.norms_ = nn.ModuleList(
+ norm_fn(feature_size)
+ for feature_size in (
+ *upsampling_dims,
+ *self.downsampling_dims[1:],
+ )
+ )
+
+ # Create upsampling layers
+ self.upsampling_layers = nn.ModuleList(
+ [nn.Conv1d(in_dimension, self.upsampling_dims[0], 1)]
+ )
+ for i in range(len(self.upsampling_dims) - 1):
+ self.upsampling_layers.append(
+ nn.Conv1d(
+ self.upsampling_dims[i], self.upsampling_dims[i + 1], 1
+ )
+ )
+
+ # Create downsampling layers
+ self.downsampling_layers = nn.ModuleList(
+ [
+ nn.Linear(
+ self.downsampling_dims[i], self.downsampling_dims[i + 1]
+ )
+ for i in range(len(self.downsampling_dims) - 1)
+ ]
+ )
+ self.downsampling_layers.append(
+ nn.Linear(self.downsampling_dims[-1], in_dimension**2)
+ )
+
+ def __call__(
+ self,
+ features: torch.Tensor,
+ ) -> torch.Tensor:
+ """Type definition for call implementation."""
+ return self._call_impl(features)
+
+ def forward(
+ self,
+ features: torch.Tensor,
+ ) -> torch.Tensor:
+ """Linear Transform forward.
+
+ Args:
+ features (Tensor[B, C, N]): Input features (e.g. points)
+
+ Returns:
+ Learned Canonical Transfomation Matrix for this input.
+ See T-Net in Pointnet publication
+ (https://arxiv.org/pdf/1612.00593.pdf)
+ for further information
+ """
+ batchsize = features.shape[0]
+ # Upsample features
+ for idx, layer in enumerate(self.upsampling_layers):
+ features = layer(features)
+ if self.norms_ is not None:
+ features = self.norms_[idx](features)
+ features = self.activation_(features)
+
+ features = torch.max(features, 2, keepdim=True)[0]
+ features = features.view(-1, self.upsampling_dims[-1])
+
+ # Downsample features
+ for idx, layer in enumerate(self.downsampling_layers):
+ features = layer(features)
+
+ # Do not apply norm and activation for
+ # final layer
+ if idx != len(self.downsampling_layers) - 1:
+ if self.norms_ is not None:
+ norm_idx = idx + len(self.upsampling_layers)
+ features = self.norms_[norm_idx](features)
+ features = self.activation_(features)
+
+ identity_batch = self.identity.repeat(batchsize, 1)
+ transformations = features + identity_batch
+
+ return transformations.view(
+ batchsize, self.in_dimension_, self.in_dimension_
+ )
+
+
+class PointNetEncoder(nn.Module):
+ """PointNetEncoder.
+
+ Encodes a pointcloud and additional features into one feature description
+
+ See pointnet publication for more information
+ (https://arxiv.org/pdf/1612.00593.pdf)
+ """
+
+ def __init__(
+ self,
+ in_dimensions: int = 3,
+ out_dimensions: int = 1024,
+ mlp_dimensions: Iterable[Iterable[int]] = ((64, 64), (64, 128)),
+ norm_cls: str | None = "BatchNorm1d",
+ activation_cls: str = "ReLU",
+ **kwargs: ArgsType,
+ ):
+ """Creates a new PointNetEncoder.
+
+ Args:
+ in_dimensions (int): input dimension (e.g. 3 for xzy, 6 for xzyrgb)
+ out_dimensions (int): output dimensions
+ mlp_dimensions (Iterable[Iterable[int]]):(Dimensions of MLP layers)
+ norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
+ activation_cls (str): class for activation (nn.'activation_cls')
+ kwargs : See arguments of @LinearTransformStn
+ """
+ super().__init__()
+
+ self.out_dimension = out_dimensions
+
+ # Extend dimensions to upscale from input dimension
+ mlp_dim_list: list[list[int]] = [list(d) for d in mlp_dimensions]
+ mlp_dim_list[0].insert(0, in_dimensions)
+ mlp_dim_list[-1].append(out_dimensions)
+ self.mlp_dimensions = mlp_dim_list
+
+ # Learnable transformation layers.
+ self.trans_layers_ = nn.ModuleList(
+ [
+ LinearTransform(
+ in_dimension=dims[0],
+ norm_cls=norm_cls,
+ activation_cls=activation_cls,
+ **kwargs,
+ )
+ for dims in mlp_dim_list
+ ]
+ )
+
+ # MLP layers
+ self.mlp_layers_ = nn.ModuleList()
+
+ # Create activation
+ activation = getattr(nn, activation_cls)()
+
+ # Create norms
+ norm_fn: Callable[[int], nn.Module] | None = (
+ getattr(nn, norm_cls) if norm_cls is not None else None
+ )
+
+ for mlp_idx, mlp_dims in enumerate(mlp_dim_list):
+ layers: list[nn.Module] = []
+
+ for idx, (in_dim, out_dim) in enumerate(
+ zip(mlp_dims[:-1], mlp_dims[1:])
+ ):
+ # Create MLP
+ layers.append(torch.nn.Conv1d(in_dim, out_dim, 1))
+ # Create BN if needed
+ if norm_fn is not None:
+ layers.append(norm_fn(out_dim))
+
+ # Only add activation if not last layer
+ if (
+ mlp_idx != len(mlp_dim_list) - 1
+ and idx != len(mlp_dims) - 2
+ ):
+ layers.append(activation)
+
+ self.mlp_layers_.append(nn.Sequential(*layers))
+
+ def __call__(self, features: torch.Tensor) -> PointNetEncoderOut:
+ """Type definition for call implementation."""
+ return self._call_impl(features)
+
+ def forward(self, features: torch.Tensor) -> PointNetEncoderOut:
+ """Pointnet encoder forward.
+
+ Args:
+ features (Tensor[B, C, N]): Input features stacked in channels.
+ e.g. raw point inputs: [B, 3, N] , w color : [B, 3+3, N], ...
+
+ Returns:
+ Extracted feature representation for input and all
+ applied transformations.
+ """
+ transforms: list[torch.Tensor] = []
+
+ for block_idx, trans_layer in enumerate(self.trans_layers_):
+ # Apply transformation
+ trans = trans_layer(features)
+ transforms.append(trans)
+ features = features.transpose(2, 1)
+ features = torch.bmm(features, trans)
+ features = features.transpose(2, 1)
+
+ if block_idx == len(self.trans_layers_) - 1:
+ pointwise_features = features.clone()
+
+ # Apply MLP
+ features = self.mlp_layers_[block_idx](features)
+
+ features = torch.max(features, 2, keepdim=True)[0]
+ features = features.view(-1, self.out_dimension)
+
+ return PointNetEncoderOut(
+ features=features,
+ transformations=transforms,
+ pointwise_features=pointwise_features, # pylint: disable=possibly-used-before-assignment, line-too-long
+ )
+
+
+class PointNetSegmentation(nn.Module):
+ """Segmentation network using a simple pointnet as encoder."""
+
+ def __init__(
+ self,
+ n_classes: int,
+ in_dimensions: int = 3,
+ feature_dimension: int = 1024,
+ norm_cls: str = "BatchNorm1d",
+ activation_cls: str = "ReLU",
+ ):
+ """Creates a new Point Net segementation network.
+
+ Args:
+ n_classes (int): Number of semantic classes
+ in_dimensions (int): Input dimension (3 for xyz, 6 xyzrgb, ...)
+ feature_dimension (int): Size of feature from the encoder
+ norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
+ activation_cls (str): class for activation (nn.'activation_cls')
+
+ Raises:
+ ValueError: If dimensions are invalid
+ """
+ super().__init__()
+ self.in_dimensions = in_dimensions
+
+ self.encoder = PointNetEncoder(
+ in_dimensions=in_dimensions,
+ out_dimensions=feature_dimension,
+ norm_cls=norm_cls,
+ activation_cls=activation_cls,
+ )
+ pc_feat_dim = self.encoder.mlp_dimensions[-1][0]
+
+ # Create activation
+ activation = getattr(nn, activation_cls)()
+
+ # Create norms
+ norm_fn: Callable[[int], nn.Module] = (
+ getattr(nn, norm_cls) if norm_cls is not None else None
+ )
+ self.classifier_dims = [feature_dimension + pc_feat_dim, 512, 256, 128]
+ # Build Model
+ self.classifier = nn.Sequential()
+ for in_dim, out_dim in zip(
+ self.classifier_dims[:-1], self.classifier_dims[1:]
+ ):
+ self.classifier.append(nn.Conv1d(in_dim, out_dim, 1))
+ if norm_fn is not None:
+ self.classifier.append(norm_fn(out_dim))
+ self.classifier.append(activation)
+
+ self.classifier.append(
+ nn.Conv1d(
+ out_dim, # pylint: disable=undefined-loop-variable
+ n_classes,
+ 1,
+ )
+ )
+
+ def __call__(self, points: torch.Tensor) -> PointNetSemanticsOut:
+ """Call function."""
+ return self._call_impl(points)
+
+ def forward(self, points: torch.Tensor) -> PointNetSemanticsOut:
+ """Pointnet Segmenter Forward.
+
+ Args:
+ points (tensor) : inputs points dimension [B, in_dim, n_pts]
+
+ Returns:
+ Returns a list of tensors where the first element is
+ the desired segmentation [B, n_classes, n_pts] and the other
+ elements are the linear transformation matrices which
+ have been used to transform the pointclouds
+ @see LinearTransform
+ """
+ assert points.size(-2) == self.in_dimensions
+ n_pts = points.size(-1)
+ bs = points.size(0)
+ encoder_out = self.encoder(points)
+ global_features = encoder_out.features.view(bs, -1, 1).repeat(
+ 1, 1, n_pts
+ )
+
+ x = torch.cat([global_features, encoder_out.pointwise_features], 1)
+
+ x = self.classifier(x)
+ return PointNetSemanticsOut(
+ class_logits=x, transformations=encoder_out.transformations
+ )
diff --git a/vis4d/op/base/pointnetpp.py b/vis4d/op/base/pointnetpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f4105fa97f2aafd6eae667c19017d13ef63244
--- /dev/null
+++ b/vis4d/op/base/pointnetpp.py
@@ -0,0 +1,498 @@
+"""Pointnet++ implementation.
+
+based on https://github.com/yanx27/Pointnet_Pointnet2_pytorch
+Added typing and named tuples for convenience.
+
+#TODO write tests
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import NamedTuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class PointNetSetAbstractionOut(NamedTuple):
+ """Ouput of PointNet set abstraction."""
+
+ coordinates: Tensor # [B, C, S]
+ features: Tensor # [B, D', S]
+
+
+def square_distance(src: Tensor, dst: Tensor) -> Tensor:
+ """Calculate Euclid distance between each two points.
+
+ src^T * dst = xn * xm + yn * ym + zn * zm;
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
+
+ Input:
+ src: source points, [B, N, C]
+ dst: target points, [B, M, C]
+
+ Output:
+ dist: per-point square distance, [B, N, M]
+ """
+ bs, n_pts_in, _ = src.shape
+ _, n_pts_out, _ = dst.shape
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
+ dist += torch.sum(src**2, -1).view(bs, n_pts_in, 1)
+ dist += torch.sum(dst**2, -1).view(bs, 1, n_pts_out)
+ return dist
+
+
+def index_points(points: Tensor, idx: Tensor) -> Tensor:
+ """Indexes points.
+
+ Input:
+ points: input points data, [B, N, C]
+ idx: sample index data, [B, S]
+
+ Return:
+ new_points:, indexed points data, [B, S, C]
+ """
+ device = points.device
+ bs = points.shape[0]
+ view_shape = list(idx.shape)
+ view_shape[1:] = [1] * (len(view_shape) - 1)
+ repeat_shape = list(idx.shape)
+ repeat_shape[0] = 1
+ batch_indices = (
+ torch.arange(bs, dtype=torch.long)
+ .to(device)
+ .view(view_shape)
+ .repeat(repeat_shape)
+ )
+ new_points = points[batch_indices, idx, :]
+ return new_points
+
+
+def farthest_point_sample(xyz: Tensor, npoint: int) -> Tensor:
+ """Farthest point sampling.
+
+ Input:
+ xyz: pointcloud data, [B, N, 3]
+ npoint: number of samples
+
+ Return:
+ centroids: sampled pointcloud index, [B, npoint]
+ """
+ device = xyz.device
+ bs, n_pts, _ = xyz.shape
+ centroids = torch.zeros(bs, npoint, dtype=torch.long).to(device)
+ distance = torch.ones(bs, n_pts).to(device) * 1e10
+ farthest = torch.randint(0, n_pts, (bs,), dtype=torch.long).to(device)
+ batch_indices = torch.arange(bs, dtype=torch.long).to(device)
+ for i in range(npoint):
+ centroids[:, i] = farthest
+ centroid = xyz[batch_indices, farthest, :].view(bs, 1, 3)
+ dist = torch.sum((xyz - centroid) ** 2, -1)
+ mask = dist < distance
+ distance[mask] = dist[mask]
+ farthest = torch.max(distance, -1)[1]
+ return centroids
+
+
+def query_ball_point(
+ radius: float, nsample: int, xyz: Tensor, new_xyz: Tensor
+) -> Tensor:
+ """Query around a ball with given radius.
+
+ Input:
+ radius: local region radius
+ nsample: max sample number in local region
+ xyz: all points, [B, N, 3]
+ new_xyz: query points, [B, S, 3]
+
+ Return:
+ group_idx: grouped points index, [B, S, nsample]
+ """
+ device = xyz.device
+ bs, n_pts_in, _ = xyz.shape
+ _, n_pts_out, _ = new_xyz.shape
+ group_idx = (
+ torch.arange(n_pts_in, dtype=torch.long)
+ .to(device)
+ .view(1, 1, n_pts_in)
+ .repeat([bs, n_pts_out, 1])
+ )
+ sqrdists = square_distance(new_xyz, xyz)
+ group_idx[sqrdists > radius**2] = n_pts_in
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
+ group_first = (
+ group_idx[:, :, 0].view(bs, n_pts_out, 1).repeat([1, 1, nsample])
+ )
+ mask = group_idx == n_pts_in
+ group_idx[mask] = group_first[mask]
+ return group_idx
+
+
+def sample_and_group(
+ npoint: int,
+ radius: float,
+ nsample: int,
+ xyz: Tensor,
+ points: Tensor,
+) -> tuple[Tensor, Tensor]:
+ """Samples and groups.
+
+ Input:
+ npoint: Number of center to sample
+ radius: Grouping Radius
+ nsample: Max number of points to sample for each circle
+ xyz: input points position data, [B, N, 3]
+ points: input points data, [B, N, D]
+
+ Return:
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
+ """
+ bs, _, channels = xyz.shape
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
+ new_xyz = index_points(xyz, fps_idx)
+ idx = query_ball_point(radius, nsample, xyz, new_xyz)
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(bs, npoint, 1, channels)
+
+ if points is not None:
+ grouped_points = index_points(points, idx)
+ new_points = torch.cat(
+ [grouped_xyz_norm, grouped_points], dim=-1
+ ) # [B, npoint, nsample, C+D]
+ else:
+ new_points = grouped_xyz_norm
+ return new_xyz, new_points
+
+
+def sample_and_group_all(xyz: Tensor, points: Tensor) -> tuple[Tensor, Tensor]:
+ """Sample and groups all.
+
+ Input:
+ xyz: input points position data, [B, N, 3]
+ points: input points data, [B, N, D]
+
+ Return:
+ new_xyz: sampled points position data, [B, 1, 3]
+ new_points: sampled points data, [B, 1, N, 3+D]
+ """
+ device = xyz.device
+ bs, n_pts, channels = xyz.shape
+ new_xyz = torch.zeros(bs, 1, channels).to(device)
+ grouped_xyz = xyz.view(bs, 1, n_pts, channels)
+ if points is not None:
+ new_points = torch.cat(
+ [grouped_xyz, points.view(bs, 1, n_pts, -1)], dim=-1
+ )
+ else:
+ new_points = grouped_xyz
+ return new_xyz, new_points
+
+
+class PointNetSetAbstraction(nn.Module):
+ """PointNet set abstraction layer."""
+
+ def __init__(
+ self,
+ npoint: int,
+ radius: float,
+ nsample: int,
+ in_channel: int,
+ mlp: list[int],
+ group_all: bool,
+ norm_cls: str | None = "BatchNorm2d",
+ ):
+ """Set Abstraction Layer from the Pointnet Architecture.
+
+ Args:
+ npoint: How many points to sample
+ radius: Size of the ball query
+ nsample: Max number of points to group inside circle
+ in_channel: Input channel dimension
+ mlp: Input channel dimension of the mlp layers.
+ E.g. [32 , 32, 64] will use a MLP with three layers
+ group_all: If true, groups all point inside the ball, otherwise
+ samples 'nsample' points.
+ norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
+ """
+ super().__init__()
+ self.npoint = npoint
+ self.radius = radius
+ self.nsample = nsample
+ self.mlp_convs = nn.ModuleList()
+ self.mlp_bns = nn.ModuleList()
+ last_channel = in_channel
+
+ # Create norms
+ norm_fn: Callable[[int], nn.Module] | None = (
+ getattr(nn, norm_cls) if norm_cls is not None else None
+ )
+
+ for out_channel in mlp:
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
+ if norm_fn is not None:
+ self.mlp_bns.append(norm_fn(out_channel))
+ last_channel = out_channel
+ self.group_all = group_all
+
+ def __call__(
+ self, coordinates: Tensor, features: Tensor
+ ) -> PointNetSetAbstractionOut:
+ """Call function.
+
+ Input:
+ coordinates: input points position data, [B, C, N]
+ features: input points data, [B, D, N]
+
+ Return:
+ PointNetSetAbstractionOut with:
+ coordinates: sampled points position data, [B, C, S]
+ features: sample points feature data, [B, D', S]
+ """
+ return self._call_impl(coordinates, features)
+
+ def forward(
+ self, xyz: Tensor, points: Tensor
+ ) -> PointNetSetAbstractionOut:
+ """Pointnet++ set abstraction layer forward.
+
+ Input:
+ xyz: input points position data, [B, C, N]
+ points: input points data, [B, D, N]
+
+ Return:
+ PointNetSetAbstractionOut with:
+ coordinates: sampled points position data, [B, C, S]
+ features: sample points feature data, [B, D', S]
+ """
+ xyz = xyz.permute(0, 2, 1)
+ if points is not None:
+ points = points.permute(0, 2, 1)
+
+ if self.group_all:
+ new_xyz, new_points = sample_and_group_all(xyz, points)
+ else:
+ new_xyz, new_points = sample_and_group(
+ self.npoint, self.radius, self.nsample, xyz, points
+ )
+ # new_xyz: sampled points position data, [B, npoint, C]
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
+ for i, conv in enumerate(self.mlp_convs):
+ bn = self.mlp_bns[i] if len(self.mlp_bns) != 0 else lambda x: x
+ new_points = F.relu(bn(conv(new_points)))
+
+ new_points = torch.max(new_points, 2)[0]
+ new_xyz = new_xyz.permute(0, 2, 1)
+ return PointNetSetAbstractionOut(new_xyz, new_points)
+
+
+class PointNetFeaturePropagation(nn.Module):
+ """Pointnet++ Feature Propagation Layer."""
+
+ def __init__(
+ self,
+ in_channel: int,
+ mlp: list[int],
+ norm_cls: str = "BatchNorm1d",
+ ):
+ """Creates a pointnet++ feature propagation layer.
+
+ Args:
+ in_channel: Number of input channels
+ mlp: list with hidden dimensions of the MLP.
+ norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
+ """
+ super().__init__()
+ self.mlp_convs = nn.ModuleList()
+ self.mlp_bns = nn.ModuleList()
+
+ # Create norms
+ norm_fn: Callable[[int], nn.Module] = (
+ getattr(nn, norm_cls) if norm_cls is not None else None
+ )
+ last_channel = in_channel
+ for out_channel in mlp:
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
+ if norm_cls is not None:
+ self.mlp_bns.append(norm_fn(out_channel))
+ last_channel = out_channel
+
+ def __call__(
+ self,
+ xyz1: Tensor,
+ xyz2: Tensor,
+ points1: Tensor | None,
+ points2: Tensor,
+ ) -> Tensor:
+ """Call function.
+
+ Input:
+ xyz1: input points position data, [B, C, N]
+ xyz2: sampled input points position data, [B, C, S]
+ points1: input points features, [B, D, N]
+ points2: sampled points features, [B, D, S]
+
+ Return:
+ new_points: upsampled points data, [B, D', N]
+ """
+ return self._call_impl(xyz1, xyz2, points1, points2)
+
+ def forward(
+ self,
+ xyz1: Tensor,
+ xyz2: Tensor,
+ points1: Tensor | None,
+ points2: Tensor,
+ ) -> Tensor:
+ """Forward Implementation.
+
+ Input:
+ xyz1: input points position data, [B, C, N]
+ xyz2: sampled input points position data, [B, C, S]
+ points1: input points features, [B, D, N]
+ points2: sampled points features, [B, D, S]
+
+ Return:
+ new_points: upsampled points data, [B, D', N]
+ """
+ xyz1 = xyz1.permute(0, 2, 1)
+ xyz2 = xyz2.permute(0, 2, 1)
+
+ points2 = points2.permute(0, 2, 1)
+ bs, n_pts, _ = xyz1.shape
+ _, n_out_pts, _ = xyz2.shape
+
+ if n_out_pts == 1:
+ interpolated_points = points2.repeat(1, n_pts, 1)
+ else:
+ dists = square_distance(xyz1, xyz2)
+ dists, idx = dists.sort(dim=-1)
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
+
+ dist_recip: Tensor = 1.0 / (dists + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+ interpolated_points = torch.sum(
+ index_points(points2, idx) * weight.view(bs, n_pts, 3, 1),
+ dim=2,
+ )
+
+ if points1 is not None:
+ points1 = points1.permute(0, 2, 1)
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
+ else:
+ new_points = interpolated_points
+
+ new_points = new_points.permute(0, 2, 1)
+ for i, conv in enumerate(self.mlp_convs):
+ bn = self.mlp_bns[i] if len(self.mlp_bns) != 0 else lambda x: x
+ new_points = F.relu(bn(conv(new_points)))
+ return new_points
+
+
+class PointNet2SegmentationOut(NamedTuple):
+ """Prediction for the pointnet++ semantic segmentation network."""
+
+ class_logits: Tensor
+
+
+class PointNet2Segmentation(nn.Module): # TODO, probably move to module?
+ """Pointnet++ Segmentation Network."""
+
+ def __init__(self, num_classes: int, in_channels: int = 3):
+ """Creates a new Pointnet++ for segmentation.
+
+ Args:
+ num_classes: Number of semantic classes
+ in_channels: Number of input channels
+ """
+ super().__init__()
+
+ self.set_abstractions = [
+ PointNetSetAbstraction(
+ 1024, 0.1, 32, in_channels + 3, [32, 32, 64], False
+ ),
+ PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False),
+ PointNetSetAbstraction(
+ 64, 0.4, 32, 128 + 3, [128, 128, 256], False
+ ),
+ PointNetSetAbstraction(
+ 16, 0.8, 32, 256 + 3, [256, 256, 512], False
+ ),
+ ]
+
+ self.feature_propagations = [
+ PointNetFeaturePropagation(768, [256, 256]),
+ PointNetFeaturePropagation(384, [256, 256]),
+ PointNetFeaturePropagation(320, [256, 128]),
+ PointNetFeaturePropagation(128 + 3, [128, 128, 128]),
+ ]
+
+ # Final convolutions
+ self.conv1 = nn.Conv1d(128, 128, 1)
+ self.bn1 = nn.BatchNorm1d(128)
+ self.drop1 = nn.Dropout(0.5)
+ self.conv2 = nn.Conv1d(128, num_classes, 1)
+ self.in_channels = in_channels
+
+ def __call__(self, xyz: Tensor) -> PointNet2SegmentationOut:
+ """Call implementation.
+
+ Args:
+ xyz: Pointcloud data shaped [N, n_feats, n_pts]
+
+ Returns:
+ PointNet2SegmentationOut, class logits for each point
+ """
+ return self._call_impl(xyz)
+
+ def forward(self, xyz: Tensor) -> PointNet2SegmentationOut:
+ """Predicts the semantic class logits for each point.
+
+ Args:
+ xyz: Pointcloud data shaped [N, n_feats, n_pts]$
+
+ Returns:
+ PointNet2SegmentationOut, class logits for each point
+ """
+ assert xyz.size(1) == self.in_channels
+
+ l0_points = xyz
+ l0_xyz = xyz[:, :3, :]
+
+ set_abstraction_out = PointNetSetAbstractionOut(
+ coordinates=l0_xyz, features=l0_points
+ )
+ outputs: list[PointNetSetAbstractionOut] = [set_abstraction_out]
+
+ for set_abs_layer in self.set_abstractions:
+ set_abstraction_out = set_abs_layer(
+ set_abstraction_out.coordinates, set_abstraction_out.features
+ )
+
+ outputs.append(set_abstraction_out)
+
+ pointwise_features = outputs[-1].features
+ for idx, feature_prop_layer in enumerate(self.feature_propagations):
+ layer_after_out = outputs[-idx - 1] # l4
+ layer_out = outputs[-idx - 2] # l3
+
+ out_features = (
+ layer_out.features if idx < len(outputs) - 1 else None
+ )
+ pointwise_features = feature_prop_layer(
+ layer_out.coordinates,
+ layer_after_out.coordinates,
+ out_features,
+ pointwise_features,
+ )
+
+ x = self.drop1(F.relu(self.bn1(self.conv1(pointwise_features))))
+ x = self.conv2(x)
+ return PointNet2SegmentationOut(class_logits=x)
diff --git a/vis4d/op/base/resnet.py b/vis4d/op/base/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..10e9f51f34f3dfde1cc5c1a6de702abf115f8e02
--- /dev/null
+++ b/vis4d/op/base/resnet.py
@@ -0,0 +1,609 @@
+"""Residual networks base model.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torchvision.models.resnet as _resnet
+from torch import Tensor, nn
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.checkpoint import checkpoint
+
+from vis4d.common.ckpt import load_model_checkpoint
+from vis4d.common.typing import ArgsType
+from vis4d.op.layer.util import build_conv_layer, build_norm_layer
+from vis4d.op.layer.weight_init import constant_init, kaiming_init
+
+from .base import BaseModel
+
+
+class BasicBlock(nn.Module):
+ """BasicBlock."""
+
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ downsample: nn.Module | None = None,
+ style: str = "pytorch",
+ use_checkpoint: bool = False,
+ with_dcn: bool = False,
+ norm: str = "BatchNorm2d",
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ assert style in {"pytorch", "caffe"} # No effect for BasicBlock
+ assert not with_dcn, "DCN is not supported for BasicBlock."
+ self.conv1 = build_conv_layer(
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ dilation=dilation,
+ padding=dilation,
+ bias=False,
+ )
+ self.bn1 = build_norm_layer(norm, planes)
+ self.conv2 = build_conv_layer(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = build_norm_layer(norm, planes)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward function."""
+
+ def _inner_forward(x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.use_checkpoint and x.requires_grad:
+ out = checkpoint(_inner_forward, x, use_reentrant=True)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck."""
+
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ downsample: nn.Module | None = None,
+ style: str = "pytorch",
+ use_checkpoint: bool = False,
+ with_dcn: bool = False,
+ norm: str = "BatchNorm2d",
+ ) -> None:
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super().__init__()
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.use_checkpoint = use_checkpoint
+
+ assert style in {"pytorch", "caffe"}
+ if style == "pytorch":
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.conv1 = build_conv_layer(
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False,
+ )
+ self.bn1 = build_norm_layer(norm, planes)
+
+ self.conv2 = build_conv_layer(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False,
+ use_dcn=with_dcn,
+ )
+ self.bn2 = build_norm_layer(norm, planes)
+
+ self.conv3 = build_conv_layer(
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False,
+ )
+ self.bn3 = build_norm_layer(norm, planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward function."""
+
+ def _inner_forward(x: Tensor) -> Tensor:
+ identity = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.use_checkpoint and x.requires_grad:
+ out = checkpoint(_inner_forward, x, use_reentrant=True)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(BaseModel):
+ """ResNet BaseModel."""
+
+ arch_settings = {
+ "resnet18": (18, BasicBlock, (2, 2, 2, 2)),
+ "resnet34": (34, BasicBlock, (3, 4, 6, 3)),
+ "resnet50": (50, Bottleneck, (3, 4, 6, 3)),
+ "resnet101": (101, Bottleneck, (3, 4, 23, 3)),
+ "resnet152": (152, Bottleneck, (3, 8, 36, 3)),
+ }
+
+ def __init__(
+ self,
+ resnet_name: str,
+ in_channels: int = 3,
+ stem_channels: int | None = None,
+ base_channels: int = 64,
+ num_stages: int = 4,
+ strides: Sequence[int] = (1, 2, 2, 2),
+ dilations: Sequence[int] = (1, 1, 1, 1),
+ style: str = "pytorch",
+ deep_stem: bool = False,
+ avg_down: bool = False,
+ trainable_layers: int = 5,
+ norm: str = "BatchNorm2d",
+ norm_frozen: bool = True,
+ stages_with_dcn: Sequence[bool] = (False, False, False, False),
+ replace_stride_with_dilation: Sequence[bool] = (False, False, False),
+ use_checkpoint: bool = False,
+ zero_init_residual: bool = True,
+ pretrained: bool = False,
+ weights: None | str = None,
+ ) -> None:
+ """Create ResNet.
+
+ Args:
+ resnet_name (str): Name of the ResNet variant.
+ in_channels (int): Number of input image channels. Default: 3.
+ stem_channels (int | None): Number of stem channels. If not
+ specified, it will be the same as `base_channels`. Default:
+ None.
+ base_channels (int): Number of base channels of res layer. Default:
+ 64.
+ num_stages (int): Resnet stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ Default: (1, 2, 2, 2).
+ dilations (Sequence[int]): Dilation of each stage. Default: (1, 1,
+ 1, 1)
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the
+ stride-two layer is the 3x3 conv layer, otherwise the
+ stride-two layer is the first 1x1 conv layer. Default: pytorch.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
+ Default: False.
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False.
+ trainable_layers (int, optional): Number layers for training or
+ fine-tuning. 5 means all the layers can be fine-tuned. Defaults
+ to 5.
+ norm (str): Normalization layer str. Default: BatchNorm2d, which
+ means using `nn.BatchNorm2d`.
+ norm_frozen (bool): Whether to set norm layers to eval mode. It
+ freezes running stats (mean and var). Note: Effect on
+ Batch Norm and its variants only.
+ stages_with_dcn (Sequence[bool]): Indices of stages with deformable
+ convolutions. Default: (False, False, False, False).
+ replace_stride_with_dilation (Sequence[bool]): Whether to replace
+ stride with dilation. Default: (False, False, False).
+ use_checkpoint (bool): Use checkpoint or not. Using checkpoint will
+ save some memory while slowing down the training speed.
+ Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm
+ layer in resblocks to let them behave as identity.
+ Default: True.
+ pretrained (bool): Whether to load pretrained weights. Default:
+ False.
+ weights (str, optional): model pretrained path. Default: None
+ """
+ super().__init__()
+ self._norm = norm
+
+ self.zero_init_residual = zero_init_residual
+ if resnet_name not in self.arch_settings:
+ raise KeyError(f"invalid architecture {resnet_name} for ResNet")
+ self.name = resnet_name
+ self.deep_stem = deep_stem
+ self.trainable_layers = trainable_layers
+
+ self.use_checkpoint = use_checkpoint
+ self.norm_frozen = norm_frozen
+
+ depth, self.block, stage_blocks = self.arch_settings[resnet_name]
+ assert isinstance(depth, int)
+
+ self.depth = depth
+ stem_channels = stem_channels or base_channels
+
+ assert 4 >= num_stages >= 1
+ assert len(strides) == len(dilations) == num_stages
+
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ if i > 0 and replace_stride_with_dilation[i - 1]:
+ dilation = strides[i]
+ stride = 1
+ else:
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = base_channels * 2**i
+ res_layer = self._make_res_layer(
+ block=self.block, # type: ignore
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=style,
+ avg_down=avg_down,
+ use_checkpoint=use_checkpoint,
+ with_dcn=stages_with_dcn[i],
+ )
+ self.inplanes = planes * self.block.expansion # type: ignore
+ layer_name = f"layer{i + 1}"
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ if pretrained:
+ if weights is None:
+ # default loading the imagenet-1k v1 pre-trained model weights
+ weights = _resnet.__dict__[
+ f"ResNet{depth}_Weights"
+ ].IMAGENET1K_V1.url
+
+ load_model_checkpoint(self, weights)
+ else:
+ self._init_weights()
+
+ self._freeze_stages()
+
+ def _init_weights(self) -> None:
+ """Initialize the weights of module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and isinstance(
+ m.bn3.weight, nn.Parameter
+ ):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock) and isinstance(
+ m.bn2.weight, nn.Parameter
+ ):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_stem_layer(self, in_channels: int, stem_channels: int) -> None:
+ """Make stem layer for ResNet."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False,
+ ),
+ build_norm_layer(self._norm, stem_channels // 2),
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ build_norm_layer(self._norm, stem_channels // 2),
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ ),
+ build_norm_layer(self._norm, stem_channels),
+ nn.ReLU(inplace=True),
+ )
+ else:
+ self.conv1 = build_conv_layer(
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False,
+ )
+ self.bn1 = build_norm_layer(self._norm, stem_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _make_res_layer(
+ self,
+ block: BasicBlock | Bottleneck,
+ inplanes: int,
+ planes: int,
+ num_blocks: int,
+ stride: int,
+ dilation: int,
+ style: str,
+ avg_down: bool,
+ use_checkpoint: bool,
+ with_dcn: bool,
+ ) -> nn.Sequential:
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ layers: list[BasicBlock | Bottleneck] = []
+ downsample: nn.Module | None = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample_list: list[nn.AvgPool2d | nn.Module] = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample_list.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False,
+ )
+ )
+ downsample_list.extend(
+ [
+ build_conv_layer(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False,
+ ),
+ build_norm_layer(self._norm, planes * block.expansion),
+ ]
+ )
+ downsample = nn.Sequential(*downsample_list)
+
+ layers = []
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ dilation=dilation,
+ downsample=downsample,
+ style=style,
+ use_checkpoint=use_checkpoint,
+ with_dcn=with_dcn,
+ norm=self._norm,
+ )
+ )
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ dilation=dilation,
+ style=style,
+ use_checkpoint=use_checkpoint,
+ with_dcn=with_dcn,
+ norm=self._norm,
+ )
+ )
+ return nn.Sequential(*layers)
+
+ def _freeze_stages(self) -> None:
+ """Freeze stages param and norm stats."""
+ if self.trainable_layers < 5:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.bn1.eval()
+ for m in (self.conv1, self.bn1):
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, 5 - self.trainable_layers):
+ m = getattr(self, f"layer{i}")
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode: bool = True) -> ResNet:
+ """Override the train mode for the model."""
+ super().train(mode)
+ self._freeze_stages()
+
+ if mode and self.norm_frozen:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+ return self
+
+ @property
+ def out_channels(self) -> list[int]:
+ """Get the number of channels for each level of feature pyramid.
+
+ Returns:
+ list[int]: number of channels
+ """
+ if self.name in {"resnet18", "resnet34"}:
+ # channels = [3, 3] + [64 * 2**i for i in range(4)]
+ channels = [3, 3, 64, 128, 256, 512]
+ else:
+ # channels = [3, 3] + [256 * 2**i for i in range(4)]
+ channels = [3, 3, 256, 512, 1024, 2048]
+ return channels
+
+ def forward(self, images: Tensor) -> list[Tensor]:
+ """Forward function.
+
+ Args:
+ images (Tensor[N, C, H, W]): Image input to process. Expected to
+ type float32 with values ranging 0..255.
+
+ Returns:
+ fp (list[torch.Tensor]): The output feature pyramid. The list index
+ represents the level, which has a downsampling raio of 2^index.
+ fp[0] and fp[1] is a reference to the input images and
+ torchvision resnet downsamples the feature maps by 4 directly.
+ The last feature map downsamples the input image by 64 with a
+ pooling layer on the second last map.
+ """
+ if self.deep_stem:
+ x = self.stem(images)
+ else:
+ x = self.conv1(images)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = [images, images]
+ for _, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ outs.append(x)
+ return outs
+
+
+class ResNetV1c(ResNet):
+ """ResNetV1c variant with a deeper stem.
+
+ Compared with default ResNet, ResNetV1c replaces the 7x7 conv in the input
+ stem with three 3x3 convs. For more details please refer to `Bag of Tricks
+ for Image Classification with Convolutional Neural Networks
+ `.
+ """
+
+ model_urls = {
+ "resnet50_v1c": (
+ "https://download.openmmlab.com/pretrain/third_party/"
+ "resnet50_v1c-2cccc1ad.pth"
+ ),
+ "resnet101_v1c": (
+ "https://download.openmmlab.com/pretrain/third_party/"
+ "resnet101_v1c-e67eebb6.pth"
+ ),
+ }
+
+ def __init__(
+ self,
+ resnet_name: str,
+ pretrained: bool = False,
+ weights: str | None = None,
+ **kwargs: ArgsType,
+ ):
+ """Initialize ResNetV1c.
+
+ Args:
+ resnet_name (str): Name of the resnet model.
+ pretrained (bool, optional): Whether to load ImageNet pre-trained
+ weights. Defaults to False.
+ weights (str, optional): Path to custom pretrained weights.
+ **kwargs: Arguments for ResNet.
+ """
+ assert resnet_name in {
+ "resnet18_v1c",
+ "resnet34_v1c",
+ "resnet50_v1c",
+ "resnet101_v1c",
+ }
+ if pretrained and weights is None:
+ assert resnet_name in {
+ "resnet50_v1c",
+ "resnet101_v1c",
+ }, "Only resnet50_v1c and resnet101_v1c have pretrained weights."
+ weights = self.model_urls[resnet_name]
+
+ super().__init__(
+ resnet_name[:-4],
+ deep_stem=True,
+ pretrained=pretrained,
+ weights=weights,
+ **kwargs,
+ )
diff --git a/vis4d/op/base/unet.py b/vis4d/op/base/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bfbf4d1815f59316f052372f59e6584eb9dcb88
--- /dev/null
+++ b/vis4d/op/base/unet.py
@@ -0,0 +1,169 @@
+"""Unet Implementation based on https://arxiv.org/abs/1505.04597.
+
+Code taken from https://github.com/jaxony/unet-pytorch/blob/master/model.py
+and modified to include typing and custom ops.
+"""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+from torch import nn
+
+from vis4d.op.layer.conv2d import UnetDownConv, UnetUpConv
+
+
+class UNetOut(NamedTuple):
+ """Output of the UNet operator.
+
+ logits: Final output of the network without applying softmax
+ intermediate_features: Intermediate features of the upsampling path
+ at different scales.
+ """
+
+ logits: torch.Tensor
+ intermediate_features: list[torch.Tensor]
+
+
+class UNet(nn.Module):
+ """The U-Net is a convolutional encoder-decoder neural network.
+
+ Contextual spatial information (from the decoding,
+ expansive pathway) about an input tensor is merged with
+ information representing the localization of details
+ (from the encoding, compressive pathway).
+
+ Modifications to the original paper:
+ (1) padding is used in 3x3 convolutions to prevent loss
+ of border pixels
+ (2) merging outputs does not require cropping due to (1)
+ (3) residual connections can be used by specifying
+ UNet(merge_mode='add')
+ (4) if non-parametric upsampling is used in the decoder
+ pathway (specified by upmode='upsample'), then an
+ additional 1x1 2d convolution occurs after upsampling
+ to reduce channel dimensionality by a factor of 2.
+ This channel halving happens with the convolution in
+ the tranpose convolution (specified by upmode='transpose')
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ in_channels: int = 3,
+ depth: int = 5,
+ start_filts: int = 32,
+ up_mode: str = "transpose",
+ merge_mode: str = "concat",
+ ):
+ """Unet Operator.
+
+ Args:
+ in_channels: int, number of channels in the input tensor.
+ Default is 3 for RGB images.
+ num_classes: int, number of output classes.
+ depth: int, number of MaxPools in the U-Net.
+ start_filts: int, number of convolutional filters for the
+ first conv.
+ up_mode: string, type of upconvolution. Choices: 'transpose'
+ for transpose convolution or 'upsample' for nearest neighbour
+ upsampling.
+ merge_mode: string, how to merge features, can be 'concat' or 'add'
+
+
+ Raises:
+ ValueError: if invalid modes are provided
+ """
+ super().__init__()
+
+ if up_mode in {"transpose", "upsample"}:
+ self.up_mode = up_mode
+ else:
+ raise ValueError(
+ f"{up_mode} is not a valid mode for upsampling. Only"
+ f"'transpose' and 'upsample' are allowed."
+ )
+
+ if merge_mode in {"concat", "add"}:
+ self.merge_mode = merge_mode
+ else:
+ raise ValueError(
+ f'"{up_mode}" is not a valid mode for'
+ f"merging up and down paths. "
+ f'Only "concat" and '
+ f'"add" are allowed.'
+ )
+
+ # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
+ if self.up_mode == "upsample" and self.merge_mode == "add":
+ raise ValueError(
+ 'up_mode "upsample" is incompatible '
+ 'with merge_mode "add" at the moment '
+ "because it doesn't make sense to use "
+ "nearest neighbour to reduce "
+ "depth channels (by half)."
+ )
+
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.start_filts = start_filts
+ self.depth = depth
+
+ self.down_convs: nn.ModuleList = nn.ModuleList()
+
+ # create the encoder pathway and add to a list
+ for i in range(depth):
+ ins = self.in_channels if i == 0 else outs # type: ignore
+ outs = self.start_filts * (2**i)
+ pooling = i < (depth - 1)
+
+ down_conv = UnetDownConv(ins, outs, pooling=pooling)
+ self.down_convs.append(down_conv)
+
+ self.up_convs: nn.ModuleList = nn.ModuleList()
+
+ # create the decoder pathway and add to a list
+ # - careful! decoding only requires depth-1 blocks
+ for i in range(depth - 1):
+ ins = outs
+ outs = ins // 2
+ up_conv = UnetUpConv(
+ ins, outs, up_mode=up_mode, merge_mode=merge_mode
+ )
+ self.up_convs.append(up_conv)
+ self.conv_final = nn.Conv2d(
+ outs, num_classes, kernel_size=1, groups=1, stride=1
+ )
+
+ def __call__(self, data: torch.Tensor) -> UNetOut:
+ """Applies the UNet.
+
+ Args:
+ data (tensor): Input Images into the network shape [N, C, W, H]
+
+ """
+ return self._call_impl(data)
+
+ def forward(self, data: torch.Tensor) -> UNetOut:
+ """Applies the UNet.
+
+ Args:
+ data (tensor): Input Images into the network shape [N, C, W, H]
+ """
+ encoder_outs: list[torch.Tensor] = []
+ inter_feats: list[torch.Tensor] = []
+ # encoder pathway, save outputs for merging
+
+ for down_conv in self.down_convs:
+ out = down_conv(data)
+ data = out.pooled_features
+ encoder_outs.append(out.features)
+
+ for level, up_conv in enumerate(self.up_convs):
+ before_pool = encoder_outs[-(level + 2)]
+ data = up_conv(before_pool, data)
+ inter_feats.append(data)
+
+ logits = self.conv_final(data)
+ return UNetOut(logits=logits, intermediate_features=inter_feats)
diff --git a/vis4d/op/base/vgg.py b/vis4d/op/base/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a887ea03bb52c764adc8b56580f861665221c68b
--- /dev/null
+++ b/vis4d/op/base/vgg.py
@@ -0,0 +1,107 @@
+"""Residual networks for classification."""
+
+from __future__ import annotations
+
+import torch
+import torchvision.models.vgg as _vgg
+from torchvision.models._utils import IntermediateLayerGetter
+
+from .base import BaseModel
+
+
+class VGG(BaseModel):
+ """Wrapper for torch vision VGG."""
+
+ def __init__(
+ self,
+ vgg_name: str,
+ trainable_layers: None | int = None,
+ pretrained: bool = False,
+ ):
+ """Initialize the VGG base model from torchvision.
+
+ Args:
+ vgg_name (str): name of the VGG variant. Choices in ["vgg11",
+ "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn",
+ "vgg19_bn"].
+ trainable_layers (int, optional): Number layers for training or
+ fine-tuning. None means all the layers can be fine-tuned.
+ pretrained (bool, optional): Whether to load ImageNet
+ pre-trained weights. Defaults to False.
+
+ Raises:
+ ValueError: The VGG name is not supported
+ """
+ super().__init__()
+ if vgg_name not in [
+ "vgg11",
+ "vgg13",
+ "vgg16",
+ "vgg19",
+ "vgg11_bn",
+ "vgg13_bn",
+ "vgg16_bn",
+ "vgg19_bn",
+ ]:
+ raise ValueError("The VGG name is not supported!")
+
+ weights = "IMAGENET1K_V1" if pretrained else None
+ vgg = _vgg.__dict__[vgg_name](weights=weights)
+ use_bn = vgg_name[-3:] == "_bn"
+ self._out_channels: list[int] = []
+ returned_layers = []
+ last_channel = -1
+ layer_counter = 0
+
+ vgg_channels = _vgg.cfgs[
+ {"vgg11": "A", "vgg13": "B", "vgg16": "D", "vgg19": "E"}[
+ vgg_name[:5]
+ ]
+ ]
+ for channel in vgg_channels:
+ if channel == "M":
+ returned_layers.append(layer_counter)
+ self._out_channels.append(last_channel)
+ layer_counter += 1
+ else:
+ if use_bn:
+ layer_counter += 3
+ else:
+ layer_counter += 2
+ last_channel = channel
+
+ if trainable_layers is not None:
+ for name, parameter in vgg.features.named_parameters():
+ layer_ind = int(name.split(".")[0])
+ if layer_ind < layer_counter - trainable_layers:
+ parameter.requires_grad_(False)
+
+ return_layers = {str(v): str(i) for i, v in enumerate(returned_layers)}
+ self.body = IntermediateLayerGetter(
+ vgg.features, return_layers=return_layers
+ )
+ self.name = vgg_name
+
+ @property
+ def out_channels(self) -> list[int]:
+ """Get the number of channels for each level of feature pyramid.
+
+ Returns:
+ list[int]: number of channels
+ """
+ return [3, 3, *self._out_channels]
+
+ def forward(self, images: torch.Tensor) -> list[torch.Tensor]:
+ """VGG feature forward without classification head.
+
+ Args:
+ images (Tensor[N, C, H, W]): Image input to process. Expected to
+ type float32 with values ranging 0..255.
+
+ Returns:
+ fp (list[torch.Tensor]): The output feature pyramid. The list index
+ represents the level, which has a downsampling raio of 2^index.
+ fp[0] and fp[1] is a reference to the input images. The last
+ feature map downsamples the input image by 64.
+ """
+ return [images, images, *self.body(images).values()]
diff --git a/vis4d/op/base/vit.py b/vis4d/op/base/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..186a36f4c678d8d033b5bb997051af9b40e8f553
--- /dev/null
+++ b/vis4d/op/base/vit.py
@@ -0,0 +1,271 @@
+"""Residual networks for classification."""
+
+from __future__ import annotations
+
+import torch
+from timm.models import named_apply
+from torch import nn
+
+from vis4d.op.layer.patch_embed import PatchEmbed
+from vis4d.op.layer.transformer import TransformerBlock
+
+from .base import BaseModel
+
+
+def _init_weights_vit_timm( # pylint: disable=unused-argument
+ module: nn.Module, name: str
+) -> None:
+ """Weight initialization, original timm impl (for reproducibility)."""
+ if isinstance(module, nn.Linear):
+ nn.init.trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, "init_weights"):
+ module.init_weights() # type: ignore
+
+
+ViT_PRESET = { # pylint: disable=consider-using-namedtuple-or-dataclass
+ "vit_tiny_patch16_224": {
+ "patch_size": 16,
+ "embed_dim": 192,
+ "depth": 12,
+ "num_heads": 3,
+ },
+ "vit_small_patch16_224": {
+ "patch_size": 16,
+ "embed_dim": 384,
+ "depth": 12,
+ "num_heads": 6,
+ },
+ "vit_base_patch16_224": {
+ "patch_size": 16,
+ "embed_dim": 768,
+ "depth": 12,
+ "num_heads": 12,
+ },
+ "vit_large_patch16_224": {
+ "patch_size": 16,
+ "embed_dim": 1024,
+ "depth": 24,
+ "num_heads": 16,
+ },
+ "vit_huge_patch16_224": {
+ "patch_size": 16,
+ "embed_dim": 1280,
+ "depth": 32,
+ "num_heads": 16,
+ },
+ "vit_small_patch32_224": {
+ "patch_size": 32,
+ "embed_dim": 384,
+ "depth": 12,
+ "num_heads": 6,
+ },
+ "vit_base_patch32_224": {
+ "patch_size": 32,
+ "embed_dim": 768,
+ "depth": 12,
+ "num_heads": 12,
+ },
+ "vit_large_patch32_224": {
+ "patch_size": 32,
+ "embed_dim": 1024,
+ "depth": 24,
+ "num_heads": 16,
+ },
+ "vit_huge_patch32_224": {
+ "patch_size": 32,
+ "embed_dim": 1280,
+ "depth": 32,
+ "num_heads": 16,
+ },
+}
+
+
+class VisionTransformer(BaseModel):
+ """Vision Transformer (ViT) model without classification head.
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
+ Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+
+ Adapted from:
+ - pytorch vision transformer impl
+ - timm vision transformer impl
+ """
+
+ def __init__(
+ self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ init_values: float | None = None,
+ class_token: bool = True,
+ no_embed_class: bool = False,
+ pre_norm: bool = False,
+ pos_drop_rate: float = 0.0,
+ drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ norm_layer: nn.Module | None = None,
+ act_layer: nn.Module = nn.GELU(),
+ ) -> None:
+ """Init VisionTransformer.
+
+ Args:
+ img_size (int, optional): Input image size. Defaults to 224.
+ patch_size (int, optional): Patch size. Defaults to 16.
+ in_channels (int, optional): Number of input channels. Defaults to
+ 3.
+ num_classes (int, optional): Number of classes. Defaults to 1000.
+ embed_dim (int, optional): Embedding dimension. Defaults to 768.
+ depth (int, optional): Depth. Defaults to 12.
+ num_heads (int, optional): Number of attention heads. Defaults to
+ 12.
+ mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding
+ dim. Defaults to 4.0.
+ qkv_bias (bool, optional): If to add bias to qkv. Defaults to True.
+ init_values (float, optional): Initial values for layer scale.
+ Defaults to None.
+ class_token (bool, optional): If to add a class token. Defaults to
+ True.
+ no_embed_class (bool, optional): If to not embed class token.
+ Defaults to False.
+ pre_norm (bool, optional): If to use pre-norm. Defaults to False.
+ pos_drop_rate (float, optional): Postional dropout rate. Defaults
+ to 0.0.
+ drop_rate (float, optional): Dropout rate. Defaults to 0.0.
+ attn_drop_rate (float, optional): Attention dropout rate. Defaults
+ to 0.0.
+ drop_path_rate (float, optional): Drop path rate. Defaults to 0.0.
+ embed_layer (nn.Module, optional): Embedding layer. Defaults to
+ PatchEmbed.
+ norm_layer (nn.Module, optional): Normalization layer. If None,
+ nn.LayerNorm is used. Defaults to None.
+ act_layer (nn.Module, optional): Activation layer. Defaults to
+ nn.GELU().
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_depth = depth
+ self.num_prefix_tokens = 1 if class_token else 0
+ self.no_embed_class = no_embed_class
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim,
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = (
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
+ )
+ embed_len = (
+ num_patches
+ if no_embed_class
+ else num_patches + self.num_prefix_tokens
+ )
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim))
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+ self.norm_pre = (
+ nn.LayerNorm(embed_dim, eps=1e-6) if pre_norm else nn.Identity()
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ blocks = [
+ TransformerBlock(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ init_values=init_values,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ )
+ for i in range(depth)
+ ]
+ self.blocks = nn.ModuleList(blocks)
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Init weights using timm's implementation."""
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+ if self.cls_token is not None:
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(_init_weights_vit_timm, self)
+
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
+ """Add positional embeddings."""
+ if self.no_embed_class:
+ # deit-3, updated JAX (big vision)
+ # position embedding does not overlap with class token, add then
+ # concat
+ x = x + self.pos_embed
+ if self.cls_token is not None:
+ x = torch.cat(
+ (self.cls_token.expand(x.shape[0], -1, -1), x), dim=1
+ )
+ else:
+ # original timm, JAX, and deit vit impl
+ # pos_embed has entry for class token, concat then add
+ if self.cls_token is not None:
+ x = torch.cat(
+ (self.cls_token.expand(x.shape[0], -1, -1), x), dim=1
+ )
+ x = x + self.pos_embed
+ return self.pos_drop(x)
+
+ @property
+ def out_channels(self) -> list[int]:
+ """Return the number of output channels per feature level."""
+ return [self.embed_dim] * (self.num_depth + 1)
+
+ def __call__(self, data: torch.Tensor) -> list[torch.Tensor]:
+ """Applies the ViT encoder.
+
+ Args:
+ data (tensor): Input Images into the network shape [N, C, W, H]
+
+ """
+ return self._call_impl(data)
+
+ def forward(self, images: torch.Tensor) -> list[torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ images (torch.Tensor): Input images tensor of shape (B, C, H, W).
+
+ Returns:
+ feats (list[torch.Tensor]): Features of the input images extracted
+ by the ViT encoder. feats[0] is the input images, and feats[1]
+ is the output of the patch embedding layer. The rest of the
+ elements are the outputs of each transformer block, with the
+ shape (B, N, dim), where N is the number of patches, and dim
+ is the embedding dimension. The final element is the output of
+ the ViT encoder.
+ """
+ feats = [images]
+ x = self.patch_embed(images)
+ x = self.norm_pre(self._pos_embed(x))
+ feats.append(x)
+ for blk in self.blocks:
+ x = blk(x)
+ feats.append(x)
+ return feats
diff --git a/vis4d/op/box/__init__.py b/vis4d/op/box/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5fa78566ff8e3e3eb9eca8a7d31f3fe85ee8ebf
--- /dev/null
+++ b/vis4d/op/box/__init__.py
@@ -0,0 +1 @@
+"""Operations on 2D bounding boxes."""
diff --git a/vis4d/op/box/__pycache__/__init__.cpython-311.pyc b/vis4d/op/box/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c502b757743d77862f0557e3e924f015b1c7897
Binary files /dev/null and b/vis4d/op/box/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/op/box/__pycache__/box2d.cpython-311.pyc b/vis4d/op/box/__pycache__/box2d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..735c1ba37b969b4fc97635c8b423d0f479b167a9
Binary files /dev/null and b/vis4d/op/box/__pycache__/box2d.cpython-311.pyc differ
diff --git a/vis4d/op/box/__pycache__/box3d.cpython-311.pyc b/vis4d/op/box/__pycache__/box3d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c21bd9e20cb638d7c30d17d8d6db614b40fae624
Binary files /dev/null and b/vis4d/op/box/__pycache__/box3d.cpython-311.pyc differ
diff --git a/vis4d/op/box/anchor/__init__.py b/vis4d/op/box/anchor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..95be90d588d53e64af33827bfec13970bdc5acac
--- /dev/null
+++ b/vis4d/op/box/anchor/__init__.py
@@ -0,0 +1,6 @@
+"""Anchor and point generators."""
+
+from .anchor_generator import AnchorGenerator, anchor_inside_image
+from .point_generator import MlvlPointGenerator
+
+__all__ = ["AnchorGenerator", "anchor_inside_image", "MlvlPointGenerator"]
diff --git a/vis4d/op/box/anchor/anchor_generator.py b/vis4d/op/box/anchor/anchor_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a4efc383e213af2f5d5473e66103506b273f439
--- /dev/null
+++ b/vis4d/op/box/anchor/anchor_generator.py
@@ -0,0 +1,329 @@
+"""Anchor generator for 2D bounding boxes.
+
+Modified from:
+https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/anchor/anchor_generator.py
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.nn.modules.utils import _pair
+
+from .util import meshgrid
+
+
+def anchor_inside_image(
+ flat_anchors: Tensor, img_shape: tuple[int, int], allowed_border: int = 0
+) -> Tensor:
+ """Check whether the anchors are inside the border.
+
+ Args:
+ flat_anchors (Tensor): Flatten anchors, shape (n, 4).
+ img_shape (tuple(int)): Shape of current image.
+ allowed_border (int): The border to allow the valid anchor.
+ Defaults to 0.
+
+ Returns:
+ Tensor: Flags indicating whether the anchors are inside a valid range.
+ """
+ img_h, img_w = img_shape
+ inside_flags = (
+ (flat_anchors[:, 0] >= -allowed_border)
+ & (flat_anchors[:, 1] >= -allowed_border)
+ & (flat_anchors[:, 2] < img_w + allowed_border)
+ & (flat_anchors[:, 3] < img_h + allowed_border)
+ )
+ return inside_flags
+
+
+class AnchorGenerator:
+ """Standard anchor generator for 2D anchor-based detectors.
+
+ Examples:
+ >>> from vis4d.op.box.anchor import AnchorGenerator
+ >>> self = AnchorGenerator([16], [1.], [1.], [9])
+ >>> all_anchors = self.grid_priors([(2, 2)], device='cpu')
+ >>> print(all_anchors)
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
+ [11.5000, -4.5000, 20.5000, 4.5000],
+ [-4.5000, 11.5000, 4.5000, 20.5000],
+ [11.5000, 11.5000, 20.5000, 20.5000]])]
+ >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
+ >>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu')
+ >>> print(all_anchors)
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
+ [11.5000, -4.5000, 20.5000, 4.5000],
+ [-4.5000, 11.5000, 4.5000, 20.5000],
+ [11.5000, 11.5000, 20.5000, 20.5000]]), \
+ tensor([[-9., -9., 9., 9.]])]
+ """
+
+ def __init__(
+ self,
+ strides: list[int] | list[tuple[int, int]],
+ ratios: list[float],
+ scales: list[int] | None = None,
+ base_sizes: list[int] | None = None,
+ scale_major: bool = True,
+ octave_base_scale: None | int = None,
+ scales_per_octave: None | int = None,
+ centers: list[tuple[float, float]] | None = None,
+ center_offset: float = 0.0,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ ratios (list[float]): The list of ratios between the height and
+ width of anchors in a single level.
+ scales (list[int] | None): Anchor scales for anchors in a single
+ level. It cannot be set at the same time if `octave_base_scale`
+ and `scales_per_octave` are set.
+ base_sizes (list[int] | None): The basic sizes
+ of anchors in multiple levels.
+ If None is given, strides will be used as base_sizes.
+ (If strides are non square, the shortest stride is taken.)
+ scale_major (bool): Whether to multiply scales first when
+ generating base anchors. If true, the anchors in the same row
+ will have the same scales. By default it is True in V2.0
+ octave_base_scale (int): The base scale of octave.
+ scales_per_octave (int): Number of scales for each octave.
+ `octave_base_scale` and `scales_per_octave` are usually used in
+ retinanet and the `scales` should be None when they are set.
+ centers (list[tuple[float, float]] | None): The centers of the
+ anchor relative to the feature grid center in multiple feature
+ levels. By default it is set to be None and not used. If a list
+ of tuple of float is given, they will be used to shift the
+ centers of anchors.
+ center_offset (float): The offset of center in proportion to
+ anchors' width and height. By default it is 0 in V2.0.
+ """
+ # check center and center_offset
+ if center_offset != 0:
+ assert centers is None, (
+ "center cannot be set when center_offset"
+ f"!=0, {centers} is given."
+ )
+ if not 0 <= center_offset <= 1:
+ raise ValueError(
+ "center_offset should be in range [0, 1], "
+ f"{center_offset} is given."
+ )
+ if centers is not None:
+ assert len(centers) == len(strides), (
+ "The number of strides should be the same as centers, got "
+ f"{strides} and {centers}"
+ )
+
+ # calculate base sizes of anchors
+ self.strides = [_pair(stride) for stride in strides]
+ self.base_sizes = (
+ [min(stride) for stride in self.strides]
+ if base_sizes is None
+ else base_sizes
+ )
+ assert len(self.base_sizes) == len(self.strides), (
+ "The number of strides should be the same as base sizes, got "
+ f"{self.strides} and {self.base_sizes}"
+ )
+
+ # calculate scales of anchors
+ assert (
+ octave_base_scale is not None and scales_per_octave is not None
+ ) ^ (scales is not None), (
+ "scales and octave_base_scale with scales_per_octave cannot"
+ " be set at the same time"
+ )
+ if scales is not None:
+ self.scales = torch.Tensor(scales)
+ elif octave_base_scale is not None and scales_per_octave is not None:
+ octave_scales = np.array(
+ [
+ 2 ** (i / scales_per_octave)
+ for i in range(scales_per_octave)
+ ]
+ )
+ scales = octave_scales * octave_base_scale # type: ignore
+ self.scales = torch.Tensor(scales)
+ else:
+ raise ValueError(
+ "Either scales or octave_base_scale with "
+ "scales_per_octave should be set"
+ )
+
+ self.octave_base_scale = octave_base_scale
+ self.scales_per_octave = scales_per_octave
+ self.ratios = torch.Tensor(ratios)
+ self.scale_major = scale_major
+ self.centers = centers
+ self.center_offset = center_offset
+ self.base_anchors = self.gen_base_anchors()
+
+ @property
+ def num_base_priors(self) -> list[int]:
+ """list[int]: The number of priors at a point on the feature grid."""
+ return [base_anchors.size(0) for base_anchors in self.base_anchors]
+
+ @property
+ def num_levels(self) -> int:
+ """int: number of feature levels that the generator will be applied."""
+ return len(self.strides)
+
+ def gen_base_anchors(self) -> list[Tensor]:
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_size in enumerate(self.base_sizes):
+ center = None
+ if self.centers is not None:
+ center = self.centers[i]
+ multi_level_base_anchors.append(
+ self.gen_single_level_base_anchors(
+ base_size,
+ scales=self.scales,
+ ratios=self.ratios,
+ center=center,
+ )
+ )
+ return multi_level_base_anchors
+
+ def gen_single_level_base_anchors(
+ self,
+ base_size: int,
+ scales: Tensor,
+ ratios: Tensor,
+ center: tuple[float, float] | None = None,
+ ) -> Tensor:
+ """Generate base anchors of a single level.
+
+ Args:
+ base_size (int): Basic size of an anchor.
+ scales (Tensor): Scales of the anchor.
+ ratios (Tensor): The ratio between between the height
+ and width of anchors in a single level.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ Tensor: Anchors in a single-level feature maps.
+ """
+ width, height = base_size, base_size
+ if center is None:
+ x_center = self.center_offset * width
+ y_center = self.center_offset * height
+ else:
+ x_center, y_center = center
+
+ h_ratios = torch.sqrt(ratios)
+ w_ratios = 1 / h_ratios
+ if self.scale_major:
+ ws = (width * w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (height * h_ratios[:, None] * scales[None, :]).view(-1)
+ else:
+ ws = (width * scales[:, None] * w_ratios[None, :]).view(-1)
+ hs = (height * scales[:, None] * h_ratios[None, :]).view(-1)
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchors = [
+ x_center - 0.5 * ws,
+ y_center - 0.5 * hs,
+ x_center + 0.5 * ws,
+ y_center + 0.5 * hs,
+ ]
+
+ return torch.stack(base_anchors, dim=-1)
+
+ def grid_priors(
+ self,
+ featmap_sizes: list[tuple[int, int]],
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device("cpu"),
+ ) -> list[Tensor]:
+ """Generate grid anchors in multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels.
+ dtype (torch.dtype): Dtype of priors. Default: torch.float32.
+ device (torch.device): The device where the anchors will be put on.
+
+ Return:
+ list[Tensor]: Anchors in multiple feature levels. The sizes of each
+ tensor should be [N, 4], where
+ N = width * height * num_base_anchors, width and height
+ are the sizes of the corresponding feature level,
+ num_base_anchors is the number of anchors for that level.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_anchors = []
+ for i in range(self.num_levels):
+ anchors = self.single_level_grid_priors(
+ featmap_sizes[i], level_idx=i, dtype=dtype, device=device
+ )
+ multi_level_anchors.append(anchors)
+ return multi_level_anchors
+
+ def single_level_grid_priors(
+ self,
+ featmap_size: tuple[int, int],
+ level_idx: int,
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device("cpu"),
+ ) -> Tensor:
+ """Generate grid anchors of a single level.
+
+ Args:
+ featmap_size (tuple[int, int]): Size of the feature maps.
+ level_idx (int): The index of corresponding feature map level.
+ dtype (torch.dtype, optional): Data type of points. Defaults to
+ torch.float32.
+ device (torch.device): The device the tensor will be put on.
+
+ Returns:
+ Tensor: Anchors in the overall feature maps.
+ """
+ base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ # First create Range with the default dtype, than convert to
+ # target `dtype` for onnx exporting.
+ shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
+ shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
+
+ shift_xx, shift_yy = meshgrid(shift_x, shift_y)
+ shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
+ # first feat_w elements correspond to the first row of shifts
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
+
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
+ all_anchors = all_anchors.view(-1, 4)
+ # first A rows correspond to A anchors of (0, 0) in feature map,
+ # then (0, 1), (0, 2), ...
+ return all_anchors
+
+ def __repr__(self) -> str:
+ """str: a string that describes the module."""
+ indent_str = " "
+ repr_str = self.__class__.__name__ + "(\n"
+ repr_str += f"{indent_str}strides={self.strides},\n"
+ repr_str += f"{indent_str}ratios={self.ratios},\n"
+ repr_str += f"{indent_str}scales={self.scales},\n"
+ repr_str += f"{indent_str}base_sizes={self.base_sizes},\n"
+ repr_str += f"{indent_str}scale_major={self.scale_major},\n"
+ repr_str += f"{indent_str}octave_base_scale="
+ repr_str += f"{self.octave_base_scale},\n"
+ repr_str += f"{indent_str}scales_per_octave="
+ repr_str += f"{self.scales_per_octave},\n"
+ repr_str += f"{indent_str}num_levels={self.num_levels}\n"
+ repr_str += f"{indent_str}centers={self.centers},\n"
+ repr_str += f"{indent_str}center_offset={self.center_offset})"
+ return repr_str
diff --git a/vis4d/op/box/anchor/point_generator.py b/vis4d/op/box/anchor/point_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1bb100595339cf94856a9e0a3ad09be5d7f89ba
--- /dev/null
+++ b/vis4d/op/box/anchor/point_generator.py
@@ -0,0 +1,210 @@
+"""Point generator for 2D bounding boxes.
+
+Modified from:
+https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/anchor/point_generator.py
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+from .util import meshgrid
+
+
+class MlvlPointGenerator:
+ """Standard points generator for multi-level feature maps.
+
+ Used for 2D points-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ offset (float): The offset of points, the value is normalized with
+ corresponding stride. Defaults to 0.5.
+ """
+
+ def __init__(
+ self, strides: list[int] | list[tuple[int, int]], offset: float = 0.5
+ ):
+ """Init."""
+ self.strides = [_pair(stride) for stride in strides]
+ self.offset = offset
+
+ @property
+ def num_levels(self) -> int:
+ """Number of feature levels."""
+ return len(self.strides)
+
+ @property
+ def num_base_priors(self) -> list[int]:
+ """Number of points at a point on the feature grid."""
+ return [1 for _ in range(len(self.strides))]
+
+ def grid_priors(
+ self,
+ featmap_sizes: list[tuple[int, int]],
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device("cuda"),
+ with_stride: bool = False,
+ ) -> list[torch.Tensor]:
+ """Generate grid points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple[int, int]]): List of feature map sizes in
+ multiple feature levels, each (H, W).
+ dtype (torch.dtype): Dtype of priors. Defaults to torch.float32.
+ device (torch.device): The device where the anchors will be put on.
+ Defaults to torch.device("cuda").
+ with_stride (bool): Whether to concatenate the stride to the last
+ dimension of points. Defaults to False,
+
+ Return:
+ list[torch.Tensor]: Points of multiple feature levels.
+ The sizes of each tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_priors = []
+ for i in range(self.num_levels):
+ priors = self.single_level_grid_priors(
+ featmap_sizes[i],
+ level_idx=i,
+ dtype=dtype,
+ device=device,
+ with_stride=with_stride,
+ )
+ multi_level_priors.append(priors)
+ return multi_level_priors
+
+ def single_level_grid_priors(
+ self,
+ featmap_size: tuple[int, int],
+ level_idx: int,
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device("cuda"),
+ with_stride: bool = False,
+ ) -> torch.Tensor:
+ """Generate grid Points of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_priors``.
+
+ Args:
+ featmap_size (tuple[int, int]): Size of the feature maps, (H, W).
+ level_idx (int): The index of corresponding feature map level.
+ dtype (torch.dtype): Dtype of priors. Defaults to torch.float32.
+ device (torch.device): The device where the tensors will be put on.
+ Defaults to torch.device("cuda").
+ with_stride (bool): Concatenate the stride to the last dimension
+ of points. Defaults to False,
+
+ Return:
+ Tensor: Points of single feature levels.
+ The shape of tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ shift_x = (
+ torch.arange(0, feat_w, device=device) + self.offset
+ ) * stride_w
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_x = shift_x.to(dtype)
+
+ shift_y = (
+ torch.arange(0, feat_h, device=device) + self.offset
+ ) * stride_h
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_y = shift_y.to(dtype)
+ shift_xx, shift_yy = meshgrid(shift_x, shift_y)
+ if not with_stride:
+ shifts = torch.stack([shift_xx, shift_yy], dim=-1)
+ else:
+ # use `shape[0]` instead of `len(shift_xx)` for ONNX export
+ stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(
+ dtype
+ )
+ stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(
+ dtype
+ )
+ shifts = torch.stack(
+ [shift_xx, shift_yy, stride_w, stride_h], dim=-1
+ )
+ all_points = shifts.to(device)
+ return all_points
+
+ def valid_flags(
+ self,
+ featmap_sizes: list[tuple[int, int]],
+ pad_shape: tuple[int, int],
+ device: torch.device = torch.device("cuda"),
+ ) -> list[torch.Tensor]:
+ """Generate valid flags of points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple[int, int]]): List of feature map sizes in
+ multiple feature levels, each (H, W).
+ pad_shape (tuple[int, int]): The padded shape of the image, (H, W).
+ device (torch.device): The device where the anchors will be put on.
+ Defaults to torch.device("cuda").
+
+ Return:
+ list(torch.Tensor): Valid flags of points of multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ point_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
+ flags = self.single_level_valid_flags(
+ (feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device
+ )
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(
+ self,
+ featmap_size: tuple[int, int],
+ valid_size: tuple[int, int],
+ device: torch.device = torch.device("cuda"),
+ ) -> torch.Tensor:
+ """Generate the valid flags of points of a single feature map.
+
+ Args:
+ featmap_size (tuple[int, int]): The size of feature maps, (H, W).
+ valid_size (tuple[int, int]): The valid size of the feature maps,
+ (H, W).
+ device (torch.device, optional): The device where the flags will
+ be put on. Defaults to torch.device("cuda").
+
+ Returns:
+ torch.Tensor: The valid flags of each points in a single level
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
diff --git a/vis4d/op/box/anchor/util.py b/vis4d/op/box/anchor/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa314c86a3bd28cc832d67d11561a9632f241e23
--- /dev/null
+++ b/vis4d/op/box/anchor/util.py
@@ -0,0 +1,27 @@
+"""Anchor utils."""
+
+from __future__ import annotations
+
+from torch import Tensor
+
+
+def meshgrid(
+ x_grid: Tensor, y_grid: Tensor, row_major: bool = True
+) -> tuple[Tensor, Tensor]:
+ """Generate mesh grid of x and y.
+
+ Args:
+ x_grid (Tensor): Grids of x dimension.
+ y_grid (Tensor): Grids of y dimension.
+ row_major (bool, optional): Whether to return y grids first.
+ Defaults to True.
+
+ Returns:
+ tuple[Tensor]: The mesh grids of x and y.
+ """
+ # use shape instead of len to keep tracing while exporting to onnx
+ xx = x_grid.repeat(y_grid.shape[0])
+ yy = y_grid.view(-1, 1).repeat(1, x_grid.shape[0]).view(-1)
+ if row_major:
+ return xx, yy
+ return yy, xx
diff --git a/vis4d/op/box/box2d.py b/vis4d/op/box/box2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..906f876434d20cc1926b9ed561c42b34bf610a62
--- /dev/null
+++ b/vis4d/op/box/box2d.py
@@ -0,0 +1,467 @@
+"""Utility functions for bounding boxes."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+from torchvision.ops import batched_nms, nms
+
+from vis4d.common.logging import rank_zero_warn
+from vis4d.op.geometry.transform import transform_points
+
+
+def bbox_scale(
+ boxes: torch.Tensor, scale_factor_xy: tuple[float, float]
+) -> torch.Tensor:
+ """Scale bounding box tensor.
+
+ Args:
+ boxes (torch.Tensor): Bounding boxes with shape [N, 4]
+ scale_factor_xy (tuple[float, float]): Scaling factor for x and y
+
+ Returns:
+ torch.Tensor with bounding boxes scaled by the given factors in
+ x and y direction
+ """
+ boxes[:, [0, 2]] *= scale_factor_xy[0]
+ boxes[:, [1, 3]] *= scale_factor_xy[1]
+ return boxes
+
+
+def bbox_clip(
+ boxes: torch.Tensor,
+ image_hw: tuple[float, float],
+ epsilon: int = 0,
+) -> torch.Tensor:
+ """Clip bounding boxes to image dims.
+
+ Args:
+ boxes (torch.Tensor): Bounding boxes with shape [N, 4]
+ image_hw (tuple[float, float]): Image dimensions.
+ epsilon (int): Epsilon for clipping.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: Clipped bounding boxes.
+ """
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(0, image_hw[1] - epsilon)
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, image_hw[0] - epsilon)
+ return boxes
+
+
+def scale_and_clip_boxes(
+ boxes: torch.Tensor,
+ original_hw: tuple[int, int],
+ current_hw: tuple[int, int],
+ clip: bool = True,
+) -> torch.Tensor:
+ """Postprocess boxes by scaling and clipping to given image dims.
+
+ Args:
+ boxes (torch.Tensor): Bounding boxes with shape [N, 4].
+ original_hw (tuple[int, int]): Original height / width of image.
+ current_hw (tuple[int, int]): Current height / width of image.
+ clip (bool): If true, clips box corners to image bounds.
+
+ Returns:
+ torch.Tensor: Rescaled and possibly clipped bounding boxes.
+ """
+ scale_factor = (
+ original_hw[1] / current_hw[1],
+ original_hw[0] / current_hw[0],
+ )
+ boxes = bbox_scale(boxes, scale_factor)
+ if clip:
+ boxes = bbox_clip(boxes, original_hw)
+ return boxes
+
+
+def bbox_area(boxes: torch.Tensor) -> torch.Tensor:
+ """Compute bounding box areas.
+
+ Args:
+ boxes (torch.Tensor): [N, 4] tensor of 2D boxes
+ in format (x1, y1, x2, y2).
+
+ Returns:
+ torch.Tensor: [N,] tensor of box areas.
+ """
+ return (boxes[:, 2] - boxes[:, 0]).clamp(0) * (
+ boxes[:, 3] - boxes[:, 1]
+ ).clamp(0)
+
+
+def bbox_intersection(boxes1: Tensor, boxes2: Tensor) -> torch.Tensor:
+ """Given two lists of boxes of size N and M, compute N x M intersection.
+
+ Args:
+ boxes1: N 2D boxes in format (x1, y1, x2, y2)
+ boxes2: M 2D boxes in format (x1, y1, x2, y2)
+
+ Returns:
+ Tensor: intersection (N, M).
+ """
+ width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
+ boxes1[:, None, :2], boxes2[:, :2]
+ )
+ width_height.clamp_(min=0)
+ intersection = width_height.prod(dim=2)
+ return intersection
+
+
+def bbox_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
+ """Compute IoU between all pairs of boxes.
+
+ Args:
+ boxes1: N 2D boxes in format (x1, y1, x2, y2)
+ boxes2: M 2D boxes in format (x1, y1, x2, y2)
+
+ Returns:
+ Tensor: IoU (N, M).
+ """
+ area1 = bbox_area(boxes1)
+ area2 = bbox_area(boxes2)
+ inter = bbox_intersection(boxes1, boxes2)
+
+ union = area1[:, None] + area2 - inter
+
+ inter = torch.where(
+ union > 0,
+ inter,
+ torch.zeros(1, dtype=inter.dtype, device=inter.device),
+ )
+
+ iou = torch.where(
+ inter > 0,
+ inter / (area1[:, None] + area2 - inter),
+ torch.zeros(1, dtype=inter.dtype, device=inter.device),
+ )
+ return iou
+
+
+def bbox_intersection_aligned(boxes1: Tensor, boxes2: Tensor) -> torch.Tensor:
+ """Given two lists of boxes both of size N, compute N intersection.
+
+ Args:
+ boxes1: N 2D boxes in format (x1, y1, x2, y2)
+ boxes2: N 2D boxes in format (x1, y1, x2, y2)
+
+ Returns:
+ Tensor: intersection (N).
+ """
+ width_height = torch.min(boxes1[:, 2:], boxes2[:, 2:]) - torch.max(
+ boxes1[:, :2], boxes2[:, :2]
+ )
+ width_height.clamp_(min=0)
+ intersection = width_height.prod(dim=1)
+ return intersection
+
+
+def bbox_iou_aligned(
+ boxes1: torch.Tensor, boxes2: torch.Tensor
+) -> torch.Tensor:
+ """Compute IoU between aligned pairs of boxes.
+
+ The number of boxes in both inputs must be the same.
+
+ Args:
+ boxes1: N 2D boxes in format (x1, y1, x2, y2)
+ boxes2: N 2D boxes in format (x1, y1, x2, y2)
+
+ Returns:
+ Tensor: IoU (N).
+ """
+ area1 = bbox_area(boxes1)
+ area2 = bbox_area(boxes2)
+ inter = bbox_intersection_aligned(boxes1, boxes2)
+
+ iou = torch.where(
+ inter > 0,
+ inter / (area1 + area2 - inter),
+ torch.zeros(1, dtype=inter.dtype, device=inter.device),
+ )
+ return iou
+
+
+def transform_bbox(
+ trans_mat: torch.Tensor, boxes: torch.Tensor
+) -> torch.Tensor:
+ """Apply trans_mat (3, 3) / (B, 3, 3) to (N, 4) / (B, N, 4) xyxy boxes.
+
+ Args:
+ trans_mat (torch.Tensor): Transformation matrix
+ of shape (3,3) or (B,3,3)
+ boxes (torch.Tensor): Bounding boxes of shape (N,4) or (B,N,4)
+
+ Returns:
+ torch.Tensor containing linear transformed bounding boxes. (B?, N, 4)
+ """
+ assert len(trans_mat.shape) == len(
+ boxes.shape
+ ), "trans_mat and boxes must have same number of dimensions!"
+ x1y1 = boxes[..., :2]
+ x2y1 = torch.stack((boxes[..., 2], boxes[..., 1]), -1)
+ x2y2 = boxes[..., 2:]
+ x1y2 = torch.stack((boxes[..., 0], boxes[..., 3]), -1)
+
+ x1y1 = transform_points(x1y1, trans_mat)
+ x2y1 = transform_points(x2y1, trans_mat)
+ x2y2 = transform_points(x2y2, trans_mat)
+ x1y2 = transform_points(x1y2, trans_mat)
+
+ x_all = torch.stack(
+ (x1y1[..., 0], x2y2[..., 0], x2y1[..., 0], x1y2[..., 0]), -1
+ )
+ y_all = torch.stack(
+ (x1y1[..., 1], x2y2[..., 1], x2y1[..., 1], x1y2[..., 1]), -1
+ )
+ transformed_boxes = torch.stack(
+ (
+ x_all.min(dim=-1)[0],
+ y_all.min(dim=-1)[0],
+ x_all.max(dim=-1)[0],
+ y_all.max(dim=-1)[0],
+ ),
+ -1,
+ )
+
+ if len(boxes.shape) == 2:
+ transformed_boxes.squeeze(0)
+ return transformed_boxes
+
+
+# TODO, refactor? move to utils?
+def random_choice(tensor: torch.Tensor, sample_size: int) -> torch.Tensor:
+ """Randomly choose elements from a tensor.
+
+ If sample_size < len(tensor) this function will sample without repetition
+ otherwise certain elements will be repeated.
+
+ Args:
+ tensor (torch.Tensor): Tensor to sample from
+ sample_size (int): Number of elements to sample
+
+ Returns:
+ torch.Tensor containing sample_size randomly sampled entries.
+ """
+ perm = torch.randperm(len(tensor), device=tensor.device)[:sample_size]
+
+ # Additionally sample with repetition
+ if sample_size > len(tensor):
+ remaining_samples = sample_size - len(tensor)
+ perm = torch.concat(
+ [
+ torch.randint(
+ remaining_samples,
+ (remaining_samples,),
+ device=tensor.device,
+ ),
+ perm,
+ ]
+ )
+
+ return tensor[perm]
+
+
+def non_intersection(
+ tensor_a: torch.Tensor, tensor_b: torch.Tensor
+) -> torch.Tensor:
+ """Get the elements of tensor_a that are not present in tensor_b.
+
+ Args:
+ tensor_a (torch.Tensor): First tensor
+ tensor_b (torch.Tensor): Second tensor
+
+ Returns:
+ torch.Tensor containing all elements that occur in both tensors
+ """
+ compareview = tensor_b.repeat(tensor_a.shape[0], 1).T
+ return tensor_a[(compareview != tensor_a).T.prod(1) == 1]
+
+
+def apply_mask(
+ masks: list[torch.Tensor], *args: list[torch.Tensor]
+) -> tuple[list[torch.Tensor], ...]:
+ """Apply given masks (either bool or indices) to given list of tensors.
+
+ Args:
+ masks (list[torch.Tensor]): Masks to apply on tensors.
+ *args (list[torch.Tensor]): List of tensors to apply the masks on.
+
+ Returns:
+ tuple[list[torch.Tensor], ...]: Masked tensor lists.
+ """
+ return tuple(
+ [t[m] if len(t) > 0 else t for t, m in zip(t_list, masks)]
+ for t_list in args
+ )
+
+
+def filter_boxes_by_area(
+ boxes: torch.Tensor, min_area: float = 0.0
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Filter a set of 2D bounding boxes given a minimum area.
+
+ Args:
+ boxes (Tensor): 2D bounding boxes [N, 4].
+ min_area (float, optional): Minimum area. Defaults to 0.0.
+
+ Returns:
+ tuple[Tensor, Tensor]: filtered boxes, boolean mask
+ """
+ if min_area > 0.0:
+ w = boxes[:, 2] - boxes[:, 0]
+ h = boxes[:, 3] - boxes[:, 1]
+ valid_mask = w * h >= min_area
+ if not valid_mask.all():
+ return boxes[valid_mask], valid_mask
+ return boxes, boxes.new_ones((len(boxes),), dtype=torch.bool)
+
+
+def hbox2corner(boxes: Tensor) -> Tensor:
+ """Convert box coordinates from boxes to corners.
+
+ Boxes are represented as (x1, y1, x2, y2).
+ Corners are represented as ((x1, y1), (x2, y1), (x1, y2), (x2, y2)).
+
+ Args:
+ boxes (Tensor): Horizontal box tensor with shape of (..., 4).
+
+ Returns:
+ Tensor: Corner tensor with shape of (..., 4, 2).
+ """
+ x1, y1, x2, y2 = torch.split(boxes, 1, dim=-1)
+ corners = torch.cat([x1, y1, x2, y1, x1, y2, x2, y2], dim=-1)
+ return corners.reshape(*corners.shape[:-1], 4, 2)
+
+
+def corner2hbox(corners: Tensor) -> Tensor:
+ """Convert box coordinates from corners to boxes.
+
+ Boxes are represented as (x1, y1, x2, y2).
+ Corners are represented as ((x1, y1), (x2, y1), (x1, y2), (x2, y2)).
+
+ Args:
+ corners (Tensor): Corner tensor with shape of (..., 4, 2).
+
+ Returns:
+ Tensor: Horizontal box tensor with shape of (..., 4).
+ """
+ if corners.numel() == 0:
+ return corners.new_zeros((0, 4))
+ min_xy = corners.min(dim=-2)[0]
+ max_xy = corners.max(dim=-2)[0]
+ return torch.cat([min_xy, max_xy], dim=-1)
+
+
+def bbox_project(boxes: Tensor, homography_matrix: Tensor) -> Tensor:
+ """Apply geometric transform to boxes in-place.
+
+ Args:
+ boxes (Tensor): Horizontal box tensor with shape of (..., 4).
+ homography_matrix (Tensor): Shape (3, 3) for geometric transformation.
+ """
+ corners = hbox2corner(boxes)
+ corners = torch.cat(
+ [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1
+ )
+ corners_t = torch.transpose(corners, -1, -2)
+ corners_t = torch.matmul(homography_matrix, corners_t)
+ corners = torch.transpose(corners_t, -1, -2)
+ # Convert to homogeneous coordinates by normalization
+ corners = corners[..., :2] / corners[..., 2:3]
+ return corner2hbox(corners)
+
+
+def multiclass_nms(
+ multi_bboxes: Tensor,
+ multi_scores: Tensor,
+ score_thr: float,
+ iou_thr: float,
+ max_num: int = -1,
+ class_agnostic: bool = False,
+ split_thr: int = 100000,
+) -> tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Non-maximum suppression with multiple classes.
+
+ Args:
+ multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
+ multi_scores (Tensor): shape (n, #class), where the last column
+ contains scores of the background class, but this will be ignored.
+ score_thr (float): bbox threshold, bboxes with scores lower than it
+ will not be considered.
+ iou_thr (float): NMS IoU threshold
+ max_num (int, optional): if there are more than max_num bboxes after
+ NMS, only top max_num will be kept. Defaults to -1.
+ class_agnostic (bool, optional): whether apply class_agnostic NMS.
+ Defaults to False.
+ split_thr (int, optional): If the number of bboxes is less than
+ split_thr, use class agnostic NMS with class_agnostic=True.
+ Defaults to 100000.
+
+ Returns:
+ tuple: (Tensor, Tensor, Tensor, Tensor): detections (k, 5), scores
+ (k), classes (k) and indices (k).
+
+ Raises:
+ RuntimeError: If there is a onnx error,
+ """
+ num_classes = multi_scores.size(1) - 1
+ # exclude background category
+ if multi_bboxes.shape[1] > 4:
+ bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
+ else:
+ bboxes = multi_bboxes[:, None].expand(
+ multi_scores.size(0), num_classes, 4
+ )
+
+ scores = multi_scores[:, :-1]
+
+ labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
+ labels = labels.view(1, -1).expand_as(scores)
+
+ bboxes = bboxes.reshape(-1, 4)
+ scores = scores.reshape(-1)
+ labels = labels.reshape(-1)
+
+ if not torch.onnx.is_in_onnx_export():
+ # NonZero not supported in TensorRT
+ # remove low scoring boxes
+ valid_mask = scores > score_thr
+
+ if not torch.onnx.is_in_onnx_export():
+ # NonZero not supported in TensorRT
+ inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
+ bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
+ else:
+ # TensorRT NMS plugin has invalid output filled with -1
+ # add dummy data to make detection output correct.
+ bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0)
+ scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
+ labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
+
+ if bboxes.numel() == 0:
+ if torch.onnx.is_in_onnx_export():
+ raise RuntimeError(
+ "[ONNX Error] Can not record NMS "
+ "as it has not been executed this time"
+ )
+ return bboxes, scores, labels, inds
+
+ if class_agnostic and bboxes.shape[0] < split_thr:
+ keep = nms(bboxes, scores, iou_thr)
+ else:
+ if class_agnostic:
+ rank_zero_warn(
+ f"Number of bboxes is larger than {split_thr}, "
+ "using per-class NMS instead"
+ )
+ keep = batched_nms(bboxes, scores, labels, iou_thr)
+
+ if max_num > 0:
+ keep = keep[:max_num]
+
+ bboxes = bboxes[keep]
+ scores = scores[keep]
+ labels = labels[keep]
+ return bboxes, scores, labels, inds[keep]
diff --git a/vis4d/op/box/box3d.py b/vis4d/op/box/box3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..756d2172709c322e315451dc53d3a69754307a60
--- /dev/null
+++ b/vis4d/op/box/box3d.py
@@ -0,0 +1,144 @@
+"""Utility functions for 3D bounding boxes."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+from vis4d.data.const import AxisMode
+from vis4d.op.geometry.projection import project_points
+from vis4d.op.geometry.rotation import (
+ euler_angles_to_matrix,
+ matrix_to_quaternion,
+ quaternion_multiply,
+ quaternion_to_matrix,
+ rotate_orientation,
+ rotation_matrix_yaw,
+)
+from vis4d.op.geometry.transform import get_transform_matrix, transform_points
+
+
+def boxes3d_to_corners(boxes3d: Tensor, axis_mode: AxisMode) -> Tensor:
+ """Convert a Tensor of 3D boxes to its respective corner points.
+
+ Args:
+ boxes3d (Tensor): Box parameters. Tensor of shape [N, 10].
+ axis_mode (AxisMode): Coordinate system convention.
+
+ Returns:
+ Tensor: [N, 8, 3] 3D bounding box corner coordinates, in this order:
+
+ (back)
+ (6) +---------+. (7)
+ | ` . | ` .
+ | (4) +---+-----+ (5)
+ | | | |
+ (2) +-----+---+. (3)|
+ ` . | ` . |
+ (0) ` +---------+ (1)
+ (front)
+ """
+ w, l, h = boxes3d[:, 3], boxes3d[:, 4], boxes3d[:, 5]
+ rotation_matrix = quaternion_to_matrix(boxes3d[:, 6:])
+
+ if axis_mode == AxisMode.OPENCV:
+ x_corners = torch.stack(
+ [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2],
+ dim=-1,
+ )
+ y_corners = torch.stack(
+ [h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2],
+ dim=-1,
+ )
+ z_corners = torch.stack(
+ [-w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2],
+ dim=-1,
+ )
+ else:
+ x_corners = torch.stack(
+ [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2],
+ dim=-1,
+ )
+ y_corners = torch.stack(
+ [-w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2, -w / 2, w / 2],
+ dim=-1,
+ )
+ z_corners = torch.stack(
+ [-h / 2, -h / 2, -h / 2, -h / 2, h / 2, h / 2, h / 2, h / 2],
+ dim=-1,
+ )
+
+ corners = torch.stack([x_corners, y_corners, z_corners], dim=-1)
+ corners = transform_points(
+ corners, get_transform_matrix(rotation_matrix, boxes3d[:, :3])
+ )
+ return corners
+
+
+def boxes3d_in_image(
+ box_corners: Tensor, cam_intrinsics: Tensor, image_hw: tuple[int, int]
+) -> Tensor:
+ """Check if a 3D bounding box is (partially) in an image.
+
+ Args:
+ box_corners (Tensor): [N, 8, 3] Tensor of 3D boxes corners. In OpenCV
+ coordinate frame.
+ cam_intrinsics (Tensor): [3, 3] Camera matrix.
+ image_hw (tuple[int, int]): image height / width.
+
+ Returns:
+ Tensor: [N,] boolean values.
+ """
+ points = project_points(box_corners.view(-1, 3), cam_intrinsics).view(
+ -1, 8, 2
+ )
+ mask = (points[..., 0] >= 0) * (points[..., 0] < image_hw[1]) * (
+ points[..., 1] >= 0
+ ) * (points[..., 1] < image_hw[0]) * box_corners[..., 2] > 0.0
+ mask = mask.any(dim=-1)
+ return mask
+
+
+def transform_boxes3d(
+ boxes3d: Tensor,
+ transform_matrix: Tensor,
+ source_axis_mode: AxisMode,
+ target_axis_mode: AxisMode,
+ only_yaw: bool = True,
+) -> Tensor:
+ """Transform 3D boxes using given transform matrix.
+
+ Args:
+ boxes3d (Tensor): [N, 10] Tensor of 3D boxes.
+ transform_matrix (Tensor): [4, 4] Transform matrix.
+ source_axis_mode (AxisMode): Source coordinate system convention of the
+ boxes.
+ target_axis_mode (AxisMode): Target coordinate system convention of the
+ boxes.
+ only_yaw (bool): Whether to only care about yaw rotation.
+ """
+ boxes3d_transformed = boxes3d.new_zeros(boxes3d.shape)
+ boxes3d_transformed[:, :3] = transform_points(
+ boxes3d[:, :3], transform_matrix
+ )
+ boxes3d_transformed[:, 3:6] = boxes3d[:, 3:6]
+
+ if only_yaw:
+ orientation = rotation_matrix_yaw(
+ quaternion_to_matrix(boxes3d[:, 6:]), source_axis_mode
+ )
+
+ orientation = rotate_orientation(
+ orientation, transform_matrix, axis_mode=target_axis_mode
+ )
+
+ boxes3d_transformed[:, 6:] = matrix_to_quaternion(
+ euler_angles_to_matrix(orientation)
+ )
+ else:
+ rot_quat = matrix_to_quaternion(transform_matrix[:3, :3])
+ boxes3d_transformed[:, 6:] = quaternion_multiply(
+ rot_quat, boxes3d[:, 6:]
+ )
+
+ return boxes3d_transformed
diff --git a/vis4d/op/box/encoder/__init__.py b/vis4d/op/box/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76cb645aae18da37a436e0b2333f4a3648ccb8ad
--- /dev/null
+++ b/vis4d/op/box/encoder/__init__.py
@@ -0,0 +1,12 @@
+"""Init box coder module."""
+
+from .delta_xywh import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
+from .qd_3dt import QD3DTBox3DDecoder
+from .yolox import YOLOXBBoxDecoder
+
+__all__ = [
+ "DeltaXYWHBBoxEncoder",
+ "DeltaXYWHBBoxDecoder",
+ "QD3DTBox3DDecoder",
+ "YOLOXBBoxDecoder",
+]
diff --git a/vis4d/op/box/encoder/bevformer.py b/vis4d/op/box/encoder/bevformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6991757366f82d25fd210dc99d032850339c7458
--- /dev/null
+++ b/vis4d/op/box/encoder/bevformer.py
@@ -0,0 +1,119 @@
+"""NMS-Free bounding box coder for BEVFormer."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+
+class NMSFreeDecoder:
+ """BBox decoder for NMS-free detector."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ post_center_range: list[float],
+ max_num: int = 100,
+ score_threshold: float | None = None,
+ ) -> None:
+ """Initialize NMSFreeDecoder.
+
+ Args:
+ num_classes (int): Number of classes.
+ post_center_range (list[float]): Limit of the center.
+ max_num (int): Max number to be kept. Default: 100.
+ score_threshold (float): Threshold to filter boxes based on score.
+ Default: None.
+ """
+ self.num_classes = num_classes
+ self.post_center_range = post_center_range
+ self.max_num = max_num
+ self.score_threshold = score_threshold
+
+ def __call__(
+ self, cls_scores: Tensor, bbox_preds: Tensor
+ ) -> tuple[Tensor, Tensor, Tensor]:
+ """Decode single batch bboxes.
+
+ Args:
+ cls_scores (Tensor): Outputs from the classification head, in shape
+ of [num_query, cls_out_channels]. Note cls_out_channels
+ should includes background.
+ bbox_preds (Tensor): Outputs from the regression
+ head with normalized coordinate format (cx, cy, w, l, cz, h,
+ rot_sine, rot_cosine, vx, vy). Shape [num_query, 9].
+
+ Returns:
+ tuple[Tensor, Tensor, Tensor]: Decoded boxes (x, y, z, l, w, h,
+ yaw, vx, vy), scores and labels.
+ """
+ cls_scores = cls_scores.sigmoid()
+ scores, indexs = cls_scores.view(-1).topk(self.max_num)
+ labels = indexs % self.num_classes
+ bbox_index = indexs // self.num_classes
+ bbox_preds = bbox_preds[bbox_index]
+
+ final_box_preds = _denormalize_bbox(bbox_preds)
+ final_scores = scores
+ final_preds = labels
+
+ # use score threshold
+ if self.score_threshold is not None:
+ thresh_mask = final_scores > self.score_threshold
+ tmp_score = self.score_threshold
+ while thresh_mask.sum() == 0:
+ tmp_score *= 0.9
+ if tmp_score < 0.01:
+ thresh_mask = final_scores > -1
+ break
+ thresh_mask = final_scores >= tmp_score
+
+ post_center_range = torch.tensor(
+ self.post_center_range, device=scores.device
+ )
+ mask = (final_box_preds[..., :3] >= post_center_range[:3]).all(1)
+ mask &= (final_box_preds[..., :3] <= post_center_range[3:]).all(1)
+
+ if self.score_threshold:
+ mask &= thresh_mask
+
+ boxes3d = final_box_preds[mask]
+ scores = final_scores[mask]
+
+ labels = final_preds[mask]
+
+ return boxes3d, scores, labels
+
+
+def _denormalize_bbox(normalized_bboxes: Tensor) -> Tensor:
+ """Denormalize bounding boxes."""
+ # rotation
+ rot_sine = normalized_bboxes[..., 6:7]
+
+ rot_cosine = normalized_bboxes[..., 7:8]
+ rot = torch.atan2(rot_sine, rot_cosine)
+
+ # center in the bev
+ cx = normalized_bboxes[..., 0:1]
+ cy = normalized_bboxes[..., 1:2]
+ cz = normalized_bboxes[..., 4:5]
+
+ # size
+ w = normalized_bboxes[..., 2:3]
+ l = normalized_bboxes[..., 3:4]
+ h = normalized_bboxes[..., 5:6]
+
+ w = w.exp()
+ l = l.exp()
+ h = h.exp()
+ if normalized_bboxes.size(-1) > 8:
+ # velocity
+ vx = normalized_bboxes[:, 8:9]
+ vy = normalized_bboxes[:, 9:10]
+ denormalized_bboxes = torch.cat(
+ [cx, cy, cz, w, l, h, rot, vx, vy], dim=-1
+ )
+ else:
+ denormalized_bboxes = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)
+
+ return denormalized_bboxes
diff --git a/vis4d/op/box/encoder/delta_xywh.py b/vis4d/op/box/encoder/delta_xywh.py
new file mode 100644
index 0000000000000000000000000000000000000000..65944620088cb7174938bdd06e6153e81f85be28
--- /dev/null
+++ b/vis4d/op/box/encoder/delta_xywh.py
@@ -0,0 +1,215 @@
+"""XYWH Delta coder for 2D boxes.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import Tensor
+
+
+class DeltaXYWHBBoxEncoder:
+ """Delta XYWH BBox encoder.
+
+ Following the practice in `R-CNN `_,
+ it encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh).
+ """
+
+ def __init__(
+ self,
+ target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
+ target_stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ target_means (tuple, optional): Denormalizing means of target for
+ delta coordinates. Defaults to (0.0, 0.0, 0.0, 0.0).
+ target_stds (tuple, optional): Denormalizing standard deviation of
+ target for delta coordinates. Defaults to (1.0, 1.0, 1.0, 1.0).
+ """
+ self.means = target_means
+ self.stds = target_stds
+
+ def __call__(self, boxes: Tensor, targets: Tensor) -> Tensor:
+ """Get box regression transformation deltas.
+
+ Used to transform target boxes into target regression parameters.
+
+ Args:
+ boxes (Tensor): Source boxes, e.g., object proposals.
+ targets (Tensor): Target of the transformation, e.g.,
+ ground-truth boxes.
+
+ Returns:
+ Tensor: Box transformation deltas
+ """
+ assert boxes.size(0) == targets.size(0)
+ assert boxes.size(-1) == targets.size(-1) == 4
+ encoded_bboxes = bbox2delta(boxes, targets, self.means, self.stds)
+ return encoded_bboxes
+
+
+class DeltaXYWHBBoxDecoder:
+ """Delta XYWH BBox decoder.
+
+ Following the practice in `R-CNN `_,
+ it decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
+ """
+
+ def __init__(
+ self,
+ target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
+ target_stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
+ wh_ratio_clip: float = 16 / 1000,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ target_means (tuple, optional): Denormalizing means of target for
+ delta coordinates. Defaults to (0.0, 0.0, 0.0, 0.0).
+ target_stds (tuple, optional): Denormalizing standard deviation of
+ target for delta coordinates. Defaults to (1.0, 1.0, 1.0, 1.0).
+ wh_ratio_clip (float, optional): Maximum aspect ratio for boxes.
+ Defaults to 16/1000.
+ """
+ self.means = target_means
+ self.stds = target_stds
+ self.wh_ratio_clip = wh_ratio_clip
+
+ def __call__(self, boxes: Tensor, box_deltas: Tensor) -> Tensor:
+ """Apply box offset energies box_deltas to boxes.
+
+ Args:
+ boxes (Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
+ box_deltas (Tensor): Encoded offsets with respect to each roi.
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+ when rois is a grid of anchors.Offset encoding follows [1]_.
+
+ Returns:
+ Tensor: Decoded boxes.
+ """
+ assert box_deltas.size(0) == boxes.size(0)
+ decoded_boxes = delta2bbox(
+ boxes, box_deltas, self.means, self.stds, self.wh_ratio_clip
+ )
+ return decoded_boxes
+
+
+def bbox2delta(
+ proposals: torch.Tensor,
+ gt_boxes: torch.Tensor,
+ means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
+ stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
+) -> Tensor:
+ """Compute deltas of proposals w.r.t. gt.
+
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
+ truth boxes to get regression target.
+ This is the inverse function of :func:`delta2bbox`.
+
+ Args:
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4).
+ gt_boxes (Tensor): Gt boxes to be used as base, shape (N, ..., 4).
+ means (Sequence[float]): Denormalizing means for delta coordinates.
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates.
+
+ Returns:
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
+ dw, dh.
+ """
+ assert proposals.size() == gt_boxes.size()
+
+ proposals = proposals.float()
+ gt = gt_boxes.float()
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
+ pw = proposals[..., 2] - proposals[..., 0]
+ ph = proposals[..., 3] - proposals[..., 1]
+
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
+ gw = gt[..., 2] - gt[..., 0]
+ gh = gt[..., 3] - gt[..., 1]
+
+ dx = (gx - px) / pw
+ dy = (gy - py) / ph
+ dw = torch.log(gw / pw)
+ dh = torch.log(gh / ph)
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
+
+ mean_tensor = torch.tensor(means, dtype=deltas.dtype, device=deltas.device)
+ std_tensor = torch.tensor(stds, dtype=deltas.dtype, device=deltas.device)
+ deltas = deltas.sub_(mean_tensor.view(1, -1)).div_(std_tensor.view(1, -1))
+
+ return deltas
+
+
+def delta2bbox(
+ rois: torch.Tensor,
+ deltas: torch.Tensor,
+ means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
+ stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
+ wh_ratio_clip: float = 16 / 1000,
+) -> Tensor:
+ """Apply deltas to shift/scale base boxes.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
+ network outputs used to shift/scale those boxes.
+ This is the inverse function of :func:`bbox2delta`.
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4).
+ deltas (Tensor): Encoded offsets relative to each roi.
+ Has shape (N, num_classes * 4) or (N, 4). Note
+ N = num_base_anchors * W * H, when rois is a grid of
+ anchors. Offset encoding follows [1]_.
+ means (Sequence[float]): Denormalizing means for delta coordinates.
+ Default (0., 0., 0., 0.).
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates. Default (1., 1., 1., 1.).
+ wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
+ 16 / 1000.
+
+ Returns:
+ Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
+ represent tl_x, tl_y, br_x, br_y.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+ """
+ num_boxes, num_classes = deltas.size(0), deltas.size(1) // 4
+ if num_boxes == 0:
+ return deltas
+
+ deltas = deltas.reshape(-1, 4)
+
+ mean_tensor = torch.tensor(means, dtype=deltas.dtype, device=deltas.device)
+ std_tensor = torch.tensor(stds, dtype=deltas.dtype, device=deltas.device)
+ denorm_deltas = deltas * std_tensor.view(1, -1) + mean_tensor.view(1, -1)
+
+ dxy = denorm_deltas[:, :2]
+ dwh = denorm_deltas[:, 2:]
+
+ # Compute width/height of each roi
+ rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
+ pxy = (rois_[:, :2] + rois_[:, 2:]) * 0.5
+ pwh = rois_[:, 2:] - rois_[:, :2]
+
+ dxy_wh = pwh * dxy
+
+ max_ratio = abs(math.log(wh_ratio_clip))
+ dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
+
+ gxy = pxy + dxy_wh
+ gwh = pwh * dwh.exp()
+ x1y1 = gxy - (gwh * 0.5)
+ x2y2 = gxy + (gwh * 0.5)
+ boxes = torch.cat([x1y1, x2y2], dim=-1)
+ boxes = boxes.reshape(num_boxes, -1)
+ return boxes
diff --git a/vis4d/op/box/encoder/qd_3dt.py b/vis4d/op/box/encoder/qd_3dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..7258f10b98a2f04a6407d484f77f8d2416a3e2d7
--- /dev/null
+++ b/vis4d/op/box/encoder/qd_3dt.py
@@ -0,0 +1,159 @@
+"""3D bounding box coder."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from vis4d.data.const import AxisMode
+from vis4d.op.geometry.projection import project_points, unproject_points
+from vis4d.op.geometry.rotation import (
+ alpha2yaw,
+ normalize_angle,
+ quaternion_to_matrix,
+ rotation_matrix_yaw,
+ rotation_output_to_alpha,
+ yaw2alpha,
+)
+
+
+class QD3DTBox3DEncoder:
+ """3D bounding box encoder based on qd_3dt."""
+
+ def __init__(
+ self,
+ center_scale: float = 10.0,
+ depth_log_scale: float = 2.0,
+ dim_log_scale: float = 2.0,
+ num_rotation_bins: int = 2,
+ bin_overlap: float = 1 / 6,
+ ) -> None:
+ """Init."""
+ self.center_scale = center_scale
+ self.depth_log_scale = depth_log_scale
+ self.dim_log_scale = dim_log_scale
+ self.num_rotation_bins = num_rotation_bins
+ self.bin_overlap = bin_overlap
+
+ def __call__(
+ self, boxes: Tensor, boxes3d: Tensor, intrinsics: Tensor
+ ) -> Tensor:
+ """Encode deltas between 2D boxes and 3D boxes given intrinsics."""
+ # delta center 2d
+ projected_center_3d = project_points(boxes3d[:, :3], intrinsics)
+ ctr_x = (boxes[:, 0] + boxes[:, 2]) / 2
+ ctr_y = (boxes[:, 1] + boxes[:, 3]) / 2
+ center_2d = torch.stack([ctr_x, ctr_y], -1)
+ delta_center = (projected_center_3d - center_2d) / self.center_scale
+
+ # depth
+ depth = torch.where(
+ boxes3d[:, 2] > 0,
+ torch.log(boxes3d[:, 2]) * self.depth_log_scale,
+ -boxes3d[:, 2].new_ones(1),
+ )
+ depth = depth.unsqueeze(-1)
+
+ # dimensions
+ dims = torch.where(
+ boxes3d[:, 3:6] > 0,
+ torch.log(boxes3d[:, 3:6]) * self.dim_log_scale,
+ boxes3d[:, 3:6].new_ones(1) * 100.0,
+ )
+
+ # WLH -> HWL
+ dims = dims[:, [2, 0, 1]]
+
+ # rotation
+ yaw = rotation_matrix_yaw(
+ quaternion_to_matrix(boxes3d[:, 6:]), axis_mode=AxisMode.OPENCV
+ )[:, 1]
+ alpha = yaw2alpha(yaw, boxes3d[:, :3])
+ bin_cls = torch.zeros(
+ (alpha.shape[0], self.num_rotation_bins), device=alpha.device
+ )
+ bin_res = torch.zeros(
+ (alpha.shape[0], self.num_rotation_bins), device=alpha.device
+ )
+ bin_centers = torch.arange(
+ -np.pi,
+ np.pi,
+ 2 * np.pi / self.num_rotation_bins,
+ device=alpha.device,
+ )
+ bin_centers += np.pi / self.num_rotation_bins
+ for i in range(alpha.shape[0]):
+ overlap_value = (
+ np.pi * 2 / self.num_rotation_bins * self.bin_overlap
+ )
+ alpha_hi = normalize_angle(alpha[i] + overlap_value)
+ alpha_lo = normalize_angle(alpha[i] - overlap_value)
+ for bin_idx in range(self.num_rotation_bins):
+ bin_min = bin_centers[bin_idx] - np.pi / self.num_rotation_bins
+ bin_max = bin_centers[bin_idx] + np.pi / self.num_rotation_bins
+ if (
+ bin_min <= alpha_lo <= bin_max
+ or bin_min <= alpha_hi <= bin_max
+ ):
+ bin_cls[i, bin_idx] = 1
+ bin_res[i, bin_idx] = alpha[i] - bin_centers[bin_idx]
+
+ return torch.cat([delta_center, depth, dims, bin_cls, bin_res], -1)
+
+
+class QD3DTBox3DDecoder:
+ """3D bounding box decoder based on qd_3dt."""
+
+ def __init__(
+ self,
+ center_scale: float = 10.0,
+ depth_log_scale: float = 2.0,
+ dim_log_scale: float = 2.0,
+ num_rotation_bins: int = 2,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.center_scale = center_scale
+ self.depth_log_scale = depth_log_scale
+ self.dim_log_scale = dim_log_scale
+ self.num_rotation_bins = num_rotation_bins
+
+ def __call__(
+ self, boxes_2d: Tensor, boxes_deltas: Tensor, intrinsics: Tensor
+ ) -> Tensor:
+ """Decode the predicted boxes_deltas according to given 2D boxes."""
+ # center
+ delta_center = boxes_deltas[:, 0:2] * self.center_scale
+ ctr_x = (boxes_2d[:, 0] + boxes_2d[:, 2]) / 2
+ ctr_y = (boxes_2d[:, 1] + boxes_2d[:, 3]) / 2
+ boxes_2d_center = torch.stack([ctr_x, ctr_y], -1)
+ center_2d = boxes_2d_center + delta_center
+ depth = torch.exp(boxes_deltas[:, 2:3] / self.depth_log_scale)
+ center_3d = unproject_points(center_2d, depth, intrinsics)
+
+ # dimensions
+ dimensions = torch.exp(boxes_deltas[:, 3:6] / self.dim_log_scale)
+
+ # rot_y
+ alpha = rotation_output_to_alpha(
+ boxes_deltas[:, 6:-1], self.num_rotation_bins
+ )
+ rot_y = alpha2yaw(alpha, center_3d)
+ orientation = torch.stack(
+ [torch.zeros_like(rot_y), rot_y, torch.zeros_like(rot_y)], -1
+ )
+
+ velocities = torch.zeros(
+ (boxes_deltas.shape[0], 3), device=boxes_deltas.device
+ )
+
+ return torch.cat(
+ [
+ center_3d,
+ dimensions,
+ orientation,
+ velocities,
+ ],
+ 1,
+ )
diff --git a/vis4d/op/box/encoder/yolox.py b/vis4d/op/box/encoder/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9c7dbd7c5168bb9203488f30b17232cd88bbb45
--- /dev/null
+++ b/vis4d/op/box/encoder/yolox.py
@@ -0,0 +1,34 @@
+"""YOLOX decoder for 2D boxes.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+
+class YOLOXBBoxDecoder:
+ """YOLOX BBox decoder."""
+
+ def __call__(self, points: Tensor, offsets: Tensor) -> Tensor:
+ """Apply box offsets to points, used by YOLOX.
+
+ Args:
+ points (Tensor): Points. Shape (B, N, 4) or (N, 4).
+ offsets (Tensor): Offsets. Has shape (B, N, 4) or (N, 4).
+
+ Returns:
+ Tensor: Decoded boxes.
+ """
+ xys = (offsets[..., :2] * points[:, 2:]) + points[:, :2]
+ whs = offsets[..., 2:].exp() * points[:, 2:]
+
+ tl_x = xys[..., 0] - whs[..., 0] / 2
+ tl_y = xys[..., 1] - whs[..., 1] / 2
+ br_x = xys[..., 0] + whs[..., 0] / 2
+ br_y = xys[..., 1] + whs[..., 1] / 2
+
+ decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
+ return decoded_bboxes
diff --git a/vis4d/op/box/matchers/__init__.py b/vis4d/op/box/matchers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1dd157c1e4b3a9ebaf9e9445399b857c0b4003
--- /dev/null
+++ b/vis4d/op/box/matchers/__init__.py
@@ -0,0 +1,7 @@
+"""Matchers package."""
+
+from .base import Matcher, MatchResult
+from .max_iou import MaxIoUMatcher
+from .sim_ota import SimOTAMatcher
+
+__all__ = ["Matcher", "MaxIoUMatcher", "MatchResult", "SimOTAMatcher"]
diff --git a/vis4d/op/box/matchers/base.py b/vis4d/op/box/matchers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a6b27735e70101fdc82fe0cb8b9fa5c99098d8
--- /dev/null
+++ b/vis4d/op/box/matchers/base.py
@@ -0,0 +1,37 @@
+"""Matchers."""
+
+import abc
+from typing import NamedTuple
+
+import torch
+from torch import nn
+
+
+class MatchResult(NamedTuple):
+ """Match result class. Stores expected result tensors.
+
+ assigned_gt_indices: torch.Tensor - Tensor of [0, M) where M = num gt
+ assigned_gt_iou: torch.Tensor - Tensor with IoU to assigned GT
+ assigned_labels: torch.Tensor - Tensor of {0, -1, 1} = {neg, ignore, pos}
+ """
+
+ assigned_gt_indices: torch.Tensor
+ assigned_gt_iou: torch.Tensor
+ assigned_labels: torch.Tensor
+
+
+class Matcher(nn.Module):
+ """Base class for box / target matchers."""
+
+ @abc.abstractmethod
+ def forward(
+ self, boxes: torch.Tensor, targets: torch.Tensor
+ ) -> MatchResult:
+ """Match bounding boxes according to their struct."""
+ raise NotImplementedError
+
+ def __call__(
+ self, boxes: torch.Tensor, targets: torch.Tensor
+ ) -> MatchResult:
+ """Type declaration for forward."""
+ return self._call_impl(boxes, targets)
diff --git a/vis4d/op/box/matchers/max_iou.py b/vis4d/op/box/matchers/max_iou.py
new file mode 100644
index 0000000000000000000000000000000000000000..506098db36a210fb67d547214b81eaa931436c70
--- /dev/null
+++ b/vis4d/op/box/matchers/max_iou.py
@@ -0,0 +1,126 @@
+"""Match predictions and targets according to maximum 2D IoU."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+from ..box2d import bbox_iou
+from .base import Matcher, MatchResult
+
+
+# implementation modified from:
+# https://github.com/facebookresearch/detectron2/
+class MaxIoUMatcher(Matcher):
+ """MaxIoUMatcher class."""
+
+ def __init__(
+ self,
+ thresholds: list[float],
+ labels: list[int],
+ allow_low_quality_matches: bool,
+ min_positive_iou: float = 0.0,
+ ):
+ """Creates an instance of the class."""
+ super().__init__()
+ self.allow_low_quality_matches = allow_low_quality_matches
+ self.min_positive_iou = min_positive_iou
+ if not thresholds[0] > 0:
+ raise ValueError(
+ f"Lowest threshold {thresholds[0]} must be greater than 0!"
+ )
+ eps = 1e-4
+ thresholds.insert(0, 0.0 - eps)
+ thresholds.append(1.0 + eps)
+ if not all(
+ (lo <= hi for (lo, hi) in zip(thresholds[:-1], thresholds[1:]))
+ ):
+ raise ValueError("Thresholds must be in ascending order!")
+
+ assert all(
+ (v in [-1, 0, 1] for v in labels)
+ ), "labels must be in [-1, 0, 1]!"
+ assert (
+ len(labels) == len(thresholds) - 1
+ ), "Labels must be of len(thresholds) + 1."
+ self.thresholds = thresholds
+ self.labels = labels
+
+ def forward(self, boxes: Tensor, targets: Tensor) -> MatchResult:
+ """Match all boxes to targets based on maximum IoU."""
+ if len(targets) == 0:
+ matches = boxes.new_zeros((len(boxes),), dtype=torch.int64)
+ match_labels = boxes.new_zeros((len(boxes),), dtype=torch.int8)
+ match_iou = boxes.new_zeros((len(boxes),))
+ else:
+ # M x N matrix, where M = num gt, N = num proposals
+ match_quality_matrix = bbox_iou(targets, boxes)
+
+ # matches N x 1 = index of assigned gt i.e. range [0, M)
+ # match_labels N x 1, 0 = negative, -1 = ignore, 1 = positive
+ matches, match_labels = self._compute_matches(match_quality_matrix)
+ match_iou = match_quality_matrix[
+ matches, torch.arange(0, len(boxes), device=boxes.device)
+ ]
+
+ return MatchResult(
+ assigned_gt_indices=matches,
+ assigned_labels=match_labels,
+ assigned_gt_iou=match_iou,
+ )
+
+ def _compute_matches(
+ self, match_quality_matrix: Tensor
+ ) -> tuple[Tensor, Tensor]:
+ """Compute matching boxes and their labels w/ match_quality_matrix."""
+ assert match_quality_matrix.dim() == 2
+ if match_quality_matrix.numel() == 0:
+ default_matches = match_quality_matrix.new_full(
+ (match_quality_matrix.shape[1],), 0, dtype=torch.int64
+ )
+ default_match_labels = match_quality_matrix.new_full(
+ (match_quality_matrix.shape[1],),
+ self.labels[0],
+ dtype=torch.int8,
+ )
+ return default_matches, default_match_labels
+
+ assert torch.all(torch.greater_equal(match_quality_matrix, 0))
+
+ # Max over gt elements (dim 0) --> best gt for each prediction
+ matched_vals, matches = match_quality_matrix.max(dim=0)
+
+ match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
+
+ for l, low, high in zip(
+ self.labels, self.thresholds[:-1], self.thresholds[1:]
+ ):
+ low_high = (matched_vals >= low) & (matched_vals < high)
+ match_labels[low_high] = l
+
+ if self.allow_low_quality_matches:
+ _set_low_quality_matches(
+ match_labels, match_quality_matrix, self.min_positive_iou
+ )
+
+ return matches, match_labels
+
+
+def _set_low_quality_matches(
+ match_labels: Tensor,
+ match_quality_matrix: Tensor,
+ min_positive_iou: float = 0.0,
+) -> None:
+ """Set matches for predictions that have only low-quality matches.
+
+ See Sec. 3.1.2 of Faster R-CNN: https://arxiv.org/abs/1506.01497
+ """
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+ if min_positive_iou > 0:
+ highest_quality_foreach_gt = highest_quality_foreach_gt.clamp(
+ min_positive_iou
+ )
+ pred_inds_with_highest_quality = (
+ match_quality_matrix == highest_quality_foreach_gt[:, None]
+ ).nonzero()[:, 1]
+ match_labels[pred_inds_with_highest_quality] = 1
diff --git a/vis4d/op/box/matchers/sim_ota.py b/vis4d/op/box/matchers/sim_ota.py
new file mode 100644
index 0000000000000000000000000000000000000000..d940ce12c5ac3c77e396aab6de1f1d4ee04ac874
--- /dev/null
+++ b/vis4d/op/box/matchers/sim_ota.py
@@ -0,0 +1,252 @@
+"""SimOTA label assigner.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from vis4d.op.box.box2d import bbox_iou
+
+from .base import MatchResult
+
+INF = 100000.0
+EPS = 1.0e-7
+
+
+class SimOTAMatcher(nn.Module):
+ """SimOTA label assigner used by YOLOX.
+
+ Args:
+ center_radius (float, optional): Ground truth center size to judge
+ whether a prior is in center. Defaults to 2.5.
+ candidate_topk (int, optional): The candidate top-k which used to
+ get top-k ious to calculate dynamic-k. Defaults to 10.
+ iou_weight (float, optional): The scale factor for regression
+ iou cost. Defaults to 3.0.
+ cls_weight (float, optional): The scale factor for classification
+ cost. Defaults to 1.0.
+ """
+
+ def __init__(
+ self,
+ center_radius: float = 2.5,
+ candidate_topk: int = 10,
+ iou_weight: float = 3.0,
+ cls_weight: float = 1.0,
+ ):
+ """Init."""
+ super().__init__()
+ self.center_radius = center_radius
+ self.candidate_topk = candidate_topk
+ self.iou_weight = iou_weight
+ self.cls_weight = cls_weight
+
+ def forward( # pylint: disable=arguments-differ # type: ignore[override]
+ self,
+ pred_scores: Tensor,
+ priors: Tensor,
+ decoded_bboxes: Tensor,
+ gt_bboxes: Tensor,
+ gt_labels: Tensor,
+ ) -> MatchResult:
+ """Assign gt to priors using SimOTA.
+
+ Args:
+ pred_scores (Tensor): Classification scores of one image,
+ a 2D-Tensor with shape [num_priors, num_classes]
+ priors (Tensor): All priors of one image, a 2D-Tensor with shape
+ [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
+ decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
+ [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
+ with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth labels of one image, a Tensor
+ with shape [num_gts].
+
+ Returns:
+ MatchResult: The assigned result.
+ """
+ num_gt = gt_bboxes.size(0)
+ num_bboxes = decoded_bboxes.size(0)
+
+ # assign 0 by default
+ assigned_gt_inds = decoded_bboxes.new_full(
+ (num_bboxes,), 0, dtype=torch.long
+ )
+ valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
+ priors, gt_bboxes
+ )
+ valid_decoded_bbox = decoded_bboxes[valid_mask]
+ valid_pred_scores = pred_scores[valid_mask]
+ num_valid = valid_decoded_bbox.size(0)
+
+ if num_gt == 0 or num_bboxes == 0 or num_valid == 0:
+ # No ground truth or boxes, return empty assignment
+ assigned_gt_iou = decoded_bboxes.new_zeros((num_bboxes,))
+ if num_gt == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = decoded_bboxes.new_full(
+ (num_bboxes,), -1, dtype=torch.long
+ )
+ return MatchResult(
+ assigned_gt_indices=assigned_gt_inds,
+ assigned_labels=assigned_labels,
+ assigned_gt_iou=assigned_gt_iou,
+ )
+
+ pairwise_ious = bbox_iou(valid_decoded_bbox, gt_bboxes)
+ iou_cost = -torch.log(pairwise_ious + EPS)
+
+ gt_onehot_label = (
+ F.one_hot( # pylint: disable=not-callable
+ gt_labels.to(torch.int64), pred_scores.shape[-1]
+ )
+ .float()
+ .unsqueeze(0)
+ .repeat(num_valid, 1, 1)
+ )
+
+ valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
+ # disable AMP autocast and calculate BCE with FP32 to avoid overflow
+ with torch.cuda.amp.autocast(enabled=False):
+ cls_cost = (
+ F.binary_cross_entropy(
+ valid_pred_scores.to(dtype=torch.float32),
+ gt_onehot_label,
+ reduction="none",
+ )
+ .sum(-1)
+ .to(dtype=valid_pred_scores.dtype)
+ )
+
+ cost_matrix = (
+ cls_cost * self.cls_weight
+ + iou_cost * self.iou_weight
+ + (~is_in_boxes_and_center) * INF
+ )
+
+ matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(
+ cost_matrix, pairwise_ious, num_gt, valid_mask
+ )
+
+ # convert to MatchResult format
+ assigned_gt_inds[valid_mask] = matched_gt_inds
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
+ assigned_labels[valid_mask] = 1
+ assigned_gt_iou = assigned_gt_inds.new_full(
+ (num_bboxes,), -INF, dtype=torch.float32
+ )
+ assigned_gt_iou[valid_mask] = matched_pred_ious
+ return MatchResult(
+ assigned_gt_indices=assigned_gt_inds,
+ assigned_labels=assigned_labels,
+ assigned_gt_iou=assigned_gt_iou,
+ )
+
+ def get_in_gt_and_in_center_info(
+ self, priors: Tensor, gt_bboxes: Tensor
+ ) -> tuple[Tensor, Tensor]:
+ """Get whether the priors are in gt bboxes and in centers."""
+ num_gt = gt_bboxes.size(0)
+
+ repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt)
+ repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt)
+ repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
+ repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)
+
+ # is prior centers in gt bboxes, shape: [n_prior, n_gt]
+ l_ = repeated_x - gt_bboxes[:, 0]
+ t_ = repeated_y - gt_bboxes[:, 1]
+ r_ = gt_bboxes[:, 2] - repeated_x
+ b_ = gt_bboxes[:, 3] - repeated_y
+
+ deltas = torch.stack([l_, t_, r_, b_], dim=1)
+ is_in_gts = deltas.min(dim=1).values > 0
+ is_in_gts_all = is_in_gts.sum(dim=1) > 0
+
+ # is prior centers in gt centers
+ gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
+ gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
+ ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
+ ct_box_t = gt_cys - self.center_radius * repeated_stride_y
+ ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
+ ct_box_b = gt_cys + self.center_radius * repeated_stride_y
+
+ cl_ = repeated_x - ct_box_l
+ ct_ = repeated_y - ct_box_t
+ cr_ = ct_box_r - repeated_x
+ cb_ = ct_box_b - repeated_y
+
+ ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
+ is_in_cts = ct_deltas.min(dim=1).values > 0
+ is_in_cts_all = is_in_cts.sum(dim=1) > 0
+
+ # in boxes or in centers, shape: [num_priors]
+ is_in_gts_or_centers = is_in_gts_all | is_in_cts_all
+
+ # both in boxes and centers, shape: [num_fg, num_gt]
+ is_in_boxes_and_centers = (
+ is_in_gts[is_in_gts_or_centers, :]
+ & is_in_cts[is_in_gts_or_centers, :]
+ )
+ return is_in_gts_or_centers, is_in_boxes_and_centers
+
+ def dynamic_k_matching(
+ self,
+ cost: Tensor,
+ pairwise_ious: Tensor,
+ num_gt: int,
+ valid_mask: Tensor,
+ ) -> tuple[Tensor, Tensor]:
+ """Dynamic K matching strategy."""
+ matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+ # select candidate topk ious for dynamic-k calculation
+ candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
+ topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
+ # calculate dynamic k for each gt
+ dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
+ for gt_idx in range(num_gt):
+ _, pos_idx = torch.topk(
+ cost[:, gt_idx],
+ k=dynamic_ks[gt_idx].item(), # type: ignore
+ largest=False,
+ )
+ matching_matrix[:, gt_idx][pos_idx] = 1
+
+ del topk_ious, dynamic_ks, pos_idx
+
+ prior_match_gt_mask = matching_matrix.sum(1) > 1
+ if prior_match_gt_mask.sum() > 0:
+ _, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)
+ matching_matrix[prior_match_gt_mask, :] *= 0
+ matching_matrix[prior_match_gt_mask, cost_argmin] = 1
+ # get foreground mask inside box and center prior
+ fg_mask_inboxes = matching_matrix.sum(1) > 0
+ valid_mask[valid_mask.clone()] = fg_mask_inboxes
+
+ matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
+ matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[
+ fg_mask_inboxes
+ ]
+ return matched_pred_ious, matched_gt_inds
+
+ def __call__(
+ self,
+ pred_scores: Tensor,
+ priors: Tensor,
+ decoded_bboxes: Tensor,
+ gt_bboxes: Tensor,
+ gt_labels: Tensor,
+ ) -> MatchResult:
+ """Type declaration for forward."""
+ return self._call_impl(
+ pred_scores, priors, decoded_bboxes, gt_bboxes, gt_labels
+ )
diff --git a/vis4d/op/box/poolers/__init__.py b/vis4d/op/box/poolers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ef125f3eac772daf3e0e279812fbd64213f5f16
--- /dev/null
+++ b/vis4d/op/box/poolers/__init__.py
@@ -0,0 +1,15 @@
+"""Init sampler module."""
+
+from .base import RoIPooler
+from .roi_pooler import (
+ MultiScaleRoIAlign,
+ MultiScaleRoIPool,
+ MultiScaleRoIPooler,
+)
+
+__all__ = [
+ "RoIPooler",
+ "MultiScaleRoIAlign",
+ "MultiScaleRoIPool",
+ "MultiScaleRoIPooler",
+]
diff --git a/vis4d/op/box/poolers/base.py b/vis4d/op/box/poolers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..55cc64ed8f281fc29837cf5a7a5765b515865423
--- /dev/null
+++ b/vis4d/op/box/poolers/base.py
@@ -0,0 +1,24 @@
+"""RoI Pooling module base."""
+
+from __future__ import annotations
+
+import abc
+
+import torch
+from torch import nn
+
+
+class RoIPooler(nn.Module):
+ """Base class for RoI poolers."""
+
+ def __init__(self, resolution: tuple[int, int]) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.resolution = resolution
+
+ @abc.abstractmethod
+ def forward(
+ self, features: list[torch.Tensor], boxes: list[torch.Tensor]
+ ) -> torch.Tensor:
+ """Pool features in input bounding boxes from given feature maps."""
+ raise NotImplementedError
diff --git a/vis4d/op/box/poolers/roi_pooler.py b/vis4d/op/box/poolers/roi_pooler.py
new file mode 100644
index 0000000000000000000000000000000000000000..12a133a0e54e3249fc9d6c39680db142d28efe79
--- /dev/null
+++ b/vis4d/op/box/poolers/roi_pooler.py
@@ -0,0 +1,193 @@
+"""Vis4D RoI Pooling module."""
+
+from __future__ import annotations
+
+import abc
+import math
+
+import torch
+from torchvision.ops import roi_align, roi_pool
+
+from vis4d.common.typing import ArgsType
+
+from .base import RoIPooler
+from .utils import assign_boxes_to_levels, boxes_to_tensor
+
+
+# implementation modified from:
+# https://github.com/facebookresearch/detectron2/
+class MultiScaleRoIPooler(RoIPooler):
+ """Wrapper for roi pooling that supports multi-scale feature maps."""
+
+ def __init__(
+ self,
+ resolution: tuple[int, int],
+ strides: list[int],
+ canonical_box_size: int = 224,
+ canonical_level: int = 4,
+ aligned: bool = True,
+ ):
+ """Multi-scale version of arbitrary RoI pooling operations.
+
+ Args:
+ resolution: Pooler resolution.
+ strides: Feature map strides relative to the input.
+ The strides must be powers of 2 and a monotically decreasing
+ geometric sequence with a factor of 1/2.
+ canonical_box_size: Canonical box size in pixels (sqrt(box area)).
+ The default is heuristically defined as 224 pixels in the FPN
+ paper (based on ImageNet pre-training).
+ canonical_level: The feature map level index from which a canonical
+ sized box should be placed. The default is defined as level 4
+ (stride=16) in the FPN paper, i.e., a box of size 224x224 will
+ be placed on the feature with stride=16.
+ The box placement for all boxes will be determined from their
+ sizes w.r.t canonical_box_size. For example, a box whose area
+ is 4x that of a canonical box should be used to pool features
+ from feature level ``canonical_level+1``.
+ aligned (bool): For roi_align op. Shift the box coordinates it by
+ -0.5 for a better alignment with the two neighboring pixel
+ indices.
+ """
+ super().__init__(resolution)
+ self.canonical_level = canonical_level
+ self.canonical_box_size = canonical_box_size
+ self.aligned = aligned
+ self.strides = strides
+
+ # Map scale (defined as 1 / stride) to its feature map level under the
+ # assumption that stride is a power of 2.
+ self.scales = [1 / s for s in self.strides]
+
+ min_level = -(math.log2(self.scales[0]))
+ max_level = -(math.log2(self.scales[-1]))
+ assert math.isclose(min_level, int(min_level)) and math.isclose(
+ max_level, int(max_level)
+ ), "Featuremap stride is not power of 2!"
+ self.min_level = int(min_level)
+ self.max_level = int(max_level)
+ assert (
+ len(self.scales) == self.max_level - self.min_level + 1
+ ), "[ROIPooler] Sizes of input NamedTensors do not form a pyramid!"
+ assert self.min_level >= 0 and self.min_level <= self.max_level
+ assert self.canonical_box_size > 0
+
+ def forward(
+ self, features: list[torch.Tensor], boxes: list[torch.Tensor]
+ ) -> torch.Tensor:
+ """Torchvision based roi pooling operation.
+
+ Args:
+ features: List of image feature tensors (e.g., fpn levels) - NCHW
+ format.
+ boxes: List of proposals (per image).
+
+ Returns:
+ torch.Tensor: NCHW format, where N = num boxes (total),
+ HW is roi size, C is feature dim. Boxes are concatenated along
+ dimension 0 for all batch elements.
+ """
+ assert len(features) == len(self.scales), (
+ f"unequal value, len(strides)={len(self.scales)}, "
+ f"but x is list of {len(features)} Tensors"
+ )
+
+ assert len(boxes) == features[0].shape[0], (
+ f"unequal value, x[0] batch dim 0 is {features[0].shape[0]}, "
+ f"but box_list has length {len(boxes)}"
+ )
+ if len(boxes) == 0:
+ return torch.zeros(
+ (0, features[0].shape[1]) + self.resolution,
+ device=features[0].device,
+ dtype=features[0].dtype,
+ )
+
+ pooler_fmt_boxes = boxes_to_tensor(boxes)
+ if len(self.scales) == 1:
+ return self._pooling_op(
+ features[0],
+ pooler_fmt_boxes,
+ spatial_scale=self.scales[0],
+ )
+
+ level_assignments = assign_boxes_to_levels(
+ boxes,
+ self.min_level,
+ self.max_level,
+ self.canonical_box_size,
+ self.canonical_level,
+ )
+
+ num_boxes = pooler_fmt_boxes.shape[0]
+ num_channels = features[0].shape[1]
+ output_size = self.resolution[0]
+
+ dtype, device = features[0].dtype, features[0].device
+ output = torch.zeros(
+ (num_boxes, num_channels, output_size, output_size),
+ dtype=dtype,
+ device=device,
+ )
+
+ for level, scale in enumerate(self.scales):
+ inds = torch.eq(level_assignments, level).nonzero()[:, 0]
+ pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
+ pooled_features = self._pooling_op(
+ features[level], pooler_fmt_boxes_level, spatial_scale=scale
+ )
+ # Use index_put_ instead of advance indexing
+ # avoids pytorch/issues/49852
+ output.index_put_((inds,), pooled_features)
+
+ return output
+
+ @abc.abstractmethod
+ def _pooling_op(
+ self,
+ inputs: torch.Tensor,
+ boxes: torch.Tensor,
+ spatial_scale: float = 1.0,
+ ) -> torch.Tensor:
+ """Execute pooling op defined in config."""
+ raise NotImplementedError
+
+
+class MultiScaleRoIAlign(MultiScaleRoIPooler):
+ """RoI Align supporting multi-scale inputs."""
+
+ def __init__(
+ self, sampling_ratio: int, *args: ArgsType, **kwargs: ArgsType
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(*args, **kwargs)
+ self.sampling_ratio = sampling_ratio
+
+ def _pooling_op(
+ self,
+ inputs: torch.Tensor,
+ boxes: torch.Tensor,
+ spatial_scale: float = 1.0,
+ ) -> torch.Tensor:
+ """Roialign wrapper."""
+ return roi_align(
+ inputs,
+ boxes,
+ self.resolution,
+ spatial_scale,
+ self.sampling_ratio,
+ self.aligned,
+ )
+
+
+class MultiScaleRoIPool(MultiScaleRoIPooler):
+ """RoI Pool supporting multi-scale inputs."""
+
+ def _pooling_op(
+ self,
+ inputs: torch.Tensor,
+ boxes: torch.Tensor,
+ spatial_scale: float = 1.0,
+ ) -> torch.Tensor:
+ """Roipool wrapper."""
+ return roi_pool(inputs, boxes, self.resolution, spatial_scale)
diff --git a/vis4d/op/box/poolers/utils.py b/vis4d/op/box/poolers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dcb5f749cfdac0a9f3fc1e20b975094c7417b08
--- /dev/null
+++ b/vis4d/op/box/poolers/utils.py
@@ -0,0 +1,73 @@
+"""Utility functions for RoI poolers."""
+
+from __future__ import annotations
+
+import torch
+
+from ..box2d import bbox_area
+
+
+def assign_boxes_to_levels(
+ box_lists: list[torch.Tensor],
+ min_level: int,
+ max_level: int,
+ canonical_box_size: int,
+ canonical_level: int,
+) -> torch.Tensor:
+ """Map each box to a feature map level index and return the assignment.
+
+ Args:
+ box_lists: List of Boxes
+ min_level: Smallest feature map level index. The input is considered
+ index 0, the output of stage 1 is index 1, and so.
+ max_level: Largest feature map level index.
+ canonical_box_size: A canonical box size in pixels (sqrt(box area)).
+ canonical_level: The feature map level index on which a
+ canonically-sized box should be placed.
+
+ Returns:
+ Tensor (M,), where M is the total number of boxes in the list. Each
+ element is the feature map index, as an offset from min_level, for the
+ corresponding box (so value i means the box is at self.min_level + i).
+ """
+ box_sizes = torch.sqrt(
+ torch.cat([bbox_area(boxes) for boxes in box_lists])
+ )
+ # Eqn.(1) in FPN paper
+ level_assignments = torch.floor(
+ canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
+ )
+ # clamp level to (min, max), in case the box size is too large or too small
+ # for the available feature maps
+ level_assignments = torch.clamp(
+ level_assignments, min=min_level, max=max_level
+ )
+ return level_assignments.to(torch.int64) - min_level
+
+
+def boxes_to_tensor(boxes: list[torch.Tensor]) -> torch.Tensor:
+ """Convert all boxes into the tensor format used by ROI pooling ops.
+
+ Args:
+ boxes: List of Boxes
+
+ Returns:
+ A tensor of shape (M, 5), where M is the total number of boxes
+ aggregated over all N batch images. The 5 columns are
+ (batch index, x0, y0, x1, y1), where batch index is in [0, N).
+ """
+
+ def _fmt_box_list(box_tensor: torch.Tensor, batch_i: int) -> torch.Tensor:
+ repeated_index = torch.full_like(
+ box_tensor[:, :1],
+ batch_i,
+ dtype=box_tensor.dtype,
+ device=box_tensor.device,
+ )
+ return torch.cat((repeated_index, box_tensor), dim=1)
+
+ pooler_fmt_boxes = torch.cat(
+ [_fmt_box_list(boxs[:, :4], i) for i, boxs in enumerate(boxes)],
+ dim=0,
+ )
+ return pooler_fmt_boxes
diff --git a/vis4d/op/box/samplers/__init__.py b/vis4d/op/box/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..716b8f076e87b0fc31e00eb5233ba80b0674b9e5
--- /dev/null
+++ b/vis4d/op/box/samplers/__init__.py
@@ -0,0 +1,15 @@
+"""Init sampler module."""
+
+from .base import Sampler, SamplingResult, match_and_sample_proposals
+from .combined import CombinedSampler
+from .pseudo import PseudoSampler
+from .random import RandomSampler
+
+__all__ = [
+ "Sampler",
+ "CombinedSampler",
+ "RandomSampler",
+ "PseudoSampler",
+ "SamplingResult",
+ "match_and_sample_proposals",
+]
diff --git a/vis4d/op/box/samplers/base.py b/vis4d/op/box/samplers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e41b14eb6dcecee94558e9640cbd3cb0fd7355e
--- /dev/null
+++ b/vis4d/op/box/samplers/base.py
@@ -0,0 +1,71 @@
+"""Interface for Vis4D bounding box samplers."""
+
+from __future__ import annotations
+
+import abc
+from typing import NamedTuple
+
+import torch
+from torch import Tensor, nn
+
+from ..matchers import Matcher, MatchResult
+
+
+class SamplingResult(NamedTuple):
+ """Sampling result class. Stores expected result tensors.
+
+ sampled_box_indices (Tensor): Index of sampled boxes from input.
+ sampled_target_indices (Tensor): Index of assigned target for each
+ positive sampled box.
+ sampled_labels (Tensor): {0, -1, 1} = {neg, ignore, pos}.
+ """
+
+ sampled_box_indices: Tensor
+ sampled_target_indices: Tensor
+ sampled_labels: Tensor
+
+
+class Sampler(nn.Module):
+ """Sampler base class."""
+
+ def __init__(self, batch_size: int, positive_fraction: float) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.batch_size = batch_size
+ self.positive_fraction = positive_fraction
+
+ @abc.abstractmethod
+ def forward(self, matching: MatchResult) -> SamplingResult:
+ """Sample bounding boxes according to their struct."""
+ raise NotImplementedError
+
+ def __call__(self, matching: MatchResult) -> SamplingResult:
+ """Type declaration."""
+ return self._call_impl(matching)
+
+
+def match_and_sample_proposals(
+ matcher: Matcher,
+ sampler: Sampler,
+ proposal_boxes: list[Tensor],
+ target_boxes: list[Tensor],
+) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
+ """Match proposals to targets and subsample.
+
+ First, match the proposals to targets (ground truth labels) using the
+ matcher. It is usually IoU matcher. The matching labels the proposals with
+ positive or negative to show whether they are matched to an object.
+ Second, the sampler will choose proposals based on certain criteria such as
+ total proposal number and ratio of postives and negatives.
+ """
+ with torch.no_grad():
+ matchings = tuple(
+ matcher(prop_box, tgt_box)
+ for prop_box, tgt_box in zip(proposal_boxes, target_boxes)
+ )
+ sampling_results = tuple(sampler(matchs) for matchs in matchings)
+ return (
+ [s.sampled_box_indices for s in sampling_results],
+ [s.sampled_target_indices for s in sampling_results],
+ [s.sampled_labels for s in sampling_results],
+ )
diff --git a/vis4d/op/box/samplers/combined.py b/vis4d/op/box/samplers/combined.py
new file mode 100644
index 0000000000000000000000000000000000000000..622f7591fd0bcfd9b6205f57ea36f8616a99a1f4
--- /dev/null
+++ b/vis4d/op/box/samplers/combined.py
@@ -0,0 +1,210 @@
+"""Combined Sampler."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+from vis4d.common.typing import ArgsType
+
+from ..box2d import non_intersection, random_choice
+from ..matchers.base import MatchResult
+from .base import Sampler, SamplingResult
+
+
+class CombinedSampler(Sampler):
+ """Combined sampler. Can have different strategies for pos/neg samples."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ pos_strategy: str,
+ neg_strategy: str,
+ neg_pos_ub: float = 3.0,
+ floor_thr: float = -1.0,
+ floor_fraction: float = 0.0,
+ num_bins: int = 3,
+ bg_label: int = 0,
+ **kwargs: ArgsType,
+ ):
+ """Creates an instance of the class."""
+ super().__init__(*args, **kwargs)
+ self.neg_pos_ub = neg_pos_ub
+ self.floor_thr = floor_thr
+ self.floor_fraction = floor_fraction
+ self.num_bins = num_bins
+ self.bg_label = bg_label
+
+ if not pos_strategy in {
+ "instance_balanced",
+ "iou_balanced",
+ } or not neg_strategy in {"instance_balanced", "iou_balanced"}:
+ raise ValueError(
+ "strategies must be in [instance_balanced, iou_balanced]"
+ )
+
+ self.pos_strategy = getattr(self, pos_strategy + "_sampling")
+ self.neg_strategy = getattr(self, neg_strategy + "_sampling")
+
+ @staticmethod
+ def instance_balanced_sampling(
+ idx_tensor: Tensor,
+ assigned_gts: Tensor,
+ assigned_gt_ious: Tensor, # pylint: disable=unused-argument
+ sample_size: int,
+ ) -> Tensor:
+ """Sample indices with balancing according to matched GT instance."""
+ if idx_tensor.numel() <= sample_size:
+ return idx_tensor
+
+ unique_gt_inds = assigned_gts.unique()
+ num_gts = len(unique_gt_inds)
+ num_per_gt = int(sample_size / float(num_gts))
+ sampled_inds_list = []
+ # sample specific amount per gt instance
+ for i in unique_gt_inds:
+ inds = torch.nonzero(assigned_gts == i, as_tuple=False)
+ inds = inds.squeeze(1)
+ if len(inds) > num_per_gt:
+ inds = random_choice(inds, num_per_gt)
+ sampled_inds_list.append(inds)
+ sampled_inds = torch.cat(sampled_inds_list)
+
+ # deal with edge cases
+ if len(sampled_inds) < sample_size:
+ num_extra = sample_size - len(sampled_inds)
+ extra_inds = non_intersection(idx_tensor, sampled_inds)
+ if len(extra_inds) > num_extra:
+ extra_inds = random_choice(extra_inds, num_extra)
+ sampled_inds = torch.cat([sampled_inds, extra_inds])
+ return sampled_inds
+
+ def iou_balanced_sampling(
+ self,
+ idx_tensor: Tensor,
+ assigned_gts: Tensor, # pylint: disable=unused-argument
+ assigned_gt_ious: Tensor,
+ sample_size: int,
+ ) -> Tensor:
+ """Sample indices with balancing according to IoU with matched GT."""
+ if idx_tensor.numel() <= sample_size:
+ return idx_tensor
+
+ # define 'floor' set - set with low iou samples
+ if self.floor_thr >= 0:
+ floor_set = idx_tensor[assigned_gt_ious <= self.floor_thr]
+ iou_sampling_set = idx_tensor[assigned_gt_ious > self.floor_thr]
+ else:
+ floor_set = None
+ iou_sampling_set = idx_tensor[assigned_gt_ious > self.floor_thr]
+
+ num_iou_set_samples = int(sample_size * (1 - self.floor_fraction))
+ if len(iou_sampling_set) > num_iou_set_samples:
+ if self.num_bins >= 2:
+ iou_sampled_inds = self.sample_within_intervals(
+ idx_tensor, assigned_gt_ious, num_iou_set_samples
+ )
+ else:
+ iou_sampled_inds = random_choice(
+ iou_sampling_set, num_iou_set_samples
+ )
+ else:
+ iou_sampled_inds = iou_sampling_set # pragma: no cover
+
+ if floor_set is not None:
+ num_floor_set_samples = sample_size - len(iou_sampled_inds)
+ if len(floor_set) > num_floor_set_samples:
+ sampled_floor_inds = random_choice(
+ floor_set, num_floor_set_samples
+ )
+ else:
+ sampled_floor_inds = floor_set # pragma: no cover
+ sampled_inds = torch.cat([sampled_floor_inds, iou_sampled_inds])
+ else:
+ sampled_inds = iou_sampled_inds
+
+ if len(sampled_inds) < sample_size: # pragma: no cover
+ num_extra = sample_size - len(sampled_inds)
+ extra_inds = non_intersection(idx_tensor, sampled_inds)
+ if len(extra_inds) > num_extra:
+ extra_inds = random_choice(extra_inds, num_extra)
+ sampled_inds = torch.cat([sampled_inds, extra_inds])
+
+ return sampled_inds
+
+ def forward(self, matching: MatchResult) -> SamplingResult:
+ """Sample boxes according to strategies defined in cfg."""
+ pos_sample_size = int(self.batch_size * self.positive_fraction)
+
+ positive_mask: Tensor = (matching.assigned_labels != -1) & (
+ matching.assigned_labels != self.bg_label
+ )
+ negative_mask = torch.eq(matching.assigned_labels, self.bg_label)
+
+ positive = positive_mask.nonzero()[:, 0]
+ negative = negative_mask.nonzero()[:, 0]
+
+ num_pos = min(positive.numel(), pos_sample_size)
+ num_neg = self.batch_size - num_pos
+
+ if self.neg_pos_ub >= 0:
+ neg_upper_bound = int(self.neg_pos_ub * num_pos)
+ num_neg = min(num_neg, neg_upper_bound)
+
+ pos_idx = self.pos_strategy(
+ idx_tensor=positive,
+ assigned_gts=matching.assigned_gt_indices[positive_mask],
+ assigned_gt_ious=matching.assigned_gt_iou[positive_mask],
+ sample_size=num_pos,
+ )
+
+ neg_idx = self.neg_strategy(
+ idx_tensor=negative,
+ assigned_gts=matching.assigned_gt_indices[negative_mask],
+ assigned_gt_ious=matching.assigned_gt_iou[negative_mask],
+ sample_size=num_neg,
+ )
+ sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0)
+
+ return SamplingResult(
+ sampled_box_indices=sampled_idcs,
+ sampled_target_indices=matching.assigned_gt_indices[sampled_idcs],
+ sampled_labels=matching.assigned_labels[sampled_idcs],
+ )
+
+ def sample_within_intervals(
+ self,
+ idx_tensor: Tensor,
+ assigned_gt_ious: Tensor,
+ sample_size: int,
+ ) -> Tensor:
+ """Sample according to N iou intervals where N = num bins."""
+ floor_thr = max(self.floor_thr, 0.0)
+ max_iou = assigned_gt_ious.max()
+ iou_interval = (max_iou - floor_thr) / self.num_bins
+ per_bin_samples = int(sample_size / self.num_bins)
+
+ sampled_inds_list = []
+ for i in range(self.num_bins):
+ start_iou = floor_thr + i * iou_interval
+ end_iou = floor_thr + (i + 1) * iou_interval
+ tmp_set = (
+ (start_iou <= assigned_gt_ious) & (assigned_gt_ious < end_iou)
+ ).nonzero()[:, 0]
+ if len(tmp_set) > per_bin_samples:
+ tmp_sampled_set = random_choice(
+ idx_tensor[tmp_set], per_bin_samples
+ )
+ else:
+ tmp_sampled_set = idx_tensor[tmp_set] # pragma: no cover
+ sampled_inds_list.append(tmp_sampled_set)
+
+ sampled_inds = torch.cat(sampled_inds_list)
+ if len(sampled_inds) < sample_size:
+ num_extra = sample_size - len(sampled_inds)
+ extra_inds = non_intersection(idx_tensor, sampled_inds)
+ if len(extra_inds) > num_extra:
+ extra_inds = random_choice(extra_inds, num_extra)
+ sampled_inds = torch.cat([sampled_inds, extra_inds])
+
+ return sampled_inds
diff --git a/vis4d/op/box/samplers/pseudo.py b/vis4d/op/box/samplers/pseudo.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e22431c9f606e442b05a3450f246140f814e16
--- /dev/null
+++ b/vis4d/op/box/samplers/pseudo.py
@@ -0,0 +1,35 @@
+"""Pseudo Sampler."""
+
+from __future__ import annotations
+
+import torch
+
+from ..matchers.base import MatchResult
+from .base import Sampler, SamplingResult
+
+
+class PseudoSampler(Sampler):
+ """Pseudo sampler class (does nothing)."""
+
+ def __init__(self) -> None:
+ """Init."""
+ super(Sampler, self).__init__()
+
+ def forward(self, matching: MatchResult) -> SamplingResult:
+ """Sample boxes randomly."""
+ pos_idx, neg_idx = self._sample_labels(matching.assigned_labels)
+ sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0)
+ return SamplingResult(
+ sampled_box_indices=sampled_idcs,
+ sampled_target_indices=matching.assigned_gt_indices[sampled_idcs],
+ sampled_labels=matching.assigned_labels[sampled_idcs],
+ )
+
+ @staticmethod
+ def _sample_labels(
+ labels: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Randomly sample indices from given labels."""
+ positive = ((labels != -1) & (labels != 0)).nonzero()[:, 0]
+ negative = torch.eq(labels, 0).nonzero()[:, 0]
+ return positive, negative
diff --git a/vis4d/op/box/samplers/random.py b/vis4d/op/box/samplers/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0bc6d202982f3b096e2519e8fdcf45b5b6dc965
--- /dev/null
+++ b/vis4d/op/box/samplers/random.py
@@ -0,0 +1,63 @@
+"""Random Sampler."""
+
+from __future__ import annotations
+
+import torch
+
+from vis4d.common.typing import ArgsType
+
+from ..matchers.base import MatchResult
+from .base import Sampler, SamplingResult
+
+
+class RandomSampler(Sampler):
+ """Random sampler class."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ bg_label: int = 0,
+ **kwargs: ArgsType,
+ ):
+ """Creates an instance of the class."""
+ super().__init__(*args, **kwargs)
+ self.bg_label = bg_label
+
+ def forward(
+ self,
+ matching: MatchResult,
+ ) -> SamplingResult:
+ """Sample boxes randomly."""
+ pos_idx, neg_idx = self._sample_labels(matching.assigned_labels)
+ sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0)
+ return SamplingResult(
+ sampled_box_indices=sampled_idcs,
+ sampled_target_indices=matching.assigned_gt_indices[sampled_idcs],
+ sampled_labels=matching.assigned_labels[sampled_idcs],
+ )
+
+ def _sample_labels(
+ self, labels: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Randomly sample indices from given labels."""
+ positive = ((labels != -1) & (labels != self.bg_label)).nonzero()[:, 0]
+ negative = torch.eq(labels, self.bg_label).nonzero()[:, 0]
+
+ num_pos = int(self.batch_size * self.positive_fraction)
+ # protect against not enough positive examples
+ num_pos = min(positive.numel(), num_pos)
+ num_neg = self.batch_size - num_pos
+ # protect against not enough negative examples
+ num_neg = min(negative.numel(), num_neg)
+
+ # randomly select positive and negative examples
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[
+ :num_pos
+ ]
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[
+ :num_neg
+ ]
+
+ pos_idx = positive[perm1]
+ neg_idx = negative[perm2]
+ return pos_idx, neg_idx
diff --git a/vis4d/op/detect/__init__.py b/vis4d/op/detect/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bddfadc65f29aee0f133843cf877ca84b67eba2
--- /dev/null
+++ b/vis4d/op/detect/__init__.py
@@ -0,0 +1 @@
+"""Detector module."""
diff --git a/vis4d/op/detect/common.py b/vis4d/op/detect/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..377366c121d1cdb7fd819e69cf5c2fd3590aa992
--- /dev/null
+++ b/vis4d/op/detect/common.py
@@ -0,0 +1,18 @@
+"""Common classes and functions for detection."""
+
+from typing import NamedTuple
+
+from torch import Tensor
+
+
+class DetOut(NamedTuple):
+ """Output of the detection model.
+
+ boxes (list[Tensor]): 2D bounding boxes of shape [N, 4] in xyxy format.
+ scores (list[Tensor]): confidence scores of shape [N,].
+ class_ids (list[Tensor]): class ids of shape [N,].
+ """
+
+ boxes: list[Tensor]
+ scores: list[Tensor]
+ class_ids: list[Tensor]
diff --git a/vis4d/op/detect/dense_anchor.py b/vis4d/op/detect/dense_anchor.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d8dd92ee6832223d6afabbaffd584dac289f5fd
--- /dev/null
+++ b/vis4d/op/detect/dense_anchor.py
@@ -0,0 +1,347 @@
+"""Dense anchor-based head."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from vis4d.common.typing import TorchLossFunc
+from vis4d.op.box.anchor import AnchorGenerator, anchor_inside_image
+from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
+from vis4d.op.box.matchers import Matcher
+from vis4d.op.box.samplers import Sampler
+from vis4d.op.loss.reducer import SumWeightedLoss
+from vis4d.op.util import unmap
+
+
+class DetectorTargets(NamedTuple):
+ """Targets for first-stage detection."""
+
+ labels: Tensor
+ label_weights: Tensor
+ bbox_targets: Tensor
+ bbox_weights: Tensor
+
+
+def images_to_levels(
+ targets: list[
+ tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]
+ ],
+) -> list[list[Tensor]]:
+ """Convert targets by image to targets by feature level."""
+ targets_per_level = []
+ for lvl_id in range(len(targets[0][0])):
+ targets_single_level = []
+ for tgt_id in range(len(targets[0])):
+ targets_single_level.append(
+ torch.stack([tgt[tgt_id][lvl_id] for tgt in targets], 0)
+ )
+ targets_per_level.append(targets_single_level)
+ return targets_per_level
+
+
+def get_targets_per_image(
+ target_boxes: Tensor,
+ anchors: Tensor,
+ matcher: Matcher,
+ sampler: Sampler,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ image_hw: tuple[int, int],
+ target_class: Tensor | float = 1.0,
+ allowed_border: int = 0,
+) -> tuple[DetectorTargets, int, int]:
+ """Get targets per batch element, all scales.
+
+ Args:
+ target_boxes (Tensor): (N, 4) Tensor of target boxes for a single
+ image.
+ anchors (Tensor): (M, 4) box priors
+ matcher (Matcher): box matcher matching anchors to targets.
+ sampler (Sampler): box sampler sub-sampling matches.
+ box_encoder (DeltaXYWHBBoxEncoder): Encodes boxes into target
+ regression parameters.
+ image_hw (tuple[int, int]): input image height and width.
+ target_class (Tensor | float, optional): class label(s) of target
+ boxes. Defaults to 1.0.
+ allowed_border (int, optional): Allowed border for sub-sampling anchors
+ that lie inside the input image. Defaults to 0.
+
+ Returns:
+ tuple[DetectorTargets, Tensor, Tensor]: Targets, sum of positives, sum
+ of negatives.
+ """
+ inside_flags = anchor_inside_image(
+ anchors, image_hw, allowed_border=allowed_border
+ )
+ # assign gt and sample anchors
+ anchors = anchors[inside_flags, :]
+
+ matching = matcher(anchors, target_boxes)
+ sampling_result = sampler(matching)
+
+ num_valid_anchors = anchors.size(0)
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_zeros((num_valid_anchors,))
+ label_weights = anchors.new_zeros(num_valid_anchors)
+
+ positives = torch.eq(sampling_result.sampled_labels, 1)
+ negatives = torch.eq(sampling_result.sampled_labels, 0)
+ pos_inds = sampling_result.sampled_box_indices[positives]
+ pos_target_inds = sampling_result.sampled_target_indices[positives]
+ neg_inds = sampling_result.sampled_box_indices[negatives]
+ if len(pos_inds) > 0:
+ pos_bbox_targets = box_encoder(
+ anchors[pos_inds], target_boxes[pos_target_inds]
+ )
+ bbox_targets[pos_inds] = pos_bbox_targets
+ bbox_weights[pos_inds] = 1.0
+ if isinstance(target_class, float):
+ labels[pos_inds] = target_class
+ else:
+ labels[pos_inds] = target_class[pos_target_inds].float()
+ label_weights[pos_inds] = 1.0
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ num_total_anchors = inside_flags.size(0)
+ labels = unmap(labels, num_total_anchors, inside_flags)
+ label_weights = unmap(label_weights, num_total_anchors, inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (
+ DetectorTargets(labels, label_weights, bbox_targets, bbox_weights),
+ int(positives.sum()),
+ int(negatives.sum()),
+ )
+
+
+def get_targets_per_batch(
+ featmap_sizes: list[tuple[int, int]],
+ target_boxes: list[Tensor],
+ target_class_ids: list[Tensor | float],
+ images_hw: list[tuple[int, int]],
+ anchor_generator: AnchorGenerator,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ box_matcher: Matcher,
+ box_sampler: Sampler,
+ allowed_border: int = 0,
+) -> tuple[list[list[Tensor]], int]:
+ """Get targets for all batch elements, all scales."""
+ device = target_boxes[0].device
+
+ anchor_grids = anchor_generator.grid_priors(featmap_sizes, device=device)
+ num_level_anchors = [anchors.size(0) for anchors in anchor_grids]
+ anchors_all_levels = torch.cat(anchor_grids)
+
+ targets: list[
+ tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]
+ ] = []
+ num_total_pos, num_total_neg = 0, 0
+ for tgt_box, tgt_cls, image_hw in zip(
+ target_boxes, target_class_ids, images_hw
+ ):
+ target, num_pos, num_neg = get_targets_per_image(
+ tgt_box,
+ anchors_all_levels,
+ box_matcher,
+ box_sampler,
+ box_encoder,
+ image_hw,
+ tgt_cls,
+ allowed_border,
+ )
+ num_total_pos += num_pos
+ num_total_neg += num_neg
+ bbox_targets_per_level = target.bbox_targets.split(num_level_anchors)
+ bbox_weights_per_level = target.bbox_weights.split(num_level_anchors)
+ labels_per_level = target.labels.split(num_level_anchors)
+ label_weights_per_level = target.label_weights.split(num_level_anchors)
+ targets.append(
+ (
+ bbox_targets_per_level,
+ bbox_weights_per_level,
+ labels_per_level,
+ label_weights_per_level,
+ )
+ )
+ targets_per_level = images_to_levels(targets)
+ num_samples = num_total_pos + num_total_neg
+ return targets_per_level, num_samples
+
+
+class DenseAnchorHeadLosses(NamedTuple):
+ """Dense anchor head loss container."""
+
+ loss_cls: Tensor
+ loss_bbox: Tensor
+
+
+class DenseAnchorHeadLoss(nn.Module):
+ """Loss of dense anchor heads.
+
+ For a given set of multi-scale dense outputs, compute the desired target
+ outputs and apply classification and regression losses.
+ The targets are computed with the given target bounding boxes, the
+ anchor grid defined by the anchor generator and the given box encoder.
+ """
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ box_matcher: Matcher,
+ box_sampler: Sampler,
+ loss_cls: TorchLossFunc,
+ loss_bbox: TorchLossFunc,
+ allowed_border: int = 0,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ anchor_generator (AnchorGenerator): Generates anchor grid priors.
+ box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to
+ the desired network output.
+ box_matcher (Matcher): Box matcher.
+ box_sampler (Sampler): Box sampler.
+ loss_cls (TorchLossFunc): Classification loss.
+ loss_bbox (TorchLossFunc): Bounding box regression loss.
+ allowed_border (int): The border to allow the valid anchor.
+ Defaults to 0.
+ """
+ super().__init__()
+ self.anchor_generator = anchor_generator
+ self.box_encoder = box_encoder
+ self.allowed_border = allowed_border
+ self.matcher = box_matcher
+ self.sampler = box_sampler
+ self.loss_cls = loss_cls
+ self.loss_bbox = loss_bbox
+
+ def _loss_single_scale(
+ self,
+ cls_out: Tensor,
+ reg_out: Tensor,
+ bbox_targets: Tensor,
+ bbox_weights: Tensor,
+ labels: Tensor,
+ label_weights: Tensor,
+ num_total_samples: int,
+ ) -> tuple[Tensor, Tensor]:
+ """Compute losses per scale, all batch elements.
+
+ Args:
+ cls_out (Tensor): [N, C, H, W] tensor of class logits.
+ reg_out (Tensor): [N, C, H, W] tensor of regression params.
+ bbox_targets (Tensor): [H * W, 4] bounding box targets
+ bbox_weights (Tensor): [H * W] per-sample weighting for loss.
+ labels (Tensor): [H * W] classification targets.
+ label_weights (Tensor): [H * W] per-sample weighting for loss.
+ num_total_samples (int): average factor of loss.
+
+ Returns:
+ tuple[Tensor, Tensor]: classification and regression losses.
+ """
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_out.permute(0, 2, 3, 1).reshape(labels.size(0), -1)
+ if cls_score.size(1) > 1:
+ labels = F.one_hot( # pylint: disable=not-callable
+ labels.long(), num_classes=cls_score.size(1) + 1
+ )[:, : cls_score.size(1)].float()
+ label_weights = label_weights.repeat(cls_score.size(1)).reshape(
+ -1, cls_score.size(1)
+ )
+ else:
+ cls_score = cls_score.squeeze(1)
+
+ loss_cls = self.loss_cls(cls_score, labels, reduction="none")
+ loss_cls = SumWeightedLoss(label_weights, num_total_samples)(loss_cls)
+
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ bbox_pred = reg_out.permute(0, 2, 3, 1).reshape(-1, 4)
+
+ loss_bbox = self.loss_bbox(
+ pred=bbox_pred,
+ target=bbox_targets,
+ reducer=SumWeightedLoss(bbox_weights, num_total_samples),
+ )
+ return loss_cls, loss_bbox
+
+ def forward(
+ self,
+ cls_outs: list[Tensor],
+ reg_outs: list[Tensor],
+ target_boxes: list[Tensor],
+ images_hw: list[tuple[int, int]],
+ target_class_ids: list[Tensor | float] | None = None,
+ ) -> DenseAnchorHeadLosses:
+ """Compute RetinaNet classification and regression losses.
+
+ Args:
+ cls_outs (list[Tensor]): Network classification outputs
+ at all scales.
+ reg_outs (list[Tensor]): Network regression outputs
+ at all scales.
+ target_boxes (list[Tensor]): Target bounding boxes.
+ images_hw (list[tuple[int, int]]): Image dimensions without
+ padding.
+ target_class_ids (list[Tensor] | None, optional): Target
+ class labels.
+
+ Returns:
+ DenseAnchorHeadLosses: Classification and regression losses.
+ """
+ featmap_sizes = [
+ (featmap.size()[-2], featmap.size()[-1]) for featmap in cls_outs
+ ]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+ if target_class_ids is None:
+ target_class_ids = [1.0 for _ in range(len(target_boxes))]
+
+ targets_per_level, num_samples = get_targets_per_batch(
+ featmap_sizes,
+ target_boxes,
+ target_class_ids,
+ images_hw,
+ self.anchor_generator,
+ self.box_encoder,
+ self.matcher,
+ self.sampler,
+ self.allowed_border,
+ )
+
+ device = cls_outs[0].device
+ loss_cls_all = torch.tensor(0.0, device=device)
+ loss_bbox_all = torch.tensor(0.0, device=device)
+ for level_id, (cls_out, reg_out) in enumerate(zip(cls_outs, reg_outs)):
+ box_tgt, box_wgt, lbl, lbl_wgt = targets_per_level[level_id]
+ loss_cls, loss_bbox = self._loss_single_scale(
+ cls_out, reg_out, box_tgt, box_wgt, lbl, lbl_wgt, num_samples
+ )
+ loss_cls_all += loss_cls
+ loss_bbox_all += loss_bbox
+ return DenseAnchorHeadLosses(
+ loss_cls=loss_cls_all, loss_bbox=loss_bbox_all
+ )
+
+ def __call__(
+ self,
+ cls_outs: list[Tensor],
+ reg_outs: list[Tensor],
+ target_boxes: list[Tensor],
+ images_hw: list[tuple[int, int]],
+ target_class_ids: list[Tensor] | None = None,
+ ) -> DenseAnchorHeadLosses:
+ """Type definition."""
+ return self._call_impl(
+ cls_outs, reg_outs, target_boxes, images_hw, target_class_ids
+ )
diff --git a/vis4d/op/detect/faster_rcnn.py b/vis4d/op/detect/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..08062d72bdbc09067fd156c60c213618a0a5e4fa
--- /dev/null
+++ b/vis4d/op/detect/faster_rcnn.py
@@ -0,0 +1,228 @@
+"""Faster RCNN detector."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+from torch import nn
+
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.box2d import apply_mask
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
+from vis4d.op.box.matchers import Matcher, MaxIoUMatcher
+from vis4d.op.box.samplers import (
+ RandomSampler,
+ Sampler,
+ match_and_sample_proposals,
+)
+
+from .rcnn import RCNNHead, RCNNOut
+from .rpn import RPN2RoI, RPNHead, RPNOut
+from .typing import Proposals, Targets
+
+
+class FRCNNOut(NamedTuple):
+ """Faster RCNN function call outputs."""
+
+ rpn: RPNOut
+ roi: RCNNOut
+ proposals: Proposals
+ sampled_proposals: Proposals | None
+ sampled_targets: Targets | None
+ sampled_target_indices: list[torch.Tensor] | None
+
+
+class FasterRCNNHead(nn.Module):
+ """This class composes RPN and RCNN head components.
+
+ It generates proposals via RPN and samples those, and runs the RCNN head
+ on the sampled proposals. During training, the sampling process is based
+ on the GT bounding boxes, during inference it is based on objectness score
+ of the proposals.
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ anchor_generator: None | AnchorGenerator = None,
+ rpn_box_decoder: None | DeltaXYWHBBoxDecoder = None,
+ box_matcher: None | Matcher = None,
+ box_sampler: None | Sampler = None,
+ roi_head: None | RCNNHead = None,
+ proposal_append_gt: bool = True,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of object categories.
+ anchor_generator (AnchorGenerator, optional): Custom generator for
+ RPN. Defaults to None.
+ rpn_box_decoder (DeltaXYWHBBoxDecoder, optional): Custom rpn box
+ decoder. Defaults to None.
+ box_matcher (Matcher, optional): Custom box matcher for RCNN stage.
+ Defaults to None.
+ box_sampler (Sampler, optional): Custom box sampler for RCNN stage.
+ Defaults to None.
+ roi_head (RCNNHead, optional): Custom ROI head. Defaults to None.
+ proposal_append_gt (bool): If to append the ground truth boxes for
+ proposal sampling during training. Defaults to True.
+ """
+ super().__init__()
+ if anchor_generator is None:
+ anchor_generator = AnchorGenerator(
+ scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64]
+ )
+
+ self.box_matcher = (
+ MaxIoUMatcher(
+ thresholds=[0.5],
+ labels=[0, 1],
+ allow_low_quality_matches=False,
+ )
+ if box_matcher is None
+ else box_matcher
+ )
+
+ self.box_sampler = (
+ RandomSampler(batch_size=512, positive_fraction=0.25)
+ if box_sampler is None
+ else box_sampler
+ )
+
+ self.proposal_append_gt = proposal_append_gt
+ self.rpn_head = RPNHead(anchor_generator.num_base_priors[0])
+ self.rpn2roi = RPN2RoI(anchor_generator, rpn_box_decoder)
+
+ self.roi_head = (
+ RCNNHead(num_classes=num_classes) if roi_head is None else roi_head
+ )
+
+ @torch.no_grad()
+ def _sample_proposals(
+ self,
+ proposal_boxes: list[torch.Tensor],
+ scores: list[torch.Tensor],
+ target_boxes: list[torch.Tensor],
+ target_classes: list[torch.Tensor],
+ ) -> tuple[Proposals, Targets, list[torch.Tensor]]:
+ """Sample proposals for training of Faster RCNN.
+
+ Args:
+ proposal_boxes (list[torch.Tensor]): Proposals decoded from RPN.
+ scores (list[torch.Tensor]): Scores decoded from RPN.
+ target_boxes (list[torch.Tensor]): All target boxes.
+ target_classes (list[torch.Tensor]): According class labels.
+
+ Returns:
+ tuple[Proposals, Targets]: Sampled proposals, associated targets.
+ """
+ if self.proposal_append_gt:
+ proposal_boxes = [
+ torch.cat([p, t]) for p, t in zip(proposal_boxes, target_boxes)
+ ]
+ scores = [
+ torch.cat([s, s.new_ones(len(t))])
+ for s, t in zip(scores, target_boxes)
+ ]
+
+ (
+ sampled_box_indices,
+ sampled_target_indices,
+ sampled_labels,
+ ) = match_and_sample_proposals(
+ self.box_matcher, self.box_sampler, proposal_boxes, target_boxes
+ )
+
+ sampled_boxes, sampled_scores = apply_mask(
+ sampled_box_indices, proposal_boxes, scores
+ )
+
+ sampled_target_boxes, sampled_target_classes = apply_mask(
+ sampled_target_indices, target_boxes, target_classes
+ )
+
+ sampled_proposals = Proposals(
+ boxes=sampled_boxes, scores=sampled_scores
+ )
+ sampled_targets = Targets(
+ boxes=sampled_target_boxes,
+ classes=sampled_target_classes,
+ labels=sampled_labels,
+ )
+ return sampled_proposals, sampled_targets, sampled_target_indices
+
+ def forward(
+ self,
+ features: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ target_boxes: None | list[torch.Tensor] = None,
+ target_classes: None | list[torch.Tensor] = None,
+ ) -> FRCNNOut:
+ """Faster RCNN forward.
+
+ Args:
+ features (list[torch.Tensor]): Feature pyramid.
+ images_hw (list[tuple[int, int]]): Image sizes without padding.
+ This is necessary for removing the erroneous boxes on the
+ padded regions.
+ target_boxes (None | list[torch.Tensor], optional): Ground truth
+ bounding box locations. Defaults to None.
+ target_classes (None | list[torch.Tensor], optional): Ground truth
+ bounding box classes. Defaults to None.
+
+ Returns:
+ FRCNNReturn: Proposal and RoI outputs.
+ """
+ if target_boxes is not None:
+ assert target_classes is not None
+
+ rpn_out = self.rpn_head(features)
+
+ if target_boxes is not None:
+ assert (
+ target_classes is not None
+ ), "Need target classes for target boxes!"
+ proposal_boxes, scores = self.rpn2roi(
+ rpn_out.cls, rpn_out.box, images_hw
+ )
+
+ (
+ sampled_proposals,
+ sampled_targets,
+ sampled_target_indices,
+ ) = self._sample_proposals(
+ proposal_boxes, scores, target_boxes, target_classes
+ )
+ roi_out = self.roi_head(features, sampled_proposals.boxes)
+ else:
+ proposal_boxes, scores = self.rpn2roi(
+ rpn_out.cls, rpn_out.box, images_hw
+ )
+ sampled_proposals, sampled_targets, sampled_target_indices = (
+ None,
+ None,
+ None,
+ )
+ roi_out = self.roi_head(features, proposal_boxes)
+
+ return FRCNNOut(
+ roi=roi_out,
+ rpn=rpn_out,
+ proposals=Proposals(proposal_boxes, scores),
+ sampled_proposals=sampled_proposals,
+ sampled_targets=sampled_targets,
+ sampled_target_indices=sampled_target_indices,
+ )
+
+ def __call__(
+ self,
+ features: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ target_boxes: list[torch.Tensor] | None = None,
+ target_classes: list[torch.Tensor] | None = None,
+ ) -> FRCNNOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ features, images_hw, target_boxes, target_classes
+ )
diff --git a/vis4d/op/detect/mask_rcnn.py b/vis4d/op/detect/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8a528d8693b9789052af7cada4f8df69459b20c
--- /dev/null
+++ b/vis4d/op/detect/mask_rcnn.py
@@ -0,0 +1,420 @@
+"""Mask RCNN detector."""
+
+from __future__ import annotations
+
+from typing import NamedTuple, Protocol
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torchvision.ops import roi_align
+
+from vis4d.op.box.box2d import apply_mask
+from vis4d.op.box.poolers import MultiScaleRoIAlign
+from vis4d.op.mask.util import paste_masks_in_image, remove_overlap
+
+from .typing import Proposals, Targets
+
+
+class MaskRCNNHeadOut(NamedTuple):
+ """Mask R-CNN RoI head outputs."""
+
+ # logits for mask prediction. The dimension is number of masks x number of
+ # classes x H_mask x W_mask
+ mask_pred: list[torch.Tensor]
+
+
+class MaskRCNNHead(nn.Module):
+ """Mask R-CNN RoI head.
+
+ Args:
+ num_classes (int, optional): Number of classes. Defaults to 80.
+ num_convs (int, optional): Number of convolution layers. Defaults to 4.
+ roi_size (tuple[int, int], optional): Size of RoI after pooling.
+ Defaults to (14, 14).
+ in_channels (int, optional): Input feature channels. Defaults to 256.
+ conv_kernel_size (int, optional): Kernel size of convolution. Defaults
+ to 3.
+ conv_out_channels (int, optional): Output channels of convolution.
+ Defaults to 256.
+ scale_factor (int, optional): Scaling factor of upsampling. Defaults
+ to 2.
+ class_agnostic (bool, optional): Whether to do class agnostic mask
+ prediction. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ num_classes: int = 80,
+ num_convs: int = 4,
+ roi_size: tuple[int, int] = (14, 14),
+ in_channels: int = 256,
+ conv_kernel_size: int = 3,
+ conv_out_channels: int = 256,
+ scale_factor: int = 2,
+ class_agnostic: bool = False,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.roi_pooler = MultiScaleRoIAlign(
+ sampling_ratio=0, resolution=roi_size, strides=[4, 8, 16, 32]
+ )
+
+ self.convs = nn.ModuleList()
+ for i in range(num_convs):
+ in_channels = in_channels if i == 0 else conv_out_channels
+ padding = (conv_kernel_size - 1) // 2
+ self.convs.append(
+ nn.Conv2d(
+ in_channels,
+ conv_out_channels,
+ conv_kernel_size,
+ padding=padding,
+ )
+ )
+
+ upsample_in_channels = (
+ conv_out_channels if num_convs > 0 else in_channels
+ )
+ self.upsample = nn.ConvTranspose2d(
+ upsample_in_channels,
+ conv_out_channels,
+ scale_factor,
+ stride=scale_factor,
+ )
+
+ out_channels = 1 if class_agnostic else num_classes
+ self.conv_logits = nn.Conv2d(conv_out_channels, out_channels, 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ self._init_weights(self.convs)
+ self._init_weights(self.upsample, mode="fan_out")
+ self._init_weights(self.conv_logits, mode="fan_out")
+
+ @staticmethod
+ def _init_weights(module: nn.Module, mode: str = "fan_in") -> None:
+ """Initialize weights."""
+ if hasattr(module, "weight") and hasattr(module, "bias"):
+ assert isinstance(module.weight, torch.Tensor) and isinstance(
+ module.bias, torch.Tensor
+ )
+ nn.init.kaiming_normal_(
+ module.weight, mode=mode, nonlinearity="relu" # type: ignore
+ )
+ nn.init.constant_(module.bias, 0)
+
+ def forward(
+ self, features: list[torch.Tensor], boxes: list[torch.Tensor]
+ ) -> MaskRCNNHeadOut:
+ """Forward pass.
+
+ Args:
+ features (list[torch.Tensor]): Feature pyramid.
+ boxes (list[torch.Tensor]): Proposal boxes.
+
+ Returns:
+ MaskRCNNHeadOut: Mask prediction outputs.
+ """
+ # Take stride 4, 8, 16, 32 features
+ mask_feats = self.roi_pooler(features[2:6], boxes)
+ for conv in self.convs:
+ mask_feats = self.relu(conv(mask_feats))
+ mask_feats = self.relu(self.upsample(mask_feats))
+ mask_pred = self.conv_logits(mask_feats)
+ num_dets_per_img = tuple(len(d) for d in boxes)
+ mask_preds = mask_pred.split(num_dets_per_img, 0)
+ return MaskRCNNHeadOut(mask_pred=mask_preds)
+
+
+class MaskOut(NamedTuple):
+ """Output of the final detections from Mask RCNN."""
+
+ masks: list[torch.Tensor] # N, H, W
+ scores: list[torch.Tensor]
+ class_ids: list[torch.Tensor]
+
+
+class Det2Mask(nn.Module):
+ """Post processing of mask predictions.
+
+ Args:
+ mask_threshold (float, optional): Positive threshold. Defaults to 0.5.
+ no_overlap (bool, optional): Whether to remove overlapping pixels
+ between masks. Defaults to False.
+ """
+
+ def __init__(
+ self, mask_threshold: float = 0.5, no_overlap: bool = False
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.mask_threshold = mask_threshold
+ self.no_overlap = no_overlap
+
+ def forward(
+ self,
+ mask_outs: list[torch.Tensor],
+ det_boxes: list[torch.Tensor],
+ det_scores: list[torch.Tensor],
+ det_class_ids: list[torch.Tensor],
+ original_hw: list[tuple[int, int]],
+ ) -> MaskOut:
+ """Paste mask predictions back into original image resolution.
+
+ Args:
+ mask_outs (list[torch.Tensor]): List of mask outputs for each batch
+ element.
+ det_boxes (list[torch.Tensor]): List of detection boxes for each
+ batch element.
+ det_scores (list[torch.Tensor]): List of detection scores for each
+ batch element.
+ det_class_ids (list[torch.Tensor]): List of detection classeds for
+ each batch element.
+ original_hw (list[tuple[int, int]]): Original image resolution.
+
+ Returns:
+ MaskOut: Post-processed mask predictions.
+ """
+ all_masks = []
+ all_scores = []
+ all_class_ids = []
+ for mask_out, boxes, scores, class_ids, orig_hw in zip(
+ mask_outs, det_boxes, det_scores, det_class_ids, original_hw
+ ):
+ pasted_masks = paste_masks_in_image(
+ mask_out[torch.arange(len(mask_out)), class_ids],
+ boxes,
+ orig_hw[::-1],
+ self.mask_threshold,
+ )
+ if self.no_overlap:
+ pasted_masks = remove_overlap(pasted_masks, scores)
+ all_masks.append(pasted_masks)
+ all_scores.append(scores)
+ all_class_ids.append(class_ids)
+ return MaskOut(
+ masks=all_masks, scores=all_scores, class_ids=all_class_ids
+ )
+
+ def __call__(
+ self,
+ mask_outs: list[torch.Tensor],
+ det_boxes: list[torch.Tensor],
+ det_scores: list[torch.Tensor],
+ det_class_ids: list[torch.Tensor],
+ original_hw: list[tuple[int, int]],
+ ) -> MaskOut:
+ """Type definition for function call."""
+ return self._call_impl(
+ mask_outs, det_boxes, det_scores, det_class_ids, original_hw
+ )
+
+
+class MaskRCNNHeadLosses(NamedTuple):
+ """Mask RoI head loss container."""
+
+ rcnn_loss_mask: torch.Tensor
+
+
+class MaskRCNNHeadLoss(nn.Module):
+ """Mask RoI head loss function.
+
+ Args:
+ num_classes (int): number of object categories.
+ """
+
+ def __init__(self, num_classes: int) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.num_classes = num_classes
+
+ @staticmethod
+ def _get_targets_per_image(
+ boxes: Tensor,
+ tgt_masks: Tensor,
+ out_shape: tuple[int, int],
+ binarize: bool = True,
+ ) -> Tensor:
+ """Get aligned mask targets for each proposal.
+
+ Args:
+ boxes (Tensor): proposal boxes.
+ tgt_masks (Tensor): target masks.
+ out_shape (tuple[int, int]): output shape.
+ binarize (bool, optional): whether to convert target mask to
+ binary. Defaults to True.
+
+ Returns:
+ Tensor: aligned mask targets.
+ """
+ fake_inds = torch.arange(len(boxes), device=boxes.device)[:, None]
+ rois = torch.cat([fake_inds, boxes], dim=1) # Nx5
+ gt_masks_th = tgt_masks[:, None, :, :].type(rois.dtype)
+ targets = roi_align(
+ gt_masks_th, rois, out_shape, 1.0, 0, True
+ ).squeeze(1)
+ resized_masks = targets >= 0.5 if binarize else targets
+ return resized_masks
+
+ def forward(
+ self,
+ mask_preds: list[torch.Tensor],
+ proposal_boxes: list[torch.Tensor],
+ target_classes: list[torch.Tensor],
+ target_masks: list[torch.Tensor],
+ ) -> MaskRCNNHeadLosses:
+ """Calculate losses of Mask RCNN head.
+
+ Args:
+ mask_preds (list[torch.Tensor]): [M, C, H', W'] mask outputs per
+ batch element.
+ proposal_boxes (list[torch.Tensor]): [M, 4] proposal boxes per
+ batch element.
+ target_classes (list[torch.Tensor]): list of [M, 4] assigned
+ target boxes for each proposal.
+ target_masks (list[torch.Tensor]): list of [M, H, W] assigned
+ target masks for each proposal.
+
+ Returns:
+ MaskRCNNHeadLosses: mask loss.
+ """
+ mask_pred = torch.cat(mask_preds)
+ mask_size = (mask_pred.shape[2], mask_pred.shape[3])
+ # get targets
+ targets = []
+ for boxes, tgt_masks in zip(proposal_boxes, target_masks):
+ if len(tgt_masks) == 0:
+ targets.append(
+ torch.empty((0, *mask_size), device=tgt_masks.device)
+ )
+ else:
+ targets.append(
+ self._get_targets_per_image(boxes, tgt_masks, mask_size)
+ )
+ mask_targets = torch.cat(targets)
+ mask_labels = torch.cat(target_classes)
+
+ if len(mask_targets) > 0:
+ num_rois = mask_pred.shape[0]
+ inds = torch.arange(
+ 0, num_rois, dtype=torch.long, device=mask_pred.device
+ )
+ pred_slice = mask_pred[inds, mask_labels[inds].long()].squeeze(1)
+ loss_mask = F.binary_cross_entropy_with_logits(
+ pred_slice, mask_targets.float(), reduction="mean"
+ )
+ else:
+ loss_mask = mask_targets.sum()
+
+ return MaskRCNNHeadLosses(rcnn_loss_mask=loss_mask)
+
+
+class MaskSampler(Protocol):
+ """Type definition for mask sampler."""
+
+ def __call__(
+ self,
+ target_masks: list[Tensor],
+ sampled_target_indices: list[Tensor],
+ sampled_targets: Targets,
+ sampled_proposals: Proposals,
+ ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
+ """Type definition for function call.
+
+ Args:
+ target_masks (list[Tensor]): list of [N, H, W] target masks per
+ batch element.
+ sampled_target_indices (list[Tensor]): list of [M] indices of
+ sampled targets per batch element.
+ sampled_targets (Targets): sampled targets.
+ sampled_proposals (Proposals): sampled proposals.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor], list[Tensor]]: sampled masks,
+ sampled target indices, sampled targets.
+ """
+
+
+def positive_mask_sampler(
+ target_masks: list[Tensor],
+ sampled_target_indices: list[Tensor],
+ sampled_targets: Targets,
+ sampled_proposals: Proposals,
+) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
+ """Sample only positive masks from target masks.
+
+ Args:
+ target_masks (list[Tensor]): list of [N, H, W] target masks per
+ batch element.
+ sampled_target_indices (list[Tensor]): list of [M] indices of
+ sampled targets per batch element.
+ sampled_targets (Targets): sampled targets.
+ sampled_proposals (Proposals): sampled proposals.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor], list[Tensor]]: sampled masks,
+ sampled target indices, sampled targets.
+ """
+ sampled_masks = apply_mask(sampled_target_indices, target_masks)[0]
+
+ pos_proposals, pos_classes, pos_mask_targets = apply_mask(
+ [torch.eq(label, 1) for label in sampled_targets.labels],
+ sampled_proposals.boxes,
+ sampled_targets.classes,
+ sampled_masks,
+ )
+ return pos_proposals, pos_classes, pos_mask_targets
+
+
+class SampledMaskLoss(nn.Module):
+ """Sampled Mask RCNN head loss function."""
+
+ def __init__(
+ self,
+ mask_sampler: MaskSampler,
+ loss: MaskRCNNHeadLoss,
+ ) -> None:
+ """Initialize sampled mask loss.
+
+ Args:
+ mask_sampler (MaskSampler): mask sampler.
+ loss (MaskRCNNHeadLoss): mask loss.
+ """
+ super().__init__()
+ self.loss = loss
+ self.mask_sampler = mask_sampler
+
+ def forward(
+ self,
+ mask_preds: list[Tensor],
+ target_masks: list[Tensor],
+ sampled_target_indices: list[Tensor],
+ sampled_targets: Targets,
+ sampled_proposals: Proposals,
+ ) -> MaskRCNNHeadLosses:
+ """Calculate losses of Mask RCNN head.
+
+ Args:
+ mask_preds (list[torch.Tensor]): [M, C, H', W'] mask outputs per
+ batch element.
+ target_masks (list[torch.Tensor]): list of [M, H, W] assigned
+ target masks for each proposal.
+ sampled_target_indices (list[Tensor]): list of [M, 4] assigned
+ target boxes for each proposal.
+ sampled_targets (Targets): list of [M, 4] assigned
+ target boxes for each proposal.
+ sampled_proposals (Proposals): list of [M, 4] assigned
+ target boxes for each proposal.
+
+ Returns:
+ MaskRCNNHeadLosses: mask loss.
+ """
+ pos_proposals, pos_classes, pos_mask_targets = self.mask_sampler(
+ target_masks,
+ sampled_target_indices,
+ sampled_targets,
+ sampled_proposals,
+ )
+ return self.loss(
+ mask_preds, pos_proposals, pos_classes, pos_mask_targets
+ )
diff --git a/vis4d/op/detect/rcnn.py b/vis4d/op/detect/rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb4fa5cbd1f7f57ead00bcf7614741cd4b636982
--- /dev/null
+++ b/vis4d/op/detect/rcnn.py
@@ -0,0 +1,452 @@
+"""Faster R-CNN RoI head."""
+
+from __future__ import annotations
+
+from math import prod
+from typing import NamedTuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from vis4d.common.typing import TorchLossFunc
+from vis4d.op.box.box2d import bbox_clip, multiclass_nms
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
+from vis4d.op.box.poolers import MultiScaleRoIAlign
+from vis4d.op.detect.common import DetOut
+from vis4d.op.layer.conv2d import add_conv_branch
+from vis4d.op.layer.weight_init import kaiming_init, normal_init, xavier_init
+from vis4d.op.loss.common import l1_loss
+from vis4d.op.loss.reducer import SumWeightedLoss
+
+
+class RCNNOut(NamedTuple):
+ """Faster R-CNN RoI head outputs."""
+
+ # Logits for box classication. The logit dimension is number of classes
+ # plus 1 for the background.
+ cls_score: torch.Tensor
+ # Each box has regression for all classes. So the tensor dimention is
+ # [batch_size, number of boxes, number of classes x 4]
+ bbox_pred: torch.Tensor
+
+
+def get_default_rcnn_box_codec(
+ target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
+ target_stds: tuple[float, float, float, float] = (0.1, 0.1, 0.2, 0.2),
+) -> tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder]:
+ """Get the default bounding box encoder and decoder for RCNN."""
+ return (
+ DeltaXYWHBBoxEncoder(target_means, target_stds),
+ DeltaXYWHBBoxDecoder(target_means, target_stds),
+ )
+
+
+class RCNNHead(nn.Module):
+ """Faster R-CNN RoI head.
+
+ This head pools the RoIs from a set of feature maps and processes them
+ into classification / regression outputs.
+
+ Args:
+ num_shared_convs (int, optional): number of shared conv layers.
+ Defaults to 0.
+ num_shared_fcs (int, optional): number of shared fc layers. Defaults
+ to 2.
+ conv_out_channels (int, optional): number of output channels for
+ shared conv layers. Defaults to 256.
+ in_channels (int, optional): Number of channels in input feature maps.
+ Defaults to 256.
+ fc_out_channels (int, optional): Output channels of shared linear
+ layers. Defaults to 1024.
+ num_classes (int, optional): number of categories. Defaults to 80.
+ roi_size (tuple[int, int], optional): size of pooled RoIs. Defaults
+ to (7, 7).
+ """
+
+ def __init__(
+ self,
+ num_shared_convs: int = 0,
+ num_shared_fcs: int = 2,
+ conv_out_channels: int = 256,
+ in_channels: int = 256,
+ fc_out_channels: int = 1024,
+ num_classes: int = 80,
+ roi_size: tuple[int, int] = (7, 7),
+ start_level: int = 2,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.roi_pooler = MultiScaleRoIAlign(
+ sampling_ratio=0, resolution=roi_size, strides=[4, 8, 16, 32]
+ )
+
+ # Used feature layers are [start_level, end_level)
+ self.start_level = start_level
+ self.end_level = start_level + len(self.roi_pooler.scales)
+
+ self.num_shared_convs = num_shared_convs
+ self.num_shared_fcs = num_shared_fcs
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+
+ # add shared convs and fcs
+ (
+ self.shared_convs,
+ self.shared_fcs,
+ last_layer_dim,
+ ) = self._add_conv_fc_branch(
+ self.num_shared_convs, self.num_shared_fcs, in_channels, True
+ )
+ self.shared_out_channels = last_layer_dim
+
+ in_channels *= prod(roi_size)
+
+ self.fc_cls = nn.Linear(
+ in_features=fc_out_channels, out_features=num_classes + 1
+ )
+ self.fc_reg = nn.Linear(
+ in_features=fc_out_channels, out_features=4 * num_classes
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ self._init_weights()
+
+ def _add_conv_fc_branch(
+ self,
+ num_branch_convs: int = 0,
+ num_branch_fcs: int = 0,
+ in_channels: int = 0,
+ is_shared: bool = False,
+ ) -> tuple[nn.ModuleList, nn.ModuleList, int]:
+ """Add shared or separable branch."""
+ convs, last_layer_dim = add_conv_branch(
+ num_branch_convs,
+ in_channels,
+ self.conv_out_channels,
+ True,
+ None,
+ None,
+ )
+
+ fcs = nn.ModuleList()
+ if num_branch_fcs > 0:
+ if is_shared or num_branch_fcs == 0:
+ last_layer_dim *= int(np.prod(self.roi_pooler.resolution))
+ for i in range(num_branch_fcs):
+ fc_in_dim = last_layer_dim if i == 0 else self.fc_out_channels
+ fcs.append(nn.Linear(fc_in_dim, self.fc_out_channels))
+ return convs, fcs, last_layer_dim
+
+ def _init_weights(self) -> None:
+ """Init weights."""
+ for m in self.shared_convs.modules():
+ kaiming_init(m)
+
+ for m in self.shared_fcs.modules():
+ xavier_init(m, distribution="uniform")
+
+ normal_init(self.fc_cls, std=0.01)
+ normal_init(self.fc_reg, std=0.001)
+
+ def forward(
+ self, features: list[torch.Tensor], boxes: list[torch.Tensor]
+ ) -> RCNNOut:
+ """Forward pass during training stage."""
+ bbox_feats = self.roi_pooler(
+ features[self.start_level : self.end_level], boxes
+ )
+ if self.num_shared_convs > 0:
+ for conv in self.shared_convs:
+ bbox_feats = conv(bbox_feats)
+
+ bbox_feats = bbox_feats.flatten(start_dim=1)
+
+ for fc in self.shared_fcs:
+ bbox_feats = self.relu(fc(bbox_feats))
+ cls_score = self.fc_cls(bbox_feats)
+ bbox_pred = self.fc_reg(bbox_feats)
+ return RCNNOut(cls_score, bbox_pred)
+
+ def __call__(
+ self, features: list[torch.Tensor], boxes: list[torch.Tensor]
+ ) -> RCNNOut:
+ """Type definition for function call."""
+ return self._call_impl(features, boxes)
+
+
+class RoI2Det(nn.Module):
+ """Post processing of RCNN results and detection generation.
+
+ It does the following:
+ 1. Take the classification and regression outputs from the RCNN heads.
+ 2. Take the proposal boxes that are RCNN inputs.
+ 3. Determine the final box classes and take the according box regression
+ parameters.
+ 4. Adjust the box sizes and offsets according the regression parameters.
+ 5. Return the final boxes.
+ """
+
+ def __init__(
+ self,
+ box_decoder: None | DeltaXYWHBBoxDecoder = None,
+ score_threshold: float = 0.05,
+ iou_threshold: float = 0.5,
+ max_per_img: int = 100,
+ class_agnostic_nms: bool = False,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ box_decoder (DeltaXYWHBBoxDecoder, optional): Decodes regression
+ parameters to detected boxes. Defaults to None. If None, it
+ will use the default decoder.
+ score_threshold (float, optional): Minimum score of a detection.
+ Defaults to 0.05.
+ iou_threshold (float, optional): IoU threshold of NMS
+ post-processing step. Defaults to 0.5.
+ max_per_img (int, optional): Maximum number of detections per
+ image. Defaults to 100.
+ class_agnostic_nms (bool, optional): Whether to use class agnostic
+ NMS. Defaults to False.
+ """
+ super().__init__()
+ if box_decoder is None:
+ _, self.box_decoder = get_default_rcnn_box_codec()
+ else:
+ self.box_decoder = box_decoder
+ self.score_threshold = score_threshold
+ self.max_per_img = max_per_img
+ self.iou_threshold = iou_threshold
+ self.class_agnostic_nms = class_agnostic_nms
+
+ def forward(
+ self,
+ class_outs: torch.Tensor,
+ regression_outs: torch.Tensor,
+ boxes: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Convert RCNN network outputs to detections.
+
+ Args:
+ class_outs (torch.Tensor): [B, num_classes] batched tensor of
+ classifiation scores.
+ regression_outs (torch.Tensor): [B, num_classes * 4] predicted
+ box offsets.
+ boxes (list[torch.Tensor]): Initial boxes (RoIs).
+ images_hw (list[tuple[int, int]]): Image sizes.
+
+ Returns:
+ DetOut: boxes, scores and class ids of detections per image.
+ """
+ num_proposals_per_img = tuple(len(p) for p in boxes)
+ regression_outs = regression_outs.split(num_proposals_per_img, 0)
+ class_outs = class_outs.split(num_proposals_per_img, 0)
+ all_det_boxes = []
+ all_det_scores = []
+ all_det_class_ids = []
+ for cls_out, reg_out, boxs, image_hw in zip(
+ class_outs, regression_outs, boxes, images_hw
+ ):
+ scores = F.softmax(cls_out, dim=-1)
+ bboxes = bbox_clip(
+ self.box_decoder(boxs[:, :4], reg_out).view(-1, 4),
+ image_hw,
+ ).view(reg_out.shape)
+ det_bbox, det_scores, det_label, _ = multiclass_nms(
+ bboxes,
+ scores,
+ self.score_threshold,
+ self.iou_threshold,
+ self.max_per_img,
+ self.class_agnostic_nms,
+ )
+ all_det_boxes.append(det_bbox)
+ all_det_scores.append(det_scores)
+ all_det_class_ids.append(det_label)
+
+ return DetOut(
+ boxes=all_det_boxes,
+ scores=all_det_scores,
+ class_ids=all_det_class_ids,
+ )
+
+ def __call__(
+ self,
+ class_outs: torch.Tensor,
+ regression_outs: torch.Tensor,
+ boxes: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Type definition for function call."""
+ return self._call_impl(class_outs, regression_outs, boxes, images_hw)
+
+
+class RCNNTargets(NamedTuple):
+ """Target container."""
+
+ labels: Tensor
+ label_weights: Tensor
+ bbox_targets: Tensor
+ bbox_weights: Tensor
+
+
+class RCNNLosses(NamedTuple):
+ """RCNN loss container."""
+
+ rcnn_loss_cls: torch.Tensor
+ rcnn_loss_bbox: torch.Tensor
+
+
+class RCNNLoss(nn.Module):
+ """RCNN loss in Faster R-CNN.
+
+ This class computes the loss of RCNN given proposal boxes and their
+ corresponding target boxes with the given box encoder.
+ """
+
+ def __init__(
+ self,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ num_classes: int = 80,
+ loss_cls: TorchLossFunc = F.cross_entropy,
+ loss_bbox: TorchLossFunc = l1_loss,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ box_encoder (DeltaXYWHBBoxEncoder): Decodes box regression
+ parameters into detected boxes.
+ num_classes (int, optional): number of object categories. Defaults
+ to 80.
+ loss_cls (TorchLossFunc, optional): Classification loss function.
+ Defaults to F.cross_entropy.
+ loss_bbox (TorchLossFunc, optional): Regression loss function.
+ Defaults to l1_loss.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.box_encoder = box_encoder
+ self.loss_cls = loss_cls
+ self.loss_bbox = loss_bbox
+
+ def _get_targets_per_image(
+ self,
+ boxes: Tensor,
+ labels: Tensor,
+ target_boxes: Tensor,
+ target_classes: Tensor,
+ ) -> RCNNTargets:
+ """Generate targets per image.
+
+ Args:
+ boxes (Tensor): [N, 4] tensor of proposal boxes
+ labels (Tensor): [N,] tensor of positive / negative / ignore labels
+ target_boxes (Tensor): [N, 4] Assigned target boxes.
+ target_classes (Tensor): [N,] Assigned target class labels.
+
+ Returns:
+ RCNNTargets: Box / class label tensors and weights.
+ """
+ pos_mask, neg_mask = torch.eq(labels, 1), torch.eq(labels, 0)
+ num_pos, num_neg = int(pos_mask.sum()), int(neg_mask.sum())
+ num_samples = num_pos + num_neg
+
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = boxes.new_full(
+ (num_samples,), self.num_classes, dtype=torch.long
+ )
+ label_weights = boxes.new_zeros(num_samples)
+ box_targets = boxes.new_zeros(num_samples, 4)
+ box_weights = boxes.new_zeros(num_samples, 4)
+ if num_pos > 0:
+ pos_target_boxes = target_boxes[pos_mask]
+ pos_target_classes = target_classes[pos_mask]
+ labels[:num_pos] = pos_target_classes
+ label_weights[:num_pos] = 1.0
+ pos_box_targets = self.box_encoder(
+ boxes[pos_mask], pos_target_boxes
+ )
+ box_targets[:num_pos, :] = pos_box_targets
+ box_weights[:num_pos, :] = 1
+ if num_neg > 0:
+ label_weights[-num_neg:] = 1.0
+ return RCNNTargets(labels, label_weights, box_targets, box_weights)
+
+ def forward(
+ self,
+ class_outs: torch.Tensor,
+ regression_outs: torch.Tensor,
+ boxes: list[torch.Tensor],
+ boxes_mask: list[torch.Tensor],
+ target_boxes: list[torch.Tensor],
+ target_classes: list[torch.Tensor],
+ ) -> RCNNLosses:
+ """Calculate losses of RCNN head.
+
+ Args:
+ class_outs (torch.Tensor): [M*B, num_classes] classification
+ outputs.
+ regression_outs (torch.Tensor): Tensor[M*B, regression_params]
+ regression outputs.
+ boxes (list[torch.Tensor]): [M, 4] proposal boxes per batch
+ element.
+ boxes_mask (list[torch.Tensor]): positive (1), ignore (-1),
+ negative (0).
+ target_boxes (list[torch.Tensor]): list of [M, 4] assigned target
+ boxes for each proposal.
+ target_classes (list[torch.Tensor]): list of [M,] assigned target
+ classes for each proposal.
+
+ Returns:
+ RCNNLosses: classification and regression losses.
+ """
+ # get targets
+ targets = []
+ for boxs, boxs_mask, tgt_boxs, tgt_cls in zip(
+ boxes, boxes_mask, target_boxes, target_classes
+ ):
+ targets.append(
+ self._get_targets_per_image(boxs, boxs_mask, tgt_boxs, tgt_cls)
+ )
+
+ labels = torch.cat([tgt.labels for tgt in targets], 0)
+ label_weights = torch.cat([tgt.label_weights for tgt in targets], 0)
+ bbox_targets = torch.cat([tgt.bbox_targets for tgt in targets], 0)
+ bbox_weights = torch.cat([tgt.bbox_weights for tgt in targets], 0)
+
+ # compute losses
+ avg_factor = torch.sum(torch.greater(label_weights, 0)).clamp(1.0)
+ if class_outs.numel() > 0:
+ loss_cls = SumWeightedLoss(label_weights, avg_factor)(
+ self.loss_cls(class_outs, labels, reduction="none")
+ )
+ else:
+ loss_cls = class_outs.sum()
+
+ bg_class_ind = self.num_classes
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ pos_inds = torch.logical_and(
+ torch.greater_equal(labels, 0), torch.less(labels, bg_class_ind)
+ )
+ # do not perform bounding box regression for BG anymore.
+ if pos_inds.any():
+ pos_reg_outs = regression_outs.view(
+ regression_outs.size(0), -1, 4
+ )[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]]
+ loss_bbox = self.loss_bbox(
+ pred=pos_reg_outs,
+ target=bbox_targets[pos_inds.type(torch.bool)],
+ reducer=SumWeightedLoss(
+ bbox_weights[pos_inds.type(torch.bool)],
+ bbox_targets.size(0),
+ ),
+ )
+ else:
+ loss_bbox = regression_outs[pos_inds].sum()
+
+ return RCNNLosses(rcnn_loss_cls=loss_cls, rcnn_loss_bbox=loss_bbox)
diff --git a/vis4d/op/detect/retinanet.py b/vis4d/op/detect/retinanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e847c56336fb1a62b308e93d8a015a89069481
--- /dev/null
+++ b/vis4d/op/detect/retinanet.py
@@ -0,0 +1,410 @@
+"""RetinaNet."""
+
+from __future__ import annotations
+
+from math import prod
+from typing import NamedTuple
+
+import torch
+from torch import nn
+from torchvision.ops import batched_nms, sigmoid_focal_loss
+
+from vis4d.common.typing import TorchLossFunc
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.box2d import bbox_clip, filter_boxes_by_area
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
+from vis4d.op.box.matchers import Matcher, MaxIoUMatcher
+from vis4d.op.box.samplers import PseudoSampler, Sampler
+from vis4d.op.loss.common import l1_loss
+
+from .common import DetOut
+from .dense_anchor import DenseAnchorHeadLoss
+
+
+class RetinaNetOut(NamedTuple):
+ """RetinaNet head outputs."""
+
+ # Logits for box classification for each feature level. The logit
+ # dimention is [batch_size, number of anchors * number of classes, height,
+ # width].
+ cls_score: list[torch.Tensor]
+ # Each box has regression for all classes for each feature level. So the
+ # tensor dimension is [batch_size, number of anchors * 4, height, width].
+ bbox_pred: list[torch.Tensor]
+
+
+def get_default_anchor_generator() -> AnchorGenerator:
+ """Get default anchor generator."""
+ return AnchorGenerator(
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128],
+ )
+
+
+def get_default_box_codec() -> (
+ tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder]
+):
+ """Get the default bounding box encoder."""
+ return (
+ DeltaXYWHBBoxEncoder(
+ target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0)
+ ),
+ DeltaXYWHBBoxDecoder(
+ target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0)
+ ),
+ )
+
+
+def get_default_box_matcher() -> MaxIoUMatcher:
+ """Get default bounding box matcher."""
+ return MaxIoUMatcher(
+ thresholds=[0.4, 0.5],
+ labels=[0, -1, 1],
+ allow_low_quality_matches=True,
+ )
+
+
+def get_default_box_sampler() -> PseudoSampler:
+ """Get default bounding box sampler."""
+ return PseudoSampler()
+
+
+class RetinaNetHead(nn.Module): # TODO: Refactor to use the new API
+ """RetinaNet Head."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ in_channels: int,
+ feat_channels: int = 256,
+ stacked_convs: int = 4,
+ use_sigmoid_cls: bool = True,
+ anchor_generator: AnchorGenerator | None = None,
+ box_decoder: DeltaXYWHBBoxDecoder | None = None,
+ box_matcher: Matcher | None = None,
+ box_sampler: Sampler | None = None,
+ ):
+ """Creates an instance of the class."""
+ super().__init__()
+ self.anchor_generator = (
+ anchor_generator
+ if anchor_generator is not None
+ else get_default_anchor_generator()
+ )
+ if box_decoder is None:
+ _, self.box_decoder = get_default_box_codec()
+ else:
+ self.box_decoder = box_decoder
+ self.box_matcher = (
+ box_matcher
+ if box_matcher is not None
+ else get_default_box_matcher()
+ )
+ self.box_sampler = (
+ box_sampler
+ if box_sampler is not None
+ else get_default_box_sampler()
+ )
+ num_base_priors = self.anchor_generator.num_base_priors[0]
+
+ if use_sigmoid_cls:
+ cls_out_channels = num_classes
+ else:
+ cls_out_channels = num_classes + 1
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(stacked_convs):
+ chn = in_channels if i == 0 else feat_channels
+ self.cls_convs.append(
+ nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1),
+ )
+ self.reg_convs.append(
+ nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1),
+ )
+ self.retina_cls = nn.Conv2d(
+ feat_channels, num_base_priors * cls_out_channels, 3, padding=1
+ )
+ self.retina_reg = nn.Conv2d(
+ feat_channels, num_base_priors * 4, 3, padding=1
+ )
+
+ def forward(self, features: list[torch.Tensor]) -> RetinaNetOut:
+ """Forward pass of RetinaNet.
+
+ Args:
+ features (list[torch.Tensor]): Feature pyramid
+
+ Returns:
+ RetinaNetOut: classification score and box prediction.
+ """
+ cls_scores, bbox_preds = [], []
+ for feat in features:
+ cls_feat = feat
+ reg_feat = feat
+ for cls_conv in self.cls_convs:
+ cls_feat = self.relu(cls_conv(cls_feat))
+ for reg_conv in self.reg_convs:
+ reg_feat = self.relu(reg_conv(reg_feat))
+ cls_scores.append(self.retina_cls(cls_feat))
+ bbox_preds.append(self.retina_reg(reg_feat))
+ return RetinaNetOut(cls_score=cls_scores, bbox_pred=bbox_preds)
+
+ def __call__(self, features: list[torch.Tensor]) -> RetinaNetOut:
+ """Type definition for call implementation."""
+ return self._call_impl(features)
+
+
+def get_params_per_level(
+ cls_out: torch.Tensor,
+ reg_out: torch.Tensor,
+ anchors: torch.Tensor,
+ num_pre_nms: int = 2000,
+ score_thr: float = 0.0,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Get topk params from feature output per level per image before nms.
+
+ Params include flattened classification scores, box energies, and
+ corresponding anchors.
+
+ Args:
+ cls_out (torch.Tensor):
+ [C, H, W] classification scores at a particular scale.
+ reg_out (torch.Tensor):
+ [C, H, W] regression parameters at a particular scale.
+ anchors (torch.Tensor): [H * W, 4] anchor boxes per cell.
+ num_pre_nms (int): number of predictions before nms.
+ score_thr (float): score threshold for filtering predictions.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: topk
+ flattened classification, regression outputs, and corresponding
+ anchors.
+ """
+ assert cls_out.size()[-2:] == reg_out.size()[-2:], (
+ f"Shape mismatch: cls_out({cls_out.size()[-2:]}), reg_out("
+ f"{reg_out.size()[-2:]})."
+ )
+ reg_out = reg_out.permute(1, 2, 0).reshape(-1, 4)
+ cls_out = cls_out.permute(1, 2, 0).reshape(reg_out.size(0), -1).sigmoid()
+ valid_mask = torch.greater(cls_out, score_thr)
+ valid_idxs = torch.nonzero(valid_mask)
+ num_topk = min(num_pre_nms, valid_idxs.size(0))
+ cls_out_filt = cls_out[valid_mask]
+ cls_out_ranked, rank_inds = cls_out_filt.sort(descending=True)
+ topk_inds = valid_idxs[rank_inds[:num_topk]]
+ keep_inds, labels = topk_inds.unbind(dim=1)
+ cls_out = cls_out_ranked[:num_topk]
+ reg_out = reg_out[keep_inds, :]
+ anchors = anchors[keep_inds, :]
+
+ return cls_out, labels, reg_out, anchors
+
+
+def decode_multi_level_outputs(
+ cls_out_all: list[torch.Tensor],
+ lbl_out_all: list[torch.Tensor],
+ reg_out_all: list[torch.Tensor],
+ anchors_all: list[torch.Tensor],
+ image_hw: tuple[int, int],
+ box_decoder: DeltaXYWHBBoxDecoder,
+ max_per_img: int = 1000,
+ nms_threshold: float = 0.7,
+ min_box_size: tuple[int, int] = (0, 0),
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Decode box energies into detections for a single image.
+
+ Detections are post-processed via NMS. NMS is performed per level.
+ Afterwards, select topk detections.
+
+ Args:
+ cls_out_all (list[torch.Tensor]): topk class scores per level.
+ lbl_out_all (list[torch.Tensor]): topk class labels per level.
+ reg_out_all (list[torch.Tensor]): topk regression params per level.
+ anchors_all (list[torch.Tensor]): topk anchor boxes per level.
+ image_hw (tuple[int, int]): image size.
+ box_decoder (DeltaXYWHBBoxDecoder): bounding box encoder.
+ max_per_img (int, optional): maximum predictions per image.
+ Defaults to 1000.
+ nms_threshold (float, optional): iou threshold for NMS.
+ Defaults to 0.7.
+ min_box_size (tuple[int, int], optional): minimum box size.
+ Defaults to (0, 0).
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: decoded proposal boxes & scores.
+ """
+ scores, labels = torch.cat(cls_out_all), torch.cat(lbl_out_all)
+ boxes = bbox_clip(
+ box_decoder(torch.cat(anchors_all), torch.cat(reg_out_all)),
+ image_hw,
+ )
+
+ boxes, mask = filter_boxes_by_area(boxes, min_area=prod(min_box_size))
+ scores, labels = scores[mask], labels[mask]
+
+ if boxes.numel() > 0:
+ keep = batched_nms(boxes, scores, labels, iou_threshold=nms_threshold)[
+ :max_per_img
+ ]
+ return boxes[keep], scores[keep], labels[keep]
+ return (boxes.new_zeros(0, 4), scores.new_zeros(0), labels.new_zeros(0))
+
+
+class Dense2Det(nn.Module):
+ """Compute detections from dense network outputs.
+
+ This class acts as a stateless functor that does the following:
+ 1. Create anchor grid for feature grids (classification and regression
+ outputs) at all scales.
+ For each image
+ For each level
+ 2. Get a topk pre-selection of flattened classification scores and
+ box energies from feature output before NMS.
+ 3. Decode class scores and box energies into detection boxes,
+ apply NMS.
+ Return detection boxes for all images.
+ """
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ box_decoder: DeltaXYWHBBoxDecoder,
+ num_pre_nms: int = 2000,
+ max_per_img: int = 1000,
+ nms_threshold: float = 0.7,
+ min_box_size: tuple[int, int] = (0, 0),
+ score_thr: float = 0.0,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.anchor_generator = anchor_generator
+ self.box_decoder = box_decoder
+ self.num_pre_nms = num_pre_nms
+ self.max_per_img = max_per_img
+ self.nms_threshold = nms_threshold
+ self.min_box_size = min_box_size
+ self.score_thr = score_thr
+
+ def forward(
+ self,
+ cls_outs: list[torch.Tensor],
+ reg_outs: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Compute detections from dense network outputs.
+
+ Generate anchor grid for all scales.
+ For each batch element:
+ Compute classification, regression, and anchor pairs for all
+ scales. Decode those pairs into proposals, post-process with NMS.
+
+ Args:
+ cls_outs (list[torch.Tensor]): [N, C * A, H, W] per scale.
+ reg_outs (list[torch.Tensor]): [N, 4 * A, H, W] per scale.
+ images_hw (list[tuple[int, int]]): list of image sizes.
+
+ Returns:
+ DetOut: Detection outputs.
+ """
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ device = cls_outs[0].device
+ featmap_sizes: list[tuple[int, int]] = [
+ featmap.size()[-2:] for featmap in cls_outs # type: ignore
+ ]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+ anchor_grids = self.anchor_generator.grid_priors(
+ featmap_sizes, device=device
+ )
+ proposals, scores, labels = [], [], []
+ for img_id, image_hw in enumerate(images_hw):
+ cls_out_all, lbl_out_all, reg_out_all, anchors_all = [], [], [], []
+ for cls_out, reg_out, anchor_grid in zip(
+ cls_outs, reg_outs, anchor_grids
+ ):
+ cls_out_, lbl_out, reg_out_, anchors = get_params_per_level(
+ cls_out[img_id],
+ reg_out[img_id],
+ anchor_grid,
+ self.num_pre_nms,
+ self.score_thr,
+ )
+ cls_out_all += [cls_out_]
+ lbl_out_all += [lbl_out]
+ reg_out_all += [reg_out_]
+ anchors_all += [anchors]
+
+ box, score, label = decode_multi_level_outputs(
+ cls_out_all,
+ lbl_out_all,
+ reg_out_all,
+ anchors_all,
+ image_hw,
+ self.box_decoder,
+ self.max_per_img,
+ self.nms_threshold,
+ self.min_box_size,
+ )
+ proposals.append(box)
+ scores.append(score)
+ labels.append(label)
+ return DetOut(proposals, scores, labels)
+
+ def __call__(
+ self,
+ cls_outs: list[torch.Tensor],
+ reg_outs: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Type definition for function call."""
+ return self._call_impl(cls_outs, reg_outs, images_hw)
+
+
+class RetinaNetHeadLoss(DenseAnchorHeadLoss):
+ """Loss of RetinaNet head."""
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ box_matcher: None | Matcher = None,
+ box_sampler: None | Sampler = None,
+ loss_cls: TorchLossFunc = sigmoid_focal_loss,
+ loss_bbox: TorchLossFunc = l1_loss,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ anchor_generator (AnchorGenerator): Generates anchor grid priors.
+ box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to the
+ desired network output.
+ box_matcher (None | Matcher, optional): Box matcher. Defaults to
+ None.
+ box_sampler (None | Sampler, optional): Box sampler. Defaults to
+ None.
+ loss_cls (TorchLossFunc, optional): Classification loss function.
+ Defaults to sigmoid_focal_loss.
+ loss_bbox (TorchLossFunc, optional): Regression loss function.
+ Defaults to l1_loss.
+ """
+ matcher = (
+ box_matcher
+ if box_matcher is not None
+ else get_default_box_matcher()
+ )
+ sampler = (
+ box_sampler
+ if box_sampler is not None
+ else get_default_box_sampler()
+ )
+ super().__init__(
+ anchor_generator,
+ box_encoder,
+ matcher,
+ sampler,
+ loss_cls,
+ loss_bbox,
+ )
diff --git a/vis4d/op/detect/rpn.py b/vis4d/op/detect/rpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae646d47b2473588f96a82c25be0db202aa9f13a
--- /dev/null
+++ b/vis4d/op/detect/rpn.py
@@ -0,0 +1,421 @@
+"""Faster RCNN RPN Head."""
+
+from __future__ import annotations
+
+from math import prod
+from typing import NamedTuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops import batched_nms
+
+from vis4d.common.typing import TorchLossFunc
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.box2d import bbox_clip, filter_boxes_by_area
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
+from vis4d.op.box.matchers import Matcher, MaxIoUMatcher
+from vis4d.op.box.samplers import RandomSampler, Sampler
+from vis4d.op.layer.conv2d import Conv2d
+from vis4d.op.loss.common import l1_loss
+
+from .dense_anchor import DenseAnchorHeadLoss, DenseAnchorHeadLosses
+from .typing import Proposals
+
+
+class RPNOut(NamedTuple):
+ """Output of RPN head."""
+
+ # Sigmoid input for binary classification of the anchor
+ # Positive means there is an object in that anchor.
+ # Each list item is for on feature pyramid level.
+ cls: list[torch.Tensor]
+ # 4 x number of anchors for center offets and sizes (width, height) of the
+ # boxes under the anchor.
+ # Each list item is for on feature pyramid level.
+ box: list[torch.Tensor]
+
+
+def get_default_rpn_box_codec(
+ target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
+ target_stds: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
+) -> tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder]:
+ """Get the default bounding box encoder and decoder for RPN."""
+ return (
+ DeltaXYWHBBoxEncoder(target_means, target_stds),
+ DeltaXYWHBBoxDecoder(target_means, target_stds),
+ )
+
+
+class RPNHead(nn.Module):
+ """Faster RCNN RPN Head.
+
+ Creates RPN network output from a multi-scale feature map input.
+ """
+
+ rpn_conv: nn.Module
+
+ def __init__(
+ self,
+ num_anchors: int,
+ num_convs: int = 1,
+ in_channels: int = 256,
+ feat_channels: int = 256,
+ start_level: int = 2,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_anchors (int): Number of anchors per cell.
+ num_convs (int, optional): Number of conv layers before RPN heads.
+ Defaults to 1.
+ in_channels (int, optional): Feature channel size of input feature
+ maps. Defaults to 256.
+ feat_channels (int, optional): Feature channel size of conv layers.
+ Defaults to 256.
+ start_level (int, optional): starting level of feature maps.
+ Defaults to 2.
+ """
+ super().__init__()
+ self.start_level = start_level
+
+ if num_convs > 1:
+ rpn_convs = []
+ for i in range(num_convs):
+ if i > 0:
+ in_channels = feat_channels
+ rpn_convs.append(
+ Conv2d(
+ in_channels,
+ feat_channels,
+ kernel_size=3,
+ padding=1,
+ activation=nn.ReLU(inplace=False),
+ )
+ )
+ self.rpn_conv = nn.Sequential(*rpn_convs)
+ else:
+ self.rpn_conv = Conv2d(
+ in_channels,
+ feat_channels,
+ kernel_size=3,
+ padding=1,
+ activation=nn.ReLU(inplace=True),
+ )
+ self.rpn_cls = Conv2d(feat_channels, num_anchors, 1)
+ self.rpn_box = Conv2d(feat_channels, num_anchors * 4, 1)
+
+ self.apply(self._init_weights)
+
+ @staticmethod
+ def _init_weights(module: nn.Module) -> None:
+ """Init RPN weights."""
+ if isinstance(module, nn.Conv2d):
+ module.weight.data.normal_(mean=0.0, std=0.01)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def forward(self, features: list[torch.Tensor]) -> RPNOut:
+ """Forward pass of RPN."""
+ cls_outs, box_outs = [], []
+ for feat in features[self.start_level :]:
+ feat = self.rpn_conv(feat)
+ cls_outs += [self.rpn_cls(feat)]
+ box_outs += [self.rpn_box(feat)]
+ return RPNOut(cls=cls_outs, box=box_outs)
+
+ def __call__(self, features: list[torch.Tensor]) -> RPNOut:
+ """Type definition."""
+ return self._call_impl(features)
+
+
+class RPN2RoI(nn.Module):
+ """Generate Proposals (RoIs) from RPN network output.
+
+ This class acts as a stateless functor that does the following:
+ 1. Create anchor grid for feature grids (classification and regression
+ outputs) at all scales.
+ For each image
+ For each level
+ 2. Get a topk pre-selection of flattened classification scores and
+ box energies from feature output before NMS.
+ 3. Decode class scores and box energies into proposal boxes, apply NMS.
+ Return proposal boxes for all images.
+ """
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ box_decoder: None | DeltaXYWHBBoxDecoder = None,
+ num_proposals_pre_nms_train: int = 2000,
+ num_proposals_pre_nms_test: int = 1000,
+ max_per_img: int = 1000,
+ proposal_nms_threshold: float = 0.7,
+ min_proposal_size: tuple[int, int] = (0, 0),
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ anchor_generator (AnchorGenerator): Creates anchor grid serving as
+ for bounding box regression.
+ box_decoder (DeltaXYWHBBoxDecoder, optional): decodes box energies
+ predicted by the network into 2D bounding box parameters.
+ Defaults to None. If None, uses the default decoder.
+ num_proposals_pre_nms_train (int, optional): How many boxes are
+ kept prior to NMS during training. Defaults to 2000.
+ num_proposals_pre_nms_test (int, optional): How many boxes are
+ kept prior to NMS during inference. Defaults to 1000.
+ max_per_img (int, optional): Maximum boxes per image.
+ Defaults to 1000.
+ proposal_nms_threshold (float, optional): NMS threshold on proposal
+ boxes. Defaults to 0.7.
+ min_proposal_size (tuple[int, int], optional): Minimum size of a
+ proposal box. Defaults to (0, 0).
+ """
+ super().__init__()
+ self.anchor_generator = anchor_generator
+
+ if box_decoder is None:
+ _, self.box_decoder = get_default_rpn_box_codec()
+ else:
+ self.box_decoder = box_decoder
+
+ self.max_per_img = max_per_img
+ self.min_proposal_size = min_proposal_size
+ self.num_proposals_pre_nms_train = num_proposals_pre_nms_train
+ self.num_proposals_pre_nms_test = num_proposals_pre_nms_test
+ self.proposal_nms_threshold = proposal_nms_threshold
+
+ def _get_params_per_level(
+ self,
+ cls_out: torch.Tensor,
+ reg_out: torch.Tensor,
+ anchors: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Get a topk pre-selection of parameters.
+
+ The parameters include flattened classification scores and box
+ energies from feature output per level per image before nms.
+
+ Args:
+ cls_out (torch.Tensor): [C, H, W] classification scores at a
+ particular scale.
+ reg_out (torch.Tensor): [C, H, W] regression parameters at a
+ particular scale.
+ anchors (torch.Tensor): [H*W, 4] anchor boxes per cell.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Topk flattened
+ classification, regression outputs and corresponding anchors.
+ """
+ assert cls_out.size()[-2:] == reg_out.size()[-2:], (
+ f"Shape mismatch: cls_out({cls_out.size()[-2:]}), reg_out("
+ f"{reg_out.size()[-2:]})."
+ )
+ cls_out = cls_out.permute(1, 2, 0).reshape(-1).sigmoid()
+ reg_out = reg_out.permute(1, 2, 0).reshape(-1, 4)
+ if self.training:
+ num_proposals_pre_nms = self.num_proposals_pre_nms_train
+ else:
+ num_proposals_pre_nms = self.num_proposals_pre_nms_test
+
+ if 0 < num_proposals_pre_nms < cls_out.shape[0]:
+ cls_out_ranked, rank_inds = cls_out.sort(descending=True)
+ topk_inds = rank_inds[:num_proposals_pre_nms]
+ cls_out = cls_out_ranked[:num_proposals_pre_nms]
+ reg_out = reg_out[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+
+ return cls_out, reg_out, anchors
+
+ def _decode_multi_level_outputs(
+ self,
+ cls_out_all: list[torch.Tensor],
+ reg_out_all: list[torch.Tensor],
+ anchors_all: list[torch.Tensor],
+ level_all: list[torch.Tensor],
+ image_hw: tuple[int, int],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Decode box energies into proposals for a single image, post-process.
+
+ Post-processing happens via NMS. NMS is performed per level.
+ Afterwards, select topk proposals.
+
+ Args:
+ cls_out_all (list[torch.Tensor]): topk class scores per level.
+ reg_out_all (list[torch.Tensor]): topk regression params per level.
+ anchors_all (list[torch.Tensor]): topk anchor boxes per level.
+ level_all (list[torch.Tensor]): tensors indicating level per entry.
+ image_hw (tuple[int, int]): image size.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: decoded proposal boxes & scores.
+ """
+ scores = torch.cat(cls_out_all)
+ levels = torch.cat(level_all)
+
+ proposals = bbox_clip(
+ self.box_decoder(torch.cat(anchors_all), torch.cat(reg_out_all)),
+ image_hw,
+ )
+
+ proposals, mask = filter_boxes_by_area(
+ proposals, min_area=prod(self.min_proposal_size)
+ )
+ scores = scores[mask]
+ levels = levels[mask]
+
+ if proposals.numel() > 0:
+ keep = batched_nms(
+ proposals,
+ scores,
+ levels,
+ iou_threshold=self.proposal_nms_threshold,
+ )[: self.max_per_img]
+ proposals = proposals[keep]
+ scores = scores[keep]
+ else: # pragma: no cover
+ return proposals.new_zeros(0, 4), scores.new_zeros(0)
+ return proposals, scores
+
+ def forward(
+ self,
+ class_outs: list[torch.Tensor],
+ regression_outs: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> Proposals:
+ """Compute proposals from RPN network outputs.
+
+ Generate anchor grid for all scales.
+ For each batch element:
+ Compute classification, regression and anchor pairs for all scales.
+ Decode those pairs into proposals, post-process with NMS.
+
+ Args:
+ class_outs (list[torch.Tensor]): [N, 1 * A, H, W] per scale.
+ regression_outs (list[torch.Tensor]): [N, 4 * A, H, W] per scale.
+ images_hw (list[tuple[int, int]]): list of image sizes.
+
+ Returns:
+ Proposals: proposal boxes and scores.
+ """
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ device = class_outs[0].device
+ featmap_sizes: list[tuple[int, int]] = [
+ featmap.size()[-2:] for featmap in class_outs # type: ignore
+ ]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+ anchor_grids = self.anchor_generator.grid_priors(
+ featmap_sizes, device=device
+ )
+ proposals, scores = [], []
+ for img_id, image_hw in enumerate(images_hw):
+ cls_out_all, reg_out_all, anchors_all, level_all = [], [], [], []
+ for level, (cls_outs, reg_outs, anchor_grid) in enumerate(
+ zip(class_outs, regression_outs, anchor_grids)
+ ):
+ cls_out, reg_out, anchors = self._get_params_per_level(
+ cls_outs[img_id], reg_outs[img_id], anchor_grid
+ )
+ cls_out_all += [cls_out]
+ reg_out_all += [reg_out]
+ anchors_all += [anchors]
+ level_all += [
+ cls_out.new_full((len(cls_out),), level, dtype=torch.long)
+ ]
+
+ box, score = self._decode_multi_level_outputs(
+ cls_out_all, reg_out_all, anchors_all, level_all, image_hw
+ )
+ proposals.append(box)
+ scores.append(score)
+ return Proposals(proposals, scores)
+
+
+class RPNLosses(NamedTuple):
+ """RPN loss container."""
+
+ rpn_loss_cls: torch.Tensor
+ rpn_loss_bbox: torch.Tensor
+
+
+class RPNLoss(DenseAnchorHeadLoss):
+ """Loss of region proposal network."""
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ box_encoder: DeltaXYWHBBoxEncoder,
+ matcher: Matcher | None = None,
+ sampler: Sampler | None = None,
+ loss_cls: TorchLossFunc = F.binary_cross_entropy_with_logits,
+ loss_bbox: TorchLossFunc = l1_loss,
+ ):
+ """Creates an instance of the class.
+
+ Args:
+ anchor_generator (AnchorGenerator): Generates anchor grid priors.
+ box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to the
+ desired network output.
+ matcher (Matcher): Matches ground truth boxes to anchor grid
+ priors. Defaults to None. If None, uses MaxIoUMatcher.
+ sampler (Sampler): Samples anchors for training. Defaults to None.
+ If None, uses RandomSampler.
+ loss_cls (TorchLossFunc): Classification loss function. Defaults to
+ F.binary_cross_entropy_with_logits.
+ loss_bbox (TorchLossFunc): Regression loss function. Defaults to
+ l1_loss.
+ """
+ matcher = (
+ MaxIoUMatcher(
+ thresholds=[0.3, 0.7],
+ labels=[0, -1, 1],
+ allow_low_quality_matches=True,
+ min_positive_iou=0.3,
+ )
+ if matcher is None
+ else matcher
+ )
+
+ sampler = (
+ RandomSampler(batch_size=256, positive_fraction=0.5)
+ if sampler is None
+ else sampler
+ )
+
+ super().__init__(
+ anchor_generator,
+ box_encoder,
+ matcher,
+ sampler,
+ loss_cls,
+ loss_bbox,
+ )
+
+ def forward(
+ self,
+ cls_outs: list[torch.Tensor],
+ reg_outs: list[torch.Tensor],
+ target_boxes: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ target_class_ids: list[torch.Tensor | float] | None = None,
+ ) -> DenseAnchorHeadLosses:
+ """Compute RPN classification and regression losses.
+
+ Args:
+ cls_outs (list[torch.Tensor]): Network classification outputs
+ at all scales.
+ reg_outs (list[torch.Tensor]): Network regression outputs
+ at all scales.
+ target_boxes (list[torch.Tensor]): Target bounding boxes.
+ images_hw (list[tuple[int, int]]): Image dimensions
+ without padding.
+ target_class_ids (list[torch.Tensor] | None): Target class labels.
+
+ Returns:
+ DenseAnchorHeadLosses: Classification and regression losses.
+ """
+ return super().forward(
+ cls_outs, reg_outs, target_boxes, images_hw, target_class_ids
+ )
diff --git a/vis4d/op/detect/typing.py b/vis4d/op/detect/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c9dc21d21b22a2f804e4e538fbe283566db9768
--- /dev/null
+++ b/vis4d/op/detect/typing.py
@@ -0,0 +1,22 @@
+"""Detect op typing."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from torch import Tensor
+
+
+class Proposals(NamedTuple):
+ """Output structure for 2D bounding box proposals."""
+
+ boxes: list[Tensor]
+ scores: list[Tensor]
+
+
+class Targets(NamedTuple):
+ """Output structure for targets."""
+
+ boxes: list[Tensor]
+ classes: list[Tensor]
+ labels: list[Tensor]
diff --git a/vis4d/op/detect/yolox.py b/vis4d/op/detect/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..269565db94f29ac85ff651d3e2d700b7d22d2e41
--- /dev/null
+++ b/vis4d/op/detect/yolox.py
@@ -0,0 +1,714 @@
+"""YOLOX detection head.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Sequence
+from typing import NamedTuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torchvision.ops import batched_nms
+
+from vis4d.common.distributed import reduce_mean
+from vis4d.common.typing import TorchLossFunc
+from vis4d.op.box.anchor import MlvlPointGenerator
+from vis4d.op.box.encoder import YOLOXBBoxDecoder
+from vis4d.op.box.matchers import SimOTAMatcher
+from vis4d.op.box.samplers import PseudoSampler
+from vis4d.op.layer.conv2d import Conv2d
+from vis4d.op.layer.weight_init import bias_init_with_prob
+from vis4d.op.loss import IoULoss
+from vis4d.op.loss.reducer import SumWeightedLoss
+
+from .common import DetOut
+
+
+class YOLOXOut(NamedTuple):
+ """YOLOX head outputs."""
+
+ # Logits for box classification for each feature level. The logit
+ # dimention is [batch_size, number of classes, height, width].
+ cls_score: list[torch.Tensor]
+ # Each box has regression for all classes for each feature level. So the
+ # tensor dimension is [batch_size, 4, height, width].
+ bbox_pred: list[torch.Tensor]
+ # Objectness scores for each feature level. The tensor dimension is
+ # [batch_size, 1, height, width]
+ objectness: list[torch.Tensor]
+
+
+def get_default_point_generator() -> MlvlPointGenerator:
+ """Get default point generator."""
+ return MlvlPointGenerator(strides=[8, 16, 32], offset=0)
+
+
+class YOLOXHead(nn.Module):
+ """YOLOX Head.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of input channels.
+ feat_channels (int, optional): Number of feature channels. Defaults to
+ 256.
+ stacked_convs (int, optional): Number of stacked convolutions. Defaults
+ to 2.
+ strides (Sequence[int], optional): Strides for each feature level.
+ Defaults to (8, 16, 32).
+ point_generator (MlvlPointGenerator, optional): Point generator.
+ Defaults to None.
+ box_decoder (YOLOXBBoxDecoder, optional): Bounding box decoder.
+ Defaults to None.
+ box_matcher (Matcher, optional): Bounding box matcher. Defaults to
+ None.
+ box_sampler (Sampler, optional): Bounding box sampler. Defaults to
+ None.
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ in_channels: int,
+ feat_channels: int = 256,
+ stacked_convs: int = 2,
+ strides: Sequence[int] = (8, 16, 32),
+ point_generator: MlvlPointGenerator | None = None,
+ box_decoder: YOLOXBBoxDecoder | None = None,
+ ):
+ """Creates an instance of the class."""
+ super().__init__()
+ self.point_generator = (
+ point_generator
+ if point_generator is not None
+ else get_default_point_generator()
+ )
+ if box_decoder is None:
+ self.box_decoder = YOLOXBBoxDecoder()
+ else:
+ self.box_decoder = box_decoder
+
+ self.multi_level_cls_convs = nn.ModuleList()
+ self.multi_level_reg_convs = nn.ModuleList()
+ self.multi_level_conv_cls = nn.ModuleList()
+ self.multi_level_conv_reg = nn.ModuleList()
+ self.multi_level_conv_obj = nn.ModuleList()
+ for _ in strides:
+ self.multi_level_cls_convs.append(
+ self._build_stacked_convs(
+ in_channels, feat_channels, stacked_convs
+ )
+ )
+ self.multi_level_reg_convs.append(
+ self._build_stacked_convs(
+ in_channels, feat_channels, stacked_convs
+ )
+ )
+ conv_cls, conv_reg, conv_obj = self._build_predictor(
+ feat_channels, num_classes
+ )
+ self.multi_level_conv_cls.append(conv_cls)
+ self.multi_level_conv_reg.append(conv_reg)
+ self.multi_level_conv_obj.append(conv_obj)
+ self._init_weights()
+
+ def _build_stacked_convs(
+ self, in_channels: int, feat_channels: int, stacked_convs: int
+ ) -> nn.Module:
+ """Initialize conv layers of a single level head.
+
+ Args:
+ in_channels (int): Number of input channels.
+ feat_channels (int): Number of feature channels.
+ stacked_convs (int): Number of stacked conv layers.
+ """
+ stacked_conv_layers = []
+ for i in range(stacked_convs):
+ chn = in_channels if i == 0 else feat_channels
+ stacked_conv_layers.append(
+ Conv2d(
+ chn,
+ feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm=nn.BatchNorm2d(
+ feat_channels, eps=0.001, momentum=0.03
+ ),
+ activation=nn.SiLU(inplace=True),
+ bias=False,
+ )
+ )
+ return nn.Sequential(*stacked_conv_layers)
+
+ def _build_predictor(
+ self, feat_channels: int, num_classes: int
+ ) -> tuple[nn.Module, nn.Module, nn.Module]:
+ """Initialize predictor layers of a single level head.
+
+ Args:
+ feat_channels (int): Number of input channels.
+ num_classes (int): Number of classes.
+ """
+ conv_cls = nn.Conv2d(feat_channels, num_classes, 1)
+ conv_reg = nn.Conv2d(feat_channels, 4, 1)
+ conv_obj = nn.Conv2d(feat_channels, 1, 1)
+ return conv_cls, conv_reg, conv_obj
+
+ def _init_weights(self) -> None:
+ """Initialize weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(
+ m.weight,
+ a=math.sqrt(5),
+ mode="fan_in",
+ nonlinearity="leaky_relu",
+ )
+ bias_init = bias_init_with_prob(0.01)
+ for conv_cls, conv_obj in zip(
+ self.multi_level_conv_cls, self.multi_level_conv_obj
+ ):
+ conv_cls.bias.data.fill_(bias_init) # type: ignore
+ conv_obj.bias.data.fill_(bias_init) # type: ignore
+
+ def forward(self, features: list[torch.Tensor]) -> YOLOXOut:
+ """Forward pass of YOLOX head.
+
+ Args:
+ features (list[torch.Tensor]): Input features.
+
+ Returns:
+ YOLOXOut: Classification, box, and objectness predictions.
+ """
+ cls_score, bbox_pred, objectness = [], [], []
+ for feature, cls_conv, reg_conv, conv_cls, conv_reg, conv_obj in zip(
+ features,
+ self.multi_level_cls_convs,
+ self.multi_level_reg_convs,
+ self.multi_level_conv_cls,
+ self.multi_level_conv_reg,
+ self.multi_level_conv_obj,
+ ):
+ cls_feat = cls_conv(feature)
+ reg_feat = reg_conv(feature)
+
+ cls_score.append(conv_cls(cls_feat))
+ bbox_pred.append(conv_reg(reg_feat))
+ objectness.append(conv_obj(reg_feat))
+ return YOLOXOut(
+ cls_score=cls_score, bbox_pred=bbox_pred, objectness=objectness
+ )
+
+ def __call__(self, features: list[torch.Tensor]) -> YOLOXOut:
+ """Type definition for call implementation."""
+ return self._call_impl(features)
+
+
+def bboxes_nms(
+ cls_scores: torch.Tensor,
+ bboxes: torch.Tensor,
+ objectness: torch.Tensor,
+ nms_threshold: float = 0.65,
+ score_thr: float = 0.01,
+ nms_pre: int = -1,
+ max_per_img: int = -1,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Decode box energies into detections for a single image.
+
+ Detections are post-processed via NMS. NMS is performed per level.
+ Afterwards, select topk detections.
+
+ Args:
+ cls_scores (torch.Tensor): topk class scores per level.
+ bboxes (torch.Tensor): topk class labels per level.
+ objectness (torch.Tensor): topk regression params per level.
+ nms_threshold (float, optional): iou threshold for NMS.
+ Defaults to 0.65.
+ score_thr (float, optional): score threshold to filter detections.
+ Defaults to 0.01.
+ nms_pre (int, optional): number of topk results before NMS.
+ Defaults to -1 (all).
+ max_per_img (int, optional): number of topk results after NMS.
+ Defaults to -1 (all).
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: decoded boxes, scores,
+ and labels.
+ """
+ if nms_pre == -1:
+ nms_pre = len(cls_scores)
+ if max_per_img == -1:
+ max_per_img = len(cls_scores)
+ max_scores, labels = torch.max(cls_scores, 1)
+ valid_mask = objectness * max_scores >= score_thr
+ valid_idxs = valid_mask.nonzero()[:, 0]
+ num_topk = min(nms_pre, valid_mask.sum()) # type: ignore
+
+ scores, idxs = (max_scores[valid_mask] * objectness[valid_mask]).sort(
+ descending=True
+ )
+ scores = scores[:num_topk]
+ topk_idxs = valid_idxs[idxs[:num_topk]]
+
+ bboxes = bboxes[topk_idxs]
+ labels = labels[topk_idxs]
+
+ if labels.numel() > 0:
+ keep = batched_nms(bboxes, scores, labels, nms_threshold)[:max_per_img]
+ return bboxes[keep], scores[keep], labels[keep]
+ return bboxes.new_zeros(0, 4), scores.new_zeros(0), labels.new_zeros(0)
+
+
+def preprocess_outputs(
+ cls_outs: list[torch.Tensor],
+ reg_outs: list[torch.Tensor],
+ obj_outs: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ point_generator: MlvlPointGenerator,
+ box_decoder: YOLOXBBoxDecoder,
+) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Preprocess model outputs before postprocessing/loss computation.
+
+ Args:
+ cls_outs (list[torch.Tensor]): [N, C, H, W] per scale.
+ reg_outs (list[torch.Tensor]): [N, 4, H, W] per scale.
+ obj_outs (list[torch.Tensor]): [N, 1, H, W] per scale.
+ images_hw (list[tuple[int, int]]): List of image sizes.
+ point_generator (MlvlPointGenerator): Point generator.
+ box_decoder (YOLOXBBoxDecoder): Box decoder.
+
+ Returns:
+ tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: Flattened outputs.
+ """
+ dtype, device = cls_outs[0].dtype, cls_outs[0].device
+ num_imgs = len(images_hw)
+ num_classes = cls_outs[0].shape[1]
+ featmap_sizes: list[tuple[int, int]] = [
+ tuple(featmap.size()[-2:]) for featmap in cls_outs # type: ignore
+ ]
+ assert len(featmap_sizes) == point_generator.num_levels
+ mlvl_points = point_generator.grid_priors(
+ featmap_sizes, dtype=dtype, device=device, with_stride=True
+ )
+
+ # flatten cls_outs, reg_outs and obj_outs
+ cls_list = [
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, num_classes)
+ for cls_score in cls_outs
+ ]
+ reg_list = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ for bbox_pred in reg_outs
+ ]
+ obj_list = [
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
+ for objectness in obj_outs
+ ]
+
+ flatten_cls = torch.cat(cls_list, dim=1)
+ flatten_reg = torch.cat(reg_list, dim=1)
+ flatten_obj = torch.cat(obj_list, dim=1)
+ flatten_points = torch.cat(mlvl_points)
+
+ flatten_boxes = box_decoder(flatten_points, flatten_reg)
+ return flatten_cls, flatten_reg, flatten_obj, flatten_points, flatten_boxes
+
+
+class YOLOXPostprocess(nn.Module):
+ """Postprocess detections from YOLOX detection head."""
+
+ def __init__(
+ self,
+ point_generator: MlvlPointGenerator,
+ box_decoder: YOLOXBBoxDecoder,
+ nms_threshold: float = 0.65,
+ score_thr: float = 0.01,
+ nms_pre: int = -1,
+ max_per_img: int = -1,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ point_generator (MlvlPointGenerator): Point generator.
+ box_decoder (YOLOXBBoxDecoder): Box decoder.
+ nms_threshold (float, optional): IoU threshold for NMS. Defaults to
+ 0.65.
+ score_thr (float, optional): Score threshold to filter detections.
+ Defaults to 0.01.
+ nms_pre (int, optional): Number of topk results before NMS.
+ Defaults to -1 (all).
+ max_per_img (int, optional): Number of topk results after NMS.
+ Defaults to -1 (all).
+ """
+ super().__init__()
+ self.point_generator = point_generator
+ self.box_decoder = box_decoder
+ self.nms_threshold = nms_threshold
+ self.score_thr = score_thr
+ self.nms_pre = nms_pre
+ self.max_per_img = max_per_img
+
+ def forward(
+ self,
+ cls_outs: list[torch.Tensor],
+ reg_outs: list[torch.Tensor],
+ obj_outs: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Forward pass.
+
+ Args:
+ cls_outs (list[torch.Tensor]): [N, C, H, W] per scale.
+ reg_outs (list[torch.Tensor]): [N, 4, H, W] per scale.
+ obj_outs (list[torch.Tensor]): [N, 1, H, W] per scale.
+ images_hw (list[tuple[int, int]]): list of image sizes.
+
+ Returns:
+ DetOut: Detection outputs.
+ """
+ flatten_cls, _, flatten_obj, _, flatten_boxes = preprocess_outputs(
+ cls_outs,
+ reg_outs,
+ obj_outs,
+ images_hw,
+ self.point_generator,
+ self.box_decoder,
+ )
+ flatten_cls, flatten_obj = flatten_cls.sigmoid(), flatten_obj.sigmoid()
+
+ bbox_list, score_list, label_list = [], [], []
+ for img_id, _ in enumerate(images_hw):
+ bboxes, scores, labels = bboxes_nms(
+ flatten_cls[img_id],
+ flatten_boxes[img_id],
+ flatten_obj[img_id],
+ nms_threshold=self.nms_threshold,
+ score_thr=self.score_thr,
+ nms_pre=self.nms_pre,
+ max_per_img=self.max_per_img,
+ )
+ bbox_list.append(bboxes)
+ score_list.append(scores)
+ label_list.append(labels)
+ return DetOut(bbox_list, score_list, label_list)
+
+ def __call__(
+ self,
+ cls_outs: list[torch.Tensor],
+ reg_outs: list[torch.Tensor],
+ obj_outs: list[torch.Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> DetOut:
+ """Type definition for function call."""
+ return self._call_impl(cls_outs, reg_outs, obj_outs, images_hw)
+
+
+class YOLOXHeadLosses(NamedTuple):
+ """YOLOX head loss container."""
+
+ loss_cls: Tensor
+ loss_bbox: Tensor
+ loss_obj: Tensor
+ loss_l1: Tensor | None
+
+
+def bbox_xyxy_to_cxcywh(bbox: Tensor) -> Tensor:
+ """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
+ return torch.cat(bbox_new, dim=-1)
+
+
+def get_l1_target(
+ bbox_target: Tensor, priors: Tensor, eps: float = 1e-8
+) -> Tensor:
+ """Convert gt bboxes to center offset and log width height.
+
+ Args:
+ bbox_target (Tensor): Shape (n, 4) for ground-truth bboxes.
+ priors (Tensor): Shape (n, 4) for prior boxes.
+ eps (float, optional): Epsilon for numerical stability. Defaults to
+ 1e-8.
+ """
+ l1_target = bbox_target.new_zeros((len(bbox_target), 4))
+ gt_cxcywh = bbox_xyxy_to_cxcywh(bbox_target)
+ l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:]
+ l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
+ return l1_target
+
+
+class YOLOXHeadLoss(nn.Module):
+ """Loss of YOLOX head."""
+
+ def __init__(
+ self,
+ num_classes: int,
+ point_generator: MlvlPointGenerator | None = None,
+ box_decoder: YOLOXBBoxDecoder | None = None,
+ loss_cls: TorchLossFunc = F.binary_cross_entropy_with_logits,
+ loss_bbox: TorchLossFunc = IoULoss(mode="square", eps=1e-16),
+ loss_obj: TorchLossFunc = F.binary_cross_entropy_with_logits,
+ loss_l1: TorchLossFunc | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of classes.
+ point_generator (MlvlPointGenerator): Point generator.
+ box_decoder (YOLOXBBoxDecoder): Box decoder.
+ loss_cls (TorchLossFunc, optional): Classification loss function.
+ Defaults to sigmoid_focal_loss.
+ loss_bbox (TorchLossFunc, optional): Regression loss function.
+ Defaults to l1_loss.
+ loss_obj (TorchLossFunc, optional): Objectness loss function.
+ Defaults to sigmoid_focal_loss.
+ loss_l1 (TorchLossFunc | None, optional): L1 loss function.
+ Defaults to None. Only used during the final few epochs.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.point_generator = (
+ point_generator
+ if point_generator is not None
+ else get_default_point_generator()
+ )
+ if box_decoder is None:
+ self.box_decoder = YOLOXBBoxDecoder()
+ else:
+ self.box_decoder = box_decoder
+ self.box_matcher = SimOTAMatcher()
+ self.box_sampler = PseudoSampler()
+ self.loss_cls = loss_cls
+ self.loss_bbox = loss_bbox
+ self.loss_obj = loss_obj
+ self.loss_l1 = loss_l1
+
+ def _get_target_single(
+ self,
+ cls_preds: Tensor,
+ objectness: Tensor,
+ priors: Tensor,
+ decoded_bboxes: Tensor,
+ gt_bboxes: Tensor,
+ gt_labels: Tensor,
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, int]:
+ """Compute YOLOX training targets in a single image.
+
+ Args:
+ cls_preds (Tensor): Classification predictions of one image,
+ a 2D-Tensor with shape [num_priors, num_classes]
+ objectness (Tensor): Objectness predictions of one image,
+ a 1D-Tensor with shape [num_priors]
+ priors (Tensor): All priors of one image, a 2D-Tensor with shape
+ [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
+ decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
+ a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
+ br_x, br_y] format.
+ gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
+ with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth labels of one image, a Tensor
+ with shape [num_gts].
+ """
+ num_priors = priors.size(0)
+ num_gts = gt_labels.size(0)
+ gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype)
+ # No target
+ if num_gts == 0:
+ cls_target = cls_preds.new_zeros((0, self.num_classes))
+ bbox_target = cls_preds.new_zeros((0, 4))
+ l1_target = cls_preds.new_zeros((0, 4))
+ obj_target = cls_preds.new_zeros((num_priors, 1))
+ foreground_mask = cls_preds.new_zeros(num_priors).bool()
+ return (
+ foreground_mask,
+ cls_target,
+ obj_target,
+ bbox_target,
+ l1_target,
+ 0,
+ )
+
+ # YOLOX uses center priors with 0.5 offset to assign targets,
+ # but use center priors without offset to regress bboxes.
+ offset_priors = torch.cat(
+ [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1
+ )
+
+ scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid()
+ match_result = self.box_matcher(
+ scores.sqrt_(),
+ offset_priors,
+ decoded_bboxes,
+ gt_bboxes,
+ gt_labels,
+ )
+ sampling_result = self.box_sampler(match_result)
+ positives = sampling_result.sampled_labels == 1
+ pos_inds = sampling_result.sampled_box_indices[positives]
+ pos_tgt_inds = sampling_result.sampled_target_indices[positives]
+ num_pos_per_img = pos_inds.size(0)
+
+ pos_ious = match_result.assigned_gt_iou[pos_inds]
+ # IOU aware classification score
+ cls_target = F.one_hot( # pylint: disable=not-callable
+ gt_labels[pos_tgt_inds], self.num_classes
+ ) * pos_ious.unsqueeze(-1)
+ obj_target = torch.zeros_like(objectness).unsqueeze(-1)
+ obj_target[pos_inds] = 1
+ bbox_target = gt_bboxes[pos_tgt_inds]
+ if self.loss_l1 is not None:
+ l1_target = get_l1_target(bbox_target, priors[pos_inds])
+ else:
+ l1_target = bbox_target.new_zeros((len(bbox_target), 4))
+ foreground_mask = torch.zeros_like(objectness).to(torch.bool)
+ foreground_mask[pos_inds] = 1
+ return (
+ foreground_mask,
+ cls_target,
+ obj_target,
+ bbox_target,
+ l1_target,
+ num_pos_per_img,
+ )
+
+ def forward(
+ self,
+ cls_outs: list[Tensor],
+ reg_outs: list[Tensor],
+ obj_outs: list[Tensor],
+ target_boxes: list[Tensor],
+ target_class_ids: list[Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> YOLOXHeadLosses:
+ """Compute YOLOX classification, regression, and objectness losses.
+
+ Args:
+ cls_outs (list[Tensor]): Network classification outputs at all
+ scales.
+ reg_outs (list[Tensor]): Network regression outputs at all scales.
+ obj_outs (list[Tensor]): Network objectness outputs at all scales.
+ target_boxes (list[Tensor]): Target bounding boxes.
+ images_hw (list[tuple[int, int]]): Image dimensions without
+ padding.
+ target_class_ids (list[Tensor]): Target class labels.
+
+ Returns:
+ YOLOXHeadLosses: YOLOX losses.
+ """
+ (
+ flatten_cls,
+ flatten_reg,
+ flatten_obj,
+ flatten_points,
+ flatten_boxes,
+ ) = preprocess_outputs(
+ cls_outs,
+ reg_outs,
+ obj_outs,
+ images_hw,
+ self.point_generator,
+ self.box_decoder,
+ )
+
+ num_imgs = len(images_hw)
+ pos_masks_list, cls_targets_list, obj_targets_list = [], [], []
+ bbox_targets_list, l1_targets_list, num_fg_imgs_list = [], [], []
+ for flat_cls, flat_obj, flat_pts, flat_bxs, tgt_bxs, tgt_cls in zip(
+ flatten_cls.detach(),
+ flatten_obj.detach(),
+ flatten_points.unsqueeze(0).repeat(num_imgs, 1, 1),
+ flatten_boxes.detach(),
+ target_boxes,
+ target_class_ids,
+ ):
+ targets = self._get_target_single(
+ flat_cls, flat_obj, flat_pts, flat_bxs, tgt_bxs, tgt_cls
+ )
+ pos_masks_list.append(targets[0])
+ cls_targets_list.append(targets[1])
+ obj_targets_list.append(targets[2])
+ bbox_targets_list.append(targets[3])
+ l1_targets_list.append(targets[4])
+ num_fg_imgs_list.append(targets[5])
+
+ num_pos = torch.tensor(
+ sum(num_fg_imgs_list), dtype=torch.float, device=flatten_cls.device
+ )
+ num_total_samples: Tensor | float = max( # type: ignore
+ reduce_mean(num_pos), 1.0
+ )
+
+ pos_masks = torch.cat(pos_masks_list, 0)
+ cls_targets = torch.cat(cls_targets_list, 0)
+ obj_targets = torch.cat(obj_targets_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ if self.loss_l1 is not None:
+ l1_targets = torch.cat(l1_targets_list, 0)
+
+ loss_obj = self.loss_obj(
+ flatten_obj.view(-1, 1), obj_targets, reduction="none"
+ )
+ loss_obj = SumWeightedLoss(1.0, num_total_samples)(loss_obj)
+
+ if num_pos > 0:
+ loss_cls = self.loss_cls(
+ flatten_cls.view(-1, self.num_classes)[pos_masks],
+ cls_targets,
+ reduction="none",
+ )
+ loss_cls = SumWeightedLoss(1.0, num_total_samples)(loss_cls)
+ loss_bbox = self.loss_bbox(
+ flatten_boxes.view(-1, 4)[pos_masks], bbox_targets
+ )
+ loss_bbox = SumWeightedLoss(5.0, num_total_samples)(loss_bbox)
+ else:
+ loss_cls = flatten_cls.sum() * 0
+ loss_bbox = flatten_boxes.sum() * 0
+
+ if self.loss_l1 is not None:
+ if num_pos > 0:
+ loss_l1 = self.loss_l1(
+ flatten_reg.view(-1, 4)[pos_masks], l1_targets
+ )
+ loss_l1 = SumWeightedLoss(1.0, num_total_samples)(loss_l1)
+ else:
+ loss_l1 = flatten_reg.sum() * 0
+ else:
+ loss_l1 = None
+
+ return YOLOXHeadLosses(
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ loss_obj=loss_obj,
+ loss_l1=loss_l1,
+ )
+
+ def __call__(
+ self,
+ cls_outs: list[Tensor],
+ reg_outs: list[Tensor],
+ obj_outs: list[Tensor],
+ target_boxes: list[Tensor],
+ target_class_ids: list[Tensor],
+ images_hw: list[tuple[int, int]],
+ ) -> YOLOXHeadLosses:
+ """Type definition."""
+ return self._call_impl(
+ cls_outs,
+ reg_outs,
+ obj_outs,
+ target_boxes,
+ target_class_ids,
+ images_hw,
+ )
diff --git a/vis4d/op/detect3d/__init__.py b/vis4d/op/detect3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c523fd6406a8767bafa438e90cc4db159e9ab032
--- /dev/null
+++ b/vis4d/op/detect3d/__init__.py
@@ -0,0 +1 @@
+"""3D detector module."""
diff --git a/vis4d/op/detect3d/bevformer/__init__.py b/vis4d/op/detect3d/bevformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..efdc356055d2a9619afc603c6748401054239d68
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/__init__.py
@@ -0,0 +1,6 @@
+"""BEVFormer ops."""
+
+from .bevformer import BEVFormerHead
+from .grid_mask import GridMask
+
+__all__ = ["BEVFormerHead", "GridMask"]
diff --git a/vis4d/op/detect3d/bevformer/bevformer.py b/vis4d/op/detect3d/bevformer/bevformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdcf87a8a592d07d452ee69de3a12c253b20f07b
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/bevformer.py
@@ -0,0 +1,298 @@
+"""BEVFormer head."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+from vis4d.data.const import AxisMode
+from vis4d.op.box.box3d import transform_boxes3d
+from vis4d.op.box.encoder.bevformer import NMSFreeDecoder
+from vis4d.op.geometry.rotation import (
+ euler_angles_to_matrix,
+ matrix_to_quaternion,
+ rotate_velocities,
+)
+from vis4d.op.layer.positional_encoding import LearnedPositionalEncoding
+from vis4d.op.layer.transformer import get_clones, inverse_sigmoid
+from vis4d.op.layer.weight_init import bias_init_with_prob
+
+from ..common import Detect3DOut
+from .transformer import PerceptionTransformer
+
+
+def bbox3d2result(
+ bbox_list: list[tuple[Tensor, Tensor, Tensor]], lidar2global: Tensor
+) -> Detect3DOut:
+ """Convert BEVFormer detection results to Detect3DOut.
+
+ Args:
+ bbox_list (list[tuple[Tensor, Tensor, Tensor]): List of bounding boxes,
+ scores and labels.
+ lidar2global (Tensor): Lidar to global transformation (B, 4, 4).
+
+ Returns:
+ Detect3DOut: Detection results.
+ """
+ boxes_3d = []
+ velocities = []
+ class_ids = []
+ scores_3d = []
+ for i, (bboxes, scores, labels) in enumerate(bbox_list):
+ # move boxes from lidar to global coordinate system
+ yaw = bboxes.new_zeros(bboxes.shape[0], 3)
+ yaw[:, 2] = bboxes[:, 6]
+ orientation = matrix_to_quaternion(euler_angles_to_matrix(yaw))
+
+ boxes3d_lidar = torch.cat([bboxes[:, :6], orientation], dim=1)
+ boxes_3d.append(
+ transform_boxes3d(
+ boxes3d_lidar, lidar2global[i], AxisMode.LIDAR, AxisMode.ROS
+ )
+ )
+
+ _velocities = bboxes.new_zeros(bboxes.shape[0], 3)
+ _velocities[:, :2] = bboxes[:, -2:]
+ velocities.append(rotate_velocities(_velocities, lidar2global[i]))
+
+ class_ids.append(labels)
+ scores_3d.append(scores)
+
+ return Detect3DOut(boxes_3d, velocities, class_ids, scores_3d)
+
+
+class BEVFormerHead(nn.Module):
+ """BEVFormer 3D detection head."""
+
+ def __init__(
+ self,
+ num_classes: int = 10,
+ embed_dims: int = 256,
+ num_query: int = 900,
+ transformer: PerceptionTransformer | None = None,
+ num_reg_fcs: int = 2,
+ num_cls_fcs: int = 2,
+ point_cloud_range: Sequence[float] = (
+ -51.2,
+ -51.2,
+ -5.0,
+ 51.2,
+ 51.2,
+ 3.0,
+ ),
+ bev_h: int = 200,
+ bev_w: int = 200,
+ ) -> None:
+ """Initialize BEVFormerHead.
+
+ Args:
+ num_classes (int, optional): Number of classes. Defaults to 10.
+ embed_dims (int, optional): Embedding dimensions. Defaults to 256.
+ num_query (int, optional): Number of queries. Defaults to 900.
+ transformer (PerceptionTransformer, optional): Transformer.
+ Defaults to None. If None, a default transformer will be
+ created.
+ num_reg_fcs (int, optional): Number of fully connected layers in
+ regression branch. Defaults to 2.
+ num_cls_fcs (int, optional): Number of fully connected layers in
+ classification branch. Defaults to 2.
+ point_cloud_range (Sequence[float], optional): Point cloud range.
+ Defaults to (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0).
+ bev_h (int, optional): BEV height. Defaults to 200.
+ bev_w (int, optional): BEV width. Defaults to 200.
+ """
+ super().__init__()
+ self.embed_dims = embed_dims
+ self.num_reg_fcs = num_reg_fcs
+ self.bev_h = bev_h
+ self.bev_w = bev_w
+
+ self.positional_encoding = LearnedPositionalEncoding(
+ num_feats=embed_dims // 2, row_num_embed=bev_h, col_num_embed=bev_w
+ )
+
+ self.cls_out_channels = num_classes
+
+ self.transformer = transformer or PerceptionTransformer(
+ embed_dims=embed_dims
+ )
+
+ self.code_size = 10
+ self.num_query = num_query
+
+ self.box_decoder = NMSFreeDecoder(
+ num_classes=num_classes,
+ post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
+ max_num=300,
+ )
+ self.pc_range = list(point_cloud_range)
+ self.real_w = self.pc_range[3] - self.pc_range[0]
+ self.real_h = self.pc_range[4] - self.pc_range[1]
+ self.num_cls_fcs = num_cls_fcs - 1
+
+ self.code_weights = nn.Parameter(
+ torch.tensor(
+ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2],
+ requires_grad=False,
+ ),
+ requires_grad=False,
+ )
+
+ self._init_layers()
+ self._init_weights()
+
+ def _init_layers(self) -> None:
+ """Initialize classification branch and regression branch of head."""
+ cls_branch: list[nn.Module] = []
+ for _ in range(self.num_reg_fcs):
+ cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
+ cls_branch.append(nn.LayerNorm(self.embed_dims))
+ cls_branch.append(nn.ReLU(inplace=True))
+ cls_branch.append(nn.Linear(self.embed_dims, self.cls_out_channels))
+ fc_cls = nn.Sequential(*cls_branch)
+
+ reg_branch: list[nn.Module] = []
+ for _ in range(self.num_reg_fcs):
+ reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
+ reg_branch.append(nn.ReLU())
+ reg_branch.append(nn.Linear(self.embed_dims, self.code_size))
+ fc_reg = nn.Sequential(*reg_branch)
+
+ num_pred = self.transformer.decoder.num_layers
+
+ self.cls_branches = get_clones(fc_cls, num_pred)
+ self.reg_branches = get_clones(fc_reg, num_pred)
+
+ self.bev_embedding = nn.Embedding(
+ self.bev_h * self.bev_w, self.embed_dims
+ )
+ self.query_embedding = nn.Embedding(
+ self.num_query, self.embed_dims * 2
+ )
+
+ def _init_weights(self) -> None:
+ """Initialize weights."""
+ bias_init = bias_init_with_prob(0.01)
+ for m in self.cls_branches:
+ nn.init.constant_(m[-1].bias, bias_init) # type: ignore
+
+ def forward(
+ self,
+ mlvl_feats: list[Tensor],
+ can_bus: Tensor,
+ images_hw: tuple[int, int],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: Tensor,
+ prev_bev: Tensor | None = None,
+ ) -> tuple[Detect3DOut, Tensor]:
+ """Forward function.
+
+ Args:
+ mlvl_feats (list[Tensor]): Features from the upstream network, each
+ is with shape (B, N, C, H, W).
+ can_bus (Tensor): CAN bus data, with shape (B, 18).
+ images_hw (tuple[int, int]): Image height and width.
+ cam_intrinsics (list[Tensor]): Camera intrinsics.
+ cam_extrinsics (list[Tensor]): Camera extrinsics.
+ lidar_extrinsics (list[Tensor]): LiDAR extrinsics.
+ prev_bev (Tensor, optional): Previous BEV feature map, with shape
+ (B, C, H, W). Defaults to None.
+
+ Returns:
+ tuple[Detect3DOut, Tensor]: Detection results and BEV feature map.
+ """
+ batch_size = mlvl_feats[0].shape[0]
+ dtype = mlvl_feats[0].dtype
+ object_query_embeds = self.query_embedding.weight.to(dtype)
+ bev_queries = self.bev_embedding.weight.to(dtype)
+
+ bev_mask = bev_queries.new_zeros((batch_size, self.bev_h, self.bev_w))
+ bev_pos = self.positional_encoding(bev_mask)
+
+ bev_embed, hs, init_reference, inter_references = self.transformer(
+ mlvl_feats,
+ can_bus,
+ bev_queries,
+ object_query_embeds,
+ self.bev_h,
+ self.bev_w,
+ images_hw=images_hw,
+ cam_intrinsics=cam_intrinsics,
+ cam_extrinsics=cam_extrinsics,
+ lidar_extrinsics=lidar_extrinsics,
+ grid_length=(self.real_h / self.bev_h, self.real_w / self.bev_w),
+ bev_pos=bev_pos,
+ reg_branches=self.reg_branches,
+ prev_bev=prev_bev,
+ )
+
+ hs = hs.permute(0, 2, 1, 3)
+ outputs_classes = []
+ outputs_coords = []
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.cls_branches[lvl](hs[lvl])
+ outputs_coord = self.reg_branches[lvl](hs[lvl])
+
+ assert reference.shape[-1] == 3
+ outputs_coord[..., 0:2] += reference[..., 0:2]
+ outputs_coord[..., 0:2] = outputs_coord[..., 0:2].sigmoid()
+ outputs_coord[..., 4:5] += reference[..., 2:3]
+ outputs_coord[..., 4:5] = outputs_coord[..., 4:5].sigmoid()
+ outputs_coord[..., 0:1] = (
+ outputs_coord[..., 0:1] * (self.pc_range[3] - self.pc_range[0])
+ + self.pc_range[0]
+ )
+ outputs_coord[..., 1:2] = (
+ outputs_coord[..., 1:2] * (self.pc_range[4] - self.pc_range[1])
+ + self.pc_range[1]
+ )
+ outputs_coord[..., 4:5] = (
+ outputs_coord[..., 4:5] * (self.pc_range[5] - self.pc_range[2])
+ + self.pc_range[2]
+ )
+
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+
+ ret_list: list[tuple[Tensor, Tensor, Tensor]] = []
+ for cls_scores, bbox_preds in zip(
+ outputs_classes[-1], outputs_coords[-1]
+ ):
+ bboxes, scores, labels = self.box_decoder(cls_scores, bbox_preds)
+
+ # mapping MMDetection3D's coordinate to our LIDAR coordinate
+ bboxes[:, 6] = -(bboxes[:, 6] + np.pi / 2)
+
+ ret_list.append((bboxes, scores, labels))
+
+ return bbox3d2result(ret_list, lidar_extrinsics), bev_embed
+
+ def __call__(
+ self,
+ mlvl_feats: list[Tensor],
+ can_bus: Tensor,
+ images_hw: tuple[int, int],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: Tensor,
+ prev_bev: Tensor | None = None,
+ ) -> tuple[Detect3DOut, Tensor]:
+ """Type definition."""
+ return self._call_impl(
+ mlvl_feats,
+ can_bus,
+ images_hw,
+ cam_intrinsics,
+ cam_extrinsics,
+ lidar_extrinsics,
+ prev_bev,
+ )
diff --git a/vis4d/op/detect3d/bevformer/decoder.py b/vis4d/op/detect3d/bevformer/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4d14ed64df2c9f4f33330d1b34580c548bfd2cb
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/decoder.py
@@ -0,0 +1,415 @@
+"""BEVFormer decoder."""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.op.layer.attention import MultiheadAttention
+from vis4d.op.layer.ms_deform_attn import (
+ MSDeformAttentionFunction,
+ is_power_of_2,
+ ms_deformable_attention_cpu,
+)
+from vis4d.op.layer.transformer import FFN, inverse_sigmoid
+from vis4d.op.layer.weight_init import constant_init, xavier_init
+
+
+class BEVFormerDecoder(nn.Module):
+ """Implements the decoder in DETR3D transformer."""
+
+ def __init__(
+ self,
+ num_layers: int = 6,
+ embed_dims: int = 256,
+ return_intermediate: bool = True,
+ ) -> None:
+ """Init.
+
+ Args:
+ num_layers (int): The number of decoder layers. Default: 6.
+ embed_dims (int): The embedding dimension. Default: 256.
+ return_intermediate (bool): Whether to return intermediate
+ results. Default: True.
+ """
+ super().__init__()
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+
+ self.layers = nn.ModuleList(
+ [
+ (BEVFormerDecoderLayer(embed_dims=embed_dims))
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ query: Tensor,
+ value: Tensor,
+ reference_points: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ query_pos: Tensor,
+ reg_branches: list[nn.Module],
+ ) -> tuple[Tensor, Tensor]:
+ """Forward function.
+
+ Args:
+ query (Tensor): Input query with shape (num_query, bs, embed_dims).
+ value (Tensor): Input value with shape (bs, num_query, embed_dims).
+ reference_points (Tensor): The reference points of offset. In shape
+ (bs, num_query, 4) when as_two_stage, otherwise has shape (bs,
+ num_query, 2).
+ spatial_shapes (Tensor): The spatial shapes of feature maps.
+ level_start_index (Tensor): The start index of each level.
+ query_pos (Tensor): The query position embedding.
+ reg_branches: (list[nn.Module]): Used for refining the regression
+ results.
+
+ Returns:
+ tuple[Tensor, Tensor]: The output of the decoder with reference
+ points. If return_intermediate is True, the output and
+ reference points of each layer will be stacked and return.
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ # BS, NUM_QUERY, NUM_LEVEL, 2
+ reference_points_input = reference_points[..., :2].unsqueeze(2)
+ output = layer(
+ output,
+ reference_points=reference_points_input,
+ value=value,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ query_pos=query_pos,
+ )
+ output = output.permute(1, 0, 2)
+
+ tmp = reg_branches[lid](output)
+
+ assert reference_points.shape[-1] == 3
+ new_reference_points = torch.zeros_like(reference_points)
+ new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
+ reference_points[..., :2]
+ )
+ new_reference_points[..., 2:3] = tmp[..., 4:5] + inverse_sigmoid(
+ reference_points[..., 2:3]
+ )
+
+ new_reference_points = new_reference_points.sigmoid()
+
+ reference_points = new_reference_points.detach()
+
+ output = output.permute(1, 0, 2)
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points
+ )
+
+ return output, reference_points
+
+
+class BEVFormerDecoderLayer(nn.Module):
+ """Implements decoder layer in DETR transformer."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ feedforward_channels: int = 512,
+ drop_out: float = 0.1,
+ ) -> None:
+ """Init.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ drop_out (float): The dropout rate of FFNs.
+ """
+ super().__init__()
+ self.attentions = nn.ModuleList()
+
+ self.attentions.append(
+ MultiheadAttention(
+ embed_dims=embed_dims,
+ num_heads=8,
+ attn_drop=0.1,
+ proj_drop=0.1,
+ )
+ )
+ self.attentions.append(
+ DecoderCrossAttention(embed_dims=embed_dims, num_levels=1)
+ )
+
+ self.embed_dims = embed_dims
+
+ self.ffns = nn.ModuleList()
+ self.ffns.append(
+ FFN(
+ embed_dims=self.embed_dims,
+ feedforward_channels=feedforward_channels,
+ dropout=drop_out,
+ )
+ )
+
+ self.norms = nn.ModuleList()
+ for _ in range(3):
+ self.norms.append(nn.LayerNorm(self.embed_dims))
+
+ def forward(
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ value: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ query_pos: Tensor | None = None,
+ ) -> Tensor:
+ """Forward.
+
+ Args:
+ query (Tensor): The input query, has shape (bs, num_queries, dim).
+ reference_points (Tensor): The reference points of offset. In shape
+ (bs, num_query, 4) when as_two_stage, otherwise has shape (bs,
+ num_query, 2).
+ value (Tensor, optional): The input value, has shape (bs, num_keys,
+ dim).
+ spatial_shapes (Tensor): The spatial shapes of feature maps.
+ level_start_index (Tensor): The start index of each level.
+ query_pos (Tensor, optional): The positional encoding for `query`,
+ has the same shape as `query`. If not `None`, it will be added
+ to `query` before forward function. Defaults to `None`.
+
+ Returns:
+ Tensor: forwarded results, has shape (bs, num_queries, dim).
+ """
+ query = self.attentions[0](
+ query=query,
+ key=query,
+ value=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ )
+
+ query = self.norms[0](query)
+
+ query = self.attentions[1](
+ query=query,
+ reference_points=reference_points,
+ value=value,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ query_pos=query_pos,
+ )
+
+ query = self.norms[1](query)
+
+ query = self.ffns[0](query)
+
+ query = self.norms[2](query)
+
+ return query
+
+
+class DecoderCrossAttention(nn.Module):
+ """Custom Multi-Scale Deformable Attention."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_heads: int = 8,
+ num_levels: int = 4,
+ num_points: int = 4,
+ im2col_step: int = 64,
+ dropout: float = 0.1,
+ batch_first: bool = False,
+ ) -> None:
+ """Initialization.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention.
+ Default: 256.
+ num_heads (int): Parallel attention heads. Default: 8.
+ num_levels (int): The number of feature map used in Attention.
+ Default: 4.
+ num_points (int): The number of sampling points for each query in
+ each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`.
+ Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of (batch, n,
+ embed_dim) or (n, batch, embed_dim). Default to False.
+ """
+ super().__init__()
+ if embed_dims % num_heads != 0:
+ raise ValueError(
+ f"embed_dims must be divisible by num_heads, "
+ f"but got {embed_dims} and {num_heads}"
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+
+ is_power_of_2(embed_dims // num_heads)
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2
+ )
+ self.attention_weights = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points
+ )
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.0)
+ thetas = torch.mul(
+ torch.arange(self.num_heads, dtype=torch.float32),
+ (2.0 * math.pi / self.num_heads),
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.num_heads, 1, 1, 2)
+ .repeat(1, self.num_levels, self.num_points, 1)
+ )
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0.0, bias=0.0)
+ xavier_init(self.value_proj, distribution="uniform", bias=0.0)
+ xavier_init(self.output_proj, distribution="uniform", bias=0.0)
+
+ def forward( # pylint: disable=duplicate-code
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ value: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ key_padding_mask: Tensor | None = None,
+ query_pos: Tensor | None = None,
+ identity: Tensor | None = None,
+ ) -> Tensor:
+ """Forward.
+
+ Args:
+ query (Tensor): Query of Transformer with shape (num_query, bs,
+ embed_dims).
+ reference_points (Tensor): The normalized reference points with
+ shape (bs, num_query, num_levels, 2), all elements is range in
+ [0, 1], top-left (0,0), bottom-right (1, 1), including padding
+ area. or (N, Length_{query}, num_levels, 4), add additional two
+ dimensions is (w, h) to form reference boxes.
+ value (Tensor): The value tensor with shape (num_key, bs,
+ embed_dims).
+ spatial_shapes (Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ identity (Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ if identity is None:
+ identity = query
+
+ if query_pos is not None:
+ query = query + query_pos
+
+ # change to (bs, num_query ,embed_dims)
+ if not self.batch_first:
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
+ )
+
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
+ )
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points
+ )
+
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1
+ )
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets
+ / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets
+ / self.num_points
+ * reference_points[:, :, None, :, None, 2:]
+ * 0.5
+ )
+ else:
+ raise ValueError(
+ f"Last dim of reference_points must be"
+ f" 2 or 4, but get {reference_points.shape[-1]} instead."
+ )
+
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MSDeformAttentionFunction.apply(
+ value,
+ spatial_shapes,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ else:
+ output = ms_deformable_attention_cpu(
+ value, spatial_shapes, sampling_locations, attention_weights
+ )
+
+ output = self.output_proj(output)
+
+ # (num_query, bs ,embed_dims)
+ if not self.batch_first:
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
diff --git a/vis4d/op/detect3d/bevformer/encoder.py b/vis4d/op/detect3d/bevformer/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddda533aedea9c0c216bb4ce51e49e58d4fb2949
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/encoder.py
@@ -0,0 +1,432 @@
+"""BEVFormer Encoder."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.op.geometry.transform import inverse_rigid_transform
+from vis4d.op.layer.transformer import FFN, get_clones
+
+from .spatial_cross_attention import SpatialCrossAttention
+from .temporal_self_attention import TemporalSelfAttention
+
+
+class BEVFormerEncoder(nn.Module):
+ """Attention with both self and cross attention."""
+
+ def __init__(
+ self,
+ num_layers: int = 6,
+ layer: BEVFormerEncoderLayer | None = None,
+ embed_dims: int = 256,
+ num_points_in_pillar: int = 4,
+ point_cloud_range: Sequence[float] = (
+ -51.2,
+ -51.2,
+ -5.0,
+ 51.2,
+ 51.2,
+ 3.0,
+ ),
+ return_intermediate: bool = False,
+ ) -> None:
+ """Init.
+
+ Args:
+ num_layers (int): Number of layers in the encoder.
+ layer (BEVFormerEncoderLayer, optional): Encoder layer. Defaults to
+ None. If None, a default layer will be used.
+ embed_dims (int): Embedding dimension.
+ num_points_in_pillar (int): Number of points in each pillar.
+ point_cloud_range (Sequence[float]): Range of the point cloud.
+ Defaults to (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0).
+ return_intermediate (bool): Whether to return intermediate outputs.
+ """
+ super().__init__()
+ self.num_layers = num_layers
+ self.embed_dims = embed_dims
+ self.num_points_in_pillar = num_points_in_pillar
+ self.pc_range = point_cloud_range
+ self.return_intermediate = return_intermediate
+
+ layer = layer or BEVFormerEncoderLayer(embed_dims=embed_dims)
+
+ self.layers = get_clones(layer, num=self.num_layers)
+
+ self.eps = 1e-5
+
+ def get_reference_points(
+ self,
+ bev_h: int,
+ bev_w: int,
+ dim: int,
+ batch_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> Tensor:
+ """Get the reference points used in SCA and TSA.
+
+ Args:
+ bev_h (int): Height of the BEV feature map.
+ bev_w (int): Width of the BEV feature map.
+ dim (int): Dimension of the reference points.
+ batch_size (int): Batch size.
+ device (torch.device): The device where reference_points should be.
+ dtype (torch.dtype): The dtype of reference_points.
+
+ Returns:
+ Tensor: reference points used in decoder, has shape (batch_size,
+ num_keys, num_levels, dim).
+ """
+ assert dim in {2, 3}, f"Unknown dim {dim}."
+ # Reference points in 3D space for spatial cross-attention (SCA)
+ if dim == 3:
+ height_z = self.pc_range[5] - self.pc_range[2]
+ zs = (
+ torch.linspace(
+ 0.5,
+ height_z - 0.5,
+ self.num_points_in_pillar,
+ dtype=dtype,
+ device=device,
+ )
+ .view(-1, 1, 1)
+ .expand(self.num_points_in_pillar, bev_h, bev_w)
+ / height_z
+ )
+ xs = (
+ torch.linspace(
+ 0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device
+ )
+ .view(1, 1, bev_w)
+ .expand(self.num_points_in_pillar, bev_h, bev_w)
+ / bev_w
+ )
+ ys = (
+ torch.linspace(
+ 0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device
+ )
+ .view(1, bev_h, 1)
+ .expand(self.num_points_in_pillar, bev_h, bev_w)
+ / bev_h
+ )
+ ref_3d = torch.stack((xs, ys, zs), -1)
+ ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
+ ref_3d = ref_3d[None].repeat(batch_size, 1, 1, 1)
+ return ref_3d
+
+ # Reference points on 2D bev plane for temporal self-attention (TSA)
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(
+ 0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device
+ ),
+ torch.linspace(
+ 0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device
+ ),
+ indexing="ij",
+ )
+ ref_y = ref_y.reshape(-1)[None] / bev_h
+ ref_x = ref_x.reshape(-1)[None] / bev_w
+ ref_2d = torch.stack((ref_x, ref_y), -1)
+ ref_2d = ref_2d.repeat(batch_size, 1, 1).unsqueeze(2)
+ return ref_2d
+
+ def point_sampling(
+ self,
+ reference_points: Tensor,
+ images_hw: tuple[int, int],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: Tensor,
+ ) -> tuple[Tensor, Tensor]:
+ """Sample points from reference points."""
+ lidar2img_list = []
+ for i, _cam_intrinsics in enumerate(cam_intrinsics):
+ viewpad = torch.eye(4, device=_cam_intrinsics.device)
+ viewpad[:3, :3] = _cam_intrinsics
+
+ lidar2img = (
+ viewpad
+ @ inverse_rigid_transform(cam_extrinsics[i])
+ @ lidar_extrinsics
+ )
+
+ lidar2img_list.append(lidar2img)
+
+ lidar2img = torch.stack(lidar2img_list, dim=1) # (B, N, 4, 4)
+
+ reference_points = reference_points.clone()
+ reference_points[..., 0:1] = (
+ reference_points[..., 0:1] * (self.pc_range[3] - self.pc_range[0])
+ + self.pc_range[0]
+ )
+ reference_points[..., 1:2] = (
+ reference_points[..., 1:2] * (self.pc_range[4] - self.pc_range[1])
+ + self.pc_range[1]
+ )
+ reference_points[..., 2:3] = (
+ reference_points[..., 2:3] * (self.pc_range[5] - self.pc_range[2])
+ + self.pc_range[2]
+ )
+
+ reference_points = torch.cat(
+ (reference_points, torch.ones_like(reference_points[..., :1])), -1
+ )
+
+ reference_points = reference_points.permute(1, 0, 2, 3)
+ d, b, num_query, _ = reference_points.shape
+ num_cam = lidar2img.size(1)
+
+ reference_points = (
+ reference_points.view(d, b, 1, num_query, 4)
+ .repeat(1, 1, num_cam, 1, 1)
+ .unsqueeze(-1)
+ )
+
+ lidar2img = lidar2img.view(1, b, num_cam, 1, 4, 4).repeat(
+ d, 1, 1, num_query, 1, 1
+ )
+
+ reference_points_cam = torch.matmul(
+ lidar2img, reference_points
+ ).squeeze(-1)
+
+ bev_mask = reference_points_cam[..., 2:3] > self.eps
+
+ reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
+ reference_points_cam[..., 2:3],
+ torch.mul(
+ torch.ones_like(reference_points_cam[..., 2:3]), self.eps
+ ),
+ )
+
+ reference_points_cam[..., 0] /= images_hw[1]
+ reference_points_cam[..., 1] /= images_hw[0]
+
+ bev_mask = (
+ bev_mask
+ & (reference_points_cam[..., 1:2] > 0.0)
+ & (reference_points_cam[..., 1:2] < 1.0)
+ & (reference_points_cam[..., 0:1] < 1.0)
+ & (reference_points_cam[..., 0:1] > 0.0)
+ )
+
+ reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
+ bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
+
+ return reference_points_cam, bev_mask
+
+ def forward(
+ self,
+ bev_query: Tensor,
+ value: Tensor,
+ bev_h: int,
+ bev_w: int,
+ bev_pos: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ prev_bev: Tensor | None,
+ shift: Tensor,
+ images_hw: tuple[int, int],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: Tensor,
+ ) -> Tensor:
+ """Forward.
+
+ Args:
+ bev_query (Tensor): Input BEV query with shape (num_query,
+ batch_size, embed_dims).
+ value (Tensor): Input multi-cameta features with shape (num_cam,
+ num_value, batch_size, embed_dims).
+ bev_h (int): BEV height.
+ bev_w (int): BEV width.
+ bev_pos (Tensor): BEV positional encoding with shape (batch_size,
+ embed_dims).
+ spatial_shapes (Tensor): Spatial shapes of multi-level
+ features with shape (num_levels, 2).
+ level_start_index (Tensor): Start index of each level with shape
+ (num_levels, ).
+ prev_bev (Tensor | None): Previous BEV features with shape
+ (batch_size, embed_dims).
+ shift (Tensor): Shift of each level with shape (num_levels, 2).
+ images_hw (tuple[int, int]): List of image height and width.
+ cam_intrinsics (list[Tensor]): List of camera intrinsics. In shape
+ (num_cam, batch_size, 3, 3)
+ cam_extrinsics (list[Tensor]): List of camera extrinsics. In shape
+ (num_cam, batch_size, 4, 4)
+ lidar_extrinsics (Tensor): LiDAR extrinsics. In shape (batch_size,
+ 4, 4)
+
+ Returns:
+ Tensor: Results with shape [batch_size, num_query, embed_dims]
+ when return_intermediate is False, otherwise it has shape
+ [num_layers, batch_size, num_query, embed_dims].
+ """
+ intermediate = []
+
+ ref_3d = self.get_reference_points(
+ bev_h,
+ bev_w,
+ dim=3,
+ batch_size=bev_query.size(1),
+ device=bev_query.device,
+ dtype=bev_query.dtype,
+ )
+
+ ref_2d = self.get_reference_points(
+ bev_h,
+ bev_w,
+ dim=2,
+ batch_size=bev_query.size(1),
+ device=bev_query.device,
+ dtype=bev_query.dtype,
+ )
+
+ reference_points_img, bev_mask = self.point_sampling(
+ ref_3d,
+ images_hw,
+ cam_intrinsics,
+ cam_extrinsics,
+ lidar_extrinsics,
+ )
+
+ shift_ref_2d = ref_2d.clone()
+ shift_ref_2d += shift[:, None, None, :]
+
+ bev_query = bev_query.permute(1, 0, 2)
+ bev_pos = bev_pos.permute(1, 0, 2)
+
+ batch_size, len_bev, num_bev_level, _ = ref_2d.shape
+ if prev_bev is not None:
+ prev_bev = prev_bev.permute(1, 0, 2)
+ prev_bev = torch.stack([prev_bev, bev_query], 1).reshape(
+ batch_size * 2, len_bev, -1
+ )
+ hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
+ batch_size * 2, len_bev, num_bev_level, 2
+ )
+ else:
+ hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
+ batch_size * 2, len_bev, num_bev_level, 2
+ )
+
+ for _, layer in enumerate(self.layers):
+ output = layer(
+ bev_query,
+ value,
+ bev_pos=bev_pos,
+ ref_2d=hybird_ref_2d,
+ bev_h=bev_h,
+ bev_w=bev_w,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ reference_points_img=reference_points_img,
+ bev_mask=bev_mask,
+ prev_bev=prev_bev,
+ )
+
+ bev_query = output
+
+ if self.return_intermediate:
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output
+
+
+class BEVFormerEncoderLayer(nn.Module):
+ """BEVFormer encoder layer."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ self_attn: TemporalSelfAttention | None = None,
+ cross_attn: SpatialCrossAttention | None = None,
+ feedforward_channels: int = 512,
+ drop_out: float = 0.1,
+ ) -> None:
+ """Init."""
+ super().__init__()
+ self.attentions = nn.ModuleList()
+
+ self_attn = self_attn or TemporalSelfAttention(
+ embed_dims=embed_dims, num_levels=1
+ )
+ self.attentions.append(self_attn)
+
+ cross_attn = cross_attn or SpatialCrossAttention(embed_dims=embed_dims)
+ self.attentions.append(cross_attn)
+
+ self.embed_dims = embed_dims
+
+ self.ffns = nn.ModuleList()
+ self.ffns.append(
+ FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ dropout=drop_out,
+ )
+ )
+
+ self.norms = nn.ModuleList()
+ for _ in range(3):
+ self.norms.append(nn.LayerNorm(self.embed_dims))
+
+ def forward(
+ self,
+ query: Tensor,
+ value: Tensor,
+ bev_pos: Tensor,
+ ref_2d: Tensor,
+ bev_h: int,
+ bev_w: int,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ reference_points_img: Tensor,
+ bev_mask: Tensor,
+ prev_bev: Tensor | None = None,
+ ) -> Tensor:
+ """Forward function.
+
+ self_attn -> norm -> cross_attn -> norm -> ffn -> norm
+
+ Returns:
+ Tensor: forwarded results with shape [num_queries, batch_size,
+ embed_dims].
+ """
+ # Temporal self attention
+ query = self.attentions[0](
+ query,
+ ref_2d,
+ prev_bev,
+ spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
+ level_start_index=torch.tensor([0], device=query.device),
+ query_pos=bev_pos,
+ )
+
+ query = self.norms[0](query)
+
+ # Spaital cross attention
+ query = self.attentions[1](
+ query,
+ reference_points_img,
+ value,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ bev_mask=bev_mask,
+ )
+
+ query = self.norms[1](query)
+
+ # FFN
+ query = self.ffns[0](query)
+
+ query = self.norms[2](query)
+
+ return query
diff --git a/vis4d/op/detect3d/bevformer/grid_mask.py b/vis4d/op/detect3d/bevformer/grid_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..01efd8d91a0b1915b83688db95b16447f2407e37
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/grid_mask.py
@@ -0,0 +1,82 @@
+"""Grid mask for BEVFormer."""
+
+import numpy as np
+import torch
+from PIL import Image
+from torch import Tensor, nn
+
+
+class GridMask(nn.Module):
+ """Grid Mask Layer."""
+
+ def __init__(
+ self,
+ use_h: bool,
+ use_w: bool,
+ rotate: int = 1,
+ offset: bool = False,
+ ratio: float = 0.5,
+ mode: int = 0,
+ prob: float = 1.0,
+ ) -> None:
+ """Init."""
+ super().__init__()
+ self.use_h = use_h
+ self.use_w = use_w
+ self.rotate = rotate
+ self.offset = offset
+ self.ratio = ratio
+ self.mode = mode
+ self.st_prob = prob
+ self.prob = prob
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward."""
+ if np.random.rand() > self.prob:
+ return x
+
+ device = x.device
+ n, c, h, w = x.size()
+ x = x.view(-1, h, w)
+ hh = int(1.5 * h)
+ ww = int(1.5 * w)
+ d = np.random.randint(2, h)
+ l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
+ mask = np.ones((hh, ww), np.float32)
+ st_h = np.random.randint(d)
+ st_w = np.random.randint(d)
+ if self.use_h:
+ for i in range(hh // d):
+ s = d * i + st_h
+ t = min(s + l, hh)
+ mask[s:t, :] *= 0
+ if self.use_w:
+ for i in range(ww // d):
+ s = d * i + st_w
+ t = min(s + l, ww)
+ mask[:, s:t] *= 0
+
+ r = np.random.randint(self.rotate)
+ mask_img = Image.fromarray(np.uint8(mask))
+ mask_img = mask_img.rotate(r)
+ mask = np.asarray(mask_img)
+ mask = mask[
+ (hh - h) // 2 : (hh - h) // 2 + h,
+ (ww - w) // 2 : (ww - w) // 2 + w,
+ ]
+
+ mask_tensor = torch.from_numpy(mask).to(x.dtype).to(device)
+ if self.mode == 1:
+ mask_tensor = 1 - mask_tensor
+ mask_tensor = mask_tensor.expand_as(x)
+ if self.offset:
+ offset = (
+ torch.from_numpy(2 * (np.random.rand(h, w) - 0.5))
+ .to(x.dtype)
+ .to(device)
+ )
+ x = x * mask_tensor + offset * (1 - mask_tensor)
+ else:
+ x = x * mask_tensor
+
+ return x.view(n, c, h, w)
diff --git a/vis4d/op/detect3d/bevformer/spatial_cross_attention.py b/vis4d/op/detect3d/bevformer/spatial_cross_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1840d797af164bb254d53bddea3e5b347361ecdc
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/spatial_cross_attention.py
@@ -0,0 +1,375 @@
+"""Spatial Cross Attention Module for BEVFormer."""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.op.layer.ms_deform_attn import (
+ MSDeformAttentionFunction,
+ is_power_of_2,
+ ms_deformable_attention_cpu,
+)
+from vis4d.op.layer.weight_init import constant_init, xavier_init
+
+
+class SpatialCrossAttention(nn.Module):
+ """An attention module used in BEVFormer."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_cams: int = 6,
+ dropout: float = 0.1,
+ deformable_attention: MSDeformableAttention3D | None = None,
+ ) -> None:
+ """Init.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention. Default:
+ 256.
+ num_cams (int): The number of cameras. Default: 6.
+ dropout (float): A Dropout layer on `inp_residual`. Default: 0.1.
+ deformable_attention (MSDeformableAttention3D, optional):
+ The deformable attention module. Default: None. If None,
+ we will use `MSDeformableAttention3D` with default
+ parameters.
+ """
+ super().__init__()
+ self.dropout = nn.Dropout(dropout)
+ self.deformable_attention = (
+ deformable_attention or MSDeformableAttention3D()
+ )
+ self.embed_dims = embed_dims
+ self.num_cams = num_cams
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weight()
+
+ def init_weight(self) -> None:
+ """Default initialization for Parameters of Module."""
+ xavier_init(self.output_proj, distribution="uniform", bias=0.0)
+
+ def forward(
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ value: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ bev_mask: Tensor,
+ query_pos: Tensor | None = None,
+ ) -> Tensor:
+ """Forward Function of Detr3DCrossAtten.
+
+ Args:
+ query (Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ reference_points (Tensor): The normalized reference points with
+ shape (bs, num_query, 4), all elements is range in [0, 1],
+ top-left (0,0), bottom-right (1, 1), including padding area.
+ Or (N, Length_{query}, num_levels, 4), add additional two
+ dimensions is (w, h) to form reference boxes.
+ value (Tensor): The value tensor with shape `(num_key, bs,
+ embed_dims)`. (B, N, C, H, W)
+ spatial_shapes (Tensor): Spatial shape of features in different
+ level. With shape (num_levels, 2), last dimension represent
+ (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ bev_mask (Tensor): The mask of BEV features with shape
+ (num_query, bs, num_levels, h, w).
+ query_pos (Tensor): The positional encoding for `query`. Default
+ None.
+
+ Returns:
+ Tensor: Forwarded results with shape [num_query, bs, embed_dims].
+ """
+ inp_residual = query
+ slots = torch.zeros_like(query)
+
+ if query_pos is not None:
+ query = query + query_pos
+
+ bs = query.shape[0]
+ d = reference_points.shape[3]
+
+ indexes = []
+ for i, mask_per_img in enumerate(bev_mask):
+ index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
+ indexes.append(index_query_per_img)
+ max_len = max(len(each) for each in indexes)
+
+ # Each camera only interacts with its corresponding BEV queries.
+ # This step can greatly save GPU memory.
+ queries_rebatch = query.new_zeros(
+ [bs, self.num_cams, max_len, self.embed_dims]
+ )
+ reference_points_rebatch = reference_points.new_zeros(
+ [bs, self.num_cams, max_len, d, 2]
+ )
+
+ for j in range(bs):
+ for i, _reference_points in enumerate(reference_points):
+ index_query_per_img = indexes[i]
+ queries_rebatch[j, i, : len(index_query_per_img)] = query[
+ j, index_query_per_img
+ ]
+ reference_points_rebatch[j, i, : len(index_query_per_img)] = (
+ _reference_points[j, index_query_per_img]
+ )
+
+ _, l, bs, _ = value.shape
+
+ value = value.permute(2, 0, 1, 3).reshape(
+ bs * self.num_cams, l, self.embed_dims
+ )
+
+ queries = self.deformable_attention(
+ query=queries_rebatch.view(
+ bs * self.num_cams, max_len, self.embed_dims
+ ),
+ reference_points=reference_points_rebatch.view(
+ bs * self.num_cams, max_len, d, 2
+ ),
+ value=value,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ ).view(bs, self.num_cams, max_len, self.embed_dims)
+
+ for j in range(bs):
+ for i, index_query_per_img in enumerate(indexes):
+ slots[j, index_query_per_img] += queries[
+ j, i, : len(index_query_per_img)
+ ]
+
+ count = bev_mask.sum(-1) > 0
+ count = count.permute(1, 2, 0).sum(-1)
+ count = torch.clamp(count, min=1.0)
+ slots = slots / count[..., None]
+ slots = self.output_proj(slots)
+
+ return self.dropout(slots) + inp_residual
+
+
+class MSDeformableAttention3D(nn.Module):
+ """An attention module used in BEVFormer based on Deformable-Detr.
+
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+ `_.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_heads: int = 8,
+ num_levels: int = 4,
+ num_points: int = 8,
+ im2col_step: int = 64,
+ batch_first: bool = True,
+ ) -> None:
+ """Init.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention. Default:
+ 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in
+ Attention. Default: 4.
+ num_points (int): The number of sampling points for each query in
+ each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ batch_first (bool): Key, Query and Value are shape of (batch, n,
+ embed_dim) or (n, batch, embed_dim). Default to True.
+ """
+ super().__init__()
+ if embed_dims % num_heads != 0:
+ raise ValueError(
+ f"embed_dims must be divisible by num_heads, "
+ f"but got {embed_dims} and {num_heads}"
+ )
+
+ self.batch_first = batch_first
+
+ is_power_of_2(embed_dims // num_heads)
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2
+ )
+ self.attention_weights = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points
+ )
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.0)
+ thetas = torch.mul(
+ torch.arange(self.num_heads, dtype=torch.float32),
+ (2.0 * math.pi / self.num_heads),
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.num_heads, 1, 1, 2)
+ .repeat(1, self.num_levels, self.num_points, 1)
+ )
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0.0, bias=0.0)
+ xavier_init(self.value_proj, distribution="uniform", bias=0.0)
+
+ def forward( # pylint: disable=duplicate-code
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ value: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ key_padding_mask: Tensor | None = None,
+ query_pos: Tensor | None = None,
+ ) -> Tensor:
+ """Forward.
+
+ Args:
+ query (Tensor): Query of Transformer with shape (bs, num_query,
+ embed_dims).
+ reference_points (Tensor): The normalized reference points with
+ shape (bs, num_query, num_levels, 2), all elements is range in
+ [0, 1], top-left (0,0), bottom-right (1, 1), including padding
+ area. Or (N, Length_{query}, num_levels, 4), add additional two
+ dimensions is (w, h) to form reference boxes.
+ value (Tensor): The value tensor with shape `(bs, num_key,
+ embed_dims)`.
+ spatial_shapes (Tensor): Spatial shape of features in different
+ levels. With shape (num_levels, 2), last dimension represents
+ (h, w).
+ level_start_index (Tensor): The start index of each level. A tensor
+ has shape ``(num_levels, )`` and can be represented as [0,
+ h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ key_padding_mask (Tensor): ByteTensor for value, with shape [bs,
+ num_key].
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ if query_pos is not None:
+ query = query + query_pos
+
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
+ )
+
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
+ )
+
+ attention_weights = attention_weights.softmax(-1)
+
+ # bs, num_query, num_heads, num_levels, num_all_points
+ attention_weights = attention_weights.view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points
+ )
+
+ # For each BEV query, it owns `num_z_anchors` in 3D space that
+ # having different heights. After proejcting, each BEV query has
+ # `num_z_anchors` reference points in each 2D image. For each
+ # referent point, we sample `num_points` sampling points.
+ # For `num_z_anchors` reference points, it has overall `num_points
+ # * num_z_anchors` sampling points.
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1
+ )
+
+ bs, num_query, num_z_anchors, xy = reference_points.shape
+ reference_points = reference_points[:, :, None, None, None, :, :]
+ sampling_offsets = (
+ sampling_offsets
+ / offset_normalizer[None, None, None, :, None, :]
+ )
+ (
+ bs,
+ num_query,
+ num_heads,
+ num_levels,
+ num_all_points,
+ xy,
+ ) = sampling_offsets.shape
+ sampling_offsets = sampling_offsets.view(
+ bs,
+ num_query,
+ num_heads,
+ num_levels,
+ num_all_points // num_z_anchors,
+ num_z_anchors,
+ xy,
+ )
+ sampling_locations = reference_points + sampling_offsets
+ (
+ bs,
+ num_query,
+ num_heads,
+ num_levels,
+ num_points,
+ num_z_anchors,
+ xy,
+ ) = sampling_locations.shape
+ assert num_all_points == num_points * num_z_anchors
+
+ # bs, num_query, num_heads, num_levels, num_all_points, 2
+ sampling_locations = sampling_locations.view(
+ bs, num_query, num_heads, num_levels, num_all_points, xy
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 , but get "
+ + f"{reference_points.shape[-1]} instead."
+ )
+
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MSDeformAttentionFunction.apply(
+ value,
+ spatial_shapes,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ else:
+ output = ms_deformable_attention_cpu(
+ value, spatial_shapes, sampling_locations, attention_weights
+ )
+
+ if not self.batch_first:
+ output = output.permute(1, 0, 2)
+
+ return output
diff --git a/vis4d/op/detect3d/bevformer/temporal_self_attention.py b/vis4d/op/detect3d/bevformer/temporal_self_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..00fafdead8979eb39386b53c4fa2520dab174433
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/temporal_self_attention.py
@@ -0,0 +1,285 @@
+"""An attention module used in BEVFormer based on Deformable-Detr."""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.op.layer.ms_deform_attn import (
+ MSDeformAttentionFunction,
+ is_power_of_2,
+ ms_deformable_attention_cpu,
+)
+from vis4d.op.layer.weight_init import constant_init, xavier_init
+
+
+class TemporalSelfAttention(nn.Module):
+ """Temperal Self Attention."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_heads: int = 8,
+ num_levels: int = 4,
+ num_points: int = 4,
+ num_bev_queue: int = 2,
+ im2col_step: int = 64,
+ dropout: float = 0.1,
+ batch_first: bool = True,
+ ) -> None:
+ """Init.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention. Default:
+ 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in Attention.
+ Default: 4.
+ num_points (int): The number of sampling points for each query in
+ each head. Default: 4.
+ num_bev_queue (int): In this version, we only use one history BEV
+ and one currenct BEV. The length of BEV queue is 2.
+ im2col_step (int): The step used in image_to_column. Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`. Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of (batch, n,
+ embed_dim) or (n, batch, embed_dim). Default to True.
+ """
+ super().__init__()
+ if embed_dims % num_heads != 0:
+ raise ValueError(
+ f"embed_dims must be divisible by num_heads, "
+ f"but got {embed_dims} and {num_heads}"
+ )
+
+ is_power_of_2(embed_dims // num_heads)
+
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.num_bev_queue = num_bev_queue
+ self.sampling_offsets = nn.Linear(
+ embed_dims * self.num_bev_queue,
+ num_bev_queue * num_heads * num_levels * num_points * 2,
+ )
+ self.attention_weights = nn.Linear(
+ embed_dims * self.num_bev_queue,
+ num_bev_queue * num_heads * num_levels * num_points,
+ )
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.0)
+ thetas = torch.mul(
+ torch.arange(self.num_heads, dtype=torch.float32),
+ (2.0 * math.pi / self.num_heads),
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.num_heads, 1, 1, 2)
+ .repeat(
+ 1, self.num_levels * self.num_bev_queue, self.num_points, 1
+ )
+ )
+
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0.0, bias=0.0)
+ xavier_init(self.value_proj, distribution="uniform", bias=0.0)
+ xavier_init(self.output_proj, distribution="uniform", bias=0.0)
+
+ def forward(
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ value: Tensor | None,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ key_padding_mask: Tensor | None = None,
+ identity: Tensor | None = None,
+ query_pos: Tensor | None = None,
+ ) -> Tensor:
+ """Forward Function of MultiScaleDeformAttention.
+
+ Args:
+ query (Tensor): Query of Transformer with shape (num_query, bs,
+ embed_dims).
+ reference_points (Tensor): The normalized reference points with
+ shape (bs, num_query, num_levels, 2), all elements is range in
+ [0, 1], top-left (0,0), bottom-right (1, 1), including padding
+ area. or (N, Length_{query}, num_levels, 4), add additional two
+ dimensions is (w, h) to form reference boxes.
+ value (Tensor): The value tensor with shape (num_key, bs,
+ embed_dims).
+ spatial_shapes (Tensor): Spatial shape of features in different
+ levels. With shape (num_levels, 2), last dimension represents
+ (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ key_padding_mask (Tensor): ByteTensor for value, with shape [bs,
+ num_key].
+ identity (Tensor): The tensor used for addition, with the
+ same shape as query. Default None. If None, query will be used.
+ query_pos (Tensor, optional): The positional encoding for query.
+ Default: None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ if value is None:
+ assert self.batch_first
+ bs, len_bev, c = query.shape
+ value = torch.stack([query, query], 1).reshape(bs * 2, len_bev, c)
+
+ if identity is None:
+ identity = query
+
+ if query_pos is not None:
+ query = query + query_pos
+
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, embed_dims = query.shape
+ _, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+ assert self.num_bev_queue == 2
+
+ query = torch.cat([value[:bs], query], -1)
+ value = self.value_proj(value)
+ assert isinstance(value, Tensor)
+
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+
+ value = value.reshape(
+ bs * self.num_bev_queue, num_value, self.num_heads, -1
+ )
+
+ sampling_offsets = self.sampling_offsets(query)
+ sampling_offsets = sampling_offsets.view(
+ bs,
+ num_query,
+ self.num_heads,
+ self.num_bev_queue,
+ self.num_levels,
+ self.num_points,
+ 2,
+ )
+ attention_weights = self.attention_weights(query).view(
+ bs,
+ num_query,
+ self.num_heads,
+ self.num_bev_queue,
+ self.num_levels * self.num_points,
+ )
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(
+ bs,
+ num_query,
+ self.num_heads,
+ self.num_bev_queue,
+ self.num_levels,
+ self.num_points,
+ )
+
+ attention_weights = (
+ attention_weights.permute(0, 3, 1, 2, 4, 5)
+ .reshape(
+ bs * self.num_bev_queue,
+ num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points,
+ )
+ .contiguous()
+ )
+
+ sampling_offsets = sampling_offsets.permute(
+ 0, 3, 1, 2, 4, 5, 6
+ ).reshape(
+ bs * self.num_bev_queue,
+ num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points,
+ 2,
+ )
+
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1
+ )
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets
+ / offset_normalizer[None, None, None, :, None, :]
+ )
+
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets
+ / self.num_points
+ * reference_points[:, :, None, :, None, 2:]
+ * 0.5
+ )
+ else:
+ raise ValueError(
+ f"Last dim of reference_points must be"
+ f" 2 or 4, but get {reference_points.shape[-1]} instead."
+ )
+
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MSDeformAttentionFunction.apply(
+ value,
+ spatial_shapes,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ else:
+ output = ms_deformable_attention_cpu(
+ value,
+ spatial_shapes,
+ sampling_locations,
+ attention_weights,
+ )
+
+ # output shape (bs*num_bev_queue, num_query, embed_dims)
+ # (bs*num_bev_queue, num_query, embed_dims)
+ # -> (num_query, embed_dims, bs*num_bev_queue)
+ output = output.permute(1, 2, 0)
+
+ # fuse history value and current value
+ # (num_query, embed_dims, bs*num_bev_queue)
+ # -> (num_query, embed_dims, bs, num_bev_queue)
+ output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
+ output = output.mean(-1)
+
+ # (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
+ output = output.permute(2, 0, 1)
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
diff --git a/vis4d/op/detect3d/bevformer/transformer.py b/vis4d/op/detect3d/bevformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..489c713f3dc141467daa637d497e8dc33e7d5c09
--- /dev/null
+++ b/vis4d/op/detect3d/bevformer/transformer.py
@@ -0,0 +1,271 @@
+"""BEVFormer transformer."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+from torchvision.transforms.functional import rotate
+
+from vis4d.op.layer.weight_init import xavier_init
+
+from .decoder import BEVFormerDecoder
+from .encoder import BEVFormerEncoder
+
+
+class PerceptionTransformer(nn.Module):
+ """Perception Transformer."""
+
+ def __init__(
+ self,
+ num_cams: int = 6,
+ encoder: BEVFormerEncoder | None = None,
+ decoder: BEVFormerDecoder | None = None,
+ embed_dims: int = 256,
+ num_feature_levels: int = 4,
+ rotate_center: tuple[int, int] = (100, 100),
+ ) -> None:
+ """Init."""
+ super().__init__()
+ self.num_cams = num_cams
+ self.embed_dims = embed_dims
+ self.num_feature_levels = num_feature_levels
+ self.rotate_center = list(rotate_center)
+
+ self.encoder = encoder or BEVFormerEncoder(embed_dims=self.embed_dims)
+ self.decoder = decoder or BEVFormerDecoder(embed_dims=self.embed_dims)
+
+ self._init_layers()
+ self._init_weights()
+
+ def _init_layers(self) -> None:
+ """Initialize layers of the Detr3DTransformer."""
+ self.level_embeds = nn.Parameter(
+ torch.Tensor(self.num_feature_levels, self.embed_dims)
+ )
+ self.cams_embeds = nn.Parameter(
+ torch.Tensor(self.num_cams, self.embed_dims)
+ )
+ self.reference_points = nn.Linear(self.embed_dims, 3)
+
+ self.can_bus_mlp = nn.Sequential(
+ nn.Linear(18, self.embed_dims // 2),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.embed_dims // 2, self.embed_dims),
+ nn.ReLU(inplace=True),
+ )
+ self.can_bus_mlp.add_module("norm", nn.LayerNorm(self.embed_dims))
+
+ def _init_weights(self) -> None:
+ """Initialize the transformer weights."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ nn.init.normal_(self.level_embeds)
+ nn.init.normal_(self.cams_embeds)
+ xavier_init(self.reference_points, distribution="uniform", bias=0.0)
+ xavier_init(self.can_bus_mlp, distribution="uniform", bias=0.0)
+
+ def get_bev_features(
+ self,
+ mlvl_feats: list[Tensor],
+ can_bus: Tensor,
+ bev_queries: Tensor,
+ bev_h: int,
+ bev_w: int,
+ images_hw: tuple[int, int],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: Tensor,
+ grid_length: tuple[float, float],
+ bev_pos: Tensor,
+ prev_bev: Tensor | None = None,
+ ) -> Tensor:
+ """Obtain bev features."""
+ batch_size = mlvl_feats[0].shape[0]
+ bev_queries = bev_queries.unsqueeze(1).repeat(1, batch_size, 1)
+ bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
+
+ # obtain rotation angle and shift with ego motion
+ delta_x = can_bus[:, 0].unsqueeze(1)
+ delta_y = can_bus[:, 1].unsqueeze(1)
+ ego_angle = can_bus[:, -2] / np.pi * 180
+
+ translation_length = torch.sqrt(delta_x**2 + delta_y**2)
+ translation_angle = torch.arctan2(delta_y, delta_x) / np.pi * 180
+ bev_angle = ego_angle - translation_angle
+
+ shift_y = (
+ translation_length
+ * torch.cos(bev_angle / 180 * np.pi)
+ / grid_length[0]
+ / bev_h
+ )
+ shift_x = (
+ translation_length
+ * torch.sin(bev_angle / 180 * np.pi)
+ / grid_length[1]
+ / bev_w
+ )
+
+ # B, xy
+ shift = torch.cat([shift_x, shift_y], dim=1)
+
+ if prev_bev is not None:
+ if prev_bev.shape[1] == bev_h * bev_w:
+ prev_bev = prev_bev.permute(1, 0, 2)
+
+ # rotate prev_bev
+ for i in range(batch_size):
+ rotation_angle = float(can_bus[i][-1])
+ tmp_prev_bev = (
+ prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1)
+ )
+ tmp_prev_bev = rotate(
+ tmp_prev_bev, rotation_angle, center=self.rotate_center
+ )
+ tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
+ bev_h * bev_w, 1, -1
+ )
+ prev_bev[:, i] = tmp_prev_bev[:, 0]
+
+ # add can bus signals
+ bev_queries = bev_queries + self.can_bus_mlp(can_bus)[None, :, :]
+
+ feat_flatten_list = []
+ spatial_shapes_list = []
+ for lvl, feat in enumerate(mlvl_feats):
+ spatial_shape = feat.shape[-2:]
+ feat = feat.flatten(3).permute(1, 0, 3, 2)
+
+ # Add cams_embeds and level_embeds
+ feat += self.cams_embeds[:, None, None, :].to(feat.dtype)
+ feat += self.level_embeds[None, None, lvl : lvl + 1, :].to(
+ feat.dtype
+ )
+
+ spatial_shapes_list.append(spatial_shape)
+ feat_flatten_list.append(feat)
+
+ feat_flatten = torch.cat(feat_flatten_list, 2)
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes_list, dtype=torch.long, device=bev_pos.device
+ )
+ level_start_index = torch.cat(
+ (
+ spatial_shapes.new_zeros((1,)),
+ spatial_shapes.prod(1).cumsum(0)[:-1],
+ )
+ )
+
+ # (num_cam, H*W, bs, embed_dims)
+ feat_flatten = feat_flatten.permute(0, 2, 1, 3)
+
+ bev_embed = self.encoder(
+ bev_queries,
+ feat_flatten,
+ bev_h=bev_h,
+ bev_w=bev_w,
+ bev_pos=bev_pos,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ prev_bev=prev_bev,
+ shift=shift,
+ images_hw=images_hw,
+ cam_intrinsics=cam_intrinsics,
+ cam_extrinsics=cam_extrinsics,
+ lidar_extrinsics=lidar_extrinsics,
+ )
+ return bev_embed
+
+ def forward(
+ self,
+ mlvl_feats: list[Tensor],
+ can_bus: Tensor,
+ bev_queries: Tensor,
+ object_query_embed: Tensor,
+ bev_h: int,
+ bev_w: int,
+ images_hw: tuple[int, int],
+ cam_intrinsics: list[Tensor],
+ cam_extrinsics: list[Tensor],
+ lidar_extrinsics: Tensor,
+ grid_length: tuple[float, float],
+ bev_pos: Tensor,
+ reg_branches: list[nn.Module],
+ prev_bev: Tensor | None = None,
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Forward function for BEVFormer transformer.
+
+ Args:
+ mlvl_feats (list(Tensor)): Input queries from different level. Each
+ element has shape [bs, num_cams, embed_dims, h, w].
+ can_bus (Tensor): The can bus signals, has shape [bs, 18].
+ bev_queries (Tensor): (bev_h * bev_w, embed_dims).
+ object_query_embed (Tensor): The query embedding for decoder,
+ with shape [num_query, embed_dims * 2].
+ bev_h (int): The height of BEV feature map.
+ bev_w (int): The width of BEV feature map.
+ images_hw (tuple[int, int]): The height and width of images.
+ cam_intrinsics (list[Tensor]): The camera intrinsics.
+ cam_extrinsics (list[Tensor]): The camera extrinsics.
+ lidar_extrinsics (Tensor): The lidar extrinsics.
+ grid_length (tuple[float, float]): The length of grid in x and y
+ direction.
+ bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w)
+ reg_branches (list[nn.Module]): Regression heads for feature maps
+ from each decoder layer.
+ prev_bev (Tensor, optional): The previous BEV feature map, has
+ shape [bev_h * bev_w, bs, embed_dims]. Defaults to None.
+
+ Returns:
+ bev_embed (Tensor): BEV features has shape [bev_h *bev_w, bs,
+ embed_dims].
+ inter_states: Outputs from decoder has shape [1, bs, num_query,
+ embed_dims].
+ reference_points: As the initial reference has shape [bs,
+ num_queries, 4].
+ inter_references: The internal value of reference points in the
+ decoder, has shape [num_dec_layers, bs,num_query, embed_dims].
+ """
+ # bs, bev_h*bev_w, embed_dims
+ bev_embed = self.get_bev_features(
+ mlvl_feats,
+ can_bus,
+ bev_queries,
+ bev_h,
+ bev_w,
+ images_hw=images_hw,
+ cam_intrinsics=cam_intrinsics,
+ cam_extrinsics=cam_extrinsics,
+ lidar_extrinsics=lidar_extrinsics,
+ grid_length=grid_length,
+ bev_pos=bev_pos,
+ prev_bev=prev_bev,
+ )
+
+ bs = mlvl_feats[0].shape[0]
+ query_pos, query = torch.split(
+ object_query_embed, self.embed_dims, dim=1
+ )
+ query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
+ query = query.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_pos)
+ reference_points = reference_points.sigmoid()
+
+ query = query.permute(1, 0, 2)
+ query_pos = query_pos.permute(1, 0, 2)
+ bev_embed = bev_embed.permute(1, 0, 2)
+
+ inter_states, inter_references = self.decoder(
+ query=query,
+ value=bev_embed,
+ reference_points=reference_points,
+ spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
+ level_start_index=torch.tensor([0], device=query.device),
+ query_pos=query_pos,
+ reg_branches=reg_branches,
+ )
+
+ return bev_embed, inter_states, reference_points, inter_references
diff --git a/vis4d/op/detect3d/common.py b/vis4d/op/detect3d/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..33003af712f87f544b8e9c4689bc45281fa0f84b
--- /dev/null
+++ b/vis4d/op/detect3d/common.py
@@ -0,0 +1,23 @@
+"""Common classes and functions for 3D detection."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from torch import Tensor
+
+
+class Detect3DOut(NamedTuple):
+ """Output of detect 3D model.
+
+ Attributes:
+ boxes_3d (list[Tensor]): List of bounding boxes (B, N, 10).
+ velocities (list[Tensor]): List of velocities (B, N, 3).
+ class_ids (list[Tensor]): List of class ids (B, N).
+ scores_3d (list[Tensor]): List of scores (B, N).
+ """
+
+ boxes_3d: list[Tensor]
+ velocities: list[Tensor]
+ class_ids: list[Tensor]
+ scores_3d: list[Tensor]
diff --git a/vis4d/op/detect3d/qd_3dt.py b/vis4d/op/detect3d/qd_3dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d92370b8e38d5861c38d15e220767682a7784d79
--- /dev/null
+++ b/vis4d/op/detect3d/qd_3dt.py
@@ -0,0 +1,699 @@
+"""QD-3DT detector."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.typing import LossesType
+from vis4d.op.box.encoder.qd_3dt import QD3DTBox3DDecoder, QD3DTBox3DEncoder
+from vis4d.op.box.matchers import Matcher, MaxIoUMatcher
+from vis4d.op.box.poolers import MultiScaleRoIAlign, MultiScaleRoIPooler
+from vis4d.op.box.samplers import (
+ CombinedSampler,
+ Sampler,
+ match_and_sample_proposals,
+)
+from vis4d.op.geometry.rotation import generate_rotation_output
+from vis4d.op.layer.conv2d import Conv2d, add_conv_branch
+from vis4d.op.layer.weight_init import kaiming_init, xavier_init
+from vis4d.op.loss.base import Loss
+from vis4d.op.loss.common import rotation_loss, smooth_l1_loss
+from vis4d.op.loss.reducer import LossReducer, SumWeightedLoss, mean_loss
+
+
+class QD3DTBBox3DHeadOutput(NamedTuple):
+ """QD-3DT bounding box 3D head training output."""
+
+ predictions: list[Tensor]
+ targets: Tensor | None
+ labels: Tensor | None
+
+
+class QD3DTDet3DOut(NamedTuple):
+ """Output of QD-3DT bounding box 3D head.
+
+ Attributes:
+ boxes_3d (list[Tensor]): Predicted 3D bounding boxes. Each tensor has
+ shape (N, 12) and contains x,y,z,h,w,l,rx,ry,rz,vx,vy,vz.
+ depth_uncertainty (list[Tensor]): Predicted depth uncertainty. Each
+ tensor has shape (N, 1).
+ """
+
+ boxes_3d: list[Tensor]
+ depth_uncertainty: list[Tensor]
+
+
+def get_default_proposal_pooler() -> MultiScaleRoIAlign:
+ """Get default proposal pooler of QD-3DT bounding box 3D head."""
+ return MultiScaleRoIAlign(
+ resolution=[7, 7], strides=[4, 8, 16, 32], sampling_ratio=0
+ )
+
+
+def get_default_box_sampler() -> CombinedSampler:
+ """Get default box sampler of QD-3DT bounding box 3D head."""
+ return CombinedSampler(
+ batch_size=512,
+ positive_fraction=0.25,
+ pos_strategy="instance_balanced",
+ neg_strategy="iou_balanced",
+ )
+
+
+def get_default_box_matcher() -> MaxIoUMatcher:
+ """Get default box matcher of QD-3DT bounding box 3D head."""
+ return MaxIoUMatcher(
+ thresholds=[0.5, 0.5],
+ labels=[0, -1, 1],
+ allow_low_quality_matches=False,
+ )
+
+
+def get_default_box_codec(
+ center_scale: float = 10.0,
+ depth_log_scale: float = 2.0,
+ dim_log_scale: float = 2.0,
+ num_rotation_bins: int = 2,
+ bin_overlap: float = 1 / 6,
+) -> tuple[QD3DTBox3DEncoder, QD3DTBox3DDecoder]:
+ """Get the default bounding box encoder and decoder."""
+ return (
+ QD3DTBox3DEncoder(
+ center_scale=center_scale,
+ depth_log_scale=depth_log_scale,
+ dim_log_scale=dim_log_scale,
+ num_rotation_bins=num_rotation_bins,
+ bin_overlap=bin_overlap,
+ ),
+ QD3DTBox3DDecoder(
+ center_scale=center_scale,
+ depth_log_scale=depth_log_scale,
+ dim_log_scale=dim_log_scale,
+ num_rotation_bins=num_rotation_bins,
+ ),
+ )
+
+
+class QD3DTBBox3DHead(nn.Module):
+ """This class implements the QD-3DT bounding box 3D head."""
+
+ def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments, line-too-long
+ self,
+ num_classes: int,
+ proposal_pooler: None | MultiScaleRoIPooler = None,
+ box_matcher: None | Matcher = None,
+ box_sampler: None | Sampler = None,
+ box_encoder: None | QD3DTBox3DEncoder = None,
+ proposal_append_gt: bool = True,
+ num_shared_convs: int = 2,
+ num_shared_fcs: int = 0,
+ num_dep_convs: int = 4,
+ num_dep_fcs: int = 0,
+ num_dim_convs: int = 4,
+ num_dim_fcs: int = 0,
+ num_rot_convs: int = 4,
+ num_rot_fcs: int = 0,
+ num_cen_2d_convs: int = 4,
+ num_cen_2d_fcs: int = 0,
+ in_channels: int = 256,
+ conv_out_dim: int = 256,
+ fc_out_dim: int = 1024,
+ roi_feat_size: int = 7,
+ conv_has_bias: bool = True,
+ norm: None | str = None,
+ num_groups: int = 32,
+ num_rotation_bins: int = 2,
+ start_level: int = 2,
+ ):
+ """Initialize the QD-3DT bounding box 3D head."""
+ super().__init__()
+ self.proposal_pooler = (
+ proposal_pooler
+ if proposal_pooler is not None
+ else get_default_proposal_pooler()
+ )
+ self.box_matcher = (
+ box_matcher
+ if box_matcher is not None
+ else get_default_box_matcher()
+ )
+ self.box_sampler = (
+ box_sampler
+ if box_sampler is not None
+ else get_default_box_sampler()
+ )
+ self.box_encoder = (
+ box_encoder if box_encoder is not None else QD3DTBox3DEncoder()
+ )
+ self.num_shared_convs = num_shared_convs
+ self.num_shared_fcs = num_shared_fcs
+ self.num_rotation_bins = num_rotation_bins
+ self.proposal_append_gt = proposal_append_gt
+ self.cls_out_channels = num_classes
+
+ # Used feature layers are [start_level, end_level)
+ self.start_level = start_level
+ num_strides = len(self.proposal_pooler.scales)
+ self.end_level = start_level + num_strides
+
+ # add shared convs and fcs
+ (
+ self.shared_convs,
+ self.shared_fcs,
+ self.shared_out_channels,
+ ) = self._add_conv_fc_branch(
+ num_shared_convs,
+ num_shared_fcs,
+ in_channels,
+ conv_out_dim,
+ fc_out_dim,
+ conv_has_bias,
+ norm,
+ num_groups,
+ True,
+ )
+
+ # add depth specific branch
+ (
+ self.dep_convs,
+ self.dep_fcs,
+ self.dep_last_dim,
+ ) = self._add_conv_fc_branch(
+ num_dep_convs,
+ num_dep_fcs,
+ self.shared_out_channels,
+ conv_out_dim,
+ fc_out_dim,
+ conv_has_bias,
+ norm,
+ num_groups,
+ )
+
+ # add dim specific branch
+ (
+ self.dim_convs,
+ self.dim_fcs,
+ self.dim_last_dim,
+ ) = self._add_conv_fc_branch(
+ num_dim_convs,
+ num_dim_fcs,
+ self.shared_out_channels,
+ conv_out_dim,
+ fc_out_dim,
+ conv_has_bias,
+ norm,
+ num_groups,
+ )
+
+ # add rot specific branch
+ (
+ self.rot_convs,
+ self.rot_fcs,
+ self.rot_last_dim,
+ ) = self._add_conv_fc_branch(
+ num_rot_convs,
+ num_rot_fcs,
+ self.shared_out_channels,
+ conv_out_dim,
+ fc_out_dim,
+ conv_has_bias,
+ norm,
+ num_groups,
+ )
+
+ # add delta 2D center specific branch
+ (
+ self.cen_2d_convs,
+ self.cen_2d_fcs,
+ self.cen_2d_last_dim,
+ ) = self._add_conv_fc_branch(
+ num_cen_2d_convs,
+ num_cen_2d_fcs,
+ self.shared_out_channels,
+ conv_out_dim,
+ fc_out_dim,
+ conv_has_bias,
+ norm,
+ num_groups,
+ )
+
+ if num_shared_fcs == 0:
+ if num_dep_fcs == 0:
+ self.dep_last_dim *= roi_feat_size * roi_feat_size
+ if num_dim_fcs == 0:
+ self.dim_last_dim *= roi_feat_size * roi_feat_size
+ if num_rot_fcs == 0:
+ self.rot_last_dim *= roi_feat_size * roi_feat_size
+ if num_cen_2d_fcs == 0:
+ self.cen_2d_last_dim *= roi_feat_size * roi_feat_size
+
+ self.relu = nn.ReLU(inplace=True)
+ # reconstruct fc_cls and fc_reg since input channels are changed
+ out_dim_dep = self.cls_out_channels
+ self.fc_dep = nn.Linear(self.dep_last_dim, out_dim_dep)
+
+ self.fc_dep_uncer = nn.Linear(self.dep_last_dim, out_dim_dep)
+
+ out_dim_size = 3 * self.cls_out_channels
+ self.fc_dim = nn.Linear(self.dim_last_dim, out_dim_size)
+
+ out_rot_size = 3 * num_rotation_bins * self.cls_out_channels
+ self.fc_rot = nn.Linear(self.rot_last_dim, out_rot_size)
+
+ out_cen_2d_size = 2 * self.cls_out_channels
+ self.fc_cen_2d = nn.Linear(self.cen_2d_last_dim, out_cen_2d_size)
+
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Init weights of modules in head."""
+ module_lists: list[nn.ModuleList | nn.Linear | Conv2d] = []
+ module_lists += [self.shared_convs]
+ module_lists += [self.shared_fcs]
+ module_lists += [self.dep_convs]
+ module_lists += [self.fc_dep_uncer]
+ module_lists += [self.fc_dep, self.dep_fcs]
+ module_lists += [self.dim_convs]
+ module_lists += [self.fc_dim, self.dim_fcs]
+ module_lists += [self.rot_convs]
+ module_lists += [self.fc_rot, self.rot_fcs]
+ module_lists += [self.cen_2d_convs]
+ module_lists += [self.fc_cen_2d, self.cen_2d_fcs]
+
+ for module_list in module_lists:
+ for m in module_list.modules():
+ if isinstance(m, nn.Linear):
+ xavier_init(m, distribution="uniform")
+ elif isinstance(m, Conv2d):
+ kaiming_init(m)
+
+ def _add_conv_fc_branch(
+ self,
+ num_branch_convs: int,
+ num_branch_fcs: int,
+ in_channels: int,
+ conv_out_dim: int,
+ fc_out_dim: int,
+ conv_has_bias: bool,
+ norm: None | str,
+ num_groups: int,
+ is_shared: bool = False,
+ ) -> tuple[nn.ModuleList, nn.ModuleList, int]:
+ """Init modules of head."""
+ convs, last_layer_dim = add_conv_branch(
+ num_branch_convs,
+ in_channels,
+ conv_out_dim,
+ conv_has_bias,
+ norm,
+ num_groups,
+ )
+
+ fcs = nn.ModuleList()
+ if num_branch_fcs > 0:
+ if is_shared or num_branch_fcs == 0:
+ last_layer_dim *= int(np.prod(self.proposal_pooler.resolution))
+ for i in range(num_branch_fcs):
+ fc_in_dim = last_layer_dim if i == 0 else fc_out_dim
+ fcs.append(
+ nn.Sequential(
+ nn.Linear(fc_in_dim, fc_out_dim),
+ nn.ReLU(inplace=True),
+ )
+ )
+ last_layer_dim = fc_out_dim
+ return convs, fcs, last_layer_dim
+
+ def get_embeds(
+ self, feat: Tensor
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Generate embedding from bbox feature."""
+ # shared part
+ if self.num_shared_convs > 0:
+ for conv in self.shared_convs:
+ feat = conv(feat)
+
+ if self.num_shared_fcs > 0:
+ feat = feat.view(feat.size(0), -1)
+ for fc in self.shared_fcs:
+ feat = self.relu(fc(feat))
+
+ # separate branches
+ x_dep = feat
+ x_dim = feat
+ x_rot = feat
+ x_cen_2d = feat
+
+ for conv in self.dep_convs:
+ x_dep = conv(x_dep)
+ if x_dep.dim() > 2:
+ x_dep = x_dep.view(x_dep.size(0), -1)
+ for fc in self.dep_fcs:
+ x_dep = self.relu(fc(x_dep))
+
+ for conv in self.dim_convs:
+ x_dim = conv(x_dim)
+ if x_dim.dim() > 2:
+ x_dim = x_dim.view(x_dim.size(0), -1)
+ for fc in self.dim_fcs:
+ x_dim = self.relu(fc(x_dim))
+
+ for conv in self.rot_convs:
+ x_rot = conv(x_rot)
+ if x_rot.dim() > 2:
+ x_rot = x_rot.view(x_rot.size(0), -1)
+ for fc in self.rot_fcs:
+ x_rot = self.relu(fc(x_rot))
+
+ for conv in self.cen_2d_convs:
+ x_cen_2d = conv(x_cen_2d)
+ if x_cen_2d.dim() > 2:
+ x_cen_2d = x_cen_2d.view(x_cen_2d.size(0), -1)
+ for fc in self.cen_2d_fcs:
+ x_cen_2d = self.relu(fc(x_cen_2d))
+
+ return x_dep, x_dim, x_rot, x_cen_2d
+
+ def get_outputs(
+ self, x_dep: Tensor, x_dim: Tensor, x_rot: Tensor, x_cen_2d: Tensor
+ ) -> Tensor:
+ """Generate output 3D bounding box parameters."""
+ depth = self.fc_dep(x_dep).view(-1, self.cls_out_channels, 1)
+ depth_uncertainty = self.fc_dep_uncer(x_dep).view(
+ -1, self.cls_out_channels, 1
+ )
+ dim = self.fc_dim(x_dim).view(-1, self.cls_out_channels, 3)
+ alpha = generate_rotation_output(
+ self.fc_rot(x_rot), self.num_rotation_bins
+ )
+ delta_cen_2d = self.fc_cen_2d(x_cen_2d).view(
+ -1, self.cls_out_channels, 2
+ )
+ return torch.cat(
+ [delta_cen_2d, depth, dim, alpha, depth_uncertainty], -1
+ )
+
+ def get_predictions(
+ self, features: list[Tensor], boxes_2d: list[Tensor]
+ ) -> list[Tensor]:
+ """Get 3D bounding box prediction parameters."""
+ if sum(len(b) for b in boxes_2d) == 0: # pragma: no cover
+ return [
+ torch.empty(
+ (
+ 0,
+ self.cls_out_channels,
+ 6 + 3 * self.num_rotation_bins + 1,
+ ),
+ device=boxes_2d[0].device,
+ )
+ ] * len(boxes_2d)
+
+ roi_feats = self.proposal_pooler(
+ features[self.start_level : self.end_level], boxes_2d
+ )
+ x_dep, x_dim, x_rot, x_cen_2d = self.get_embeds(roi_feats)
+
+ outputs: list[Tensor] = list(
+ self.get_outputs(x_dep, x_dim, x_rot, x_cen_2d).split(
+ [len(b) for b in boxes_2d]
+ )
+ )
+ return outputs
+
+ def get_targets(
+ self,
+ pos_assigned_gt_inds: list[Tensor],
+ target_boxes: list[Tensor],
+ target_boxes3d: list[Tensor],
+ target_class_ids: list[Tensor],
+ intrinsics: Tensor,
+ ) -> tuple[Tensor, Tensor]:
+ """Get 3D bounding box targets for training."""
+ targets = []
+ labels = []
+ for i, (tgt_boxes, tgt_boxes3d, intrinsics_) in enumerate(
+ zip(target_boxes, target_boxes3d, intrinsics)
+ ):
+ bbox_target = self.box_encoder(tgt_boxes, tgt_boxes3d, intrinsics_)
+ targets.append(bbox_target[pos_assigned_gt_inds[i]])
+
+ labels.append(target_class_ids[i][pos_assigned_gt_inds[i]])
+
+ return torch.cat(targets), torch.cat(labels)
+
+ def forward(
+ self,
+ features: list[Tensor],
+ det_boxes: list[Tensor],
+ intrinsics: Tensor | None = None,
+ target_boxes: list[Tensor] | None = None,
+ target_boxes3d: list[Tensor] | None = None,
+ target_class_ids: list[Tensor] | None = None,
+ ) -> QD3DTBBox3DHeadOutput:
+ """Forward."""
+ if (
+ intrinsics is not None
+ and target_boxes is not None
+ and target_boxes3d is not None
+ and target_class_ids is not None
+ ):
+ if self.proposal_append_gt:
+ det_boxes = [
+ torch.cat([d, t]) for d, t in zip(det_boxes, target_boxes)
+ ]
+
+ (
+ sampled_box_indices,
+ sampled_target_indices,
+ sampled_labels,
+ ) = match_and_sample_proposals(
+ self.box_matcher, self.box_sampler, det_boxes, target_boxes
+ )
+ positives = [torch.eq(l, 1) for l in sampled_labels]
+ pos_assigned_gt_inds = [
+ i[p] if len(p) != 0 else p
+ for i, p in zip(sampled_target_indices, positives)
+ ]
+ pos_boxes = [
+ b[s_i][p]
+ for b, s_i, p in zip(det_boxes, sampled_box_indices, positives)
+ ]
+ predictions = self.get_predictions(features, pos_boxes)
+
+ targets, labels = self.get_targets(
+ pos_assigned_gt_inds,
+ target_boxes,
+ target_boxes3d,
+ target_class_ids,
+ intrinsics,
+ )
+
+ return QD3DTBBox3DHeadOutput(
+ predictions=predictions, targets=targets, labels=labels
+ )
+
+ predictions = self.get_predictions(features, det_boxes)
+
+ return QD3DTBBox3DHeadOutput(predictions, None, None)
+
+ def __call__(
+ self,
+ features: list[Tensor],
+ det_boxes: list[Tensor],
+ intrinsics: Tensor | None = None,
+ target_boxes: list[Tensor] | None = None,
+ target_boxes3d: list[Tensor] | None = None,
+ target_class_ids: list[Tensor] | None = None,
+ ) -> QD3DTBBox3DHeadOutput:
+ """Type definition."""
+ return self._call_impl(
+ features,
+ det_boxes,
+ intrinsics,
+ target_boxes,
+ target_boxes3d,
+ target_class_ids,
+ )
+
+
+class RoI2Det3D:
+ """Post processing for QD3DTBBox3DHead."""
+
+ def __init__(self, box_decoder: None | QD3DTBox3DDecoder = None) -> None:
+ """Initialize."""
+ self.box_decoder = (
+ QD3DTBox3DDecoder() if box_decoder is None else box_decoder
+ )
+
+ def __call__(
+ self,
+ predictions: list[Tensor],
+ boxes_2d: list[Tensor],
+ class_ids: list[Tensor],
+ intrinsics: Tensor,
+ ) -> QD3DTDet3DOut:
+ """Forward pass during testing stage.
+
+ Args:
+ predictions(list[Tensor]): Predictions.
+ boxes_2d(list[Tensor]): 2D boxes.
+ class_ids(list[Tensor]): Class IDs.
+ intrinsics(Tensor): Camera intrinsics.
+
+ Returns:
+ QD3DTDet3DOut: QD3DT 3D detection output.
+ """
+ boxes_3d = []
+ depth_uncertainty = []
+ device = boxes_2d[0].device
+ for _boxes_2d, _class_ids, _boxes_deltas, _intrinsics in zip(
+ boxes_2d, class_ids, predictions, intrinsics
+ ):
+ if len(_boxes_2d) == 0:
+ boxes_3d.append(torch.empty(0, 12).to(device))
+ depth_uncertainty.append(torch.empty(0).to(device))
+ continue
+
+ _boxes_deltas = _boxes_deltas[
+ torch.arange(_boxes_deltas.shape[0]), _class_ids
+ ]
+
+ depth_uncertainty.append(
+ _boxes_deltas[:, -1].clamp(min=0.0, max=1.0)
+ )
+ boxes_3d.append(
+ self.box_decoder(_boxes_2d, _boxes_deltas, _intrinsics)
+ )
+
+ return QD3DTDet3DOut(
+ boxes_3d=boxes_3d, depth_uncertainty=depth_uncertainty
+ )
+
+
+class Box3DUncertaintyLoss(Loss):
+ """Box3d loss for QD-3DT."""
+
+ def __init__(
+ self,
+ reducer: LossReducer = mean_loss,
+ center_loss_weight: float = 1.0,
+ depth_loss_weight: float = 1.0,
+ dimension_loss_weight: float = 1.0,
+ rotation_loss_weight: float = 1.0,
+ uncertainty_loss_weight: float = 1.0,
+ num_rotation_bins: int = 2,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ reducer (LossReducer): Reducer for the loss function.
+ center_loss_weight (float): Weight for center loss.
+ depth_loss_weight (float): Weight for depth loss.
+ dimension_loss_weight (float): Weight for dimension loss.
+ rotation_loss_weight (float): Weight for rotation loss.
+ uncertainty_loss_weight (float): Weight for uncertainty loss.
+ num_rotation_bins (int): Number of rotation bins.
+ """
+ super().__init__(reducer)
+ self.center_loss_weight = center_loss_weight
+ self.depth_loss_weight = depth_loss_weight
+ self.dimension_loss_weight = dimension_loss_weight
+ self.rotation_loss_weight = rotation_loss_weight
+ self.uncertainty_loss_weight = uncertainty_loss_weight
+ self.num_rotation_bins = num_rotation_bins
+
+ def forward(
+ self, pred: Tensor, target: Tensor, labels: Tensor
+ ) -> LossesType:
+ """Compute box3d loss.
+
+ Args:
+ pred (Tensor): Box predictions of shape [N, num_classes,
+ 6 + 3 * num_rotations_bins].
+ target (torcch.Tensor): Target boxes of shape [N,
+ 6 + num_rotation_bins].
+ labels (Tensor): Target Labels of shape [N].
+
+ Returns:
+ dict[str, Tensor] containing 'delta 2dc', 'dimension', 'depth',
+ 'rotation' and 'uncertainty' loss.
+ """
+ if pred.size(0) == 0:
+ loss_ctr3d = loss_dep3d = loss_dim3d = loss_rot3d = loss_conf3d = (
+ pred.sum() * 0
+ )
+ result_dict = {
+ "loss_ctr3d": loss_ctr3d,
+ "loss_dep3d": loss_dep3d,
+ "loss_dim3d": loss_dim3d,
+ "loss_rot3d": loss_rot3d,
+ "loss_conf3d": loss_conf3d,
+ }
+
+ return result_dict
+
+ pred = pred[torch.arange(pred.shape[0], device=pred.device), labels]
+
+ # delta 2dc loss
+ loss_cen = smooth_l1_loss(
+ pred[:, :2], target[:, :2], reducer=self.reducer, beta=1 / 9
+ )
+
+ # dimension loss
+ dim_mask = target[:, 3:6] != 100.0
+ loss_dim = smooth_l1_loss(
+ pred[:, 3:6][dim_mask],
+ target[:, 3:6][dim_mask],
+ reducer=self.reducer,
+ beta=1 / 9,
+ )
+
+ # depth loss
+ depth_mask = target[:, 2] > 0
+ loss_dep = smooth_l1_loss(
+ pred[:, 2][depth_mask],
+ target[:, 2][depth_mask],
+ reducer=self.reducer,
+ beta=1 / 9,
+ )
+
+ # rotation loss
+ loss_rot = rotation_loss(
+ pred[:, 6 : 6 + self.num_rotation_bins * 3],
+ target[:, 6 : 6 + self.num_rotation_bins],
+ target[:, 6 + self.num_rotation_bins :],
+ self.num_rotation_bins,
+ reducer=self.reducer,
+ )
+
+ # uncertainty loss
+ pos_depth_self_labels = torch.exp(
+ -torch.mul(torch.abs(pred[:, 2] - target[:, 2]), 5.0)
+ )
+ pos_depth_self_weights = torch.where(
+ pos_depth_self_labels > 0.8,
+ pos_depth_self_labels.new_ones(1) * 5.0,
+ pos_depth_self_labels.new_ones(1) * 0.1,
+ )
+
+ loss_unc3d = smooth_l1_loss(
+ pred[:, -1],
+ pos_depth_self_labels.detach().clone(),
+ reducer=SumWeightedLoss(
+ pos_depth_self_weights, len(pos_depth_self_weights)
+ ),
+ beta=1 / 9,
+ )
+
+ return {
+ "loss_ctr3d": torch.mul(self.center_loss_weight, loss_cen),
+ "loss_dep3d": torch.mul(self.depth_loss_weight, loss_dep),
+ "loss_dim3d": torch.mul(self.dimension_loss_weight, loss_dim),
+ "loss_rot3d": torch.mul(self.rotation_loss_weight, loss_rot),
+ "loss_unc3d": torch.mul(self.uncertainty_loss_weight, loss_unc3d),
+ }
diff --git a/vis4d/op/detect3d/util.py b/vis4d/op/detect3d/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..80c646200cff4107f5063982363164d305d873bf
--- /dev/null
+++ b/vis4d/op/detect3d/util.py
@@ -0,0 +1,117 @@
+"""Utilitiy functions for detection 3D ops."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE
+
+if VIS4D_CUDA_OPS_AVAILABLE:
+ from vis4d_cuda_ops import nms_rotated # pylint: disable=no-name-in-module
+
+
+def bev_3d_nms(
+ center_x: Tensor,
+ center_y: Tensor,
+ width: Tensor,
+ length: Tensor,
+ angle: Tensor,
+ scores: Tensor,
+ class_ids: Tensor | None = None,
+ iou_threshold: float = 0.1,
+) -> Tensor:
+ """BEV 3D NMS.
+
+ Args:
+ center_x (Tensor): Center x of boxes. In shape (N, 1).
+ center_y (Tensor): Center y of boxes. In shape (N, 1).
+ width (Tensor): Width of boxes. In shape (N, 1).
+ length (Tensor): Length of boxes. In shape (N, 1).
+ angle (Tensor): Angle of boxes. In shape (N, 1).
+ scores (Tensor): Scores of boxes. In shape (N, 1).
+ class_ids (Tensor | None, optional): Class ids of boxes. In shape
+ (N,). Defaults to None. If None, class_agnostic NMS will be
+ performed.
+ iou_threshold (float, optional): IoU threshold. Defaults to 0.1.
+
+ Returns:
+ Tensor: Indices of boxes that have been kept by NMS.
+ """
+ class_ids = (
+ torch.zeros_like(scores, dtype=torch.int64) # class_agnostic
+ if class_ids is None
+ else class_ids
+ )
+
+ return batched_nms_rotated(
+ torch.cat([center_x, center_y, width, length, angle], dim=-1),
+ scores,
+ class_ids,
+ iou_threshold,
+ )
+
+
+def batched_nms_rotated(
+ boxes: Tensor,
+ scores: Tensor,
+ idxs: Tensor,
+ iou_threshold: float,
+) -> Tensor:
+ """Performs non-maximum suppression in a batched fashion.
+
+ Each index value correspond to a category, and NMS
+ will not be applied between elements of different categories.
+
+ Args:
+ boxes (Tensor): Boxes where NMS will be performed. They are expected to
+ be in (x_ctr, y_ctr, width, height, angle_degrees) format. In shape
+ (N, 5).
+ scores (Tensor): Scores for each one of the boxes. In shape (N,).
+ idxs (Tensor): Indices of the categories for each one of the boxes.
+ In shape (N,).
+ iou_threshold (float): Discards all overlapping boxes with IoU <
+ iou_threshold.
+
+ Returns:
+ Tensor: Int64 tensor with the indices of the elements that have been
+ kept by NMS, sorted in decreasing order of scores
+ """
+ assert boxes.shape[-1] == 5
+
+ if boxes.numel() == 0:
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
+
+ boxes = boxes.float() # fp16 does not have enough range for batched NMS
+
+ # Strategy: in order to perform NMS independently per class,
+ # we add an offset to all the boxes. The offset is dependent
+ # only on the class idx, and is large enough so that boxes
+ # from different classes do not overlap
+
+ # Note that batched_nms in torchvision/ops/boxes.py only uses
+ # max_coordinate, which won't handle negative coordinates correctly.
+ # Here by using min_coordinate we can make sure the negative coordinates
+ # are correctly handled.
+ max_coordinate = (
+ torch.max(boxes[:, 0], boxes[:, 1])
+ + torch.max(boxes[:, 2], boxes[:, 3]) / 2
+ ).max()
+ min_coordinate = (
+ torch.min(boxes[:, 0], boxes[:, 1])
+ - torch.max(boxes[:, 2], boxes[:, 3]) / 2
+ ).min()
+ offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1)
+ boxes_for_nms = (
+ boxes.clone()
+ ) # avoid modifying the original values in boxes
+ boxes_for_nms[:, :2] += offsets[:, None]
+
+ if not VIS4D_CUDA_OPS_AVAILABLE:
+ raise RuntimeError(
+ "Please install vis4d_cuda_ops to use batched_nms_rotated"
+ )
+ keep = nms_rotated( # pylint: disable=possibly-used-before-assignment
+ boxes_for_nms, scores, iou_threshold
+ )
+ return keep
diff --git a/vis4d/op/fpp/__init__.py b/vis4d/op/fpp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb06dca6c0f7362ae45c7119b8f24a5cd8a216b4
--- /dev/null
+++ b/vis4d/op/fpp/__init__.py
@@ -0,0 +1,12 @@
+"""Vis4D modules for feature pyramid processing.
+
+Feature pyramid processing is usually used for augmenting the existing feature
+maps and/or upsampling the feature maps.
+"""
+
+from .base import FeaturePyramidProcessing
+from .dla_up import DLAUp
+from .fpn import FPN
+from .yolox_pafpn import YOLOXPAFPN
+
+__all__ = ["DLAUp", "FPN", "FeaturePyramidProcessing", "YOLOXPAFPN"]
diff --git a/vis4d/op/fpp/base.py b/vis4d/op/fpp/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d998ea3238f119a9a378ebe60cbc2afee6b53d
--- /dev/null
+++ b/vis4d/op/fpp/base.py
@@ -0,0 +1,31 @@
+"""Feature pyramid processing base class."""
+
+from __future__ import annotations
+
+import abc
+
+from torch import Tensor, nn
+
+
+class FeaturePyramidProcessing(nn.Module):
+ """Base Neck class."""
+
+ @abc.abstractmethod
+ def forward(self, features: list[Tensor]) -> list[Tensor]:
+ """Feature pyramid processing.
+
+ This module do a further processing for the hierarchical feature
+ representation extracted by the base models.
+
+ Args:
+ features (list[Tensor]): Feature pyramid as outputs of the
+ base model.
+
+ Returns:
+ list[Tensor]: Feature pyramid after the processing.
+ """
+ raise NotImplementedError
+
+ def __call__(self, features: list[Tensor]) -> list[Tensor]:
+ """Type definition for call implementation."""
+ return self._call_impl(features)
diff --git a/vis4d/op/fpp/dla_up.py b/vis4d/op/fpp/dla_up.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e37cacf8ffc64eee49d4dd34d6f35bf65cae3d8
--- /dev/null
+++ b/vis4d/op/fpp/dla_up.py
@@ -0,0 +1,171 @@
+"""DLA-UP.
+
+TODO(fyu) need clean up and update to the latest interface.
+"""
+
+from __future__ import annotations
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+from vis4d.common.typing import NDArrayI64
+from vis4d.op.layer.conv2d import Conv2d
+from vis4d.op.layer.deform_conv import DeformConv
+
+from .base import FeaturePyramidProcessing
+
+
+def fill_up_weights(up_layer: nn.ConvTranspose2d) -> None:
+ """Initialize weights of upsample layer."""
+ w = up_layer.weight.data
+ f = math.ceil(w.size(2) / 2)
+ c = (2 * f - 1 - f % 2) / (2.0 * f)
+ for i in range(w.size(2)):
+ for j in range(w.size(3)):
+ w[0, 0, i, j] = (1 - math.fabs(i / f - c)) * (
+ 1 - math.fabs(j / f - c)
+ )
+ for c in range(1, w.size(0)):
+ w[c, 0, :, :] = w[0, 0, :, :]
+
+
+class IDAUp(nn.Module):
+ """IDAUp."""
+
+ def __init__(
+ self, use_dc: bool, o: int, channels: list[int], up_f: list[int]
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ for i in range(1, len(channels)):
+ c = channels[i]
+ f = int(up_f[i])
+ if use_dc:
+ proj: Conv2d | DeformConv = DeformConv(
+ c,
+ o,
+ kernel_size=3,
+ padding=1,
+ norm=nn.BatchNorm2d(o),
+ activation=nn.ReLU(inplace=True),
+ )
+ node: Conv2d | DeformConv = DeformConv(
+ o,
+ o,
+ kernel_size=3,
+ padding=1,
+ norm=nn.BatchNorm2d(o),
+ activation=nn.ReLU(inplace=True),
+ )
+ else:
+ proj = Conv2d(
+ c,
+ o,
+ kernel_size=1,
+ stride=1,
+ bias=False,
+ norm=nn.BatchNorm2d(o),
+ activation=nn.ReLU(inplace=True),
+ )
+ node = Conv2d(
+ o,
+ o,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ norm=nn.BatchNorm2d(o),
+ activation=nn.ReLU(inplace=True),
+ )
+
+ up = nn.ConvTranspose2d(
+ o,
+ o,
+ f * 2,
+ stride=f,
+ padding=f // 2,
+ output_padding=0,
+ groups=o,
+ bias=False,
+ )
+ fill_up_weights(up)
+
+ setattr(self, "proj_" + str(i), proj)
+ setattr(self, "up_" + str(i), up)
+ setattr(self, "node_" + str(i), node)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(
+ self, layers: list[torch.Tensor], startp: int, endp: int
+ ) -> None:
+ """Forward."""
+ for i in range(startp + 1, endp):
+ upsample = getattr(self, "up_" + str(i - startp))
+ project = getattr(self, "proj_" + str(i - startp))
+ layers[i] = upsample(project(layers[i]))
+ node = getattr(self, "node_" + str(i - startp))
+ layers[i] = node(layers[i] + layers[i - 1])
+
+
+class DLAUp(FeaturePyramidProcessing):
+ """DLAUp."""
+
+ def __init__(
+ self,
+ in_channels: list[int],
+ out_channels: None | int = None,
+ start_level: int = 0,
+ end_level: int = -1,
+ use_deformable_convs: bool = True,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.start_level = start_level
+ self.end_level = end_level
+ if self.end_level == -1:
+ self.end_level = len(in_channels)
+ in_channels = in_channels[self.start_level : self.end_level]
+ channels = list(in_channels)
+ scales: NDArrayI64 = np.array(
+ [2**i for i, _ in enumerate(in_channels)], dtype=np.int64
+ )
+ for i in range(len(channels) - 1):
+ j = -i - 2
+ idaup = IDAUp(
+ use_deformable_convs,
+ channels[j],
+ in_channels[j:],
+ scales[j:] // scales[j],
+ )
+ setattr(self, f"ida_{i}", idaup)
+ scales[j + 1 :] = scales[j]
+ in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]]
+ if out_channels is None:
+ out_channels = channels[0]
+ self.ida_final = IDAUp(
+ use_deformable_convs,
+ out_channels,
+ channels,
+ [2**i for i in range(self.end_level - self.start_level)],
+ )
+
+ def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
+ """Forward."""
+ outs = [features[self.end_level - 1]]
+ for i in range(self.end_level - self.start_level - 1):
+ ida = getattr(self, f"ida_{i}")
+ ida(features, self.end_level - i - 2, self.end_level)
+ outs.insert(0, features[self.end_level - 1])
+ self.ida_final(outs, 0, len(outs))
+ outs = [outs[-1]]
+ return outs
diff --git a/vis4d/op/fpp/fpn.py b/vis4d/op/fpp/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0433437bb9529ffb6d696cba15a3a43538afd204
--- /dev/null
+++ b/vis4d/op/fpp/fpn.py
@@ -0,0 +1,138 @@
+"""Feature Pyramid Network.
+
+This is based on `"Feature Pyramid Network for Object Detection"
+`_.
+"""
+
+from __future__ import annotations
+
+from collections import OrderedDict
+
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torchvision.ops import FeaturePyramidNetwork as _FPN
+from torchvision.ops.feature_pyramid_network import (
+ ExtraFPNBlock as _ExtraFPNBlock,
+)
+from torchvision.ops.feature_pyramid_network import (
+ LastLevelMaxPool,
+)
+
+from .base import FeaturePyramidProcessing
+
+
+class FPN(_FPN, FeaturePyramidProcessing): # type: ignore
+ """Feature Pyramid Network.
+
+ This is a wrapper of the torchvision implementation.
+ """
+
+ def __init__(
+ self,
+ in_channels_list: list[int],
+ out_channels: int,
+ extra_blocks: _ExtraFPNBlock | None = LastLevelMaxPool(),
+ start_index: int = 2,
+ ) -> None:
+ """Init without additional components.
+
+ Args:
+ in_channels_list (list[int]): List of input channels.
+ out_channels (int): Output channels.
+ extra_blocks (_ExtraFPNBlock, optional): Extra block. Defaults to
+ LastLevelMaxPool().
+ start_index (int, optional): Start index of base model feature
+ maps. Defaults to 2.
+ """
+ super().__init__(
+ in_channels_list, out_channels, extra_blocks=extra_blocks
+ )
+ self.start_index = start_index
+
+ def forward(self, x: list[Tensor]) -> list[Tensor]:
+ """Process the input features with FPN.
+
+ Because by default, FPN doesn't upsample the first two feature maps in
+ the pyramid, we keep the first two feature maps intact.
+
+ Args:
+ x (list[Tensor]): Feature pyramid as outputs of the
+ base model.
+
+ Returns:
+ list[Tensor]: Feature pyramid after FPN processing.
+ """
+ feat_dict = OrderedDict(
+ (k, v)
+ for k, v in zip(
+ [str(i) for i in range(len(x) - self.start_index)],
+ x[self.start_index :],
+ )
+ )
+ outs = super().forward(feat_dict) # type: ignore
+ return [*x[: self.start_index], *outs.values()] # type: ignore
+
+ def __call__(self, x: list[Tensor]) -> list[Tensor]:
+ """Type definition for call implementation."""
+ return self._call_impl(x)
+
+
+class ExtraFPNBlock(_ExtraFPNBlock): # type: ignore
+ """Extra block in the FPN.
+
+ This is a wrapper of the torchvision implementation.
+ """
+
+ def __init__(
+ self,
+ extra_levels: int,
+ in_channels: int,
+ out_channels: int,
+ add_extra_convs: str = "on_output",
+ extra_relu: bool = False,
+ ) -> None:
+ """Create an instance of the class."""
+ super().__init__()
+ self.extra_levels = extra_levels
+ self.add_extra_convs = add_extra_convs
+ self.extra_relu = extra_relu
+
+ self.convs = nn.ModuleList()
+ if extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == "on_input":
+ _in_channels = in_channels
+ else:
+ _in_channels = out_channels
+
+ extra_fpn_conv = nn.Conv2d(
+ _in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ )
+ self.convs.append(extra_fpn_conv)
+
+ def forward(
+ self, results: list[Tensor], x: list[Tensor], names: list[str]
+ ) -> tuple[list[Tensor], list[str]]:
+ """Forward."""
+ if self.add_extra_convs == "on_input":
+ extra_source = x[-1]
+ elif self.add_extra_convs == "on_output":
+ extra_source = results[-1]
+ else:
+ raise NotImplementedError
+
+ results.append(self.convs[0](extra_source))
+ names.append(str(int(names[-1]) + 1))
+
+ for i in range(1, self.extra_levels):
+ if self.extra_relu:
+ results.append(self.convs[i](F.relu(results[-1])))
+ else:
+ results.append(self.convs[i](results[-1]))
+ names.append(str(int(names[-1]) + 1))
+
+ return results, names
diff --git a/vis4d/op/fpp/yolox_pafpn.py b/vis4d/op/fpp/yolox_pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5794f359d918215061753cb79462df63a583ba
--- /dev/null
+++ b/vis4d/op/fpp/yolox_pafpn.py
@@ -0,0 +1,175 @@
+"""YOLOX PAFPN.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import nn
+
+from vis4d.op.layer.conv2d import Conv2d
+from vis4d.op.layer.csp_layer import CSPLayer
+
+from .base import FeaturePyramidProcessing
+
+
+class YOLOXPAFPN(FeaturePyramidProcessing):
+ """Path Aggregation Network used in YOLOX.
+
+ Args:
+ in_channels (list[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ num_csp_blocks (int, optional): Number of bottlenecks in CSPLayer.
+ Defaults to 3.
+ start_index (int, optional): Index of the first input feature map.
+ Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ in_channels: list[int],
+ out_channels: int,
+ num_csp_blocks: int = 3,
+ start_index: int = 2,
+ ):
+ """Init."""
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.start_index = start_index
+
+ # build top-down blocks
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+ self.reduce_layers = nn.ModuleList()
+ self.top_down_blocks = nn.ModuleList()
+ for idx in range(len(in_channels) - 1, 0, -1):
+ self.reduce_layers.append(
+ Conv2d(
+ in_channels[idx],
+ in_channels[idx - 1],
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(
+ in_channels[idx - 1], eps=0.001, momentum=0.03
+ ),
+ activation=nn.SiLU(inplace=True),
+ )
+ )
+ self.top_down_blocks.append(
+ CSPLayer(
+ in_channels[idx - 1] * 2,
+ in_channels[idx - 1],
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ )
+ )
+
+ # build bottom-up blocks
+ self.downsamples = nn.ModuleList()
+ self.bottom_up_blocks = nn.ModuleList()
+ for idx in range(len(in_channels) - 1):
+ self.downsamples.append(
+ Conv2d(
+ in_channels[idx],
+ in_channels[idx],
+ 3,
+ stride=2,
+ padding=1,
+ bias=False,
+ norm=nn.BatchNorm2d(
+ in_channels[idx], eps=0.001, momentum=0.03
+ ),
+ activation=nn.SiLU(inplace=True),
+ )
+ )
+ self.bottom_up_blocks.append(
+ CSPLayer(
+ in_channels[idx] * 2,
+ in_channels[idx + 1],
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ )
+ )
+
+ self.out_convs = nn.ModuleList()
+ for _, inc in enumerate(in_channels):
+ self.out_convs.append(
+ Conv2d(
+ inc,
+ out_channels,
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(
+ out_channels, eps=0.001, momentum=0.03
+ ),
+ activation=nn.SiLU(inplace=True),
+ )
+ )
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Initialize weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(
+ m.weight,
+ a=math.sqrt(5),
+ mode="fan_in",
+ nonlinearity="leaky_relu",
+ )
+
+ def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
+ """Forward pass.
+
+ Args:
+ features (tuple[Tensor]): Input features.
+
+ Returns:
+ list[Tensor]: YOLOXPAFPN features.
+ """
+ images, features = (
+ features[: self.start_index],
+ features[self.start_index :],
+ )
+ assert len(features) == len(self.in_channels)
+
+ # top-down path
+ inner_outs = [features[-1]]
+ for idx in range(len(self.in_channels) - 1, 0, -1):
+ feat_heigh = inner_outs[0]
+ feat_low = features[idx - 1]
+ feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx](
+ feat_heigh
+ )
+ inner_outs[0] = feat_heigh
+
+ upsample_feat = self.upsample(feat_heigh)
+
+ inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
+ torch.cat([upsample_feat, feat_low], 1)
+ )
+ inner_outs.insert(0, inner_out)
+
+ # bottom-up path
+ outs = [inner_outs[0]]
+ for idx in range(len(self.in_channels) - 1):
+ feat_low = outs[-1]
+ feat_height = inner_outs[idx + 1]
+ downsample_feat = self.downsamples[idx](feat_low)
+ out = self.bottom_up_blocks[idx](
+ torch.cat([downsample_feat, feat_height], 1)
+ )
+ outs.append(out)
+
+ # out convs
+ for idx, conv in enumerate(self.out_convs):
+ outs[idx] = conv(outs[idx])
+
+ return images + outs
+
+ def __call__(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
+ """Type definition for call implementation."""
+ return self._call_impl(features)
diff --git a/vis4d/op/geometry/__init__.py b/vis4d/op/geometry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..07be652039b98f91f13deba9d270f48db32da255
--- /dev/null
+++ b/vis4d/op/geometry/__init__.py
@@ -0,0 +1 @@
+"""Init geometry module."""
diff --git a/vis4d/op/geometry/__pycache__/__init__.cpython-311.pyc b/vis4d/op/geometry/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03bf2ee18cf669d7fdb220ff7eb658b93fc9eb6a
Binary files /dev/null and b/vis4d/op/geometry/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/op/geometry/__pycache__/projection.cpython-311.pyc b/vis4d/op/geometry/__pycache__/projection.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6edbb527072ab7e765916660ade8b89a246daf54
Binary files /dev/null and b/vis4d/op/geometry/__pycache__/projection.cpython-311.pyc differ
diff --git a/vis4d/op/geometry/__pycache__/rotation.cpython-311.pyc b/vis4d/op/geometry/__pycache__/rotation.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0de841b3bee8f6e089a0a5bf44ec28f792326eab
Binary files /dev/null and b/vis4d/op/geometry/__pycache__/rotation.cpython-311.pyc differ
diff --git a/vis4d/op/geometry/__pycache__/transform.cpython-311.pyc b/vis4d/op/geometry/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fb8060a4c244977f08bb59eb25178f2e8ba753b8
Binary files /dev/null and b/vis4d/op/geometry/__pycache__/transform.cpython-311.pyc differ
diff --git a/vis4d/op/geometry/projection.py b/vis4d/op/geometry/projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab2c352dddbafaa5c8b845a804146f4201bdcbc4
--- /dev/null
+++ b/vis4d/op/geometry/projection.py
@@ -0,0 +1,138 @@
+"""Projection utilities."""
+
+from __future__ import annotations
+
+import torch
+
+from .transform import inverse_pinhole
+
+
+def project_points(
+ points: torch.Tensor, intrinsics: torch.Tensor
+) -> torch.Tensor:
+ """Project points to pixel coordinates with given intrinsics.
+
+ Args:
+ points: (N, 3) or (B, N, 3) 3D coordinates.
+ intrinsics: (3, 3) or (B, 3, 3) intrinsic camera matrices.
+
+ Returns:
+ torch.Tensor: (N, 2) or (B, N, 2) 2D pixel coordinates.
+
+ Raises:
+ ValueError: Shape of input points is not valid for computation.
+ """
+ assert points.shape[-1] == 3, "Input coordinates must be 3 dimensional!"
+ hom_coords = points / points[..., 2:3]
+ if len(hom_coords.shape) == 2:
+ assert (
+ len(intrinsics.shape) == 2
+ ), "Got multiple intrinsics for single point set!"
+ intrinsics = intrinsics.T
+ elif len(hom_coords.shape) == 3:
+ if len(intrinsics.shape) == 2:
+ intrinsics = intrinsics.unsqueeze(0)
+ intrinsics = intrinsics.permute(0, 2, 1)
+ else:
+ raise ValueError(f"Shape of input points not valid: {points.shape}")
+ pts_2d = hom_coords @ intrinsics
+ return pts_2d[..., :2]
+
+
+def unproject_points(
+ points: torch.Tensor, depths: torch.Tensor, intrinsics: torch.Tensor
+) -> torch.Tensor:
+ """Un-projects pixel coordinates to 3D coordinates with given intrinsics.
+
+ Args:
+ points: (N, 2) or (B, N, 2) 2D pixel coordinates.
+ depths: (N,) / (N, 1) or (B, N,) / (B, N, 1) depth values.
+ intrinsics: (3, 3) or (B, 3, 3) intrinsic camera matrices.
+
+ Returns:
+ torch.Tensor: (N, 3) or (B, N, 3) 3D coordinates.
+
+ Raises:
+ ValueError: Shape of input points is not valid for computation.
+ """
+ if len(points.shape) == 2:
+ assert (
+ len(intrinsics.shape) == 2 or intrinsics.shape[0] == 1
+ ), "Got multiple intrinsics for single point set!"
+ if len(intrinsics.shape) == 3:
+ intrinsics = intrinsics.squeeze(0)
+ inv_intrinsics = inverse_pinhole(intrinsics).transpose(0, 1)
+ if len(depths.shape) == 1:
+ depths = depths.unsqueeze(-1)
+ assert len(depths.shape) == 2, "depths must have same dims as points"
+ elif len(points.shape) == 3:
+ inv_intrinsics = inverse_pinhole(intrinsics).transpose(-2, -1)
+ if len(depths.shape) == 2:
+ depths = depths.unsqueeze(-1)
+ assert len(depths.shape) == 3, "depths must have same dims as points"
+ else:
+ raise ValueError(f"Shape of input points not valid: {points.shape}")
+ hom_coords = torch.cat([points, torch.ones_like(points)[..., 0:1]], -1)
+ pts_3d = hom_coords @ inv_intrinsics
+ pts_3d *= depths
+ return pts_3d
+
+
+def points_inside_image(
+ points_coord: torch.Tensor,
+ depths: torch.Tensor,
+ images_hw: torch.Tensor | tuple[int, int],
+) -> torch.Tensor:
+ """Generate binary mask.
+
+ Creates a mask that is true for all point coordiantes that lie inside the
+ image,
+
+ Args:
+ points_coord (torch.Tensor): 2D pixel coordinates of shape [..., 2].
+ depths (torch.Tensor): Associated depth of each 2D pixel coordinate.
+ images_hw: (torch.Tensor| tuple[int, int]]) Associated tensor of image
+ dimensions, shape [..., 2] or single height, width pair.
+
+ Returns:
+ torch.Tensor: Binary mask of points inside an image.
+ """
+ mask = torch.ones_like(depths)
+ h: int | torch.Tensor
+ w: int | torch.Tensor
+
+ if isinstance(images_hw, tuple):
+ h, w = images_hw
+ else:
+ h, w = images_hw[..., 0], images_hw[..., 1]
+ mask = torch.logical_and(mask, torch.greater(depths, 0))
+ mask = torch.logical_and(mask, points_coord[..., 0] > 0)
+ mask = torch.logical_and(mask, points_coord[..., 0] < w - 1)
+ mask = torch.logical_and(mask, points_coord[..., 1] > 0)
+ mask = torch.logical_and(mask, points_coord[..., 1] < h - 1)
+ return mask
+
+
+def generate_depth_map(
+ points: torch.Tensor,
+ intrinsics: torch.Tensor,
+ image_hw: tuple[int, int],
+) -> torch.Tensor:
+ """Generate depth map for given pointcloud.
+
+ Args:
+ points: (N, 3) coordinates.
+ intrinsics: (3, 3) intrinsic camera matrices.
+ image_hw: (tuple[int,int]) height, width of the image
+
+ Returns:
+ torch.Tensor: Projected depth map of the given pointcloud.
+ Invalid depth has 0 values
+ """
+ pts_2d = project_points(points, intrinsics).round()
+ depths = points[:, 2]
+ depth_map = points.new_zeros(image_hw)
+ mask = points_inside_image(pts_2d, depths, image_hw)
+ pts_2d = pts_2d[mask].long()
+ depth_map[pts_2d[:, 1], pts_2d[:, 0]] = depths[mask]
+ return depth_map
diff --git a/vis4d/op/geometry/rotation.py b/vis4d/op/geometry/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b605ecc567b7a5d19f985d62f32ab7b8252861b2
--- /dev/null
+++ b/vis4d/op/geometry/rotation.py
@@ -0,0 +1,539 @@
+"""Rotation utilities."""
+
+import functools
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from vis4d.data.const import AxisMode
+
+
+def normalize_angle(input_angles: Tensor) -> Tensor:
+ """Normalize content of input_angles to range [-pi, pi].
+
+ Args:
+ input_angles: (Tensor) tensor of any shape containing
+ unnormalized angles.
+
+ Returns:
+ Tensor with angles normalized to +/- pi
+ """
+ return torch.sub((input_angles + torch.pi) % (2 * torch.pi), torch.pi)
+
+
+def acute_angle(theta_1: Tensor, theta_2: Tensor) -> Tensor:
+ """Update theta_1 to mkae the agnle between two thetas is acute."""
+ # Make sure the angle between two thetas is acute
+ if torch.pi / 2.0 < abs(theta_2 - theta_1) < torch.pi * 3 / 2.0:
+ theta_1 += torch.pi
+ if theta_1 > torch.pi:
+ theta_1 -= torch.pi * 2
+ if theta_1 < -torch.pi:
+ theta_1 += torch.pi * 2
+
+ # Convert the case of > 270 to < 90
+ if abs(theta_2 - theta_1) >= torch.pi * 3 / 2.0:
+ if theta_2 > 0:
+ theta_1 += torch.pi * 2
+ else:
+ theta_1 -= torch.pi * 2
+ return theta_1
+
+
+def yaw2alpha(rot_y: Tensor, center: Tensor) -> Tensor:
+ """Get alpha by vertical rotation - theta.
+
+ Args:
+ rot_y: Rotation around Y-axis in camera coordinates [-pi..pi]
+ center: 3D object center in camera coordinates
+
+ Returns:
+ alpha: Observation angle of object, ranging [-pi..pi]
+ """
+ alpha = rot_y - torch.atan2(center[..., 0], center[..., 2])
+ return normalize_angle(alpha)
+
+
+def alpha2yaw(alpha: Tensor, center: Tensor) -> Tensor:
+ """Get vertical rotation by alpha + theta.
+
+ Args:
+ alpha: Observation angle of object, ranging [-pi..pi]
+ center: 3D object center in camera coordinates
+
+ Returns:
+ rot_y: Vertical rotation in camera coordinates [-pi..pi]
+ """
+ rot_y = alpha + torch.atan2(center[..., 0], center[..., 2])
+ return normalize_angle(rot_y)
+
+
+def rotation_output_to_alpha(output: Tensor, num_bins: int = 2) -> Tensor:
+ """Get alpha from bin-based regression output.
+
+ Uses method described in (with two bins):
+ See: 3D Bounding Box Estimation Using Deep Learning and Geometry,
+ Mousavian et al., CVPR'17
+
+ Args:
+ output: (Tensor) bin based regressed output.
+ num_bins: (int) number of bins to use
+
+ Returns:
+ Tensor containing the angle from the bin-based regression output
+ """
+ out_range = torch.tensor(list(range(len(output))), device=output.device)
+ bin_idx = output[:, :num_bins].argmax(dim=-1)
+ res_idx = num_bins + 2 * bin_idx
+ bin_centers = torch.arange(
+ -torch.pi, torch.pi, 2 * torch.pi / num_bins, device=output.device
+ )
+ bin_centers += torch.pi / num_bins
+ alpha = (
+ torch.atan(output[out_range, res_idx] / output[out_range, res_idx + 1])
+ + bin_centers[bin_idx]
+ )
+ return alpha
+
+
+def generate_rotation_output(pred: Tensor, num_bins: int = 2) -> Tensor:
+ """Convert output to bin confidence and cos / sin of residual.
+
+ The viewpoint (alpha) prediction (N, num_bins + 2 * num_bins) consists of:
+ bin confidences (N, num_bins): softmax logits for bin probability.
+ 1st entry is probability for orientation being in bin 1,
+ 2nd entry is probability for orientation being in bin 2,
+ and so on.
+ bin residual (N, num_bins * 2): angle residual w.r.t. bin N orientation,
+ represented as sin and cos values.
+
+ See: 3D Bounding Box Estimation Using Deep Learning and Geometry,
+ Mousavian et al., CVPR'17
+ """
+ pred = pred.view(pred.size(0), -1, 3 * num_bins)
+ bin_logits = pred[..., :num_bins]
+
+ bin_residuals = []
+ for i in range(num_bins):
+ res_idx = num_bins + 2 * i
+ norm = pred[..., res_idx : res_idx + 2].norm(dim=-1, keepdim=True)
+ bsin = pred[..., res_idx : res_idx + 1] / norm
+ bcos = pred[..., res_idx + 1 : res_idx + 2] / norm
+ bin_residuals.append(bsin)
+ bin_residuals.append(bcos)
+
+ rot = torch.cat([bin_logits, *bin_residuals], -1)
+ return rot
+
+
+# Rotation conversion functions adapted from:
+# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
+def _axis_angle_rotation(axis: str, angle: Tensor) -> Tensor:
+ """Get rotation matrix for an angle around an axis.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ assert axis in {"X", "Y", "Z"}, f"Invalid axis {axis}."
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ rot_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ rot_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ else:
+ rot_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+
+ return torch.stack(rot_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def euler_angles_to_matrix(
+ euler_angles: Tensor, convention: str = "XYZ"
+) -> Tensor:
+ """Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
+ convention: Convention string of three uppercase letters from
+ "X", "Y", and "Z".
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+
+ Raises:
+ ValueError: if convention string is not a combination of XYZ
+ """
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ _axis_angle_rotation(c, a)
+ for c, a in zip(convention, torch.unbind(euler_angles, -1))
+ ]
+ return functools.reduce(torch.matmul, matrices)
+
+
+def _index_from_letter(letter: str) -> int: # pragma: no cover
+ """Return index from letter.
+
+ Args:
+ letter: (str) letter in [X,Y,Z]
+
+ Returns:
+ int mapping of the corresponding letter [0,1,2]
+
+ Raises:
+ ValueError: if the given letter is not valid
+ """
+ if letter == "X":
+ return 0
+ if letter == "Y":
+ return 1
+ if letter == "Z":
+ return 2
+ raise ValueError("letter not valid!")
+
+
+def _angle_from_tan(
+ axis: str,
+ other_axis: str,
+ data: Tensor,
+ horizontal: bool,
+ tait_bryan: bool,
+) -> Tensor:
+ """Helper function for matrix_to_euler_angles.
+
+ Extracts the first or third Euler angle from the two members of
+ the matrix which are positive constant times its sine and cosine.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
+ convention.
+ data: Rotation matrices as tensor of shape (..., 3, 3).
+ horizontal: Whether we are looking for the angle for the third axis,
+ which means the relevant entries are in the same row of the
+ rotation matrix. If not, they are in the same column.
+ tait_bryan: Whether the first and third axes in the convention differ.
+
+ Returns:
+ Euler Angles in radians for each matrix in data as a tensor
+ of shape (...).
+ """
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
+ if horizontal:
+ i2, i1 = i1, i2
+ even = axis + other_axis in {"XY", "YZ", "ZX"}
+ if horizontal == even:
+ return torch.atan2(data[..., i1], data[..., i2])
+ if tait_bryan:
+ return torch.atan2(-data[..., i2], data[..., i1])
+ return torch.atan2(data[..., i2], -data[..., i1])
+
+
+def matrix_to_euler_angles(matrix: Tensor, convention: str = "XYZ") -> Tensor:
+ """Convert rotations given as rotation matrices to Euler angles in radians.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ Returns:
+ Euler angles in radians as tensor of shape (..., 3).
+
+ Raises:
+ ValueError: if convention string is not a combination of XYZ
+ """
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+ i0 = _index_from_letter(convention[0])
+ i2 = _index_from_letter(convention[2])
+ tait_bryan = i0 != i2
+ if tait_bryan:
+ rads = matrix[..., i0, i2]
+ # safety for nan
+ rads[torch.where(rads > 1.0)] = rads.new_tensor([1.0]).to(rads.device)
+ rads[torch.where(rads < -1.0)] = rads.new_tensor([-1.0]).to(
+ rads.device
+ )
+ central_angle = torch.asin(
+ rads * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
+ )
+ else:
+ central_angle = torch.acos(matrix[..., i0, i0])
+
+ o = (
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
+ ),
+ )
+ return torch.stack(o, -1)
+
+
+def quaternion_to_matrix(quaternions: Tensor) -> Tensor:
+ """Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def _sqrt_positive_part(quat: Tensor) -> Tensor:
+ """Returns sqrt(max(0, x)) but with a zero subgradient where x is 0."""
+ ret = torch.zeros_like(quat)
+ positive_mask = quat > 0
+ ret[positive_mask] = torch.sqrt(quat[positive_mask])
+ return ret
+
+
+def matrix_to_quaternion(matrix: Tensor) -> Tensor:
+ """Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+
+ Raises:
+ ValueError: If shape of input matrix is not correct.
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(*batch_dim, 9), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ torch.stack(
+ [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1
+ ),
+ torch.stack(
+ [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1
+ ),
+ torch.stack(
+ [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1
+ ),
+ torch.stack(
+ [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1
+ ),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is
+ # small, the candidate won't be picked.
+ quat_candidates = quat_by_rijk / (
+ 2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1))
+ )
+
+ # if not for numerical problems, quat_candidates[i] should be same
+ # (up to a sign), forall i; we pick the best-conditioned one
+ # (with the largest denominator)
+
+ return quat_candidates[
+ F.one_hot( # pylint: disable=not-callable
+ q_abs.argmax(dim=-1), num_classes=4
+ )
+ > 0.5,
+ :, # pyre-ignore[16]
+ ].reshape(*batch_dim, 4)
+
+
+def standardize_quaternion(quaternions: Tensor) -> Tensor:
+ """Convert a unit quaternion to a standard form.
+
+ Standard form: One in which the real part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+
+def quaternion_raw_multiply(quat1: Tensor, quat2: Tensor) -> Tensor:
+ """Multiply two quaternions.
+
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ quat1: Quaternions as tensor of shape (..., 4), real part first.
+ quat2: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of quat1 and quat2, tensor of quaternions shape (..., 4).
+ """
+ aw, ax, ay, az = torch.unbind(quat1, -1)
+ bw, bx, by, bz = torch.unbind(quat2, -1)
+ ow = aw * bw - ax * bx - ay * by - az * bz
+ ox = aw * bx + ax * bw + ay * bz - az * by
+ oy = aw * by - ax * bz + ay * bw + az * bx
+ oz = aw * bz + ax * by - ay * bx + az * bw
+ return torch.stack((ow, ox, oy, oz), -1)
+
+
+def quaternion_multiply(quat1: Tensor, quat2: Tensor) -> Tensor:
+ """Multiply two quaternions representing rotations.
+
+ Returns the quaternion representing their composition, i.e. the version
+ with nonnegative real part. Usual torch rules for broadcasting apply.
+
+ Args:
+ quat1: Quaternions as tensor of shape (..., 4), real part first.
+ quat2: Quaternions as tensor of shape (..., 4), real part first.
+
+ Returns:
+ The product of quat1 and quat2, tensor of quaternions shape (..., 4).
+ """
+ return standardize_quaternion(quaternion_raw_multiply(quat1, quat2))
+
+
+def quaternion_invert(quaternion: Tensor) -> Tensor:
+ """Return quaternion that represents inverse rotation.
+
+ Args:
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
+ first, which must be versors (unit quaternions).
+
+ Returns:
+ The inverse, a tensor of quaternions of shape (..., 4).
+ """
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
+
+
+def quaternion_apply(quaternion: Tensor, points: Tensor) -> Tensor:
+ """Apply the rotation given by a quaternion to a 3D point.
+
+ Usual torch rules for broadcasting apply.
+
+ Args:
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
+ points: Tensor of 3D points of shape (..., 3).
+
+ Returns:
+ Tensor of rotated points of shape (..., 3).
+
+ Raises:
+ ValueError: If points is not a valid 3D point set.
+ """
+ if points.size(-1) != 3:
+ raise ValueError(f"Points are not in 3D, {points.shape}.")
+ real_parts = points.new_zeros(points.shape[:-1] + (1,))
+ point_as_quaternion = torch.cat((real_parts, points), -1)
+ out = quaternion_raw_multiply(
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
+ quaternion_invert(quaternion),
+ )
+ return out[..., 1:]
+
+
+def rotation_matrix_yaw(
+ rotation_matrix: Tensor, axis_mode: AxisMode
+) -> Tensor:
+ """Get yaw of 3D boxes in euler angle under given axis mode.
+
+ Args:
+ rotation_matrix (Tensor): [N, 3, 3] Rotation matrix of the object.
+ axis_mode (AxisMode): Coordinate system convention.
+
+ Returns:
+ orientation (Tensor): [N, 3] Yaw in euler angle.
+ """
+ orientation = rotation_matrix.new_zeros(rotation_matrix.shape[0], 3)
+
+ if axis_mode == AxisMode.OPENCV:
+ orientation[:, 1] = matrix_to_euler_angles(rotation_matrix, "YZX")[
+ :, 0
+ ]
+ else:
+ orientation[:, 2] = matrix_to_euler_angles(rotation_matrix, "ZYX")[
+ :, 0
+ ]
+ return orientation
+
+
+def rotate_orientation(
+ orientation: Tensor, extrinsics: Tensor, axis_mode: AxisMode = AxisMode.ROS
+) -> Tensor:
+ """Rotate the orientation of the object in different coordinate.
+
+ Args:
+ orientation (Tensor): [N, 3] Orientation of the object in euler angles.
+ extrinsics (Tensor): [4, 4] Extrinsic matrix of the object.
+ axis_mode (AxisMode): Coordinate system convention. Default:
+ AxisMode.ROS
+ """
+ rot = extrinsics[:3, :3] @ euler_angles_to_matrix(orientation)
+ return rotation_matrix_yaw(rot, axis_mode)
+
+
+def rotate_velocities(velocities: Tensor, extrinsics: Tensor) -> Tensor:
+ """Rotate the velocities of the object in different coordinate."""
+ return (extrinsics[:3, :3] @ velocities.unsqueeze(-1)).squeeze(-1)
diff --git a/vis4d/op/geometry/transform.py b/vis4d/op/geometry/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..2276419a55cb386f840dc6f536dff26756d009df
--- /dev/null
+++ b/vis4d/op/geometry/transform.py
@@ -0,0 +1,120 @@
+"""Vis4D geometric transformation functions."""
+
+import torch
+from torch import Tensor
+
+
+def transform_points(points: Tensor, transform: Tensor) -> Tensor:
+ """Applies transform to points.
+
+ Args:
+ points (Tensor): points of shape (N, D) or (B, N, D).
+ transform (Tensor): transforms of shape (D+1, D+1) or (B, D+1, D+1).
+
+ Returns:
+ Tensor: (N, D) / (B, N, D) transformed points.
+
+ Raises:
+ ValueError: Either points or transform have incorrect shape
+ """
+ hom_coords = torch.cat([points, torch.ones_like(points[..., 0:1])], -1)
+ if len(points.shape) == 2:
+ if len(transform.shape) == 3:
+ assert (
+ transform.shape[0] == 1
+ ), "Got multiple transforms for single point set!"
+ transform = transform.squeeze(0)
+ transform = transform.T
+ elif len(points.shape) == 3:
+ if len(transform.shape) == 2:
+ transform = transform.T.unsqueeze(0)
+ elif len(transform.shape) == 3:
+ transform = transform.permute(0, 2, 1)
+ else:
+ raise ValueError(f"Shape of transform invalid: {transform.shape}")
+ else:
+ raise ValueError(f"Shape of input points invalid: {points.shape}")
+ points_transformed = hom_coords @ transform
+ return points_transformed[..., : points.shape[-1]]
+
+
+def inverse_pinhole(intrinsic_matrix: Tensor) -> Tensor:
+ """Calculate inverse of pinhole projection matrix.
+
+ Args:
+ intrinsic_matrix (Tensor): [..., 3, 3] intrinsics or single [3, 3]
+ intrinsics.
+
+ Returns:
+ Tensor: Inverse of input intrinisics.
+ """
+ squeeze = False
+ inv = intrinsic_matrix.clone()
+ if len(intrinsic_matrix.shape) == 2:
+ inv = inv.unsqueeze(0)
+ squeeze = True
+
+ inv[..., 0, 0] = 1.0 / inv[..., 0, 0]
+ inv[..., 1, 1] = 1.0 / inv[..., 1, 1]
+ inv[..., 0, 2] = -inv[..., 0, 2] * inv[..., 0, 0]
+ inv[..., 1, 2] = -inv[..., 1, 2] * inv[..., 1, 1]
+
+ if squeeze:
+ inv = inv.squeeze(0)
+ return inv
+
+
+def inverse_rigid_transform(transformation: Tensor) -> Tensor:
+ """Calculate inverse of rigid body transformation(s).
+
+ Args:
+ transformation (Tensor): [N, 4, 4] transformations or single [4, 4]
+ transformation.
+
+ Returns:
+ Tensor: Inverse of input transformation(s).
+ """
+ squeeze = False
+ if len(transformation.shape) == 2:
+ transformation = transformation.unsqueeze(0)
+ squeeze = True
+ rotation, translation = transformation[:, :3, :3], transformation[:, :3, 3]
+ rot = rotation.permute(0, 2, 1)
+ t = -rot @ translation[:, :, None]
+ inv = torch.cat([torch.cat([rot, t], -1), transformation[:, 3:4]], 1)
+ if squeeze:
+ inv = inv.squeeze(0)
+ return inv
+
+
+def get_transform_matrix(rotation: Tensor, translation: Tensor) -> Tensor:
+ """Assembles 4x4 transformation from rotation / translation pair(s).
+
+ Args:
+ rotation (Tensor): [N, 3, 3] or [3, 3] rotation(s).
+ translation (Tensor): [N, 3] or [3,] translation(s).
+
+ Returns:
+ Tensor: [N, 4, 4] or [4, 4] transformation.
+ """
+ squeeze = False
+ if len(rotation.shape) == 2:
+ assert len(translation.shape) == 1
+ rotation = rotation.unsqueeze(0)
+ translation = translation.unsqueeze(0)
+ squeeze = True
+ batch_size = 1
+ else:
+ assert len(rotation.shape) == 3 and len(translation.shape) == 2
+ assert rotation.shape[0] == translation.shape[0]
+ batch_size = rotation.shape[0]
+ assert (
+ rotation.shape[-2] == rotation.shape[-1] == translation.shape[-1] == 3
+ )
+ transforms = rotation.new_zeros((batch_size, 4, 4))
+ transforms[:, :3, :3] = rotation
+ transforms[:, :3, 3] = translation
+ transforms[:, 3, 3] = 1.0
+ if squeeze:
+ transforms = transforms.squeeze(0)
+ return transforms
diff --git a/vis4d/op/layer/__init__.py b/vis4d/op/layer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7baaa368a725604b71292dab2dd125a9e3799a94
--- /dev/null
+++ b/vis4d/op/layer/__init__.py
@@ -0,0 +1 @@
+"""layers op module."""
diff --git a/vis4d/op/layer/__pycache__/__init__.cpython-311.pyc b/vis4d/op/layer/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6956b9af70383dd46441c309a28ca27c9063113a
Binary files /dev/null and b/vis4d/op/layer/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/attention.cpython-311.pyc b/vis4d/op/layer/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6243d27908d23a077773034e27e4ec7482dea5a8
Binary files /dev/null and b/vis4d/op/layer/__pycache__/attention.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/conv2d.cpython-311.pyc b/vis4d/op/layer/__pycache__/conv2d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9138db1a7adb5e568b9a9cff3a1fe87c1387cc48
Binary files /dev/null and b/vis4d/op/layer/__pycache__/conv2d.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/deform_conv.cpython-311.pyc b/vis4d/op/layer/__pycache__/deform_conv.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f56497013dc1b5e804d8c716b445938dd79c7c69
Binary files /dev/null and b/vis4d/op/layer/__pycache__/deform_conv.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/drop.cpython-311.pyc b/vis4d/op/layer/__pycache__/drop.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ba08d87d4986851f8e6769b1cb7fe350092636c
Binary files /dev/null and b/vis4d/op/layer/__pycache__/drop.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/mlp.cpython-311.pyc b/vis4d/op/layer/__pycache__/mlp.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d15262fd304789cb7dcbff07aa5882d39fd92f7
Binary files /dev/null and b/vis4d/op/layer/__pycache__/mlp.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/transformer.cpython-311.pyc b/vis4d/op/layer/__pycache__/transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1d2d61ef5c2f90236af5f07376f660283924ff4
Binary files /dev/null and b/vis4d/op/layer/__pycache__/transformer.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/util.cpython-311.pyc b/vis4d/op/layer/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af205966e5dc357cced6e5aa3816db2f5e0ff6e3
Binary files /dev/null and b/vis4d/op/layer/__pycache__/util.cpython-311.pyc differ
diff --git a/vis4d/op/layer/__pycache__/weight_init.cpython-311.pyc b/vis4d/op/layer/__pycache__/weight_init.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4eee10c16a20de1ce4e8676847aa97100c10eea
Binary files /dev/null and b/vis4d/op/layer/__pycache__/weight_init.cpython-311.pyc differ
diff --git a/vis4d/op/layer/attention.py b/vis4d/op/layer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..f07212d0e3bd39a3bb338d565bfe845a16e50e4e
--- /dev/null
+++ b/vis4d/op/layer/attention.py
@@ -0,0 +1,241 @@
+"""Attention layer."""
+
+from __future__ import annotations
+
+from torch import Tensor, nn
+
+from vis4d.common.logging import rank_zero_warn
+from vis4d.common.typing import ArgsType
+
+
+class Attention(nn.Module):
+ """ViT Attention Layer.
+
+ Modified from timm (https://github.com/huggingface/pytorch-image-models).
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ """Init attention layer.
+
+ Args:
+ dim (int): Input tensor's dimension.
+ num_heads (int, optional): Number of attention heads. Defaults to
+ 8.
+ qkv_bias (bool, optional): If to add bias to qkv. Defaults to
+ False.
+ attn_drop (float, optional): Dropout rate for attention. Defaults
+ to 0.0.
+ proj_drop (float, optional): Dropout rate for projection. Defaults
+ to 0.0.
+ """
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def __call__(self, data: Tensor) -> Tensor:
+ """Applies the layer.
+
+ Args:
+ data (Tensor): Input tensor of shape (B, N, dim).
+
+ Returns:
+ Tensor: Output tensor of the same shape as input.
+ """
+ return self._call_impl(data)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ batch_size, num_samples, dim = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(
+ batch_size,
+ num_samples,
+ 3,
+ self.num_heads,
+ dim // self.num_heads,
+ )
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv.unbind(
+ 0
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, num_samples, dim)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MultiheadAttention(nn.Module):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int,
+ num_heads: int,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ dropout_layer: nn.Module | None = None,
+ batch_first: bool = False,
+ need_weights: bool = False,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Init MultiheadAttention.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (nn.Module | None, optional): The dropout_layer used
+ when adding the shortcut. Defaults to None.
+ batch_first (bool): When it is True, Key, Query and Value are
+ shape of (batch, n, embed_dim), otherwise (n, batch,
+ embed_dim). Default to False.
+ need_weights (bool): Whether to return the attention weights.
+ If True, the output will be a tuple of (attn_output,
+ attn_output_weights) and not using FlashAttention. If False,
+ only the attn_output will be returned. Default to False.
+ """
+ super().__init__()
+ self.batch_first = batch_first
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.need_weights = need_weights
+
+ self.attn = nn.MultiheadAttention(
+ embed_dims, num_heads, dropout=attn_drop, **kwargs
+ )
+
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.dropout_layer = dropout_layer or nn.Identity()
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor | None = None,
+ value: Tensor | None = None,
+ identity: Tensor | None = None,
+ query_pos: Tensor | None = None,
+ key_pos: Tensor | None = None,
+ attn_mask: Tensor | None = None,
+ key_padding_mask: Tensor | None = None,
+ ) -> Tensor:
+ """Forward function for `MultiheadAttention`.
+
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as query,
+ will be used for the identity link.
+ If None, `query` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `query`. If not None, it will
+ be added to `query` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims]
+ if self.batch_first is False, else [bs, num_queries,
+ embed_dims].
+ """
+ if key is None:
+ key = query
+
+ if value is None:
+ value = key
+
+ if identity is None:
+ identity = query
+
+ if key_pos is None and query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ rank_zero_warn(
+ f"Position encoding of key in {self.__class__.__name__}"
+ + "is missing, and positional encodeing of query has "
+ + "has different shape and cannot be usde for key. "
+ + "It it is not desired, please provide key_pos."
+ )
+
+ if query_pos is not None:
+ query = query + query_pos
+
+ if key_pos is not None:
+ key = key + key_pos
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query, batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ out = self.attn(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=self.need_weights,
+ )
+
+ if isinstance(out, tuple):
+ out = out[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
diff --git a/vis4d/op/layer/conv2d.py b/vis4d/op/layer/conv2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..81521b5565080285c422b95d6cb68e096285884d
--- /dev/null
+++ b/vis4d/op/layer/conv2d.py
@@ -0,0 +1,283 @@
+"""Wrapper for conv2d."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from vis4d.common.typing import ArgsType
+
+from .weight_init import constant_init
+
+
+class Conv2d(nn.Conv2d):
+ """Wrapper around Conv2d to support empty inputs and norm/activation."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ norm: nn.Module | None = None,
+ activation: nn.Module | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class.
+
+ If norm is specified, it is initialized with 1.0 and bias with 0.0.
+ """
+ super().__init__(*args, **kwargs)
+ self.norm = norm
+ self.activation = activation
+
+ if self.norm is not None:
+ constant_init(self.norm, 1.0, bias=0.0)
+
+ def forward( # pylint: disable=arguments-renamed
+ self, x: Tensor
+ ) -> Tensor:
+ """Forward pass."""
+ if not torch.jit.is_scripting(): # type: ignore
+ # https://github.com/pytorch/pytorch/issues/12013
+ if (
+ x.numel() == 0
+ and self.training
+ and isinstance(self.norm, nn.SyncBatchNorm)
+ ):
+ raise ValueError(
+ "SyncBatchNorm does not support empty inputs!"
+ )
+
+ x = F.conv2d( # pylint: disable=not-callable
+ x,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ )
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+
+def add_conv_branch(
+ num_branch_convs: int,
+ last_layer_dim: int,
+ conv_out_dim: int,
+ conv_has_bias: bool,
+ norm_cfg: str | None,
+ num_groups: int | None,
+) -> tuple[nn.ModuleList, int]:
+ """Init conv branch for head."""
+ convs = nn.ModuleList()
+ if norm_cfg is not None:
+ norm = getattr(nn, norm_cfg)
+ else:
+ norm = None
+
+ if norm == nn.GroupNorm:
+ assert num_groups is not None, "num_groups must be specified"
+ norm = lambda x: nn.GroupNorm( # pylint: disable=unnecessary-lambda-assignment
+ num_groups, x
+ )
+ if num_branch_convs > 0:
+ for i in range(num_branch_convs):
+ conv_in_dim = last_layer_dim if i == 0 else conv_out_dim
+ convs.append(
+ Conv2d(
+ conv_in_dim,
+ conv_out_dim,
+ kernel_size=3,
+ padding=1,
+ bias=conv_has_bias,
+ norm=norm(conv_out_dim) if norm is not None else norm,
+ activation=nn.ReLU(inplace=True),
+ )
+ )
+ last_layer_dim = conv_out_dim
+
+ return convs, last_layer_dim
+
+
+class UnetDownConvOut(NamedTuple):
+ """Output of the UnetDownConv operator.
+
+ features: Features before applying the pooling operator
+ pooled_features: Features after applying the pooling operator
+ """
+
+ features: Tensor
+ pooled_features: Tensor
+
+
+class UnetDownConv(nn.Module):
+ """Downsamples a feature map by applying two convolutions and maxpool."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ pooling: bool = True,
+ activation: str = "ReLU",
+ ):
+ """Creates a new downsampling convolution operator.
+
+ This operator consists of two convolutions followed by a maxpool
+ operator.
+
+ Args:
+ in_channels (int): input channesl
+ out_channels (int): output channesl
+ pooling (bool): If pooling should be applied
+ activation (str): Activation that should be applied
+ """
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.pooling = pooling
+ activation = getattr(nn, activation)()
+
+ self.conv1 = nn.Conv2d(
+ self.in_channels,
+ self.out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ bias=True,
+ )
+ self.conv2 = nn.Conv2d(
+ self.out_channels,
+ self.out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ bias=True,
+ )
+
+ if self.pooling:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ def __call__(self, data: Tensor) -> UnetDownConvOut:
+ """Applies the operator.
+
+ Args:
+ data (Tensor): Input data.
+
+ Returns:
+ UnetDownConvOut: Containing the features before the pooling
+ operation (features) and after (pooled_features).
+ """
+ return self._call_impl(data)
+
+ def forward(self, data: Tensor) -> UnetDownConvOut:
+ """Applies the operator.
+
+ Args:
+ data (Tensor): Input data.
+
+ Returns:
+ UnetDownConvOut: containing the features before the pooling
+ operation (features) and after (pooled_features).
+ """
+ x = F.relu(self.conv1(data))
+ x = F.relu(self.conv2(x))
+ before_pool = x
+ if self.pooling:
+ x = self.pool(x)
+ return UnetDownConvOut(features=before_pool, pooled_features=x)
+
+
+class UnetUpConv(nn.Module):
+ """An operator that performs 2 convolutions and 1 UpConvolution.
+
+ A ReLU activation follows each convolution.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ merge_mode: str = "concat",
+ up_mode: str = "transpose",
+ ):
+ """Creates a new UpConv operator.
+
+ This operator merges two inputs by upsampling one and combining it with
+ the other.
+
+ Args:
+ in_channels: Number of input channels (low res)
+ out_channels: Number of output channels (high res)
+ merge_mode: How to merge both input channels
+ up_mode: How to upsample the channel with lower resolution
+
+ Raises:
+ ValueError: If upsampling mode is unknown
+ """
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.merge_mode = merge_mode
+ self.up_mode = up_mode
+
+ # Upsampling
+ if self.up_mode == "transpose":
+ self.upconv: nn.Module = nn.ConvTranspose2d(
+ in_channels, out_channels, kernel_size=2, stride=2
+ )
+ elif self.up_mode == "upsample":
+ self.upconv = nn.Sequential(
+ nn.Upsample(mode="bilinear", scale_factor=2),
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
+ )
+ else:
+ raise ValueError(f"Unknown upsampling mode: {up_mode}")
+
+ if self.merge_mode == "concat":
+ self.conv1 = nn.Conv2d(
+ 2 * self.out_channels, self.out_channels, 3, padding=1
+ )
+ else:
+ # num of input channels to conv2 is same
+ self.conv1 = nn.Conv2d(
+ self.out_channels, self.out_channels, 3, padding=1
+ )
+ self.conv2 = nn.Conv2d(
+ self.out_channels, self.out_channels, 3, padding=1
+ )
+
+ def __call__(self, from_down: Tensor, from_up: Tensor) -> Tensor:
+ """Forward pass.
+
+ Arguments:
+ from_down (Tensor): Tensor from the encoder pathway. Assumed to
+ have dimension 'out_channels'
+ from_up (Tensor): Upconv'd tensor from the decoder pathway. Assumed
+ to have dimension 'in_channels'
+ """
+ return self._call_impl(from_down, from_up)
+
+ def forward(self, from_down: Tensor, from_up: Tensor) -> Tensor:
+ """Forward pass.
+
+ Arguments:
+ from_down (Tensor): Tensor from the encoder pathway. Assumed to
+ have dimension 'out_channels'
+ from_up (Tensor): Upconv'd tensor from the decoder pathway. Assumed
+ to have dimension 'in_channels'
+ """
+ from_up = self.upconv(from_up)
+ if self.merge_mode == "concat":
+ x = torch.cat((from_up, from_down), 1)
+ else:
+ x = from_up + from_down
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ return x
diff --git a/vis4d/op/layer/csp_layer.py b/vis4d/op/layer/csp_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..47dc5e6f7678ccd220a8bf3203295c448a4e4bf0
--- /dev/null
+++ b/vis4d/op/layer/csp_layer.py
@@ -0,0 +1,146 @@
+"""Cross Stage Partial Layer.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+from .conv2d import Conv2d
+
+
+class DarknetBottleneck(nn.Module):
+ """The basic bottleneck block used in Darknet.
+
+ Each ResBlock consists of two Conv blocks and the input is added to the
+ final output. Each block is composed of Conv, BN, and SiLU.
+ The first convolutional layer has filter size of 1x1 and the second one
+ has filter size of 3x3.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ expansion (float, optional): The kernel size of the convolution.
+ Defaults to 0.5.
+ add_identity (bool, optional): Whether to add identity to the output.
+ Defaults to True.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: float = 0.5,
+ add_identity: bool = True,
+ ):
+ """Init."""
+ super().__init__()
+ hidden_channels = int(out_channels * expansion)
+ self.conv1 = Conv2d(
+ in_channels,
+ hidden_channels,
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(hidden_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+ self.conv2 = Conv2d(
+ hidden_channels,
+ out_channels,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False,
+ norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+ self.add_identity = add_identity and in_channels == out_channels
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ features (torch.Tensor): Input features.
+ """
+ identity = features
+ out = self.conv1(features)
+ out = self.conv2(out)
+
+ if self.add_identity:
+ return out + identity
+ return out
+
+
+class CSPLayer(nn.Module):
+ """Cross Stage Partial Layer.
+
+ Args:
+ in_channels (int): The input channels of the CSP layer.
+ out_channels (int): The output channels of the CSP layer.
+ expand_ratio (float, optional): Ratio to adjust the number of channels
+ of the hidden layer. Defaults to 0.5.
+ num_blocks (int, optional): Number of blocks. Defaults to 1.
+ add_identity (bool, optional): Whether to add identity in blocks.
+ Defaults to True.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 0.5,
+ num_blocks: int = 1,
+ add_identity: bool = True,
+ ):
+ """Init."""
+ super().__init__()
+ mid_channels = int(out_channels * expand_ratio)
+ self.main_conv = Conv2d(
+ in_channels,
+ mid_channels,
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+ self.short_conv = Conv2d(
+ in_channels,
+ mid_channels,
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(mid_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+ self.final_conv = Conv2d(
+ 2 * mid_channels,
+ out_channels,
+ 1,
+ bias=False,
+ norm=nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03),
+ activation=nn.SiLU(inplace=True),
+ )
+
+ self.blocks = nn.Sequential(
+ *[
+ DarknetBottleneck(
+ mid_channels, mid_channels, 1.0, add_identity
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ features (torch.Tensor): Input features.
+ """
+ x_short = self.short_conv(features)
+
+ x_main = self.main_conv(features)
+ x_main = self.blocks(x_main)
+
+ x_final = torch.cat((x_main, x_short), dim=1)
+ return self.final_conv(x_final)
diff --git a/vis4d/op/layer/deform_conv.py b/vis4d/op/layer/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a3a52baf2063130799cf6f57e10b988f13545f8
--- /dev/null
+++ b/vis4d/op/layer/deform_conv.py
@@ -0,0 +1,93 @@
+"""Wrapper for deformable convolution."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor, nn
+from torchvision.ops import DeformConv2d
+
+from .weight_init import constant_init
+
+
+class DeformConv(DeformConv2d): # type: ignore
+ """Wrapper around Deformable Convolution operator with norm/activation.
+
+ If norm is specified, it is initialized with 1.0 and bias with 0.0.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ norm: nn.Module | None = None,
+ activation: nn.Module | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ in_channels (int): Input channels.
+ out_channels (int): Output channels.
+ kernel_size (int): Size of convolutional kernel.
+ stride (int, optional): Stride of convolutional layer. Defaults to
+ 1.
+ padding (int, optional): Padding of convolutional layer. Defaults
+ to 0.
+ dilation (int, optional): Dilation of convolutional layer. Defaults
+ to 1.
+ groups (int, optional): Number of deformable groups. Defaults to 1.
+ bias (bool, optional): Whether to use bias in convolutional layer.
+ Defaults to True.
+ norm (nn.Module, optional): Normalization layer. Defaults to None.
+ activation (nn.Module, optional): Activation layer. Defaults to
+ None.
+ """
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ bias=True,
+ )
+ self.norm = norm
+ self.activation = activation
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Initialize weights of offset conv layer."""
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_() # type: ignore
+ if self.norm is not None:
+ constant_init(self.norm, 1.0, bias=0.0)
+
+ def forward( # pylint: disable=arguments-differ
+ self, input_x: Tensor
+ ) -> Tensor:
+ """Forward."""
+ out = self.conv_offset(input_x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ input_x = super().forward(input_x, offset, mask)
+ if self.norm is not None:
+ input_x = self.norm(input_x)
+ if self.activation is not None:
+ input_x = self.activation(input_x)
+ return input_x
diff --git a/vis4d/op/layer/drop.py b/vis4d/op/layer/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0a1cef8b7f5627c158df26fc8b34d480e0709cb
--- /dev/null
+++ b/vis4d/op/layer/drop.py
@@ -0,0 +1,68 @@
+"""DropPath (Stochastic Depth) regularization layers.
+
+Modified from timm (https://github.com/huggingface/pytorch-image-models).
+"""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+
+def drop_path(
+ x: torch.Tensor,
+ drop_prob: float = 0.0,
+ training: bool = False,
+ scale_by_keep: bool = True,
+) -> torch.Tensor:
+ """Drop path regularizer (Stochastic Depth) per sample.
+
+ Args:
+ x (torch.Tensor): Input tensor of shape (batch_size, ...).
+ drop_prob (float, optional): Probability of an element to be zeroed.
+ Defaults to 0.0.
+ training (bool, optional): If to apply drop path. Defaults to False.
+ scale_by_keep (bool, optional): If to scale by keep probability.
+ Defaults to True.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """DropPath regularizer (Stochastic Depth) per sample."""
+
+ def __init__(
+ self, drop_prob: float = 0.0, scale_by_keep: bool = True
+ ) -> None:
+ """Init DropPath.
+
+ Args:
+ drop_prob (float, optional): Probability of an item to be masked.
+ Defaults to 0.0.
+ scale_by_keep (bool, optional): If to scale by keep probability.
+ Defaults to True.
+ """
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def __call__(self, data: torch.Tensor) -> torch.Tensor:
+ """Applies the layer.
+
+ Args:
+ data: (tensor) input shape [N, ...]
+ """
+ return self._call_impl(data)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
diff --git a/vis4d/op/layer/mlp.py b/vis4d/op/layer/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea157511ab5075f93a4e04c9eadb6787b6168ad7
--- /dev/null
+++ b/vis4d/op/layer/mlp.py
@@ -0,0 +1,62 @@
+"""MLP Layers."""
+
+from __future__ import annotations
+
+from torch import Tensor, nn
+
+
+class TransformerBlockMLP(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
+
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int | None = None,
+ out_features: int | None = None,
+ act_layer: nn.Module = nn.GELU(),
+ bias: bool = True,
+ drop: float = 0.0,
+ ):
+ """Init MLP.
+
+ Args:
+ in_features (int): Number of input features.
+ hidden_features (int, optional): Number of hidden features.
+ Defaults to None.
+ out_features (int, optional): Number of output features.
+ Defaults to None.
+ act_layer (nn.Module, optional): Activation layer.
+ Defaults to nn.GELU.
+ bias (bool, optional): If bias should be used. Defaults to True.
+ drop (float, optional): Dropout probability. Defaults to 0.0.
+ """
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer
+ self.drop1 = nn.Dropout(drop)
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop2 = nn.Dropout(drop)
+
+ def __call__(self, data: Tensor) -> Tensor:
+ """Applies the layer.
+
+ Args:
+ data: (tensor) input shape [N, C]
+ """
+ return self._call_impl(data)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x: (tensor) input shape [N, C]
+ """
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
diff --git a/vis4d/op/layer/ms_deform_attn.py b/vis4d/op/layer/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd91ec0e4c84774b7cbdc2bbd0e0f1acac7acfb3
--- /dev/null
+++ b/vis4d/op/layer/ms_deform_attn.py
@@ -0,0 +1,563 @@
+# pylint: disable=no-name-in-module, abstract-method, arguments-differ
+"""Multi-Scale Deformable Attention Module.
+
+Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py) # pylint: disable=line-too-long
+"""
+from __future__ import annotations
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.init import constant_, xavier_uniform_
+
+from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE
+from vis4d.common.logging import rank_zero_warn
+
+if VIS4D_CUDA_OPS_AVAILABLE:
+ from vis4d_cuda_ops import ms_deform_attn_backward, ms_deform_attn_forward
+else:
+ raise ImportError("vis4d_cuda_ops is not installed.")
+
+
+class MSDeformAttentionFunction(Function): # pragma: no cover
+ """Multi-Scale Deformable Attention Function module."""
+
+ @staticmethod
+ def forward( # type: ignore
+ ctx,
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ value_level_start_index: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+ im2col_step: int,
+ ) -> Tensor:
+ """Forward pass."""
+ if not VIS4D_CUDA_OPS_AVAILABLE:
+ raise RuntimeError(
+ "MSDeformAttentionFunction requires vis4d cuda ops to run."
+ )
+ ctx.im2col_step = im2col_step
+ output = ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ctx.im2col_step,
+ )
+ ctx.save_for_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable # type: ignore
+ def backward( # type: ignore
+ ctx, grad_output: Tensor
+ ) -> tuple[Tensor, None, None, Tensor, Tensor, None]:
+ """Backward pass."""
+ if not VIS4D_CUDA_OPS_AVAILABLE:
+ raise RuntimeError(
+ "MSDeformAttentionFunction requires vis4d cuda ops to run."
+ )
+ (
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ) = ctx.saved_tensors
+ (
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ ) = ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output,
+ ctx.im2col_step,
+ )
+
+ return (
+ grad_value,
+ None,
+ None,
+ grad_sampling_loc,
+ grad_attn_weight,
+ None,
+ )
+
+
+def ms_deformable_attention_cpu(
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+) -> Tensor:
+ """CPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape (bs, num_keys, mum_heads,
+ embed_dims // num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of each feature map, has
+ shape (num_levels, 2), last dimension 2 represent (h, w).
+ sampling_locations (Tensor): The location of sampling points, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2), the last
+ dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used when
+ calculate the attention, has shape (bs ,num_queries, num_heads,
+ num_levels, num_points),
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims).
+ """
+ bs, _, num_heads, embed_dims = value.shape
+ (
+ _,
+ num_queries,
+ num_heads,
+ num_levels,
+ num_points,
+ _,
+ ) = sampling_locations.shape
+ value_list = value.split([h * w for h, w in value_spatial_shapes], dim=1)
+ sampling_grids: Tensor = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level, (h, w) in enumerate(value_spatial_shapes):
+ # bs, h*w, num_heads, embed_dims ->
+ # bs, h*w, num_heads*embed_dims ->
+ # bs, num_heads*embed_dims, h*w ->
+ # bs*num_heads, embed_dims, h, w
+ value_l_ = (
+ value_list[level]
+ .flatten(2)
+ .transpose(1, 2)
+ .reshape(bs * num_heads, embed_dims, h, w)
+ )
+ # bs, num_queries, num_heads, num_points, 2 ->
+ # bs, num_heads, num_queries, num_points, 2 ->
+ # bs*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = (
+ sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
+ )
+ # bs*num_heads, embed_dims, num_queries, num_points
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points
+ )
+ output = (
+ (
+ torch.stack(sampling_value_list, dim=-2).flatten(-2)
+ * attention_weights
+ )
+ .sum(-1)
+ .view(bs, num_heads * embed_dims, num_queries)
+ )
+ return output.transpose(1, 2).contiguous()
+
+
+def is_power_of_2(number: int) -> None:
+ """Check if a number is a power of 2."""
+ if (not isinstance(number, int)) or (number < 0):
+ raise ValueError(
+ f"invalid input for is_power_of_2: {number} (type: {type(number)})"
+ )
+ if not ((number & (number - 1) == 0) and number != 0):
+ rank_zero_warn(
+ "You'd better set hidden dimensions in MultiScaleDeformAttention"
+ "to make the dimension of each attention head a power of 2, "
+ "which is more efficient in our CUDA implementation."
+ )
+
+
+class MSDeformAttention(nn.Module):
+ """Multi-Scale Deformable Attention Module.
+
+ This is the original implementation from Deformable DETR.
+ """
+
+ def __init__(
+ self,
+ d_model: int = 256,
+ n_levels: int = 4,
+ n_heads: int = 8,
+ n_points: int = 4,
+ im2col_step: int = 64,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ d_model (int): Hidden dimensions.
+ n_levels (int): Number of feature levels.
+ n_heads (int): Number of attention heads.
+ n_points (int): Number of sampling points per attention head per
+ feature level.
+ im2col_step (int): The step used in image_to_column. Default: 64.
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError(
+ "d_model must be divisible by n_heads, but got "
+ + f"{d_model} and {n_heads}."
+ )
+
+ is_power_of_2(d_model // n_heads)
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+ self.im2col_step = im2col_step
+
+ self.sampling_offsets = nn.Linear(
+ d_model, n_heads * n_levels * n_points * 2
+ )
+ self.attention_weights = nn.Linear(
+ d_model, n_heads * n_levels * n_points
+ )
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ """Reset parameters."""
+ constant_(self.sampling_offsets.weight.data, 0.0)
+ thetas = torch.mul(
+ torch.arange(self.n_heads, dtype=torch.float32),
+ (2.0 * math.pi / self.n_heads),
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.n_heads, 1, 1, 2)
+ .repeat(1, self.n_levels, self.n_points, 1)
+ )
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.0)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.0)
+
+ def forward(
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ input_flatten: Tensor,
+ input_spatial_shapes: Tensor,
+ input_level_start_index: Tensor,
+ input_padding_mask: Tensor | None = None,
+ ) -> Tensor:
+ r"""Forward function.
+
+ Args:
+ query (Tensor): (n, length_{query}, C).
+ reference_points (Tensor): (n, length_{query}, n_levels, 2),
+ range in [0, 1], top-left (0,0), bottom-right (1, 1), including
+ padding area or (n, length_{query}, n_levels, 4), add
+ additional (w, h) to form reference boxes.
+ input_flatten (Tensor): (n, \sum_{l=0}^{L-1} H_l \cdot W_l, C).
+ input_spatial_shapes (Tensor): (n_levels, 2), [(H_0, W_0),
+ (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ input_level_start_index (Tensor): (n_levels, ), [0, H_0*W_0,
+ H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ...,
+ H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ input_padding_mask (Tensor): (n, \sum_{l=0}^{L-1} H_l \cdot W_l),
+ True for padding elements, False for non-padding elements.
+
+ Retrun
+ output (Tensor): (n, length_{query}, C).
+ """
+ n, len_q, _ = query.shape
+ n, len_in, _ = input_flatten.shape
+ assert (
+ input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]
+ ).sum() == len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(
+ n, len_in, self.n_heads, self.d_model // self.n_heads
+ )
+ sampling_offsets = self.sampling_offsets(query).view(
+ n, len_q, self.n_heads, self.n_levels, self.n_points, 2
+ )
+ attention_weights = self.attention_weights(query).view(
+ n, len_q, self.n_heads, self.n_levels * self.n_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ n, len_q, self.n_heads, self.n_levels, self.n_points
+ )
+ # n, len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]],
+ -1,
+ )
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets
+ / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets
+ / self.n_points
+ * reference_points[:, :, None, :, None, 2:]
+ * 0.5
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 or 4, "
+ + f"but get {reference_points.shape[-1]} instead."
+ )
+
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MSDeformAttentionFunction.apply(
+ value,
+ input_spatial_shapes,
+ input_level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ else:
+ output = ms_deformable_attention_cpu(
+ value,
+ input_spatial_shapes,
+ sampling_locations,
+ attention_weights,
+ )
+
+ output = self.output_proj(output)
+
+ return output
+
+ def __call__(
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ input_flatten: Tensor,
+ input_spatial_shapes: Tensor,
+ input_level_start_index: Tensor,
+ input_padding_mask: Tensor | None = None,
+ ) -> Tensor:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ query,
+ reference_points,
+ input_flatten,
+ input_spatial_shapes,
+ input_level_start_index,
+ input_padding_mask,
+ )
+
+
+class MultiScaleDeformableAttention(nn.Module):
+ """A wrapper for ``MSDeformAttention``.
+
+ This module implements MSDeformAttention with identity connection,
+ and positional encoding is also passed as input.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_heads: int = 8,
+ num_levels: int = 4,
+ num_points: int = 4,
+ im2col_step: int = 64,
+ dropout: float = 0.0,
+ ) -> None:
+ """Init."""
+ super().__init__()
+ if embed_dims % num_heads != 0:
+ raise ValueError(
+ "embed_dims must be divisible by num_heads, but got "
+ + f"{embed_dims} and {num_heads}."
+ )
+
+ is_power_of_2(embed_dims // num_heads)
+
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.num_levels = num_levels
+ self.num_points = num_points
+ self.im2col_step = im2col_step
+
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2
+ )
+ self.attention_weights = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points
+ )
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+
+ self.dropout = nn.Dropout(dropout)
+
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Initialize weights."""
+ constant_(self.sampling_offsets.weight.data, 0.0)
+ thetas = torch.mul(
+ torch.arange(self.num_heads, dtype=torch.float32),
+ (2.0 * math.pi / self.num_heads),
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.num_heads, 1, 1, 2)
+ .repeat(1, self.num_levels, self.num_points, 1)
+ )
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.0)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.0)
+
+ def forward(
+ self,
+ query: Tensor,
+ reference_points: Tensor,
+ input_flatten: Tensor,
+ input_spatial_shapes: Tensor,
+ input_level_start_index: Tensor,
+ query_pos: Tensor | None = None,
+ identity: Tensor | None = None,
+ input_padding_mask: Tensor | None = None,
+ ) -> Tensor:
+ r"""Forward function.
+
+ Args:
+ query (Tensor): The input query with shape [bs, num_queries,
+ embed_dims].
+ reference_points (Tensor): (bs, num_queries, num_levels, 2),
+ range in [0, 1], top-left (0,0), bottom-right (1, 1), including
+ padding area or (bs, num_queries, num_levels, 4), add
+ additional (w, h) to form reference boxes.
+ input_flatten (Tensor): (bs, \sum_{l=0}^{L-1} H_l \cdot W_l, C).
+ input_spatial_shapes (Tensor): (num_levels, 2), [(H_0, W_0),
+ (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
+ input_level_start_index (Tensor): (num_levels, ), [0, H_0*W_0,
+ H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ...,
+ H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}].
+ query_pos (Tensor | None): The positional encoding for query, with
+ the same shape as `query`. If not None, it will
+ be added to `query` before forward function. Defaults to None.
+ identity (Tensor | None): With the same shape as query, it will be
+ used for the identity link. If None, `query` will be used.
+ Defaults to None.
+ input_padding_mask (Tensor): (bs, \sum_{l=0}^{L-1} H_l \cdot W_l),
+ True for padding elements, False for non-padding elements.
+
+ Returns
+ output (Tensor): (bs, num_queries, C).
+ """
+ if identity is None:
+ identity = query
+
+ if query_pos is not None:
+ query = query + query_pos
+
+ n, len_q, _ = query.shape
+ n, len_in, _ = input_flatten.shape
+ assert (
+ input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]
+ ).sum() == len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(
+ n, len_in, self.num_heads, self.embed_dims // self.num_heads
+ )
+ sampling_offsets = self.sampling_offsets(query).view(
+ n, len_q, self.num_heads, self.num_levels, self.num_points, 2
+ )
+ attention_weights = self.attention_weights(query).view(
+ n, len_q, self.num_heads, self.num_levels * self.num_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ n, len_q, self.num_heads, self.num_levels, self.num_points
+ )
+ # n, len_q, num_heads, num_levels, num_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]],
+ -1,
+ )
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets
+ / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets
+ / self.num_points
+ * reference_points[:, :, None, :, None, 2:]
+ * 0.5
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 or 4, "
+ + f"but get {reference_points.shape[-1]} instead."
+ )
+
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MSDeformAttentionFunction.apply(
+ value,
+ input_spatial_shapes,
+ input_level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ else:
+ output = ms_deformable_attention_cpu(
+ value,
+ input_spatial_shapes,
+ sampling_locations,
+ attention_weights,
+ )
+
+ output = self.output_proj(output)
+
+ return self.dropout(output) + identity
diff --git a/vis4d/op/layer/patch_embed.py b/vis4d/op/layer/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba775a63cd4f030ef6d2c1feb6a73977d0b45b3d
--- /dev/null
+++ b/vis4d/op/layer/patch_embed.py
@@ -0,0 +1,91 @@
+"""Image to Patch Embedding using Conv2d.
+
+Modified from vision_transformer
+(https://github.com/google-research/vision_transformer).
+"""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding."""
+
+ def __init__(
+ self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_channels: int = 3,
+ embed_dim: int = 768,
+ norm_layer: nn.Module | None = None,
+ flatten: bool = True,
+ bias: bool = True,
+ ):
+ """Init PatchEmbed.
+
+ Args:
+ img_size (int, optional): Input image's size. Defaults to 224.
+ patch_size (int, optional): Patch size. Defaults to 16.
+ in_channels (int, optional): Number of input image's channels.
+ Defaults to 3.
+ embed_dim (int, optional): Patch embedding's dim. Defaults to 768.
+ norm_layer (nn.Module, optional): Normalization layer. Defaults to
+ None, which means no normalization layer.
+ flatten (bool, optional): If to flatten the output tensor.
+ Defaults to True.
+ bias (bool, optional): If to add bias to the convolution layer.
+ Defaults to True.
+
+ Raises:
+ ValueError: If the input image's size is not divisible by the patch
+ size.
+ """
+ super().__init__()
+ self.img_size = (img_size, img_size)
+ self.patch_size = (patch_size, patch_size)
+ self.grid_size = (
+ self.img_size[0] // self.patch_size[0],
+ self.img_size[1] // self.patch_size[1],
+ )
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(
+ in_channels,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=bias,
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def __call__(self, data: torch.Tensor) -> torch.Tensor:
+ """Applies the layer.
+
+ Args:
+ data (torch.Tensor): Input tensor of shape (B, C, H, W).
+
+ Returns:
+ torch.Tensor: Output tensor of shape (B, N, C), where N is the
+ number of patches (N = H * W).
+ """
+ return self._call_impl(data)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ _, _, height, width = x.shape
+ assert height == self.img_size[0], (
+ f"Input image height ({height}) doesn't match model"
+ f"({self.img_size})."
+ )
+ assert width == self.img_size[1], (
+ f"Input image width ({width}) doesn't match model"
+ f"({self.img_size})."
+ )
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # (B, C, H, W) -> (B, N, C)
+ x = self.norm(x)
+ return x
diff --git a/vis4d/op/layer/positional_encoding.py b/vis4d/op/layer/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ccfa7c0e5be18feec72d9d4cbda8b9d7b5ac9e
--- /dev/null
+++ b/vis4d/op/layer/positional_encoding.py
@@ -0,0 +1,192 @@
+"""Positional encoding for transformer.
+
+Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import Tensor, nn
+
+from .weight_init import uniform_init
+
+
+class SinePositionalEncoding(nn.Module):
+ """Position encoding with sine and cosine functions.
+
+ See `End-to-End Object Detection with Transformers
+ `_ for details.
+ """
+
+ def __init__(
+ self,
+ num_feats: int,
+ temperature: int = 10000,
+ normalize: bool = False,
+ scale: float = 2 * math.pi,
+ eps: float = 1e-6,
+ offset: float = 0.0,
+ ) -> None:
+ """Initialization for `SinePositionalEncoding`.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. Note the final returned dimension
+ for each position is 2 times of this value.
+ temperature (int, optional): The temperature used for scaling
+ the position embedding. Defaults to 10000.
+ normalize (bool, optional): Whether to normalize the position
+ embedding. Defaults to False.
+ scale (float, optional): A scale factor that scales the position
+ embedding. The scale will be used only when normalize is True.
+ Defaults to 2*pi.
+ eps (float, optional): A value added to the denominator for
+ numerical stability. Defaults to 1e-6.
+ offset (float, optional): offset add to embed when do the
+ normalization. Defaults to 0.
+ """
+ super().__init__()
+ if normalize:
+ assert isinstance(scale, (float, int)), (
+ "when normalize is set,"
+ "scale should be provided and in float or int type, "
+ f"found {type(scale)}"
+ )
+ self.num_feats = num_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = eps
+ self.offset = offset
+
+ def forward(
+ self, mask: Tensor | None, inputs: Tensor | None = None
+ ) -> Tensor:
+ """Forward function for `SinePositionalEncoding`.
+
+ Args:
+ mask (Tensor | None): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w]. If None, it means single
+ image or batch image with no padding.
+ inputs (Tensor | None): The input tensor. It mask is None, this
+ input tensor is required to get the shape of the input image.
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ if mask is not None:
+ # For convenience of exporting to ONNX, it's required to convert
+ # `masks` from bool to int.
+ mask = mask.to(torch.int)
+ b, h, w = mask.size()
+ device = mask.device
+ not_mask = 1 - mask # logical_not
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ else:
+ # single image or batch image with no padding
+ assert isinstance(inputs, Tensor)
+ b, _, h, w = inputs.shape
+ device = inputs.device
+ x_embed = torch.arange(
+ 1, w + 1, dtype=torch.float32, device=device
+ )
+ x_embed = x_embed.view(1, 1, -1).repeat(b, h, 1)
+ y_embed = torch.arange(
+ 1, h + 1, dtype=torch.float32, device=device
+ )
+ y_embed = y_embed.view(1, -1, 1).repeat(b, 1, w)
+ if self.normalize:
+ y_embed = (
+ (y_embed + self.offset)
+ / (y_embed[:, -1:, :] + self.eps)
+ * self.scale
+ )
+ x_embed = (
+ (x_embed + self.offset)
+ / (x_embed[:, :, -1:] + self.eps)
+ * self.scale
+ )
+ dim_t = torch.arange(
+ self.num_feats, dtype=torch.float32, device=device
+ )
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).view(b, h, w, -1)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).view(b, h, w, -1)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class LearnedPositionalEncoding(nn.Module):
+ """Position embedding with learnable embedding weights."""
+
+ def __init__(
+ self, num_feats: int, row_num_embed: int = 50, col_num_embed: int = 50
+ ) -> None:
+ """Initialization for LearnedPositionalEncoding.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. The final returned dimension for
+ each position is 2 times of this value.
+ row_num_embed (int, optional): The dictionary size of row
+ embeddings. Defaults to 50.
+ col_num_embed (int, optional): The dictionary size of col
+ embeddings. Defaults to 50.
+ """
+ super().__init__()
+ self.row_embed = nn.Embedding(row_num_embed, num_feats)
+ self.col_embed = nn.Embedding(col_num_embed, num_feats)
+ self.num_feats = num_feats
+ self.row_num_embed = row_num_embed
+ self.col_num_embed = col_num_embed
+
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Initialize the weights of position embedding."""
+ uniform_init(self.row_embed, lower=0, upper=1)
+ uniform_init(self.col_embed, lower=0, upper=1)
+
+ def forward(self, mask: Tensor) -> Tensor:
+ """Forward function for `LearnedPositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ h, w = mask.shape[-2:]
+ x = torch.arange(w, device=mask.device)
+ y = torch.arange(h, device=mask.device)
+ x_embed = self.col_embed(x)
+ y_embed = self.row_embed(y)
+ pos = (
+ torch.cat(
+ (
+ x_embed.unsqueeze(0).repeat(h, 1, 1),
+ y_embed.unsqueeze(1).repeat(1, w, 1),
+ ),
+ dim=-1,
+ )
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .repeat(mask.shape[0], 1, 1, 1)
+ )
+ return pos
diff --git a/vis4d/op/layer/transformer.py b/vis4d/op/layer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2600ade3df5b1d0fe085b452cc088c41d0c49616
--- /dev/null
+++ b/vis4d/op/layer/transformer.py
@@ -0,0 +1,255 @@
+"""Transformer layer.
+
+Modified from timm (https://github.com/huggingface/pytorch-image-models) and
+mmdetection (https://github.com/open-mmlab/mmdetection).
+"""
+
+from __future__ import annotations
+
+import copy
+
+import torch
+from torch import Tensor, nn
+
+from .attention import Attention
+from .drop import DropPath
+from .mlp import TransformerBlockMLP
+from .util import build_activation_layer
+
+
+def inverse_sigmoid(x: Tensor, eps: float = 1e-5) -> Tensor:
+ """Inverse function of sigmoid.
+
+ Args:
+ x (Tensor): The tensor to do the inverse.
+ eps (float): EPS avoid numerical overflow. Defaults 1e-5.
+
+ Returns:
+ Tensor: The x has passed the inverse function of sigmoid, has same
+ shape with input.
+ """
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+def get_clones(module: nn.Module, num: int) -> nn.ModuleList:
+ """Create N identical layers."""
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(num)])
+
+
+class LayerScale(nn.Module):
+ """Layer scaler."""
+
+ def __init__(
+ self,
+ dim: int,
+ inplace: bool = False,
+ data_format: str = "channels_last",
+ init_values: float = 1e-5,
+ ):
+ """Init layer scaler.
+
+ Args:
+ dim (int): Input tensor's dimension.
+ inplace (bool): Whether performs operation in-place. Default:
+ False.
+ data_format (str): The input data format, could be 'channels_last'
+ or 'channels_first', representing (B, C, H, W) and (B, N, C)
+ format data respectively. Default: channels_last.
+ init_values (float, optional): Initial values for layer scale.
+ Defaults to 1e-5.
+ """
+ super().__init__()
+ assert data_format in {
+ "channels_last",
+ "channels_first",
+ }, "data_format could only be channels_last or channels_first."
+ self.inplace = inplace
+ self.data_format = data_format
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ if self.data_format == "channels_first":
+ shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
+ else:
+ shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
+
+ if self.inplace:
+ return x.mul_(self.gamma.view(*shape))
+
+ return x * self.gamma.view(*shape)
+
+
+class TransformerBlock(nn.Module):
+ """Transformer block for Vision Transformer."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values: float | None = None,
+ drop_path: float = 0.0,
+ act_layer: nn.Module = nn.GELU(),
+ norm_layer: nn.Module | None = None,
+ ):
+ """Init transformer block.
+
+ Args:
+ dim (int): Input tensor's dimension.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding
+ dim. Defaults to 4.0.
+ qkv_bias (bool, optional): If to add bias to qkv. Defaults to
+ False.
+ drop (float, optional): Dropout rate for attention and projection.
+ Defaults to 0.0.
+ attn_drop (float, optional): Dropout rate for attention. Defaults
+ to 0.0.
+ init_values (tuple[float, float] | None, optional): Initial values
+ for layer scale. Defaults to None.
+ drop_path (float, optional): Dropout rate for drop path. Defaults
+ to 0.0.
+ act_layer (nn.Module, optional): Activation layer. Defaults to
+ nn.GELU.
+ norm_layer (nn.Module, optional): Normalization layer. If None, use
+ nn.LayerNorm.
+ """
+ super().__init__()
+ self.norm1 = (
+ norm_layer(dim) if norm_layer else nn.LayerNorm(dim, eps=1e-6)
+ )
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values)
+ if init_values
+ else nn.Identity()
+ )
+ self.drop_path1 = (
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ )
+
+ self.norm2 = (
+ norm_layer(dim) if norm_layer else nn.LayerNorm(dim, eps=1e-6)
+ )
+ self.mlp = TransformerBlockMLP(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=drop,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values)
+ if init_values
+ else nn.Identity()
+ )
+ self.drop_path2 = (
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ )
+
+ def __call__(self, data: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ data (torch.Tensor): Input tensor of shape (B, N, dim).
+
+ Returns:
+ torch.Tensor: Output tensor of shape (B, N, dim).
+ """
+ return self._call_impl(data)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class FFN(nn.Module):
+ """Implements feed-forward networks (FFNs) with identity connection."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ feedforward_channels: int = 1024,
+ num_fcs: int = 2,
+ dropout: float = 0.0,
+ activation: str = "ReLU",
+ inplace: bool = True,
+ dropout_layer: nn.Module | None = None,
+ add_identity: bool = True,
+ layer_scale_init_value: float = 0.0,
+ ) -> None:
+ """Init FFN.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int): The number of fully-connected layers in FFNs.
+ Defaults: 2.
+ dropout (float): The dropout rate of FFNs.
+ activation (str): The activation function of FFNs.
+ inplace (bool): Whether to set inplace for activation.
+ dropout_layer (nn.Module | None, optional): The dropout_layer used
+ when adding the shortcut. Defaults to None. If None, Identity
+ is used.
+ add_identity (bool, optional): Whether to add the identity
+ connection. Default: True.
+ layer_scale_init_value (float): Initial value of scale factor in
+ LayerScale. Default: 0.0
+ """
+ super().__init__()
+ self.embed_dims = embed_dims
+
+ layers: list[nn.Module] = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ nn.Sequential(
+ nn.Linear(in_channels, feedforward_channels),
+ build_activation_layer(activation, inplace),
+ nn.Dropout(dropout),
+ )
+ )
+ in_channels = feedforward_channels
+ layers.append(nn.Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(dropout))
+ self.layers = nn.Sequential(*layers)
+
+ self.dropout_layer = dropout_layer or nn.Identity()
+ self.add_identity = add_identity
+ self.layer_scale_init_value = layer_scale_init_value
+
+ if self.layer_scale_init_value > 0:
+ self.gamma2 = LayerScale(
+ embed_dims, init_values=self.layer_scale_init_value
+ )
+
+ def forward(self, x: Tensor, identity: Tensor | None = None) -> None:
+ """Forward function for FFN.
+
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+
+ if self.layer_scale_init_value > 0:
+ out = self.gamma2(out)
+
+ if self.add_identity:
+ identity = x if identity is None else identity
+ return identity + self.dropout_layer(out)
+
+ return self.dropout_layer(out)
diff --git a/vis4d/op/layer/util.py b/vis4d/op/layer/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6c7488ebd8bc0ae60ca1e06b243ce5e5ed0433
--- /dev/null
+++ b/vis4d/op/layer/util.py
@@ -0,0 +1,89 @@
+"""Utility functions for layer ops."""
+
+from __future__ import annotations
+
+from torch import nn
+
+from .conv2d import Conv2d
+from .deform_conv import DeformConv
+
+
+def build_conv_layer(
+ in_planes: int,
+ out_planes: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ bias: bool = False,
+ norm: nn.Module | None = None,
+ activation: nn.Module | None = None,
+ use_dcn: bool = False,
+) -> nn.Module:
+ """Build a convolution layer."""
+ if use_dcn:
+ return DeformConv(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ norm=norm,
+ activation=activation,
+ )
+
+ return Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ norm=norm,
+ activation=activation,
+ )
+
+
+def build_activation_layer(
+ activation: str, inplace: bool = False
+) -> nn.Module:
+ """Build activation layer.
+
+ Args:
+ activation (str): Activation layer type.
+ inplace (bool, optional): If to set inplace. Defaults to False. It will
+ be ignored if the activation layer is not inplace.
+ """
+ activation_layer = getattr(nn, activation)
+
+ if activation_layer in {nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU}:
+ return activation_layer()
+
+ return activation_layer(inplace=inplace)
+
+
+def build_norm_layer(
+ norm: str, out_channels: int, num_groups: int | None = None
+) -> nn.Module:
+ """Build normalization layer.
+
+ Args:
+ norm (str): Normalization layer type.
+ out_channels (int): Number of output channels.
+ num_groups (int | None, optional): Number of groups for GroupNorm.
+ Defaults to None.
+ """
+ norm_layer = getattr(nn, norm)
+ if norm_layer == nn.GroupNorm:
+ assert (
+ num_groups is not None
+ ), "num_groups must be specified when using Group Norm"
+ return norm_layer(num_groups, out_channels)
+
+ return norm_layer(out_channels)
diff --git a/vis4d/op/layer/weight_init.py b/vis4d/op/layer/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e58e24e0d2400cadd2b733cca6617075a120b49
--- /dev/null
+++ b/vis4d/op/layer/weight_init.py
@@ -0,0 +1,120 @@
+"""Model weight initialization."""
+
+from typing import Literal
+
+import numpy as np
+from torch import nn
+
+NonlinearityType = Literal[
+ "linear",
+ "conv1d",
+ "conv2d",
+ "conv3d",
+ "conv_transpose1d",
+ "conv_transpose2d",
+ "conv_transpose3d",
+ "sigmoid",
+ "tanh",
+ "relu",
+ "leaky_relu",
+ "selu",
+]
+FanMode = Literal["fan_in", "fan_out"]
+
+
+def constant_init(module: nn.Module, val: float, bias: float = 0.0) -> None:
+ """Initialize module with constant value."""
+ if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(
+ module: nn.Module,
+ gain: float = 1.0,
+ bias: float = 0.0,
+ distribution: str = "normal",
+) -> None:
+ """Initialize module with Xavier initialization."""
+ assert distribution in {"uniform", "normal"}
+ if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
+ if distribution == "uniform":
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(
+ module: nn.Module,
+ negative_slope: float = 0.0,
+ mode: FanMode = "fan_out",
+ nonlinearity: NonlinearityType = "relu",
+ bias: float = 0.0,
+ distribution: str = "normal",
+) -> None:
+ """Initialize module with Kaiming initialization.
+
+ Args:
+ module (nn.Module): Module to initialize.
+ negative_slope (float, optional): The negative slope of the rectifier
+ used after this layer (only used with ``'leaky_relu'``). Defaults
+ to 0.0.
+ mode (FanMode, optional): Either `"fan_in"` (default) or `"fan_out"``.
+ Choosing `"fan_in"` preserves the magnitude of the variance of
+ the weights in the forward pass. Choosing `"fan_out"` preserves
+ magnitudes in the backwards pass. Defaults to "fan_out".
+ nonlinearity (NonlinearityType, optional): The non-linear function
+ (`nn.functional` name). Defaults to "relu".
+ bias (float, optional): The bias to use. Defaults to 0.0.
+ distribution (str, optional): Either ``'uniform'`` or ``'normal'``.
+ Defaults to "normal".
+ """
+ assert distribution in {"uniform", "normal"}
+ if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
+ if distribution == "uniform":
+ nn.init.kaiming_uniform_(
+ module.weight,
+ a=negative_slope,
+ mode=mode,
+ nonlinearity=nonlinearity,
+ )
+ else:
+ nn.init.kaiming_normal_(
+ module.weight,
+ a=negative_slope,
+ mode=mode,
+ nonlinearity=nonlinearity,
+ )
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(
+ module: nn.Module, mean: float = 0.0, std: float = 1.0, bias: float = 0
+) -> None:
+ """Initialize module with normal distribution."""
+ if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
+ nn.init.constant_(module.bias, bias)
+
+
+def bias_init_with_prob(prior_prob: float) -> float:
+ """Initialize conv/fc bias value according to a given probability value."""
+ return float(-np.log((1 - prior_prob) / prior_prob))
+
+
+def uniform_init(
+ module: nn.Module,
+ lower: float = 0.0,
+ upper: float = 1.0,
+ bias: float = 0.0,
+) -> None:
+ """Initialize module with uniform distribution."""
+ if hasattr(module, "weight") and isinstance(module.weight, nn.Parameter):
+ nn.init.uniform_(module.weight, lower, upper)
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
+ nn.init.constant_(module.bias, bias)
diff --git a/vis4d/op/loss/__init__.py b/vis4d/op/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8846c92abc71540756ffdbdb3b3fc086ecda395
--- /dev/null
+++ b/vis4d/op/loss/__init__.py
@@ -0,0 +1,23 @@
+"""This module contains commonly used loss functions.
+
+The losses do not follow a common API, but have a reducer as attribute,
+which is a function to aggregate loss values into a single tensor value.
+"""
+
+from .base import Loss
+from .embedding_distance import EmbeddingDistanceLoss
+from .iou_loss import IoULoss
+from .multi_level_seg_loss import MultiLevelSegLoss
+from .multi_pos_cross_entropy import MultiPosCrossEntropyLoss
+from .orthogonal_transform_loss import OrthogonalTransformRegularizationLoss
+from .seg_cross_entropy_loss import SegCrossEntropyLoss
+
+__all__ = [
+ "Loss",
+ "EmbeddingDistanceLoss",
+ "IoULoss",
+ "MultiLevelSegLoss",
+ "MultiPosCrossEntropyLoss",
+ "OrthogonalTransformRegularizationLoss",
+ "SegCrossEntropyLoss",
+]
diff --git a/vis4d/op/loss/base.py b/vis4d/op/loss/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9ecc27a1e87b1d11f82c62ba70439a7afe97767
--- /dev/null
+++ b/vis4d/op/loss/base.py
@@ -0,0 +1,28 @@
+"""Base class for meta architectures."""
+
+import abc
+
+from torch import nn
+
+from vis4d.op.loss.reducer import identity_loss
+
+from .reducer import LossReducer
+
+
+class Loss(nn.Module, abc.ABC):
+ """Base loss class."""
+
+ def __init__(self, reducer: LossReducer = identity_loss) -> None:
+ """Initialize a loss functor.
+
+ Args:
+ reducer (LossReducer): A function to aggregate the loss values into
+ a single tensor value. It is commonly used for dense prediction
+ tasks to merge pixel-wise loss to a final loss.
+
+ Example::
+ def mean_loss(loss: torch.Tensor) -> torch.Tensor:
+ return loss.mean()
+ """
+ super().__init__()
+ self.reducer = reducer
diff --git a/vis4d/op/loss/common.py b/vis4d/op/loss/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..991947c61a7b4dfb401240f47d745f60888dfe65
--- /dev/null
+++ b/vis4d/op/loss/common.py
@@ -0,0 +1,129 @@
+"""Common loss functions."""
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from vis4d.op.loss.reducer import LossReducer, identity_loss
+
+
+def smooth_l1_loss(
+ pred: Tensor,
+ target: Tensor,
+ reducer: LossReducer = identity_loss,
+ beta: float = 1.0,
+) -> Tensor:
+ """Smooth L1 loss.
+
+ L1 loss that uses a squared term if the absolute element-wise error
+ falls below beta.
+
+ Args:
+ pred (Tensor): Model predictions
+ target (Tensor): Ground truth value
+ reducer (LossReducer): Reducer to reduce the loss value. Defaults to
+ identy_loss, which is no reduction.
+ beta (float): Specifies the threshold at which to change between L1
+ and L2 loss. The value must be non-negative. Default: 1.0
+
+ Returns:
+ Tensor : The reduced smooth l1 loss:
+ |pred - target| - 0.5*beta if |pred - target| < 0.5*beta
+ (pred - target)^2 * 0.5/beta else
+ """
+ assert beta > 0
+ assert pred.size() == target.size() and target.numel() > 0
+ diff = torch.abs(pred - target)
+ loss = torch.where(
+ diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta
+ )
+ return reducer(loss)
+
+
+def l1_loss(
+ pred: Tensor, target: Tensor, reducer: LossReducer = identity_loss
+) -> Tensor:
+ """L1 loss.
+
+ Args:
+ pred (Tensor): Model predictions
+ target (Tensor): Ground truth value
+ reducer (LossReducer): Reducer to reduce the loss value. Defaults to
+ identy_loss, which is no reduction.
+
+ Returns:
+ Tensor : The reduced L1 loss (reduce(|pred - target|))
+ """
+ assert pred.size() == target.size() and target.numel() > 0
+ loss = torch.abs(pred - target)
+ return reducer(loss)
+
+
+def l2_loss(
+ pred: Tensor, target: Tensor, reducer: LossReducer = identity_loss
+) -> Tensor:
+ """L2 loss.
+
+ Args:
+ pred (Tensor): Model predictions
+ target (Tensor): Ground truth value
+ reducer (LossReducer): Reducer to reduce the loss value. Defaults to
+ identy_loss, which is no reduction.
+
+ Returns:
+ Tensor : The reduced L2 loss (reduce((pred - target)**2))
+ """
+ assert pred.size() == target.size() and target.numel() > 0
+ loss = (pred - target) ** 2
+ return reducer(loss)
+
+
+def rotation_loss(
+ pred: Tensor,
+ target_bin: Tensor,
+ target_res: Tensor,
+ num_bins: int,
+ reducer: LossReducer = identity_loss,
+) -> Tensor:
+ """Rotation loss.
+
+ Consists of bin-based classification loss and residual-based regression
+ loss.
+
+ Args:
+ pred (Tensor): Prediction shape [B, num_bins * 3]
+ target_bin (Tensor): Target bins shape [B, num_bin]
+ target_res (Tensor): Target residual shape [B, num_bin]
+ num_bins (int): Number of bins
+ reducer (LossReducer, optional): Loss Reducer.
+ Defaults to identity_loss.
+
+ Returns:
+ Tensor: The reduced loss value
+ """
+ loss_bins = (
+ F.binary_cross_entropy_with_logits(
+ pred[:, :num_bins], target_bin, reduction="none"
+ )
+ .mean(dim=0)
+ .sum()
+ )
+
+ loss_res = torch.zeros_like(loss_bins)
+ for i in range(num_bins):
+ bin_mask = target_bin[:, i] == 1
+ res_idx = num_bins + 2 * i
+ if bin_mask.any():
+ loss_sin = smooth_l1_loss(
+ pred[bin_mask, res_idx],
+ torch.sin(target_res[bin_mask, i]),
+ reducer=reducer,
+ )
+ loss_cos = smooth_l1_loss(
+ pred[bin_mask, res_idx + 1],
+ torch.cos(target_res[bin_mask, i]),
+ reducer=reducer,
+ )
+ loss_res += loss_sin + loss_cos
+
+ return loss_bins + loss_res
diff --git a/vis4d/op/loss/cross_entropy.py b/vis4d/op/loss/cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51ead860b2207d17a2c89a5cb1c45db128c8150
--- /dev/null
+++ b/vis4d/op/loss/cross_entropy.py
@@ -0,0 +1,89 @@
+"""Cross entropy loss."""
+
+from __future__ import annotations
+
+import torch.nn.functional as F
+from torch import Tensor
+
+from .base import Loss
+from .reducer import LossReducer, mean_loss
+
+
+class CrossEntropyLoss(Loss):
+ """Cross entropy loss class."""
+
+ def __init__(
+ self,
+ reducer: LossReducer = mean_loss,
+ class_weights: list[float] | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ reducer (LossReducer): Reducer for the loss function. Defaults to
+ mean_loss.
+ class_weights (list[float], optional): Class weights for the loss
+ function. Defaults to None.
+ """
+ super().__init__(reducer)
+ self.class_weights = class_weights
+
+ def forward(
+ self,
+ output: Tensor,
+ target: Tensor,
+ reducer: LossReducer | None = None,
+ ignore_index: int = 255,
+ ) -> Tensor:
+ """Forward pass.
+
+ Args:
+ output (list[Tensor]): Model output.
+ target (Tensor): Assigned segmentation target mask.
+ reducer (LossReducer, optional): Reducer for the loss function.
+ Defaults to None.
+ ignore_index (int): Ignore class id. Default to 255.
+
+ Returns:
+ Tensor: Computed loss.
+ """
+ if self.class_weights is not None:
+ class_weights = output.new_tensor(
+ self.class_weights, device=output.device
+ )
+ else:
+ class_weights = None
+ reducer = reducer or self.reducer
+
+ return reducer(
+ cross_entropy(
+ output, target, class_weights, ignore_index=ignore_index
+ )
+ )
+
+
+def cross_entropy(
+ output: Tensor,
+ target: Tensor,
+ class_weights: Tensor | None = None,
+ ignore_index: int = 255,
+) -> Tensor:
+ """Cross entropy loss function.
+
+ Args:
+ output (Tensor): Model output.
+ target (Tensor): Assigned segmentation target mask.
+ class_weights (Tensor | None, optional): Class weights for the loss
+ function. Defaults to None.
+ ignore_index (int): Ignore class id. Default to 255.
+
+ Returns:
+ Tensor: Computed loss.
+ """
+ return F.cross_entropy(
+ output,
+ target.long(),
+ weight=class_weights,
+ ignore_index=ignore_index,
+ reduction="none",
+ )
diff --git a/vis4d/op/loss/embedding_distance.py b/vis4d/op/loss/embedding_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c614541e96c3dba7810810e9386129f59824cf7
--- /dev/null
+++ b/vis4d/op/loss/embedding_distance.py
@@ -0,0 +1,103 @@
+"""Embedding distance loss."""
+
+from __future__ import annotations
+
+import torch
+
+from vis4d.op.box.box2d import random_choice
+
+from .base import Loss
+from .common import l2_loss
+from .reducer import LossReducer, SumWeightedLoss, identity_loss
+
+
+class EmbeddingDistanceLoss(Loss):
+ """Embedding distance loss for learning appearance similarity.
+
+ Computes the difference between the target distances and the predicted
+ distances of two sets of embedding vectors. Uses hard negative mining based
+ on the loss values to select pairs for overall loss computation.
+ """
+
+ def __init__(
+ self,
+ reducer: LossReducer = identity_loss,
+ neg_pos_ub: float = 3.0,
+ pos_margin: float = 0.0,
+ neg_margin: float = 0.3,
+ hard_mining: bool = True,
+ ):
+ """Creates an instance of the class."""
+ super().__init__(reducer)
+ self.neg_pos_ub = neg_pos_ub
+ self.neg_margin = neg_margin
+ self.pos_margin = pos_margin
+ self.hard_mining = hard_mining
+
+ def forward( # pylint: disable=arguments-differ
+ self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ weight: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The predicted distances between two sets of
+ predictions. Shape [N, M].
+ target (torch.Tensor): The corresponding target distances. Either
+ zero (different identity) or one (same identity).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+
+ Returns:
+ loss_bbox (torch.Tensor): embedding distance loss.
+ """
+ if weight is None:
+ weight = target.new_ones(target.size())
+ pred, weight, avg_factor = self.update_weight(pred, target, weight)
+ return l2_loss(
+ pred, target, reducer=SumWeightedLoss(weight, avg_factor)
+ )
+
+ def update_weight(
+ self, pred: torch.Tensor, target: torch.Tensor, weight: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Update element-wise loss weights.
+
+ Exclude negatives according to maximum fraction of samples and/or
+ hard negative mining.
+ """
+ invalid_inds = weight <= 0
+ target[invalid_inds] = -1
+ pos_inds = torch.eq(target, 1)
+ neg_inds = torch.eq(target, 0)
+
+ if self.pos_margin > 0:
+ pred[pos_inds] -= self.pos_margin
+ if self.neg_margin > 0:
+ pred[neg_inds] -= self.neg_margin
+ pred = torch.clamp(pred, min=0, max=1)
+
+ num_pos = max(1, int(torch.eq(target, 1).sum()))
+ num_neg = int(torch.eq(target, 0).sum())
+ if self.neg_pos_ub > 0 and num_neg / num_pos > self.neg_pos_ub:
+ num_neg = int(num_pos * self.neg_pos_ub)
+ neg_idx = torch.nonzero(torch.eq(target, 0), as_tuple=False)
+
+ if self.hard_mining:
+ costs = l2_loss(pred, target)[
+ neg_idx[:, 0], neg_idx[:, 1]
+ ].detach()
+ neg_idx = neg_idx[costs.topk(num_neg)[1], :]
+ else:
+ neg_idx = random_choice(neg_idx, num_neg)
+
+ new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool()
+ new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True
+
+ invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds)
+ weight[invalid_neg_inds] = 0
+
+ avg_factor = torch.greater(weight, 0).sum()
+ return pred, weight, avg_factor
diff --git a/vis4d/op/loss/iou_loss.py b/vis4d/op/loss/iou_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5290d3b1d18010873299de5f0f6e1c964147104a
--- /dev/null
+++ b/vis4d/op/loss/iou_loss.py
@@ -0,0 +1,94 @@
+"""Embedding distance loss."""
+
+from __future__ import annotations
+
+import torch
+
+from vis4d.op.box.box2d import bbox_iou_aligned
+
+from .base import Loss
+from .reducer import LossReducer, identity_loss
+
+
+def iou_loss(
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ reducer: LossReducer = identity_loss,
+ mode: str = "log",
+ eps: float = 1e-6,
+) -> torch.Tensor:
+ """Compute IoU loss.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes.
+ target (torch.Tensor): Target bboxes.
+ reducer (LossReducer): Reducer to reduce the loss value. Defaults to
+ identy_loss, which is no reduction.
+ mode (str, optional): Mode to calculate the loss. Defaults to "log".
+ eps (float, optional): Epsilon value to avoid division by zero.
+
+ Returns:
+ torch.Tensor : The reduced IoU loss.
+ """
+ assert mode in {
+ "linear",
+ "square",
+ "log",
+ }, f"Invalid mode {mode}. Must be one of 'linear', 'square', 'log'."
+ ious = bbox_iou_aligned(pred, target).clamp(min=eps)
+ if mode == "linear":
+ loss = 1 - ious
+ elif mode == "square":
+ loss = 1 - ious**2
+ else:
+ loss = -ious.log()
+ return reducer(loss)
+
+
+class IoULoss(Loss):
+ """IoU loss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+ The loss is calculated depending on the mode:
+ - linear: 1 - IoU
+ - square: 1 - IoU^2
+ - log: -log(IoU)
+
+ Args:
+ reducer (LossReducer): Reducer to reduce the loss value. Defaults to
+ identy_loss, which is no reduction.
+ mode (str, optional): Mode to calculate the loss. Defaults to "log".
+ eps (float, optional): Epsilon value to avoid division by zero.
+ """
+
+ def __init__(
+ self,
+ reducer: LossReducer = identity_loss,
+ mode: str = "log",
+ eps: float = 1e-6,
+ ):
+ """Creates an instance of the class."""
+ super().__init__(reducer)
+ self.mode = mode
+ self.eps = eps
+ assert mode in {
+ "linear",
+ "square",
+ "log",
+ }, f"Invalid mode {mode}. Must be one of 'linear', 'square', 'log'."
+
+ def forward( # pylint: disable=arguments-differ
+ self, pred: torch.Tensor, target: torch.Tensor
+ ) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes.
+ target (torch.Tensor): Target bboxes.
+
+ Returns:
+ torch.Tensor: The reduced IoU loss.
+ """
+ return iou_loss(
+ pred, target, reducer=self.reducer, mode=self.mode, eps=self.eps
+ )
diff --git a/vis4d/op/loss/multi_level_seg_loss.py b/vis4d/op/loss/multi_level_seg_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1cddfc9fc330aa201843daec8ac3a5df1a2100
--- /dev/null
+++ b/vis4d/op/loss/multi_level_seg_loss.py
@@ -0,0 +1,72 @@
+"""Multi-level segmentation loss."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+from vis4d.common.typing import LossesType
+
+from .base import Loss
+from .cross_entropy import cross_entropy
+from .reducer import LossReducer, mean_loss
+
+
+class MultiLevelSegLoss(Loss):
+ """Multi-level segmentation loss class.
+
+ Applies the segmentation loss function to multiple levels of predictions to
+ provide auxiliary losses for intermediate outputs in addition to the final
+ output, used in FCN.
+ """
+
+ def __init__(
+ self,
+ reducer: LossReducer = mean_loss,
+ feature_idx: tuple[int, ...] = (0,),
+ weights: list[float] | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ reducer (LossReducer): Reducer for the loss function. Defaults to
+ mean_loss.
+ feature_idx (tuple[int]): Indices for the level of features to
+ compute losses. Defaults to (0,).
+ weights (list[float], optional): The weights of each feature level.
+ If None passes, it will set to 1 for all levels. Defaults to
+ None.
+ """
+ super().__init__(reducer)
+ self.feature_idx = feature_idx
+ if weights is None:
+ self.weights = [1.0] * len(self.feature_idx)
+ else:
+ self.weights = weights
+
+ def forward(
+ self, outputs: list[Tensor], target: Tensor, ignore_index: int = 255
+ ) -> LossesType:
+ """Forward pass.
+
+ Args:
+ outputs (list[Tensor]): Multi-level outputs.
+ target (Tensor): Assigned segmentation target mask.
+ ignore_index (int): Ignore class id. Default to 255.
+
+ Returns:
+ LossesType: Computed losses for each level.
+ """
+ losses: LossesType = {}
+ tgt_h, tgt_w = target.shape[-2:]
+ for i, idx in enumerate(self.feature_idx):
+ loss = self.reducer(
+ cross_entropy(
+ outputs[idx][:, :, :tgt_h, :tgt_w],
+ target,
+ ignore_index=ignore_index,
+ )
+ )
+ losses[f"loss_seg_level{idx}"] = torch.mul(self.weights[i], loss)
+
+ return losses
diff --git a/vis4d/op/loss/multi_pos_cross_entropy.py b/vis4d/op/loss/multi_pos_cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ab48f087fcca3dba1e62b701f9393be11b9f01
--- /dev/null
+++ b/vis4d/op/loss/multi_pos_cross_entropy.py
@@ -0,0 +1,60 @@
+"""Multi-positive cross entropy loss."""
+
+import torch
+from torch import Tensor
+
+from .base import Loss
+from .reducer import LossReducer, SumWeightedLoss
+
+
+class MultiPosCrossEntropyLoss(Loss):
+ """Multi-positive cross entropy loss.
+
+ Used for appearance similiary learning in QDTrack.
+ """
+
+ def forward(
+ self,
+ pred: Tensor,
+ target: Tensor,
+ weight: Tensor,
+ avg_factor: float,
+ ) -> Tensor:
+ """Multi-positive cross entropy loss.
+
+ Args:
+ pred (Tensor): Similarity scores before softmax. Shape [N, M]
+ target (Tensor): Target for each pair. Either one, meaning
+ same identity or zero, meaning different identity. Shape [N, M]
+ weight (Tensor): The weight of loss for each prediction.
+ avg_factor (float): Averaging factor for the loss.
+
+ Returns:
+ Tensor: Scalar loss value.
+ """
+ return multi_pos_cross_entropy(
+ pred, target, reducer=SumWeightedLoss(weight, avg_factor)
+ )
+
+
+def multi_pos_cross_entropy(
+ pred: Tensor, target: Tensor, reducer: LossReducer
+) -> Tensor:
+ """Calculate multi-positive cross-entropy loss."""
+ pos_inds = torch.eq(target, 1)
+ neg_inds = torch.eq(target, 0)
+ pred_pos = pred * pos_inds.float()
+ pred_neg = pred * neg_inds.float()
+ # use -inf to mask out unwanted elements.
+ pred_pos[neg_inds] = pred_pos[neg_inds] + float("inf")
+ pred_neg[pos_inds] = pred_neg[pos_inds] + float("-inf")
+
+ _pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1)
+ _neg_expand = pred_neg.repeat(1, pred.shape[1])
+
+ x = torch.nn.functional.pad( # pylint: disable=not-callable
+ (_neg_expand - _pos_expand), (0, 1), "constant", 0
+ )
+ loss = torch.logsumexp(x, dim=1)
+
+ return reducer(loss)
diff --git a/vis4d/op/loss/orthogonal_transform_loss.py b/vis4d/op/loss/orthogonal_transform_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..c97dc2718b0ed03ce9fb4d37eed5dd958fddbc0a
--- /dev/null
+++ b/vis4d/op/loss/orthogonal_transform_loss.py
@@ -0,0 +1,61 @@
+"""Orthogonal Transform Loss."""
+
+from __future__ import annotations
+
+import torch
+
+from .base import Loss
+
+
+class OrthogonalTransformRegularizationLoss(Loss):
+ """Loss that punishes linear transformations that are not orthogonal.
+
+ Calculates difference of X'*X and identity matrix using norm( X'*X - I)
+ """
+
+ def __call___(self, transforms: list[torch.Tensor]) -> torch.Tensor:
+ """Calculates the loss.
+
+ Calculates difference of X'*X and the identity matrix using
+ norm(X'*X - I) for each transformation
+
+ Args:
+ transforms: (list(torch.tensor)) list with transformation matrices
+ batched ([N, 3, 3], [N, x, x], ....)
+
+ Returns:
+ torch.Tensor containing the mean loss value (mean(norm(X'*X - I)))
+ """
+ return self._call_impl(transforms)
+
+ def forward(self, transforms: list[torch.Tensor]) -> torch.Tensor:
+ """Calculates the loss.
+
+ Calculates difference of X'*X and the identity matrix using
+ norm(X'*X - I) for each transformation
+
+ Args:
+ transforms: (list(torch.tensor)) list with transformation matrices
+ batched ([N, 3, 3], [N, x, x], ....)
+
+ Returns:
+ torch.Tensor containing the mean loss value (mean(norm(X'*X - I)))
+ """
+ loss = torch.tensor(0.0)
+ for trans in transforms:
+ d = trans.size()[1]
+
+ try:
+ identity = self.get_buffer(f"identity_{d}")
+ except AttributeError as _:
+ # Create identity buffers if not yet allocated
+ identity = torch.eye(d, device=trans.device)
+ self.register_buffer(f"identity_{d}", identity)
+
+ loss += torch.mean(
+ torch.norm(
+ torch.bmm(trans, trans.transpose(2, 1)) - identity,
+ dim=(1, 2),
+ )
+ )
+ return loss
diff --git a/vis4d/op/loss/reducer.py b/vis4d/op/loss/reducer.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a47c59631f75f2254ce97679ec6d9aa04083c0
--- /dev/null
+++ b/vis4d/op/loss/reducer.py
@@ -0,0 +1,69 @@
+"""Definitions of loss reducers.
+
+Loss reducers are usually used as the last step in loss computation to average
+or sum the loss maps from dense predictions or object detections.
+"""
+
+from __future__ import annotations
+
+from typing import Callable
+
+from torch import Tensor
+
+LossReducer = Callable[[Tensor], Tensor]
+
+
+def identity_loss(loss: Tensor) -> Tensor:
+ """Make no change to the loss."""
+ return loss
+
+
+def mean_loss(loss: Tensor) -> Tensor:
+ """Average the loss tensor values to a single value.
+
+ Args:
+ loss (Tensor): Input multi-dimentional tensor.
+
+ Returns:
+ Tensor: Tensor containing a single loss value.
+ """
+ return loss.mean()
+
+
+def sum_loss(loss: Tensor) -> Tensor:
+ """Sum the loss tensor values to a single value.
+
+ Args:
+ loss (Tensor): Input multi-dimentional tensor.
+
+ Returns:
+ Tensor: Tensor containing a single loss value.
+ """
+ return loss.sum()
+
+
+class SumWeightedLoss:
+ """A loss reducer to calculated weighted sum loss."""
+
+ def __init__(
+ self, weight: float | Tensor, avg_factor: float | Tensor
+ ) -> None:
+ """Initialize the loss reducer.
+
+ Args:
+ weight (float | Tensor): Weights for each loss elements
+ avg_factor (float | Tensor): average factor for the weighted loss
+ """
+ self.weight = weight
+ self.avg_factor = avg_factor
+
+ def __call__(self, loss: Tensor) -> Tensor:
+ """Weight the loss elements and take the sum with the average factor.
+
+ Args:
+ loss (Tensor): input loss
+
+ Returns:
+ Tensor: output loss
+ """
+ return (loss * self.weight).sum() / self.avg_factor
diff --git a/vis4d/op/loss/seg_cross_entropy_loss.py b/vis4d/op/loss/seg_cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..069affb8941933133140cd372ec2d5b5d2ed6d53
--- /dev/null
+++ b/vis4d/op/loss/seg_cross_entropy_loss.py
@@ -0,0 +1,50 @@
+"""Segmentation cross entropy loss."""
+
+from __future__ import annotations
+
+from torch import Tensor
+
+from vis4d.common.typing import LossesType
+
+from .base import Loss
+from .cross_entropy import cross_entropy
+from .reducer import LossReducer, mean_loss
+
+
+class SegCrossEntropyLoss(Loss):
+ """Segmentation cross entropy loss class.
+
+ Wrapper for nn.CrossEntropyLoss that additionally clips the output to the
+ target size and converts the target mask tensor to long.
+ """
+
+ def __init__(self, reducer: LossReducer = mean_loss) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ reducer (LossReducer): Reducer for the loss function. Defaults to
+ mean_loss.
+ """
+ super().__init__(reducer)
+
+ def forward(
+ self, output: Tensor, target: Tensor, ignore_index: int = 255
+ ) -> LossesType:
+ """Forward pass.
+
+ Args:
+ output (list[Tensor]): Model output.
+ target (Tensor): Assigned segmentation target mask.
+ ignore_index (int): Ignore class id. Default to 255.
+
+ Returns:
+ LossesType: Computed loss.
+ """
+ losses: LossesType = {}
+ tgt_h, tgt_w = target.shape[-2:]
+ losses["loss_seg"] = self.reducer(
+ cross_entropy(
+ output[:, :, :tgt_h, :tgt_w], target, ignore_index=ignore_index
+ )
+ )
+ return losses
diff --git a/vis4d/op/mask/__init__.py b/vis4d/op/mask/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5942e862bf69effca1ea4e645cea9c65df72b95b
--- /dev/null
+++ b/vis4d/op/mask/__init__.py
@@ -0,0 +1 @@
+"""Operations on 2D segmentation masks."""
diff --git a/vis4d/op/mask/util.py b/vis4d/op/mask/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e0b299ca5e78342f916834a437b3977b2480fda
--- /dev/null
+++ b/vis4d/op/mask/util.py
@@ -0,0 +1,283 @@
+"""Utility functions for segmentation masks."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def _do_paste_mask( # type: ignore
+ masks: Tensor,
+ boxes: Tensor,
+ img_h: int,
+ img_w: int,
+ skip_empty: bool = True,
+) -> tuple[Tensor, tuple[slice, slice] | tuple[()]]:
+ """Paste mask onto image.
+
+ On GPU, paste all masks together (up to chunk size) by using the entire
+ image to sample the masks Compared to pasting them one by one, this has
+ more operations but is faster on COCO-scale dataset.
+
+ This implementation is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ Args:
+ masks (Tensor): Masks with shape [N, 1, Hmask, Wmask].
+ boxes (Tensor): Boxes with shape [N, 4].
+ img_h (int): Image height.
+ img_w (int): Image width.
+ skip_empty (bool, optional): Only paste masks within the region that
+ tightly bound all boxes, and returns the results this region only.
+ An important optimization for CPU. Defaults to True.
+
+ Returns:
+ Tensor: Mask with shape [N, Himg, Wimg] if skip_empty == True, or
+ a mask of shape (N, H', W') and the slice object for the
+ corresponding region if skip_empty == False.
+ """
+ device = masks.device
+
+ if skip_empty:
+ x0_int, y0_int = torch.clamp(
+ boxes.min(dim=0).values.floor()[:2] - 1, min=0
+ ).to(dtype=torch.int32)
+ x0_int, y0_int = x0_int.item(), y0_int.item()
+ x1_int = (
+ torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w)
+ .to(dtype=torch.int32)
+ .item()
+ )
+ y1_int = (
+ torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h)
+ .to(dtype=torch.int32)
+ .item()
+ )
+ else:
+ x0_int, y0_int = 0, 0
+ x1_int, y1_int = img_w, img_h
+ x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
+
+ num_masks = masks.shape[0]
+
+ img_y: Tensor = (
+ torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
+ )
+ img_x: Tensor = (
+ torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
+ )
+ img_y = (img_y - y0) / (y1 - y0) * 2 - 1 # (N, h)
+ img_x = (img_x - x0) / (x1 - x0) * 2 - 1 # (N, w)
+
+ gx = img_x[:, None, :].expand(num_masks, img_y.size(1), img_x.size(1))
+ gy = img_y[:, :, None].expand(num_masks, img_y.size(1), img_x.size(1))
+ grid = torch.stack([gx, gy], dim=3)
+
+ if not masks.dtype.is_floating_point:
+ masks = masks.float()
+ img_masks = F.grid_sample(masks, grid, align_corners=False)
+
+ if skip_empty:
+ return img_masks[:, 0], ( # pylint: disable=unsubscriptable-object
+ slice(y0_int, y1_int),
+ slice(x0_int, x1_int),
+ )
+ return img_masks[:, 0], () # pylint: disable=unsubscriptable-object
+
+
+def paste_masks_in_image(
+ masks: Tensor,
+ boxes: Tensor,
+ image_shape: tuple[int, int],
+ threshold: float = 0.5,
+ bytes_per_float: int = 4,
+ gpu_mem_limit: int = 1024**3,
+) -> Tensor:
+ """Paste masks that are of a fixed resolution into an image.
+
+ The location, height, and width for pasting each mask is determined by
+ their corresponding bounding boxes in boxes.
+
+ This implementation is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ Args:
+ masks (Tensor): Masks with shape [N, Hmask, Wmask], where N is
+ the number of detected object instances in the image and Hmask,
+ Wmask are the mask width and mask height of the predicted mask
+ (e.g., Hmask = Wmask = 28). Values are in [0, 1].
+ boxes (Tensor): Boxes with shape [N, 4]. boxes[i] and masks[i]
+ correspond to the same object instance.
+ image_shape (tuple[int, int]): Image resolution (width, height).
+ threshold (float, optional): Threshold for discretization of mask.
+ Defaults to 0.5.
+ bytes_per_float (int, optional): Number of bytes per float. Defaults to
+ 4.
+ gpu_mem_limit (int, optional): GPU memory limit. Defaults to 1024**3.
+
+ Returns:
+ Tensor: Masks with shape [N, Himage, Wimage], where N is the
+ number of detected object instances and Himage, Wimage are the
+ image width and height.
+ """
+ assert (
+ masks.shape[-1] == masks.shape[-2]
+ ), "Only square mask predictions are supported"
+ assert threshold >= 0
+ num_masks = len(masks)
+ if num_masks == 0:
+ return masks
+
+ img_w, img_h = image_shape
+
+ # The actual implementation split the input into chunks,
+ # and paste them chunk by chunk.
+ if masks.device.type == "cpu":
+ # CPU is most efficient when they are pasted one by one with
+ # skip_empty=True so that it performs minimal number of operations.
+ num_chunks = num_masks
+ else: # pragma: no cover
+ # GPU benefits from parallelism for larger chunks, but may have
+ # memory issue int(img_h) because shape may be tensors in tracing
+ num_chunks = int(
+ np.ceil(
+ num_masks
+ * int(img_h)
+ * int(img_w)
+ * bytes_per_float
+ / gpu_mem_limit
+ )
+ )
+ assert (
+ num_chunks <= num_masks
+ ), "Default gpu_mem_limit is too small; try increasing it"
+ chunks = torch.chunk(
+ torch.arange(num_masks, device=masks.device), num_chunks
+ )
+
+ img_masks = torch.zeros(
+ num_masks, img_h, img_w, device=masks.device, dtype=torch.bool
+ )
+ for inds in chunks:
+ (
+ masks_chunk,
+ spatial_inds,
+ ) = _do_paste_mask(
+ masks[inds, None, :, :],
+ boxes[inds, :4],
+ img_h,
+ img_w,
+ skip_empty=masks.device.type == "cpu",
+ )
+ masks_chunk = torch.greater_equal(masks_chunk, threshold).to(
+ dtype=torch.bool
+ )
+ img_masks[(inds,) + spatial_inds] = masks_chunk
+ return img_masks.type(torch.uint8)
+
+
+def nhw_to_hwc_mask(
+ masks: Tensor, class_ids: Tensor, ignore_class: int = 255
+) -> Tensor:
+ """Convert N binary HxW masks to HxW semantic mask.
+
+ Args:
+ masks (Tensor): Masks with shape [N, H, W].
+ class_ids (Tensor): Class IDs with shape [N, 1].
+ ignore_class (int, optional): Ignore label. Defaults to 255.
+
+ Returns:
+ Tensor: Masks with shape [H, W], where each location indicate the
+ class label.
+ """
+ hwc_mask = torch.full(
+ masks.shape[1:], ignore_class, dtype=masks.dtype, device=masks.device
+ )
+ for mask, cat_id in zip(masks, class_ids):
+ hwc_mask[mask > 0] = cat_id
+ return hwc_mask
+
+
+def clip_mask(mask: Tensor, target_shape: tuple[int, int]) -> Tensor:
+ """Clip mask.
+
+ Args:
+ mask (Tensor): Mask with shape [C, H, W].
+ target_shape (tuple[int, int]): Target shape (Ht, Wt).
+
+ Returns:
+ Tensor: Clipped mask with shape [C, Ht, Wt].
+ """
+ return mask[:, : target_shape[0], : target_shape[1]]
+
+
+def remove_overlap(mask: Tensor, score: Tensor) -> Tensor:
+ """Remove overlapping pixels between masks.
+
+ Args:
+ mask (Tensor): Mask with shape [N, H, W].
+ score (Tensor): Score with shape [N].
+
+ Returns:
+ Tensor: Mask with shape [N, H, W].
+ """
+ foreground = torch.zeros(
+ mask.shape[1:], dtype=torch.bool, device=mask.device
+ )
+ sort_idx = score.argsort(descending=True)
+ for i in sort_idx:
+ mask[i] = torch.logical_and(mask[i], ~foreground)
+ foreground = torch.logical_or(mask[i], foreground)
+ return mask
+
+
+def postprocess_segms(
+ segms: Tensor,
+ images_hw: list[tuple[int, int]],
+ original_hw: list[tuple[int, int]],
+) -> Tensor:
+ """Postprocess segmentations.
+
+ Args:
+ segms (Tensor): Segmentations with shape [B, C, H, W].
+ images_hw (list[tuple[int, int]]): Image resolutions.
+ original_hw (list[tuple[int, int]]): Original image resolutions.
+
+ Returns:
+ Tensor: Post-processed segmentations.
+ """
+ post_segms = []
+ for segm, image_hw, orig_hw in zip(segms, images_hw, original_hw):
+ post_segms.append(
+ F.interpolate(
+ segm[:, : image_hw[0], : image_hw[1]].unsqueeze(1),
+ size=(orig_hw[0], orig_hw[1]),
+ mode="bilinear",
+ ).squeeze(1)
+ )
+ return torch.stack(post_segms).argmax(dim=1)
+
+
+def masks2boxes(masks: Tensor) -> Tensor:
+ """Obtain the tight bounding boxes of binary masks.
+
+ Args:
+ masks (Tensor): Binary mask of shape (N, H, W).
+
+ Returns:
+ Tensor: Boxes with shape (N, 4) of positive region in binary mask.
+ """
+ num_masks = masks.shape[0]
+ bboxes = masks.new_zeros((num_masks, 4), dtype=torch.float32)
+ x_any = torch.any(masks, dim=1)
+ y_any = torch.any(masks, dim=2)
+ for i in range(num_masks):
+ x = torch.where(x_any[i, :])[0]
+ y = torch.where(y_any[i, :])[0]
+ if len(x) > 0 and len(y) > 0:
+ bboxes[i, :] = bboxes.new_tensor(
+ [x[0], y[0], x[-1] + 1, y[-1] + 1]
+ )
+ return bboxes
diff --git a/vis4d/op/motion/__init__.py b/vis4d/op/motion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..725a2159d956ef4aaf41f115a8b1237d36c9e706
--- /dev/null
+++ b/vis4d/op/motion/__init__.py
@@ -0,0 +1 @@
+"""Motion operations."""
diff --git a/vis4d/op/motion/kalman_filter.py b/vis4d/op/motion/kalman_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a3a05963afba0c14fa77e21da4c0cebee864d21
--- /dev/null
+++ b/vis4d/op/motion/kalman_filter.py
@@ -0,0 +1,84 @@
+"""Kalman Filter PyTorch implementation."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+
+def predict(
+ motion_mat: Tensor,
+ cov_motion_q: Tensor,
+ mean: Tensor,
+ covariance: Tensor,
+) -> tuple[Tensor, Tensor]:
+ """Run Kalman filter prediction step."""
+ # x = Fx
+ mean = torch.matmul(motion_mat, mean)
+
+ # P = (FP)F + Q
+ covariance = (
+ torch.matmul(motion_mat, torch.matmul(covariance, motion_mat.T))
+ + cov_motion_q
+ )
+
+ return mean, covariance
+
+
+def project(
+ update_mat: Tensor, cov_project_r: Tensor, mean: Tensor, covariance: Tensor
+) -> tuple[Tensor, Tensor]:
+ """Project state distribution to measurement space."""
+ # Hx
+ mean = torch.matmul(update_mat, mean)
+
+ # HPH^T + R
+ covariance = torch.matmul(
+ update_mat, torch.matmul(covariance, update_mat.T)
+ )
+ projected_cov = covariance + cov_project_r
+ return mean, projected_cov
+
+
+def update(
+ update_mat: Tensor,
+ cov_project_r: Tensor,
+ mean: Tensor,
+ covariance: Tensor,
+ measurement: Tensor,
+) -> tuple[Tensor, Tensor]:
+ """Run Kalman filter correction step."""
+ # Hx, S = HPH^T + R
+ projected_mean, projected_cov = project(
+ update_mat, cov_project_r, mean, covariance
+ )
+
+ # K = PHT * S^-1
+ chol_factor = torch.linalg.cholesky( # pylint: disable=not-callable
+ projected_cov
+ )
+ kalman_gain = torch.cholesky_solve(
+ torch.matmul(covariance, update_mat.T).T,
+ chol_factor,
+ upper=False,
+ ).T
+
+ # y = z - Hx
+ innovation = measurement - projected_mean
+
+ # x = x + Ky
+ new_mean = mean + torch.matmul(innovation, kalman_gain.T)
+
+ # P = (I-KH)P(I-KH)' + KRK'
+ # This is more numerically stable
+ # and works for non-optimal K vs the equation
+ # P = (I-KH)P usually seen in the literature.
+ i_kh = torch.eye(mean.shape[-1]).to(
+ device=measurement.device
+ ) - torch.matmul(kalman_gain, update_mat)
+
+ new_covariance = torch.matmul(
+ torch.matmul(i_kh, covariance), i_kh.T
+ ) + torch.matmul(torch.matmul(kalman_gain, cov_project_r), kalman_gain.T)
+
+ return new_mean, new_covariance
diff --git a/vis4d/op/motion/velo_lstm.py b/vis4d/op/motion/velo_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1b8ec715ae28a644778327f20711aa6d89c362f
--- /dev/null
+++ b/vis4d/op/motion/velo_lstm.py
@@ -0,0 +1,56 @@
+"""VeloLSTM operations."""
+
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from vis4d.common.typing import LossesType
+from vis4d.op.loss.base import Loss
+
+
+class VeloLSTMLoss(Loss):
+ """Loss term for VeloLSTM."""
+
+ def __init__(self, loc_dim: int = 7, smooth_weight: float = 0.001) -> None:
+ """Initialize the loss term."""
+ super().__init__()
+ self.loc_dim = loc_dim
+ self.smooth_weight = smooth_weight
+
+ @staticmethod
+ def linear_motion_loss(outputs: Tensor) -> Tensor:
+ """Linear motion loss.
+
+ Loss: |(loc_t - loc_t-1), (loc_t-1, loc_t-2)|_1 for t = [2, s_len]
+ """
+ s_len = outputs.shape[1]
+
+ loss = outputs.new_zeros(1)
+ past_motion = outputs[:, 1, :] - outputs[:, 0, :]
+ for idx in range(2, s_len, 1):
+ curr_motion = outputs[:, idx, :] - outputs[:, idx - 1, :]
+ loss += F.l1_loss(past_motion, curr_motion, reduction="mean")
+ past_motion = curr_motion
+ return loss / (s_len - 2)
+
+ def forward(
+ self, loc_preds: Tensor, loc_refines: Tensor, gt_traj: Tensor
+ ) -> LossesType:
+ """Loss term for VeloLSTM."""
+ refine_loss = F.smooth_l1_loss(
+ loc_refines, gt_traj[:, 1:, : self.loc_dim], reduction="mean"
+ )
+ pred_loss = F.smooth_l1_loss(
+ loc_preds[:, :-1, :],
+ gt_traj[:, 2:, : self.loc_dim],
+ reduction="mean",
+ )
+ linear_loss = self.linear_motion_loss(loc_preds[:, :-1, :])
+
+ return {
+ "refine_loss": refine_loss,
+ "pred_loss": pred_loss,
+ "linear_loss": torch.mul(self.smooth_weight, linear_loss),
+ }
diff --git a/vis4d/op/seg/__init__.py b/vis4d/op/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2abad433ff740ed5313977b11976abcf913e7957
--- /dev/null
+++ b/vis4d/op/seg/__init__.py
@@ -0,0 +1 @@
+"""Segmentor module."""
diff --git a/vis4d/op/seg/fcn.py b/vis4d/op/seg/fcn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab9ba09927214f68b4da6fd6c96f5b99b30cd66c
--- /dev/null
+++ b/vis4d/op/seg/fcn.py
@@ -0,0 +1,117 @@
+"""FCN Head for semantic segmentation."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class FCNOut(NamedTuple):
+ """Output of the FCN prediction."""
+
+ pred: torch.Tensor # logits for final prediction, (N, C, H, W)
+ outputs: list[torch.Tensor] # transformed feature maps
+
+
+class FCNHead(nn.Module):
+ """FCN Head made with ResNet base model.
+
+ This is based on the implementation in `torchvision
+ `_.
+ """
+
+ def __init__(
+ self,
+ in_channels: list[int],
+ out_channels: int,
+ dropout_prob: float = 0.1,
+ resize: tuple[int, int] | None = None,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ in_channels (list[int]): Number of channels in multi-level image
+ feature.
+ out_channels (int): Number of output channels. Usually the number
+ of classes.
+ dropout_prob (float, optional): Dropout probability. Defaults to
+ 0.1.
+ resize (tuple(int,int), optional): Target shape to resize output.
+ Defaults to None.
+ """
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.resize = resize
+ self.heads = nn.ModuleList()
+ for in_channel in self.in_channels:
+ self.heads.append(
+ self._make_head(in_channel, self.out_channels, dropout_prob)
+ )
+
+ def _make_head(
+ self, in_channels: int, channels: int, dropout_prob: float
+ ) -> nn.Module:
+ """Generate FCN segmentation head.
+
+ Args:
+ in_channels (int): Input feature channels.
+ channels (int): Output segmentation channels.
+ dropout_prob (float): Dropout probability.
+
+ Returns:
+ nn.Module: FCN segmentation head.
+ """
+ inter_channels = in_channels // 4
+ layers = [
+ nn.Conv2d(
+ in_channels,
+ inter_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(),
+ nn.Dropout(dropout_prob),
+ nn.Conv2d(inter_channels, channels, kernel_size=1),
+ ]
+ return nn.Sequential(*layers)
+
+ def forward(self, feats: list[torch.Tensor]) -> FCNOut:
+ """Transforms feature maps and returns segmentation prediction.
+
+ Args:
+ feats (list[torch.Tensor]): List of multi-level image features.
+
+ Returns:
+ output (list[torch.Tensor]): Each tensor has shape (batch_size,
+ self.channels, H, W) which is prediction for each FCN stages. E.g.,
+
+ outputs[-1] ==> main output map
+ outputs[-2] ==> aux output map (e.g., used for training)
+ outputs[:-2] ==> x[:-2]
+ """
+ outputs = feats.copy()
+ num_features = len(feats)
+ for i in range(len(self.in_channels)):
+ idx = num_features - len(self.in_channels) + i
+ feat = feats[idx]
+ output = self.heads[i](feat)
+ if self.resize:
+ output = F.interpolate(
+ output,
+ size=self.resize,
+ mode="bilinear",
+ align_corners=False,
+ )
+ outputs[idx] = F.log_softmax(output, dim=1)
+ return FCNOut(pred=outputs[-1], outputs=outputs)
+
+ def __call__(self, feats: list[torch.Tensor]) -> FCNOut:
+ """Type definition for function call."""
+ return super()._call_impl(feats)
diff --git a/vis4d/op/seg/semantic_fpn.py b/vis4d/op/seg/semantic_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b646fda3d6566e4098863d9d150b893fa9d2c451
--- /dev/null
+++ b/vis4d/op/seg/semantic_fpn.py
@@ -0,0 +1,122 @@
+"""Semantic FPN Head for segmentation."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from vis4d.op.layer.conv2d import Conv2d
+
+
+class SemanticFPNOut(NamedTuple):
+ """Output of the SemanticFPN prediction."""
+
+ outputs: Tensor # logits for final prediction, (N, C, H, W)
+
+
+class SemanticFPNHead(nn.Module):
+ """SemanticFPNHead used in Panoptic FPN."""
+
+ def __init__(
+ self,
+ num_classes: int = 53,
+ in_channels: int = 256,
+ inner_channels: int = 128,
+ start_level: int = 2,
+ end_level: int = 6,
+ dropout_ratio: float = 0.1,
+ ):
+ """Creates an instance of the class.
+
+ Args:
+ num_classes (int): Number of classes. Default: 53.
+ in_channels (int): Number of channels in the input feature map.
+ inner_channels (int): Number of channels in inner features.
+ start_level (int): The start level of the input features used in
+ SemanticFPN.
+ end_level (int): The end level of the used features, the
+ ``end_level``-th layer will not be used.
+ dropout_ratio (float): The drop ratio of dropout layer.
+ Default: 0.1.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+
+ # Used feature layers are [start_level, end_level)
+ self.start_level = start_level
+ self.end_level = end_level
+ self.num_stages = end_level - start_level
+ self.inner_channels = inner_channels
+
+ self.scale_heads = nn.ModuleList()
+ for i in range(start_level, end_level):
+ head_length = max(1, i - start_level)
+ scale_head: list[nn.Module] = []
+ for k in range(head_length):
+ scale_head.append(
+ Conv2d(
+ in_channels if k == 0 else inner_channels,
+ inner_channels,
+ 3,
+ padding=1,
+ stride=1,
+ bias=False,
+ norm=nn.BatchNorm2d(inner_channels),
+ activation=nn.ReLU(inplace=True),
+ )
+ )
+ if i > start_level:
+ scale_head.append(
+ nn.Upsample(
+ scale_factor=2,
+ mode="bilinear",
+ align_corners=False,
+ )
+ )
+ self.scale_heads.append(nn.Sequential(*scale_head))
+ self.conv_seg = nn.Conv2d(inner_channels, num_classes, 1)
+ self.dropout_ratio = dropout_ratio
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ self.init_weights()
+
+ def init_weights(self) -> None:
+ """Initialize weights."""
+ nn.init.kaiming_normal_(
+ self.conv_seg.weight, mode="fan_out", nonlinearity="relu"
+ )
+ if hasattr(self.conv_seg, "bias") and self.conv_seg.bias is not None:
+ nn.init.constant_(self.conv_seg.bias, 0)
+
+ def forward(self, features: list[Tensor]) -> SemanticFPNOut:
+ """Transforms feature maps and returns segmentation prediction.
+
+ Args:
+ features (list[Tensor]): List of multi-level image features.
+
+ Returns:
+ SemanticFPNOut: Segmentation outputs.
+ """
+ assert self.num_stages <= len(
+ features
+ ), "Number of subnets must be not more than length of features."
+
+ output = self.scale_heads[0](features[self.start_level])
+ for i in range(1, self.num_stages):
+ output = output + F.interpolate(
+ self.scale_heads[i](features[self.start_level + i]),
+ size=output.shape[2:],
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ if self.dropout_ratio > 0:
+ output = self.dropout(output)
+ seg_preds = self.conv_seg(output)
+ return SemanticFPNOut(outputs=seg_preds)
+
+ def __call__(self, feats: list[Tensor]) -> SemanticFPNOut:
+ """Type definition for function call."""
+ return super()._call_impl(feats)
diff --git a/vis4d/op/track/__init__.py b/vis4d/op/track/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f2b8dff252eed5da17e7f1929da96c2806c7637
--- /dev/null
+++ b/vis4d/op/track/__init__.py
@@ -0,0 +1 @@
+"""Tracking models module."""
diff --git a/vis4d/op/track/assignment.py b/vis4d/op/track/assignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..4136c4161cd85a4521cc1ec8a841dfa3f1c11ce2
--- /dev/null
+++ b/vis4d/op/track/assignment.py
@@ -0,0 +1,104 @@
+"""Track assignment functions."""
+
+from __future__ import annotations
+
+import torch
+from scipy.optimize import linear_sum_assignment
+from torch import Tensor
+
+
+def greedy_assign(
+ detection_scores: Tensor,
+ tracklet_ids: Tensor,
+ affinity_scores: Tensor,
+ match_score_thr: float = 0.5,
+ obj_score_thr: float = 0.3,
+ nms_conf_thr: None | float = None,
+) -> Tensor:
+ """Greedy assignment of detections to tracks given affinities."""
+ ids = torch.full(
+ (len(detection_scores),),
+ -1,
+ dtype=torch.long,
+ device=detection_scores.device,
+ )
+
+ for i, score in enumerate(detection_scores):
+ conf, memo_ind = torch.max(affinity_scores[i, :], dim=0)
+ cur_id = tracklet_ids[memo_ind]
+ if conf > match_score_thr:
+ if cur_id > -1:
+ if score > obj_score_thr:
+ ids[i] = cur_id
+ affinity_scores[:i, memo_ind] = 0
+ affinity_scores[(i + 1) :, memo_ind] = 0
+ elif nms_conf_thr is not None and conf > nms_conf_thr:
+ ids[i] = -2
+ return ids
+
+
+def hungarian_assign(
+ detection_scores: Tensor,
+ tracklet_ids: Tensor,
+ affinity_scores: Tensor,
+ match_score_thr: float = 0.5,
+ obj_score_thr: float = 0.3,
+ nms_conf_thr: None | float = None,
+) -> Tensor:
+ """Hungarian assignment of detections to tracks given affinities."""
+ ids = torch.full(
+ (len(detection_scores),),
+ -1,
+ dtype=torch.long,
+ device=detection_scores.device,
+ )
+
+ matched_indices = linear_sum_assignment(-affinity_scores.cpu().numpy())
+
+ for idx in range(len(matched_indices[0])):
+ i = matched_indices[0][idx]
+ memo_ind = matched_indices[1][idx]
+ conf = affinity_scores[i, memo_ind]
+ tid = tracklet_ids[memo_ind]
+ if conf > match_score_thr and tid > -1:
+ if detection_scores[i] > obj_score_thr:
+ ids[i] = tid
+ affinity_scores[:i, memo_ind] = 0
+ affinity_scores[i + 1 :, memo_ind] = 0
+ elif nms_conf_thr is not None and conf > nms_conf_thr:
+ ids[i] = -2
+
+ return ids
+
+
+class TrackIDCounter:
+ """Global counter for track ids.
+
+ Holds a count of tracks to enable unique and contiguous track ids starting
+ from zero.
+ """
+
+ count: int = 0
+
+ @classmethod
+ def reset(cls) -> None:
+ """Reset track id counter."""
+ cls.count = 0
+
+ @classmethod
+ def get_ids(
+ cls, num_ids: int, device: torch.device = torch.device("cpu")
+ ) -> Tensor:
+ """Generate a num_ids number of new unique tracking ids.
+
+ Args:
+ num_ids (int): number of ids
+ device (torch.device, optional): Device to create ids on. Defaults
+ to torch.device("cpu").
+
+ Returns:
+ Tensor: Tensor of new contiguous track ids.
+ """
+ new_ids = torch.arange(cls.count, cls.count + num_ids, device=device)
+ cls.count = cls.count + num_ids
+ return new_ids
diff --git a/vis4d/op/track/common.py b/vis4d/op/track/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8b9e1a88298769cf86f6ead619d46fec692c2e1
--- /dev/null
+++ b/vis4d/op/track/common.py
@@ -0,0 +1,23 @@
+"""Common classes and functions for tracking."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from torch import Tensor
+
+
+class TrackOut(NamedTuple):
+ """Output of track model.
+
+ Attributes:
+ boxes (list[Tensor]): List of bounding boxes (B, N, 4).
+ class_ids (list[Tensor]): List of class ids (B, N).
+ scores (list[Tensor]): List of scores (B, N).
+ track_ids (list[Tensor]): List of track ids (B, N).
+ """
+
+ boxes: list[Tensor]
+ class_ids: list[Tensor]
+ scores: list[Tensor]
+ track_ids: list[Tensor]
diff --git a/vis4d/op/track/matching.py b/vis4d/op/track/matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1baff96a10c28ddc8e829d25891c94963fa358a
--- /dev/null
+++ b/vis4d/op/track/matching.py
@@ -0,0 +1,48 @@
+"""Matching calculation utils."""
+
+from __future__ import annotations
+
+import torch
+from torch.nn import functional as F
+
+
+def calc_bisoftmax_affinity(
+ detection_embeddings: torch.Tensor,
+ track_embeddings: torch.Tensor,
+ detection_class_ids: torch.Tensor | None = None,
+ track_class_ids: torch.Tensor | None = None,
+ with_categories: bool = False,
+) -> torch.Tensor:
+ """Calculate affinity matrix using bisoftmax metric."""
+ feats = torch.mm(detection_embeddings, track_embeddings.t())
+ d2t_scores = feats.softmax(dim=1)
+ t2d_scores = feats.softmax(dim=0)
+ similarity_scores = (d2t_scores + t2d_scores) / 2
+
+ if with_categories:
+ assert (
+ detection_class_ids is not None and track_class_ids is not None
+ ), "Please provide class ids if with_categories=True!"
+ cat_same = detection_class_ids.view(-1, 1) == track_class_ids.view(
+ 1, -1
+ )
+ similarity_scores *= cat_same.float()
+ return similarity_scores
+
+
+def cosine_similarity(
+ key_embeds: torch.Tensor,
+ ref_embeds: torch.Tensor,
+ normalize: bool = True,
+ temperature: float = -1,
+) -> torch.Tensor:
+ """Calculate cosine similarity."""
+ if normalize:
+ key_embeds = F.normalize(key_embeds, p=2, dim=1)
+ ref_embeds = F.normalize(ref_embeds, p=2, dim=1)
+
+ dists = torch.mm(key_embeds, ref_embeds.t())
+
+ if temperature > 0:
+ dists /= temperature # pragma: no cover
+ return dists
diff --git a/vis4d/op/track/qdtrack.py b/vis4d/op/track/qdtrack.py
new file mode 100644
index 0000000000000000000000000000000000000000..b177d4f08acc0cbbbfd59489298a0b636c4cf24c
--- /dev/null
+++ b/vis4d/op/track/qdtrack.py
@@ -0,0 +1,681 @@
+"""Quasi-dense embedding similarity based graph."""
+
+from __future__ import annotations
+
+import math
+from typing import NamedTuple
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.op.box.box2d import bbox_iou
+from vis4d.op.box.matchers.max_iou import MaxIoUMatcher
+from vis4d.op.box.poolers import MultiScaleRoIAlign, MultiScaleRoIPooler
+from vis4d.op.box.samplers import CombinedSampler, match_and_sample_proposals
+from vis4d.op.layer.conv2d import add_conv_branch
+from vis4d.op.loss import EmbeddingDistanceLoss, MultiPosCrossEntropyLoss
+
+from .assignment import TrackIDCounter, greedy_assign
+from .matching import calc_bisoftmax_affinity, cosine_similarity
+
+
+def get_default_box_sampler() -> CombinedSampler:
+ """Get default box sampler of qdtrack."""
+ box_sampler = CombinedSampler(
+ batch_size=256,
+ positive_fraction=0.5,
+ pos_strategy="instance_balanced",
+ neg_strategy="iou_balanced",
+ )
+ return box_sampler
+
+
+def get_default_box_matcher() -> MaxIoUMatcher:
+ """Get default box matcher of qdtrack."""
+ box_matcher = MaxIoUMatcher(
+ thresholds=[0.3, 0.7],
+ labels=[0, -1, 1],
+ allow_low_quality_matches=False,
+ )
+ return box_matcher
+
+
+class QDTrackOut(NamedTuple):
+ """Output of QDTrack during training."""
+
+ key_embeddings: list[Tensor]
+ ref_embeddings: list[list[Tensor]] | None
+ key_track_ids: list[Tensor] | None
+ ref_track_ids: list[list[Tensor]] | None
+
+
+class QDTrackHead(nn.Module):
+ """QDTrack - quasi-dense instance similarity learning."""
+
+ def __init__(
+ self,
+ similarity_head: QDSimilarityHead | None = None,
+ box_sampler: CombinedSampler | None = None,
+ box_matcher: MaxIoUMatcher | None = None,
+ proposal_append_gt: bool = True,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.similarity_head = (
+ QDSimilarityHead() if similarity_head is None else similarity_head
+ )
+
+ self.box_sampler = (
+ box_sampler
+ if box_sampler is not None
+ else get_default_box_sampler()
+ )
+
+ self.box_matcher = (
+ box_matcher
+ if box_matcher is not None
+ else get_default_box_matcher()
+ )
+
+ self.proposal_append_gt = proposal_append_gt
+
+ @torch.no_grad()
+ def _sample_proposals(
+ self,
+ det_boxes: list[list[Tensor]],
+ target_boxes: list[list[Tensor]],
+ target_track_ids: list[list[Tensor]],
+ ) -> tuple[list[list[Tensor]], list[list[Tensor]]]:
+ """Sample proposals for instance similarity learning."""
+ sampled_boxes, sampled_track_ids = [], []
+ for i, (boxes, tgt_boxes) in enumerate(zip(det_boxes, target_boxes)):
+ if self.proposal_append_gt:
+ boxes = [torch.cat([d, t]) for d, t in zip(boxes, tgt_boxes)]
+
+ (
+ sampled_box_indices,
+ sampled_target_indices,
+ sampled_labels,
+ ) = match_and_sample_proposals(
+ self.box_matcher, self.box_sampler, boxes, tgt_boxes
+ )
+
+ positives = [l == 1 for l in sampled_labels]
+ if i == 0: # key view: take only positives
+ sampled_box = [
+ b[s_i][p]
+ for b, s_i, p in zip(boxes, sampled_box_indices, positives)
+ ]
+ sampled_tr_id = [
+ t[s_i][p]
+ for t, s_i, p in zip(
+ target_track_ids[i], sampled_target_indices, positives
+ )
+ ]
+ else: # set track_ids to -1 for all negatives
+ sampled_box = [
+ b[s_i] for b, s_i in zip(boxes, sampled_box_indices)
+ ]
+ sampled_tr_id = [
+ t[s_i]
+ for t, s_i in zip(
+ target_track_ids[i], sampled_target_indices
+ )
+ ]
+ for pos, samp_tgt in zip(positives, sampled_tr_id):
+ samp_tgt[~pos] = -1
+
+ sampled_boxes.append(sampled_box)
+ sampled_track_ids.append(sampled_tr_id)
+ return sampled_boxes, sampled_track_ids
+
+ def forward(
+ self,
+ features: list[Tensor] | list[list[Tensor]],
+ det_boxes: list[Tensor] | list[list[Tensor]],
+ target_boxes: None | list[list[Tensor]] = None,
+ target_track_ids: None | list[list[Tensor]] = None,
+ ) -> QDTrackOut:
+ """Forward function."""
+ if target_boxes is not None and target_track_ids is not None:
+ sampled_boxes, sampled_track_ids = self._sample_proposals(
+ det_boxes, # type: ignore
+ target_boxes,
+ target_track_ids,
+ )
+
+ embeddings = []
+ for feats, boxes in zip(features, sampled_boxes):
+ assert isinstance(feats, list) and isinstance(boxes, list)
+ embeddings.append(self.similarity_head(feats, boxes))
+
+ return QDTrackOut(
+ embeddings[0],
+ embeddings[1:],
+ sampled_track_ids[0],
+ sampled_track_ids[1:],
+ )
+
+ key_embeddings = self.similarity_head(features, det_boxes) # type: ignore # pylint: disable=line-too-long
+
+ return QDTrackOut(key_embeddings, None, None, None)
+
+ def __call__(
+ self,
+ features: list[Tensor] | list[list[Tensor]],
+ det_boxes: list[Tensor] | list[list[Tensor]],
+ target_boxes: None | list[list[Tensor]] = None,
+ target_track_ids: None | list[list[Tensor]] = None,
+ ) -> QDTrackOut:
+ """Type definition for call implementation."""
+ return self._call_impl(
+ features, det_boxes, target_boxes, target_track_ids
+ )
+
+
+class QDTrackAssociation:
+ """Data association relying on quasi-dense instance similarity.
+
+ This class assigns detection candidates to a given memory of existing
+ tracks and backdrops.
+ Backdrops are low-score detections kept in case they have high
+ similarity with a high-score detection in succeeding frames.
+
+ Attributes:
+ init_score_thr: Confidence threshold for initializing a new track
+ obj_score_thr: Confidence treshold s.t. a detection is considered in
+ the track / det matching process.
+ match_score_thr: Similarity score threshold for matching a detection to
+ an existing track.
+ memo_backdrop_frames: Number of timesteps to keep backdrops.
+ memo_momentum: Momentum of embedding memory for smoothing embeddings.
+ nms_backdrop_iou_thr: Maximum IoU of a backdrop with another detection.
+ nms_class_iou_thr: Maximum IoU of a high score detection with another
+ of a different class.
+ with_cats: If to consider category information for tracking (i.e. all
+ detections within a track must have consistent category labels).
+ """
+
+ def __init__(
+ self,
+ init_score_thr: float = 0.7,
+ obj_score_thr: float = 0.3,
+ match_score_thr: float = 0.5,
+ nms_conf_thr: float = 0.5,
+ nms_backdrop_iou_thr: float = 0.3,
+ nms_class_iou_thr: float = 0.7,
+ with_cats: bool = True,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ self.init_score_thr = init_score_thr
+ self.obj_score_thr = obj_score_thr
+ self.match_score_thr = match_score_thr
+ self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
+ self.nms_class_iou_thr = nms_class_iou_thr
+ self.nms_conf_thr = nms_conf_thr
+ self.with_cats = with_cats
+
+ def _filter_detections(
+ self,
+ detections: Tensor,
+ scores: Tensor,
+ class_ids: Tensor,
+ embeddings: Tensor,
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Remove overlapping objects across classes via nms.
+
+ Args:
+ detections (Tensor): [N, 4] Tensor of boxes.
+ scores (Tensor): [N,] Tensor of confidence scores.
+ class_ids (Tensor): [N,] Tensor of class ids.
+ embeddings (Tensor): [N, C] tensor of appearance embeddings.
+
+ Returns:
+ tuple[Tensor]: filtered detections, scores, class_ids,
+ embeddings, and filtered indices.
+ """
+ scores, inds = scores.sort(descending=True)
+ detections, embeddings, class_ids = (
+ detections[inds],
+ embeddings[inds],
+ class_ids[inds],
+ )
+ valids = embeddings.new_ones((len(detections),), dtype=torch.bool)
+ ious = bbox_iou(detections, detections)
+ for i in range(1, len(detections)):
+ if scores[i] < self.obj_score_thr:
+ thr = self.nms_backdrop_iou_thr
+ else:
+ thr = self.nms_class_iou_thr
+
+ if (ious[i, :i] > thr).any():
+ valids[i] = False
+ detections = detections[valids]
+ scores = scores[valids]
+ class_ids = class_ids[valids]
+ embeddings = embeddings[valids]
+ return detections, scores, class_ids, embeddings, inds[valids]
+
+ def __call__(
+ self,
+ detections: Tensor,
+ detection_scores: Tensor,
+ detection_class_ids: Tensor,
+ detection_embeddings: Tensor,
+ memory_track_ids: Tensor | None = None,
+ memory_class_ids: Tensor | None = None,
+ memory_embeddings: Tensor | None = None,
+ ) -> tuple[Tensor, Tensor]:
+ """Process inputs, match detections with existing tracks.
+
+ Args:
+ detections (Tensor): [N, 4] detected boxes.
+ detection_scores (Tensor): [N,] confidence scores.
+ detection_class_ids (Tensor): [N,] class indices.
+ detection_embeddings (Tensor): [N, C] appearance embeddings.
+ memory_track_ids (Tensor): [M,] track ids in memory.
+ memory_class_ids (Tensor): [M,] class indices in memory.
+ memory_embeddings (Tensor): [M, C] appearance embeddings in
+ memory.
+
+ Returns:
+ tuple[Tensor, Tensor]: track ids of active tracks and selected
+ detection indices corresponding to tracks.
+ """
+ (
+ detections,
+ detection_scores,
+ detection_class_ids,
+ detection_embeddings,
+ permute_inds,
+ ) = self._filter_detections(
+ detections,
+ detection_scores,
+ detection_class_ids,
+ detection_embeddings,
+ )
+
+ # match if buffer is not empty
+ if len(detections) > 0 and memory_track_ids is not None:
+ assert (
+ memory_class_ids is not None and memory_embeddings is not None
+ )
+
+ affinity_scores = calc_bisoftmax_affinity(
+ detection_embeddings,
+ memory_embeddings,
+ detection_class_ids,
+ memory_class_ids,
+ self.with_cats,
+ )
+ ids = greedy_assign(
+ detection_scores,
+ memory_track_ids,
+ affinity_scores,
+ self.match_score_thr,
+ self.obj_score_thr,
+ self.nms_conf_thr,
+ )
+ else:
+ ids = torch.full(
+ (len(detections),),
+ -1,
+ dtype=torch.long,
+ device=detections.device,
+ )
+ new_inds = (ids == -1) & (detection_scores > self.init_score_thr)
+ ids[new_inds] = TrackIDCounter.get_ids(
+ new_inds.sum(), device=ids.device # type: ignore
+ )
+ return ids, permute_inds
+
+
+class QDSimilarityHead(nn.Module):
+ """Instance embedding head for quasi-dense similarity learning.
+
+ Given a set of input feature maps and RoIs, pool RoI representations from
+ feature maps and process them to a per-RoI embeddings vector.
+ """
+
+ def __init__(
+ self,
+ proposal_pooler: None | MultiScaleRoIPooler = None,
+ in_dim: int = 256,
+ num_convs: int = 4,
+ conv_out_dim: int = 256,
+ conv_has_bias: bool = False,
+ num_fcs: int = 1,
+ fc_out_dim: int = 1024,
+ embedding_dim: int = 256,
+ norm: str = "GroupNorm",
+ num_groups: int = 32,
+ start_level: int = 2,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ proposal_pooler (None | MultiScaleRoIPooler, optional): RoI pooling
+ module. Defaults to None.
+ in_dim (int, optional): Input feature dimension. Defaults to 256.
+ num_convs (int, optional): Number of convolutional layers inside
+ the head. Defaults to 4.
+ conv_out_dim (int, optional): Output dimension of the last conv
+ layer. Defaults to 256.
+ conv_has_bias (bool, optional): If the conv layers have a bias
+ parameter. Defaults to False.
+ num_fcs (int, optional): Number of fully connected layers following
+ the conv layers. Defaults to 1.
+ fc_out_dim (int, optional): Output dimension of the last fully
+ connected layer. Defaults to 1024.
+ embedding_dim (int, optional): Dimensionality of the output
+ instance embedding. Defaults to 256.
+ norm (str, optional): Normalization of the layers inside the head.
+ One of BatchNorm2d, GroupNorm. Defaults to "GroupNorm".
+ num_groups (int, optional): Number of groups for the GroupNorm
+ normalization. Defaults to 32.
+ start_level (int, optional): starting level of feature maps.
+ Defaults to 2.
+ """
+ super().__init__()
+ self.in_dim = in_dim
+ self.num_convs = num_convs
+ self.conv_out_dim = conv_out_dim
+ self.conv_has_bias = conv_has_bias
+ self.num_fcs = num_fcs
+ self.fc_out_dim = fc_out_dim
+ self.norm = norm
+ self.num_groups = num_groups
+
+ if proposal_pooler is not None:
+ self.roi_pooler = proposal_pooler
+ else:
+ self.roi_pooler = MultiScaleRoIAlign(
+ resolution=[7, 7], strides=[4, 8, 16, 32], sampling_ratio=0
+ )
+
+ # Used feature layers are [start_level, end_level)
+ self.start_level = start_level
+ num_strides = len(self.roi_pooler.scales)
+ self.end_level = start_level + num_strides
+
+ self.convs, self.fcs, last_layer_dim = self._init_embedding_head()
+ self.fc_embed = nn.Linear(last_layer_dim, embedding_dim)
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Init weights of modules in head."""
+ for m in self.convs:
+ nn.init.kaiming_uniform_(m.weight, a=1) # type: ignore
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0) # type: ignore
+
+ for m in self.fcs:
+ if isinstance(m[0], nn.Linear): # type: ignore
+ nn.init.xavier_uniform_(m[0].weight) # type: ignore
+ nn.init.constant_(m[0].bias, 0) # type: ignore
+
+ nn.init.normal_(self.fc_embed.weight, 0, 0.01)
+ nn.init.constant_(self.fc_embed.bias, 0)
+
+ def _init_embedding_head(
+ self,
+ ) -> tuple[torch.nn.ModuleList, torch.nn.ModuleList, int]:
+ """Init modules of head."""
+ convs, last_layer_dim = add_conv_branch(
+ self.num_convs,
+ self.in_dim,
+ self.conv_out_dim,
+ self.conv_has_bias,
+ self.norm,
+ self.num_groups,
+ )
+
+ fcs = nn.ModuleList()
+ if self.num_fcs > 0:
+ last_layer_dim *= math.prod(self.roi_pooler.resolution)
+ for i in range(self.num_fcs):
+ fc_in_dim = last_layer_dim if i == 0 else self.fc_out_dim
+ fcs.append(
+ nn.Sequential(
+ nn.Linear(fc_in_dim, self.fc_out_dim),
+ nn.ReLU(inplace=True),
+ )
+ )
+ last_layer_dim = self.fc_out_dim
+ return convs, fcs, last_layer_dim
+
+ def forward(
+ self, features: list[Tensor], boxes: list[Tensor]
+ ) -> list[Tensor]:
+ """Similarity head forward pass.
+
+ Args:
+ features (list[Tensor]): A feature pyramid. The list index
+ represents the level, which has a downsampling raio of 2^index.
+ fp[0] is a feature map with the image resolution instead of the
+ original image.
+ boxes (list[Tensor]): A list of [N, 4] 2D bounding boxes per
+ batch element.
+
+ Returns:
+ list[Tensor]: An embedding vector per input box, .
+ """
+ # RoI pooling
+ x = self.roi_pooler(features[self.start_level : self.end_level], boxes)
+
+ # convs
+ if self.num_convs > 0:
+ for conv in self.convs:
+ x = conv(x)
+
+ # fcs
+ x = torch.flatten(x, start_dim=1)
+ if self.num_fcs > 0:
+ for fc in self.fcs:
+ x = fc(x)
+
+ embeddings: list[Tensor] = list(
+ self.fc_embed(x).split([len(b) for b in boxes])
+ )
+ return embeddings
+
+ def __call__(
+ self, features: list[Tensor], boxes: list[Tensor]
+ ) -> list[Tensor]:
+ """Type definition."""
+ return self._call_impl(features, boxes)
+
+
+class QDTrackInstanceSimilarityLosses(NamedTuple):
+ """QDTrack losses return type. Consists of two scalar loss tensors."""
+
+ track_loss: Tensor
+ track_loss_aux: Tensor
+
+
+class QDTrackInstanceSimilarityLoss(nn.Module):
+ """Instance similarity loss as in QDTrack.
+
+ Given a number of key frame embeddings and a number of reference frame
+ embeddings along with their track identities, compute two losses:
+ 1. Multi-positive cross-entropy loss.
+ 2. Cosine similarity loss (auxiliary).
+ """
+
+ def __init__(self, softmax_temp: float = -1):
+ """Creates an instance of the class.
+
+ Args:
+ softmax_temp (float, optional): Temperature parameter for
+ multi-positive cross-entropy loss. Defaults to -1.
+ """
+ super().__init__()
+ self.softmax_temp = softmax_temp
+ self.track_loss = MultiPosCrossEntropyLoss()
+ self.track_loss_aux = EmbeddingDistanceLoss()
+ self.track_loss_weight = 0.25
+
+ def forward(
+ self,
+ key_embeddings: list[Tensor],
+ ref_embeddings: list[list[Tensor]],
+ key_track_ids: list[Tensor],
+ ref_track_ids: list[list[Tensor]],
+ ) -> QDTrackInstanceSimilarityLosses:
+ """The QDTrack instance similarity loss.
+
+ Key inputs are of type list[Tensor/Boxes2D] (Lists are length N)
+ Ref inputs are of type list[list[Tensor/Boxes2D]] where the lists
+ are of length MxN.
+ Where M is the number of reference views and N is the
+ number of batch elements.
+
+ NOTE: this only works if key only contains positives and all
+ negatives in ref have track_id -1
+
+ Args:
+ key_embeddings (list[Tensor]): key frame embeddings.
+ ref_embeddings (list[list[Tensor]]): reference frame
+ embeddings.
+ key_track_ids (list[Tensor]): associated track ids per
+ embedding in key frame.
+ ref_track_ids (list[list[Tensor]]): associated track ids per
+ embedding in reference frame(s).
+
+ Returns:
+ QDTrackInstanceSimilarityLosses: Scalar loss tensors.
+ """
+ if sum(len(e) for e in key_embeddings) == 0: # pragma: no cover
+ dummy_loss = sum(e.sum() * 0.0 for e in key_embeddings)
+ return QDTrackInstanceSimilarityLosses(dummy_loss, dummy_loss) # type: ignore # pylint: disable=line-too-long
+
+ loss_track = torch.tensor(0.0, device=key_embeddings[0].device)
+ loss_track_aux = torch.tensor(0.0, device=key_embeddings[0].device)
+ dists, cos_dists = self._match(key_embeddings, ref_embeddings)
+ track_targets, track_weights = self._get_targets(
+ key_track_ids, ref_track_ids
+ )
+ # for each reference view
+ for curr_dists, curr_cos_dists, curr_targets, curr_weights in zip(
+ dists, cos_dists, track_targets, track_weights
+ ):
+ # for each batch element
+ for _dists, _cos_dists, _targets, _weights in zip(
+ curr_dists, curr_cos_dists, curr_targets, curr_weights
+ ):
+ if all(_dists.shape):
+ loss_track += (
+ self.track_loss(
+ _dists,
+ _targets,
+ _weights,
+ avg_factor=_weights.sum() + 1e-5,
+ )
+ * self.track_loss_weight
+ )
+ if self.track_loss_aux is not None:
+ loss_track_aux += self.track_loss_aux(
+ _cos_dists, _targets
+ )
+
+ num_pairs = len(dists) * len(dists[0])
+ loss_track = torch.div(loss_track, num_pairs)
+ loss_track_aux = torch.div(loss_track_aux, num_pairs)
+
+ return QDTrackInstanceSimilarityLosses(
+ track_loss=loss_track, track_loss_aux=loss_track_aux
+ )
+
+ def __call__(
+ self,
+ key_embeddings: list[Tensor],
+ ref_embeddings: list[list[Tensor]],
+ key_track_ids: list[Tensor],
+ ref_track_ids: list[list[Tensor]],
+ ) -> QDTrackInstanceSimilarityLosses:
+ """Type definition."""
+ return self._call_impl(
+ key_embeddings, ref_embeddings, key_track_ids, ref_track_ids
+ )
+
+ @staticmethod
+ def _get_targets(
+ key_track_ids: list[Tensor],
+ ref_track_ids: list[list[Tensor]],
+ ) -> tuple[list[list[Tensor]], list[list[Tensor]]]:
+ """Create tracking target tensors.
+
+ Args:
+ key_track_ids (list[Tensor]): A List of Tensors [N,] per
+ batch element containing the corresponding track ids of each
+ box in the key frame.
+ ref_track_ids (list[list[Tensor]]): A nested list fo Tensors
+ [N,] per batch element, per reference view. The inner list
+ denotes the batch index, the outer list the reference view
+ index. Contains track ids of boxes in all reference views
+ across the batch.
+
+ Returns:
+ tuple[list[list[Tensor]], list[list[Tensor]]]: The
+ target tensors per key-reference pair containing 1 if the
+ identities of two boxes across the key and a reference view
+ match, and 0 otherwise and the loss reduction weights for
+ a certain box.
+ """
+ # for each reference view
+ track_targets, track_weights = [], []
+ for ref_target in ref_track_ids:
+ # for each batch element
+ curr_targets, curr_weights = [], []
+ for key_target, ref_target_ in zip(key_track_ids, ref_target):
+ # target shape: len(key_target) x len(ref_target_)
+ # NOTE: this only works if key only contains positives and all
+ # negatives in ref have track_id -1
+ target = (
+ key_target.view(-1, 1) == ref_target_.view(1, -1)
+ ).int()
+ weight = (target.sum(dim=1) > 0).float()
+ curr_targets.append(target)
+ curr_weights.append(weight)
+ track_targets.append(curr_targets)
+ track_weights.append(curr_weights)
+ return track_targets, track_weights
+
+ def _match(
+ self,
+ key_embeds: list[Tensor],
+ ref_embeds: list[list[Tensor]],
+ ) -> tuple[list[list[Tensor]], list[list[Tensor]]]:
+ """Calculate distances for all pairs of key / ref embeddings.
+
+ Args:
+ key_embeds (list[Tensor]): Embeddings for boxes in key frame.
+ ref_embeds (list[list[Tensor]]): Embeddings for boxes in
+ all reference frames.
+
+ Returns:
+ tuple[list[list[Tensor]], list[list[Tensor]]]:
+ Embedding distances for all embedding pairs, first normalized
+ via softmax, then normal cosine similary.
+ """
+ # for each reference view
+ dists, cos_dists = [], []
+ for ref_embed in ref_embeds:
+ # for each batch element
+ dists_curr, cos_dists_curr = [], []
+ for key_embed, ref_embed_ in zip(key_embeds, ref_embed):
+ dist = cosine_similarity(
+ key_embed,
+ ref_embed_,
+ normalize=False,
+ temperature=self.softmax_temp,
+ )
+ dists_curr.append(dist)
+ if self.track_loss_aux is not None:
+ cos_dist = cosine_similarity(key_embed, ref_embed_)
+ cos_dists_curr.append(cos_dist)
+
+ dists.append(dists_curr)
+ cos_dists.append(cos_dists_curr)
+ return dists, cos_dists
diff --git a/vis4d/op/track3d/__init__.py b/vis4d/op/track3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb3d1797016ad1ca3b483e81bcf2f79a0ba3d0c
--- /dev/null
+++ b/vis4d/op/track3d/__init__.py
@@ -0,0 +1 @@
+"""3D tracking models module."""
diff --git a/vis4d/op/track3d/cc_3dt.py b/vis4d/op/track3d/cc_3dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..837b2b76f8a01ec9db6d151f939bb09e0169e897
--- /dev/null
+++ b/vis4d/op/track3d/cc_3dt.py
@@ -0,0 +1,446 @@
+"""CC-3DT graph."""
+
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from vis4d.op.box.box2d import bbox_iou
+from vis4d.op.geometry.rotation import (
+ euler_angles_to_matrix,
+ matrix_to_quaternion,
+ rotate_orientation,
+ rotate_velocities,
+)
+from vis4d.op.geometry.transform import transform_points
+from vis4d.op.track.assignment import TrackIDCounter, greedy_assign
+from vis4d.op.track.matching import calc_bisoftmax_affinity
+
+from .common import Track3DOut
+
+
+def get_track_3d_out(
+ boxes_3d: Tensor, class_ids: Tensor, scores_3d: Tensor, track_ids: Tensor
+) -> Track3DOut:
+ """Get track 3D output.
+
+ Args:
+ boxes_3d (Tensor): (N, 12): x,y,z,h,w,l,rx,ry,rz,vx,vy,vz
+ class_ids (Tensor): (N,)
+ scores_3d (Tensor): (N,)
+ track_ids (Tensor): (N,)
+
+ Returns:
+ Track3DOut: output
+ """
+ center = boxes_3d[:, :3]
+ # HWL -> WLH
+ dims = boxes_3d[:, [4, 5, 3]]
+ orientation = matrix_to_quaternion(
+ euler_angles_to_matrix(boxes_3d[:, 6:9])
+ )
+
+ return Track3DOut(
+ boxes_3d=[torch.cat([center, dims, orientation], dim=1)],
+ velocities=[boxes_3d[:, 9:12]],
+ class_ids=[class_ids],
+ scores_3d=[scores_3d],
+ track_ids=[track_ids],
+ )
+
+
+class CC3DTrackAssociation:
+ """Data association relying on quasi-dense instance similarity and 3D clue.
+
+ This class assigns detection candidates to a given memory of existing
+ tracks and backdrops.
+ Backdrops are low-score detections kept in case they have high
+ similarity with a high-score detection in succeeding frames.
+ """
+
+ def __init__(
+ self,
+ init_score_thr: float = 0.8,
+ obj_score_thr: float = 0.5,
+ match_score_thr: float = 0.5,
+ nms_backdrop_iou_thr: float = 0.3,
+ nms_class_iou_thr: float = 0.7,
+ nms_conf_thr: float = 0.5,
+ with_cats: bool = True,
+ with_velocities: bool = False,
+ bbox_affinity_weight: float = 0.5,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ init_score_thr (float): Confidence threshold for initializing a new
+ track.
+ obj_score_thr (float): Confidence treshold s.t. a detection is
+ considered in the track / det matching process.
+ match_score_thr (float): Similarity score threshold for matching a
+ detection to an existing track.
+ nms_backdrop_iou_thr (float): Maximum IoU of a backdrop with
+ another detection.
+ nms_class_iou_thr (float): Maximum IoU of a high score detection
+ with another of a different class.
+ nms_conf_thr (float): Confidence threshold for NMS.
+ with_cats (bool): If to consider category information for
+ tracking (i.e. all detections within a track must have
+ consistent category labels).
+ with_velocities (bool): If to use predicted velocities for
+ matching.
+ bbox_affinity_weight (float): Weight of bbox affinity in the
+ overall affinity score.
+ """
+ super().__init__()
+ self.init_score_thr = init_score_thr
+ self.obj_score_thr = obj_score_thr
+ self.match_score_thr = match_score_thr
+ self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
+ self.nms_class_iou_thr = nms_class_iou_thr
+ self.nms_conf_thr = nms_conf_thr
+ self.with_cats = with_cats
+ self.with_velocities = with_velocities
+ self.bbox_affinity_weight = bbox_affinity_weight
+ self.feat_affinity_weight = 1 - bbox_affinity_weight
+
+ def _filter_detections(
+ self,
+ detections: Tensor,
+ camera_ids: Tensor,
+ scores: Tensor,
+ detections_3d: Tensor,
+ scores_3d: Tensor,
+ class_ids: Tensor,
+ embeddings: Tensor,
+ velocities: Tensor | None = None,
+ ) -> tuple[
+ Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor | None, Tensor
+ ]:
+ """Remove overlapping objects across classes via nms.
+
+ Args:
+ detections (Tensor): [N, 4] Tensor of boxes.
+ camera_ids (Tensor): [N,] Tensor of camera ids.
+ scores (Tensor): [N,] Tensor of confidence scores.
+ detections_3d (Tensor): [N, 7] Tensor of 3D boxes.
+ scores_3d (Tensor): [N,] Tensor of 3D confidence scores.
+ class_ids (Tensor): [N,] Tensor of class ids.
+ embeddings (Tensor): [N, C] tensor of appearance embeddings.
+ velocities (Tensor | None): [N, 3] Tensor of velocities.
+
+ Returns:
+ tuple[Tensor]: filtered detections, scores, class_ids,
+ embeddings, and filtered indices.
+ """
+ scores, inds = scores.sort(descending=True)
+ (
+ detections,
+ camera_ids,
+ embeddings,
+ class_ids,
+ detections_3d,
+ scores_3d,
+ ) = (
+ detections[inds],
+ camera_ids[inds],
+ embeddings[inds],
+ class_ids[inds],
+ detections_3d[inds],
+ scores_3d[inds],
+ )
+
+ if velocities is not None:
+ velocities = velocities[inds]
+
+ valids = embeddings.new_ones((len(detections),), dtype=torch.bool)
+
+ ious = bbox_iou(detections, detections)
+ valid_ious = torch.eq(
+ camera_ids.unsqueeze(1), camera_ids.unsqueeze(0)
+ ).int()
+ ious *= valid_ious
+
+ for i in range(1, len(detections)):
+ if scores[i] < self.obj_score_thr:
+ thr = self.nms_backdrop_iou_thr
+ else:
+ thr = self.nms_class_iou_thr
+
+ if (ious[i, :i] > thr).any():
+ valids[i] = False
+
+ detections = detections[valids]
+ scores = scores[valids]
+ detections_3d = detections_3d[valids]
+ scores_3d = scores_3d[valids]
+ class_ids = class_ids[valids]
+ embeddings = embeddings[valids]
+
+ if velocities is not None:
+ velocities = velocities[valids]
+
+ return (
+ detections,
+ scores,
+ detections_3d,
+ scores_3d,
+ class_ids,
+ embeddings,
+ velocities,
+ inds[valids],
+ )
+
+ def depth_ordering(
+ self,
+ obsv_boxes_3d: Tensor,
+ obsv_velocities: Tensor | None,
+ memory_boxes_3d_predict: Tensor,
+ memory_boxes_3d: Tensor,
+ memory_velocities: Tensor,
+ ) -> Tensor:
+ """Depth ordering matching."""
+ # Centroid
+ centroid_weight_list = []
+ for memory_box_3d_predict in memory_boxes_3d_predict:
+ centroid_weight_list.append(
+ F.pairwise_distance( # pylint: disable=not-callable
+ obsv_boxes_3d[:, :3],
+ memory_box_3d_predict[:3],
+ keepdim=True,
+ )
+ )
+ centroid_weight = torch.cat(centroid_weight_list, dim=1)
+ centroid_weight = torch.exp(-torch.div(centroid_weight, 10.0))
+
+ # Moving distance should be aligned
+ motion_weight_list = []
+ moving_dist = (
+ obsv_boxes_3d[:, :3, None]
+ - memory_boxes_3d[:, :3, None].transpose(2, 0)
+ ).transpose(1, 2)
+ for v in moving_dist:
+ motion_weight_list.append(
+ F.pairwise_distance( # pylint: disable=not-callable
+ v, memory_velocities[:, :3]
+ ).unsqueeze(0)
+ )
+ motion_weight = torch.cat(motion_weight_list, dim=0)
+ motion_weight = torch.exp(-torch.div(motion_weight, 5.0))
+
+ # Velocity scores
+ if self.with_velocities:
+ assert (
+ obsv_velocities is not None
+ ), "Please provide velocities if with_velocities=True!"
+
+ velsim_weight_list = []
+ obsvvv_velocities = obsv_velocities.unsqueeze(1).expand_as(
+ moving_dist
+ )
+ for v in obsvvv_velocities:
+ velsim_weight_list.append(
+ F.pairwise_distance( # pylint: disable=not-callable
+ v, memory_velocities[:, -3:]
+ ).unsqueeze(0)
+ )
+ velsim_weight = torch.cat(velsim_weight_list, dim=0)
+ cos_sim = torch.exp(-velsim_weight / 5.0)
+ else:
+ # Moving direction should be aligned
+ # Set to 0.5 when two vector not within +-90 degree
+ cos_sim_list = []
+ obsv_direct = (
+ obsv_boxes_3d[:, :2, None]
+ - memory_boxes_3d[:, :2, None].transpose(2, 0)
+ ).transpose(1, 2)
+ for d in obsv_direct:
+ cos_sim_list.append(
+ F.cosine_similarity( # pylint: disable=not-callable
+ d, memory_velocities[:, :2]
+ ).unsqueeze(0)
+ )
+ cos_sim = torch.cat(cos_sim_list, dim=0)
+ cos_sim = torch.add(cos_sim, 1.0)
+ cos_sim = torch.div(cos_sim, 2.0)
+
+ scores_depth = (
+ cos_sim * centroid_weight + (1.0 - cos_sim) * motion_weight
+ )
+
+ return scores_depth
+
+ def __call__(
+ self,
+ detections: Tensor,
+ camera_ids: Tensor,
+ detection_scores: Tensor,
+ detections_3d: Tensor,
+ detection_scores_3d: Tensor,
+ detection_class_ids: Tensor,
+ detection_embeddings: Tensor,
+ obs_velocities: Tensor | None = None,
+ memory_boxes_3d: Tensor | None = None,
+ memory_track_ids: Tensor | None = None,
+ memory_class_ids: Tensor | None = None,
+ memory_embeddings: Tensor | None = None,
+ memory_boxes_3d_predict: Tensor | None = None,
+ memory_velocities: Tensor | None = None,
+ with_depth_confidence: bool = True,
+ ) -> tuple[Tensor, Tensor]:
+ """Process inputs, match detections with existing tracks.
+
+ Args:
+ detections (Tensor): [N, 4] detected boxes.
+ camera_ids (Tensor): [N,] camera ids.
+ detection_scores (Tensor): [N,] confidence scores.
+ detections_3d (Tensor): [N, 7] detected boxes in 3D.
+ detection_scores_3d (Tensor): [N,] confidence scores in 3D.
+ detection_class_ids (Tensor): [N,] class indices.
+ detection_embeddings (Tensor): [N, C] appearance embeddings.
+ obs_velocities (Tensor | None): [N, 3] velocities of detections.
+ memory_boxes_3d (Tensor): [M, 7] boxes in memory.
+ memory_track_ids (Tensor): [M,] track ids in memory.
+ memory_class_ids (Tensor): [M,] class indices in memory.
+ memory_embeddings (Tensor): [M, C] appearance embeddings in
+ memory.
+ memory_boxes_3d_predict (Tensor): [M, 7] predicted boxes in
+ memory.
+ memory_velocities (Tensor): [M, 7] velocities in memory.
+
+ Returns:
+ tuple[Tensor, Tensor]: track ids of active tracks and selected
+ detection indices corresponding to tracks.
+ """
+ (
+ detections,
+ detection_scores,
+ detections_3d,
+ detection_scores_3d,
+ detection_class_ids,
+ detection_embeddings,
+ obs_velocities,
+ permute_inds,
+ ) = self._filter_detections(
+ detections,
+ camera_ids,
+ detection_scores,
+ detections_3d,
+ detection_scores_3d,
+ detection_class_ids,
+ detection_embeddings,
+ obs_velocities,
+ )
+
+ if with_depth_confidence:
+ depth_confidence = detection_scores_3d
+ else:
+ depth_confidence = detection_scores_3d.new_ones(
+ len(detection_scores_3d)
+ )
+
+ # match if buffer is not empty
+ if len(detections) > 0 and memory_boxes_3d is not None:
+ assert (
+ memory_track_ids is not None
+ and memory_class_ids is not None
+ and memory_embeddings is not None
+ and memory_boxes_3d_predict is not None
+ and memory_velocities is not None
+ )
+
+ # Box 3D
+ bbox3d_weight_list = []
+ for memory_box_3d_predict in memory_boxes_3d_predict:
+ bbox3d_weight_list.append(
+ F.pairwise_distance( # pylint: disable=not-callable
+ detections_3d,
+ memory_box_3d_predict,
+ keepdim=True,
+ )
+ )
+ bbox3d_weight = torch.cat(bbox3d_weight_list, dim=1)
+ scores_iou = torch.exp(-torch.div(bbox3d_weight, 10.0))
+
+ # Depth Ordering
+ scores_depth = self.depth_ordering(
+ detections_3d,
+ obs_velocities,
+ memory_boxes_3d_predict,
+ memory_boxes_3d,
+ memory_velocities,
+ )
+
+ # match using bisoftmax metric
+ similarity_scores = calc_bisoftmax_affinity(
+ detection_embeddings,
+ memory_embeddings,
+ detection_class_ids,
+ memory_class_ids,
+ )
+
+ if self.with_cats:
+ assert (
+ detection_class_ids is not None
+ and memory_class_ids is not None
+ ), "Please provide class ids if with_categories=True!"
+ cat_same = detection_class_ids.view(
+ -1, 1
+ ) == memory_class_ids.view(1, -1)
+ scores_cats = cat_same.float()
+
+ affinity_scores = (
+ self.bbox_affinity_weight * scores_iou * scores_depth
+ + self.feat_affinity_weight * similarity_scores
+ )
+ affinity_scores /= (
+ self.bbox_affinity_weight + self.feat_affinity_weight
+ )
+ affinity_scores = torch.mul(
+ affinity_scores, torch.greater(scores_iou, 0.0).float()
+ )
+ affinity_scores = torch.mul(
+ affinity_scores, torch.greater(scores_depth, 0.0).float()
+ )
+ if self.with_cats:
+ affinity_scores = torch.mul(affinity_scores, scores_cats)
+
+ ids = greedy_assign(
+ detection_scores * depth_confidence,
+ memory_track_ids,
+ affinity_scores,
+ self.match_score_thr,
+ self.obj_score_thr,
+ self.nms_conf_thr,
+ )
+ else:
+ ids = torch.full(
+ (len(detections),),
+ -1,
+ dtype=torch.long,
+ device=detections.device,
+ )
+ new_inds = (ids == -1) & (detection_scores > self.init_score_thr)
+ ids[new_inds] = TrackIDCounter.get_ids(
+ new_inds.sum(), device=ids.device # type: ignore
+ )
+ return ids, permute_inds
+
+
+def cam_to_global(
+ boxes_3d_list: list[Tensor], extrinsics: Tensor
+) -> list[Tensor]:
+ """Convert camera coordinates to global coordinates."""
+ for i, boxes_3d in enumerate(boxes_3d_list):
+ if len(boxes_3d) != 0:
+ boxes_3d_list[i][:, :3] = transform_points(
+ boxes_3d_list[i][:, :3], extrinsics[i]
+ )
+ boxes_3d_list[i][:, 6:9] = rotate_orientation(
+ boxes_3d_list[i][:, 6:9], extrinsics[i]
+ )
+ boxes_3d_list[i][:, 9:12] = rotate_velocities(
+ boxes_3d_list[i][:, 9:12], extrinsics[i]
+ )
+ return boxes_3d_list
diff --git a/vis4d/op/track3d/common.py b/vis4d/op/track3d/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..77a2d289a84a4110bf54e1e76f0a3d748611b086
--- /dev/null
+++ b/vis4d/op/track3d/common.py
@@ -0,0 +1,25 @@
+"""Common classes and functions for 3D tracking."""
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from torch import Tensor
+
+
+class Track3DOut(NamedTuple):
+ """Output of track 3D model.
+
+ Attributes:
+ boxes_3d (list[Tensor]): List of bounding boxes (B, N, 10).
+ velocities (list[Tensor]): List of velocities (B, N, 3).
+ class_ids (list[Tensor]): List of class ids (B, N).
+ scores_3d (list[Tensor]): List of scores (B, N).
+ track_ids (list[Tensor]): List of track ids (B, N).
+ """
+
+ boxes_3d: list[Tensor]
+ velocities: list[Tensor]
+ class_ids: list[Tensor]
+ scores_3d: list[Tensor]
+ track_ids: list[Tensor]
diff --git a/vis4d/op/util.py b/vis4d/op/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb683e7516b1bdd16d4a58df4becf737a156cd48
--- /dev/null
+++ b/vis4d/op/util.py
@@ -0,0 +1,28 @@
+"""Utilities for op."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+
+def unmap(data: Tensor, count: int, inds: Tensor, fill: int = 0) -> Tensor:
+ """Unmap a subset of data back to the original data (of size count).
+
+ Args:
+ data (Tensor): Subset of the original data.
+ count (int): Length of the original data.
+ inds (Tensor): Indices of the subset entries in the original set.
+ fill (int, optional): Fill value for other entries. Defaults to 0.
+
+ Returns:
+ Tensor: Tensor sized like original data that contains the subset.
+ """
+ if data.dim() == 1:
+ ret = data.new_full((count,), fill)
+ ret[inds.type(torch.bool)] = data
+ else:
+ new_size = (count,) + data.size()[1:]
+ ret = data.new_full(new_size, fill)
+ ret[inds.type(torch.bool), :] = data
+ return ret
diff --git a/vis4d/state/__init__.py b/vis4d/state/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..034dac3be7266d41b216e2fa2781d09c75211038
--- /dev/null
+++ b/vis4d/state/__init__.py
@@ -0,0 +1 @@
+"""Memory and internal states needed for models."""
diff --git a/vis4d/state/track/__init__.py b/vis4d/state/track/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3d677fa8378e3cae1ed73fe5d51f9be0a0e1734
--- /dev/null
+++ b/vis4d/state/track/__init__.py
@@ -0,0 +1 @@
+"""Memory and state for tracking algorithms."""
diff --git a/vis4d/state/track/qdtrack.py b/vis4d/state/track/qdtrack.py
new file mode 100644
index 0000000000000000000000000000000000000000..316608954c022312312da242837857f32cae5ef8
--- /dev/null
+++ b/vis4d/state/track/qdtrack.py
@@ -0,0 +1,327 @@
+"""Memory for QDTrack inference."""
+
+from __future__ import annotations
+
+from typing import TypedDict
+
+import torch
+from torch import Tensor
+
+from vis4d.op.box.box2d import bbox_iou
+from vis4d.op.track.assignment import TrackIDCounter
+from vis4d.op.track.common import TrackOut
+from vis4d.op.track.qdtrack import QDTrackAssociation
+
+
+class Track(TypedDict):
+ """QDTrack Track state.
+
+ Attributes:
+ box (Tensor): In shape (4,) and contains x1, y1, x2, y2.
+ score (Tensor): In shape (1,).
+ class_id (Tensor): In shape (1,).
+ embedding (Tensor): In shape (E,). E is the embedding dimension.
+ last_frame (int): Last frame id.
+ """
+
+ box: Tensor
+ score: Tensor
+ class_id: Tensor
+ embed: Tensor
+ last_frame: int
+
+
+class QDTrackGraph:
+ """Quasi-dense embedding similarity based graph."""
+
+ def __init__(
+ self,
+ track: QDTrackAssociation | None = None,
+ memory_size: int = 10,
+ memory_momentum: float = 0.8,
+ nms_backdrop_iou_thr: float = 0.3,
+ backdrop_memory_size: int = 1,
+ ) -> None:
+ """Init."""
+ assert memory_size >= 0
+ self.memory_size = memory_size
+ assert 0 <= memory_momentum <= 1.0
+ self.memory_momentum = memory_momentum
+ assert backdrop_memory_size >= 0
+ self.backdrop_memory_size = backdrop_memory_size
+ self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
+
+ self.tracker = QDTrackAssociation() if track is None else track
+
+ self.tracklets: dict[int, Track] = {}
+ self.backdrops: list[dict[str, Tensor]] = []
+
+ def reset(self) -> None:
+ """Empty the memory."""
+ self.tracklets.clear()
+ self.backdrops.clear()
+
+ def is_empty(self) -> bool:
+ """Check if the memory is empty."""
+ return len(self.tracklets) == 0
+
+ def get_tracks(
+ self,
+ device: torch.device,
+ frame_id: int | None = None,
+ add_backdrops: bool = False,
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Get tracklests.
+
+ If the frame_id is not provided, will return the latest state of all
+ tracklets. Otherwise, will return the state of all tracklets at the
+ given frame_id. If add_backdrops is True, will also return the
+ backdrops.
+
+ Args:
+ device (torch.device): Device to put the tensors on.
+ frame_id (int, optional): Frame id to query. Defaults to None.
+ add_backdrops (bool, optional): Whether to add backdrops to the
+ output. Defaults to False.
+
+ Returns:
+ boxes (Tensor): 2D boxes in shape (N, 4).
+ scores (Tensor): 2D scores in shape (N,).
+ class_ids (Tensor): Class ids in shape (N,).
+ track_ids (Tensor): Track ids in shape (N,).
+ embeddings (Tensor): Embeddings in shape (N, E).
+ """
+ (
+ boxes_list,
+ scores_list,
+ class_ids_list,
+ embeddings_list,
+ track_ids_list,
+ ) = ([], [], [], [], [])
+
+ for track_id, track in self.tracklets.items():
+ if frame_id is None or track["last_frame"] == frame_id:
+ boxes_list.append(track["box"].unsqueeze(0))
+ scores_list.append(track["score"].unsqueeze(0))
+ class_ids_list.append(track["class_id"].unsqueeze(0))
+ embeddings_list.append(track["embed"].unsqueeze(0))
+ track_ids_list.append(track_id)
+
+ boxes = (
+ torch.cat(boxes_list)
+ if len(boxes_list) > 0
+ else torch.empty((0, 4), device=device)
+ )
+ scores = (
+ torch.cat(scores_list)
+ if len(scores_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ class_ids = (
+ torch.cat(class_ids_list)
+ if len(class_ids_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ embeddings = (
+ torch.cat(embeddings_list)
+ if len(embeddings_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ track_ids = torch.tensor(track_ids_list, device=device)
+
+ if add_backdrops:
+ for backdrop in self.backdrops:
+ backdrop_ids = torch.full(
+ (len(backdrop["embeddings"]),),
+ -1,
+ dtype=torch.long,
+ device=device,
+ )
+ track_ids = torch.cat([track_ids, backdrop_ids])
+ boxes = torch.cat([boxes, backdrop["boxes"]])
+ scores = torch.cat([scores, backdrop["scores"]])
+ class_ids = torch.cat([class_ids, backdrop["class_ids"]])
+ embeddings = torch.cat([embeddings, backdrop["embeddings"]])
+
+ return boxes, scores, class_ids, track_ids, embeddings
+
+ def __call__(
+ self,
+ embeddings_list: list[Tensor],
+ det_boxes_list: list[Tensor],
+ det_scores_list: list[Tensor],
+ class_ids_list: list[Tensor],
+ frame_id_list: list[int],
+ ) -> TrackOut:
+ """Forward during test."""
+ (
+ batched_boxes,
+ batched_scores,
+ batched_class_ids,
+ batched_track_ids,
+ ) = ([], [], [], [])
+
+ for frame_id, det_boxes, det_scores, class_ids, embeddings in zip(
+ frame_id_list,
+ det_boxes_list,
+ det_scores_list,
+ class_ids_list,
+ embeddings_list,
+ ):
+ # reset graph at begin of sequence
+ if frame_id == 0:
+ self.reset()
+ TrackIDCounter.reset()
+
+ if not self.is_empty():
+ (
+ _,
+ _,
+ memo_class_ids,
+ memo_track_ids,
+ memo_embeds,
+ ) = self.get_tracks(det_boxes.device, add_backdrops=True)
+ else:
+ memo_class_ids = None
+ memo_track_ids = None
+ memo_embeds = None
+
+ track_ids, filter_indices = self.tracker(
+ det_boxes,
+ det_scores,
+ class_ids,
+ embeddings,
+ memo_track_ids,
+ memo_class_ids,
+ memo_embeds,
+ )
+
+ self.update(
+ frame_id,
+ track_ids,
+ det_boxes[filter_indices],
+ det_scores[filter_indices],
+ class_ids[filter_indices],
+ embeddings[filter_indices],
+ )
+
+ (
+ boxes,
+ scores,
+ class_ids,
+ track_ids,
+ _,
+ ) = self.get_tracks(det_boxes.device, frame_id=frame_id)
+
+ batched_boxes.append(boxes)
+ batched_scores.append(scores)
+ batched_class_ids.append(class_ids)
+ batched_track_ids.append(track_ids)
+
+ return TrackOut(
+ boxes=batched_boxes,
+ class_ids=batched_class_ids,
+ scores=batched_scores,
+ track_ids=batched_track_ids,
+ )
+
+ def update(
+ self,
+ frame_id: int,
+ track_ids: Tensor,
+ boxes: Tensor,
+ scores: Tensor,
+ class_ids: Tensor,
+ embeddings: Tensor,
+ ) -> None:
+ """Update the track memory with a new state."""
+ valid_tracks = track_ids > -1
+
+ # update memo
+ for track_id, box, score, class_id, embed in zip(
+ track_ids[valid_tracks],
+ boxes[valid_tracks],
+ scores[valid_tracks],
+ class_ids[valid_tracks],
+ embeddings[valid_tracks],
+ ):
+ track_id = int(track_id)
+ if track_id in self.tracklets:
+ self.update_track(
+ track_id, box, score, class_id, embed, frame_id
+ )
+ else:
+ self.create_track(
+ track_id, box, score, class_id, embed, frame_id
+ )
+
+ # backdrops
+ backdrop_inds = torch.nonzero(
+ torch.eq(track_ids, -1), as_tuple=False
+ ).squeeze(1)
+
+ ious = bbox_iou(boxes[backdrop_inds], boxes)
+
+ for i, ind in enumerate(backdrop_inds):
+ if (ious[i, :ind] > self.nms_backdrop_iou_thr).any():
+ backdrop_inds[i] = -1
+ backdrop_inds = backdrop_inds[backdrop_inds > -1]
+
+ self.backdrops.insert(
+ 0,
+ {
+ "boxes": boxes[backdrop_inds],
+ "scores": scores[backdrop_inds],
+ "class_ids": class_ids[backdrop_inds],
+ "embeddings": embeddings[backdrop_inds],
+ },
+ )
+
+ # delete invalid tracks from memory
+ invalid_ids = []
+ for k, v in self.tracklets.items():
+ if frame_id - v["last_frame"] >= self.memory_size:
+ invalid_ids.append(k)
+ for invalid_id in invalid_ids:
+ self.tracklets.pop(invalid_id)
+
+ if len(self.backdrops) > self.backdrop_memory_size:
+ self.backdrops.pop()
+
+ def update_track(
+ self,
+ track_id: int,
+ box: Tensor,
+ score: Tensor,
+ class_id: Tensor,
+ embedding: Tensor,
+ frame_id: int,
+ ) -> None:
+ """Update a specific track with a new models."""
+ self.tracklets[track_id]["box"] = box
+ self.tracklets[track_id]["score"] = score
+ self.tracklets[track_id]["class_id"] = class_id
+ self.tracklets[track_id]["embed"] = (
+ 1 - self.memory_momentum
+ ) * self.tracklets[track_id][
+ "embed"
+ ] + self.memory_momentum * embedding
+ self.tracklets[track_id]["last_frame"] = frame_id
+
+ def create_track(
+ self,
+ track_id: int,
+ box: Tensor,
+ score: Tensor,
+ class_id: Tensor,
+ embedding: Tensor,
+ frame_id: int,
+ ) -> None:
+ """Create a new track from a models."""
+ self.tracklets[track_id] = Track(
+ box=box,
+ score=score,
+ class_id=class_id,
+ embed=embedding,
+ last_frame=frame_id,
+ )
diff --git a/vis4d/state/track3d/__init__.py b/vis4d/state/track3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b145e08204552b43e29a0ad9058381c565f0453
--- /dev/null
+++ b/vis4d/state/track3d/__init__.py
@@ -0,0 +1 @@
+"""Memory and state for 3D tracking algorithms."""
diff --git a/vis4d/state/track3d/cc_3dt.py b/vis4d/state/track3d/cc_3dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9928695a6037d8dd283640fd5bdf60c83911ac
--- /dev/null
+++ b/vis4d/state/track3d/cc_3dt.py
@@ -0,0 +1,577 @@
+"""Memory for CC-3DT inference."""
+
+from __future__ import annotations
+
+from typing import TypedDict
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.typing import DictStrAny
+from vis4d.op.box.box2d import bbox_iou
+from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation, get_track_3d_out
+from vis4d.op.track3d.common import Track3DOut
+from vis4d.op.track.assignment import TrackIDCounter
+
+from .motion import BaseMotionModel, KF3DMotionModel, LSTM3DMotionModel
+
+
+class Track(TypedDict):
+ """CC-3DT Track state.
+
+ Attributes:
+ box_2d (Tensor): In shape (4,) and contains x1, y1, x2, y2.
+ score_2d (Tensor): In shape (1,).
+ box_3d (Tensor): In shape (12,) contains x,y,z,h,w,l,rx,ry,rz,vx,vy,vz.
+ score_3d (Tensor): In shape (1,).
+ class_id (Tensor): In shape (1,).
+ embed (Tensor): In shape (E,). E is the embedding dimension.
+ motion_model (BaseMotionModel): The motion model.
+ velocity (Tensor): In shape (motion_dims,).
+ last_frame (int): The last frame the track was updated.
+ acc_frame (int): The number of frames the track was updated.
+ """
+
+ box_2d: Tensor
+ score_2d: Tensor
+ box_3d: Tensor
+ score_3d: Tensor
+ class_id: Tensor
+ embed: Tensor
+ motion_model: BaseMotionModel
+ velocity: Tensor
+ last_frame: int
+ acc_frame: int
+
+
+class CC3DTrackGraph:
+ """CC-3DT tracking graph."""
+
+ def __init__(
+ self,
+ track: CC3DTrackAssociation | None = None,
+ memory_size: int = 10,
+ memory_momentum: float = 0.8,
+ backdrop_memory_size: int = 1,
+ nms_backdrop_iou_thr: float = 0.3,
+ motion_model: str = "KF3D",
+ lstm_model: nn.Module | None = None,
+ motion_dims: int = 7,
+ num_frames: int = 5,
+ fps: int = 2,
+ update_3d_score: bool = True,
+ use_velocities: bool = False,
+ add_backdrops: bool = True,
+ ) -> None:
+ """Creates an instance of the class."""
+ assert memory_size >= 0
+ self.memory_size = memory_size
+ assert 0 <= memory_momentum <= 1.0
+ self.memory_momentum = memory_momentum
+ assert backdrop_memory_size >= 0
+ self.backdrop_memory_size = backdrop_memory_size
+ self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
+
+ self.tracker = CC3DTrackAssociation() if track is None else track
+
+ self.tracklets: dict[int, Track] = {}
+ self.backdrops: list[DictStrAny] = []
+
+ if motion_model == "VeloLSTM":
+ assert (
+ lstm_model is not None
+ ), "lstm_model must be provided for VeloLSTM"
+ self.lstm_model = lstm_model
+
+ self.motion_model = motion_model
+ self.motion_dims = motion_dims
+ self.num_frames = num_frames
+ self.fps = fps
+ self.update_3d_score = update_3d_score
+ self.add_backdrops = add_backdrops
+ self.use_velocities = use_velocities
+
+ def reset(self) -> None:
+ """Empty the memory."""
+ self.tracklets.clear()
+ self.backdrops.clear()
+
+ def is_empty(self) -> bool:
+ """Check if the memory is empty."""
+ return len(self.tracklets) == 0
+
+ def get_tracks(
+ self,
+ device: torch.device,
+ frame_id: int | None = None,
+ add_backdrops: bool = False,
+ ) -> tuple[
+ Tensor,
+ Tensor,
+ Tensor,
+ Tensor,
+ Tensor,
+ Tensor,
+ Tensor,
+ list[BaseMotionModel],
+ Tensor,
+ ]:
+ """Get tracklests.
+
+ If the frame_id is not provided, will return the latest state of all
+ tracklets. Otherwise, will return the state of all tracklets at the
+ given frame_id. If add_backdrops is True, will also return the
+ backdrops.
+
+ Args:
+ device (torch.device): Device to put the tensors on.
+ frame_id (int, optional): Frame id to query. Defaults to None.
+ add_backdrops (bool, optional): Whether to add backdrops to the
+ output. Defaults to False.
+
+ Returns:
+ boxes_2d (Tensor): 2D boxes in shape (N, 4).
+ scores_2d (Tensor): 2D scores in shape (N,).
+ boxes_3d (Tensor): 3D boxes in shape (N, 12).
+ scores_3d (Tensor): 3D scores in shape (N,).
+ class_ids (Tensor): Class ids in shape (N,).
+ track_ids (Tensor): Track ids in shape (N,).
+ embeds (Tensor): Embeddings in shape (N, E).
+ motion_models (list[BaseMotionModel]): Motion models.
+ velocities (Tensor): Velocities in shape (N, 3).
+ """
+ (
+ boxes_2d_list,
+ scores_2d_list,
+ boxes_3d_list,
+ scores_3d_list,
+ class_ids_list,
+ embeds_list,
+ motion_models,
+ velocities_list,
+ track_ids_list,
+ ) = ([], [], [], [], [], [], [], [], [])
+
+ for track_id, track in self.tracklets.items():
+ if frame_id is None or track["last_frame"] == frame_id:
+ boxes_2d_list.append(track["box_2d"].unsqueeze(0))
+ scores_2d_list.append(track["score_2d"].unsqueeze(0))
+ boxes_3d_list.append(track["box_3d"].unsqueeze(0))
+ scores_3d_list.append(track["score_3d"].unsqueeze(0))
+ class_ids_list.append(track["class_id"].unsqueeze(0))
+ embeds_list.append(track["embed"].unsqueeze(0))
+ motion_models.append(track["motion_model"])
+ velocities_list.append(track["velocity"].unsqueeze(0))
+ track_ids_list.append(track_id)
+
+ boxes_2d = (
+ torch.cat(boxes_2d_list)
+ if len(boxes_2d_list) > 0
+ else torch.empty((0, 4), device=device)
+ )
+ scores_2d = (
+ torch.cat(scores_2d_list)
+ if len(scores_2d_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ boxes_3d = (
+ torch.cat(boxes_3d_list)
+ if len(boxes_3d_list) > 0
+ else torch.empty((0, 12), device=device)
+ )
+ scores_3d = (
+ torch.cat(scores_3d_list)
+ if len(scores_3d_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ class_ids = (
+ torch.cat(class_ids_list)
+ if len(class_ids_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ embeds = (
+ torch.cat(embeds_list)
+ if len(embeds_list) > 0
+ else torch.empty((0,), device=device)
+ )
+ velocities = (
+ torch.cat(velocities_list)
+ if len(velocities_list) > 0
+ else torch.empty((0, self.motion_dims), device=device)
+ )
+ track_ids = torch.tensor(track_ids_list, device=device)
+
+ if add_backdrops:
+ for backdrop in self.backdrops:
+ backdrop_ids = torch.full(
+ (len(backdrop["embeddings"]),),
+ -1,
+ dtype=torch.long,
+ device=device,
+ )
+ track_ids = torch.cat([track_ids, backdrop_ids])
+ boxes_2d = torch.cat([boxes_2d, backdrop["boxes_2d"]])
+ scores_2d = torch.cat([scores_2d, backdrop["scores_2d"]])
+ boxes_3d = torch.cat([boxes_3d, backdrop["boxes_3d"]])
+ scores_3d = torch.cat([scores_3d, backdrop["scores_3d"]])
+ class_ids = torch.cat([class_ids, backdrop["class_ids"]])
+ embeds = torch.cat([embeds, backdrop["embeddings"]])
+ motion_models.extend(backdrop["motion_models"])
+ backdrop_vs = torch.zeros_like(
+ backdrop["boxes_3d"][:, : self.motion_dims]
+ )
+ velocities = torch.cat([velocities, backdrop_vs])
+
+ return (
+ boxes_2d,
+ scores_2d,
+ boxes_3d,
+ scores_3d,
+ class_ids,
+ track_ids,
+ embeds,
+ motion_models,
+ velocities,
+ )
+
+ def __call__(
+ self,
+ boxes_2d: Tensor,
+ scores_2d: Tensor,
+ camera_ids: Tensor,
+ boxes_3d: Tensor,
+ scores_3d: Tensor,
+ class_ids: Tensor,
+ embeddings: Tensor,
+ frame_id: int,
+ ) -> Track3DOut:
+ """Update the tracker with new detections."""
+ if frame_id == 0:
+ self.reset()
+ TrackIDCounter.reset()
+
+ if not self.is_empty():
+ (
+ _,
+ _,
+ memo_boxes_3d,
+ _,
+ memo_class_ids,
+ memo_track_ids,
+ memo_embeds,
+ memo_motion_models,
+ memo_velocities,
+ ) = self.get_tracks(
+ boxes_2d.device, add_backdrops=self.add_backdrops
+ )
+
+ memory_boxes_3d = torch.cat(
+ [memo_boxes_3d[:, :6], memo_boxes_3d[:, 8].unsqueeze(1)],
+ dim=1,
+ )
+
+ memory_track_ids = memo_track_ids
+ memory_class_ids = memo_class_ids
+ memory_embeddings = memo_embeds
+
+ memory_boxes_3d_predict = memory_boxes_3d.clone()
+ for i, memo_motion_model in enumerate(memo_motion_models):
+ pd_box_3d = memo_motion_model.predict(
+ update_state=memo_motion_model.age != 0
+ )
+ memory_boxes_3d_predict[i, :3] += pd_box_3d[self.motion_dims :]
+
+ memory_velocities = memo_velocities
+
+ else:
+ memory_boxes_3d = None
+ memory_track_ids = None
+ memory_class_ids = None
+ memory_embeddings = None
+ memory_boxes_3d_predict = None
+ memory_velocities = None
+
+ obs_velocities = boxes_3d[:, 9:]
+ obs_boxes_3d = torch.cat(
+ [boxes_3d[:, :6], boxes_3d[:, 8].unsqueeze(1)], dim=1
+ )
+
+ track_ids, filter_indices = self.tracker(
+ boxes_2d,
+ camera_ids,
+ scores_2d,
+ obs_boxes_3d,
+ scores_3d,
+ class_ids,
+ embeddings,
+ obs_velocities,
+ memory_boxes_3d,
+ memory_track_ids,
+ memory_class_ids,
+ memory_embeddings,
+ memory_boxes_3d_predict,
+ memory_velocities,
+ self.update_3d_score,
+ )
+
+ self.update(
+ frame_id,
+ track_ids,
+ boxes_2d[filter_indices],
+ scores_2d[filter_indices],
+ camera_ids[filter_indices],
+ boxes_3d[filter_indices],
+ scores_3d[filter_indices],
+ class_ids[filter_indices],
+ embeddings[filter_indices],
+ obs_boxes_3d[filter_indices],
+ )
+
+ (
+ _,
+ scores_2d,
+ boxes_3d,
+ scores_3d,
+ class_ids,
+ track_ids,
+ _,
+ _,
+ _,
+ ) = self.get_tracks(boxes_2d.device, frame_id=frame_id)
+
+ # update 3D score
+ if self.update_3d_score:
+ track_scores_3d = scores_2d * scores_3d
+ else:
+ track_scores_3d = scores_3d
+
+ return get_track_3d_out(
+ boxes_3d, class_ids, track_scores_3d, track_ids
+ )
+
+ def update(
+ self,
+ frame_id: int,
+ track_ids: Tensor,
+ boxes_2d: Tensor,
+ scores_2d: Tensor,
+ camera_ids: Tensor,
+ boxes_3d: Tensor,
+ scores_3d: Tensor,
+ class_ids: Tensor,
+ embeddings: Tensor,
+ obs_boxes_3d: Tensor,
+ ) -> None:
+ """Update the track memory with a new state."""
+ valid_tracks = track_ids > -1
+
+ # update memo
+ for (
+ track_id,
+ box_2d,
+ score_2d,
+ box_3d,
+ score_3d,
+ class_id,
+ embed,
+ obs_box_3d,
+ ) in zip(
+ track_ids[valid_tracks],
+ boxes_2d[valid_tracks],
+ scores_2d[valid_tracks],
+ boxes_3d[valid_tracks],
+ scores_3d[valid_tracks],
+ class_ids[valid_tracks],
+ embeddings[valid_tracks],
+ obs_boxes_3d[valid_tracks],
+ ):
+ track_id = int(track_id)
+ if track_id in self.tracklets:
+ self.update_track(
+ track_id,
+ box_2d,
+ score_2d,
+ box_3d,
+ score_3d,
+ class_id,
+ embed,
+ obs_box_3d,
+ frame_id,
+ )
+ else:
+ self.create_track(
+ track_id,
+ box_2d,
+ score_2d,
+ box_3d,
+ score_3d,
+ class_id,
+ embed,
+ obs_box_3d,
+ frame_id,
+ )
+
+ # Handle vanished tracklets
+ for track_id, track in self.tracklets.items():
+ if frame_id > track["last_frame"] and track_id > -1:
+ pd_box_3d = track["motion_model"].predict()
+ track["box_3d"][:6] = pd_box_3d[:6]
+ track["box_3d"][8] = pd_box_3d[6]
+
+ # Backdrops
+ backdrop_inds = torch.nonzero(
+ torch.eq(track_ids, -1), as_tuple=False
+ ).squeeze(1)
+
+ valid_ious = torch.eq(
+ camera_ids[backdrop_inds].unsqueeze(1),
+ camera_ids.unsqueeze(0),
+ ).int()
+ ious = bbox_iou(boxes_2d[backdrop_inds], boxes_2d)
+ ious *= valid_ious
+
+ for i, ind in enumerate(backdrop_inds):
+ if (ious[i, :ind] > self.nms_backdrop_iou_thr).any():
+ backdrop_inds[i] = -1
+ backdrop_inds = backdrop_inds[backdrop_inds > -1]
+
+ backdrop_motion_model = []
+ for bd_ind in backdrop_inds:
+ backdrop_motion_model.append(
+ self.build_motion_model(obs_boxes_3d[bd_ind])
+ )
+
+ self.backdrops.insert(
+ 0,
+ {
+ "boxes_2d": boxes_2d[backdrop_inds],
+ "scores_2d": scores_2d[backdrop_inds],
+ "boxes_3d": boxes_3d[backdrop_inds],
+ "scores_3d": scores_3d[backdrop_inds],
+ "class_ids": class_ids[backdrop_inds],
+ "embeddings": embeddings[backdrop_inds],
+ "motion_models": backdrop_motion_model,
+ },
+ )
+
+ # delete invalid tracks from memory
+ invalid_ids = []
+ for k, v in self.tracklets.items():
+ if frame_id - v["last_frame"] >= self.memory_size:
+ invalid_ids.append(k)
+ for invalid_id in invalid_ids:
+ self.tracklets.pop(invalid_id)
+
+ if len(self.backdrops) > self.backdrop_memory_size:
+ self.backdrops.pop()
+
+ def update_track(
+ self,
+ track_id: int,
+ box_2d: Tensor,
+ score_2d: Tensor,
+ box_3d: Tensor,
+ score_3d: Tensor,
+ class_id: Tensor,
+ embed: Tensor,
+ obs_box_3d: Tensor,
+ frame_id: int,
+ ) -> None:
+ """Update a track."""
+ self.tracklets[track_id]["box_2d"] = box_2d
+ self.tracklets[track_id]["score_2d"] = score_2d
+ self.tracklets[track_id]["motion_model"].update(obs_box_3d, score_3d)
+
+ pd_box_3d = self.tracklets[track_id]["motion_model"].get_state()[
+ : self.motion_dims
+ ]
+
+ prev_obs = torch.cat(
+ [
+ self.tracklets[track_id]["box_3d"][:6],
+ self.tracklets[track_id]["box_3d"][8].unsqueeze(0),
+ ]
+ )
+
+ self.tracklets[track_id]["box_3d"] = box_3d
+ self.tracklets[track_id]["box_3d"][:6] = pd_box_3d[:6]
+ self.tracklets[track_id]["box_3d"][8] = pd_box_3d[6]
+ self.tracklets[track_id]["box_3d"][9:12] = self.tracklets[track_id][
+ "motion_model"
+ ].predict_velocity()
+ self.tracklets[track_id]["score_3d"] = score_3d
+ self.tracklets[track_id]["class_id"] = class_id
+
+ self.tracklets[track_id]["embed"] = (
+ 1 - self.memory_momentum
+ ) * self.tracklets[track_id]["embed"] + self.memory_momentum * embed
+
+ velocity = (pd_box_3d - prev_obs) / (
+ frame_id - self.tracklets[track_id]["last_frame"]
+ )
+
+ self.tracklets[track_id]["velocity"] = (
+ self.tracklets[track_id]["velocity"]
+ * self.tracklets[track_id]["acc_frame"]
+ + velocity
+ ) / (self.tracklets[track_id]["acc_frame"] + 1)
+
+ # Use predicted velocity if available
+ if self.use_velocities:
+ self.tracklets[track_id]["velocity"][4:] = self.tracklets[
+ track_id
+ ]["box_3d"][9:12]
+
+ self.tracklets[track_id]["last_frame"] = frame_id
+ self.tracklets[track_id]["acc_frame"] += 1
+
+ def create_track(
+ self,
+ track_id: int,
+ box_2d: Tensor,
+ score_2d: Tensor,
+ box_3d: Tensor,
+ score_3d: Tensor,
+ class_id: Tensor,
+ embed: Tensor,
+ obs_box_3d: Tensor,
+ frame_id: int,
+ ) -> None:
+ """Create a new track."""
+ motion_model = self.build_motion_model(obs_box_3d)
+
+ self.tracklets[track_id] = Track(
+ box_2d=box_2d,
+ score_2d=score_2d,
+ box_3d=box_3d,
+ score_3d=score_3d,
+ class_id=class_id,
+ embed=embed,
+ motion_model=motion_model,
+ velocity=torch.zeros(self.motion_dims, device=box_3d.device),
+ last_frame=frame_id,
+ acc_frame=0,
+ )
+
+ def build_motion_model(self, obs_3d: Tensor) -> BaseMotionModel:
+ """Build motion model."""
+ if self.motion_model == "KF3D":
+ return KF3DMotionModel(
+ num_frames=self.num_frames,
+ obs_3d=obs_3d,
+ motion_dims=self.motion_dims,
+ fps=self.fps,
+ )
+
+ if self.motion_model == "VeloLSTM":
+ return LSTM3DMotionModel(
+ num_frames=self.num_frames,
+ lstm_model=self.lstm_model,
+ obs_3d=obs_3d,
+ motion_dims=self.motion_dims,
+ fps=self.fps,
+ )
+
+ raise NotImplementedError(
+ f"Motion model: {self.motion_model} not known!"
+ )
diff --git a/vis4d/state/track3d/motion/__init__.py b/vis4d/state/track3d/motion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..585785bbed43e6fb75a904c39e13760de71a3ef1
--- /dev/null
+++ b/vis4d/state/track3d/motion/__init__.py
@@ -0,0 +1,7 @@
+"""3D Motional Models."""
+
+from .base import BaseMotionModel
+from .kf3d import KF3DMotionModel
+from .lstm_3d import LSTM3DMotionModel
+
+__all__ = ["BaseMotionModel", "KF3DMotionModel", "LSTM3DMotionModel"]
diff --git a/vis4d/state/track3d/motion/base.py b/vis4d/state/track3d/motion/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c83fa6357d9ca977a8e44d76a7655ec0dfb79676
--- /dev/null
+++ b/vis4d/state/track3d/motion/base.py
@@ -0,0 +1,50 @@
+"""Motion model base class."""
+
+from torch import Tensor
+
+
+class BaseMotionModel:
+ """Base class for motion model."""
+
+ def __init__(
+ self,
+ num_frames: int,
+ motion_dims: int,
+ hits: int = 1,
+ hit_streak: int = 0,
+ time_since_update: int = 0,
+ age: int = 0,
+ fps: int = 1,
+ ) -> None:
+ """Creates an instance of the class."""
+ self.num_frames = num_frames
+ self.motion_dims = motion_dims
+ self.hits = hits
+ self.hit_streak = hit_streak
+ self.time_since_update = time_since_update
+ self.age = age
+ self.fps = fps
+
+ def update(self, obs_3d: Tensor, info: Tensor) -> None:
+ """Update the state."""
+ raise NotImplementedError()
+
+ def predict_velocity(self) -> Tensor:
+ """Predict velocity."""
+ raise NotImplementedError()
+
+ def predict(self, update_state: bool = True) -> Tensor:
+ """Predict the state."""
+ raise NotImplementedError()
+
+ def get_state(self) -> Tensor:
+ """Get the state."""
+ raise NotImplementedError()
+
+
+def update_array(origin_array: Tensor, input_array: Tensor) -> Tensor:
+ """Update array according the input."""
+ new_array = origin_array.clone()
+ new_array[:-1] = origin_array[1:]
+ new_array[-1:] = input_array
+ return new_array
diff --git a/vis4d/state/track3d/motion/kf3d.py b/vis4d/state/track3d/motion/kf3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c102cba4df9069d401d262b833e5f760e16017eb
--- /dev/null
+++ b/vis4d/state/track3d/motion/kf3d.py
@@ -0,0 +1,133 @@
+"""Kalman Filter 3D motion model."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+from vis4d.common.typing import ArgsType
+from vis4d.op.geometry.rotation import acute_angle, normalize_angle
+from vis4d.op.motion.kalman_filter import predict, update
+
+from .base import BaseMotionModel
+
+
+class KF3DMotionModel(BaseMotionModel):
+ """Kalman filter 3D motion model."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ obs_3d: Tensor,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(*args, **kwargs)
+ self.device = obs_3d.device
+
+ # F, H, Q, R
+ (
+ self._motion_mat,
+ self._update_mat,
+ self._cov_motion_q,
+ self._cov_project_r,
+ ) = self._kf3d_init()
+
+ self._motion_mat = self._motion_mat.to(self.device)
+ self._update_mat = self._update_mat.to(self.device)
+ self._cov_motion_q = self._cov_motion_q.to(self.device)
+ self._cov_project_r = self._cov_project_r.to(self.device)
+
+ self.mean, self.covariance = self._init_mean_cov(obs_3d)
+
+ def _kf3d_init(self) -> tuple[Tensor, Tensor, Tensor, Tensor]:
+ """KF3D init function."""
+ motion_mat = torch.Tensor(
+ [
+ [1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
+ ]
+ )
+
+ update_mat = torch.Tensor(
+ [
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
+ ]
+ )
+
+ cov_motion_q = torch.eye(self.motion_dims + 3)
+ cov_motion_q[self.motion_dims :, self.motion_dims :] *= 0.01
+
+ cov_project_r = torch.eye(self.motion_dims)
+ return motion_mat, update_mat, cov_motion_q, cov_project_r
+
+ def _init_mean_cov(self, obs_3d: Tensor) -> tuple[Tensor, Tensor]:
+ """Init KF3D mean and covariance."""
+ mean = torch.zeros(self.motion_dims + 3).to(obs_3d.device)
+ mean[: self.motion_dims] = obs_3d
+ covariance = torch.eye(self.motion_dims + 3).to(obs_3d.device) * 10.0
+ covariance[self.motion_dims :, self.motion_dims :] *= 1000.0
+ return mean, covariance
+
+ def update(self, obs_3d: Tensor, info: Tensor) -> None:
+ """Update the state."""
+ self.time_since_update = 0
+ self.hits += 1
+ self.hit_streak += 1
+
+ self.mean[6] = normalize_angle(self.mean[6])
+ obs_3d[6] = normalize_angle(obs_3d[6])
+
+ self.mean[6] = acute_angle(self.mean[6], obs_3d[6])
+
+ self.mean, self.covariance = update(
+ self._update_mat,
+ self._cov_project_r,
+ self.mean,
+ self.covariance,
+ obs_3d,
+ )
+ self.mean[6] = normalize_angle(self.mean[6])
+
+ def predict_velocity(self) -> Tensor:
+ """Predict velocity."""
+ pred_loc, _ = predict(
+ self._motion_mat,
+ self._cov_motion_q,
+ self.mean,
+ self.covariance,
+ )
+ return (pred_loc[:3] - self.mean[:3]) * self.fps
+
+ def predict(self, update_state: bool = True) -> Tensor:
+ """Predict the state."""
+ self.mean, self.covariance = predict(
+ self._motion_mat, self._cov_motion_q, self.mean, self.covariance
+ )
+
+ self.mean[6] = normalize_angle(self.mean[6])
+
+ self.age += 1
+ if self.time_since_update > 0:
+ self.hit_streak = 0
+ self.time_since_update += 1
+
+ return self.mean
+
+ def get_state(self) -> Tensor:
+ """Returns the current bounding box estimate."""
+ return self.mean
diff --git a/vis4d/state/track3d/motion/lstm_3d.py b/vis4d/state/track3d/motion/lstm_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d943e7c1c714701817beedde32ed52ceccd4d297
--- /dev/null
+++ b/vis4d/state/track3d/motion/lstm_3d.py
@@ -0,0 +1,149 @@
+"""LSTM 3D motion model."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor, nn
+
+from vis4d.common.typing import ArgsType
+from vis4d.model.motion.velo_lstm import VeloLSTM
+from vis4d.op.geometry.rotation import acute_angle, normalize_angle
+
+from .base import BaseMotionModel, update_array
+
+
+class LSTM3DMotionModel(BaseMotionModel):
+ """LSTM 3D motion model."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ lstm_model: nn.Module,
+ obs_3d: Tensor,
+ init_flag: bool = True,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Initialize a motion model using initial bounding box."""
+ super().__init__(*args, **kwargs)
+ self.init_flag = init_flag
+ self.device = obs_3d.device
+
+ assert isinstance(
+ lstm_model, VeloLSTM
+ ), "Currently only support VeloLSTM motion model!"
+ self.lstm_model = lstm_model
+ self.lstm_model.to(self.device)
+ self.lstm_model.eval()
+
+ self.obj_state = torch.cat([obs_3d, obs_3d.new_zeros(3)])
+ self.history = obs_3d.new_zeros(self.num_frames, self.motion_dims)
+ self.ref_history = torch.cat(
+ [obs_3d.view(1, self.motion_dims)] * (self.num_frames + 1)
+ )
+ self.prev_ref = obs_3d.clone()
+ self.hidden_pred = self.lstm_model.init_hidden(
+ self.device, batch_size=1
+ )
+ self.hidden_ref = self.lstm_model.init_hidden(
+ self.device, batch_size=1
+ )
+
+ def _update_history(self, bbox_3d: Tensor) -> None:
+ """Update velocity history."""
+ self.ref_history = update_array(self.ref_history, bbox_3d)
+ self.history = update_array(
+ self.history, self.ref_history[-1] - self.ref_history[-2]
+ )
+ self.prev_ref[: self.motion_dims] = self.obj_state[: self.motion_dims]
+
+ def _init_history(self, bbox_3d: Tensor) -> None:
+ """Initialize velocity history."""
+ self.ref_history = update_array(self.ref_history, bbox_3d)
+ self.history = torch.cat(
+ [
+ (self.ref_history[-1] - self.ref_history[-2]).view(
+ 1, self.motion_dims
+ )
+ ]
+ * self.num_frames
+ )
+ self.prev_ref[: self.motion_dims] = self.obj_state[: self.motion_dims]
+
+ def update(self, obs_3d: Tensor, info: Tensor) -> None:
+ """Updates the state vector with observed bbox."""
+ self.time_since_update = 0
+ self.hits += 1
+ self.hit_streak += 1
+
+ if self.age == 1:
+ self.obj_state[: self.motion_dims] = obs_3d.clone()
+
+ self.obj_state[6] = normalize_angle(self.obj_state[6])
+ obs_3d[6] = normalize_angle(obs_3d[6])
+
+ # acute angle
+ self.obj_state[6] = acute_angle(self.obj_state[6], obs_3d[6])
+
+ with torch.no_grad():
+ refined_loc, self.hidden_ref = self.lstm_model.refine(
+ self.obj_state[: self.motion_dims].unsqueeze(0),
+ obs_3d.unsqueeze(0),
+ self.prev_ref.unsqueeze(0),
+ info.unsqueeze(0).unsqueeze(0),
+ self.hidden_ref,
+ )
+
+ refined_obj = refined_loc.view(self.motion_dims)
+ refined_obj[6] = normalize_angle(refined_obj[6])
+
+ self.obj_state[: self.motion_dims] = refined_obj
+
+ if self.init_flag:
+ self._init_history(refined_obj)
+ self.init_flag = False
+ else:
+ self._update_history(refined_obj)
+
+ def predict_velocity(self) -> Tensor:
+ """Predict velocity."""
+ with torch.no_grad():
+ pred_loc, _ = self.lstm_model.predict(
+ self.history[..., : self.motion_dims].view(
+ self.num_frames, -1, self.motion_dims
+ ),
+ self.obj_state[: self.motion_dims],
+ self.hidden_pred,
+ )
+ return (pred_loc[0][:3] - self.prev_ref[:3]) * self.fps
+
+ def predict(self, update_state: bool = True) -> Tensor:
+ """Advances the state vector and returns the predicted bounding box."""
+ with torch.no_grad():
+ pred_loc, hidden_pred = self.lstm_model.predict(
+ self.history[..., : self.motion_dims].view(
+ self.num_frames, -1, self.motion_dims
+ ),
+ self.obj_state[: self.motion_dims],
+ self.hidden_pred,
+ )
+
+ pred_state = self.obj_state.clone()
+ pred_state[: self.motion_dims] = pred_loc.view(self.motion_dims)
+ pred_state[self.motion_dims :] = pred_state[:3] - self.prev_ref[:3]
+
+ pred_state[6] = normalize_angle(pred_state[6])
+
+ if update_state:
+ self.hidden_pred = hidden_pred
+ self.obj_state = pred_state
+
+ self.age += 1
+ if self.time_since_update > 0:
+ self.hit_streak = 0
+ self.time_since_update += 1
+
+ return pred_state
+
+ def get_state(self) -> Tensor:
+ """Returns the current bounding box estimate."""
+ return self.obj_state
diff --git a/vis4d/vis/__init__.py b/vis4d/vis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f08cdc210651555f0038f2f178dc863c06e047df
--- /dev/null
+++ b/vis4d/vis/__init__.py
@@ -0,0 +1 @@
+"""Contains visualization tools for a variety of data types."""
diff --git a/vis4d/vis/__pycache__/__init__.cpython-311.pyc b/vis4d/vis/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..838c5141e849e7b06b452cb976958e6cf8a39473
Binary files /dev/null and b/vis4d/vis/__pycache__/__init__.cpython-311.pyc differ
diff --git a/vis4d/vis/__pycache__/util.cpython-311.pyc b/vis4d/vis/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c49c968e4486f9a0572cb093c44c6d619f70c2f
Binary files /dev/null and b/vis4d/vis/__pycache__/util.cpython-311.pyc differ
diff --git a/vis4d/vis/base.py b/vis4d/vis/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1985f00435521160ec52aa377aa2f5385fe0ea96
--- /dev/null
+++ b/vis4d/vis/base.py
@@ -0,0 +1,53 @@
+"""Visualizer base class."""
+
+from vis4d.common.typing import ArgsType
+
+
+class Visualizer:
+ """Base visualizer class."""
+
+ def __init__(self, vis_freq: int = 50, image_mode: str = "RGB") -> None:
+ """Initialize the visualizer.
+
+ Args:
+ vis_freq (int): Visualization frequency. Defaults to 0.
+ image_mode (str): Image channel mode (RGB or BGR).
+ """
+ self.vis_freq = vis_freq
+ self.image_mode = image_mode
+ assert image_mode in {"RGB", "BGR"}
+
+ def _run_on_batch(self, cur_iter: int) -> bool:
+ """Return whether to run on current iteration.
+
+ Args:
+ cur_iter (int): Current iteration.
+ """
+ return cur_iter % self.vis_freq == 0
+
+ def reset(self) -> None:
+ """Reset visualizer for new round of evaluation."""
+ raise NotImplementedError()
+
+ def process(self, cur_iter: int, *args: ArgsType) -> None:
+ """Process data of single sample."""
+ raise NotImplementedError()
+
+ def show(self, cur_iter: int, blocking: bool = True) -> None:
+ """Shows the visualization.
+
+ Args:
+ cur_iter (int): Current iteration.
+ blocking (bool): If the visualization should be blocking and wait
+ for human input. Defaults to True.
+ """
+ raise NotImplementedError()
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Args:
+ cur_iter (int): Current iteration.
+ output_folder (str): Folder where the output should be written.
+ """
+ raise NotImplementedError()
diff --git a/vis4d/vis/image/__init__.py b/vis4d/vis/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7525db917d3bd5b8bd108a354f9fe215a7480d34
--- /dev/null
+++ b/vis4d/vis/image/__init__.py
@@ -0,0 +1,6 @@
+"""Image Visualization."""
+
+from .bounding_box_visualizer import BoundingBoxVisualizer
+from .seg_mask_visualizer import SegMaskVisualizer
+
+__all__ = ["BoundingBoxVisualizer", "SegMaskVisualizer"]
diff --git a/vis4d/vis/image/bbox3d_visualizer.py b/vis4d/vis/image/bbox3d_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f336f48bbb91c7288f4588c5139902de1bab943
--- /dev/null
+++ b/vis4d/vis/image/bbox3d_visualizer.py
@@ -0,0 +1,465 @@
+"""Bounding box 3D visualizer."""
+
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from collections.abc import Sequence
+from dataclasses import dataclass
+
+import numpy as np
+import torch
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArgsType,
+ ArrayLike,
+ ArrayLikeFloat,
+ ArrayLikeInt,
+ NDArrayF32,
+ NDArrayUI8,
+)
+from vis4d.data.const import AxisMode
+from vis4d.op.geometry.transform import inverse_rigid_transform
+from vis4d.vis.base import Visualizer
+from vis4d.vis.util import generate_color_map
+
+from .canvas import CanvasBackend, PillowCanvasBackend
+from .util import preprocess_boxes3d, preprocess_image, project_point
+from .viewer import ImageViewerBackend, MatplotlibImageViewer
+
+
+@dataclass
+class DetectionBox3D:
+ """Dataclass storing box informations."""
+
+ corners: list[tuple[float, float, float]]
+ label: str
+ color: tuple[int, int, int]
+ track_id: int | None
+
+
+@dataclass
+class DataSample:
+ """Dataclass storing a data sample that can be visualized."""
+
+ image: NDArrayUI8
+ image_name: str
+ intrinsics: NDArrayF32
+ extrinsics: NDArrayF32 | None
+ sequence_name: str | None
+ camera_name: str | None
+ boxes: list[DetectionBox3D]
+
+
+class BoundingBox3DVisualizer(Visualizer):
+ """Bounding box 3D visualizer class."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ n_colors: int = 50,
+ cat_mapping: dict[str, int] | None = None,
+ file_type: str = "png",
+ image_mode: str = "RGB",
+ width: int = 2,
+ camera_near_clip: float = 0.15,
+ plot_heading: bool = True,
+ axis_mode: AxisMode = AxisMode.ROS,
+ trajectory_length: int = 10,
+ plot_trajectory: bool = True,
+ save_boxes3d: bool = False,
+ canvas: CanvasBackend | None = None,
+ viewer: ImageViewerBackend | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new Visualizer for Image and 3D Bounding Boxes.
+
+ Args:
+ n_colors (int): How many colors should be used for the internal
+ color map. Defaults to 100.
+ cat_mapping (dict[str, int]): Mapping from class names to class
+ ids. Defaults to None.
+ file_type (str): Desired file type. Defaults to "png".
+ image_mode (str): Image channel mode (RGB or BGR). Defaults to
+ "RGB".
+ width (int): Width of the drawn bounding boxes. Defaults to 2.
+ camera_near_clip (float): Near clipping plane of the camera.
+ Defaults to 0.15.
+ plot_heading (bool): If the heading should be plotted. Defaults to
+ True.
+ axis_mode (AxisMode): Axis mode for the input bboxes. Defaults to
+ AxisMode.ROS (i.e. global coordinate).
+ trajectory_length (int): How many past frames should be used to
+ draw the trajectory. Defaults to 10.
+ plot_trajectory (bool): If the trajectory should be plotted.
+ Defaults to True.
+ save_boxes3d (bool): If the corners of 3D boxes should be saved to
+ disk in the format of npy. Defaults to False.
+ canvas (CanvasBackend): Backend that is used to draw on images. If
+ None a PillowCanvasBackend is used.
+ viewer (ImageViewerBackend): Backend that is used show images. If
+ None a MatplotlibImageViewer is used.
+ """
+ super().__init__(*args, **kwargs)
+ self._samples: list[DataSample] = []
+ self.axis_mode = axis_mode
+ self.trajectories: dict[int, list[tuple[float, float, float]]] = (
+ defaultdict(list)
+ )
+ self.trajectory_length = trajectory_length
+ self.plot_trajectory = plot_trajectory
+
+ self.color_palette = generate_color_map(n_colors)
+
+ self.class_id_mapping = (
+ {v: k for k, v in cat_mapping.items()}
+ if cat_mapping is not None
+ else {}
+ )
+
+ self.file_type = file_type
+ self.image_mode = image_mode
+ self.width = width
+
+ self.camera_near_clip = camera_near_clip
+ self.plot_heading = plot_heading
+ self.save_boxes3d = save_boxes3d
+
+ self.canvas = canvas if canvas is not None else PillowCanvasBackend()
+ self.viewer = viewer if viewer is not None else MatplotlibImageViewer()
+
+ def reset(self) -> None:
+ """Reset visualizer."""
+ self._samples.clear()
+
+ def __repr__(self) -> str:
+ """Return string representation."""
+ return "BoundingBox3DVisualizer"
+
+ def process( # pylint: disable=arguments-differ
+ self,
+ cur_iter: int,
+ images: list[ArrayLike],
+ image_names: list[str],
+ boxes3d: list[ArrayLikeFloat],
+ intrinsics: ArrayLikeFloat,
+ extrinsics: None | ArrayLikeFloat = None,
+ scores: None | list[ArrayLikeFloat] = None,
+ class_ids: None | list[ArrayLikeInt] = None,
+ track_ids: None | list[ArrayLikeInt] = None,
+ sequence_names: None | list[str] = None,
+ categories: None | list[list[str]] = None,
+ ) -> None:
+ """Processes a batch of data.
+
+ Args:
+ cur_iter (int): Current iteration.
+ images (list[ArrayLike]): Images to show.
+ image_names (list[str]): Image names.
+ boxes3d (list[ArrayLikeFloat]): List of predicted bounding boxes
+ with shape [B, N, 10].
+ intrinsics (ArrayLikeFloat): Camera intrinsics with shape
+ [B, 3, 3].
+ extrinsics (None | ArrayLikeFloat, optional): Camera extrinsics
+ with shape [B, 4, 4]. Defaults to None.
+ scores (None | list[ArrayLikeFloat], optional): List of predicted
+ box scores each of shape [B, N]. Defaults to None.
+ class_ids (None | list[ArrayLikeInt], optional): List of predicted
+ class ids each of shape [B, N]. Defaults to None.
+ track_ids (None | list[ArrayLikeInt], optional): List of predicted
+ track ids each of shape [B, N]. Defaults to None.
+ sequence_names (None | list[str], optional): List of sequence
+ names of shape [B,]. Defaults to None.
+ categories (None | list[list[str]], optional): List of categories
+ for each image. Instead of class ids, the categories will be
+ used to label the boxes. Defaults to None.
+ """
+ if self._run_on_batch(cur_iter):
+ for batch, image in enumerate(images):
+ self.process_single_image(
+ image,
+ image_names[batch],
+ boxes3d[batch],
+ intrinsics[batch], # type: ignore
+ (
+ None if extrinsics is None else extrinsics[batch] # type: ignore # pylint: disable=line-too-long
+ ),
+ None if scores is None else scores[batch],
+ None if class_ids is None else class_ids[batch],
+ None if track_ids is None else track_ids[batch],
+ None if sequence_names is None else sequence_names[batch],
+ None if categories is None else categories[batch],
+ )
+
+ for tid in self.trajectories:
+ if len(self.trajectories[tid]) > self.trajectory_length:
+ self.trajectories[tid].pop(0)
+
+ def process_single_image(
+ self,
+ image: ArrayLike,
+ image_name: str,
+ boxes3d: ArrayLikeFloat,
+ intrinsics: ArrayLikeFloat,
+ extrinsics: None | ArrayLikeFloat = None,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ sequence_name: None | str = None,
+ categories: None | list[str] = None,
+ camera_name: None | str = None,
+ ) -> None:
+ """Processes a single image entry.
+
+ Args:
+ image (ArrayLike): Image to show.
+ image_name (str): Image name.
+ boxes3d (ArrayLikeFloat): Predicted bounding boxes with shape
+ [N, 10], where N is the number of boxes.
+ intrinsics (ArrayLikeFloat): Camera intrinsics with shape [3, 3].
+ extrinsics (None | ArrayLikeFloat, optional): Camera extrinsics
+ with shape [4, 4]. Defaults to None.
+ scores (None | ArrayLikeFloat, optional): Predicted box scores of
+ shape [N]. Defaults to None.
+ class_ids (None | ArrayLikeInt, optional): Predicted class ids of
+ shape [N]. Defaults to None.
+ track_ids (None | ArrayLikeInt, optional): Predicted track ids of
+ shape [N]. Defaults to None.
+ sequence_name (None | str, optional): Sequence name. Defaults to
+ None.
+ categories (None | list[str], optional): List of categories for
+ each box. Instead of class ids, the categories will be used to
+ label the boxes. Defaults to None.
+ camera_name (None | str, optional): Camera name. Defaults to None.
+ """
+ img_normalized = preprocess_image(image, mode=self.image_mode)
+ image_hw = (img_normalized.shape[0], img_normalized.shape[1])
+
+ intrinsics_np = array_to_numpy(intrinsics, n_dims=2, dtype=np.float32)
+ extrinsics_np = (
+ array_to_numpy(extrinsics, n_dims=2, dtype=np.float32)
+ if extrinsics is not None
+ else None
+ )
+ data_sample = DataSample(
+ img_normalized,
+ image_name,
+ intrinsics_np,
+ extrinsics_np,
+ sequence_name,
+ camera_name,
+ [],
+ )
+
+ if len(boxes3d) != 0: # type: ignore
+ for center, corners, label, color, track_id in zip(
+ *preprocess_boxes3d(
+ image_hw,
+ boxes3d,
+ intrinsics,
+ extrinsics,
+ scores,
+ class_ids,
+ track_ids,
+ self.color_palette,
+ self.class_id_mapping,
+ axis_mode=self.axis_mode,
+ categories=categories,
+ )
+ ):
+ data_sample.boxes.append(
+ DetectionBox3D(
+ corners=corners,
+ label=label,
+ color=color,
+ track_id=track_id,
+ )
+ )
+ if track_id is not None:
+ self.trajectories[track_id].append(center)
+
+ self._samples.append(data_sample)
+
+ def show(self, cur_iter: int, blocking: bool = True) -> None:
+ """Shows the processed images in a interactive window.
+
+ Args:
+ cur_iter (int): Current iteration.
+ blocking (bool): If the visualizer should be blocking i.e. wait for
+ human input for each image. Defaults to True.
+ """
+ if self._run_on_batch(cur_iter):
+ image_data = [self._draw_image(d) for d in self._samples]
+ self.viewer.show_images(image_data, blocking=blocking)
+
+ def _draw_image(self, sample: DataSample) -> NDArrayUI8:
+ """Visualizes the datasample and returns is as numpy image.
+
+ Args:
+ sample (DataSample): The data sample to visualize.
+
+ Returns:
+ NDArrayUI8: A image with the visualized data sample.
+ """
+ self.canvas.create_canvas(sample.image)
+
+ if self.plot_trajectory:
+ assert (
+ sample.extrinsics is not None
+ ), "Extrinsics is needed to plot trajectory."
+ global_to_cam = inverse_rigid_transform(
+ torch.from_numpy(sample.extrinsics)
+ ).numpy()
+
+ for box in sample.boxes:
+ self.canvas.draw_box_3d(
+ box.corners,
+ box.color,
+ sample.intrinsics,
+ self.width,
+ self.camera_near_clip,
+ self.plot_heading,
+ )
+
+ selected_corner = project_point(box.corners[0], sample.intrinsics)
+ self.canvas.draw_text(
+ (selected_corner[0], selected_corner[1]), box.label, box.color
+ )
+
+ if self.plot_trajectory:
+ assert (
+ box.track_id is not None
+ ), "track id must be set to plot trajectory."
+
+ trajectory = self.trajectories[box.track_id]
+ for center in trajectory:
+ # Move global center to current camera frame
+ center_cam = np.dot(global_to_cam, [*center, 1])[:3]
+
+ if center_cam[2] > 0:
+ projected_center = project_point(
+ center_cam, sample.intrinsics
+ )
+ self.canvas.draw_circle(
+ projected_center, box.color, self.width * 2
+ )
+
+ return self.canvas.as_numpy_image()
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Writes all processes samples to the output folder naming each image
+ ..
+
+ Args:
+ cur_iter (int): Current iteration.
+ output_folder (str): Folder where the output should be written.
+ """
+ if self._run_on_batch(cur_iter):
+ for sample in self._samples:
+ output_dir = output_folder
+ image_name = f"{sample.image_name}.{self.file_type}"
+
+ self._draw_image(sample)
+
+ if sample.sequence_name is not None:
+ output_dir = os.path.join(output_dir, sample.sequence_name)
+
+ if sample.camera_name is not None:
+ output_dir = os.path.join(output_dir, sample.camera_name)
+
+ os.makedirs(output_dir, exist_ok=True)
+ self.canvas.save_to_disk(os.path.join(output_dir, image_name))
+
+ if self.save_boxes3d:
+ corners = np.array([box.corners for box in sample.boxes])
+
+ np.save(
+ os.path.join(output_dir, f"{sample.image_name}.npy"),
+ corners,
+ )
+
+
+class MultiCameraBBox3DVisualizer(BoundingBox3DVisualizer):
+ """Bounding box 3D visualizer class for multi-camera datasets."""
+
+ def __init__(
+ self, *args: ArgsType, cameras: Sequence[str], **kwargs: ArgsType
+ ) -> None:
+ """Creates a new Visualizer for Image and 3D Bounding Boxes.
+
+ Args:
+ cameras (Sequence[str]): Camera names.
+ """
+ super().__init__(*args, **kwargs)
+
+ self.cameras = cameras
+
+ def __repr__(self) -> str:
+ """Return string representation."""
+ return "MultiCameraBBox3DVisualizer"
+
+ def process( # type: ignore # pylint: disable=arguments-differ
+ self,
+ cur_iter: int,
+ images: list[list[ArrayLike]],
+ image_names: list[list[str]],
+ boxes3d: list[ArrayLikeFloat],
+ intrinsics: list[ArrayLikeFloat],
+ extrinsics: list[ArrayLikeFloat] | None = None,
+ scores: list[ArrayLikeFloat] | None = None,
+ class_ids: list[ArrayLikeInt] | None = None,
+ track_ids: list[ArrayLikeInt] | None = None,
+ sequence_names: list[str] | None = None,
+ categories: None | list[list[str]] = None,
+ ) -> None:
+ """Processes a batch of data.
+
+ Args:
+ cur_iter (int): Current iteration.
+ images (list[ArrayLike]): Images to show.
+ image_names (list[str]): Image names.
+ boxes3d (list[ArrayLikeFloat]): List of predicted bounding boxes
+ with shape [B, N, 10].
+ intrinsics (ArrayLikeFloat): Camera intrinsics with shape
+ [num_cam, B, 3, 3].
+ extrinsics (None | ArrayLikeFloat, optional): Camera extrinsics
+ with shape [num_cam, B, 4, 4]. Defaults to None.
+ scores (None | list[ArrayLikeFloat], optional): List of predicted
+ box scores each of shape [B, N]. Defaults to None.
+ class_ids (None | list[ArrayLikeInt], optional): List of predicted
+ class ids each of shape [B, N]. Defaults to None.
+ track_ids (None | list[ArrayLikeInt], optional): List of predicted
+ track ids each of shape [B, N]. Defaults to None.
+ sequence_names (None | list[str], optional): List of sequence
+ names of shape [B,]. Defaults to None.
+ categories (None | list[list[str]], optional): List of categories
+ for each image. Instead of class ids, the categories will be
+ used to label the boxes. Defaults to None.
+ """
+ if self._run_on_batch(cur_iter):
+ for idx, batch_images in enumerate(images):
+ for batch, image in enumerate(batch_images):
+ self.process_single_image(
+ image,
+ image_names[idx][batch],
+ boxes3d[batch],
+ intrinsics[idx][batch], # type: ignore
+ (
+ None
+ if extrinsics is None
+ else extrinsics[idx][batch] # type: ignore
+ ),
+ None if scores is None else scores[batch],
+ None if class_ids is None else class_ids[batch],
+ None if track_ids is None else track_ids[batch],
+ (
+ None
+ if sequence_names is None
+ else sequence_names[batch]
+ ),
+ None if categories is None else categories[batch],
+ self.cameras[idx],
+ )
diff --git a/vis4d/vis/image/bev_visualizer.py b/vis4d/vis/image/bev_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6c68c901becd9c96df710c581dd00a4e4671c72
--- /dev/null
+++ b/vis4d/vis/image/bev_visualizer.py
@@ -0,0 +1,361 @@
+"""BEV Bounding box 3D visualizer."""
+
+from __future__ import annotations
+
+import os
+from collections import defaultdict
+from dataclasses import dataclass
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArgsType,
+ ArrayLikeFloat,
+ ArrayLikeInt,
+ NDArrayF32,
+ NDArrayUI8,
+)
+from vis4d.data.const import AxisMode
+from vis4d.op.box.box3d import boxes3d_to_corners, transform_boxes3d
+from vis4d.op.geometry.transform import inverse_rigid_transform
+from vis4d.vis.base import Visualizer
+from vis4d.vis.util import generate_color_map
+
+from .canvas import CanvasBackend, PillowCanvasBackend
+from .viewer import ImageViewerBackend, MatplotlibImageViewer
+
+
+@dataclass
+class BEVBox:
+ """Dataclass storing box informations."""
+
+ corners: list[tuple[float, float]]
+ color: tuple[int, int, int]
+ track_id: int | None
+
+
+@dataclass
+class DataSample:
+ """Dataclass storing a data sample that can be visualized."""
+
+ name: str
+ extrinsics: NDArrayF32
+ sequence_name: str | None
+ boxes: list[BEVBox]
+
+
+class BEVBBox3DVisualizer(Visualizer):
+ """BEV Bounding box 3D visualizer class."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ n_colors: int = 50,
+ file_type: str = "png",
+ max_range: float = 60,
+ scale: float = 10,
+ width: int = 2,
+ margin: int = 10,
+ axis_mode: AxisMode = AxisMode.ROS,
+ trajectory_length: int = 10,
+ plot_trajectory: bool = True,
+ canvas: CanvasBackend | None = None,
+ viewer: ImageViewerBackend | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new Visualizer for BEV Image and Bounding Boxes.
+
+ Args:
+ n_colors (int): How many colors should be used for the internal
+ color map. Defaults to 100.
+ file_type (str): Desired file type. Defaults to "png".
+ max_range (float): Maximum range (meters) of the BEV image.
+ Defaults to 60.
+ scale (float): Scale of the BEV image. Defaults to 10. Means that
+ 1m in the BEV image is 10px.
+ width (int): Width of the drawn bounding boxes. Defaults to 2.
+ margin (int): Margin of the BEV image. Defaults to 10.
+ axis_mode (AxisMode): Axis mode for the input bboxes. Defaults to
+ AxisMode.ROS (i.e. global coordinate).
+ trajectory_length (int): How many past frames should be used to
+ draw the trajectory. Defaults to 10.
+ plot_trajectory (bool): If the trajectory should be plotted.
+ Defaults to True.
+ canvas (CanvasBackend): Backend that is used to draw on images. If
+ None a PillowCanvasBackend is used.
+ viewer (ImageViewerBackend): Backend that is used show images. If
+ None a MatplotlibImageViewer is used.
+ """
+ super().__init__(*args, **kwargs)
+ self._samples: list[DataSample] = []
+ self.axis_mode = axis_mode
+ self.trajectories: dict[int, list[tuple[float, float, float]]] = (
+ defaultdict(list)
+ )
+ self.trajectory_length = trajectory_length
+ self.plot_trajectory = plot_trajectory
+
+ self.color_palette = generate_color_map(n_colors)
+
+ self.file_type = file_type
+ self.max_range = max_range
+ self.scale = scale
+
+ # Generate figure size
+ self.figure_hw = (
+ int(max_range * scale + margin) * 2,
+ int(max_range * scale + margin) * 2,
+ )
+
+ self.width = width
+
+ self.canvas = canvas if canvas is not None else PillowCanvasBackend()
+ self.viewer = viewer if viewer is not None else MatplotlibImageViewer()
+
+ def __repr__(self) -> str:
+ """Return string representation."""
+ return "BEVBBox3DVisualizer"
+
+ def reset(self) -> None:
+ """Reset visualizer."""
+ self._samples.clear()
+
+ def process( # pylint: disable=arguments-differ
+ self,
+ cur_iter: int,
+ sample_names: list[list[str]] | list[str],
+ boxes3d: list[ArrayLikeFloat],
+ extrinsics: list[ArrayLikeFloat] | ArrayLikeFloat,
+ class_ids: None | list[ArrayLikeInt] = None,
+ track_ids: None | list[ArrayLikeInt] = None,
+ sequence_names: None | list[str] = None,
+ ) -> None:
+ """Processes a batch of data."""
+ # Handle multi-sensor connector results from multi-sensor data dict
+ if isinstance(sample_names[0], list) and isinstance(extrinsics, list):
+ sample_names = sample_names[0]
+ extrinsics = extrinsics[0]
+
+ if self._run_on_batch(cur_iter):
+ for batch, sample_name in enumerate(sample_names):
+ self.process_single(
+ sample_name, # type: ignore
+ boxes3d[batch],
+ extrinsics[batch], # type: ignore
+ class_ids[batch] if class_ids is not None else None,
+ track_ids[batch] if track_ids is not None else None,
+ (
+ sequence_names[batch]
+ if sequence_names is not None
+ else None
+ ),
+ )
+
+ for tid in self.trajectories:
+ if len(self.trajectories[tid]) > self.trajectory_length:
+ self.trajectories[tid].pop(0)
+
+ def process_single(
+ self,
+ sample_name: str,
+ boxes3d: ArrayLikeFloat,
+ extrinsics: ArrayLikeFloat,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ sequence_name: None | str = None,
+ ) -> None:
+ """Process single batch."""
+ boxes3d = array_to_numpy(boxes3d, n_dims=2, dtype=np.float32)
+ extrinsics_np = array_to_numpy(extrinsics, n_dims=2, dtype=np.float32)
+ data_sample = DataSample(
+ sample_name,
+ extrinsics_np,
+ sequence_name,
+ [],
+ )
+
+ boxes3d_lidar, boxes3d = self._get_lidar_and_global_boxes3d(
+ boxes3d, extrinsics_np
+ )
+
+ corners = boxes3d_to_corners(
+ boxes3d_lidar, axis_mode=AxisMode.LIDAR
+ ).numpy()
+
+ track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32)
+ class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32)
+
+ for i in range(corners.shape[0]):
+ track_id = None if track_ids_np is None else int(track_ids_np[i])
+ class_id = None if class_ids_np is None else int(class_ids_np[i])
+
+ if track_id is not None:
+ color = self.color_palette[track_id % len(self.color_palette)]
+ self.trajectories[track_id].append(
+ tuple(boxes3d[i][:3].tolist())
+ )
+ elif class_id is not None:
+ color = self.color_palette[class_id % len(self.color_palette)]
+ else:
+ color = (255, 0, 0)
+
+ data_sample.boxes.append(
+ BEVBox(
+ [tuple(pts) for pts in corners[i, :4, :2]],
+ color,
+ track_id=track_id,
+ )
+ )
+
+ self._samples.append(data_sample)
+
+ def _get_lidar_and_global_boxes3d(
+ self, boxes3d: NDArrayF32, extrinsics: NDArrayF32
+ ) -> tuple[Tensor, NDArrayF32]:
+ """Get boxes3d in lidar and global frame."""
+ if self.axis_mode == AxisMode.ROS:
+ global_to_lidar = inverse_rigid_transform(
+ torch.from_numpy(extrinsics)
+ )
+
+ boxes3d_global = boxes3d
+
+ boxes3d_lidar = transform_boxes3d(
+ torch.from_numpy(boxes3d),
+ global_to_lidar,
+ source_axis_mode=self.axis_mode,
+ target_axis_mode=AxisMode.LIDAR,
+ )
+ elif self.axis_mode == AxisMode.LIDAR:
+ boxes3d_global = transform_boxes3d(
+ torch.from_numpy(boxes3d),
+ torch.from_numpy(extrinsics),
+ source_axis_mode=self.axis_mode,
+ target_axis_mode=AxisMode.ROS,
+ ).numpy()
+
+ boxes3d_lidar = torch.from_numpy(boxes3d)
+ else:
+ raise NotImplementedError(
+ f"Axis mode {self.axis_mode} not supported"
+ )
+ return boxes3d_lidar, boxes3d_global
+
+ def show(self, cur_iter: int, blocking: bool = True) -> None:
+ """Shows the processed images in a interactive window.
+
+ Args:
+ cur_iter (int): Current iteration.
+ blocking (bool): If the visualizer should be blocking i.e. wait for
+ human input for each image. Defaults to True.
+ """
+ if self._run_on_batch(cur_iter):
+ image_data = [self._draw_image(d) for d in self._samples]
+ self.viewer.show_images(image_data, blocking=blocking)
+
+ def _map_lidar_to_bev_image(
+ self, point_x: float, point_y: float
+ ) -> tuple[float, float]:
+ """Maps a point from lidar frame to BEV image frame."""
+ return (
+ self.scale * point_x + self.figure_hw[1] // 2,
+ self.scale * -point_y + self.figure_hw[0] // 2,
+ )
+
+ def _draw_image(self, sample: DataSample) -> NDArrayUI8:
+ """Visualizes the datasample and returns is as numpy image.
+
+ Args:
+ sample (DataSample): The data sample to visualize.
+
+ Returns:
+ NDArrayUI8: A image with the visualized data sample.
+ """
+ self.canvas.create_canvas(image_hw=self.figure_hw)
+
+ img_center = self._map_lidar_to_bev_image(0, 0)
+
+ # Mark range every 10m
+ for i in range(int(self.max_range / 10), 0, -1):
+ distance = int(10 * self.scale * i)
+ grey_level = 140 + i * 10
+ self.canvas.draw_circle(
+ img_center, (grey_level, grey_level, grey_level), distance
+ )
+
+ self.canvas.draw_text(
+ (img_center[0] + distance - 25, img_center[1]),
+ f"{10 * i} m",
+ color=(0, 0, 0),
+ )
+
+ # Draw ego car
+ self.canvas.draw_rotated_box(
+ [
+ (img_center[0] - self.scale, img_center[1] - self.scale * 2),
+ (img_center[0] + self.scale, img_center[1] - self.scale * 2),
+ (img_center[0] - self.scale, img_center[1] + self.scale * 2),
+ (img_center[0] + self.scale, img_center[1] + self.scale * 2),
+ ],
+ (0, 0, 0),
+ self.width,
+ )
+
+ global_to_lidar = inverse_rigid_transform(
+ torch.from_numpy(sample.extrinsics)
+ ).numpy()
+
+ for box in sample.boxes:
+ corners = [
+ self._map_lidar_to_bev_image(pts[0], pts[1])
+ for pts in box.corners
+ ]
+ self.canvas.draw_rotated_box(corners, box.color, self.width)
+
+ if self.plot_trajectory:
+ assert (
+ box.track_id is not None
+ ), "Track id must be set to plot trajectory."
+
+ trajectory = self.trajectories[box.track_id]
+ for center in trajectory:
+ # Move global center to current lidar frame
+ center_lidar = np.dot(global_to_lidar, [*center, 1])[:3]
+
+ bev_center = self._map_lidar_to_bev_image(
+ center_lidar[0], center_lidar[1]
+ )
+
+ self.canvas.draw_circle(
+ bev_center, box.color, self.width * 2
+ )
+
+ return self.canvas.as_numpy_image()
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Writes all processes samples to the output folder naming each image
+ ..
+
+ Args:
+ cur_iter (int): Current iteration.
+ output_folder (str): Folder where the output should be written.
+ """
+ if self._run_on_batch(cur_iter):
+ for sample in self._samples:
+ output_dir = output_folder
+ sample_name = f"{sample.name}.{self.file_type}"
+
+ self._draw_image(sample)
+
+ if sample.sequence_name is not None:
+ output_dir = os.path.join(output_dir, sample.sequence_name)
+
+ output_dir = os.path.join(output_dir, "BEV")
+
+ os.makedirs(output_dir, exist_ok=True)
+ self.canvas.save_to_disk(os.path.join(output_dir, sample_name))
diff --git a/vis4d/vis/image/bounding_box_visualizer.py b/vis4d/vis/image/bounding_box_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a436c89dba735ef788b568a7c7e6c70c38f2eaf
--- /dev/null
+++ b/vis4d/vis/image/bounding_box_visualizer.py
@@ -0,0 +1,226 @@
+"""Bounding box visualizer."""
+
+from __future__ import annotations
+
+import os
+from dataclasses import dataclass
+
+from vis4d.common.typing import (
+ ArgsType,
+ ArrayLike,
+ ArrayLikeFloat,
+ ArrayLikeInt,
+ NDArrayUI8,
+)
+from vis4d.vis.base import Visualizer
+from vis4d.vis.util import generate_color_map
+
+from .canvas import CanvasBackend, PillowCanvasBackend
+from .util import preprocess_boxes, preprocess_image
+from .viewer import ImageViewerBackend, MatplotlibImageViewer
+
+
+@dataclass
+class DetectionBox2D:
+ """Dataclass storing box informations."""
+
+ corners: tuple[float, float, float, float]
+ label: str
+ color: tuple[int, int, int]
+
+
+@dataclass
+class DataSample:
+ """Dataclass storing a data sample that can be visualized."""
+
+ image: NDArrayUI8
+ image_name: str
+ boxes: list[DetectionBox2D]
+
+
+class BoundingBoxVisualizer(Visualizer):
+ """Bounding box visualizer class."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ n_colors: int = 50,
+ cat_mapping: dict[str, int] | None = None,
+ file_type: str = "png",
+ width: int = 2,
+ canvas: CanvasBackend = PillowCanvasBackend(),
+ viewer: ImageViewerBackend = MatplotlibImageViewer(),
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new Visualizer for Image and Bounding Boxes.
+
+ Args:
+ n_colors (int): How many colors should be used for the internal
+ color map
+ cat_mapping (dict[str, int]): Mapping from class names to class
+ ids. Defaults to None.
+ file_type (str): Desired file type. Defaults to "png".
+ width (int): Width of the bounding box lines. Defaults to 2.
+ canvas (CanvasBackend): Backend that is used to draw on images.
+ viewer (ImageViewerBackend): Backend that is used show images.
+ """
+ super().__init__(*args, **kwargs)
+ self._samples: list[DataSample] = []
+ self.color_palette = generate_color_map(n_colors)
+ self.class_id_mapping = (
+ {v: k for k, v in cat_mapping.items()}
+ if cat_mapping is not None
+ else {}
+ )
+ self.file_type = file_type
+ self.width = width
+ self.canvas = canvas
+ self.viewer = viewer
+
+ def __repr__(self) -> str:
+ """Return string representation of the visualizer."""
+ return "BoundingBoxVisualizer"
+
+ def reset(self) -> None:
+ """Reset visualizer."""
+ self._samples.clear()
+
+ def process( # pylint: disable=arguments-differ
+ self,
+ cur_iter: int,
+ images: list[ArrayLike],
+ image_names: list[str],
+ boxes: list[ArrayLikeFloat],
+ scores: None | list[ArrayLikeFloat] = None,
+ class_ids: None | list[ArrayLikeInt] = None,
+ track_ids: None | list[ArrayLikeInt] = None,
+ categories: None | list[list[str]] = None,
+ ) -> None:
+ """Processes a batch of data.
+
+ Args:
+ cur_iter (int): Current iteration.
+ images (list[ArrayLike]): Images to show.
+ image_names (list[str]): Image names.
+ boxes (list[ArrayLikeFloat]): List of predicted bounding boxes with
+ shape [N, (x1, y1, x2, y2)], where N is the number of boxes.
+ scores (None | list[ArrayLikeFloat], optional): List of predicted
+ box scores each of shape [N]. Defaults to None.
+ class_ids (None | list[ArrayLikeInt], optional): List of predicted
+ class ids each of shape [N]. Defaults to None.
+ track_ids (None | list[ArrayLikeInt], optional): List of predicted
+ track ids each of shape [N]. Defaults to None.
+ categories (None | list[list[str]], optional): List of categories
+ for each image. Instead of class ids, the categories will be
+ used to label the boxes. Defaults to None.
+ """
+ if self._run_on_batch(cur_iter):
+ for idx, image in enumerate(images):
+ self.process_single_image(
+ image,
+ image_names[idx],
+ boxes[idx],
+ None if scores is None else scores[idx],
+ None if class_ids is None else class_ids[idx],
+ None if track_ids is None else track_ids[idx],
+ None if categories is None else categories[idx],
+ )
+
+ def process_single_image(
+ self,
+ image: ArrayLike,
+ image_name: str,
+ boxes: ArrayLikeFloat,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ categories: None | list[str] = None,
+ ) -> None:
+ """Processes a single image entry.
+
+ Args:
+ image (ArrayLike): Image to show.
+ image_name (str): Image name.
+ boxes (ArrayLikeFloat): Predicted bounding boxes with shape
+ [N, (x1,y1,x2,y2)], where N is the number of boxes.
+ scores (None | ArrayLikeFloat, optional): Predicted box scores of
+ shape [N]. Defaults to None.
+ class_ids (None | ArrayLikeInt, optional): Predicted class ids of
+ shape [N]. Defaults to None.
+ track_ids (None | ArrayLikeInt, optional): Predicted track ids of
+ shape [N]. Defaults to None.
+ categories (None | list[str], optional): List of categories for
+ each box. Instead of class ids, the categories will be used to
+ label the boxes. Defaults to None.
+ """
+ img_normalized = preprocess_image(image, mode=self.image_mode)
+ data_sample = DataSample(img_normalized, image_name, [])
+
+ for corners, label, color in zip(
+ *preprocess_boxes(
+ boxes,
+ scores,
+ class_ids,
+ track_ids,
+ self.color_palette,
+ self.class_id_mapping,
+ categories=categories,
+ )
+ ):
+ data_sample.boxes.append(
+ DetectionBox2D(
+ corners=(corners[0], corners[1], corners[2], corners[3]),
+ label=label,
+ color=color,
+ )
+ )
+
+ self._samples.append(data_sample)
+
+ def show(self, cur_iter: int, blocking: bool = True) -> None:
+ """Shows the processed images in a interactive window.
+
+ Args:
+ cur_iter (int): Current iteration.
+ blocking (bool): If the visualizer should be blocking i.e. wait for
+ human input for each image. Defaults to True.
+ """
+ if self._run_on_batch(cur_iter):
+ image_data = [self._draw_image(d) for d in self._samples]
+ self.viewer.show_images(image_data, blocking=blocking)
+
+ def _draw_image(self, sample: DataSample) -> NDArrayUI8:
+ """Visualizes the datasample and returns is as numpy image.
+
+ Args:
+ sample (DataSample): The data sample to visualize.
+
+ Returns:
+ NDArrayUI8: A image with the visualized data sample.
+ """
+ self.canvas.create_canvas(sample.image)
+ for box in sample.boxes:
+ self.canvas.draw_box(box.corners, box.color, width=self.width)
+ self.canvas.draw_text(box.corners[:2], box.label, box.color)
+
+ return self.canvas.as_numpy_image()
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Writes all processes samples to the output folder naming each image
+ ..
+
+ Args:
+ cur_iter (int): Current iteration.
+ output_folder (str): Folder where the output should be written.
+ """
+ if self._run_on_batch(cur_iter):
+ for sample in self._samples:
+ image_name = f"{sample.image_name}.{self.file_type}"
+
+ _ = self._draw_image(sample)
+
+ self.canvas.save_to_disk(
+ os.path.join(output_folder, image_name)
+ )
diff --git a/vis4d/vis/image/canvas/__init__.py b/vis4d/vis/image/canvas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8e856bc8e8319e9576592b6fe763add88e5b56
--- /dev/null
+++ b/vis4d/vis/image/canvas/__init__.py
@@ -0,0 +1,6 @@
+"""Vis4D image canvas backends."""
+
+from .base import CanvasBackend
+from .pillow_backend import PillowCanvasBackend
+
+__all__ = ["CanvasBackend", "PillowCanvasBackend"]
diff --git a/vis4d/vis/image/canvas/base.py b/vis4d/vis/image/canvas/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eb31eda31dafea2da0418407c0bb9a8d83e0a1e
--- /dev/null
+++ b/vis4d/vis/image/canvas/base.py
@@ -0,0 +1,175 @@
+"""Base class of canvas for image based visualization."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import NDArrayBool, NDArrayF32, NDArrayUI8
+
+
+class CanvasBackend:
+ """Abstract interface that allows to draw on images.
+
+ Supports drawing different bounding boxes on top of an image.
+ """
+
+ def create_canvas(
+ self,
+ image: NDArrayUI8 | None = None,
+ image_hw: tuple[int, int] | None = None,
+ ) -> None:
+ """Creates a new canvas with a given image or shape internally.
+
+ Either provide a background image or the desired height, width
+ of the canvas.
+
+ Args:
+ image (np.array[uint8] | None): Numpy array with a background image
+ image_hw (tuple[int, int] | None): height, width of the canvas
+ """
+ raise NotImplementedError
+
+ def draw_bitmap(
+ self,
+ bitmap: NDArrayBool,
+ color: tuple[int, int, int],
+ top_left_corner: tuple[float, float] = (0, 0),
+ alpha: float = 0.5,
+ ) -> None:
+ """Draws a binary mask onto the given canvas.
+
+ Args:
+ bitmap (ndarray): The binary mask to draw
+ color (tuple[int, int, int]): Color of the box [0,255].
+ top_left_corner (tuple(float, float)): Coordinates of top left
+ corner of the bitmap. Defaults to (0, 0).
+ alpha (float, optional): Alpha value for transparency of this mask.
+ Defaults to 0.5.
+ """
+ raise NotImplementedError
+
+ def draw_text(
+ self,
+ position: tuple[float, float],
+ text: str,
+ color: tuple[int, int, int] = (255, 255, 255),
+ ) -> None:
+ """Draw text onto canvas at given position.
+
+ Args:
+ position (tuple[float, float]): x,y position where the text will
+ start.
+ text (str): Text to be placed at the given location.
+ color (tuple[int, int, int], optional): Text color. Defaults to
+ (255, 255, 255).
+ """
+ raise NotImplementedError
+
+ def draw_line(
+ self,
+ point1: tuple[float, float],
+ point2: tuple[float, float],
+ color: tuple[int, int, int],
+ width: int = 0,
+ ) -> None:
+ """Draw a line onto canvas from point 1 to 2.
+
+ Args:
+ point1 (tuple[float, float]): Start point (2D pixel coordinates).
+ point2 (tuple[float, float]): End point (2D pixel coordinates).
+ color (ttuple[int, int, int]): Color of the line.
+ width (int, optional): Line width. Defaults to 0.
+ """
+ raise NotImplementedError
+
+ def draw_circle(
+ self,
+ center: tuple[float, float],
+ color: tuple[int, int, int],
+ radius: int = 2,
+ ) -> None:
+ """Draw a circle onto canvas.
+
+ Args:
+ center (tuple[float, float]): Center of the circle.
+ color (tuple[int, int, int]): Color of the circle.
+ radius (int, optional): Radius of the circle. Defaults to 2.
+ """
+ raise NotImplementedError
+
+ def draw_box(
+ self,
+ corners: tuple[float, float, float, float],
+ color: tuple[int, int, int],
+ width: int = 1,
+ ) -> None:
+ """Draws a box onto the given canvas.
+
+ Args:
+ corners (list[float]): Containing [x1,y1,x2,y2] the corners of
+ the box.
+ color (tuple[int, int, int]): Color of the box [0,255].
+ width (int, optional): Line width. Defaults to 1.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ raise NotImplementedError
+
+ def draw_rotated_box(
+ self,
+ corners: list[tuple[float, float]],
+ color: tuple[int, int, int],
+ width: int = 0,
+ ) -> None:
+ """Draws a box onto the given canvas.
+
+ Corner ordering:
+
+ (2) +---------+ (3)
+ | |
+ | |
+ | |
+ (0) +---------+ (1)
+
+ Args:
+ corners (list[tuple[float, float]]): Containing the four corners of
+ the box.
+ color (tuple[int, int, int]): Color of the box [0,255].
+ width (int, optional): Line width. Defaults to 0.
+ """
+ raise NotImplementedError
+
+ def draw_box_3d(
+ self,
+ corners: list[tuple[float, float, float]],
+ color: tuple[int, int, int],
+ intrinsics: NDArrayF32,
+ width: int = 0,
+ camera_near_clip: float = 0.15,
+ plot_heading: bool = True,
+ ) -> None:
+ """Draws a line between two points.
+
+ Args:
+ corners (list[tuple[float, float, float]]): Containing the eight
+ corners of the box.
+ color (tuple[int, int, int]): Color of the line.
+ intrinsics (NDArrayF32): Camera intrinsics matrix.
+ width (int, optional): The width of the line. Defaults to 0.
+ camera_near_clip (float, optional): The near clipping plane of the
+ camera. Defaults to 0.15.
+ plot_heading (bool, optional): If True, the heading of the box will
+ be plotted as a line. Defaults to True.
+ """
+ raise NotImplementedError
+
+ def as_numpy_image(self) -> NDArrayUI8:
+ """Returns the current canvas as numpy image."""
+ raise NotImplementedError
+
+ def save_to_disk(self, image_path: str) -> None:
+ """Writes the current canvas to disk.
+
+ Args:
+ image_path (str): Full image path (with file name and ending).
+ """
+ raise NotImplementedError
diff --git a/vis4d/vis/image/canvas/pillow_backend.py b/vis4d/vis/image/canvas/pillow_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bfd0104097cbaabf5087f18ba0cc4b6b951ac4e
--- /dev/null
+++ b/vis4d/vis/image/canvas/pillow_backend.py
@@ -0,0 +1,376 @@
+"""Pillow backend implementation to draw on images."""
+
+from __future__ import annotations
+
+import numpy as np
+from PIL import Image, ImageDraw
+from PIL.ImageFont import ImageFont, load_default
+
+from vis4d.common.typing import NDArrayBool, NDArrayF32, NDArrayF64, NDArrayUI8
+
+from ..util import get_intersection_point, project_point
+from .base import CanvasBackend
+
+
+class PillowCanvasBackend(CanvasBackend):
+ """Canvas backend using Pillow."""
+
+ def __init__(
+ self, font: ImageFont | None = None, font_size: int | None = None
+ ) -> None:
+ """Creates a new canvas backend.
+
+ Args:
+ font (ImageFont): Pillow font to use for the label.
+ font_size (int): Font size to use for the label.
+ """
+ self._image_draw: ImageDraw.ImageDraw | None = None
+ self._font = font if font is not None else load_default(font_size)
+ self._image: Image.Image | None = None
+
+ def create_canvas(
+ self,
+ image: NDArrayUI8 | None = None,
+ image_hw: tuple[int, int] | None = None,
+ ) -> None:
+ """Creates a new canvas with a given image or shape internally.
+
+ Either provide a background image or the desired height, width
+ of the canvas.
+
+ Args:
+ image (np.array[uint8] | None): Numpy array with a background image
+ image_hw (tuple[int, int] | None): height, width of the canvas
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if image_hw is not None:
+ white_image = np.ones([*image_hw, 3]) * 255
+ image = white_image.astype(np.uint8)
+ else:
+ assert (
+ image is not None
+ ), "Image or Image Shapes required to create canvas"
+
+ self._image = Image.fromarray(image)
+ self._image_draw = ImageDraw.Draw(self._image)
+
+ def draw_bitmap(
+ self,
+ bitmap: NDArrayBool,
+ color: tuple[int, int, int],
+ top_left_corner: tuple[float, float] = (0, 0),
+ alpha: float = 0.5,
+ ) -> None:
+ """Draws a binary mask onto the given canvas.
+
+ Args:
+ bitmap (ndarray): The binary mask to draw.
+ color (tuple[int, int, int]): Color of the box [0,255].
+ top_left_corner (tuple(float, float)): Coordinates of top left
+ corner of the bitmap.
+ alpha (float): Alpha value for transparency of this mask.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+ mask = np.squeeze(bitmap)
+ assert len(mask.shape) == 2, "Bitmap expected to have shape [h,w]"
+
+ bitmap_with_alpha: NDArrayF64 = np.repeat(
+ mask[:, :, None], 4, axis=2
+ ).astype(np.float64)
+ bitmap_with_alpha[..., -1] = bitmap_with_alpha[..., -1] * alpha * 255
+ bitmap_pil = Image.fromarray(
+ bitmap_with_alpha.astype(np.uint8), mode="RGBA"
+ )
+ self._image_draw.bitmap(
+ top_left_corner, bitmap_pil, fill=color # type: ignore
+ )
+
+ def draw_text(
+ self,
+ position: tuple[float, float],
+ text: str,
+ color: tuple[int, int, int] = (255, 255, 255),
+ ) -> None:
+ """Draw text onto canvas at given position.
+
+ Args:
+ position (tuple[float, float]): x,y position where the text will
+ start.
+ text (str): Text to be placed at the given location.
+ color (tuple[int, int, int], optional): Text color. Defaults to
+ (255, 255, 255).
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+ left, top, right, bottom = self._image_draw.textbbox(
+ position, text, font=self._font
+ )
+ self._image_draw.rectangle(
+ (left - 2, top - 2, right + 2, bottom + 2), fill=color
+ )
+ self._image_draw.text(position, text, (255, 255, 255), font=self._font)
+
+ def draw_box(
+ self,
+ corners: tuple[float, float, float, float],
+ color: tuple[int, int, int],
+ width: int = 1,
+ ) -> None:
+ """Draws a box onto the given canvas.
+
+ Args:
+ corners (list[float]): Containing [x1,y2,x2,y2] the corners of
+ the box.
+ color (tuple[int, int, int]): Color of the box [0,255].
+ width (int, optional): Line width. Defaults to 1.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+
+ self._image_draw.rectangle(corners, outline=color, width=width)
+
+ def draw_rotated_box(
+ self,
+ corners: list[tuple[float, float]],
+ color: tuple[int, int, int],
+ width: int = 0,
+ ) -> None:
+ """Draws a box onto the given canvas.
+
+ Corner ordering:
+
+ (2) +---------+ (3)
+ | |
+ | |
+ | |
+ (0) +---------+ (1)
+
+ Args:
+ corners (list[tuple[float, float]]): Containing the four corners of
+ the box.
+ color (tuple[int, int, int]): Color of the box [0,255].
+ width (int, optional): Line width. Defaults to 0.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ assert len(corners) == 4, "2D box must consist of 4 corner points."
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+
+ self.draw_line(corners[0], corners[1], color, 2 * width)
+ self.draw_line(corners[0], corners[2], color, width)
+ self.draw_line(corners[1], corners[3], color, width)
+ self.draw_line(corners[2], corners[3], color, width)
+
+ center_forward = np.mean(corners[:2], axis=0, dtype=np.float32)
+ center = np.mean(corners, axis=0, dtype=np.float32)
+ self.draw_line(
+ tuple(center.tolist()),
+ tuple(center_forward.tolist()),
+ color,
+ width,
+ )
+
+ def draw_line(
+ self,
+ point1: tuple[float, float],
+ point2: tuple[float, float],
+ color: tuple[int, int, int],
+ width: int = 0,
+ ) -> None:
+ """Draw a line onto canvas from point 1 to 2.
+
+ Args:
+ point1 (tuple[float, float]): Start point (2D pixel coordinates).
+ point2 (tuple[float, float]): End point (2D pixel coordinates).
+ color (tuple[int, int, int]): Color of the line.
+ width (int, optional): Line width. Defaults to 0.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+ self._image_draw.line((point1, point2), width=width, fill=color)
+
+ def draw_circle(
+ self,
+ center: tuple[float, float],
+ color: tuple[int, int, int],
+ radius: int = 2,
+ ) -> None:
+ """Draw a circle onto canvas.
+
+ Args:
+ center (tuple[float, float]): Center of the circle.
+ color (tuple[int, int, int]): Color of the circle.
+ radius (int, optional): Radius of the circle. Defaults to 2.
+ """
+ x1 = center[0] - radius
+ y1 = center[1] - radius
+ x2 = center[0] + radius
+ y2 = center[1] + radius
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+ self._image_draw.ellipse((x1, y1, x2, y2), fill=color, outline=color)
+
+ def _draw_box_3d_line(
+ self,
+ point1: tuple[float, float, float],
+ point2: tuple[float, float, float],
+ color: tuple[int, int, int],
+ intrinsics: NDArrayF32,
+ width: int = 0,
+ camera_near_clip: float = 0.15,
+ ) -> None:
+ """Draws a line between two points.
+
+ Args:
+ point1 (tuple[float, float, float]): The first point. The third
+ coordinate is the depth.
+ point2 (tuple[float, float, float]): The first point. The third
+ coordinate is the depth.
+ color (tuple[int, int, int]): Color of the line.
+ intrinsics (NDArrayF32): Camera intrinsics matrix.
+ width (int, optional): The width of the line. Defaults to 0.
+ camera_near_clip (float, optional): The near clipping plane of the
+ camera. Defaults to 0.15.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if point1[2] < camera_near_clip and point2[2] < camera_near_clip:
+ return
+
+ if point1[2] < camera_near_clip:
+ point1 = get_intersection_point(point1, point2, camera_near_clip)
+ elif point2[2] < camera_near_clip:
+ point2 = get_intersection_point(point1, point2, camera_near_clip)
+
+ pt1 = project_point(point1, intrinsics)
+ pt2 = project_point(point2, intrinsics)
+
+ if self._image_draw is None:
+ raise ValueError(
+ "No Image Draw initialized! Did you call 'create_canvas'?"
+ )
+ self._image_draw.line((pt1, pt2), width=width, fill=color)
+
+ def draw_box_3d(
+ self,
+ corners: list[tuple[float, float, float]],
+ color: tuple[int, int, int],
+ intrinsics: NDArrayF32,
+ width: int = 0,
+ camera_near_clip: float = 0.15,
+ plot_heading: bool = True,
+ ) -> None:
+ """Draws a 3D box onto the given canvas."""
+ # Draw Front
+ self._draw_box_3d_line(
+ corners[0], corners[1], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[1], corners[5], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[5], corners[4], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[4], corners[0], color, intrinsics, width, camera_near_clip
+ )
+
+ # Draw Sides
+ self._draw_box_3d_line(
+ corners[0], corners[2], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[1], corners[3], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[4], corners[6], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[5], corners[7], color, intrinsics, width, camera_near_clip
+ )
+
+ # Draw Back
+ self._draw_box_3d_line(
+ corners[2], corners[3], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[3], corners[7], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[7], corners[6], color, intrinsics, width, camera_near_clip
+ )
+ self._draw_box_3d_line(
+ corners[6], corners[2], color, intrinsics, width, camera_near_clip
+ )
+
+ # Draw line indicating the front
+ if plot_heading:
+ center_bottom_forward = np.mean(
+ corners[:2], axis=0, dtype=np.float32
+ )
+ center_bottom = np.mean(corners[:4], axis=0, dtype=np.float32)
+ self._draw_box_3d_line(
+ tuple(center_bottom.tolist()),
+ tuple(center_bottom_forward.tolist()),
+ color,
+ intrinsics,
+ width,
+ camera_near_clip,
+ )
+
+ def as_numpy_image(self) -> NDArrayUI8:
+ """Returns the current canvas as numpy image.
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if self._image is None:
+ raise ValueError(
+ "No Image initialized! Did you call 'create_canvas'?"
+ )
+ return np.asarray(self._image)
+
+ def save_to_disk(self, image_path: str) -> None:
+ """Writes the current canvas to disk.
+
+ Args:
+ image_path (str): Full image path (with file name and ending).
+
+ Raises:
+ ValueError: If the canvas is not initialized.
+ """
+ if self._image is None:
+ raise ValueError(
+ "No Image initialized! Did you call 'create_canvas'?"
+ )
+ self._image.save(image_path)
diff --git a/vis4d/vis/image/functional.py b/vis4d/vis/image/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..5113b198be71e932f38dd1d98a842ca3f7f5f75e
--- /dev/null
+++ b/vis4d/vis/image/functional.py
@@ -0,0 +1,430 @@
+"""Function interface for image visualization functions."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ ArrayLikeBool,
+ ArrayLikeFloat,
+ ArrayLikeInt,
+ NDArrayF32,
+ NDArrayUI8,
+)
+
+from ..util import generate_color_map
+from .canvas import CanvasBackend, PillowCanvasBackend
+from .util import (
+ preprocess_boxes,
+ preprocess_boxes3d,
+ preprocess_image,
+ preprocess_masks,
+ project_point,
+)
+from .viewer import ImageViewerBackend, MatplotlibImageViewer
+
+
+def imshow(
+ image: ArrayLike,
+ image_mode: str = "RGB",
+ image_viewer: ImageViewerBackend = MatplotlibImageViewer(),
+) -> None:
+ """Shows a single image.
+
+ Args:
+ image (NDArrayNumber): The image to show.
+ image_mode (str, optional): Image Mode. Defaults to "RGB".
+ image_viewer (ImageViewerBackend, optional): The Image viewer backend
+ to use. Defaults to MatplotlibImageViewer().
+ """
+ image = preprocess_image(image, image_mode)
+ image_viewer.show_images([image])
+
+
+def draw_masks(
+ image: ArrayLike,
+ masks: ArrayLikeBool,
+ class_ids: ArrayLikeInt | None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ canvas: CanvasBackend = PillowCanvasBackend(),
+) -> NDArrayUI8:
+ """Draws semantic masks into the given image.
+
+ Args:
+ image (ArrayLike): The image to draw the bboxes into.
+ masks (ArrayLikeBool): The semantic masks with the same shape as the
+ image.
+ class_ids (ArrayLikeInt, optional): Predicted class ids.
+ Defaults to None.
+ n_colors (int, optional): Number of colors to use for color palette.
+ Defaults to 50.
+ image_mode (str, optional): Image Mode. Defaults to "RGB".
+ canvas (CanvasBackend, optional): Canvas backend to use.
+ Defaults to PillowCanvasBackend().
+
+ Returns:
+ NDArrayUI8: The image with semantic masks drawn into it,
+ """
+ image = preprocess_image(image, mode=image_mode)
+ canvas.create_canvas(image)
+ for m, c in zip(
+ *preprocess_masks(masks, class_ids, generate_color_map(n_colors))
+ ):
+ canvas.draw_bitmap(m, c)
+ return canvas.as_numpy_image()
+
+
+def draw_bboxes(
+ image: ArrayLike,
+ boxes: ArrayLikeFloat,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ class_id_mapping: None | dict[int, str] = None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ box_width: int = 1,
+ canvas: CanvasBackend = PillowCanvasBackend(),
+) -> CanvasBackend:
+ """Draws the predicted bounding boxes into the given image.
+
+ Args:
+ image (ArrayLike): The image to draw the bboxes into.
+ boxes (ArrayLikeFloat): Predicted bounding boxes.
+ scores (None | ArrayLikeFloat, optional): Predicted scores.
+ Defaults to None.
+ class_ids (ArrayLikeInt, optional): Predicted class ids.
+ Defaults to None.
+ track_ids (ArrayLikeInt, optional): Predicted track ids.
+ Defaults to None.
+ class_id_mapping (dict[int, str], optional): Mapping from class id to
+ name. Defaults to None.
+ n_colors (int, optional): Number of colors to use for color palette.
+ Defaults to 50.
+ image_mode (str, optional): Image Mode. Defaults to "RGB".
+ box_width (int, optional): Width of the box border. Defaults to 1.
+ canvas (CanvasBackend, optional): Canvas backend to use.
+ Defaults to PillowCanvasBackend().
+
+ Returns:
+ NDArrayUI8: The image with boxes drawn into it,
+ """
+ image = preprocess_image(image, image_mode)
+ box_data = preprocess_boxes(
+ boxes,
+ scores,
+ class_ids,
+ track_ids,
+ color_palette=generate_color_map(n_colors),
+ class_id_mapping=class_id_mapping,
+ )
+ canvas.create_canvas(image)
+
+ for corners, label, color in zip(*box_data):
+ canvas.draw_box(corners, color, box_width)
+
+ if len(label) > 0:
+ canvas.draw_text((corners[0], corners[1]), label, color=color)
+ return canvas
+
+
+def imshow_bboxes(
+ image: ArrayLike,
+ boxes: ArrayLikeFloat,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ class_id_mapping: None | dict[int, str] = None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ box_width: int = 1,
+ image_viewer: ImageViewerBackend = MatplotlibImageViewer(),
+ file_path: str | None = None,
+) -> None:
+ """Shows the bounding boxes overlayed on the given image.
+
+ Args:
+ image (ArrayLike): Background Image
+ boxes (ArrayLikeFloat): Boxes to show. Shape [N, 4] with
+ (x1,y1,x2,y2) as corner convention
+ scores (ArrayLikeFloat, optional): Score for each box shape [N]
+ class_ids (ArrayLikeInt, optional): Class id for each box shape [N]
+ track_ids (ArrayLikeInt, optional): Track id for each box shape [N]
+ class_id_mapping (dict[int, str], optional): Mapping to convert
+ class id to class name
+ n_colors (int, optional): Number of distinct colors used to color the
+ boxes. Defaults to 50.
+ image_mode (str, optional): Image channel mode (RGB or BGR).
+ box_width (int, optional): Width of the box border. Defaults to 1.
+ image_viewer (ImageViewerBackend, optional): The Image viewer backend
+ to use. Defaults to MatplotlibImageViewer().
+ file_path (str): The path to save the image to. Defaults to None.
+ """
+ image = preprocess_image(image, mode=image_mode)
+ canvas = draw_bboxes(
+ image,
+ boxes,
+ scores,
+ class_ids,
+ track_ids,
+ class_id_mapping,
+ n_colors,
+ image_mode,
+ box_width,
+ )
+ imshow(canvas.as_numpy_image(), image_mode, image_viewer)
+
+ if file_path is not None:
+ canvas.save_to_disk(file_path)
+
+
+def draw_bbox3d(
+ image: NDArrayUI8,
+ boxes3d: ArrayLikeFloat,
+ intrinsics: NDArrayF32,
+ extrinsics: NDArrayF32 | None = None,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ class_id_mapping: None | dict[int, str] = None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ canvas: CanvasBackend = PillowCanvasBackend(),
+ width: int = 4,
+ camera_near_clip: float = 0.15,
+) -> CanvasBackend:
+ """Draw 3D box onto image."""
+ image = preprocess_image(image, image_mode)
+ image_hw = (image.shape[0], image.shape[1])
+ _, corners, labels, colors, _ = preprocess_boxes3d(
+ image_hw,
+ boxes3d,
+ intrinsics,
+ extrinsics,
+ scores,
+ class_ids,
+ track_ids,
+ color_palette=generate_color_map(n_colors),
+ class_id_mapping=class_id_mapping,
+ )
+ canvas.create_canvas(image)
+
+ for corner, label, color in zip(corners, labels, colors):
+ canvas.draw_box_3d(corner, color, intrinsics, width, camera_near_clip)
+
+ selected_corner = project_point(corner[0], intrinsics)
+
+ if len(label) > 0:
+ canvas.draw_text(
+ (selected_corner[0], selected_corner[1]), label, color=color
+ )
+
+ return canvas
+
+
+def imshow_bboxes3d(
+ image: ArrayLike,
+ boxes3d: ArrayLikeFloat,
+ intrinsics: NDArrayF32,
+ extrinsics: NDArrayF32 | None = None,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ class_id_mapping: None | dict[int, str] = None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ image_viewer: ImageViewerBackend = MatplotlibImageViewer(),
+ file_path: str | None = None,
+) -> None:
+ """Show image with bounding boxes."""
+ image = preprocess_image(image, mode=image_mode)
+ canvas = draw_bbox3d(
+ image,
+ boxes3d,
+ intrinsics,
+ extrinsics,
+ scores,
+ class_ids,
+ track_ids,
+ class_id_mapping=class_id_mapping,
+ n_colors=n_colors,
+ image_mode=image_mode,
+ )
+ imshow(canvas.as_numpy_image(), image_mode, image_viewer)
+
+ if file_path is not None:
+ canvas.save_to_disk(file_path)
+
+
+def imshow_masks(
+ image: ArrayLike,
+ masks: ArrayLikeBool,
+ class_ids: ArrayLikeInt | None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ canvas: CanvasBackend = PillowCanvasBackend(),
+) -> None:
+ """Shows semantic masks overlayed over the given image.
+
+ Args:
+ image (ArrayLike): The image to draw the bboxes into.
+ masks (ArrayLikeBool): The semantic masks with the same shape as the
+ image.
+ class_ids (ArrayLikeInt, optional): Predicted class ids.
+ Defaults to None.
+ n_colors (int, optional): Number of colors to use for color palette.
+ Defaults to 50.
+ image_mode (str, optional): Image Mode.. Defaults to "RGB".
+ canvas (CanvasBackend, optional): Canvas backend to use.
+ Defaults to PillowCanvasBackend().
+ """
+ imshow(
+ draw_masks(image, masks, class_ids, n_colors, image_mode, canvas),
+ image_mode,
+ )
+
+
+def imshow_topk_bboxes(
+ image: ArrayLike,
+ boxes: ArrayLikeFloat,
+ scores: ArrayLikeFloat,
+ topk: int = 100,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ class_id_mapping: None | dict[int, str] = None,
+ n_colors: int = 50,
+ image_mode: str = "RGB",
+ box_width: int = 1,
+ image_viewer: ImageViewerBackend = MatplotlibImageViewer(),
+ file_path: str | None = None,
+) -> None:
+ """Visualize the 'topk' bounding boxes with highest score.
+
+ Args:
+ image (ArrayLike): Background Image
+ boxes (ArrayLikeFloat): Boxes to show. Shape [N, 4] with
+ (x1,y1,x2,y2) as corner convention
+ scores (ArrayLikeFloat): Score for each box shape [N]
+ topk (int): Number of boxes to visualize
+ class_ids (ArrayLikeInt, optional): Class id for each box shape [N]
+ track_ids (ArrayLikeInt, optional): Track id for each box shape [N]
+ class_id_mapping (dict[int, str], optional): Mapping to convert
+ class id to class name
+ n_colors (int, optional): Number of distinct colors used to color the
+ boxes. Defaults to 50.
+ image_mode (str, optional): Image channel mode (RGB or BGR).
+ box_width (int, optional): Width of the box border. Defaults to 1.
+ image_viewer (ImageViewerBackend, optional): The Image viewer backend
+ to use. Defaults to MatplotlibImageViewer().
+ file_path (str): The path to save the image to. Defaults to None.
+
+ """
+ scores = array_to_numpy(scores, n_dims=1, dtype=np.float32)
+ top_k_idxs = np.argpartition(scores.ravel(), -topk)[-topk:]
+
+ boxes_np = array_to_numpy(boxes, n_dims=2, dtype=np.float32)
+ class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32)
+ track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32)
+ imshow_bboxes(
+ image,
+ boxes_np[top_k_idxs],
+ scores[top_k_idxs],
+ class_ids_np[top_k_idxs] if class_ids_np is not None else None,
+ track_ids_np[top_k_idxs] if track_ids_np is not None else None,
+ class_id_mapping,
+ n_colors,
+ image_mode,
+ box_width,
+ image_viewer,
+ file_path,
+ )
+
+
+def imshow_track_matches(
+ key_imgs: list[ArrayLike],
+ ref_imgs: list[ArrayLike],
+ key_boxes: list[ArrayLikeFloat],
+ ref_boxes: list[ArrayLikeFloat],
+ key_track_ids: list[ArrayLikeInt],
+ ref_track_ids: list[ArrayLikeInt],
+ image_mode: str = "RGB",
+ image_viewer: ImageViewerBackend = MatplotlibImageViewer(),
+) -> None:
+ """Visualize paired bounding boxes successively for batched frame pairs.
+
+ Args:
+ key_imgs (list[ArrayLike]): Key Images.
+ ref_imgs (list[ArrayLike]): Reference Images.
+ key_boxes (list[ArrayLikeFloat]): Predicted Boxes for the key image.
+ Shape [N, 4]
+ ref_boxes (list[ArrayLikeFloat]): Predicted Boxes for the key image.
+ Shape [N, 4]
+ key_track_ids (list[ArrayLikeInt]): Predicted ids for the key images.
+ ref_track_ids (list[ArrayLikeInt]): Predicted ids for the reference
+ images.
+ image_mode (str, optional): Color mode if the image. Defaults to "RGB".
+ image_viewer (ImageViewerBackend, optional): The Image viewer backend
+ to use. Defaults to MatplotlibImageViewer().
+ """
+ key_imgs_np = tuple(
+ array_to_numpy(img, n_dims=3, dtype=np.float32) for img in key_imgs
+ )
+ ref_imgs_np = tuple(
+ array_to_numpy(img, n_dims=3, dtype=np.float32) for img in ref_imgs
+ )
+ key_boxes_np = tuple(
+ array_to_numpy(b, n_dims=2, dtype=np.float32) for b in key_boxes
+ )
+ ref_boxes_np = tuple(
+ array_to_numpy(b, n_dims=2, dtype=np.float32) for b in ref_boxes
+ )
+ key_track_ids_np = tuple(
+ array_to_numpy(t, n_dims=1, dtype=np.int32) for t in key_track_ids
+ )
+ ref_track_ids_np = tuple(
+ array_to_numpy(t, n_dims=1, dtype=np.int32) for t in ref_track_ids
+ )
+
+ for batch_i, (key_box, ref_box) in enumerate(
+ zip(key_boxes_np, ref_boxes_np)
+ ):
+ target = key_track_ids_np[batch_i].reshape(-1, 1) == ref_track_ids_np[
+ batch_i
+ ].reshape(1, -1)
+ for key_i in range(target.shape[0]):
+ if target[key_i].sum() == 0:
+ continue
+ ref_i = np.argmax(target[key_i]).item()
+ ref_image = ref_imgs_np[batch_i]
+ key_image = key_imgs_np[batch_i]
+
+ if ref_image.shape != key_image.shape:
+ # Can not stack images together
+ imshow_bboxes(
+ key_image,
+ key_box[key_i],
+ image_mode=image_mode,
+ image_viewer=image_viewer,
+ )
+ imshow_bboxes(
+ ref_image,
+ ref_box[ref_i],
+ image_mode=image_mode,
+ image_viewer=image_viewer,
+ )
+ else:
+ # stack imgs horizontal
+ k_canvas = draw_bboxes(
+ key_image, key_box[batch_i], image_mode=image_mode
+ )
+ r_canvas = draw_bboxes(
+ ref_image, ref_box[batch_i], image_mode=image_mode
+ )
+ k_np_img = k_canvas.as_numpy_image()
+ r_np_img = r_canvas.as_numpy_image()
+ stacked_img = np.vstack([k_np_img, r_np_img])
+
+ imshow(stacked_img, image_mode, image_viewer)
diff --git a/vis4d/vis/image/seg_mask_visualizer.py b/vis4d/vis/image/seg_mask_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b95451a0d3228aa7d85da20152abf22a20051f2
--- /dev/null
+++ b/vis4d/vis/image/seg_mask_visualizer.py
@@ -0,0 +1,213 @@
+"""Segmentation mask visualizer."""
+
+from __future__ import annotations
+
+import os
+from dataclasses import dataclass
+
+from vis4d.common.typing import (
+ ArgsType,
+ ArrayLikeFloat,
+ ArrayLikeInt,
+ ArrayLikeUInt,
+ NDArrayBool,
+ NDArrayUI8,
+)
+from vis4d.vis.base import Visualizer
+from vis4d.vis.image.canvas import CanvasBackend, PillowCanvasBackend
+from vis4d.vis.image.util import preprocess_image, preprocess_masks
+from vis4d.vis.image.viewer import ImageViewerBackend, MatplotlibImageViewer
+from vis4d.vis.util import generate_color_map
+
+
+@dataclass
+class SegMask2D:
+ """Dataclass storing mask information."""
+
+ mask: NDArrayBool
+ color: tuple[int, int, int]
+
+
+@dataclass
+class ImageWithSegMask:
+ """Dataclass storing a data sample that can be visualized."""
+
+ image: NDArrayUI8
+ image_name: str
+ masks: list[SegMask2D]
+
+
+class SegMaskVisualizer(Visualizer):
+ """Segmentation mask visualizer class."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ n_colors: int = 50,
+ class_id_mapping: dict[int, str] | None = None,
+ file_type: str = "png",
+ color_palette: list[tuple[int, int, int]] | None = None,
+ canvas: CanvasBackend = PillowCanvasBackend(),
+ viewer: ImageViewerBackend = MatplotlibImageViewer(),
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new Visualizer for Image and Bounding Boxes.
+
+ Args:
+ n_colors (int): How many colors should be used for the color map.
+ class_id_mapping (dict[int, str]): Mapping from class id to
+ human readable name.
+ file_type (str): Desired file type
+ color_palette (list[tuple[int, int, int]]): Color palette for each
+ class, in RGB format (0-255). If None, a random color palette
+ with n_colors is generated automatically. Defaults to None.
+ canvas (CanvasBackend): Backend that is used to draw on images
+ viewer (ImageViewerBackend): Backend that is used show images
+ """
+ super().__init__(*args, **kwargs)
+ self._samples: list[ImageWithSegMask] = []
+ self.color_palette = (
+ generate_color_map(n_colors)
+ if color_palette is None
+ else color_palette
+ )
+ self.class_id_mapping = (
+ class_id_mapping if class_id_mapping is not None else {}
+ )
+ self.file_type = file_type
+ self.canvas = canvas
+ self.viewer = viewer
+
+ def reset(self) -> None:
+ """Reset visualizer for new round of evaluation."""
+ self._samples.clear()
+
+ def _add_masks(
+ self,
+ data_sample: ImageWithSegMask,
+ masks: ArrayLikeUInt,
+ class_ids: ArrayLikeInt | None = None,
+ ) -> None:
+ """Adds a mask to the current data sample.
+
+ Args:
+ data_sample (ImageWithSegMask): Data sample to add mask to.
+ masks (ArrayLikeUInt): Binary masks shape [N, H, W] or [H, W].
+ class_ids (NDArrayInt, optional): Class ids for each mask, with
+ shape [N]. Defaults to None.
+ """
+ if class_ids is not None:
+ assert (
+ class_ids.shape[0] == masks.shape[0] # type: ignore
+ ), "The amount of masks must match the given class count!"
+
+ for mask, color in zip(
+ *preprocess_masks(masks, class_ids, self.color_palette)
+ ):
+ data_sample.masks.append(SegMask2D(mask=mask, color=color))
+
+ def _draw_image(self, sample: ImageWithSegMask) -> NDArrayUI8:
+ """Visualizes the datasample and returns is as numpy image.
+
+ Args:
+ sample (DataSample): The data sample to visualize.
+
+ Returns:
+ NDArrayUI8: A image with the visualized data sample.
+ """
+ self.canvas.create_canvas(sample.image)
+ for mask in sample.masks:
+ self.canvas.draw_bitmap(mask.mask, mask.color)
+ return self.canvas.as_numpy_image()
+
+ def process( # pylint: disable=arguments-differ
+ self,
+ cur_iter: int,
+ images: list[ArrayLikeFloat],
+ image_names: list[str],
+ masks: list[ArrayLikeUInt],
+ class_ids: list[ArrayLikeInt] | None = None,
+ ) -> None:
+ """Processes a batch of data.
+
+ Args:
+ cur_iter (int): Current iteration.
+ images (list[ArrayLikeFloat]): Images to show.
+ image_names (list[str]): Image names.
+ masks (list[ArrayLikeUInt]): Segmentation masks to show, each
+ with shape [H, W] or [N, H, W]. If the shape is [H, W], the
+ mask is assumed to be a semantic segmentation mask with each
+ pixel being the class id. If the shape is [N, H, W], each mask
+ is assumed to be a binary mask with each pixel being either 0
+ or 1.
+ class_ids (list[ArrayLikeInt], optional): Class ids for each mask,
+ with shape [N]. If set, the masks are assumed to be binary
+ masks and the length of class_ids must match the amount of
+ masks. Defaults to None.
+ """
+ if not self._run_on_batch(cur_iter):
+ return
+
+ for idx, image in enumerate(images):
+ self.process_single_image(
+ image,
+ image_names[idx],
+ masks[idx],
+ None if class_ids is None else class_ids[idx],
+ )
+
+ def process_single_image(
+ self,
+ image: ArrayLikeFloat,
+ image_name: str,
+ masks: ArrayLikeUInt,
+ class_ids: ArrayLikeInt | None = None,
+ ) -> None:
+ """Processes a single image entry.
+
+ Args:
+ image (ArrayLikeFloat): Images to show.
+ image_name (str): Name of the image.
+ masks (ArrayLikeUInt): Binary masks to show, each with shape
+ [N, H, W] or [H, W].
+ class_ids (ArrayLikeInt, optional): Class ids for each mask, with
+ shape [N]. Defaults to None.
+ """
+ img_normalized = preprocess_image(image, mode=self.image_mode)
+ data_sample = ImageWithSegMask(img_normalized, image_name, [])
+ self._add_masks(data_sample, masks, class_ids)
+ self._samples.append(data_sample)
+
+ def show(self, cur_iter: int, blocking: bool = True) -> None:
+ """Shows the processed images in a interactive window.
+
+ Args:
+ cur_iter (int): Current iteration.
+ blocking (bool): If the visualizer should be blocking i.e. wait for
+ human input for each image
+ """
+ if not self._run_on_batch(cur_iter):
+ return
+ image_data = [self._draw_image(d) for d in self._samples]
+ self.viewer.show_images(image_data, blocking=blocking)
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Writes all processes samples to the output folder naming each image
+ ..
+
+ Args:
+ cur_iter (int): Current iteration.
+ output_folder (str): Folder where the output should be written.
+ """
+ if not self._run_on_batch(cur_iter):
+ return
+ for sample in self._samples:
+ image_name = f"{sample.image_name}.{self.file_type}"
+
+ self.canvas.create_canvas(sample.image)
+ for mask in sample.masks:
+ self.canvas.draw_bitmap(mask.mask, mask.color)
+
+ self.canvas.save_to_disk(os.path.join(output_folder, image_name))
diff --git a/vis4d/vis/image/util.py b/vis4d/vis/image/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66dc65b82450c878c2a2f0b36867b77378887d2
--- /dev/null
+++ b/vis4d/vis/image/util.py
@@ -0,0 +1,423 @@
+"""Utility functions for image processing operations."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArrayLike,
+ ArrayLikeFloat,
+ ArrayLikeInt,
+ ArrayLikeUInt,
+ NDArrayBool,
+ NDArrayF32,
+ NDArrayUI8,
+)
+from vis4d.data.const import AxisMode
+from vis4d.op.box.box3d import (
+ boxes3d_in_image,
+ boxes3d_to_corners,
+ transform_boxes3d,
+)
+from vis4d.op.geometry.projection import project_points
+from vis4d.op.geometry.transform import inverse_rigid_transform
+from vis4d.vis.util import DEFAULT_COLOR_MAPPING
+
+
+def _get_box_label(
+ category: str | None,
+ score: float | None,
+ track_id: int | None,
+) -> str:
+ """Gets a unique string representation for a box definition.
+
+ Args:
+ category (str): The category name
+ score (float): The confidence score
+ track_id (int): The track id
+
+ Returns:
+ str: Label for this box of format
+ 'class_name, track_id, score%'
+ """
+ labels = []
+
+ if category is not None:
+ labels.append(category)
+ if track_id is not None:
+ labels.append(str(track_id))
+ if score is not None:
+ labels.append(f"{score * 100:.1f}%")
+ return ", ".join(labels)
+
+
+def _to_binary_mask(
+ mask: NDArrayUI8, ignore_class: int = 255
+) -> tuple[NDArrayUI8, NDArrayUI8]:
+ """Converts a mask to binary masks.
+
+ Args:
+ mask (NDArrayUI8): The mask to convert with shape [H, W].
+ ignore_class (int): The class id to ignore. Defaults to 255.
+
+ Returns:
+ NDArrayUI8: The binary masks with shape [N, H, W].
+ NDArrayUI8: The class ids for each binary mask.
+ """
+ binary_masks = []
+ class_ids = []
+ for class_id in np.unique(mask):
+ if class_id == ignore_class:
+ continue
+ binary_masks.append(mask == class_id)
+ class_ids.append(class_id)
+ return np.stack(binary_masks, axis=0), np.array(class_ids, dtype=np.uint8)
+
+
+def preprocess_boxes(
+ boxes: ArrayLikeFloat,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ color_palette: list[tuple[int, int, int]] = DEFAULT_COLOR_MAPPING,
+ class_id_mapping: dict[int, str] | None = None,
+ default_color: tuple[int, int, int] = (255, 0, 0),
+ categories: None | list[str] = None,
+) -> tuple[
+ list[tuple[float, float, float, float]],
+ list[str],
+ list[tuple[int, int, int]],
+]:
+ """Preprocesses bounding boxes.
+
+ Converts the given predicted bounding boxes and class/track information
+ into lists of corners, labels and colors.
+
+ Args:
+ boxes (ArrayLikeFloat): Boxes of shape [N, 4] where N is the number of
+ boxes and the second channel consists of
+ (x1,y1,x2,y2) box coordinates.
+ scores (ArrayLikeFloat): Scores for each box shape [N]
+ class_ids (ArrayLikeInt): Class id for each box shape [N]
+ track_ids (ArrayLikeInt): Track id for each box shape [N]
+ color_palette (list[tuple[float, float, float]]): Color palette for
+ each id.
+ class_id_mapping(dict[int, str], optional): Mapping from class id
+ to color tuple (0-255).
+ default_color (tuple[int, int, int]): fallback color for boxes of no
+ class or track id is given.
+ categories (None | list[str], optional): List of categories for each
+ box.
+
+ Returns:
+ boxes_proc (list[tuple[float, float, float, float]]): List of box
+ corners.
+ labels_proc (list[str]): List of labels.
+ colors_proc (list[tuple[int, int, int]]): List of colors.
+ """
+ if class_id_mapping is None:
+ class_id_mapping = {}
+
+ boxes = array_to_numpy(boxes, n_dims=2, dtype=np.float32)
+
+ scores_np = array_to_numpy(scores, n_dims=1, dtype=np.float32)
+ class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32)
+ track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32)
+
+ boxes_proc: list[tuple[float, float, float, float]] = []
+ colors_proc: list[tuple[int, int, int]] = []
+ labels_proc: list[str] = []
+
+ # Only one box provided
+ if len(boxes.shape) == 1:
+ # unsqueeze one dimension
+ boxes = boxes.reshape(1, -1)
+
+ for idx in range(boxes.shape[0]):
+ class_id = None if class_ids_np is None else class_ids_np[idx].item()
+ score = None if scores_np is None else scores_np[idx].item()
+ track_id = None if track_ids_np is None else track_ids_np[idx].item()
+
+ if track_id is not None:
+ color = color_palette[track_id % len(color_palette)]
+ elif class_id is not None:
+ color = color_palette[class_id % len(color_palette)]
+ else:
+ color = default_color
+
+ boxes_proc.append(
+ (
+ boxes[idx][0].item(),
+ boxes[idx][1].item(),
+ boxes[idx][2].item(),
+ boxes[idx][3].item(),
+ )
+ )
+ colors_proc.append(color)
+
+ if categories is not None:
+ category = categories[idx]
+ elif class_id is not None:
+ category = class_id_mapping.get(class_id, str(class_id))
+ else:
+ category = None
+
+ labels_proc.append(_get_box_label(category, score, track_id))
+ return boxes_proc, labels_proc, colors_proc
+
+
+def preprocess_boxes3d(
+ image_hw: tuple[int, int],
+ boxes3d: ArrayLikeFloat,
+ intrinsics: ArrayLikeFloat,
+ extrinsics: ArrayLikeFloat | None = None,
+ scores: None | ArrayLikeFloat = None,
+ class_ids: None | ArrayLikeInt = None,
+ track_ids: None | ArrayLikeInt = None,
+ color_palette: list[tuple[int, int, int]] = DEFAULT_COLOR_MAPPING,
+ class_id_mapping: dict[int, str] | None = None,
+ default_color: tuple[int, int, int] = (255, 0, 0),
+ axis_mode: AxisMode = AxisMode.OPENCV,
+ categories: None | list[str] = None,
+) -> tuple[
+ list[tuple[float, float, float]],
+ list[list[tuple[float, float, float]]],
+ list[str],
+ list[tuple[int, int, int]],
+ list[int | None],
+]:
+ """Preprocesses bounding boxes.
+
+ Converts the given predicted bounding boxes and class/track information
+ into lists of centers, corners, labels, colors and track_ids.
+ """
+ if class_id_mapping is None:
+ class_id_mapping = {}
+
+ boxes3d = array_to_numpy(boxes3d, n_dims=2, dtype=np.float32)
+ intrinsics = array_to_numpy(intrinsics, n_dims=2, dtype=np.float32)
+
+ boxes3d = torch.from_numpy(boxes3d)
+ intrinsics = torch.from_numpy(intrinsics)
+
+ if axis_mode != AxisMode.OPENCV:
+ assert (
+ extrinsics is not None
+ ), "extrinsics must be provided to move boxes to camera coordiante."
+ extrinsics = array_to_numpy(extrinsics, n_dims=2, dtype=np.float32)
+ extrinsics = torch.from_numpy(extrinsics)
+ global_to_cam = inverse_rigid_transform(extrinsics)
+ boxes3d_cam = transform_boxes3d(
+ boxes3d,
+ global_to_cam,
+ source_axis_mode=AxisMode.ROS,
+ target_axis_mode=AxisMode.OPENCV,
+ )
+ else:
+ boxes3d_cam = boxes3d
+
+ corners = boxes3d_to_corners(boxes3d_cam, axis_mode=AxisMode.OPENCV)
+
+ mask = boxes3d_in_image(corners, intrinsics, image_hw)
+
+ boxes3d_np = boxes3d.numpy()
+ corners_np = corners.numpy()
+
+ scores_np = array_to_numpy(scores, n_dims=1, dtype=np.float32)
+ class_ids_np = array_to_numpy(class_ids, n_dims=1, dtype=np.int32)
+ track_ids_np = array_to_numpy(track_ids, n_dims=1, dtype=np.int32)
+
+ centers_proc: list[tuple[float, float, float]] = []
+ corners_proc: list[list[tuple[float, float, float]]] = []
+ colors_proc: list[tuple[int, int, int]] = []
+ labels_proc: list[str] = []
+ track_ids_proc: list[int | None] = []
+
+ if len(mask) == 1:
+ if not mask[0]:
+ return (
+ centers_proc,
+ corners_proc,
+ labels_proc,
+ colors_proc,
+ track_ids_proc,
+ )
+ else:
+ boxes3d_np = boxes3d_np[mask]
+ corners_np = corners_np[mask]
+ scores_np = scores_np[mask] if scores_np is not None else None
+ class_ids_np = class_ids_np[mask] if class_ids_np is not None else None
+ track_ids_np = track_ids_np[mask] if track_ids_np is not None else None
+
+ for idx in range(corners_np.shape[0]):
+ class_id = None if class_ids_np is None else class_ids_np[idx].item()
+ score = None if scores_np is None else scores_np[idx].item()
+ track_id = None if track_ids_np is None else track_ids_np[idx].item()
+
+ if track_id is not None:
+ color = color_palette[track_id % len(color_palette)]
+ elif class_id is not None:
+ color = color_palette[class_id % len(color_palette)]
+ else:
+ color = default_color
+
+ centers_proc.append(
+ (
+ boxes3d_np[idx][0].item(),
+ boxes3d_np[idx][1].item(),
+ boxes3d_np[idx][2].item(),
+ )
+ )
+ corners_proc.append([tuple(pts) for pts in corners_np[idx].tolist()])
+ colors_proc.append(color)
+
+ if categories is not None:
+ category = categories[idx]
+ elif class_id is not None:
+ category = class_id_mapping.get(class_id, str(class_id))
+ else:
+ category = None
+
+ labels_proc.append(_get_box_label(category, score, track_id))
+ track_ids_proc.append(track_id)
+ return centers_proc, corners_proc, labels_proc, colors_proc, track_ids_proc
+
+
+def preprocess_masks(
+ masks: ArrayLikeUInt,
+ class_ids: ArrayLikeInt | None = None,
+ color_mapping: list[tuple[int, int, int]] = DEFAULT_COLOR_MAPPING,
+) -> tuple[list[NDArrayBool], list[tuple[int, int, int]]]:
+ """Preprocesses predicted semantic or instance segmentation masks.
+
+ Args:
+ masks (ArrayLikeUInt): Masks of shape [H, W] or [N, H, W]. If the
+ masks are of shape [H, W], they are assumed to be semantic
+ segmentation masks, i.e. each pixel contains the class id.
+ If the masks are of shape [N, H, W], they are assumed to be
+ the binary masks of N instances.
+ class_ids (ArrayLikeInt, None): An array with class ids for each mask
+ shape [N]. If None, then the masks must be semantic segmentation
+ masks and the class ids are extracted from the masks.
+ color_mapping (list[tuple[int, int, int]]): Color mapping for
+ each class.
+
+ Returns:
+ tuple[list[masks], list[colors]]: Returns a list with all masks of
+ shape [H, W] as well as a list with the corresponding colors.
+
+ Raises:
+ ValueError: If the masks have an invalid shape.
+ """
+ masks_np = array_to_numpy(masks, n_dims=None, dtype=np.uint8)
+
+ if len(masks_np.shape) == 2:
+ masks_np, class_ids = _to_binary_mask(masks_np)
+ elif len(masks_np.shape) == 3:
+ if class_ids is not None:
+ class_ids = array_to_numpy(class_ids, n_dims=1, dtype=np.int32)
+ else:
+ raise ValueError(
+ f"Expected masks to have 2 or 3 dimensions, but got "
+ f"{len(masks_np.shape)}"
+ )
+
+ masks_binary = masks_np.astype(bool)
+ mask_list: list[NDArrayBool] = []
+ color_list: list[tuple[int, int, int]] = []
+
+ for idx in range(masks_binary.shape[0]):
+ mask = masks_binary[idx, ...]
+
+ class_id = None if class_ids is None else class_ids[idx].item()
+ if class_id is not None:
+ color = color_mapping[class_id % len(color_mapping)]
+ else:
+ color = color_mapping[idx % len(color_mapping)]
+ mask_list.append(mask)
+ color_list.append(color)
+ return mask_list, color_list
+
+
+def preprocess_image(image: ArrayLike, mode: str = "RGB") -> NDArrayUI8:
+ """Validate and convert input image.
+
+ Args:
+ image: CHW or HWC image (ArrayLike) with C = 3.
+ mode: input channel format (e.g. BGR, HSV).
+
+ Returns:
+ np.array[uint8]: Processed image_np in RGB.
+ """
+ image_np = array_to_numpy(image, n_dims=3, dtype=np.float32)
+ # Convert torch to numpy
+ assert len(image_np.shape) == 3
+ assert image_np.shape[0] == 3 or image_np.shape[-1] == 3
+
+ # Convert torch to numpy convention
+ if not image_np.shape[-1] == 3:
+ image_np = np.transpose(image_np, (1, 2, 0))
+
+ # Convert image_np to [0, 255]
+ min_val, max_val = (
+ np.min(image_np, axis=(0, 1)),
+ np.max(image_np, axis=(0, 1)),
+ )
+ image_np = image_np.astype(np.float32)
+ image_np = (image_np - min_val) / (max_val - min_val) * 255.0
+
+ if mode == "BGR":
+ image_np = image_np[..., [2, 1, 0]]
+
+ return image_np.astype(np.uint8)
+
+
+def get_intersection_point(
+ point1: tuple[float, float, float],
+ point2: tuple[float, float, float],
+ camera_near_clip: float,
+) -> tuple[float, float, float]:
+ """Get point intersecting with camera near plane on line point1 -> point2.
+
+ The line is defined by two points in camera coordinates and their depth.
+
+ Args:
+ point1 (tuple[float x 3]): First point in camera coordinates.
+ point2 (tuple[float x 3]): Second point in camera coordinates
+ camera_near_clip (float): camera_near_clip
+
+ Returns:
+ tuple[float, float, float]: The intersection point in camera
+ coordiantes.
+ """
+ c1, c2, c3 = 0, 0, camera_near_clip
+ a1, a2, a3 = 0, 0, 1
+ x1, y1, z1 = point1
+ x2, y2, z2 = point2
+
+ k_up = abs(a1 * (x1 - c1) + a2 * (y1 - c2) + a3 * (z1 - c3))
+ k_down = abs(a1 * (x1 - x2) + a2 * (y1 - y2) + a3 * (z1 - z2))
+ if k_up > k_down:
+ k = 1.0
+ else:
+ k = k_up / k_down
+
+ return ((1 - k) * x1 + k * x2, (1 - k) * y1 + k * y2, camera_near_clip)
+
+
+def project_point(
+ point: tuple[float, float, float], intrinsics: NDArrayF32
+) -> tuple[float, float]:
+ """Project single point into the image plane."""
+ projected_x, projected_y = (
+ project_points(
+ torch.from_numpy(np.array([point], dtype=np.float32)),
+ torch.from_numpy(intrinsics),
+ )
+ .squeeze(0)
+ .numpy()
+ .tolist()
+ )
+ return projected_x, projected_y
diff --git a/vis4d/vis/image/viewer/__init__.py b/vis4d/vis/image/viewer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..207225e6dace026ad9eec9aa3bf778e93fc84f4d
--- /dev/null
+++ b/vis4d/vis/image/viewer/__init__.py
@@ -0,0 +1,6 @@
+"""Viewer implementations to display images."""
+
+from .base import ImageViewerBackend
+from .matplotlib_viewer import MatplotlibImageViewer
+
+__all__ = ["ImageViewerBackend", "MatplotlibImageViewer"]
diff --git a/vis4d/vis/image/viewer/base.py b/vis4d/vis/image/viewer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..59aed95cda1b3ac17a8e0aff9fd91ba8c2147798
--- /dev/null
+++ b/vis4d/vis/image/viewer/base.py
@@ -0,0 +1,32 @@
+"""Base class of image viewer for image based visualization."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import NDArrayUI8
+
+
+class ImageViewerBackend:
+ """Abstract interface that allows to show images."""
+
+ def show_images(
+ self, images: list[NDArrayUI8], blocking: bool = True
+ ) -> None:
+ """Shows a list of images.
+
+ Args:
+ images (list[NDArrayUI8]): Images to display.
+ blocking (bool, optional): If the viewer should be blocking and
+ wait for input after each image. Defaults to True.
+ """
+ raise NotImplementedError
+
+ def save_images(
+ self, images: list[NDArrayUI8], file_paths: list[str]
+ ) -> None:
+ """Saves a list of images.
+
+ Args:
+ images (list[NDArrayUI8]): Images to save.
+ file_paths (list[str]): File paths to save the images to.
+ """
+ raise NotImplementedError
diff --git a/vis4d/vis/image/viewer/matplotlib_viewer.py b/vis4d/vis/image/viewer/matplotlib_viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf11ec78d8fe600c1b462a2e1d868a578f9972f4
--- /dev/null
+++ b/vis4d/vis/image/viewer/matplotlib_viewer.py
@@ -0,0 +1,42 @@
+"""Matplotlib based image viewer."""
+
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+
+from vis4d.common.typing import NDArrayUI8
+
+from .base import ImageViewerBackend
+
+
+class MatplotlibImageViewer(ImageViewerBackend):
+ """A image viewer using matplotlib.pyplot."""
+
+ def show_images(
+ self, images: list[NDArrayUI8], blocking: bool = True
+ ) -> None:
+ """Shows a list of images.
+
+ Args:
+ images (list[NDArrayUI8]): Images to display.
+ blocking (bool): If the viewer should be blocking and wait
+ for human input after each image.
+ """
+ for image in images:
+ plt.imshow(image)
+ plt.axis("off")
+ plt.show(block=blocking)
+
+ def save_images(
+ self, images: list[NDArrayUI8], file_paths: list[str]
+ ) -> None:
+ """Saves a list of images.
+
+ Args:
+ images (list[NDArrayUI8]): Images to save.
+ file_paths (list[str]): File paths to save the images to.
+ """
+ for i, image in enumerate(images):
+ plt.imshow(image)
+ plt.axis("off")
+ plt.savefig(f"{file_paths[i]}", bbox_inches="tight")
diff --git a/vis4d/vis/pointcloud/__init__.py b/vis4d/vis/pointcloud/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..147260a84af3a18416d898f4a4ee814e33f7efcf
--- /dev/null
+++ b/vis4d/vis/pointcloud/__init__.py
@@ -0,0 +1,5 @@
+"""Pointcloud Visualization Package."""
+
+from .pointcloud_visualizer import PointCloudVisualizer
+
+__all__ = ["PointCloudVisualizer"]
diff --git a/vis4d/vis/pointcloud/functional.py b/vis4d/vis/pointcloud/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec7fd83af77f826c1adb94749499acc152df952
--- /dev/null
+++ b/vis4d/vis/pointcloud/functional.py
@@ -0,0 +1,87 @@
+"""Function interface for point cloud visualization functions."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import ArrayLikeFloat, ArrayLikeInt
+
+from ..util import DEFAULT_COLOR_MAPPING
+from .scene import Scene3D
+from .viewer import Open3DVisualizationBackend, PointCloudVisualizerBackend
+
+
+def show_3d(
+ scene: Scene3D,
+ viewer: PointCloudVisualizerBackend = Open3DVisualizationBackend(
+ class_color_mapping=DEFAULT_COLOR_MAPPING
+ ),
+) -> None:
+ """Shows a given 3D scene.
+
+ This method shows a 3D visualization of a given 3D scene. Use the viewer
+ attribute to use different visualization backends (e.g. open3d)
+
+ Args:
+ scene (Scene3D): The 3D scene that should be visualized.
+ viewer (PointCloudVisualizerBackend, optional): The Visualization
+ backend that should be used to visualize the scene.
+ Defaults to Open3DVisualizationBackend.
+ """
+ viewer.add_scene(scene)
+ viewer.show()
+ viewer.reset()
+
+
+def draw_points(
+ points_xyz: ArrayLikeFloat,
+ colors: ArrayLikeFloat | None = None,
+ classes: ArrayLikeInt | None = None,
+ instances: ArrayLikeInt | None = None,
+ transform: ArrayLikeFloat | None = None,
+ scene: Scene3D | None = None,
+) -> Scene3D:
+ """Adds pointcloud data to a 3D scene for visualization purposes.
+
+ Args:
+ points_xyz: xyz coordinates of the points shape [N, 3]
+ classes: semantic ids of the points shape [N, 1]
+ instances: instance ids of the points shape [N, 1]
+ colors: colors of the points shape [N,3] and ranging from [0,1]
+ transform: Optional 4x4 SE3 transform that transforms the point data
+ into a static reference frame.
+ scene (Scene3D | None): Visualizer that should be used to display the
+ data.
+ """
+ if scene is None:
+ scene = Scene3D()
+
+ return scene.add_pointcloud(
+ points_xyz, colors, classes, instances, transform
+ )
+
+
+def show_points(
+ points_xyz: ArrayLikeFloat,
+ colors: ArrayLikeFloat | None = None,
+ classes: ArrayLikeInt | None = None,
+ instances: ArrayLikeInt | None = None,
+ transform: ArrayLikeFloat | None = None,
+ viewer: PointCloudVisualizerBackend = Open3DVisualizationBackend(
+ class_color_mapping=DEFAULT_COLOR_MAPPING
+ ),
+) -> None:
+ """Visualizes a pointcloud with color and semantic information.
+
+ Args:
+ points_xyz: xyz coordinates of the points shape [N, 3]
+ classes: semantic ids of the points shape [N, 1]
+ instances: instance ids of the points shape [N, 1]
+ colors: colors of the points shape [N,3] and ranging from [0,1]
+ transform: Optional 4x4 SE3 transform that transforms the point data
+ into a static reference frame
+ viewer (PointCloudVisualizerBackend, optional): The Visualization
+ backend that should be used to visualize the scene.
+ Defaults to Open3DVisualizationBackend.
+ """
+ show_3d(
+ draw_points(points_xyz, colors, classes, instances, transform), viewer
+ )
diff --git a/vis4d/vis/pointcloud/pointcloud_visualizer.py b/vis4d/vis/pointcloud/pointcloud_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae27211340c8a5409217286ad09ae48b8053132e
--- /dev/null
+++ b/vis4d/vis/pointcloud/pointcloud_visualizer.py
@@ -0,0 +1,181 @@
+"""Vis4D Visualization tools for analysis and debugging."""
+
+from __future__ import annotations
+
+from vis4d.common.imports import OPEN3D_AVAILABLE
+from vis4d.common.typing import ArgsType, NDArrayF64, NDArrayI64
+from vis4d.vis.base import Visualizer
+from vis4d.vis.pointcloud.scene import Scene3D
+from vis4d.vis.pointcloud.viewer import PointCloudVisualizerBackend
+from vis4d.vis.util import DEFAULT_COLOR_MAPPING
+
+if OPEN3D_AVAILABLE:
+ from .viewer.open3d_viewer import Open3DVisualizationBackend
+
+
+# TODO: Check typing
+class PointCloudVisualizer(Visualizer):
+ """Visualizer that visualizes pointclouds."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ backend: str = "open3d",
+ class_color_mapping: list[
+ tuple[int, int, int]
+ ] = DEFAULT_COLOR_MAPPING,
+ instance_color_mapping: list[
+ tuple[int, int, int]
+ ] = DEFAULT_COLOR_MAPPING,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new Pointcloud visualizer.
+
+ Args:
+ backend (str): Visualization backend that should be used. Choice
+ of [open3d].
+ class_color_mapping (list[tuple[int, int, int]], optional): List
+ of length n_classes that assigns each class a unique color.
+ instance_color_mapping (list[tuple[int, int, int]], optional): List
+ of length n_classes that assigns each class a unique color.
+ """
+ super().__init__(*args, **kwargs)
+ if backend == "open3d":
+ if not OPEN3D_AVAILABLE:
+ raise ValueError(
+ "You have specified the open3d backend."
+ "But open3d is not installed on this system!"
+ )
+ self.visualization_backend: PointCloudVisualizerBackend = (
+ Open3DVisualizationBackend(
+ class_color_mapping=class_color_mapping,
+ instance_color_mapping=instance_color_mapping,
+ )
+ )
+ else:
+ raise ValueError(f"Unknown Point Visualization Backend {backend}")
+
+ self.current_scene_idx: int | None = None
+ self.current_scene: Scene3D | None = None
+
+ def process_single(
+ self,
+ points_xyz: NDArrayF64,
+ semantics: NDArrayI64 | None = None,
+ instances: NDArrayI64 | None = None,
+ colors: NDArrayF64 | None = None,
+ scene_index: NDArrayI64 | int | None = None,
+ ) -> None:
+ """Processes data and adds it to the visualizer.
+
+ Args:
+ points_xyz: xyz coordinates of the points shape [B, N, 3]
+ semantics: semantic ids of the points shape [B, N, 1]
+ instances: instance ids of the points shape [B, N, 1]
+ colors: colors of the points shape [B, N,3] and ranging from [0,1]
+ scene_index: Scene index for visualization of shape [B, 1].
+ This allows to plot multiple predictions in the same scene
+ if e.g. for memory reasons it had to be split up in multiple
+ channels..
+
+ Raises:
+ ValueError: If shapes of the arrays missmatch.
+ """
+ # Load correct scene
+ if scene_index is None:
+ # No scene index given. Create new scene for each call
+ self.current_scene = self.visualization_backend.create_new_scene()
+ else:
+ # Scene index given, check if we should update given scene
+ # or create a new one
+ new_scene_idx = (
+ scene_index
+ if isinstance(scene_index, int)
+ else scene_index.item()
+ )
+ if (
+ self.current_scene_idx is None
+ or self.current_scene_idx != new_scene_idx
+ ):
+ self.current_scene = (
+ self.visualization_backend.create_new_scene()
+ )
+ self.current_scene_idx = new_scene_idx
+
+ if self.current_scene is None:
+ self.current_scene = self.visualization_backend.create_new_scene()
+
+ # Add data to scene
+ self.current_scene.add_pointcloud(
+ points_xyz, colors=colors, classes=semantics, instances=instances
+ )
+
+ def process( # pylint: disable=arguments-differ
+ self,
+ cur_iter: int,
+ points_xyz: NDArrayF64,
+ semantics: NDArrayI64 | None = None,
+ instances: NDArrayI64 | None = None,
+ colors: NDArrayF64 | None = None,
+ scene_index: NDArrayI64 | None = None,
+ ) -> None:
+ """Processes a batch of data and adds it to the visualizer.
+
+ Args:
+ cur_iter: Current iteration.
+ points_xyz: xyz coordinates of the points shape [N, 3]
+ semantics: semantic ids of the points shape [N, 1]
+ instances: instance ids of the points shape [N, 1]
+ colors: colors of the points shape [N,3] and ranging from [0,1]
+ scene_index: Scene index for visualization of sape [1] or int.
+ This allows to plot multiple predictions in the same scene
+ if e.g. for memory reasons it had to be split up in multiple
+ chunls.
+
+ Raises:
+ ValueError: If shapes of the arrays missmatch.
+ """
+ if self._run_on_batch(cur_iter):
+ if len(points_xyz.shape) == 2: # Data is not batched
+ self.process_single(
+ points_xyz, semantics, instances, colors, scene_index
+ )
+ elif len(points_xyz.shape) == 3:
+ for idx in range(points_xyz.shape[0]):
+ self.process_single(
+ points_xyz[idx, ...],
+ semantics[idx, ...] if semantics is not None else None,
+ instances[idx, ...] if instances is not None else None,
+ colors[idx, ...] if colors is not None else None,
+ (
+ scene_index[idx, ...]
+ if scene_index is not None
+ else None
+ ),
+ )
+
+ else:
+ raise ValueError(
+ f"Invalid shape for point data: {points_xyz.shape}"
+ )
+
+ def show(self, cur_iter: int, blocking: bool = True) -> None:
+ """Shows the visualization.
+
+ Args:
+ cur_iter (int): Current iteration.
+ blocking (bool): If the visualization should be blocking and wait
+ for human input
+ """
+ self.visualization_backend.show(blocking)
+
+ def reset(self) -> None:
+ """Clears all saved data."""
+ self.visualization_backend.reset()
+ self.current_scene_idx = None
+ self.current_scene = None
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk."""
+ if self._run_on_batch(cur_iter):
+ self.visualization_backend.save_to_disk(output_folder)
diff --git a/vis4d/vis/pointcloud/scene.py b/vis4d/vis/pointcloud/scene.py
new file mode 100644
index 0000000000000000000000000000000000000000..963a99f7cc200e9ecf9aecdd5849b8ae2ab5d43c
--- /dev/null
+++ b/vis4d/vis/pointcloud/scene.py
@@ -0,0 +1,279 @@
+"""Data structures to store 3D data."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import ArrayLike, NDArrayFloat, NDArrayInt
+
+
+@dataclass
+class BoundingBoxData:
+ """Stores bounding box data for visualization.
+
+ Attributes:
+ corners (NDArrayFloat): Corners of the bounding box shape [8, 3].
+ color (NDArrayFloat): Colors of the bounding box shape [3].
+ class (int | None): Class id of the bounding box. Defaults to None.
+ instance (int | None): Instance id of the bounding box.
+ Defaults to None.
+ score (float | None): Score of the bounding box. Defaults to None.
+ """
+
+ corners: NDArrayFloat
+ color: NDArrayFloat | None
+ class_: int | None
+ instance: int | None
+ score: float | None
+
+ def transform(self, transform: NDArrayFloat) -> BoundingBoxData:
+ """Transforms the bounding box.
+
+ Args:
+ transform (NDArrayFloat): Transformation matrix shape [4,4] that
+ transforms points from the current local frame to a fixed
+ global frame.
+
+ Returns:
+ BoundingBoxData: Returns a new bounding box with the transformed
+ points.
+ """
+ assert transform.shape == (
+ 4,
+ 4,
+ ), "Shape of the provided transform not valid."
+ return BoundingBoxData(
+ (transform[:3, :3] @ self.corners.T).T + transform[:3, -1],
+ self.color,
+ self.class_,
+ self.instance,
+ self.score,
+ )
+
+
+@dataclass
+class PointcloudData:
+ """Stores pointcloud data for visualization.
+
+ Attributes:
+ xyz: Point Coordinates shape [n_pts,3].
+ colors: Point Colors shape [n_pts, 3] or None.
+ classes: Class ids shape [n_pts] or None.
+ instances: Instance ids shape [n_pts] or None.
+ num_points: Total number of points.
+ num_classes: Total number of classes.
+ num_instances: Total number of unique class, instance combinations.
+ """
+
+ xyz: NDArrayFloat
+ colors: NDArrayFloat | None
+ classes: NDArrayInt | None
+ instances: NDArrayInt | None
+
+ num_points: int
+ num_classes: int
+ num_instances: int
+
+ def __init__(
+ self,
+ xyz: ArrayLike,
+ colors: ArrayLike | None = None,
+ classes: ArrayLike | None = None,
+ instances: ArrayLike | None = None,
+ ) -> None:
+ """Creates a new pointcloud.
+
+ Args:
+ xyz (ArrayLike): Coordinates for each point shape [n_pts, 3]
+ colors (ArrayLike | None, optional): Colors for each point encoded
+ as rgb [n_pts, 3] in the range (0,255). Defaults to None.
+ classes (ArrayLike | None, optional): Class id for each point
+ shape [n_pts]. Defaults to None.
+ instances (ArrayLike | None, optional): Instance id for each point.
+ shape [n_pts]. Defaults to None.
+ """
+ self.xyz = array_to_numpy(xyz, n_dims=2, dtype=np.float32)
+ self.colors = array_to_numpy(colors, n_dims=2, dtype=np.float32)
+ self.classes = array_to_numpy(classes, n_dims=1, dtype=np.int32)
+ self.instances = array_to_numpy(instances, n_dims=1, dtype=np.int32)
+
+ # Assing other properties. Number points, ...
+ self.num_points = self.xyz.shape[0]
+
+ if self.classes is not None:
+ self.num_classes = len(np.unique(self.classes))
+
+ if self.instances is not None:
+ if self.classes is None:
+ self.num_instances = len(np.unique(self.instances))
+ else:
+ self.num_instances = len(
+ np.unique(
+ self.classes * np.max(self.instances) + self.instances
+ )
+ )
+
+ def transform(self, transform: NDArrayFloat) -> PointcloudData:
+ """Transforms the pointcloud.
+
+ Args:
+ transform (NDArrayFloat): Transformation matrix shape [4,4] that
+ transforms points from the current local frame to a fixed
+ global frame.
+
+ Returns:
+ PointcloudData: Returns a new pointcloud with the transformed
+ points.
+ """
+ assert transform.shape == (
+ 4,
+ 4,
+ ), "Shape of the provided transform not valid."
+ return PointcloudData(
+ (transform[:3, :3] @ self.xyz.T).T + transform[:3, -1],
+ self.colors,
+ self.classes,
+ self.instances,
+ )
+
+
+class Scene3D:
+ """Stores the data for a 3D scene.
+
+ This Scene3D object can be used to be visualized by any 3D viewer.
+
+ Attributes:
+ pointclouds (list[PointcloudData]): Stores all pointclouds that
+ have been registered for this scene so far.
+ pointclouds (list[NDArrayFloat]): Stores a transformation matrix
+ (SE3, shape (4,4)) for each pointcloud.
+ """
+
+ def __init__(self) -> None:
+ """Creates a new, empty scene."""
+ self._pointclouds: list[tuple[PointcloudData, NDArrayFloat]] = []
+ self._bounding_boxes: list[tuple[BoundingBoxData, NDArrayFloat]] = []
+
+ @staticmethod
+ def _parse_se3_transform(transform: ArrayLike | None) -> NDArrayFloat:
+ """Parses a SE3 transformation matrix.
+
+ Args:
+ transform (ArrayLike | None): Transformation matrix shape [4,4]
+ that transforms points from the current local frame to a fixed
+ global frame.
+
+ Returns:
+ NDArrayFloat: Returns a valid SE3 transformation matrix.
+ """
+ tf = array_to_numpy(transform, n_dims=2, dtype=np.float32)
+
+ if tf is None:
+ return np.eye(4)
+
+ assert tf.shape == (
+ 4,
+ 4,
+ ), "Shape of the provided transform not valid."
+ return tf
+
+ def add_bounding_box(
+ self,
+ corners: ArrayLike,
+ color: ArrayLike | None,
+ class_: int | None,
+ instance: int | None,
+ score: float | None,
+ transform: ArrayLike | None = None,
+ ) -> Scene3D:
+ """Adds a bounding box to the 3D Scene.
+
+ Args:
+ corners (ArrayLike): Corners of the bounding box shape [8, 3].
+ color (ArrayLike | None): Color of the bounding box shape [3].
+ class_ (int | None): Class id of the bounding box.
+ Defaults to None.
+ instance (int | None): Instance id of the bounding box.
+ Defaults to None.
+ score (float | None): Score of the bounding box. Defaults to None.
+ transform (ArrayLike | None): Transformation matrix shape [4,4]
+ that transforms points from the current local frame to a fixed
+ global frame.
+
+ Returns:
+ Scene3D: Returns 'self' to chain calls.
+ """
+ corners_np = array_to_numpy(corners, n_dims=2, dtype=np.float32)
+ colors_np = array_to_numpy(color, n_dims=1, dtype=np.float32)
+ self._bounding_boxes.append(
+ (
+ BoundingBoxData(
+ corners_np,
+ colors_np,
+ class_,
+ instance,
+ score,
+ ),
+ self._parse_se3_transform(transform),
+ ),
+ )
+ return self
+
+ def add_pointcloud(
+ self,
+ xyz: ArrayLike,
+ colors: ArrayLike | None = None,
+ classes: ArrayLike | None = None,
+ instances: ArrayLike | None = None,
+ transform: ArrayLike | None = None,
+ ) -> Scene3D:
+ """Adds a pointcloud to the 3D Scene.
+
+ Args:
+ xyz (ArrayLike): Coordinates for each point shape [n_pts, 3] in the
+ current local frame.
+ colors (ArrayLike | None, optional): Colors for each point encoded
+ as rgb [n_pts, 3] in the range (0,255) or (0,1).
+ Defaults to None.
+ classes (ArrayLike | None, optional): Class id for each point
+ shape [n_pts]. Defaults to None.
+ instances (ArrayLike | None, optional): Instance id for each point.
+ shape [n_pts]. Defaults to None.
+ transform (ArrayLike | None, optional): Transformation matrix
+ shape [4,4] that transforms points from the current local frame
+ to a fixed global frame. Defaults to None which is the identity
+ matrix.
+
+ Returns:
+ Scene3D: Returns 'self' to chain calls.
+ """
+ self._pointclouds.append(
+ (
+ PointcloudData(xyz, colors, classes, instances),
+ self._parse_se3_transform(transform),
+ )
+ )
+ return self
+
+ @property
+ def bounding_boxes(self) -> list[BoundingBoxData]:
+ """Returns all bounding boxes in the scene.
+
+ Returns:
+ list[BoundingBoxData]: List of all bounding boxes in the scene.
+ """
+ return [bbox.transform(tf) for (bbox, tf) in self._bounding_boxes]
+
+ @property
+ def points(self) -> list[PointcloudData]:
+ """Returns all points of all pointclouds in the scene.
+
+ Returns:
+ List[PointcloudData]: Data information for all points in the scene.
+ Providing information about the points, colors, classes and
+ instances.
+ """
+ return [pc.transform(tf) for (pc, tf) in self._pointclouds]
diff --git a/vis4d/vis/pointcloud/viewer/__init__.py b/vis4d/vis/pointcloud/viewer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2d6326dea18ed46acb6be4c7587330c609e4235
--- /dev/null
+++ b/vis4d/vis/pointcloud/viewer/__init__.py
@@ -0,0 +1,6 @@
+"""Viewer implementations to display pointcloud."""
+
+from .base import PointCloudVisualizerBackend
+from .open3d_viewer import Open3DVisualizationBackend
+
+__all__ = ["PointCloudVisualizerBackend", "Open3DVisualizationBackend"]
diff --git a/vis4d/vis/pointcloud/viewer/base.py b/vis4d/vis/pointcloud/viewer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bce996c7d255de2793f585a2e8fc84e442fa476
--- /dev/null
+++ b/vis4d/vis/pointcloud/viewer/base.py
@@ -0,0 +1,86 @@
+"""Generic classes to visualize and save pointcloud data."""
+
+from __future__ import annotations
+
+import numpy as np
+
+from ..scene import Scene3D
+
+
+class PointCloudVisualizerBackend:
+ """Visualization Backen Interface for Pointclouds."""
+
+ def __init__(
+ self,
+ class_color_mapping: list[tuple[int, int, int]],
+ instance_color_mapping: list[tuple[int, int, int]] | None = None,
+ ) -> None:
+ """Creates a new Open3D visualization backend.
+
+ Args:
+ class_color_mapping (list[tuple[int, int ,int]]): List of length
+ n_classes that maps each class index to a unique color.
+ instance_color_mapping (list[tuple[int, int ,int]], optional): List
+ of length n_instances that maps each instance id to a unique
+ color. Defaults to None.
+ """
+ self.scenes: list[Scene3D] = []
+
+ self.class_color_mapping = np.asarray(class_color_mapping)
+
+ if np.any(self.class_color_mapping > 1): # Color mapping from [0, 255]
+ self.class_color_mapping = self.class_color_mapping / 255
+
+ if instance_color_mapping is None:
+ self.instance_color_mapping = self.class_color_mapping
+ else:
+ self.instance_color_mapping = np.asarray(instance_color_mapping)
+ if np.any(self.instance_color_mapping > 1):
+ self.instance_color_mapping = self.instance_color_mapping / 255
+
+ def create_new_scene(self) -> Scene3D:
+ """Creates a new empty scene."""
+ self.scenes.append(Scene3D())
+ return self.get_current_scene()
+
+ def get_current_scene(self) -> Scene3D:
+ """Returns the currently active scene.
+
+ If no scene is available, an new empty one is created.
+
+ Returns:
+ Scene3D: current pointcloud scene
+ """
+ if (len(self.scenes)) == 0:
+ return self.create_new_scene()
+
+ return self.scenes[-1]
+
+ def show(self, blocking: bool = True) -> None:
+ """Shows the visualization.
+
+ Args:
+ blocking (bool): If the visualization should be blocking
+ and wait for human input
+ """
+ raise NotImplementedError()
+
+ def reset(self) -> None:
+ """Clears all stored data."""
+ self.scenes = []
+
+ def add_scene(self, scene: Scene3D) -> None:
+ """Adds a given Scene3D to the visualization.
+
+ Args:
+ scene (Scene3D): 3D scene that should be added.
+ """
+ self.scenes.append(scene)
+
+ def save_to_disk(self, path_to_out_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Args:
+ path_to_out_folder (str): Path to output folder
+ """
+ raise NotImplementedError()
diff --git a/vis4d/vis/pointcloud/viewer/open3d_viewer.py b/vis4d/vis/pointcloud/viewer/open3d_viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..942878b26ee34ddeee70d99b01de89085a32057f
--- /dev/null
+++ b/vis4d/vis/pointcloud/viewer/open3d_viewer.py
@@ -0,0 +1,184 @@
+"""Open3d visualization backend."""
+
+from __future__ import annotations
+
+import os
+from typing import TypedDict
+
+import numpy as np
+
+from vis4d.common.imports import OPEN3D_AVAILABLE
+from vis4d.common.typing import NDArrayF64
+from vis4d.vis.pointcloud.scene import Scene3D
+
+from .base import PointCloudVisualizerBackend
+
+if OPEN3D_AVAILABLE:
+ import open3d as o3d
+
+
+class PointcloudVisEntry(TypedDict):
+ """Entry for a pointcloud to visualize with open3d.
+
+ Only used for typing.
+ """
+
+ name: str
+ geometry: o3d.geometry.PointCloud
+
+
+class Open3DVisualizationBackend(PointCloudVisualizerBackend):
+ """Backend that uses open3d to visualize potincloud data."""
+
+ def __init__(
+ self,
+ class_color_mapping: list[tuple[int, int, int]],
+ instance_color_mapping: list[tuple[int, int, int]] | None = None,
+ ) -> None:
+ """Creates a new Open3D visualization backend.
+
+ Args:
+ color_mapping (NDArrayF64): array of size [n_classes, 3] that maps
+ each class index to a unique color.
+ class_color_mapping (list[tuple[int, int, int]]): List of length
+ n_classes that assigns each class a unique color.
+ instance_color_mapping (list[tuple[int, int, int]], optional): List
+ of length n_classes that maps each instance id to unqiue color.
+ Defaults to None.
+ """
+ super().__init__(
+ class_color_mapping=class_color_mapping,
+ instance_color_mapping=instance_color_mapping,
+ )
+
+ def save_to_disk(self, path_to_out_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Creates files [colors.ply, classes.ply, instances.ply] for each scene
+
+ Args:
+ path_to_out_folder (str): Path to output folder
+ """
+ for idx, scene in enumerate(self.scenes):
+ out_folder = os.path.join(path_to_out_folder, f"scene_{idx:03d}")
+ os.makedirs(out_folder, exist_ok=True)
+
+ for vis_pc in self._get_pc_data_for_scene(scene):
+ name = vis_pc["name"]
+ pc = vis_pc["geometry"]
+ o3d.io.write_point_cloud(
+ os.path.join(out_folder, f"{name}.ply"), pc
+ )
+ print("written", f"{name}.ply")
+
+ def show(self, blocking: bool = False) -> None:
+ """Shows the visualization.
+
+ Args:
+ blocking (bool): If the visualization should be blocking
+ and wait for human input.
+ """
+ for scene in self.scenes:
+ vis_data = []
+ vis_data += self._get_pc_data_for_scene(scene)
+
+ o3d.visualization.draw(
+ vis_data, non_blocking_and_return_uid=not blocking
+ )
+
+ def _get_pc_data_for_scene(
+ self, scene: Scene3D
+ ) -> list[PointcloudVisEntry]:
+ """Converts a given scene to a list of o3d data to visualize.
+
+ Args:
+ scene (PointcloudVisEntry): Point cloud scene to visualize
+ Returns:
+ list[dict[str, Any]]: List of o3d geometries primitives to show.
+ """
+ xyz, colors, classes, instances = [], [], [], []
+ has_classes = False
+ has_instances = False
+
+ for pc in scene.points:
+ n_pts = pc.xyz.shape[0]
+
+ xyz.append(pc.xyz)
+ colors.append(
+ pc.colors if pc.colors is not None else np.zeros((n_pts, 3))
+ )
+
+ if pc.classes is not None:
+ has_classes = True
+ col = self.class_color_mapping[
+ pc.classes.squeeze() % self.class_color_mapping.shape[0]
+ ]
+ classes.append(col)
+ else:
+ classes.append(np.zeros((n_pts, 3)))
+
+ if pc.instances is not None:
+ has_instances = True
+ col = self.instance_color_mapping[
+ pc.instances.squeeze()
+ % self.instance_color_mapping.shape[0]
+ ]
+ instances.append(col)
+ else:
+ instances.append(np.zeros((n_pts, 3)))
+
+ data: list[PointcloudVisEntry] = []
+
+ data += [
+ {
+ "name": "colors",
+ "geometry": self._create_o3d_cloud(
+ np.concatenate(xyz), np.concatenate(colors)
+ ),
+ }
+ ]
+ if has_instances:
+ data += [
+ {
+ "name": "instances",
+ "geometry": self._create_o3d_cloud(
+ np.concatenate(xyz), np.concatenate(instances)
+ ),
+ }
+ ]
+ if has_classes:
+ data += [
+ {
+ "name": "classes",
+ "geometry": self._create_o3d_cloud(
+ np.concatenate(xyz), np.concatenate(classes)
+ ),
+ }
+ ]
+
+ return data
+
+ @staticmethod
+ def _create_o3d_cloud(
+ points: NDArrayF64,
+ colors: NDArrayF64 | None = None,
+ normals: NDArrayF64 | None = None,
+ ) -> o3d.geometry.PointCloud:
+ """Creates a o3d pointcloud from poitns and colors.
+
+ Args:
+ points (NDArrayF64): xyz coordinates of the points
+ colors (NDArrayF64, optional): Colors of the points
+ normals (NDArrayF64, optional): Surface normals
+
+ Returns:
+ o3d.geometry.PointCloud: o3d pointcloud with the given attributes
+ """
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ if colors is not None and len(colors) > 0:
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+ if normals is not None and len(normals) > 0:
+ pcd.normals = o3d.utility.Vector3dVector(normals)
+
+ return pcd
diff --git a/vis4d/vis/util.py b/vis4d/vis/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9242f0a9391b6e1b38cc4ec8ae37d24ce9543e46
--- /dev/null
+++ b/vis4d/vis/util.py
@@ -0,0 +1,33 @@
+"""Utilities for visualization."""
+
+from __future__ import annotations
+
+import colorsys
+
+import numpy as np
+
+
+def generate_color_map(length: int) -> list[tuple[int, int, int]]:
+ """Generate a color palette of [length] colors.
+
+ Args:
+ length (int): Number of colors to generate.
+
+ Returns:
+ list[tuple[int, int, int]]: List with different colors ranging
+ from [0,255].
+ """
+ brightness = 0.7
+ hsv = [(i / length, 1, brightness) for i in range(length)]
+ colors_float = [colorsys.hsv_to_rgb(*c) for c in hsv]
+ colors: list[int] = (
+ (np.array(colors_float) * 255).astype(np.uint8).tolist()
+ )
+ s = np.random.get_state()
+ np.random.seed(0)
+ result = [tuple(colors[i]) for i in np.random.permutation(len(colors))]
+ np.random.set_state(s)
+ return result
+
+
+DEFAULT_COLOR_MAPPING = generate_color_map(50)
diff --git a/vis4d/zoo/__init__.py b/vis4d/zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9a31eb9aacef6c1e4da8c7f43df5e7d702cd0da
--- /dev/null
+++ b/vis4d/zoo/__init__.py
@@ -0,0 +1,31 @@
+"""Model Zoo."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import ArgsType
+
+from .bdd100k import AVAILABLE_MODELS as BDD100K_MODELS
+from .bevformer import AVAILABLE_MODELS as BEVFORMER_MODELS
+from .cc_3dt import AVAILABLE_MODELS as CC_3DT_MODELS
+from .faster_rcnn import AVAILABLE_MODELS as FASTER_RCNN_MODELS
+from .fcn_resnet import AVAILABLE_MODELS as FCN_RESNET_MODELS
+from .mask_rcnn import AVAILABLE_MODELS as MASK_RCNN_MODELS
+from .qdtrack import AVAILABLE_MODELS as QDTRACK_MODELS
+from .retinanet import AVAILABLE_MODELS as RETINANET_MODELS
+from .shift import AVAILABLE_MODELS as SHIFT_MODELS
+from .vit import AVAILABLE_MODELS as VIT_MODELS
+from .yolox import AVAILABLE_MODELS as YOLOX_MODELS
+
+AVAILABLE_MODELS: dict[str, dict[str, ArgsType]] = {
+ "bdd100k": BDD100K_MODELS,
+ "cc_3dt": CC_3DT_MODELS,
+ "bevformer": BEVFORMER_MODELS,
+ "faster_rcnn": FASTER_RCNN_MODELS,
+ "fcn_resnet": FCN_RESNET_MODELS,
+ "mask_rcnn": MASK_RCNN_MODELS,
+ "qdtrack": QDTRACK_MODELS,
+ "retinanet": RETINANET_MODELS,
+ "shift": SHIFT_MODELS,
+ "vit": VIT_MODELS,
+ "yolox": YOLOX_MODELS,
+}
diff --git a/vis4d/zoo/base/__init__.py b/vis4d/zoo/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b920b1c675f7cb72fdcac830cfee95700df7f75d
--- /dev/null
+++ b/vis4d/zoo/base/__init__.py
@@ -0,0 +1,18 @@
+"""Model Zoo base."""
+
+from .callable import get_callable_cfg
+from .dataloader import get_inference_dataloaders_cfg, get_train_dataloader_cfg
+from .optimizer import get_lr_scheduler_cfg, get_optimizer_cfg
+from .pl_trainer import get_default_pl_trainer_cfg
+from .runtime import get_default_callbacks_cfg, get_default_cfg
+
+__all__ = [
+ "get_callable_cfg",
+ "get_train_dataloader_cfg",
+ "get_inference_dataloaders_cfg",
+ "get_optimizer_cfg",
+ "get_lr_scheduler_cfg",
+ "get_default_cfg",
+ "get_default_callbacks_cfg",
+ "get_default_pl_trainer_cfg",
+]
diff --git a/vis4d/zoo/base/callable.py b/vis4d/zoo/base/callable.py
new file mode 100644
index 0000000000000000000000000000000000000000..a50553f109671084bc35b30dac322ce933947185
--- /dev/null
+++ b/vis4d/zoo/base/callable.py
@@ -0,0 +1,19 @@
+"""Callable objects for use in config files."""
+
+from ml_collections import ConfigDict
+
+from vis4d.common.typing import ArgsType, GenericFunc
+from vis4d.config import class_config, delay_instantiation
+
+
+def get_callable_cfg(func: GenericFunc, **kwargs: ArgsType) -> ConfigDict:
+ """Return callable config.
+
+ Args:
+ func (GenericFunc): Callable object.
+ **kwargs (ArgsType): Keyword arguments to pass to the callable.
+
+ Returns:
+ ConfigDict: Config for the callable.
+ """
+ return delay_instantiation(class_config(func, **kwargs))
diff --git a/vis4d/zoo/base/data_connectors/__init__.py b/vis4d/zoo/base/data_connectors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb607050d8bd36610def74b22b6148ff4d1801a0
--- /dev/null
+++ b/vis4d/zoo/base/data_connectors/__init__.py
@@ -0,0 +1,20 @@
+"""Base data connectors."""
+
+from .common import CONN_IMAGES_TEST, CONN_IMAGES_TRAIN
+from .detection import CONN_BBOX_2D_TEST, CONN_BBOX_2D_TRAIN, CONN_BOX_LOSS_2D
+from .visualizers import (
+ CONN_BBOX_2D_TRACK_VIS,
+ CONN_BBOX_2D_VIS,
+ CONN_INS_MASK_2D_VIS,
+)
+
+__all__ = [
+ "CONN_IMAGES_TEST",
+ "CONN_IMAGES_TRAIN",
+ "CONN_BBOX_2D_TEST",
+ "CONN_BBOX_2D_TRAIN",
+ "CONN_BOX_LOSS_2D",
+ "CONN_BBOX_2D_VIS",
+ "CONN_BBOX_2D_TRACK_VIS",
+ "CONN_INS_MASK_2D_VIS",
+]
diff --git a/vis4d/zoo/base/data_connectors/cls.py b/vis4d/zoo/base/data_connectors/cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..d76edfb5d949ad9bd22e91f80ecc0cf9c9d11873
--- /dev/null
+++ b/vis4d/zoo/base/data_connectors/cls.py
@@ -0,0 +1,18 @@
+"""Data connectors for classification."""
+
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import data_key, pred_key
+
+CONN_CLS_TRAIN = {K.images: K.images}
+
+CONN_CLS_TEST = {K.images: K.images}
+
+CONN_CLS_LOSS = {
+ "input": pred_key("logits"),
+ "target": data_key("categories"),
+}
+
+CONN_CLS_EVAL = {
+ "prediction": pred_key("probs"),
+ "groundtruth": data_key("categories"),
+}
diff --git a/vis4d/zoo/base/data_connectors/common.py b/vis4d/zoo/base/data_connectors/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..5738c731a4a965c296ee8c05a6723ddee514c3b5
--- /dev/null
+++ b/vis4d/zoo/base/data_connectors/common.py
@@ -0,0 +1,11 @@
+"""Data connectors for common tasks."""
+
+from vis4d.data.const import CommonKeys as K
+
+CONN_IMAGES_TRAIN = {"images": K.images, "input_hw": K.input_hw}
+
+CONN_IMAGES_TEST = {
+ "images": K.images,
+ "input_hw": K.input_hw,
+ "original_hw": K.original_hw,
+}
diff --git a/vis4d/zoo/base/data_connectors/detection.py b/vis4d/zoo/base/data_connectors/detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec2921da2705575797f93454b267b83ff25af57
--- /dev/null
+++ b/vis4d/zoo/base/data_connectors/detection.py
@@ -0,0 +1,24 @@
+"""Data connectors for detection."""
+
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import data_key, pred_key
+
+CONN_BBOX_2D_TRAIN = {
+ "images": K.images,
+ "input_hw": K.input_hw,
+ "boxes2d": K.boxes2d,
+ "boxes2d_classes": K.boxes2d_classes,
+}
+
+CONN_BBOX_2D_TEST = {
+ "images": K.images,
+ "input_hw": K.input_hw,
+ "original_hw": K.original_hw,
+}
+
+CONN_BOX_LOSS_2D = {
+ "cls_outs": pred_key("cls_score"),
+ "reg_outs": pred_key("bbox_pred"),
+ "target_boxes": data_key(K.boxes2d),
+ "images_hw": data_key(K.input_hw),
+}
diff --git a/vis4d/zoo/base/data_connectors/seg.py b/vis4d/zoo/base/data_connectors/seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6899f725856a1b6c6dad8e28f2d961bd2611f29
--- /dev/null
+++ b/vis4d/zoo/base/data_connectors/seg.py
@@ -0,0 +1,29 @@
+"""Data connectors for segmentation."""
+
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import data_key, pred_key
+
+CONN_MASKS_TRAIN = {"images": K.images}
+
+CONN_MASKS_TEST = {"images": K.images, K.original_hw: "original_hw"}
+
+CONN_SEG_LOSS = {
+ "output": pred_key("outputs"),
+ "target": data_key(K.seg_masks),
+}
+
+CONN_MULTI_SEG_LOSS = {
+ "outputs": pred_key("outputs"),
+ "target": data_key(K.seg_masks),
+}
+
+CONN_SEG_EVAL = {
+ "prediction": pred_key(K.seg_masks),
+ "groundtruth": data_key(K.seg_masks),
+}
+
+CONN_SEG_VIS = {
+ K.images: data_key(K.images),
+ "image_names": data_key(K.sample_names),
+ "masks": pred_key("masks"),
+}
diff --git a/vis4d/zoo/base/data_connectors/visualizers.py b/vis4d/zoo/base/data_connectors/visualizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b29e35af4fe996899c08c67fff1cccdb9d60a78
--- /dev/null
+++ b/vis4d/zoo/base/data_connectors/visualizers.py
@@ -0,0 +1,27 @@
+"""Default data connectors for visualizers."""
+
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import data_key, pred_key
+
+CONN_BBOX_2D_VIS = {
+ "images": data_key(K.original_images),
+ "image_names": data_key(K.sample_names),
+ "boxes": pred_key("boxes"),
+ "scores": pred_key("scores"),
+ "class_ids": pred_key("class_ids"),
+}
+
+CONN_BBOX_2D_TRACK_VIS = {
+ "images": data_key(K.original_images),
+ "image_names": data_key(K.sample_names),
+ "boxes": pred_key("boxes"),
+ "scores": pred_key("scores"),
+ "class_ids": pred_key("class_ids"),
+ "track_ids": pred_key("track_ids"),
+}
+
+CONN_INS_MASK_2D_VIS = {
+ "images": data_key(K.original_images),
+ "image_names": data_key(K.sample_names),
+ "masks": pred_key("masks.masks"),
+}
diff --git a/vis4d/zoo/base/dataloader.py b/vis4d/zoo/base/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c6371b46b0f17732a4d713d4dde360ab1c50e7
--- /dev/null
+++ b/vis4d/zoo/base/dataloader.py
@@ -0,0 +1,134 @@
+"""Dataloader configuration."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict, FieldReference
+
+from vis4d.common.typing import GenericFunc
+from vis4d.config import class_config
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.loader import (
+ DEFAULT_COLLATE_KEYS,
+ build_inference_dataloaders,
+ build_train_dataloader,
+ default_collate,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+
+from .callable import get_callable_cfg
+
+
+def get_train_dataloader_cfg(
+ datasets_cfg: ConfigDict | list[ConfigDict],
+ samples_per_gpu: int | FieldReference = 1,
+ workers_per_gpu: int | FieldReference = 1,
+ batchprocess_cfg: ConfigDict | None = None,
+ collate_fn: GenericFunc = default_collate,
+ collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
+ sensors: Sequence[str] | None = None,
+ pin_memory: bool | FieldReference = True,
+ shuffle: bool | FieldReference = True,
+ aspect_ratio_grouping: bool | FieldReference = False,
+) -> ConfigDict:
+ """Creates dataloader configuration given dataset and preprocessing.
+
+ Args:
+ datasets_cfg (ConfigDict | list[ConfigDict]): The configuration
+ contains the single dataset or datasets. If it is a list,
+ it will be wrapped into a DataPipe.
+ samples_per_gpu (int | FieldReference, optional): How many samples each
+ GPU will process. Defaults to 1.
+ workers_per_gpu (int | FieldReference, optional): How many workers to
+ spawn per GPU. Defaults to 1.
+ batchprocess_cfg (ConfigDict, optional): The config that contains the
+ batch processing operations. Defaults to None. If None, ToTensor
+ will be used.
+ collate_fn (GenericFunc, optional): The collate function to use.
+ Defaults to default_collate.
+ collate_keys (Sequence[str], optional): The keys to collate. Defaults
+ to DEFAULT_COLLATE_KEYS.
+ sensors (Sequence[str], optional): The sensors to collate. Defaults to
+ None.
+ pin_memory (bool | FieldReference, optional): Whether to pin memory.
+ Defaults to True.
+ shuffle (bool | FieldReference, optional): Whether to shuffle the
+ dataset. Defaults to True.
+ aspect_ratio_grouping (bool | FieldReference, optional): Whether to
+ group the samples by aspect ratio. Defaults to False.
+
+ Returns:
+ ConfigDict: Configuration that can be instantiate as a dataloader.
+ """
+ if batchprocess_cfg is None:
+ batchprocess_cfg = class_config(ToTensor)
+
+ if isinstance(datasets_cfg, list):
+ dataset = class_config(DataPipe, datasets=datasets_cfg)
+ else:
+ dataset = datasets_cfg
+
+ return class_config(
+ build_train_dataloader,
+ dataset=dataset,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ batchprocess_fn=batchprocess_cfg,
+ collate_fn=get_callable_cfg(collate_fn),
+ collate_keys=collate_keys,
+ sensors=sensors,
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ aspect_ratio_grouping=aspect_ratio_grouping,
+ )
+
+
+def get_inference_dataloaders_cfg(
+ datasets_cfg: ConfigDict | list[ConfigDict],
+ samples_per_gpu: int | FieldReference = 1,
+ workers_per_gpu: int | FieldReference = 1,
+ video_based_inference: bool | FieldReference = False,
+ batchprocess_cfg: ConfigDict | None = None,
+ collate_fn: GenericFunc = default_collate,
+ collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
+ sensors: Sequence[str] | None = None,
+) -> ConfigDict:
+ """Creates dataloader configuration given dataset for inference.
+
+ Args:
+ datasets_cfg (ConfigDict | list[ConfigDict]): The configuration
+ contains the single dataset or datasets.
+ samples_per_gpu (int | FieldReference, optional): How many samples each
+ GPU will process per batch. Defaults to 1.
+ workers_per_gpu (int | FieldReference, optional): How many workers each
+ GPU will spawn. Defaults to 1.
+ video_based_inference (bool | FieldReference , optional): Whether to
+ split dataset by sequences. Defaults to False.
+ batchprocess_cfg (ConfigDict, optional): The config that contains the
+ batch processing operations. Defaults to None. If None, ToTensor
+ will be used.
+ collate_fn (GenericFunc, optional): The collate function that will be
+ used to stack the batch. Defaults to default_collate.
+ collate_keys (Sequence[str], optional): The keys to collate. Defaults
+ to DEFAULT_COLLATE_KEYS.
+ sensors (Sequence[str], optional): The sensors to collate. Defaults to
+ None.
+
+ Returns:
+ ConfigDict: The dataloader configuration.
+ """
+ if batchprocess_cfg is None:
+ batchprocess_cfg = class_config(ToTensor)
+
+ return class_config(
+ build_inference_dataloaders,
+ datasets=datasets_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ video_based_inference=video_based_inference,
+ batchprocess_fn=batchprocess_cfg,
+ collate_fn=get_callable_cfg(collate_fn),
+ collate_keys=collate_keys,
+ sensors=sensors,
+ )
diff --git a/vis4d/zoo/base/datasets/__init__.py b/vis4d/zoo/base/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..085c5df6ca1c7a1d779444fb412de95c703d1d2d
--- /dev/null
+++ b/vis4d/zoo/base/datasets/__init__.py
@@ -0,0 +1 @@
+"""Model Zoo base datasets."""
diff --git a/vis4d/zoo/base/datasets/bdd100k/__init__.py b/vis4d/zoo/base/datasets/bdd100k/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66ea6eb77ddbc574d1b77853f5d6926b28d2aba4
--- /dev/null
+++ b/vis4d/zoo/base/datasets/bdd100k/__init__.py
@@ -0,0 +1,19 @@
+"""BDD100K dataset config."""
+
+from .detect import (
+ CONN_BDD100K_DET_EVAL,
+ CONN_BDD100K_INS_EVAL,
+ get_bdd100k_detection_config,
+)
+from .sem_seg import CONN_BDD100K_SEG_EVAL, get_bdd100k_sem_seg_cfg
+from .track import CONN_BDD100K_TRACK_EVAL, get_bdd100k_track_cfg
+
+__all__ = [
+ "CONN_BDD100K_DET_EVAL",
+ "CONN_BDD100K_INS_EVAL",
+ "get_bdd100k_detection_config",
+ "get_bdd100k_sem_seg_cfg",
+ "CONN_BDD100K_SEG_EVAL",
+ "get_bdd100k_track_cfg",
+ "CONN_BDD100K_TRACK_EVAL",
+]
diff --git a/vis4d/zoo/base/datasets/bdd100k/detect.py b/vis4d/zoo/base/datasets/bdd100k/detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..777323029a78d139ced7c8066fec62312e8a90cc
--- /dev/null
+++ b/vis4d/zoo/base/datasets/bdd100k/detect.py
@@ -0,0 +1,243 @@
+# pylint: disable=duplicate-code
+"""BDD100K dataset config for object detection."""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.bdd100k import BDD100K
+from vis4d.data.io import DataBackend
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.flip import (
+ FlipBoxes2D,
+ FlipImages,
+ FlipInstanceMasks,
+)
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeImages,
+ ResizeInstanceMasks,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+CONN_BDD100K_DET_EVAL = {
+ "frame_ids": data_key("frame_ids"),
+ "sample_names": data_key("sample_names"),
+ "sequence_names": data_key("sequence_names"),
+ "pred_boxes": pred_key("boxes"),
+ "pred_scores": pred_key("scores"),
+ "pred_classes": pred_key("class_ids"),
+}
+CONN_BDD100K_INS_EVAL = {
+ "frame_ids": data_key("frame_ids"),
+ "sample_names": data_key("sample_names"),
+ "sequence_names": data_key("sequence_names"),
+ "pred_boxes": pred_key("boxes.boxes"),
+ "pred_scores": pred_key("boxes.scores"),
+ "pred_classes": pred_key("boxes.class_ids"),
+ "pred_masks": pred_key("masks.masks"),
+}
+
+
+def get_train_dataloader(
+ data_root: str,
+ anno_path: str,
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d),
+ ins_seg: bool = False,
+ data_backend: None | DataBackend = None,
+ image_size: tuple[int, int] = (720, 1280),
+ multi_scale: bool = False,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> ConfigDict:
+ """Get the default train dataloader for BDD100K segmentation."""
+ # Train Dataset
+ train_dataset_cfg = class_config(
+ BDD100K,
+ data_root=data_root,
+ annotation_path=anno_path,
+ config_path="ins_seg" if ins_seg else "det",
+ keys_to_load=keys_to_load,
+ data_backend=data_backend,
+ skip_empty_samples=True,
+ )
+
+ # Train Preprocessing
+ if multi_scale:
+ ms_shapes = [(image_size[0] - 24 * i, image_size[1]) for i in range(6)]
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=ms_shapes,
+ keep_ratio=True,
+ multiscale_mode="list",
+ align_long_edge=True,
+ )
+ ]
+ else:
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ align_long_edge=True,
+ )
+ ]
+ preprocess_transforms += [
+ class_config(ResizeImages),
+ class_config(ResizeBoxes2D),
+ ]
+ if K.instance_masks in keys_to_load:
+ preprocess_transforms.append(class_config(ResizeInstanceMasks))
+
+ flip_transforms = [class_config(FlipImages), class_config(FlipBoxes2D)]
+ if K.instance_masks in keys_to_load:
+ flip_transforms.append(class_config(FlipInstanceMasks))
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=flip_transforms,
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ train_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[class_config(PadImages), class_config(ToTensor)],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ batchprocess_cfg=train_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_test_dataloader(
+ data_root: str,
+ anno_path: str,
+ keys_to_load: Sequence[str] = (K.images, K.original_images),
+ ins_seg: bool = False,
+ data_backend: None | DataBackend = None,
+ image_size: tuple[int, int] = (720, 1280),
+ samples_per_gpu: int = 1,
+ workers_per_gpu: int = 1,
+) -> ConfigDict:
+ """Get the default test dataloader for BDD100K segmentation."""
+ # Test Dataset
+ test_dataset_cfg = class_config(
+ BDD100K,
+ data_root=data_root,
+ annotation_path=anno_path,
+ config_path="ins_seg" if ins_seg else "det",
+ keys_to_load=keys_to_load,
+ data_backend=data_backend,
+ )
+
+ # Test Preprocessing
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ align_long_edge=True,
+ ),
+ class_config(ResizeImages),
+ ]
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ test_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[class_config(PadImages), class_config(ToTensor)],
+ )
+
+ # Test Dataset Config
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_bdd100k_detection_config(
+ data_root: str = "data/bdd100k/images/100k",
+ train_split: str = "train",
+ train_keys_to_load: Sequence[str] = (K.images, K.boxes2d),
+ test_split: str = "val",
+ test_keys_to_load: Sequence[str] = (K.images, K.original_images),
+ ins_seg: bool = False,
+ data_backend: None | ConfigDict = None,
+ image_size: tuple[int, int] = (720, 1280),
+ multi_scale: bool = False,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for BDD100K detection."""
+ data = DataConfig()
+
+ if K.instance_masks in train_keys_to_load:
+ train_anno_path = "data/bdd100k/labels/ins_seg_train_rle.json"
+ test_anno_path = "data/bdd100k/labels/ins_seg_val_rle.json"
+ else:
+ train_anno_path = "data/bdd100k/labels/det_20/det_train.json"
+ test_anno_path = "data/bdd100k/labels/det_20/det_val.json"
+
+ data.train_dataloader = get_train_dataloader(
+ data_root=f"{data_root}/{train_split}",
+ anno_path=train_anno_path,
+ keys_to_load=train_keys_to_load,
+ ins_seg=ins_seg,
+ data_backend=data_backend,
+ image_size=image_size,
+ multi_scale=multi_scale,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_root=f"{data_root}/{test_split}",
+ anno_path=test_anno_path,
+ keys_to_load=test_keys_to_load,
+ ins_seg=ins_seg,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ )
+
+ return data
diff --git a/vis4d/zoo/base/datasets/bdd100k/sem_seg.py b/vis4d/zoo/base/datasets/bdd100k/sem_seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2351a5b8e4d17599eecd4cae2e57cfa6e023189
--- /dev/null
+++ b/vis4d/zoo/base/datasets/bdd100k/sem_seg.py
@@ -0,0 +1,216 @@
+# pylint: disable=duplicate-code
+"""BDD100K dataset config for semantic segmentation."""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.bdd100k import BDD100K
+from vis4d.data.io import DataBackend
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.crop import (
+ CropImages,
+ CropSegMasks,
+ GenCropParameters,
+)
+from vis4d.data.transforms.flip import FlipImages, FlipSegMasks
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages, PadSegMasks
+from vis4d.data.transforms.photometric import ColorJitter
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeImages,
+ ResizeSegMasks,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+CONN_BDD100K_SEG_EVAL = {
+ "data_names": data_key("sample_names"),
+ "masks_list": pred_key("masks"),
+}
+
+
+def get_train_dataloader(
+ data_root: str,
+ anno_path: str,
+ keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ data_backend: None | DataBackend = None,
+ image_size: tuple[int, int] = (720, 1280),
+ crop_size: tuple[int, int] = (512, 1024),
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> ConfigDict:
+ """Get the default train dataloader for BDD100K segmentation."""
+ # Train Dataset
+ train_dataset_cfg = class_config(
+ BDD100K,
+ data_root=data_root,
+ annotation_path=anno_path,
+ config_path="sem_seg",
+ keys_to_load=keys_to_load,
+ data_backend=data_backend,
+ )
+
+ # Train Preprocessing
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ scale_range=(0.5, 2.0),
+ ),
+ class_config(ResizeImages),
+ class_config(ResizeSegMasks),
+ ]
+
+ preprocess_transforms = [
+ class_config(GenCropParameters, shape=crop_size, cat_max_ratio=0.75),
+ class_config(CropImages),
+ class_config(CropSegMasks),
+ ]
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[class_config(FlipImages), class_config(FlipSegMasks)],
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[class_config(ColorJitter)],
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ train_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages, shape=crop_size),
+ class_config(PadSegMasks, shape=crop_size),
+ class_config(ToTensor),
+ ],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ batchprocess_cfg=train_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_test_dataloader(
+ data_root: str,
+ anno_path: str,
+ keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ data_backend: None | DataBackend = None,
+ image_size: tuple[int, int] = (720, 1280),
+ samples_per_gpu: int = 1,
+ workers_per_gpu: int = 1,
+) -> ConfigDict:
+ """Get the default test dataloader for BDD100K segmentation."""
+ # Test Dataset
+ test_dataset_cfg = class_config(
+ BDD100K,
+ data_root=data_root,
+ annotation_path=anno_path,
+ config_path="sem_seg",
+ keys_to_load=keys_to_load,
+ data_backend=data_backend,
+ )
+
+ # Test Preprocessing
+ preprocess_transforms = [
+ class_config(GenResizeParameters, shape=image_size, keep_ratio=True),
+ class_config(ResizeImages),
+ class_config(ResizeSegMasks),
+ ]
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ test_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages, shape=image_size),
+ class_config(PadSegMasks, shape=image_size),
+ class_config(ToTensor),
+ ],
+ )
+
+ # Test Dataset Config
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_bdd100k_sem_seg_cfg(
+ data_root: str = "data/bdd100k/images/10k",
+ train_split: str = "train",
+ train_keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ test_split: str = "val",
+ test_keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ data_backend: None | ConfigDict = None,
+ image_size: tuple[int, int] = (720, 1280),
+ crop_size: tuple[int, int] = (512, 1024),
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for BDD100K semantic segmentation."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_root=f"{data_root}/{train_split}",
+ anno_path=f"data/bdd100k/labels/sem_seg_{train_split}_rle.json",
+ keys_to_load=train_keys_to_load,
+ data_backend=data_backend,
+ image_size=image_size,
+ crop_size=crop_size,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_root=f"{data_root}/{test_split}",
+ anno_path=f"data/bdd100k/labels/sem_seg_{test_split}_rle.json",
+ keys_to_load=test_keys_to_load,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=1,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ return data
diff --git a/vis4d/zoo/base/datasets/bdd100k/track.py b/vis4d/zoo/base/datasets/bdd100k/track.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed3bc2b5228c9eff6e83c008ec7f31f03b058f48
--- /dev/null
+++ b/vis4d/zoo/base/datasets/bdd100k/track.py
@@ -0,0 +1,212 @@
+"""BDD100K tracking dataset configs."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.bdd100k import BDD100K, bdd100k_track_map
+from vis4d.data.reference import MultiViewDataset, UniformViewSampler
+from vis4d.data.transforms import RandomApply, compose
+from vis4d.data.transforms.flip import FlipBoxes2D, FlipImages
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.post_process import PostProcessBoxes2D
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeImages,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+CONN_BDD100K_TRACK_EVAL = {
+ "frame_ids": data_key("frame_ids"),
+ "sample_names": data_key(K.sample_names),
+ "sequence_names": data_key(K.sequence_names),
+ "pred_boxes": pred_key("boxes"),
+ "pred_classes": pred_key("class_ids"),
+ "pred_scores": pred_key("scores"),
+ "pred_track_ids": pred_key("track_ids"),
+}
+
+
+def get_train_dataloader(
+ data_backend: None | ConfigDict,
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default train dataloader for BDD100K tracking."""
+ bdd100k_det_train = class_config(
+ BDD100K,
+ data_root="data/bdd100k/images/100k/train/",
+ keys_to_load=(K.images, K.boxes2d),
+ annotation_path="data/bdd100k/labels/det_20/det_train.json",
+ config_path="det",
+ data_backend=data_backend,
+ category_map=bdd100k_track_map,
+ skip_empty_samples=True,
+ cache_as_binary=True,
+ cached_file_path="data/bdd100k/det_train.pkl",
+ )
+
+ bdd100k_track_train = class_config(
+ BDD100K,
+ data_root="data/bdd100k/images/track/train/",
+ keys_to_load=(K.images, K.boxes2d),
+ annotation_path="data/bdd100k/labels/box_track_20/train/",
+ config_path="box_track",
+ data_backend=data_backend,
+ category_map=bdd100k_track_map,
+ skip_empty_samples=True,
+ cache_as_binary=True,
+ cached_file_path="data/bdd100k/track_train.pkl",
+ )
+
+ train_dataset_cfg = [
+ class_config(
+ MultiViewDataset,
+ dataset=bdd100k_det_train,
+ sampler=class_config(
+ UniformViewSampler, scope=0, num_ref_samples=1
+ ),
+ ),
+ class_config(
+ MultiViewDataset,
+ dataset=bdd100k_track_train,
+ sampler=class_config(
+ UniformViewSampler, scope=3, num_ref_samples=1
+ ),
+ ),
+ ]
+
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=(720, 1280),
+ keep_ratio=True,
+ ),
+ class_config(ResizeImages),
+ class_config(ResizeBoxes2D),
+ ]
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[
+ class_config(FlipImages),
+ class_config(FlipBoxes2D),
+ ],
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+ preprocess_transforms.append(class_config(PostProcessBoxes2D))
+
+ train_preprocess_cfg = class_config(
+ compose,
+ transforms=preprocess_transforms,
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[class_config(PadImages), class_config(ToTensor)],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ batchprocess_cfg=train_batchprocess_cfg,
+ )
+
+
+def get_test_dataloader(
+ data_backend: None | ConfigDict,
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default test dataloader for BDD100K tracking."""
+ test_dataset = class_config(
+ BDD100K,
+ data_root="data/bdd100k/images/track/val/",
+ keys_to_load=(K.images, K.original_images),
+ annotation_path="data/bdd100k/labels/box_track_20/val/",
+ config_path="box_track",
+ category_map=bdd100k_track_map,
+ data_backend=data_backend,
+ cache_as_binary=True,
+ cached_file_path="data/bdd100k/track_val.pkl",
+ )
+
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=(720, 1280),
+ keep_ratio=True,
+ ),
+ class_config(ResizeImages),
+ class_config(NormalizeImages),
+ ]
+
+ test_preprocess_cfg = class_config(
+ compose,
+ transforms=preprocess_transforms,
+ )
+
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages),
+ class_config(ToTensor),
+ ],
+ )
+
+ test_dataset_cfg = class_config(
+ DataPipe,
+ datasets=test_dataset,
+ preprocess_fn=test_preprocess_cfg,
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ video_based_inference=True,
+ batchprocess_cfg=test_batchprocess_cfg,
+ )
+
+
+def get_bdd100k_track_cfg(
+ data_backend: None | ConfigDict = None,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for BDD100K tracking."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_backend=data_backend,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_backend=data_backend,
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ )
+
+ return data
diff --git a/vis4d/zoo/base/datasets/coco/__init__.py b/vis4d/zoo/base/datasets/coco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7015d84df0e5b1b05b2aace1aa7393cd49bc17
--- /dev/null
+++ b/vis4d/zoo/base/datasets/coco/__init__.py
@@ -0,0 +1,15 @@
+"""COCO dataset config."""
+
+from .detection import (
+ CONN_COCO_BBOX_EVAL,
+ CONN_COCO_MASK_EVAL,
+ get_coco_detection_cfg,
+)
+from .sem_seg import get_coco_sem_seg_cfg
+
+__all__ = [
+ "get_coco_detection_cfg",
+ "CONN_COCO_BBOX_EVAL",
+ "CONN_COCO_MASK_EVAL",
+ "get_coco_sem_seg_cfg",
+]
diff --git a/vis4d/zoo/base/datasets/coco/detection.py b/vis4d/zoo/base/datasets/coco/detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..33330dbc52e6b1486e6bb18e0828ce11f7a31365
--- /dev/null
+++ b/vis4d/zoo/base/datasets/coco/detection.py
@@ -0,0 +1,246 @@
+# pylint: disable=duplicate-code
+"""COCO data loading config for object detection."""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.coco import COCO
+from vis4d.data.io import DataBackend
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.flip import (
+ FlipBoxes2D,
+ FlipImages,
+ FlipInstanceMasks,
+)
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeImages,
+ ResizeInstanceMasks,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+CONN_COCO_BBOX_EVAL = {
+ "coco_image_id": data_key(K.sample_names),
+ "pred_boxes": pred_key("boxes"),
+ "pred_scores": pred_key("scores"),
+ "pred_classes": pred_key("class_ids"),
+}
+
+CONN_COCO_MASK_EVAL = {
+ "coco_image_id": data_key(K.sample_names),
+ "pred_boxes": pred_key("boxes.boxes"),
+ "pred_scores": pred_key("boxes.scores"),
+ "pred_classes": pred_key("boxes.class_ids"),
+ "pred_masks": pred_key("masks"),
+}
+
+
+def get_train_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ data_backend: None | DataBackend,
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+ cache_as_binary: bool,
+ cached_file_path: str | None = None,
+) -> ConfigDict:
+ """Get the default train dataloader for COCO detection."""
+ # Train Dataset
+ train_dataset_cfg = class_config(
+ COCO,
+ keys_to_load=keys_to_load,
+ data_root=data_root,
+ split=split,
+ remove_empty=True,
+ data_backend=data_backend,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ # Train Preprocessing
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ align_long_edge=True,
+ ),
+ class_config(ResizeImages),
+ class_config(ResizeBoxes2D),
+ ]
+
+ if K.instance_masks in keys_to_load:
+ preprocess_transforms.append(class_config(ResizeInstanceMasks))
+
+ flip_transforms = [class_config(FlipImages), class_config(FlipBoxes2D)]
+
+ if K.instance_masks in keys_to_load:
+ flip_transforms.append(class_config(FlipInstanceMasks))
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=flip_transforms,
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ train_preprocess_cfg = class_config(
+ compose,
+ transforms=preprocess_transforms,
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages),
+ class_config(ToTensor),
+ ],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ batchprocess_cfg=train_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_test_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ data_backend: None | DataBackend,
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+ cache_as_binary: bool,
+ cached_file_path: str | None = None,
+) -> ConfigDict:
+ """Get the default test dataloader for COCO detection."""
+ # Test Dataset
+ test_dataset = class_config(
+ COCO,
+ keys_to_load=keys_to_load,
+ data_root=data_root,
+ split=split,
+ data_backend=data_backend,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ # Test Preprocessing
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ align_long_edge=True,
+ ),
+ class_config(ResizeImages),
+ class_config(ResizeBoxes2D),
+ ]
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ test_preprocess_cfg = class_config(
+ compose,
+ transforms=preprocess_transforms,
+ )
+
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages),
+ class_config(ToTensor),
+ ],
+ )
+
+ # Test Dataset Config
+ test_dataset_cfg = class_config(
+ DataPipe,
+ datasets=test_dataset,
+ preprocess_fn=test_preprocess_cfg,
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_coco_detection_cfg(
+ data_root: str = "data/coco",
+ train_split: str = "train2017",
+ train_keys_to_load: Sequence[str] = (
+ K.images,
+ K.boxes2d,
+ K.boxes2d_classes,
+ ),
+ train_cached_file_path: str | None = "data/coco/train.pkl",
+ test_split: str = "val2017",
+ test_keys_to_load: Sequence[str] = (
+ K.images,
+ K.original_images,
+ K.boxes2d,
+ K.boxes2d_classes,
+ ),
+ test_cached_file_path: str | None = "data/coco/val.pkl",
+ cache_as_binary: bool = True,
+ data_backend: None | ConfigDict = None,
+ image_size: tuple[int, int] = (800, 1333),
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for COCO detection."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_root=data_root,
+ split=train_split,
+ keys_to_load=train_keys_to_load,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=train_cached_file_path,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_root=data_root,
+ split=test_split,
+ keys_to_load=test_keys_to_load,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=1,
+ workers_per_gpu=workers_per_gpu,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=test_cached_file_path,
+ )
+
+ return data
diff --git a/vis4d/zoo/base/datasets/coco/sem_seg.py b/vis4d/zoo/base/datasets/coco/sem_seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..36a4e489c80cd2634adc26afda1a226259542ca9
--- /dev/null
+++ b/vis4d/zoo/base/datasets/coco/sem_seg.py
@@ -0,0 +1,195 @@
+# pylint: disable=duplicate-code
+"""COCO data loading config for for semantic segmentation."""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.coco import COCO
+from vis4d.data.io import DataBackend
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.flip import FlipImages, FlipSegMasks
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages, PadSegMasks
+from vis4d.data.transforms.photometric import ColorJitter
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeImages,
+ ResizeSegMasks,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+
+def get_train_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ data_backend: None | DataBackend,
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default train dataloader for COCO detection."""
+ # Train Dataset
+ train_dataset_cfg = class_config(
+ COCO,
+ keys_to_load=keys_to_load,
+ data_root=data_root,
+ split=split,
+ remove_empty=True,
+ data_backend=data_backend,
+ )
+
+ # Train Preprocessing
+ preprocess_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ scale_range=(0.5, 2.0),
+ ),
+ class_config(ResizeImages),
+ class_config(ResizeSegMasks),
+ ]
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[class_config(FlipImages), class_config(FlipSegMasks)],
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[class_config(ColorJitter)],
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ train_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages),
+ class_config(PadSegMasks),
+ class_config(ToTensor),
+ ],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ batchprocess_cfg=train_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_test_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ data_backend: None | DataBackend,
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default test dataloader for COCO detection."""
+ # Test Dataset
+ test_dataset = class_config(
+ COCO,
+ keys_to_load=keys_to_load,
+ data_root=data_root,
+ split=split,
+ data_backend=data_backend,
+ )
+
+ # Test Preprocessing
+ preprocess_transforms = [
+ class_config(GenResizeParameters, shape=image_size, keep_ratio=True),
+ class_config(ResizeImages),
+ class_config(ResizeSegMasks),
+ ]
+
+ preprocess_transforms.append(class_config(NormalizeImages))
+
+ test_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages, shape=image_size),
+ class_config(PadSegMasks, shape=image_size),
+ class_config(ToTensor),
+ ],
+ )
+
+ # Test Dataset Config
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_coco_sem_seg_cfg(
+ data_root: str = "data/coco",
+ train_split: str = "train2017",
+ train_keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ test_split: str = "val2017",
+ test_keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ data_backend: None | ConfigDict = None,
+ image_size: tuple[int, int] = (520, 520),
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for COCO semantic segmentation."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_root=data_root,
+ split=train_split,
+ keys_to_load=train_keys_to_load,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_root=data_root,
+ split=test_split,
+ keys_to_load=test_keys_to_load,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=1,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ return data
diff --git a/vis4d/zoo/base/datasets/imagenet.py b/vis4d/zoo/base/datasets/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f73f02ee51a613ffa5c34216045fdcde93205bf
--- /dev/null
+++ b/vis4d/zoo/base/datasets/imagenet.py
@@ -0,0 +1,217 @@
+"""ImageNet classification config."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.imagenet import ImageNet
+from vis4d.data.transforms.autoaugment import RandAug
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.crop import (
+ CropImages,
+ GenCentralCropParameters,
+ GenRandomSizeCropParameters,
+)
+from vis4d.data.transforms.flip import FlipImages
+from vis4d.data.transforms.mixup import (
+ GenMixupParameters,
+ MixupCategories,
+ MixupImages,
+)
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.random_erasing import RandomErasing
+from vis4d.data.transforms.resize import GenResizeParameters, ResizeImages
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+CONN_IMAGENET_CLS_EVAL = {
+ "prediction": pred_key("probs"),
+ "groundtruth": data_key("categories"),
+}
+
+
+def get_train_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default train dataloader for ImageNet 1K dataset."""
+ # Train Dataset
+ train_dataset_cfg = class_config(
+ ImageNet,
+ data_root=data_root,
+ split=split,
+ num_classes=1000,
+ keys_to_load=keys_to_load,
+ )
+
+ flip_trans = class_config(
+ RandomApply,
+ transforms=[class_config(FlipImages)],
+ probability=0.5,
+ )
+ random_resized_crop_trans = [
+ class_config(GenRandomSizeCropParameters),
+ class_config(CropImages),
+ class_config(GenResizeParameters, shape=image_size, keep_ratio=False),
+ class_config(ResizeImages),
+ ]
+ random_aug_trans = [
+ class_config(RandAug, magnitude=10, use_increasing=True),
+ class_config(RandomErasing),
+ ]
+ normalize_trans = class_config(NormalizeImages)
+ train_preprocess_cfg = class_config(
+ compose,
+ transforms=[
+ flip_trans,
+ *random_resized_crop_trans,
+ *random_aug_trans,
+ normalize_trans,
+ ],
+ )
+
+ mixup_trans = [
+ class_config(GenMixupParameters, alpha=0.2, out_shape=image_size),
+ class_config(MixupImages),
+ class_config(MixupCategories, num_classes=1000, label_smoothing=0.1),
+ ]
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ *mixup_trans,
+ class_config(ToTensor),
+ ],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ batchprocess_cfg=train_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_test_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+ crop_pct: float = 0.875,
+) -> ConfigDict:
+ """Get the default test dataloader for COCO detection."""
+ # Test Dataset
+
+ test_dataset_cfg = class_config(
+ ImageNet,
+ data_root=data_root,
+ split=split,
+ num_classes=1000,
+ keys_to_load=keys_to_load,
+ )
+
+ crop_size = tuple(int(size / crop_pct) for size in image_size)
+ resized_crop_trans = [
+ class_config(
+ GenResizeParameters,
+ shape=crop_size,
+ keep_ratio=True,
+ allow_overflow=True,
+ ),
+ class_config(ResizeImages),
+ class_config(
+ GenCentralCropParameters, shape=image_size, keep_ratio=False
+ ),
+ class_config(CropImages),
+ ]
+ normalize_trans = class_config(NormalizeImages)
+ test_preprocess_cfg = class_config(
+ compose,
+ transforms=[
+ *resized_crop_trans,
+ normalize_trans,
+ ],
+ )
+
+ mixup_trans = [
+ class_config(GenMixupParameters, alpha=0.2, out_shape=image_size),
+ class_config(MixupImages),
+ class_config(MixupCategories, num_classes=1000, label_smoothing=0.1),
+ ]
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ *mixup_trans,
+ class_config(ToTensor),
+ ],
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=test_dataset_cfg,
+ preprocess_fn=test_preprocess_cfg,
+ ),
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_imagenet_cls_cfg(
+ data_root: str = "data/imagenet",
+ train_split: str = "train",
+ train_keys_to_load: Sequence[str] = (
+ K.images,
+ K.categories,
+ ),
+ test_split: str = "val",
+ test_keys_to_load: Sequence[str] = (
+ K.images,
+ K.categories,
+ ),
+ image_size: tuple[int, int] = (224, 224),
+ samples_per_gpu: int = 256,
+ workers_per_gpu: int = 8,
+) -> DataConfig:
+ """Get the default config for COCO detection."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_root=data_root,
+ split=train_split,
+ keys_to_load=train_keys_to_load,
+ image_size=image_size,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_root=data_root,
+ split=test_split,
+ keys_to_load=test_keys_to_load,
+ image_size=image_size,
+ samples_per_gpu=1,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ return data
diff --git a/vis4d/zoo/base/datasets/nuscenes/__init__.py b/vis4d/zoo/base/datasets/nuscenes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..35ed92ba8a9c902abb9c1e2d22b9075fd9313690
--- /dev/null
+++ b/vis4d/zoo/base/datasets/nuscenes/__init__.py
@@ -0,0 +1,21 @@
+"""NuScenes dataset config."""
+
+from .nuscenes import (
+ get_nusc_mini_train_cfg,
+ get_nusc_mini_val_cfg,
+ get_nusc_train_cfg,
+ get_nusc_val_cfg,
+)
+from .nuscenes_mono import (
+ get_nusc_mono_mini_train_cfg,
+ get_nusc_mono_train_cfg,
+)
+
+__all__ = [
+ "get_nusc_train_cfg",
+ "get_nusc_mini_train_cfg",
+ "get_nusc_val_cfg",
+ "get_nusc_mini_val_cfg",
+ "get_nusc_mono_train_cfg",
+ "get_nusc_mono_mini_train_cfg",
+]
diff --git a/vis4d/zoo/base/datasets/nuscenes/nuscenes.py b/vis4d/zoo/base/datasets/nuscenes/nuscenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2491e6a19700fe00edec9258f83168a4c936a1
--- /dev/null
+++ b/vis4d/zoo/base/datasets/nuscenes/nuscenes.py
@@ -0,0 +1,115 @@
+"""NuScenes multi-sensor video dataset config."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.nuscenes import NuScenes
+
+
+def get_nusc_train_cfg(
+ data_root: str = "data/nuscenes",
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
+ skip_empty_samples: bool = True,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ data_backend: None | ConfigDict = None,
+) -> ConfigDict:
+ """Get the nuScenes validation dataset config."""
+ if cache_as_binary and cached_file_path is None:
+ cached_file_path = f"{data_root}/train.pkl"
+
+ return class_config(
+ NuScenes,
+ data_root=data_root,
+ keys_to_load=keys_to_load,
+ version="v1.0-trainval",
+ split="train",
+ skip_empty_samples=skip_empty_samples,
+ data_backend=data_backend,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+
+def get_nusc_mini_train_cfg(
+ data_root: str = "data/nuscenes",
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
+ skip_empty_samples: bool = True,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ data_backend: None | ConfigDict = None,
+) -> ConfigDict:
+ """Get the nuScenes validation dataset config."""
+ if cache_as_binary and cached_file_path is None:
+ cached_file_path = f"{data_root}/mini_train.pkl"
+
+ return class_config(
+ NuScenes,
+ data_root=data_root,
+ keys_to_load=keys_to_load,
+ version="v1.0-mini",
+ split="mini_train",
+ skip_empty_samples=skip_empty_samples,
+ data_backend=data_backend,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+
+def get_nusc_val_cfg(
+ data_root: str = "data/nuscenes",
+ keys_to_load: Sequence[str] = (K.images, K.original_images, K.boxes3d),
+ skip_empty_samples: bool = False,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ image_channel_mode: str = "RGB",
+ data_backend: None | ConfigDict = None,
+) -> ConfigDict:
+ """Get the nuScenes validation dataset config."""
+ if cache_as_binary and cached_file_path is None:
+ cached_file_path = f"{data_root}/val.pkl"
+
+ return class_config(
+ NuScenes,
+ data_root=data_root,
+ image_channel_mode=image_channel_mode,
+ keys_to_load=keys_to_load,
+ version="v1.0-trainval",
+ split="val",
+ skip_empty_samples=skip_empty_samples,
+ data_backend=data_backend,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+
+def get_nusc_mini_val_cfg(
+ data_root: str = "data/nuscenes",
+ keys_to_load: Sequence[str] = (K.images, K.original_images, K.boxes3d),
+ skip_empty_samples: bool = False,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ image_channel_mode: str = "RGB",
+ data_backend: None | ConfigDict = None,
+) -> ConfigDict:
+ """Get the nuScenes mini validation dataset config."""
+ if cache_as_binary and cached_file_path is None:
+ cached_file_path = f"{data_root}/mini_val.pkl"
+
+ return class_config(
+ NuScenes,
+ data_root=data_root,
+ image_channel_mode=image_channel_mode,
+ keys_to_load=keys_to_load,
+ version="v1.0-mini",
+ split="mini_val",
+ skip_empty_samples=skip_empty_samples,
+ data_backend=data_backend,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
diff --git a/vis4d/zoo/base/datasets/nuscenes/nuscenes_mono.py b/vis4d/zoo/base/datasets/nuscenes/nuscenes_mono.py
new file mode 100644
index 0000000000000000000000000000000000000000..b28ad98d59c9357ed4e16f7bb48c171002c15016
--- /dev/null
+++ b/vis4d/zoo/base/datasets/nuscenes/nuscenes_mono.py
@@ -0,0 +1,61 @@
+"""NuScenes monocular dataset config."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.nuscenes_mono import NuScenesMono
+
+
+def get_nusc_mono_train_cfg(
+ data_root: str = "data/nuscenes",
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
+ skip_empty_samples: bool = True,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ data_backend: None | ConfigDict = None,
+) -> ConfigDict:
+ """Get the nuScenes monocular training dataset config."""
+ if cache_as_binary and cached_file_path is None:
+ cached_file_path = f"{data_root}/mono_train.pkl"
+
+ return class_config(
+ NuScenesMono,
+ data_root=data_root,
+ keys_to_load=keys_to_load,
+ version="v1.0-trainval",
+ split="train",
+ skip_empty_samples=skip_empty_samples,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ data_backend=data_backend,
+ )
+
+
+def get_nusc_mono_mini_train_cfg(
+ data_root: str = "data/nuscenes",
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
+ skip_empty_samples: bool = True,
+ cache_as_binary: bool = True,
+ cached_file_path: str | None = None,
+ data_backend: None | ConfigDict = None,
+) -> ConfigDict:
+ """Get the nuScenes monocular mini training dataset config."""
+ if cache_as_binary and cached_file_path is None:
+ cached_file_path = f"{data_root}/mono_mini_train.pkl"
+
+ return class_config(
+ NuScenesMono,
+ data_root=data_root,
+ keys_to_load=keys_to_load,
+ version="v1.0-mini",
+ split="mini_train",
+ skip_empty_samples=skip_empty_samples,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ data_backend=data_backend,
+ )
diff --git a/vis4d/zoo/base/datasets/shift/__init__.py b/vis4d/zoo/base/datasets/shift/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53686843f4f5e9bc3a47dfe806f1d48053589e79
--- /dev/null
+++ b/vis4d/zoo/base/datasets/shift/__init__.py
@@ -0,0 +1,27 @@
+"""SHIFT dataset config."""
+
+from .tasks import (
+ CONN_SHIFT_DET_EVAL,
+ CONN_SHIFT_INS_EVAL,
+ get_shift_depth_est_config,
+ get_shift_det_config,
+ get_shift_instance_seg_config,
+ get_shift_multitask_2d_config,
+ get_shift_multitask_3d_config,
+ get_shift_optical_flow_config,
+ get_shift_sem_seg_config,
+ get_shift_tracking_config,
+)
+
+__all__ = [
+ "CONN_SHIFT_DET_EVAL",
+ "CONN_SHIFT_INS_EVAL",
+ "get_shift_depth_est_config",
+ "get_shift_det_config",
+ "get_shift_instance_seg_config",
+ "get_shift_tracking_config",
+ "get_shift_multitask_2d_config",
+ "get_shift_multitask_3d_config",
+ "get_shift_optical_flow_config",
+ "get_shift_sem_seg_config",
+]
diff --git a/vis4d/zoo/base/datasets/shift/common.py b/vis4d/zoo/base/datasets/shift/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ed84ec5b86306c856f6883d163884799bf3cee3
--- /dev/null
+++ b/vis4d/zoo/base/datasets/shift/common.py
@@ -0,0 +1,414 @@
+"""SHIFT data loading config for data augmentation."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections.config_dict import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.shift import SHIFT
+from vis4d.data.loader import default_collate, multi_sensor_collate
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.crop import (
+ CropBoxes2D,
+ CropDepthMaps,
+ CropImages,
+ CropOpticalFlows,
+ CropSegMasks,
+ GenCropParameters,
+)
+from vis4d.data.transforms.flip import (
+ FlipBoxes2D,
+ FlipDepthMaps,
+ FlipImages,
+ FlipInstanceMasks,
+ FlipOpticalFlows,
+ FlipSegMasks,
+)
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.photometric import ColorJitter
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeDepthMaps,
+ ResizeImages,
+ ResizeInstanceMasks,
+ ResizeOpticalFlows,
+ ResizeSegMasks,
+)
+from vis4d.data.transforms.select_sensor import SelectSensor
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+
+IMAGE_MEAN = [122.884, 117.266, 110.287]
+IMAGE_STD = [59.925, 59.466, 60.69]
+
+
+def get_train_preprocessing(
+ image_size: tuple[int, int] = (800, 1280),
+ crop_size: tuple[int, int] | None = None,
+ horizontal_flip_prob: float = 0.5,
+ color_jitter_prob: float = 0.0,
+ keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ views_to_load: Sequence[str] = ("front",),
+) -> ConfigDict:
+ """Get the default data preprocessing for SHIFT dataset.
+
+ Args:
+ image_size: The image size to resize to. Defaults to (800, 1280).
+ crop_size: The crop size to crop to randomly, if not None. Defaults to
+ None. This step is applied after the resize step.
+ horizontal_flip_prob: The probability of horizontal flipping. Defaults
+ to 0.5.
+ color_jitter_prob: The probability of color jittering. Defaults to 0.5.
+ keys_to_load: The keys to load from the dataset. Defaults to
+ (K.images, K.seg_masks).
+ views_to_load: The views to load from the dataset. Defaults to
+ ("front",).
+
+ Returns:
+ The data preprocessing config.
+ """
+ preprocess_transforms = []
+
+ for key_to_load in keys_to_load:
+ assert key_to_load in SHIFT.KEYS, f"Invalid key: {key_to_load}"
+
+ views_arg = {}
+ if len(views_to_load) == 1:
+ preprocess_transforms.append(
+ class_config(
+ SelectSensor,
+ selected_sensor=views_to_load[0],
+ sensors=views_to_load,
+ )
+ )
+ elif len(views_to_load) > 1:
+ views_arg["sensors"] = views_to_load
+
+ # Resize
+ if image_size != (800, 1280):
+ preprocess_transforms.append(
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ **views_arg,
+ )
+ )
+ preprocess_transforms.append(class_config(ResizeImages, **views_arg))
+ if K.seg_masks in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeSegMasks, **views_arg)
+ )
+ if K.boxes2d in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeBoxes2D, **views_arg)
+ )
+ if K.instance_masks in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeInstanceMasks, **views_arg)
+ )
+ if K.depth_maps in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeDepthMaps, **views_arg)
+ )
+ if K.optical_flows in keys_to_load:
+ preprocess_transforms.append(
+ class_config(
+ ResizeOpticalFlows, normalized_flow=False, **views_arg
+ )
+ )
+
+ # Crop
+ if crop_size is not None:
+ preprocess_transforms.append(
+ class_config(
+ GenCropParameters,
+ shape=crop_size,
+ cat_max_ratio=0.75,
+ **views_arg,
+ ),
+ )
+ preprocess_transforms.append(class_config(CropImages, **views_arg))
+ if K.seg_masks in keys_to_load:
+ preprocess_transforms.append(
+ class_config(CropSegMasks, **views_arg)
+ )
+ if K.boxes2d in keys_to_load:
+ preprocess_transforms.append(
+ class_config(CropBoxes2D, **views_arg)
+ )
+ if K.depth_maps in keys_to_load:
+ preprocess_transforms.append(
+ class_config(CropDepthMaps, **views_arg)
+ )
+ if K.optical_flows in keys_to_load:
+ preprocess_transforms.append(
+ class_config(CropOpticalFlows, **views_arg)
+ )
+
+ # Random flip
+ if horizontal_flip_prob > 0:
+ flip_transforms = []
+ flip_transforms.append(class_config(FlipImages))
+ if K.seg_masks in keys_to_load:
+ flip_transforms.append(class_config(FlipSegMasks))
+ if K.boxes2d in keys_to_load:
+ flip_transforms.append(class_config(FlipBoxes2D))
+ if K.instance_masks in keys_to_load:
+ flip_transforms.append(class_config(FlipInstanceMasks))
+ if K.depth_maps in keys_to_load:
+ flip_transforms.append(class_config(FlipDepthMaps))
+ if K.optical_flows in keys_to_load:
+ flip_transforms.append(class_config(FlipOpticalFlows))
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=flip_transforms,
+ probability=horizontal_flip_prob,
+ **views_arg,
+ )
+ )
+
+ if color_jitter_prob > 0:
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[class_config(ColorJitter, **views_arg)],
+ probability=color_jitter_prob,
+ )
+ )
+
+ preprocess_transforms.append(
+ class_config(
+ NormalizeImages, mean=IMAGE_MEAN, std=IMAGE_STD, **views_arg
+ )
+ )
+ train_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ batchprocess_transforms = [class_config(ToTensor, **views_arg)]
+ train_batchprocess_cfg = class_config(
+ compose, transforms=batchprocess_transforms
+ )
+
+ return train_preprocess_cfg, train_batchprocess_cfg
+
+
+def get_test_preprocessing(
+ image_size: tuple[int, int] = (800, 1280),
+ keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ views_to_load: Sequence[str] = ("front",),
+) -> ConfigDict:
+ """Get the default data preprocessing for SHIFT dataset.
+
+ Args:
+ image_size: The image size to resize to. Defaults to (800, 1280).
+ keys_to_load: The keys to load from the dataset. Defaults to
+ (K.images, K.seg_masks).
+ views_to_load: The views to load from the dataset. Defaults to
+ ("front",).
+
+ Returns:
+ The data preprocessing config.
+ """
+ preprocess_transforms = []
+
+ for key_to_load in keys_to_load:
+ assert key_to_load in SHIFT.KEYS, f"Invalid key: {key_to_load}"
+
+ views_arg = {}
+ if len(views_to_load) == 1:
+ preprocess_transforms.append(
+ class_config(
+ SelectSensor,
+ selected_sensor=views_to_load[0],
+ sensors=views_to_load,
+ )
+ )
+ elif len(views_to_load) > 1:
+ views_arg["sensors"] = views_to_load
+
+ # Resize
+ if image_size != (800, 1280):
+ preprocess_transforms.append(
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ **views_arg,
+ )
+ )
+ preprocess_transforms.append(class_config(ResizeImages, **views_arg))
+ if K.seg_masks in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeSegMasks, **views_arg)
+ )
+ if K.boxes2d in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeBoxes2D, **views_arg)
+ )
+ if K.depth_maps in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeDepthMaps, **views_arg)
+ )
+ if K.optical_flows in keys_to_load:
+ preprocess_transforms.append(
+ class_config(ResizeOpticalFlows, **views_arg)
+ )
+
+ preprocess_transforms.append(
+ class_config(
+ NormalizeImages, mean=IMAGE_MEAN, std=IMAGE_STD, **views_arg
+ )
+ )
+ test_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ batchprocess_transforms = [class_config(ToTensor, **views_arg)]
+
+ test_batchprocess_cfg = class_config(
+ compose, transforms=batchprocess_transforms
+ )
+
+ return test_preprocess_cfg, test_batchprocess_cfg
+
+
+def get_shift_dataloader_config(
+ train_dataset_cfg: ConfigDict,
+ test_dataset_cfg: ConfigDict,
+ keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ image_size: tuple[int, int] = (800, 1280),
+ crop_size: tuple[int, int] | None = None,
+ horizontal_flip_prob: float = 0.5,
+ color_jitter_prob: float = 0.5,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+ train_views_to_load: Sequence[str] = ("front",),
+ test_views_to_load: Sequence[str] = ("front",),
+) -> ConfigDict:
+ """Get the default config for BDD100K segmentation."""
+ data = ConfigDict()
+
+ train_preprocess_cfg, train_batchprocess_cfg = get_train_preprocessing(
+ keys_to_load=keys_to_load,
+ image_size=image_size,
+ crop_size=crop_size,
+ horizontal_flip_prob=horizontal_flip_prob,
+ color_jitter_prob=color_jitter_prob,
+ views_to_load=train_views_to_load,
+ )
+
+ test_preprocess_cfg, test_batchprocess_cfg = get_test_preprocessing(
+ keys_to_load=keys_to_load,
+ image_size=image_size,
+ views_to_load=test_views_to_load,
+ )
+
+ data.train_dataloader = get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ batchprocess_cfg=train_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ shuffle=True,
+ collate_fn=(
+ multi_sensor_collate
+ if len(train_views_to_load) > 1
+ else default_collate
+ ),
+ )
+
+ # Test Dataset Config
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg
+ )
+ data.test_dataloader = get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ collate_fn=(
+ multi_sensor_collate
+ if len(test_views_to_load) > 1
+ else default_collate
+ ),
+ )
+ return data
+
+
+def get_shift_config( # pylint: disable=too-many-arguments, too-many-positional-arguments, line-too-long
+ data_root: str = "data/shift/images",
+ train_split: str = "train",
+ train_framerate: str = "images",
+ train_shift_type: str = "discrete",
+ train_views_to_load: Sequence[str] = ("front",),
+ train_keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ train_attributes_to_load: Sequence[dict[str, str | float]] | None = None,
+ train_skip_empty_frames: bool = False,
+ test_split: str = "val",
+ test_framerate: str = "images",
+ test_shift_type: str = "discrete",
+ test_views_to_load: Sequence[str] = ("front",),
+ test_keys_to_load: Sequence[str] = (K.images, K.seg_masks),
+ test_attributes_to_load: Sequence[dict[str, str | float]] | None = None,
+ test_skip_empty_frames: bool = False,
+ data_backend: None | ConfigDict = None,
+ image_size: tuple[int, int] = (800, 1280),
+ crop_size: tuple[int, int] | None = None,
+ horizontal_flip_prob: float = 0.5,
+ color_jitter_prob: float = 0.0,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> ConfigDict:
+ """Get the default config for BDD100K segmentation."""
+ train_dataset_cfg = class_config(
+ SHIFT,
+ data_root=data_root,
+ split=train_split,
+ framerate=train_framerate,
+ shift_type=train_shift_type,
+ views_to_load=train_views_to_load,
+ keys_to_load=train_keys_to_load,
+ attributes_to_load=train_attributes_to_load,
+ skip_empty_frames=train_skip_empty_frames,
+ backend=data_backend,
+ )
+ test_dataset_cfg = class_config(
+ SHIFT,
+ data_root=data_root,
+ split=test_split,
+ framerate=test_framerate,
+ shift_type=test_shift_type,
+ views_to_load=test_views_to_load,
+ keys_to_load=test_keys_to_load,
+ attributes_to_load=test_attributes_to_load,
+ skip_empty_frames=test_skip_empty_frames,
+ backend=data_backend,
+ )
+
+ return get_shift_dataloader_config(
+ train_dataset_cfg=train_dataset_cfg,
+ test_dataset_cfg=test_dataset_cfg,
+ keys_to_load=train_keys_to_load,
+ image_size=image_size,
+ crop_size=crop_size,
+ horizontal_flip_prob=horizontal_flip_prob,
+ color_jitter_prob=color_jitter_prob,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ train_views_to_load=train_views_to_load,
+ test_views_to_load=test_views_to_load,
+ )
diff --git a/vis4d/zoo/base/datasets/shift/tasks.py b/vis4d/zoo/base/datasets/shift/tasks.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0304a6e211a486f08cce84ec784ce6e2800668
--- /dev/null
+++ b/vis4d/zoo/base/datasets/shift/tasks.py
@@ -0,0 +1,183 @@
+"""SHIFT data loading config for segmentation."""
+
+from __future__ import annotations
+
+from ml_collections.config_dict import ConfigDict
+
+from vis4d.common.typing import ArgsType
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import data_key, pred_key
+
+from .common import get_shift_config
+
+CONN_SHIFT_DET_EVAL = {
+ "frame_ids": data_key("frame_ids"),
+ "sample_names": data_key("sample_names"),
+ "sequence_names": data_key("sequence_names"),
+ "pred_boxes": pred_key("boxes"),
+ "pred_scores": pred_key("scores"),
+ "pred_classes": pred_key("class_ids"),
+}
+CONN_SHIFT_INS_EVAL = {
+ "frame_ids": data_key("frame_ids"),
+ "sample_names": data_key("sample_names"),
+ "sequence_names": data_key("sequence_names"),
+ "pred_boxes": pred_key("boxes.boxes"),
+ "pred_scores": pred_key("boxes.scores"),
+ "pred_classes": pred_key("boxes.class_ids"),
+ "pred_masks": pred_key("masks.masks"),
+}
+
+
+def get_shift_sem_seg_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT segmentation task."""
+ keys_to_load = (K.images, K.input_hw, K.original_hw, K.seg_masks)
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ horizontal_flip_prob=0.5,
+ color_jitter_prob=0.5,
+ crop_size=kwargs.get("crop_size", (512, 1024)),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_det_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT detection task."""
+ keys_to_load = (
+ K.images,
+ K.input_hw,
+ K.original_hw,
+ K.boxes2d,
+ K.boxes2d_classes,
+ )
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ train_skip_empty_frames=True,
+ test_skip_empty_frames=False,
+ horizontal_flip_prob=0.5,
+ color_jitter_prob=0.0,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_instance_seg_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT instance segmentation task."""
+ keys_to_load = (
+ K.images,
+ K.input_hw,
+ K.original_hw,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.instance_masks,
+ )
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ train_skip_empty_frames=True,
+ test_skip_empty_frames=False,
+ horizontal_flip_prob=0.5,
+ color_jitter_prob=0.5,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_depth_est_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT depth estimation task."""
+ keys_to_load = (K.images, K.input_hw, K.original_hw, K.depth_maps)
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ horizontal_flip_prob=0.5,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_optical_flow_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT optical flow task."""
+ keys_to_load = (K.images, K.input_hw, K.original_hw, K.optical_flows)
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ horizontal_flip_prob=0.5,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_tracking_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT tracking task."""
+ keys_to_load = (
+ K.images,
+ K.input_hw,
+ K.original_hw,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ )
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ horizontal_flip_prob=0.5,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_multitask_2d_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT multitask 2D task."""
+ keys_to_load = (
+ K.images,
+ K.input_hw,
+ K.original_hw,
+ K.intrinsics,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ K.seg_masks,
+ K.depth_maps,
+ )
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ horizontal_flip_prob=0.5,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
+
+
+def get_shift_multitask_3d_config(**kwargs: ArgsType) -> ConfigDict:
+ """Get the config for the SHIFT multitask 3D task."""
+ keys_to_load = (
+ K.images,
+ K.input_hw,
+ K.original_hw,
+ K.intrinsics,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_track_ids,
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ K.seg_masks,
+ K.depth_maps,
+ )
+ cfg = get_shift_config(
+ train_keys_to_load=keys_to_load,
+ test_keys_to_load=keys_to_load,
+ horizontal_flip_prob=0.5,
+ crop_size=kwargs.get("crop_size", None),
+ **kwargs,
+ )
+ return cfg
diff --git a/vis4d/zoo/base/models/__init__.py b/vis4d/zoo/base/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..95d427c92e9b91c3b2d68699c9b8f7d14828d594
--- /dev/null
+++ b/vis4d/zoo/base/models/__init__.py
@@ -0,0 +1 @@
+"""Model Zoo base models."""
diff --git a/vis4d/zoo/base/models/faster_rcnn.py b/vis4d/zoo/base/models/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0a4a18491aab96a7242f91debe0dbdd15f34bb
--- /dev/null
+++ b/vis4d/zoo/base/models/faster_rcnn.py
@@ -0,0 +1,148 @@
+"""Faseter R-CNN base model config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict, FieldReference
+
+from vis4d.config import class_config
+from vis4d.engine.connectors import LossConnector, data_key, pred_key
+from vis4d.engine.loss_module import LossModule
+from vis4d.model.detect.faster_rcnn import FasterRCNN
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
+from vis4d.op.box.matchers import MaxIoUMatcher
+from vis4d.op.box.samplers import RandomSampler
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead
+from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss
+from vis4d.op.detect.rpn import RPNLoss
+
+# Data connectors
+CONN_RPN_LOSS_2D = {
+ "cls_outs": pred_key("rpn.cls"),
+ "reg_outs": pred_key("rpn.box"),
+ "target_boxes": data_key("boxes2d"),
+ "images_hw": data_key("input_hw"),
+}
+
+CONN_ROI_LOSS_2D = {
+ "class_outs": pred_key("roi.cls_score"),
+ "regression_outs": pred_key("roi.bbox_pred"),
+ "boxes": pred_key("sampled_proposals.boxes"),
+ "boxes_mask": pred_key("sampled_targets.labels"),
+ "target_boxes": pred_key("sampled_targets.boxes"),
+ "target_classes": pred_key("sampled_targets.classes"),
+}
+
+
+def get_default_rpn_box_codec_cfg(
+ target_means: tuple[float, ...] = (0.0, 0.0, 0.0, 0.0),
+ target_stds: tuple[float, ...] = (1.0, 1.0, 1.0, 1.0),
+) -> tuple[ConfigDict, ConfigDict]:
+ """Get default config for rpn box encoder and decoder."""
+ return tuple(
+ class_config(x, target_means=target_means, target_stds=target_stds)
+ for x in (DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder)
+ )
+
+
+def get_default_rcnn_box_codec_cfg(
+ target_means: tuple[float, ...] = (0.0, 0.0, 0.0, 0.0),
+ target_stds: tuple[float, ...] = (0.1, 0.1, 0.2, 0.2),
+) -> tuple[ConfigDict, ConfigDict]:
+ """Get default config for rcnn box encoder and decoder."""
+ return tuple(
+ class_config(x, target_means=target_means, target_stds=target_stds)
+ for x in (DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder)
+ )
+
+
+def get_faster_rcnn_cfg(
+ num_classes: FieldReference | int,
+ basemodel: ConfigDict,
+ weights: str | None = None,
+) -> tuple[ConfigDict, ConfigDict]:
+ """Return default config for faster_rcnn model and loss.
+
+ This is an example for setting every component of the model and loss.
+ Everything is the same as the default args.
+
+ Args:
+ num_classes (FieldReference | int): Number of classes.
+ basemodel (ConfigDict): Base model config.
+ weights (str | None, optional): Weights to load. Defaults to None.
+ """
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ anchor_generator = class_config(
+ AnchorGenerator,
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64],
+ )
+
+ rpn_box_encoder, rpn_box_decoder = get_default_rpn_box_codec_cfg()
+ rcnn_box_encoder, rcnn_box_decoder = get_default_rcnn_box_codec_cfg()
+
+ box_matcher = class_config(
+ MaxIoUMatcher,
+ thresholds=[0.5],
+ labels=[0, 1],
+ allow_low_quality_matches=False,
+ )
+
+ box_sampler = class_config(
+ RandomSampler, batch_size=512, positive_fraction=0.25
+ )
+
+ roi_head = class_config(RCNNHead, num_classes=num_classes)
+
+ faster_rcnn_head = class_config(
+ FasterRCNNHead,
+ num_classes=num_classes,
+ anchor_generator=anchor_generator,
+ rpn_box_decoder=rpn_box_decoder,
+ box_matcher=box_matcher,
+ box_sampler=box_sampler,
+ roi_head=roi_head,
+ )
+
+ model = class_config(
+ FasterRCNN,
+ num_classes=num_classes,
+ basemodel=basemodel,
+ faster_rcnn_head=faster_rcnn_head,
+ rcnn_box_decoder=rcnn_box_decoder,
+ weights=weights,
+ )
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ rpn_loss = class_config(
+ RPNLoss,
+ anchor_generator=anchor_generator,
+ box_encoder=rpn_box_encoder,
+ )
+ rcnn_loss = class_config(
+ RCNNLoss, box_encoder=rcnn_box_encoder, num_classes=num_classes
+ )
+
+ loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": rpn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_RPN_LOSS_2D
+ ),
+ },
+ {
+ "loss": rcnn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_ROI_LOSS_2D
+ ),
+ },
+ ],
+ )
+ return model, loss
diff --git a/vis4d/zoo/base/models/mask_rcnn.py b/vis4d/zoo/base/models/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f556e56d3eaffacb36b0e74cc354597984e14a6
--- /dev/null
+++ b/vis4d/zoo/base/models/mask_rcnn.py
@@ -0,0 +1,163 @@
+"""Mask RCNN base model config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict, FieldReference
+
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import (
+ LossConnector,
+ data_key,
+ pred_key,
+ remap_pred_keys,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.model.detect.mask_rcnn import MaskRCNN
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.matchers import MaxIoUMatcher
+from vis4d.op.box.samplers import RandomSampler
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead
+from vis4d.op.detect.mask_rcnn import (
+ MaskRCNNHead,
+ MaskRCNNHeadLoss,
+ SampledMaskLoss,
+ positive_mask_sampler,
+)
+from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss
+from vis4d.op.detect.rpn import RPNLoss
+from vis4d.zoo.base import get_callable_cfg
+from vis4d.zoo.base.models.faster_rcnn import (
+ CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D,
+)
+from vis4d.zoo.base.models.faster_rcnn import (
+ CONN_RPN_LOSS_2D as _CONN_RPN_LOSS_2D,
+)
+from vis4d.zoo.base.models.faster_rcnn import (
+ get_default_rcnn_box_codec_cfg,
+ get_default_rpn_box_codec_cfg,
+)
+
+# Data connectors
+CONN_MASK_HEAD_LOSS_2D = {
+ "mask_preds": pred_key("masks.mask_pred"),
+ "target_masks": data_key(K.instance_masks),
+ "sampled_target_indices": pred_key("boxes.sampled_target_indices"),
+ "sampled_targets": pred_key("boxes.sampled_targets"),
+ "sampled_proposals": pred_key("boxes.sampled_proposals"),
+}
+
+CONN_RPN_LOSS_2D = remap_pred_keys(_CONN_RPN_LOSS_2D, "boxes")
+
+CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, "boxes")
+
+
+def get_mask_rcnn_cfg(
+ num_classes: FieldReference | int,
+ basemodel: ConfigDict,
+ no_overlap: bool = False,
+ weights: str | None = None,
+) -> tuple[ConfigDict, ConfigDict]:
+ """Return default config for mask_rcnn model and loss.
+
+ This is an example for setting every component of the model and loss.
+ Everything is the same as the default args.
+
+ Args:
+ num_classes (FieldReference | int): Number of classes.
+ basemodel (ConfigDict): Base model config.
+ no_overlap (bool, optional): Whether to remove overlapping pixels
+ between masks. Defaults to False.
+ weights (str | None, optional): Weights to load. Defaults to None.
+ """
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ anchor_generator = class_config(
+ AnchorGenerator,
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64],
+ )
+
+ rpn_box_encoder, rpn_box_decoder = get_default_rpn_box_codec_cfg()
+ rcnn_box_encoder, rcnn_box_decoder = get_default_rcnn_box_codec_cfg()
+
+ box_matcher = class_config(
+ MaxIoUMatcher,
+ thresholds=[0.5],
+ labels=[0, 1],
+ allow_low_quality_matches=False,
+ )
+
+ box_sampler = class_config(
+ RandomSampler, batch_size=512, positive_fraction=0.25
+ )
+
+ roi_head = class_config(RCNNHead, num_classes=num_classes)
+
+ mask_head = class_config(MaskRCNNHead, num_classes=num_classes)
+
+ faster_rcnn_head = class_config(
+ FasterRCNNHead,
+ num_classes=num_classes,
+ anchor_generator=anchor_generator,
+ rpn_box_decoder=rpn_box_decoder,
+ box_matcher=box_matcher,
+ box_sampler=box_sampler,
+ roi_head=roi_head,
+ )
+
+ model = class_config(
+ MaskRCNN,
+ num_classes=num_classes,
+ basemodel=basemodel,
+ faster_rcnn_head=faster_rcnn_head,
+ mask_head=mask_head,
+ rcnn_box_decoder=rcnn_box_decoder,
+ no_overlap=no_overlap,
+ weights=weights,
+ )
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ rpn_loss = class_config(
+ RPNLoss,
+ anchor_generator=anchor_generator,
+ box_encoder=rpn_box_encoder,
+ )
+ rcnn_loss = class_config(
+ RCNNLoss, box_encoder=rcnn_box_encoder, num_classes=num_classes
+ )
+
+ mask_loss = class_config(
+ SampledMaskLoss,
+ mask_sampler=get_callable_cfg(positive_mask_sampler),
+ loss=class_config(MaskRCNNHeadLoss, num_classes=num_classes),
+ )
+
+ loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": rpn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_RPN_LOSS_2D
+ ),
+ },
+ {
+ "loss": rcnn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_ROI_LOSS_2D
+ ),
+ },
+ {
+ "loss": mask_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_MASK_HEAD_LOSS_2D
+ ),
+ },
+ ],
+ )
+ return model, loss
diff --git a/vis4d/zoo/base/models/qdtrack.py b/vis4d/zoo/base/models/qdtrack.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b2e5848ac0cf5f8446a0a07762459c979f5a6a
--- /dev/null
+++ b/vis4d/zoo/base/models/qdtrack.py
@@ -0,0 +1,219 @@
+"""QD-Track model config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict, FieldReference
+
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import LossConnector, pred_key, remap_pred_keys
+from vis4d.engine.loss_module import LossModule
+from vis4d.model.adapter import ModelExpEMAAdapter
+from vis4d.model.track.qdtrack import FasterRCNNQDTrack, YOLOXQDTrack
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.box.poolers import MultiScaleRoIAlign
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead
+from vis4d.op.detect.rcnn import RCNNLoss
+from vis4d.op.detect.rpn import RPNLoss
+from vis4d.op.detect.yolox import YOLOXHeadLoss
+from vis4d.op.loss.common import smooth_l1_loss
+from vis4d.op.track.qdtrack import (
+ QDSimilarityHead,
+ QDTrackHead,
+ QDTrackInstanceSimilarityLoss,
+)
+from vis4d.zoo.base import get_callable_cfg
+from vis4d.zoo.base.models.faster_rcnn import (
+ CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D,
+)
+from vis4d.zoo.base.models.faster_rcnn import (
+ get_default_rcnn_box_codec_cfg,
+ get_default_rpn_box_codec_cfg,
+)
+
+from .yolox import get_yolox_model_cfg
+
+PRED_PREFIX = "detector_out"
+
+CONN_BBOX_2D_TRAIN = {
+ "images": K.images,
+ "images_hw": K.input_hw,
+ "original_hw": K.original_hw,
+ "frame_ids": K.frame_ids,
+ "boxes2d": K.boxes2d,
+ "boxes2d_classes": K.boxes2d_classes,
+ "boxes2d_track_ids": K.boxes2d_track_ids,
+ "keyframes": "keyframes",
+}
+
+CONN_BBOX_2D_TEST = {
+ "images": K.images,
+ "images_hw": K.input_hw,
+ "original_hw": K.original_hw,
+ "frame_ids": K.frame_ids,
+}
+
+CONN_RPN_LOSS_2D = {
+ "cls_outs": pred_key(f"{PRED_PREFIX}.rpn.cls"),
+ "reg_outs": pred_key(f"{PRED_PREFIX}.rpn.box"),
+ "target_boxes": pred_key("key_target_boxes"),
+ "images_hw": pred_key("key_images_hw"),
+}
+
+CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, PRED_PREFIX)
+
+CONN_TRACK_LOSS_2D = {
+ "key_embeddings": pred_key("key_embeddings"),
+ "ref_embeddings": pred_key("ref_embeddings"),
+ "key_track_ids": pred_key("key_track_ids"),
+ "ref_track_ids": pred_key("ref_track_ids"),
+}
+
+CONN_YOLOX_LOSS_2D = {
+ "cls_outs": pred_key(f"{PRED_PREFIX}.cls_score"),
+ "reg_outs": pred_key(f"{PRED_PREFIX}.bbox_pred"),
+ "obj_outs": pred_key(f"{PRED_PREFIX}.objectness"),
+ "target_boxes": pred_key("key_target_boxes"),
+ "target_class_ids": pred_key("key_target_classes"),
+ "images_hw": pred_key("key_images_hw"),
+}
+
+
+def get_qdtrack_cfg(
+ num_classes: int | FieldReference,
+ basemodel: ConfigDict,
+ weights: str | None = None,
+) -> tuple[ConfigDict, ConfigDict]:
+ """Get QDTrack model config."""
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ anchor_generator = class_config(
+ AnchorGenerator,
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64],
+ )
+
+ rpn_box_encoder, _ = get_default_rpn_box_codec_cfg()
+ rcnn_box_encoder, _ = get_default_rcnn_box_codec_cfg()
+
+ faster_rcnn_head = class_config(
+ FasterRCNNHead,
+ num_classes=num_classes,
+ anchor_generator=anchor_generator,
+ )
+
+ model = class_config(
+ FasterRCNNQDTrack,
+ num_classes=num_classes,
+ basemodel=basemodel,
+ faster_rcnn_head=faster_rcnn_head,
+ weights=weights,
+ )
+
+ rpn_loss = class_config(
+ RPNLoss,
+ anchor_generator=anchor_generator,
+ box_encoder=rpn_box_encoder,
+ loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0),
+ )
+ rcnn_loss = class_config(
+ RCNNLoss,
+ box_encoder=rcnn_box_encoder,
+ num_classes=num_classes,
+ loss_bbox=get_callable_cfg(smooth_l1_loss),
+ )
+
+ track_loss = class_config(QDTrackInstanceSimilarityLoss)
+
+ loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": rpn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_RPN_LOSS_2D
+ ),
+ },
+ {
+ "loss": rcnn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_ROI_LOSS_2D
+ ),
+ },
+ {
+ "loss": track_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_TRACK_LOSS_2D
+ ),
+ },
+ ],
+ )
+
+ return model, loss
+
+
+def get_qdtrack_yolox_cfg(
+ num_classes: int | FieldReference,
+ model_type: str,
+ use_ema: bool = True,
+ weights: str | None = None,
+) -> tuple[ConfigDict, ConfigDict]:
+ """Get QDTrack YOLOX model config."""
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ basemodel, fpn, yolox_head = get_yolox_model_cfg(num_classes, model_type)
+ if model_type == "tiny":
+ in_dim = 96
+ elif model_type == "small":
+ in_dim = 128
+ elif model_type == "large":
+ in_dim = 256
+ elif model_type == "xlarge":
+ in_dim = 320
+ else:
+ raise ValueError(f"Invalid model type: {model_type}")
+ model = class_config(
+ YOLOXQDTrack,
+ num_classes=num_classes,
+ basemodel=basemodel,
+ fpn=fpn,
+ yolox_head=yolox_head,
+ qdtrack_head=class_config(
+ QDTrackHead,
+ similarity_head=class_config(
+ QDSimilarityHead,
+ proposal_pooler=MultiScaleRoIAlign(
+ resolution=(7, 7), strides=[8, 16, 32], sampling_ratio=0
+ ),
+ in_dim=in_dim,
+ ),
+ ),
+ weights=weights,
+ )
+ if use_ema:
+ model = class_config(ModelExpEMAAdapter, model=model)
+
+ track_loss = class_config(QDTrackInstanceSimilarityLoss)
+
+ loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(YOLOXHeadLoss, num_classes=num_classes),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_YOLOX_LOSS_2D
+ ),
+ },
+ {
+ "loss": track_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_TRACK_LOSS_2D
+ ),
+ },
+ ],
+ )
+
+ return model, loss
diff --git a/vis4d/zoo/base/models/yolox.py b/vis4d/zoo/base/models/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..eadc3bc191210a57dd9ee0f2ec27c8b5c4e726cf
--- /dev/null
+++ b/vis4d/zoo/base/models/yolox.py
@@ -0,0 +1,221 @@
+"""YOLOX base model config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict, FieldReference
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import OptimizerConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.callbacks import (
+ EMACallback,
+ YOLOXModeSwitchCallback,
+ YOLOXSyncNormCallback,
+ YOLOXSyncRandomResizeCallback,
+)
+from vis4d.engine.connectors import LossConnector, data_key, pred_key
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim.scheduler import ConstantLR, QuadraticLRWarmup
+from vis4d.model.adapter import ModelExpEMAAdapter
+from vis4d.model.detect.yolox import YOLOX
+from vis4d.op.base import CSPDarknet
+from vis4d.op.detect.yolox import YOLOXHead, YOLOXHeadLoss
+from vis4d.op.fpp import YOLOXPAFPN
+from vis4d.zoo.base import get_lr_scheduler_cfg, get_optimizer_cfg
+
+# Data connectors
+CONN_YOLOX_LOSS_2D = {
+ "cls_outs": pred_key("cls_score"),
+ "reg_outs": pred_key("bbox_pred"),
+ "obj_outs": pred_key("objectness"),
+ "target_boxes": data_key(K.boxes2d),
+ "target_class_ids": data_key(K.boxes2d_classes),
+ "images_hw": data_key(K.input_hw),
+}
+
+
+def get_yolox_optimizers_cfg(
+ lr: float | FieldReference,
+ num_epochs: int | FieldReference,
+ warmup_epochs: int,
+ num_last_epochs: int,
+) -> list[OptimizerConfig]:
+ """Construct optimizer for YOLOX training."""
+ return [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD,
+ lr=lr,
+ momentum=0.9,
+ weight_decay=5e-4,
+ nesterov=True,
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(QuadraticLRWarmup, max_steps=warmup_epochs),
+ end=warmup_epochs,
+ epoch_based=False,
+ convert_epochs_to_steps=True,
+ convert_attributes=["max_steps"],
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ CosineAnnealingLR,
+ T_max=num_epochs - num_last_epochs - warmup_epochs,
+ eta_min=lr * 0.05,
+ ),
+ begin=warmup_epochs,
+ end=num_epochs - num_last_epochs,
+ epoch_based=False,
+ convert_epochs_to_steps=True,
+ convert_attributes=["T_max"],
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ ConstantLR, max_steps=num_last_epochs, factor=1.0
+ ),
+ begin=num_epochs - num_last_epochs,
+ end=num_epochs,
+ epoch_based=True,
+ ),
+ ],
+ param_groups=[
+ {
+ "custom_keys": ["basemodel", "fpn", "yolox_head"],
+ "norm_decay_mult": 0.0,
+ },
+ {
+ "custom_keys": ["basemodel", "fpn", "yolox_head"],
+ "bias_decay_mult": 0.0,
+ },
+ ],
+ )
+ ]
+
+
+def get_yolox_callbacks_cfg(
+ switch_epoch: int,
+ shape: tuple[int, int] = (480, 480),
+ num_sizes: int = 11,
+ use_ema: bool = True,
+) -> list[ConfigDict]:
+ """Get YOLOX callbacks for training."""
+ callbacks = []
+ if num_sizes > 0:
+ callbacks.append(
+ class_config(
+ YOLOXSyncRandomResizeCallback,
+ size_list=[
+ (shape[0] + i * 32, shape[1] + i * 32)
+ for i in range(num_sizes)
+ ],
+ interval=10,
+ )
+ )
+ callbacks += [
+ class_config(YOLOXModeSwitchCallback, switch_epoch=switch_epoch),
+ class_config(YOLOXSyncNormCallback),
+ ]
+ if use_ema:
+ callbacks += [class_config(EMACallback)]
+ return callbacks
+
+
+def get_model_setting(model_type: str) -> tuple[float, float, int, list[int]]:
+ """Get YOLOX model setting."""
+ if model_type == "tiny":
+ deepen_factor, widen_factor, num_csp_blocks = 0.33, 0.375, 1
+ in_channels = [96, 192, 384]
+ elif model_type == "small":
+ deepen_factor, widen_factor, num_csp_blocks = 0.33, 0.5, 1
+ in_channels = [128, 256, 512]
+ elif model_type == "large":
+ deepen_factor, widen_factor, num_csp_blocks = 1.0, 1.0, 3
+ in_channels = [256, 512, 1024]
+ elif model_type == "xlarge":
+ deepen_factor, widen_factor, num_csp_blocks = 1.33, 1.25, 4
+ in_channels = [320, 640, 1280]
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+ return deepen_factor, widen_factor, num_csp_blocks, in_channels
+
+
+def get_yolox_model_cfg(
+ num_classes: FieldReference | int, model_type: str
+) -> ConfigDict:
+ """Get YOLOX model."""
+ assert model_type in {"tiny", "small", "large", "xlarge"}, (
+ f"model_type must be one of 'tiny', 'small', 'large', 'xlarge', "
+ f"got {model_type}."
+ )
+ (
+ deepen_factor,
+ widen_factor,
+ num_csp_blocks,
+ in_channels,
+ ) = get_model_setting(model_type)
+ basemodel = class_config(
+ CSPDarknet, deepen_factor=deepen_factor, widen_factor=widen_factor
+ )
+ fpn = class_config(
+ YOLOXPAFPN,
+ in_channels=in_channels,
+ out_channels=in_channels[0],
+ num_csp_blocks=num_csp_blocks,
+ )
+ yolox_head = class_config(
+ YOLOXHead,
+ num_classes=num_classes,
+ in_channels=in_channels[0],
+ feat_channels=in_channels[0],
+ )
+ return basemodel, fpn, yolox_head
+
+
+def get_yolox_cfg(
+ num_classes: FieldReference | int,
+ model_type: str,
+ use_ema: bool = True,
+ weights: str | None = None,
+) -> tuple[ConfigDict, ConfigDict]:
+ """Return default config for YOLOX model and loss.
+
+ Args:
+ num_classes (FieldReference | int): Number of classes.
+ model_type (str): Model type. Must be one of 'tiny', 'small', 'large',
+ 'xlarge'.
+ use_ema (bool, optional): Whether to use EMA. Defaults to True.
+ weights (str | None, optional): Weights to load. Defaults to None.
+ """
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ basemodel, fpn, yolox_head = get_yolox_model_cfg(num_classes, model_type)
+ model = class_config(
+ YOLOX,
+ num_classes=num_classes,
+ basemodel=basemodel,
+ fpn=fpn,
+ yolox_head=yolox_head,
+ weights=weights,
+ )
+ if use_ema:
+ model = class_config(ModelExpEMAAdapter, model=model)
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(YOLOXHeadLoss, num_classes=num_classes),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_YOLOX_LOSS_2D
+ ),
+ },
+ ],
+ )
+ return model, loss
diff --git a/vis4d/zoo/base/optimizer.py b/vis4d/zoo/base/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d5bebcb550e637301d508893db8ac43661a1004
--- /dev/null
+++ b/vis4d/zoo/base/optimizer.py
@@ -0,0 +1,82 @@
+"""Optimizer configuration."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict
+
+from vis4d.config.typing import (
+ LrSchedulerConfig,
+ OptimizerConfig,
+ ParamGroupCfg,
+)
+
+
+def get_lr_scheduler_cfg(
+ scheduler: ConfigDict,
+ begin: int = 0,
+ end: int = -1,
+ epoch_based: bool = True,
+ convert_epochs_to_steps: bool = False,
+ convert_attributes: list[str] | None = None,
+) -> LrSchedulerConfig:
+ """Default learning rate scheduler configuration.
+
+ This creates a config object that can be initialized as a LearningRate
+ scheduler for training.
+
+ Args:
+ scheduler (ConfigDict): Learning rate scheduler configuration.
+ begin (int, optional): Begin epoch. Defaults to 0.
+ end (int, optional): End epoch. Defaults to None. Defaults to -1.
+ epoch_based (bool, optional): Whether the learning rate scheduler is
+ epoch based or step based. Defaults to True.
+ convert_epochs_to_steps (bool): Whether to convert the begin and end
+ for a step based scheduler to steps automatically based on length
+ of train dataloader. Enables users to set the iteration breakpoints
+ as epochs. Defaults to False.
+ convert_attributes (list[str] | None): List of attributes in the
+ scheduler that should be converted to steps. Defaults to None.
+
+ Returns:
+ LrSchedulerConfig: Config dict that can be instantiated as LearningRate
+ scheduler.
+ """
+ lr_scheduler = LrSchedulerConfig()
+
+ lr_scheduler.scheduler = scheduler
+ lr_scheduler.begin = begin
+ lr_scheduler.end = end
+ lr_scheduler.epoch_based = epoch_based
+ lr_scheduler.convert_epochs_to_steps = convert_epochs_to_steps
+ lr_scheduler.convert_attributes = convert_attributes
+
+ return lr_scheduler
+
+
+def get_optimizer_cfg(
+ optimizer: ConfigDict,
+ lr_schedulers: list[LrSchedulerConfig] | None = None,
+ param_groups: list[ParamGroupCfg] | None = None,
+) -> OptimizerConfig:
+ """Default optimizer configuration.
+
+ This creates a config object that can be initialized as an Optimizer for
+ training.
+
+ Args:
+ optimizer (ConfigDict): Optimizer configuration.
+ lr_schedulers (list[LrSchedulerConfig] | None, optional): Learning rate
+ schedulers configuration. Defaults to None.
+ param_groups (list[ParamGroupCfg] | None, optional): Parameter groups
+ configuration. Defaults to None.
+
+ Returns:
+ OptimizerConfig: Config dict that can be instantiated as Optimizer.
+ """
+ optim = OptimizerConfig()
+
+ optim.optimizer = optimizer
+ optim.lr_schedulers = lr_schedulers
+ optim.param_groups = param_groups
+
+ return optim
diff --git a/vis4d/zoo/base/pl_trainer.py b/vis4d/zoo/base/pl_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54e625cb78d8946d7018c95fd86302d232cba97
--- /dev/null
+++ b/vis4d/zoo/base/pl_trainer.py
@@ -0,0 +1,38 @@
+"""Default runtime configuration for PyTorch Lightning."""
+
+import inspect
+
+from lightning import Trainer
+
+from vis4d.config import FieldConfigDict
+from vis4d.config.typing import ExperimentConfig
+
+
+def get_default_pl_trainer_cfg(config: ExperimentConfig) -> ExperimentConfig:
+ """Get PyTorch Lightning Trainer config."""
+ pl_trainer = FieldConfigDict()
+
+ # PL Trainer arguments
+ for k, v in inspect.signature(Trainer).parameters.items():
+ if not k in {"callbacks", "devices", "logger", "strategy"}:
+ pl_trainer[k] = v.default
+
+ # PL Trainer settings
+ pl_trainer.benchmark = config.benchmark
+ pl_trainer.use_distributed_sampler = False
+ pl_trainer.num_sanity_val_steps = 0
+
+ # logger
+ pl_trainer.enable_progress_bar = False
+ pl_trainer.log_every_n_steps = config.log_every_n_steps
+
+ # Default Trainer arguments
+ pl_trainer.work_dir = config.work_dir
+ pl_trainer.exp_name = config.experiment_name
+ pl_trainer.version = config.version
+ pl_trainer.find_unused_parameters = False
+ pl_trainer.checkpoint_period = 1
+ pl_trainer.save_top_k = 1
+ pl_trainer.wandb = False
+
+ return pl_trainer
diff --git a/vis4d/zoo/base/runtime.py b/vis4d/zoo/base/runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..d127ace78dd8b54678911e07e4b9f9dd01ff4256
--- /dev/null
+++ b/vis4d/zoo/base/runtime.py
@@ -0,0 +1,93 @@
+"""Default runtime configuration for the project."""
+
+from __future__ import annotations
+
+import platform
+from datetime import datetime
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig
+from vis4d.engine.callbacks import LoggingCallback
+
+
+def get_default_cfg(
+ exp_name: str, work_dir: str = "vis4d-workspace"
+) -> ExperimentConfig:
+ """Set default config for the project.
+
+ Args:
+ exp_name (str): Experiment name.
+ work_dir (str, optional): Working directory. Defaults to
+ "vis4d-workspace".
+
+ Returns:
+ ExperimentConfig: Config for the project.
+ """
+ config = ExperimentConfig()
+
+ config.work_dir = work_dir
+ config.experiment_name = exp_name
+
+ timestamp = (
+ str(datetime.now())
+ .split(".", maxsplit=1)[0]
+ .replace(" ", "_")
+ .replace(":", "-")
+ )
+ config.timestamp = timestamp
+ config.version = timestamp
+
+ if platform.system() == "Windows":
+ path_component = "\\"
+ else:
+ path_component = "/"
+
+ config.output_dir = (
+ config.work_dir
+ + path_component
+ + config.experiment_name
+ + path_component
+ + config.version
+ )
+
+ # Set default value for the following fields
+ config.seed = -1
+ config.log_every_n_steps = 50
+ config.use_tf32 = False
+ config.tf32_matmul_precision = "highest"
+ config.benchmark = False
+ config.compute_flops = False
+ config.check_unused_parameters = False
+
+ return config
+
+
+def get_default_callbacks_cfg(
+ epoch_based: bool = True,
+ refresh_rate: int = 50,
+) -> list[ConfigDict]:
+ """Get default callbacks config.
+
+ It will return a list of callbacks config including:
+ - LoggingCallback
+
+ Args:
+ epoch_based (bool, optional): Whether to use epoch based logging.
+ refresh_rate (int, optional): Refresh rate for the logging. Defaults to
+ 50.
+
+ Returns:
+ list[ConfigDict]: List of callbacks config.
+ """
+ callbacks = []
+
+ # Logger
+ callbacks.append(
+ class_config(
+ LoggingCallback, epoch_based=epoch_based, refresh_rate=refresh_rate
+ )
+ )
+
+ return callbacks
diff --git a/vis4d/zoo/bdd100k/__init__.py b/vis4d/zoo/bdd100k/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d079c7f6048d6ca3787f42f360776ca24c0fc2
--- /dev/null
+++ b/vis4d/zoo/bdd100k/__init__.py
@@ -0,0 +1,27 @@
+"""BDD100K Model Zoo."""
+
+from .faster_rcnn import faster_rcnn_r50_1x_bdd100k, faster_rcnn_r50_3x_bdd100k
+from .mask_rcnn import (
+ mask_rcnn_r50_1x_bdd100k,
+ mask_rcnn_r50_3x_bdd100k,
+ mask_rcnn_r50_5x_bdd100k,
+)
+from .qdtrack import qdtrack_frcnn_r50_fpn_1x_bdd100k
+from .semantic_fpn import (
+ semantic_fpn_r50_40k_bdd100k,
+ semantic_fpn_r50_80k_bdd100k,
+ semantic_fpn_r101_80k_bdd100k,
+)
+
+# Lists of available models in BDD100K Model Zoo.
+AVAILABLE_MODELS = {
+ "faster_rcnn_r50_1x_bdd100k": faster_rcnn_r50_1x_bdd100k,
+ "faster_rcnn_r50_3x_bdd100k": faster_rcnn_r50_3x_bdd100k,
+ "mask_rcnn_r50_1x_bdd100k": mask_rcnn_r50_1x_bdd100k,
+ "mask_rcnn_r50_3x_bdd100k": mask_rcnn_r50_3x_bdd100k,
+ "mask_rcnn_r50_5x_bdd100k": mask_rcnn_r50_5x_bdd100k,
+ "semantic_fpn_r50_40k_bdd100k": semantic_fpn_r50_40k_bdd100k,
+ "semantic_fpn_r50_80k_bdd100k": semantic_fpn_r50_80k_bdd100k,
+ "semantic_fpn_r101_80k_bdd100k": semantic_fpn_r101_80k_bdd100k,
+ "qdtrack_frcnn_r50_fpn_1x_bdd100k": qdtrack_frcnn_r50_fpn_1x_bdd100k,
+}
diff --git a/vis4d/zoo/bdd100k/faster_rcnn/__init__.py b/vis4d/zoo/bdd100k/faster_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc567e6ee7189a8c9ce7908e0ed3d1d68d8625e1
--- /dev/null
+++ b/vis4d/zoo/bdd100k/faster_rcnn/__init__.py
@@ -0,0 +1 @@
+"""Faster R-CNN for BDD100K."""
diff --git a/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad720c0ec350a8feddf5bded595f7352a3c74a1e
--- /dev/null
+++ b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_1x_bdd100k.py
@@ -0,0 +1,162 @@
+# pylint: disable=duplicate-code
+"""Faster RCNN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_DET_EVAL,
+ get_bdd100k_detection_config,
+)
+from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Faster-RCNN config dict for the BDD100K detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="faster_rcnn_r50_1x_bdd100k")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 12
+ params.num_classes = 10
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/100k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_detection_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_faster_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KDetectEvaluator,
+ annotation_path="data/bdd100k/labels/det_20/det_val.json",
+ config_path="det",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_DET_EVAL
+ ),
+ metrics_to_eval=[BDD100KDetectEvaluator.METRICS_DET],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_3x_bdd100k.py b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_3x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..de53a841171f3b551cb9126c80f7539480ca61ba
--- /dev/null
+++ b/vis4d/zoo/bdd100k/faster_rcnn/faster_rcnn_r50_3x_bdd100k.py
@@ -0,0 +1,163 @@
+# pylint: disable=duplicate-code
+"""Faster RCNN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_DET_EVAL,
+ get_bdd100k_detection_config,
+)
+from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Faster-RCNN config dict for the BDD100K detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="faster_rcnn_r50_3x_bdd100k")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 36
+ params.num_classes = 10
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/100k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_detection_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ multi_scale=True,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_faster_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[24, 33], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KDetectEvaluator,
+ annotation_path="data/bdd100k/labels/det_20/det_val.json",
+ config_path="det",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_DET_EVAL
+ ),
+ metrics_to_eval=[BDD100KDetectEvaluator.METRICS_DET],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/mask_rcnn/__init__.py b/vis4d/zoo/bdd100k/mask_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e738b6b71d039bbcf361451597e4850179bb1883
--- /dev/null
+++ b/vis4d/zoo/bdd100k/mask_rcnn/__init__.py
@@ -0,0 +1 @@
+"""Mask R-CNN for BDD100K."""
diff --git a/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_1x_bdd100k.py b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_1x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6fb4c4e515c58be188081e399928f12d778a6ff
--- /dev/null
+++ b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_1x_bdd100k.py
@@ -0,0 +1,167 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_INS_MASK_2D_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_INS_EVAL,
+ get_bdd100k_detection_config,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Mask R-CNN config dict for BDD100K instance segmentation.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_1x_bdd100k")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 12
+ params.num_classes = 8
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/10k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_detection_config(
+ data_root=data_root,
+ train_split=train_split,
+ train_keys_to_load=(K.images, K.boxes2d, K.instance_masks),
+ test_split=test_split,
+ test_keys_to_load=(K.images, K.original_images),
+ ins_seg=True,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes,
+ basemodel=basemodel,
+ no_overlap=True,
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=25),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KDetectEvaluator,
+ annotation_path="data/bdd100k/labels/ins_seg_val_rle.json",
+ config_path="ins_seg",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_INS_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_3x_bdd100k.py b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_3x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..f57801ba3af4f21bbada1b6aafea2ecb81f51431
--- /dev/null
+++ b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_3x_bdd100k.py
@@ -0,0 +1,168 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_INS_MASK_2D_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_INS_EVAL,
+ get_bdd100k_detection_config,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Mask R-CNN config dict for BDD100K instance segmentation.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_3x_bdd100k")
+ config.check_val_every_n_epoch = 3
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 36
+ params.num_classes = 8
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/10k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_detection_config(
+ data_root=data_root,
+ train_split=train_split,
+ train_keys_to_load=(K.images, K.boxes2d, K.instance_masks),
+ test_split=test_split,
+ test_keys_to_load=(K.images, K.original_images),
+ ins_seg=True,
+ multi_scale=True,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel, no_overlap=True
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[24, 33], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=25),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KDetectEvaluator,
+ annotation_path="data/bdd100k/labels/ins_seg_val_rle.json",
+ config_path="ins_seg",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_INS_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_5x_bdd100k.py b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_5x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2e6dc4ba71bf0383eec50846883e4c48ee62b34
--- /dev/null
+++ b/vis4d/zoo/bdd100k/mask_rcnn/mask_rcnn_r50_5x_bdd100k.py
@@ -0,0 +1,168 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_INS_MASK_2D_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_INS_EVAL,
+ get_bdd100k_detection_config,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Mask R-CNN config dict for BDD100K instance segmentation.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_5x_bdd100k")
+ config.check_val_every_n_epoch = 5
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 60
+ params.num_classes = 8
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/10k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_detection_config(
+ data_root=data_root,
+ train_split=train_split,
+ train_keys_to_load=(K.images, K.boxes2d, K.instance_masks),
+ test_split=test_split,
+ test_keys_to_load=(K.images, K.original_images),
+ ins_seg=True,
+ multi_scale=True,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel, no_overlap=True
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[40, 55], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=25),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KDetectEvaluator,
+ annotation_path="data/bdd100k/labels/ins_seg_val_rle.json",
+ config_path="ins_seg",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_INS_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/qdtrack/__init__.py b/vis4d/zoo/bdd100k/qdtrack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd5f1cd38da271c6b636d9c8e26995e9dc86caad
--- /dev/null
+++ b/vis4d/zoo/bdd100k/qdtrack/__init__.py
@@ -0,0 +1 @@
+"""QDTrack for BDD100k."""
diff --git a/vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py b/vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..90470087f689d390ce36fc5cdbe46e97cd160817
--- /dev/null
+++ b/vis4d/zoo/bdd100k/qdtrack/qdtrack_frcnn_r50_fpn_1x_bdd100k.py
@@ -0,0 +1,140 @@
+# pylint: disable=duplicate-code
+"""QDTrack with Faster R-CNN on BDD100K."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.datasets.bdd100k import bdd100k_track_map
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KTrackEvaluator
+from vis4d.op.base import ResNet
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_TRACK_EVAL,
+ get_bdd100k_track_cfg,
+)
+from vis4d.zoo.base.models.qdtrack import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ get_qdtrack_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for qdtrack on bdd100k.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="qdtrack_frcnn_r50_fpn_1x_bdd100k")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 4 # batch size = 4 GPUs * 4 samples per GPU = 16
+ params.workers_per_gpu = 4
+ params.lr = 0.02
+ params.num_epochs = 12
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_track_cfg(
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ num_classes = len(bdd100k_track_map)
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_qdtrack_cfg(
+ num_classes=num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(LinearLR, start_factor=0.1, total_iters=1000),
+ end=1000,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KTrackEvaluator,
+ annotation_path="data/bdd100k/labels/box_track_20/val/",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ pl_trainer.gradient_clip_val = 35
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/semantic_fpn/__init__.py b/vis4d/zoo/bdd100k/semantic_fpn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc7e40eb37a174fb774bbba23d0bbe851a10b38b
--- /dev/null
+++ b/vis4d/zoo/bdd100k/semantic_fpn/__init__.py
@@ -0,0 +1 @@
+"""Semantic FPN for BDD100K."""
diff --git a/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r101_80k_bdd100k.py b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r101_80k_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..bea81a90ad0ada011dcba9a13166c61a186be58e
--- /dev/null
+++ b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r101_80k_bdd100k.py
@@ -0,0 +1,199 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.bdd100k import BDD100KSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.base import ResNetV1c
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_SEG_EVAL,
+ get_bdd100k_sem_seg_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="semantic_fpn_r101_80k_bdd100k")
+ config.sync_batchnorm = True
+ config.val_check_interval = 4000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 80000
+ params.num_classes = 19
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/10k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_sem_seg_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNetV1c,
+ resnet_name="resnet101_v1c",
+ pretrained=True,
+ trainable_layers=5,
+ norm_frozen=False,
+ )
+ config.model = class_config(
+ SemanticFPN, num_classes=params.num_classes, basemodel=basemodel
+ )
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KSegEvaluator,
+ annotation_path="data/bdd100k/labels/sem_seg_val_rle.json",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_40k_bdd100k.py b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_40k_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..902b7e9dd57b34c01e409fa3fc53d6250b99d09b
--- /dev/null
+++ b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_40k_bdd100k.py
@@ -0,0 +1,189 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.bdd100k import BDD100KSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_SEG_EVAL,
+ get_bdd100k_sem_seg_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="semantic_fpn_r50_40k_bdd100k")
+ config.sync_batchnorm = True
+ config.val_check_interval = 2000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 40000
+ params.num_classes = 19
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/10k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_sem_seg_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(SemanticFPN, num_classes=params.num_classes)
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KSegEvaluator,
+ annotation_path="data/bdd100k/labels/sem_seg_val_rle.json",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_80k_bdd100k.py b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_80k_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..61709ea1dd350215cb6fa4ef3dc5affa8e04be1c
--- /dev/null
+++ b/vis4d/zoo/bdd100k/semantic_fpn/semantic_fpn_r50_80k_bdd100k.py
@@ -0,0 +1,189 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN BDD100K training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.bdd100k import BDD100KSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.bdd100k import (
+ CONN_BDD100K_SEG_EVAL,
+ get_bdd100k_sem_seg_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="semantic_fpn_r50_80k_bdd100k")
+ config.sync_batchnorm = True
+ config.val_check_interval = 4000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 80000
+ params.num_classes = 19
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/bdd100k/images/10k"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_sem_seg_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(SemanticFPN, num_classes=params.num_classes)
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KSegEvaluator,
+ annotation_path="data/bdd100k/labels/sem_seg_val_rle.json",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bevformer/__init__.py b/vis4d/zoo/bevformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e75b3249c6d7d3c8c0f504a52e962d6b527bd5e
--- /dev/null
+++ b/vis4d/zoo/bevformer/__init__.py
@@ -0,0 +1,9 @@
+"""BEVFormer model zoo."""
+
+from . import bevformer_base, bevformer_tiny, bevformer_vis
+
+AVAILABLE_MODELS = {
+ "bevformer_base": bevformer_base,
+ "bevformer_tiny": bevformer_tiny,
+ "bevformer_vis": bevformer_vis,
+}
diff --git a/vis4d/zoo/bevformer/bevformer_base.py b/vis4d/zoo/bevformer/bevformer_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd491bbc7bb68b0377cb875c08609143b0fe4413
--- /dev/null
+++ b/vis4d/zoo/bevformer/bevformer_base.py
@@ -0,0 +1,157 @@
+# pylint: disable=duplicate-code
+"""BEVFormer base with ResNet-101-DCN backbone."""
+from __future__ import annotations
+
+from torch.optim.adamw import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import CallbackConnector, MultiSensorDataConnector
+from vis4d.eval.nuscenes import NuScenesDet3DEvaluator
+from vis4d.model.detect3d.bevformer import BEVFormer
+from vis4d.op.base import ResNet
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.bevformer.data import (
+ CONN_NUSC_BBOX_3D_TEST,
+ CONN_NUSC_DET3D_EVAL,
+ get_nusc_cfg,
+ nuscenes_class_map,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for BEVFormer on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="bevformer_base")
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 1
+ params.workers_per_gpu = 4
+ params.lr = 2e-4
+ params.num_epochs = 24
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/nuscenes"
+ version = "v1.0-trainval"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_nusc_cfg(
+ data_root=data_root,
+ version=version,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet,
+ resnet_name="resnet101",
+ trainable_layers=3,
+ style="caffe",
+ stages_with_dcn=(False, False, True, True),
+ )
+
+ config.model = class_config(
+ BEVFormer,
+ basemodel=basemodel,
+ weights="https://github.com/zhiqi-li/storage/releases/download/v1.0/bevformer_r101_dcn_24ep.pth", # pylint: disable=line-too-long
+ )
+
+ config.loss = None
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(AdamW, lr=params.lr, weight_decay=0.01),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=1.0 / 3, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(CosineAnnealingLR, T_max=params.num_epochs),
+ ),
+ ],
+ param_groups=[{"custom_keys": ["basemodel"], "lr_mult": 0.1}],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = None
+
+ config.test_data_connector = class_config(
+ MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ class_map=nuscenes_class_map,
+ velocity_thres=0.2,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = 35
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bevformer/bevformer_tiny.py b/vis4d/zoo/bevformer/bevformer_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e850ef4c442e77401c7c45ea265e3333828bffbf
--- /dev/null
+++ b/vis4d/zoo/bevformer/bevformer_tiny.py
@@ -0,0 +1,195 @@
+# pylint: disable=duplicate-code
+"""BEVFormer tiny with ResNet-50 backbone."""
+from __future__ import annotations
+
+from torch.optim.adamw import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import CallbackConnector, MultiSensorDataConnector
+from vis4d.eval.nuscenes import NuScenesDet3DEvaluator
+from vis4d.model.detect3d.bevformer import BEVFormer
+from vis4d.op.base import ResNet
+from vis4d.op.detect3d.bevformer import BEVFormerHead
+from vis4d.op.detect3d.bevformer.encoder import (
+ BEVFormerEncoder,
+ BEVFormerEncoderLayer,
+)
+from vis4d.op.detect3d.bevformer.spatial_cross_attention import (
+ MSDeformableAttention3D,
+ SpatialCrossAttention,
+)
+from vis4d.op.detect3d.bevformer.transformer import PerceptionTransformer
+from vis4d.op.fpp.fpn import FPN
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.bevformer.data import (
+ CONN_NUSC_BBOX_3D_TEST,
+ CONN_NUSC_DET3D_EVAL,
+ get_nusc_cfg,
+ nuscenes_class_map,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for BEVFormer on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="bevformer_tiny")
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 1
+ params.workers_per_gpu = 4
+ params.lr = 2e-4
+ params.num_epochs = 24
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/nuscenes"
+ version = "v1.0-trainval"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_nusc_cfg(
+ data_root=data_root,
+ version=version,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ scale_factor=0.5,
+ style="pytorch",
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", trainable_layers=3, pretrained=True
+ )
+
+ config.model = class_config(
+ BEVFormer,
+ basemodel=basemodel,
+ fpn=class_config(
+ FPN,
+ in_channels_list=[2048],
+ out_channels=256,
+ extra_blocks=None,
+ start_index=5,
+ ),
+ pts_bbox_head=class_config(
+ BEVFormerHead,
+ transformer=class_config(
+ PerceptionTransformer,
+ encoder=class_config(
+ BEVFormerEncoder,
+ layer=class_config(
+ BEVFormerEncoderLayer,
+ cross_attn=class_config(
+ SpatialCrossAttention,
+ deformable_attention=class_config(
+ MSDeformableAttention3D,
+ num_levels=1,
+ ),
+ ),
+ ),
+ num_layers=3,
+ ),
+ ),
+ bev_h=50,
+ bev_w=50,
+ ),
+ weights="https://github.com/zhiqi-li/storage/releases/download/v1.0/bevformer_tiny_epoch_24.pth", # pylint: disable=line-too-long
+ )
+
+ config.loss = None
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(AdamW, lr=params.lr, weight_decay=0.01),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=1.0 / 3, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(CosineAnnealingLR, T_max=params.num_epochs),
+ ),
+ ],
+ param_groups=[{"custom_keys": ["basemodel"], "lr_mult": 0.1}],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = None
+
+ config.test_data_connector = class_config(
+ MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ class_map=nuscenes_class_map,
+ velocity_thres=0.2,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = 35
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bevformer/bevformer_vis.py b/vis4d/zoo/bevformer/bevformer_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..75a98eeac13b4015821e3d0ef2ea56485d34170b
--- /dev/null
+++ b/vis4d/zoo/bevformer/bevformer_vis.py
@@ -0,0 +1,63 @@
+"""BEVFormer Visualizaion for NuScenes Example."""
+
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig
+from vis4d.engine.callbacks import VisualizerCallback
+from vis4d.engine.connectors import MultiSensorCallbackConnector
+from vis4d.vis.image.bbox3d_visualizer import MultiCameraBBox3DVisualizer
+from vis4d.zoo.base import get_default_callbacks_cfg
+from vis4d.zoo.bevformer.bevformer_base import (
+ get_config as get_bevformer_config,
+)
+from vis4d.zoo.bevformer.data import (
+ CONN_NUSC_BBOX_3D_VIS,
+ NUSC_CAMERAS,
+ nuscenes_class_map,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for BEVFormer on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_bevformer_config().ref_mode()
+
+ config.experiment_name = "bevformer_vis"
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(
+ MultiCameraBBox3DVisualizer,
+ cat_mapping=nuscenes_class_map,
+ width=2,
+ camera_near_clip=0.15,
+ cameras=NUSC_CAMERAS,
+ vis_freq=1,
+ plot_trajectory=False,
+ ),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ MultiSensorCallbackConnector,
+ key_mapping=CONN_NUSC_BBOX_3D_VIS,
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ return config.value_mode()
diff --git a/vis4d/zoo/bevformer/data.py b/vis4d/zoo/bevformer/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bf1642720f4c19e2be2353e27c79c05e3614d66
--- /dev/null
+++ b/vis4d/zoo/bevformer/data.py
@@ -0,0 +1,199 @@
+"""BEVFormer NuScenes data config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.loader import multi_sensor_collate
+from vis4d.data.transforms import compose
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeImages,
+ ResizeIntrinsics,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import get_inference_dataloaders_cfg
+from vis4d.zoo.base.datasets.nuscenes import (
+ get_nusc_mini_val_cfg,
+ get_nusc_val_cfg,
+)
+
+nuscenes_class_map = {
+ "car": 0,
+ "truck": 1,
+ "construction_vehicle": 2,
+ "bus": 3,
+ "trailer": 4,
+ "barrier": 5,
+ "motorcycle": 6,
+ "bicycle": 7,
+ "pedestrian": 8,
+ "traffic_cone": 9,
+}
+
+NUSC_SENSORS = [
+ "LIDAR_TOP",
+ "CAM_FRONT",
+ "CAM_FRONT_RIGHT",
+ "CAM_FRONT_LEFT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT",
+]
+
+NUSC_CAMERAS = [
+ "CAM_FRONT",
+ "CAM_FRONT_RIGHT",
+ "CAM_FRONT_LEFT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT",
+ "CAM_BACK_RIGHT",
+]
+
+CONN_NUSC_BBOX_3D_TEST = {
+ "images": data_key(K.images, sensors=NUSC_CAMERAS),
+ "can_bus": "can_bus",
+ "scene_names": K.sequence_names,
+ "cam_intrinsics": data_key(K.intrinsics, sensors=NUSC_CAMERAS),
+ "cam_extrinsics": data_key(K.extrinsics, sensors=NUSC_CAMERAS),
+ "lidar_extrinsics": data_key(K.extrinsics, sensors=["LIDAR_TOP"]),
+}
+
+CONN_NUSC_BBOX_3D_VIS = {
+ "images": data_key(K.original_images, sensors=NUSC_CAMERAS),
+ "image_names": data_key(K.sample_names, sensors=NUSC_CAMERAS),
+ "boxes3d": pred_key("boxes_3d"),
+ "intrinsics": data_key(K.intrinsics, sensors=NUSC_CAMERAS),
+ "extrinsics": data_key(K.extrinsics, sensors=NUSC_CAMERAS),
+ "scores": pred_key("scores_3d"),
+ "class_ids": pred_key("class_ids"),
+ "sequence_names": data_key(K.sequence_names),
+}
+
+CONN_NUSC_DET3D_EVAL = {
+ "tokens": data_key("token"),
+ "boxes_3d": pred_key("boxes_3d"),
+ "velocities": pred_key("velocities"),
+ "class_ids": pred_key("class_ids"),
+ "scores_3d": pred_key("scores_3d"),
+}
+
+
+def get_test_dataloader(
+ test_dataset: ConfigDict,
+ shape: tuple[int, int],
+ mean: list[float],
+ std: list[float],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default test dataloader for nuScenes tracking."""
+ test_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=shape,
+ keep_ratio=True,
+ sensors=NUSC_CAMERAS,
+ ),
+ class_config(ResizeImages, sensors=NUSC_CAMERAS),
+ class_config(ResizeIntrinsics, sensors=NUSC_CAMERAS),
+ class_config(
+ NormalizeImages, mean=mean, std=std, sensors=NUSC_CAMERAS
+ ),
+ ]
+
+ test_preprocess_cfg = class_config(compose, transforms=test_transforms)
+
+ test_batch_transforms = [
+ class_config(PadImages, sensors=NUSC_CAMERAS),
+ class_config(ToTensor, sensors=NUSC_SENSORS),
+ ]
+
+ test_batchprocess_cfg = class_config(
+ compose, transforms=test_batch_transforms
+ )
+
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ video_based_inference=True,
+ batchprocess_cfg=test_batchprocess_cfg,
+ collate_fn=multi_sensor_collate,
+ sensors=NUSC_SENSORS,
+ )
+
+
+def get_nusc_cfg(
+ data_root: str = "data/nuscenes",
+ version: str = "v1.0-trainval",
+ train_split: str = "train",
+ test_split: str = "val",
+ data_backend: None | ConfigDict = None,
+ scale_factor: float = 1.0,
+ style: str = "caffe",
+ samples_per_gpu: int = 1,
+ workers_per_gpu: int = 4,
+) -> DataConfig:
+ """Get the default config for nuScenes tracking."""
+ data = DataConfig()
+
+ shape = (int(900 * scale_factor), int(1600 * scale_factor))
+
+ if style == "pytorch":
+ mean = [123.675, 116.28, 103.53]
+ std = [58.395, 57.12, 57.375]
+ image_channel_mode = "RGB"
+ elif style == "caffe":
+ mean = [103.530, 116.280, 123.675]
+ std = [1.0, 1.0, 1.0]
+ image_channel_mode = "BGR"
+ else:
+ raise ValueError(f"Unknown style {style}")
+
+ if version == "v1.0-mini": # pragma: no cover
+ assert train_split == "mini_train"
+ assert test_split == "mini_val"
+ test_dataset = get_nusc_mini_val_cfg(
+ data_root=data_root,
+ image_channel_mode=image_channel_mode,
+ data_backend=data_backend,
+ cached_file_path=f"{data_root}/bevformer_mini_val.pkl",
+ )
+ elif version == "v1.0-trainval":
+ assert train_split == "train"
+ assert test_split == "val"
+ test_dataset = get_nusc_val_cfg(
+ data_root=data_root,
+ image_channel_mode=image_channel_mode,
+ data_backend=data_backend,
+ cached_file_path=f"{data_root}/bevformer_val.pkl",
+ )
+ else:
+ # TODO: Add support for v1.0-test
+ raise ValueError(f"Unknown version {version}")
+
+ # TODO: Add train dataloader
+ data.train_dataloader = None
+
+ data.test_dataloader = get_test_dataloader(
+ test_dataset,
+ shape,
+ mean,
+ std,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ return data
diff --git a/vis4d/zoo/cc_3dt/__init__.py b/vis4d/zoo/cc_3dt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9ba88f85d296f5fd6972145e5e3107cfd942f47
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/__init__.py
@@ -0,0 +1,15 @@
+"""CC-3DT Model Zoo."""
+
+from . import (
+ cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc,
+ cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc,
+ cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc,
+ cc_3dt_nusc_vis,
+)
+
+AVAILABLE_MODELS = {
+ "cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc": cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc,
+ "cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc": cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc,
+ "cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc": cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc, # pylint: disable=line-too-long
+ "cc_3dt_nusc_vis": cc_3dt_nusc_vis,
+}
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_bevformer_base_velo_lstm_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_bevformer_base_velo_lstm_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddf9740fcd307a6e9eb6317153c40e7f4a815c37
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_bevformer_base_velo_lstm_nusc.py
@@ -0,0 +1,113 @@
+# pylint: disable=duplicate-code
+"""CC-3DT with BEV detector on nuScenes."""
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig, ExperimentConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.nuscenes import NuScenes
+from vis4d.data.datasets.nuscenes_detection import NuScenesDetection
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.connectors import MultiSensorDataConnector, data_key
+from vis4d.model.motion.velo_lstm import VeloLSTM
+from vis4d.model.track3d.cc_3dt import CC3DT
+from vis4d.op.base import ResNet
+from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation
+from vis4d.state.track3d.cc_3dt import CC3DTrackGraph
+from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc import (
+ get_config as get_velo_lstm_cfg,
+)
+from vis4d.zoo.cc_3dt.data import CONN_NUSC_BBOX_3D_TEST, get_test_dataloader
+
+CONN_NUSC_BBOX_3D_TEST = {
+ "images_list": data_key(K.images, sensors=NuScenes.CAMERAS),
+ "images_hw": data_key(K.original_hw, sensors=NuScenes.CAMERAS),
+ "intrinsics_list": data_key(K.intrinsics, sensors=NuScenes.CAMERAS),
+ "extrinsics_list": data_key(K.extrinsics, sensors=NuScenes.CAMERAS),
+ "frame_ids": K.frame_ids,
+ "pred_boxes3d": data_key("pred_boxes3d", sensors=["LIDAR_TOP"]),
+ "pred_boxes3d_classes": data_key(
+ "pred_boxes3d_classes", sensors=["LIDAR_TOP"]
+ ),
+ "pred_boxes3d_scores": data_key(
+ "pred_boxes3d_scores", sensors=["LIDAR_TOP"]
+ ),
+ "pred_boxes3d_velocities": data_key(
+ "pred_boxes3d_velocities", sensors=["LIDAR_TOP"]
+ ),
+}
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for CC-3DT on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_velo_lstm_cfg().ref_mode()
+
+ config.experiment_name = "cc_3dt_bevformer_base_velo_lstm_nusc"
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ config.pure_detection = ""
+
+ data = DataConfig()
+
+ data.train_dataloader = None
+
+ test_dataset = class_config(
+ NuScenesDetection,
+ data_root="data/nuscenes",
+ version="v1.0-trainval",
+ split="val",
+ keys_to_load=[K.images, K.original_images, K.boxes3d],
+ data_backend=class_config(HDF5Backend),
+ pure_detection=config.pure_detection,
+ cache_as_binary=True,
+ cached_file_path="data/nuscenes/val.pkl",
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ test_dataset=test_dataset, samples_per_gpu=1, workers_per_gpu=4
+ )
+
+ config.data = data
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3
+ )
+
+ track_graph = class_config(
+ CC3DTrackGraph,
+ track=class_config(
+ CC3DTrackAssociation, init_score_thr=0.2, obj_score_thr=0.1
+ ),
+ motion_model="VeloLSTM",
+ lstm_model=class_config(VeloLSTM, weights=config.velo_lstm_ckpt),
+ update_3d_score=False,
+ add_backdrops=False,
+ )
+
+ config.model = class_config(
+ CC3DT,
+ basemodel=basemodel,
+ track_graph=track_graph,
+ detection_range=[40, 40, 40, 50, 50, 50, 50, 50, 30, 30],
+ )
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.test_data_connector = class_config(
+ MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST
+ )
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..53aa553179009485cef5eab80576cf11c7bd4421
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc.py
@@ -0,0 +1,200 @@
+# pylint: disable=duplicate-code
+"""CC-3DT with Faster-RCNN ResNet-101 detector using KF3D motion model."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.datasets.nuscenes import nuscenes_class_map
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ MultiSensorDataConnector,
+)
+from vis4d.eval.nuscenes import (
+ NuScenesDet3DEvaluator,
+ NuScenesTrack3DEvaluator,
+)
+from vis4d.op.base import ResNet
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.cc_3dt.data import (
+ CONN_NUSC_BBOX_3D_TEST,
+ CONN_NUSC_DET3D_EVAL,
+ CONN_NUSC_TRACK3D_EVAL,
+ get_nusc_cfg,
+)
+from vis4d.zoo.cc_3dt.model import CONN_BBOX_3D_TRAIN, get_cc_3dt_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for cc-3dt on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc")
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 4
+ params.workers_per_gpu = 4
+ params.lr = 0.01
+ params.num_epochs = 24
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/nuscenes"
+ version = "v1.0-trainval"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_nusc_cfg(
+ data_root=data_root,
+ version=version,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_cc_3dt_cfg(
+ num_classes=len(nuscenes_class_map), basemodel=basemodel, fps=2
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(LinearLR, start_factor=0.1, total_iters=1000),
+ end=1000,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[16, 22], gamma=0.1),
+ ),
+ ],
+ param_groups=[
+ {
+ "custom_keys": [
+ "faster_rcnn_head.rpn_head.rpn_cls.weight",
+ "faster_rcnn_head.rpn_head.rpn_box.weight",
+ "faster_rcnn_head.roi_head.fc_cls.weight",
+ "faster_rcnn_head.roi_head.fc_reg.weight",
+ "bbox_3d_head.dep_convs.0.weight",
+ "bbox_3d_head.dep_convs.1.weight",
+ "bbox_3d_head.dep_convs.2.weight",
+ "bbox_3d_head.dep_convs.3.weight",
+ "bbox_3d_head.dim_convs.0.weight",
+ "bbox_3d_head.dim_convs.1.weight",
+ "bbox_3d_head.dim_convs.2.weight",
+ "bbox_3d_head.dim_convs.3.weight",
+ "bbox_3d_head.rot_convs.0.weight"
+ "bbox_3d_head.rot_convs.1.weight",
+ "bbox_3d_head.rot_convs.2.weight",
+ "bbox_3d_head.rot_convs.3.weight",
+ "bbox_3d_head.cen_2d_convs.0.weight",
+ "bbox_3d_head.cen_2d_convs.1.weight",
+ "bbox_3d_head.cen_2d_convs.2.weight",
+ "bbox_3d_head.cen_2d_convs.3.weight",
+ "bbox_3d_head.fc_dep.weight",
+ "bbox_3d_head.fc_dep_uncer.weight",
+ "bbox_3d_head.fc_dim.weight",
+ "bbox_3d_head.fc_rot.weight",
+ "bbox_3d_head.fc_cen_2d.weight",
+ ],
+ "lr_mult": 10.0,
+ }
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_3D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL
+ ),
+ )
+ )
+
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(NuScenesTrack3DEvaluator),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = 10
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_pure_det_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_pure_det_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..6653033cdde78c86b51a18b029801b5b1bb30445
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_pure_det_nusc.py
@@ -0,0 +1,88 @@
+"""CC-3DT with Faster-RCNN ResNet-101 detector generating pure detection."""
+
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig
+from vis4d.data.datasets.nuscenes import NuScenes, nuscenes_class_map
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import MultiSensorCallbackConnector
+from vis4d.eval.nuscenes import NuScenesDet3DEvaluator
+from vis4d.op.base import ResNet
+from vis4d.zoo.base import get_default_callbacks_cfg
+from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc import (
+ get_config as get_kf3d_config,
+)
+from vis4d.zoo.cc_3dt.data import CONN_NUSC_DET3D_EVAL, get_nusc_cfg
+from vis4d.zoo.cc_3dt.model import get_cc_3dt_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Get config."""
+ config = get_kf3d_config().ref_mode()
+
+ config.experiment_name = "cc_3dt_frcnn_r101_fpn_pure_det_nusc"
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/nuscenes"
+ version = "v1.0-trainval"
+ train_split = "train"
+ test_split = "train"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_nusc_cfg(
+ data_root=data_root,
+ version=version,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3
+ )
+
+ config.model, _ = get_cc_3dt_cfg(
+ num_classes=len(nuscenes_class_map),
+ basemodel=basemodel,
+ fps=2,
+ pure_det=True,
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ save_only=True,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ MultiSensorCallbackConnector,
+ key_mapping=CONN_NUSC_DET3D_EVAL,
+ sensors=NuScenes.CAMERAS,
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..343ede71b15887f9c5c4cd9a148c3a8e1f2bce89
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc.py
@@ -0,0 +1,46 @@
+"""CC-3DT inference with Faster-RCNN ResNet-101 detector using VeloLSTM."""
+
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig
+from vis4d.data.datasets.nuscenes import nuscenes_class_map
+from vis4d.model.motion.velo_lstm import VeloLSTM
+from vis4d.op.base import ResNet
+from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc import (
+ get_config as get_kf3d_cfg,
+)
+from vis4d.zoo.cc_3dt.model import get_cc_3dt_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for cc-3dt on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_kf3d_cfg().ref_mode()
+
+ config.experiment_name = "cc_3dt_frcnn_r101_fpn_velo_lstm_24e_nusc"
+
+ config.velo_lstm_ckpt = ""
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3
+ )
+
+ config.model, _ = get_cc_3dt_cfg(
+ num_classes=len(nuscenes_class_map),
+ basemodel=basemodel,
+ motion_model="VeloLSTM",
+ lstm_model=class_config(VeloLSTM, weights=config.velo_lstm_ckpt),
+ fps=2,
+ )
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc.py b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..169852be7e143b4cea7a7cef9c23c11a0d4f6c0e
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc.py
@@ -0,0 +1,200 @@
+# pylint: disable=duplicate-code
+"""CC-3DT with Faster-RCNN ResNet-50 detector using KF3D motion model."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.datasets.nuscenes import nuscenes_class_map
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ MultiSensorDataConnector,
+)
+from vis4d.eval.nuscenes import (
+ NuScenesDet3DEvaluator,
+ NuScenesTrack3DEvaluator,
+)
+from vis4d.op.base import ResNet
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.cc_3dt.data import (
+ CONN_NUSC_BBOX_3D_TEST,
+ CONN_NUSC_DET3D_EVAL,
+ CONN_NUSC_TRACK3D_EVAL,
+ get_nusc_cfg,
+)
+from vis4d.zoo.cc_3dt.model import CONN_BBOX_3D_TRAIN, get_cc_3dt_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for cc-3dt on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc")
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 4
+ params.workers_per_gpu = 4
+ params.lr = 0.01
+ params.num_epochs = 12
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/nuscenes"
+ version = "v1.0-trainval"
+ train_split = "train"
+ test_split = "val"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_nusc_cfg(
+ data_root=data_root,
+ version=version,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_cc_3dt_cfg(
+ num_classes=len(nuscenes_class_map), basemodel=basemodel, fps=2
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(LinearLR, start_factor=0.1, total_iters=1000),
+ end=1000,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ param_groups=[
+ {
+ "custom_keys": [
+ "faster_rcnn_head.rpn_head.rpn_cls.weight",
+ "faster_rcnn_head.rpn_head.rpn_box.weight",
+ "faster_rcnn_head.roi_head.fc_cls.weight",
+ "faster_rcnn_head.roi_head.fc_reg.weight",
+ "bbox_3d_head.dep_convs.0.weight",
+ "bbox_3d_head.dep_convs.1.weight",
+ "bbox_3d_head.dep_convs.2.weight",
+ "bbox_3d_head.dep_convs.3.weight",
+ "bbox_3d_head.dim_convs.0.weight",
+ "bbox_3d_head.dim_convs.1.weight",
+ "bbox_3d_head.dim_convs.2.weight",
+ "bbox_3d_head.dim_convs.3.weight",
+ "bbox_3d_head.rot_convs.0.weight"
+ "bbox_3d_head.rot_convs.1.weight",
+ "bbox_3d_head.rot_convs.2.weight",
+ "bbox_3d_head.rot_convs.3.weight",
+ "bbox_3d_head.cen_2d_convs.0.weight",
+ "bbox_3d_head.cen_2d_convs.1.weight",
+ "bbox_3d_head.cen_2d_convs.2.weight",
+ "bbox_3d_head.cen_2d_convs.3.weight",
+ "bbox_3d_head.fc_dep.weight",
+ "bbox_3d_head.fc_dep_uncer.weight",
+ "bbox_3d_head.fc_dim.weight",
+ "bbox_3d_head.fc_rot.weight",
+ "bbox_3d_head.fc_cen_2d.weight",
+ ],
+ "lr_mult": 10.0,
+ }
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_3D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL
+ ),
+ )
+ )
+
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(NuScenesTrack3DEvaluator),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = 10
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_nusc_test.py b/vis4d/zoo/cc_3dt/cc_3dt_nusc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ed480340feab7908609a4b2304466776e49b768
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_nusc_test.py
@@ -0,0 +1,106 @@
+# pylint: disable=duplicate-code
+"""CC-3DT with BEV detector on nuScenes."""
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig, ExperimentConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.nuscenes_detection import NuScenesDetection
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import CallbackConnector
+from vis4d.eval.nuscenes import (
+ NuScenesDet3DEvaluator,
+ NuScenesTrack3DEvaluator,
+)
+from vis4d.zoo.base import get_default_callbacks_cfg
+from vis4d.zoo.cc_3dt.cc_3dt_bevformer_base_velo_lstm_nusc import (
+ get_config as get_cc_3dt_config,
+)
+from vis4d.zoo.cc_3dt.data import (
+ CONN_NUSC_DET3D_EVAL,
+ CONN_NUSC_TRACK3D_EVAL,
+ get_test_dataloader,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for CC-3DT on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_cc_3dt_config().ref_mode()
+
+ config.experiment_name = "cc_3dt_nusc_test"
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ config.pure_detection = ""
+
+ data = DataConfig()
+
+ data.train_dataloader = None
+
+ test_dataset = class_config(
+ NuScenesDetection,
+ data_root="data/nuscenes",
+ version="v1.0-test",
+ split="test",
+ keys_to_load=[K.images, K.original_images],
+ data_backend=class_config(HDF5Backend),
+ pure_detection=config.pure_detection,
+ cache_as_binary=True,
+ cached_file_path="data/nuscenes/test.pkl",
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ test_dataset=test_dataset, samples_per_gpu=1, workers_per_gpu=4
+ )
+
+ config.data = data
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root="data/nuscenes",
+ version="v1.0-test",
+ split="test",
+ save_only=True,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL
+ ),
+ )
+ )
+
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(NuScenesTrack3DEvaluator),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_nusc_vis.py b/vis4d/zoo/cc_3dt/cc_3dt_nusc_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..f629cd412754ca35cbfe2e1a4578a209e88c27ca
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_nusc_vis.py
@@ -0,0 +1,77 @@
+"""CC-3DT Visualizaion for NuScenes Example."""
+
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig
+from vis4d.data.datasets.nuscenes import NuScenes, nuscenes_class_map
+from vis4d.engine.callbacks import VisualizerCallback
+from vis4d.engine.connectors import MultiSensorCallbackConnector
+from vis4d.vis.image.bbox3d_visualizer import MultiCameraBBox3DVisualizer
+from vis4d.vis.image.bev_visualizer import BEVBBox3DVisualizer
+from vis4d.zoo.base import get_default_callbacks_cfg
+from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r50_fpn_kf3d_12e_nusc import (
+ get_config as get_cc_3dt_config,
+)
+from vis4d.zoo.cc_3dt.data import (
+ CONN_NUSC_BBOX_3D_VIS,
+ CONN_NUSC_BEV_BBOX_3D_VIS,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for cc-3dt on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_cc_3dt_config().ref_mode()
+
+ config.experiment_name = "cc_3dt_nusc_vis"
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(
+ MultiCameraBBox3DVisualizer,
+ cat_mapping=nuscenes_class_map,
+ width=2,
+ camera_near_clip=0.15,
+ cameras=NuScenes.CAMERAS,
+ vis_freq=1,
+ ),
+ output_dir=config.output_dir,
+ save_prefix="boxes3d",
+ test_connector=class_config(
+ MultiSensorCallbackConnector,
+ key_mapping=CONN_NUSC_BBOX_3D_VIS,
+ ),
+ )
+ )
+
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BEVBBox3DVisualizer, width=2, vis_freq=1),
+ output_dir=config.output_dir,
+ save_prefix="bev",
+ test_connector=class_config(
+ MultiSensorCallbackConnector,
+ key_mapping=CONN_NUSC_BEV_BBOX_3D_VIS,
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/cc_3dt_pp_kf3d.py b/vis4d/zoo/cc_3dt/cc_3dt_pp_kf3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5eed3563ebba92b3258c4a632d58861b340e047
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/cc_3dt_pp_kf3d.py
@@ -0,0 +1,175 @@
+# pylint: disable=duplicate-code
+"""CC-3DT++ on nuScenes."""
+from __future__ import annotations
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig, ExperimentConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.nuscenes import NuScenes
+from vis4d.data.datasets.nuscenes_detection import NuScenesDetection
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ MultiSensorDataConnector,
+ data_key,
+)
+from vis4d.eval.nuscenes import (
+ NuScenesDet3DEvaluator,
+ NuScenesTrack3DEvaluator,
+)
+from vis4d.model.track3d.cc_3dt import CC3DT
+from vis4d.op.base import ResNet
+from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation
+from vis4d.state.track3d.cc_3dt import CC3DTrackGraph
+from vis4d.zoo.base import get_default_callbacks_cfg
+from vis4d.zoo.cc_3dt.cc_3dt_frcnn_r101_fpn_kf3d_24e_nusc import (
+ get_config as get_kf3d_cfg,
+)
+from vis4d.zoo.cc_3dt.data import (
+ CONN_NUSC_DET3D_EVAL,
+ CONN_NUSC_TRACK3D_EVAL,
+ get_test_dataloader,
+)
+
+CONN_NUSC_BBOX_3D_TEST = {
+ "images_list": data_key(K.images, sensors=NuScenes.CAMERAS),
+ "images_hw": data_key(K.original_hw, sensors=NuScenes.CAMERAS),
+ "intrinsics_list": data_key(K.intrinsics, sensors=NuScenes.CAMERAS),
+ "extrinsics_list": data_key(K.extrinsics, sensors=NuScenes.CAMERAS),
+ "frame_ids": K.frame_ids,
+ "pred_boxes3d": data_key("pred_boxes3d", sensors=["LIDAR_TOP"]),
+ "pred_boxes3d_classes": data_key(
+ "pred_boxes3d_classes", sensors=["LIDAR_TOP"]
+ ),
+ "pred_boxes3d_scores": data_key(
+ "pred_boxes3d_scores", sensors=["LIDAR_TOP"]
+ ),
+ "pred_boxes3d_velocities": data_key(
+ "pred_boxes3d_velocities", sensors=["LIDAR_TOP"]
+ ),
+}
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for CC-3DT on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_kf3d_cfg().ref_mode()
+
+ config.experiment_name = "cc_3dt_pp_kf3d_nusc"
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ config.pure_detection = ""
+
+ data_root = "data/nuscenes"
+ version = "v1.0-trainval"
+ test_split = "val"
+
+ data = DataConfig()
+
+ data.train_dataloader = None
+
+ test_dataset = class_config(
+ NuScenesDetection,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ keys_to_load=[K.images, K.original_images],
+ data_backend=class_config(HDF5Backend),
+ pure_detection=config.pure_detection,
+ cache_as_binary=True,
+ cached_file_path=f"{data_root}/val.pkl",
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ test_dataset=test_dataset, samples_per_gpu=1, workers_per_gpu=1
+ )
+
+ config.data = data
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet101", pretrained=True, trainable_layers=3
+ )
+
+ track_graph = class_config(
+ CC3DTrackGraph,
+ track=class_config(
+ CC3DTrackAssociation,
+ init_score_thr=0.2,
+ obj_score_thr=0.1,
+ match_score_thr=0.3,
+ nms_class_iou_thr=0.3,
+ bbox_affinity_weight=0.75,
+ with_velocities=True,
+ ),
+ update_3d_score=False,
+ use_velocities=True,
+ add_backdrops=False,
+ )
+
+ config.model = class_config(
+ CC3DT,
+ basemodel=basemodel,
+ track_graph=track_graph,
+ detection_range=[40, 40, 40, 50, 50, 50, 50, 50, 30, 30],
+ )
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.test_data_connector = class_config(
+ MultiSensorDataConnector, key_mapping=CONN_NUSC_BBOX_3D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg(config.output_dir)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesDet3DEvaluator,
+ data_root=data_root,
+ version=version,
+ split=test_split,
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_DET3D_EVAL
+ ),
+ )
+ )
+
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ NuScenesTrack3DEvaluator, metadata=("use_camera", "use_radar")
+ ),
+ save_predictions=True,
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_NUSC_TRACK3D_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/data.py b/vis4d/zoo/cc_3dt/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ed579cd778160145f7248c9829ca5aa468d0604
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/data.py
@@ -0,0 +1,240 @@
+"""CC-3DT NuScenes data config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.nuscenes import NuScenes
+from vis4d.data.loader import multi_sensor_collate
+from vis4d.data.reference import MultiViewDataset, UniformViewSampler
+from vis4d.data.transforms import RandomApply, compose
+from vis4d.data.transforms.flip import (
+ FlipBoxes2D,
+ FlipBoxes3D,
+ FlipImages,
+ FlipIntrinsics,
+)
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.post_process import PostProcessBoxes2D
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeImages,
+ ResizeIntrinsics,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import (
+ get_inference_dataloaders_cfg,
+ get_train_dataloader_cfg,
+)
+from vis4d.zoo.base.datasets.nuscenes import (
+ get_nusc_mini_val_cfg,
+ get_nusc_mono_mini_train_cfg,
+ get_nusc_mono_train_cfg,
+ get_nusc_train_cfg,
+ get_nusc_val_cfg,
+)
+
+CONN_NUSC_DET3D_EVAL = {
+ "tokens": data_key("token"),
+ "boxes_3d": pred_key("boxes_3d"),
+ "velocities": pred_key("velocities"),
+ "class_ids": pred_key("class_ids"),
+ "scores_3d": pred_key("scores_3d"),
+}
+
+CONN_NUSC_TRACK3D_EVAL = {
+ "tokens": data_key("token"),
+ "boxes_3d": pred_key("boxes_3d"),
+ "velocities": pred_key("velocities"),
+ "class_ids": pred_key("class_ids"),
+ "scores_3d": pred_key("scores_3d"),
+ "track_ids": pred_key("track_ids"),
+}
+
+CONN_NUSC_BBOX_3D_TEST = {
+ "images": data_key(K.images, sensors=NuScenes.CAMERAS),
+ "images_hw": data_key(K.original_hw, sensors=NuScenes.CAMERAS),
+ "intrinsics": data_key(K.intrinsics, sensors=NuScenes.CAMERAS),
+ "extrinsics": data_key(K.extrinsics, sensors=NuScenes.CAMERAS),
+ "frame_ids": K.frame_ids,
+}
+
+CONN_NUSC_BBOX_3D_VIS = {
+ "images": data_key(K.original_images, sensors=NuScenes.CAMERAS),
+ "image_names": data_key(K.sample_names, sensors=NuScenes.CAMERAS),
+ "boxes3d": pred_key("boxes_3d"),
+ "intrinsics": data_key(K.intrinsics, sensors=NuScenes.CAMERAS),
+ "extrinsics": data_key(K.extrinsics, sensors=NuScenes.CAMERAS),
+ "scores": pred_key("scores_3d"),
+ "class_ids": pred_key("class_ids"),
+ "track_ids": pred_key("track_ids"),
+ "sequence_names": data_key(K.sequence_names),
+}
+
+CONN_NUSC_BEV_BBOX_3D_VIS = {
+ "sample_names": data_key(K.sample_names, sensors=["LIDAR_TOP"]),
+ "boxes3d": pred_key("boxes_3d"),
+ "extrinsics": data_key(K.extrinsics, sensors=["LIDAR_TOP"]),
+ "track_ids": pred_key("track_ids"),
+ "sequence_names": data_key(K.sequence_names),
+}
+
+
+def get_train_dataloader(
+ train_dataset: ConfigDict, samples_per_gpu: int, workers_per_gpu: int
+) -> ConfigDict:
+ """Get the default train dataloader for nuScenes tracking."""
+ train_dataset_cfg = class_config(
+ MultiViewDataset,
+ dataset=train_dataset,
+ sampler=class_config(UniformViewSampler, scope=2, num_ref_samples=1),
+ )
+
+ preprocess_transforms = [
+ class_config(GenResizeParameters, shape=(900, 1600), keep_ratio=True),
+ class_config(ResizeImages),
+ class_config(ResizeBoxes2D),
+ ]
+
+ preprocess_transforms.append(
+ class_config(
+ RandomApply,
+ transforms=[
+ class_config(FlipImages),
+ class_config(FlipIntrinsics),
+ class_config(FlipBoxes2D),
+ class_config(FlipBoxes3D),
+ ],
+ probability=0.5,
+ )
+ )
+
+ preprocess_transforms.append(class_config(PostProcessBoxes2D))
+
+ train_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages),
+ class_config(NormalizeImages),
+ class_config(ToTensor),
+ ],
+ )
+
+ return get_train_dataloader_cfg(
+ datasets_cfg=class_config(
+ DataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=train_preprocess_cfg,
+ ),
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ batchprocess_cfg=train_batchprocess_cfg,
+ )
+
+
+def get_test_dataloader(
+ test_dataset: ConfigDict, samples_per_gpu: int, workers_per_gpu: int
+) -> ConfigDict:
+ """Get the default test dataloader for nuScenes tracking."""
+ test_transforms = [
+ class_config(
+ GenResizeParameters,
+ shape=(900, 1600),
+ keep_ratio=True,
+ sensors=NuScenes.CAMERAS,
+ ),
+ class_config(ResizeImages, sensors=NuScenes.CAMERAS),
+ class_config(ResizeIntrinsics, sensors=NuScenes.CAMERAS),
+ ]
+
+ test_preprocess_cfg = class_config(compose, transforms=test_transforms)
+
+ test_batch_transforms = [
+ class_config(PadImages, sensors=NuScenes.CAMERAS),
+ class_config(NormalizeImages, sensors=NuScenes.CAMERAS),
+ class_config(ToTensor, sensors=NuScenes.SENSORS),
+ ]
+
+ test_batchprocess_cfg = class_config(
+ compose, transforms=test_batch_transforms
+ )
+
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ video_based_inference=True,
+ batchprocess_cfg=test_batchprocess_cfg,
+ collate_fn=multi_sensor_collate,
+ sensors=NuScenes.SENSORS,
+ )
+
+
+def get_nusc_cfg(
+ data_root: str = "data/nuscenes",
+ version: str = "v1.0-trainval",
+ train_split: str = "train",
+ test_split: str = "val",
+ data_backend: None | ConfigDict = None,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for nuScenes tracking."""
+ data = DataConfig()
+
+ if version == "v1.0-mini": # pragma: no cover
+ assert train_split == "mini_train"
+ assert test_split == "mini_val"
+ train_dataset = get_nusc_mono_mini_train_cfg(
+ data_root=data_root, data_backend=data_backend
+ )
+ test_dataset = get_nusc_mini_val_cfg(
+ data_root=data_root, data_backend=data_backend
+ )
+ elif version == "v1.0-trainval":
+ assert train_split == "train"
+ train_dataset = get_nusc_mono_train_cfg(
+ data_root=data_root, data_backend=data_backend
+ )
+
+ if test_split == "val":
+ test_dataset = get_nusc_val_cfg(
+ data_root=data_root, data_backend=data_backend
+ )
+ elif test_split == "train":
+ test_dataset = get_nusc_train_cfg(
+ data_root=data_root,
+ skip_empty_samples=False,
+ keys_to_load=[K.images, K.original_images, K.boxes3d],
+ data_backend=data_backend,
+ )
+ else:
+ # TODO: Add support for v1.0-test
+ raise ValueError(f"Unknown version {version}")
+
+ data.train_dataloader = get_train_dataloader(
+ train_dataset=train_dataset,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ test_dataset, samples_per_gpu=1, workers_per_gpu=1
+ )
+
+ return data
diff --git a/vis4d/zoo/cc_3dt/model.py b/vis4d/zoo/cc_3dt/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..39d6c9b1935a94b6607aa8cf359fc0174dc25748
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/model.py
@@ -0,0 +1,171 @@
+"""CC-3DT model config."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict, FieldReference
+
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import LossConnector, pred_key, remap_pred_keys
+from vis4d.engine.loss_module import LossModule
+from vis4d.model.track3d.cc_3dt import FasterRCNNCC3DT
+from vis4d.op.box.anchor import AnchorGenerator
+from vis4d.op.detect3d.qd_3dt import Box3DUncertaintyLoss
+from vis4d.op.detect.faster_rcnn import FasterRCNNHead
+from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss
+from vis4d.op.detect.rpn import RPNLoss
+from vis4d.op.loss.common import smooth_l1_loss
+from vis4d.op.track.qdtrack import QDTrackInstanceSimilarityLoss
+from vis4d.state.track3d.cc_3dt import CC3DTrackGraph
+from vis4d.zoo.base import get_callable_cfg
+from vis4d.zoo.base.models.faster_rcnn import (
+ get_default_rcnn_box_codec_cfg,
+ get_default_rpn_box_codec_cfg,
+)
+from vis4d.zoo.base.models.qdtrack import CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D
+from vis4d.zoo.base.models.qdtrack import (
+ CONN_TRACK_LOSS_2D as _CONN_TRACK_LOSS_2D,
+)
+
+PRED_PREFIX = "qdtrack_out"
+
+CONN_RPN_LOSS_2D = {
+ "cls_outs": pred_key(f"{PRED_PREFIX}.detector_out.rpn.cls"),
+ "reg_outs": pred_key(f"{PRED_PREFIX}.detector_out.rpn.box"),
+ "target_boxes": pred_key(f"{PRED_PREFIX}.key_target_boxes"),
+ "images_hw": pred_key(f"{PRED_PREFIX}.key_images_hw"),
+}
+
+CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, PRED_PREFIX)
+
+CONN_TRACK_LOSS_2D = remap_pred_keys(_CONN_TRACK_LOSS_2D, PRED_PREFIX)
+
+CONN_DET_3D_LOSS = {
+ "pred": pred_key("detector_3d_out"),
+ "target": pred_key("detector_3d_target"),
+ "labels": pred_key("detector_3d_labels"),
+}
+
+CONN_BBOX_3D_TRAIN = {
+ "images": K.images,
+ "images_hw": K.input_hw,
+ "intrinsics": K.intrinsics,
+ "boxes2d": K.boxes2d,
+ "boxes3d": K.boxes3d,
+ "boxes3d_classes": K.boxes3d_classes,
+ "boxes3d_track_ids": K.boxes3d_track_ids,
+ "keyframes": "keyframes",
+}
+
+
+def get_cc_3dt_cfg(
+ num_classes: int | FieldReference,
+ basemodel: ConfigDict,
+ pure_det: bool | FieldReference = False,
+ motion_model: str | FieldReference = "KF3D",
+ lstm_model: ConfigDict | None = None,
+ fps: int | FieldReference = 2,
+) -> tuple[ConfigDict, ConfigDict]:
+ """Get CC-3DT model config.
+
+ Args:
+ num_classes (int): Number of classes.
+ basemodel (ConfigDict): Base model config.
+ pure_det (bool, optional): Whether to use pure detection mode.
+ Defaults to False.
+ motion_model (str, optional): Motion model. Defaults to "KF3D".
+ lstm_model (ConfigDict, optional): LSTM model config. Defaults to None.
+ fps (int, optional): FPS. Defaults to 2.
+ """
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ anchor_generator = class_config(
+ AnchorGenerator,
+ scales=[4, 8],
+ ratios=[0.25, 0.5, 1.0, 2.0, 4.0],
+ strides=[4, 8, 16, 32, 64],
+ )
+
+ roi_head = class_config(
+ RCNNHead,
+ num_shared_convs=4,
+ num_classes=num_classes,
+ )
+
+ faster_rcnn_head = class_config(
+ FasterRCNNHead,
+ num_classes=num_classes,
+ anchor_generator=anchor_generator,
+ roi_head=roi_head,
+ )
+
+ track_graph = class_config(
+ CC3DTrackGraph,
+ motion_model=motion_model,
+ lstm_model=lstm_model,
+ fps=fps,
+ )
+
+ model = class_config(
+ FasterRCNNCC3DT,
+ num_classes=num_classes,
+ basemodel=basemodel,
+ faster_rcnn_head=faster_rcnn_head,
+ track_graph=track_graph,
+ pure_det=pure_det,
+ )
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ rpn_box_encoder, _ = get_default_rpn_box_codec_cfg()
+ rcnn_box_encoder, _ = get_default_rcnn_box_codec_cfg()
+
+ rpn_loss = class_config(
+ RPNLoss,
+ anchor_generator=anchor_generator,
+ box_encoder=rpn_box_encoder,
+ loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0),
+ )
+ rcnn_loss = class_config(
+ RCNNLoss,
+ box_encoder=rcnn_box_encoder,
+ num_classes=num_classes,
+ loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0),
+ )
+
+ track_loss = class_config(QDTrackInstanceSimilarityLoss)
+
+ loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": rpn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_RPN_LOSS_2D
+ ),
+ },
+ {
+ "loss": rcnn_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_ROI_LOSS_2D
+ ),
+ "weight": 5.0,
+ },
+ {
+ "loss": track_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_TRACK_LOSS_2D
+ ),
+ },
+ {
+ "loss": class_config(Box3DUncertaintyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_DET_3D_LOSS
+ ),
+ },
+ ],
+ )
+
+ return model, loss
diff --git a/vis4d/zoo/cc_3dt/velo_lstm_bevformer_base_100e_nusc.py b/vis4d/zoo/cc_3dt/velo_lstm_bevformer_base_100e_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c566fd31ffc6557ae914a196ddda888f50a37b
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/velo_lstm_bevformer_base_100e_nusc.py
@@ -0,0 +1,150 @@
+# pylint: disable=duplicate-code
+"""CC-3DT VeloLSTM for BEVFormer on nuScenes."""
+from __future__ import annotations
+
+from torch.optim.adam import Adam
+from torch.optim.lr_scheduler import MultiStepLR
+
+from vis4d.config import class_config
+from vis4d.config.typing import (
+ DataConfig,
+ ExperimentConfig,
+ ExperimentParameters,
+)
+from vis4d.data.datasets.nuscenes_trajectory import NuScenesTrajectory
+from vis4d.engine.connectors import (
+ DataConnector,
+ LossConnector,
+ data_key,
+ pred_key,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.model.motion.velo_lstm import VeloLSTM
+from vis4d.op.motion.velo_lstm import VeloLSTMLoss
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+ get_train_dataloader_cfg,
+)
+
+TRAJ_TRAIN = {"pred_traj": "pred_traj"}
+TRAJ_LOSS = {
+ "loc_preds": pred_key("loc_preds"),
+ "loc_refines": pred_key("loc_refines"),
+ "gt_traj": data_key("gt_traj"),
+}
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for VeloLSTM on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="velo_lstm_bevformer_base_100e_nusc")
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 32
+ params.workers_per_gpu = 4
+ params.lr = 0.005
+ params.num_epochs = 100
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data = DataConfig()
+
+ train_dataset_cfg = class_config(
+ NuScenesTrajectory,
+ detector="cc_3dt_frcnn_r101_fpn",
+ data_root="data/nuscenes",
+ version="v1.0-trainval",
+ split="train",
+ pure_detection="./vis4d-workspace/pure_det/bevformer_base.json",
+ cache_as_binary=True,
+ cached_file_path="data/nuscenes/cc_3dt_bevformer_base_traj_train.pkl",
+ )
+
+ data.train_dataloader = get_train_dataloader_cfg(
+ datasets_cfg=train_dataset_cfg,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ collate_keys=["pred_traj", "gt_traj"],
+ )
+
+ data.test_dataloader = None
+
+ config.data = data
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(VeloLSTM)
+
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(VeloLSTMLoss),
+ "weight": 10.0,
+ "connector": class_config(
+ LossConnector, key_mapping=TRAJ_LOSS
+ ),
+ }
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ Adam, lr=params.lr, amsgrad=True, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ MultiStepLR, milestones=[20, 40, 60, 80], gamma=0.5
+ ),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=TRAJ_TRAIN
+ )
+
+ config.test_data_connector = None
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = 3
+ pl_trainer.check_val_every_n_epoch = 101 # Disable validation
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/cc_3dt/velo_lstm_frcnn_r101_fpn_100e_nusc.py b/vis4d/zoo/cc_3dt/velo_lstm_frcnn_r101_fpn_100e_nusc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3752ef781b7ca2716c2670fdcdd61e4ecddea5f
--- /dev/null
+++ b/vis4d/zoo/cc_3dt/velo_lstm_frcnn_r101_fpn_100e_nusc.py
@@ -0,0 +1,150 @@
+# pylint: disable=duplicate-code
+"""CC-3DT VeloLSTM on nuScenes."""
+from __future__ import annotations
+
+from torch.optim.adam import Adam
+from torch.optim.lr_scheduler import MultiStepLR
+
+from vis4d.config import class_config
+from vis4d.config.typing import (
+ DataConfig,
+ ExperimentConfig,
+ ExperimentParameters,
+)
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.datasets.nuscenes_trajectory import NuScenesTrajectory
+from vis4d.engine.connectors import (
+ DataConnector,
+ LossConnector,
+ data_key,
+ pred_key,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.model.motion.velo_lstm import VeloLSTM
+from vis4d.op.motion.velo_lstm import VeloLSTMLoss
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+ get_train_dataloader_cfg,
+)
+
+TRAJ_TRAIN = {"pred_traj": "pred_traj"}
+TRAJ_LOSS = {
+ "loc_preds": pred_key("loc_preds"),
+ "loc_refines": pred_key("loc_refines"),
+ "gt_traj": data_key("gt_traj"),
+}
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for VeloLSTM on nuScenes.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="velo_lstm_frcnn_r101_fpn_100e_nusc")
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 32
+ params.workers_per_gpu = 4
+ params.lr = 0.005
+ params.num_epochs = 100
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data = DataConfig()
+
+ train_dataset_cfg = class_config(
+ NuScenesTrajectory,
+ detector="cc_3dt_frcnn_r101_fpn",
+ data_root="data/nuscenes",
+ version="v1.0-trainval",
+ split="train",
+ pure_detection="./vis4d-workspace/pure_det/cc_3dt_frcnn_r101_fpn.json",
+ cache_as_binary=True,
+ cached_file_path="data/nuscenes/cc_3dt_frcnn_r101_fpn_traj_train.pkl",
+ )
+
+ data.train_dataloader = get_train_dataloader_cfg(
+ datasets_cfg=class_config(DataPipe, datasets=train_dataset_cfg),
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ collate_keys=["pred_traj", "gt_traj"],
+ )
+
+ data.test_dataloader = None
+
+ config.data = data
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(VeloLSTM)
+
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(VeloLSTMLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=TRAJ_LOSS
+ ),
+ }
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ Adam, lr=params.lr, amsgrad=True, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ MultiStepLR, milestones=[20, 40, 60, 80], gamma=0.5
+ ),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=TRAJ_TRAIN
+ )
+
+ config.test_data_connector = None
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = 3
+ pl_trainer.check_val_every_n_epoch = 101 # Disable validation
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/faster_rcnn/__init__.py b/vis4d/zoo/faster_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aaded2cdd7db40c570a95ea559526a3b37657a8
--- /dev/null
+++ b/vis4d/zoo/faster_rcnn/__init__.py
@@ -0,0 +1,7 @@
+"""Faster-RCNN Model Zoo."""
+
+from . import faster_rcnn_coco
+
+AVAILABLE_MODELS = {
+ "faster_rcnn_coco": faster_rcnn_coco,
+}
diff --git a/vis4d/zoo/faster_rcnn/faster_rcnn_coco.py b/vis4d/zoo/faster_rcnn/faster_rcnn_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7aa6dab689ed88a8403e3ffb08741a512ea7658
--- /dev/null
+++ b/vis4d/zoo/faster_rcnn/faster_rcnn_coco.py
@@ -0,0 +1,170 @@
+# pylint: disable=duplicate-code
+"""Faster RCNN COCO training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.coco import COCODetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.coco import (
+ CONN_COCO_BBOX_EVAL,
+ get_coco_detection_cfg,
+)
+from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Faster-RCNN config dict for the coco detection task.
+
+ This is an example that shows how to set up a training experiment for the
+ COCO detection task.
+
+ Note that the high level params are exposed in the config. This allows
+ to easily change them from the command line.
+ E.g.:
+ >>> python -m vis4d.engine.run fit --config configs/faster_rcnn/faster_rcnn_coco.py --config.params.lr 0.001
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="faster_rcnn_r50_fpn_coco")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 12
+ params.num_classes = 80
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/coco"
+ train_split = "train2017"
+ test_split = "val2017"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_coco_detection_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_faster_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_BBOX_2D_TRAIN,
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_BBOX_2D_TEST,
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ COCODetectEvaluator, data_root=data_root, split=test_split
+ ),
+ metrics_to_eval=["Det"],
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/fcn_resnet/__init__.py b/vis4d/zoo/fcn_resnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e5854217e69e0807c689cf86a5fc4c85913d09f
--- /dev/null
+++ b/vis4d/zoo/fcn_resnet/__init__.py
@@ -0,0 +1,7 @@
+"""FCN Model Zoo."""
+
+from . import fcn_resnet_coco
+
+AVAILABLE_MODELS = {
+ "fcn_resnet_coco": fcn_resnet_coco,
+}
diff --git a/vis4d/zoo/fcn_resnet/fcn_resnet_coco.py b/vis4d/zoo/fcn_resnet/fcn_resnet_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb905d9bd35f7d876ae550e672b53d6c0385ea77
--- /dev/null
+++ b/vis4d/zoo/fcn_resnet/fcn_resnet_coco.py
@@ -0,0 +1,163 @@
+"""FCN-ResNet COCO training example."""
+
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.connectors import DataConnector, LossConnector
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.model.seg.fcn_resnet import FCNResNet
+from vis4d.op.loss import MultiLevelSegLoss
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_MULTI_SEG_LOSS,
+)
+from vis4d.zoo.base.datasets.coco import get_coco_sem_seg_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the COCO semantic segmentation task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="fcn_coco")
+ config.sync_batchnorm = True
+ config.val_check_interval = 2000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 40000
+ params.num_classes = 21
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/COCO"
+ train_split = "train2017"
+ test_split = "val2017"
+ image_size = (520, 520)
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_coco_sem_seg_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ image_size=image_size,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ config.model = class_config(
+ FCNResNet,
+ base_model="resnet50",
+ num_classes=params.num_classes,
+ resize=image_size,
+ )
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ config.loss = class_config(
+ LossModule,
+ losses={
+ "loss": class_config(
+ MultiLevelSegLoss, feature_idx=[4, 5], weights=[0.5, 1]
+ ),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_MULTI_SEG_LOSS
+ ),
+ },
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/mask_rcnn/__init__.py b/vis4d/zoo/mask_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b09f158ffd75c7e3b652df48643f5d049981294
--- /dev/null
+++ b/vis4d/zoo/mask_rcnn/__init__.py
@@ -0,0 +1,7 @@
+"""Mask-RCNN Model Zoo."""
+
+from . import mask_rcnn_coco
+
+AVAILABLE_MODELS = {
+ "mask_rcnn_coco": mask_rcnn_coco,
+}
diff --git a/vis4d/zoo/mask_rcnn/mask_rcnn_coco.py b/vis4d/zoo/mask_rcnn/mask_rcnn_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff61dd6999b2b28495f1b2140fb960d70cc0de5
--- /dev/null
+++ b/vis4d/zoo/mask_rcnn/mask_rcnn_coco.py
@@ -0,0 +1,192 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN COCO training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ remap_pred_keys,
+)
+from vis4d.eval.coco import COCODetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.coco import (
+ CONN_COCO_BBOX_EVAL,
+ get_coco_detection_cfg,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Mask-RCNN config dict for the coco detection task.
+
+ This is an example that shows how to set up a training experiment for the
+ COCO detection task.
+
+ Note that the high level params are exposed in the config. This allows
+ to easily change them from the command line.
+ E.g.:
+ >>> python -m vis4d.engine.run fit --config configs/faster_rcnn/faster_rcnn_coco.py --config.params.lr 0.001
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_fpn_coco")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 12
+ params.num_classes = 80
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/coco"
+ train_split = "train2017"
+ test_split = "val2017"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_coco_detection_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ train_keys_to_load=(
+ K.images,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.instance_masks,
+ ),
+ test_split=test_split,
+ test_keys_to_load=(
+ K.images,
+ K.original_images,
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.instance_masks,
+ ),
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_BBOX_2D_TRAIN,
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_BBOX_2D_TEST,
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector,
+ key_mapping=remap_pred_keys(CONN_BBOX_2D_VIS, "boxes"),
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ COCODetectEvaluator,
+ data_root=data_root,
+ split=test_split,
+ ),
+ metrics_to_eval=["Det"],
+ test_connector=class_config(
+ CallbackConnector,
+ key_mapping=remap_pred_keys(CONN_COCO_BBOX_EVAL, "boxes"),
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/qdtrack/__init__.py b/vis4d/zoo/qdtrack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d58127bbf41a2c403e77569d3cb50bee68ccf4f
--- /dev/null
+++ b/vis4d/zoo/qdtrack/__init__.py
@@ -0,0 +1,14 @@
+"""QDTrack."""
+
+from . import (
+ qdtrack_frcnn_r50_fpn_augs_1x_bdd100k,
+ qdtrack_yolox_x_25e_bdd100k,
+)
+
+# Lists of available models in BDD100K Model Zoo.
+AVAILABLE_MODELS = {
+ "qdtrack_frcnn_r50_fpn_augs_1x_bdd100k": (
+ qdtrack_frcnn_r50_fpn_augs_1x_bdd100k
+ ),
+ "qdtrack_yolox_x_25e_bdd100k": qdtrack_yolox_x_25e_bdd100k,
+}
diff --git a/vis4d/zoo/qdtrack/data_yolox.py b/vis4d/zoo/qdtrack/data_yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..376654492f0b23fa9e995b7bd274f2dcbac849f5
--- /dev/null
+++ b/vis4d/zoo/qdtrack/data_yolox.py
@@ -0,0 +1,275 @@
+"""BDD100K data loading config for QDTrack YOLOX."""
+
+from __future__ import annotations
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe, MultiSampleDataPipe
+from vis4d.data.datasets.bdd100k import BDD100K, bdd100k_track_map
+from vis4d.data.loader import build_train_dataloader, default_collate
+from vis4d.data.reference import MultiViewDataset, UniformViewSampler
+from vis4d.data.transforms.affine import (
+ AffineBoxes2D,
+ AffineImages,
+ GenAffineParameters,
+)
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.crop import (
+ CropBoxes2D,
+ CropImages,
+ GenCropParameters,
+)
+from vis4d.data.transforms.flip import FlipBoxes2D, FlipImages
+from vis4d.data.transforms.mixup import (
+ GenMixupParameters,
+ MixupBoxes2D,
+ MixupImages,
+)
+from vis4d.data.transforms.mosaic import (
+ GenMosaicParameters,
+ MosaicBoxes2D,
+ MosaicImages,
+)
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.photometric import RandomHSV
+from vis4d.data.transforms.post_process import (
+ PostProcessBoxes2D,
+ RescaleTrackIDs,
+)
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeImages,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.zoo.base import get_inference_dataloaders_cfg
+from vis4d.zoo.base.callable import get_callable_cfg
+
+
+def get_train_dataloader(
+ data_backend: None | ConfigDict,
+ image_size: tuple[int, int],
+ normalize_image: bool,
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default train dataloader for BDD100K tracking."""
+ bdd100k_det_train = class_config(
+ BDD100K,
+ data_root="data/bdd100k/images/100k/train/",
+ keys_to_load=(K.images, K.boxes2d),
+ annotation_path="data/bdd100k/labels/det_20/det_train.json",
+ category_map=bdd100k_track_map,
+ config_path="det",
+ image_channel_mode="BGR",
+ data_backend=data_backend,
+ skip_empty_samples=True,
+ cache_as_binary=True,
+ cached_file_path="data/bdd100k/pkl/det_train.pkl",
+ )
+
+ bdd100k_track_train = class_config(
+ BDD100K,
+ data_root="data/bdd100k/images/track/train/",
+ keys_to_load=(K.images, K.boxes2d),
+ annotation_path="data/bdd100k/labels/box_track_20/train/",
+ category_map=bdd100k_track_map,
+ config_path="box_track",
+ image_channel_mode="BGR",
+ data_backend=data_backend,
+ skip_empty_samples=True,
+ cache_as_binary=True,
+ cached_file_path="data/bdd100k/pkl/track_train.pkl",
+ )
+
+ train_dataset_cfg = [
+ class_config(
+ MultiViewDataset,
+ dataset=bdd100k_det_train,
+ sampler=class_config(
+ UniformViewSampler, scope=0, num_ref_samples=1
+ ),
+ ),
+ class_config(
+ MultiViewDataset,
+ dataset=bdd100k_track_train,
+ sampler=class_config(
+ UniformViewSampler, scope=3, num_ref_samples=1
+ ),
+ ),
+ ]
+
+ # Train Preprocessing
+ preprocess_transforms = [
+ [
+ class_config(GenMosaicParameters, out_shape=image_size),
+ class_config(MosaicImages, imresize_backend="cv2"),
+ class_config(MosaicBoxes2D),
+ ],
+ [class_config(RescaleTrackIDs)],
+ ]
+
+ preprocess_transforms += [
+ [
+ class_config(
+ GenAffineParameters,
+ scaling_ratio_range=(0.5, 1.5),
+ border=(-image_size[0] // 2, -image_size[1] // 2),
+ ),
+ class_config(AffineImages, as_int=True),
+ class_config(AffineBoxes2D),
+ ]
+ ]
+
+ preprocess_transforms += [
+ [
+ class_config(
+ GenMixupParameters,
+ out_shape=image_size,
+ mixup_ratio_dist="const",
+ scale_range=(0.8, 1.6),
+ pad_value=114.0,
+ ),
+ class_config(MixupImages, imresize_backend="cv2"),
+ class_config(MixupBoxes2D),
+ ],
+ [class_config(RescaleTrackIDs)],
+ ]
+
+ preprocess_transforms.append(
+ [class_config(PostProcessBoxes2D, min_area=1.0)]
+ )
+
+ batch_transforms = [
+ class_config(RandomHSV, same_on_batch=False),
+ class_config(
+ RandomApply,
+ transforms=[class_config(FlipImages), class_config(FlipBoxes2D)],
+ probability=0.5,
+ same_on_batch=False,
+ ),
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ scale_range=(0.5, 1.5),
+ same_on_batch=False,
+ ),
+ class_config(ResizeImages),
+ class_config(ResizeBoxes2D),
+ class_config(GenCropParameters, shape=image_size, same_on_batch=False),
+ class_config(CropImages),
+ class_config(CropBoxes2D),
+ ]
+ if normalize_image:
+ batch_transforms += [
+ class_config(NormalizeImages),
+ class_config(PadImages),
+ ]
+ else:
+ batch_transforms += [class_config(PadImages, value=114.0)]
+ train_batchprocess_cfg = class_config(
+ compose, transforms=batch_transforms + [class_config(ToTensor)]
+ )
+
+ return class_config(
+ build_train_dataloader,
+ dataset=class_config(
+ MultiSampleDataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=preprocess_transforms,
+ ),
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ batchprocess_fn=train_batchprocess_cfg,
+ collate_fn=get_callable_cfg(default_collate),
+ pin_memory=True,
+ shuffle=True,
+ )
+
+
+def get_test_dataloader(
+ data_backend: None | ConfigDict,
+ image_size: tuple[int, int],
+ normalize_image: bool,
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default test dataloader for BDD100K tracking."""
+ test_dataset = class_config(
+ BDD100K,
+ data_root="data/bdd100k/images/track/val/",
+ keys_to_load=(K.images, K.original_images),
+ annotation_path="data/bdd100k/labels/box_track_20/val/",
+ category_map=bdd100k_track_map,
+ config_path="box_track",
+ image_channel_mode="BGR",
+ data_backend=data_backend,
+ cache_as_binary=True,
+ cached_file_path="data/bdd100k/pkl/track_val.pkl",
+ )
+
+ preprocess_transforms = [
+ class_config(GenResizeParameters, shape=image_size, keep_ratio=True),
+ class_config(ResizeImages),
+ ]
+
+ test_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ if normalize_image:
+ batch_transforms = [
+ class_config(NormalizeImages),
+ class_config(PadImages),
+ ]
+ else:
+ batch_transforms = [class_config(PadImages, value=114.0)]
+ test_batchprocess_cfg = class_config(
+ compose, transforms=batch_transforms + [class_config(ToTensor)]
+ )
+
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ video_based_inference=True,
+ batchprocess_cfg=test_batchprocess_cfg,
+ )
+
+
+def get_bdd100k_track_cfg(
+ data_backend: None | ConfigDict = None,
+ image_size: tuple[int, int] = (800, 1440),
+ normalize_image: bool = False,
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for BDD100K tracking."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_backend=data_backend,
+ image_size=image_size,
+ normalize_image=normalize_image,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_backend=data_backend,
+ image_size=image_size,
+ normalize_image=normalize_image,
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ )
+
+ return data
diff --git a/vis4d/zoo/qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py b/vis4d/zoo/qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..757f97f945b37888be6beb6a94ea4c35dad0581a
--- /dev/null
+++ b/vis4d/zoo/qdtrack/qdtrack_frcnn_r50_fpn_augs_1x_bdd100k.py
@@ -0,0 +1,175 @@
+# pylint: disable=duplicate-code
+"""QDTrack with Faster R-CNN on BDD100K."""
+from __future__ import annotations
+
+from lightning.pytorch.callbacks import ModelCheckpoint
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.datasets.bdd100k import bdd100k_track_map
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import (
+ EvaluatorCallback,
+ VisualizerCallback,
+ YOLOXModeSwitchCallback,
+)
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KTrackEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TRACK_VIS
+from vis4d.zoo.base.datasets.bdd100k import CONN_BDD100K_TRACK_EVAL
+from vis4d.zoo.base.models.qdtrack import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ get_qdtrack_cfg,
+)
+from vis4d.zoo.qdtrack.data_yolox import get_bdd100k_track_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for qdtrack on bdd100k.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="qdtrack_frcnn_r50_fpn_augs_1x_bdd100k")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 4 # batch size = 4 GPUs * 4 samples per GPU = 16
+ params.workers_per_gpu = 8
+ params.lr = 0.02
+ params.num_epochs = 12
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_track_cfg(
+ data_backend=data_backend,
+ image_size=(720, 1280),
+ normalize_image=True,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ num_classes = len(bdd100k_track_map)
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3
+ )
+
+ config.model, config.loss = get_qdtrack_cfg(
+ num_classes=num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(LinearLR, start_factor=0.1, total_iters=1000),
+ end=1000,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Mode switch for strong augmentations
+ callbacks += [class_config(YOLOXModeSwitchCallback, switch_epoch=9)]
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(
+ BoundingBoxVisualizer, vis_freq=500, image_mode="BGR"
+ ),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_TRACK_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KTrackEvaluator,
+ annotation_path="data/bdd100k/labels/box_track_20/val/",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.checkpoint_callback = class_config(
+ ModelCheckpoint,
+ dirpath=config.get_ref("output_dir") + "/checkpoints",
+ verbose=True,
+ save_last=True,
+ save_on_train_epoch_end=True,
+ every_n_epochs=1,
+ save_top_k=4,
+ mode="max",
+ monitor="step",
+ )
+ pl_trainer.wandb = True
+ pl_trainer.gradient_clip_val = 35
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/qdtrack/qdtrack_yolox_x_25e_bdd100k.py b/vis4d/zoo/qdtrack/qdtrack_yolox_x_25e_bdd100k.py
new file mode 100644
index 0000000000000000000000000000000000000000..782c42ad12278ba11ce0cb1b5ee614fa51f99d7f
--- /dev/null
+++ b/vis4d/zoo/qdtrack/qdtrack_yolox_x_25e_bdd100k.py
@@ -0,0 +1,163 @@
+# pylint: disable=duplicate-code
+"""QDTrack with YOLOX-x on BDD100K."""
+from __future__ import annotations
+
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.datasets.bdd100k import bdd100k_track_map
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.bdd100k import BDD100KTrackEvaluator
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+)
+from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TRACK_VIS
+from vis4d.zoo.base.datasets.bdd100k import CONN_BDD100K_TRACK_EVAL
+from vis4d.zoo.base.models.qdtrack import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ get_qdtrack_yolox_cfg,
+)
+from vis4d.zoo.base.models.yolox import (
+ get_yolox_callbacks_cfg,
+ get_yolox_optimizers_cfg,
+)
+from vis4d.zoo.qdtrack.data_yolox import get_bdd100k_track_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for qdtrack on bdd100k.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="qdtrack_yolox_x_25e_bdd100k")
+ config.checkpoint_period = 5
+ config.check_val_every_n_epoch = 5
+
+ # Hyper Parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 8 # batch size = 8 GPUs * 8 samples per GPU = 64
+ params.workers_per_gpu = 8
+ params.lr = 0.001
+ params.num_epochs = 25
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_bdd100k_track_cfg(
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ num_classes = len(bdd100k_track_map)
+ weights = (
+ "mmdet://yolox/yolox_x_8x8_300e_coco/"
+ "yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth"
+ )
+ config.model, config.loss = get_qdtrack_yolox_cfg(
+ num_classes, "xlarge", weights=weights
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ # we use a schedule with 50 epochs, but only train for 25 epochs
+ num_total_epochs, num_last_epochs = 50, 10
+ config.optimizers = get_yolox_optimizers_cfg(
+ params.lr, num_total_epochs, 1, num_last_epochs
+ )
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg(
+ refresh_rate=config.log_every_n_steps
+ )
+
+ # YOLOX callbacks
+ callbacks += get_yolox_callbacks_cfg(
+ switch_epoch=num_total_epochs - num_last_epochs, num_sizes=0
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(
+ BoundingBoxVisualizer, vis_freq=500, image_mode="BGR"
+ ),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_TRACK_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ BDD100KTrackEvaluator,
+ annotation_path="data/bdd100k/labels/box_track_20/val/",
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+ pl_trainer.checkpoint_callback = class_config(
+ ModelCheckpoint,
+ dirpath=config.get_ref("output_dir") + "/checkpoints",
+ verbose=True,
+ save_last=True,
+ save_on_train_epoch_end=True,
+ every_n_epochs=config.checkpoint_period,
+ save_top_k=5,
+ mode="max",
+ monitor="step",
+ )
+ pl_trainer.wandb = True
+ pl_trainer.precision = "16-mixed"
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/retinanet/__init__.py b/vis4d/zoo/retinanet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8ea582388b856e975f2668f55c7bd5b059adec8
--- /dev/null
+++ b/vis4d/zoo/retinanet/__init__.py
@@ -0,0 +1,7 @@
+"""RetinaNet Model Zoo."""
+
+from . import retinanet_coco
+
+AVAILABLE_MODELS = {
+ "retinanet_coco": retinanet_coco,
+}
diff --git a/vis4d/zoo/retinanet/retinanet_coco.py b/vis4d/zoo/retinanet/retinanet_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..446345faaf7f5d4da66f06bcd4f157f11ad5809d
--- /dev/null
+++ b/vis4d/zoo/retinanet/retinanet_coco.py
@@ -0,0 +1,206 @@
+# pylint: disable=duplicate-code
+"""RetinaNet COCO training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.eval.coco import COCODetectEvaluator
+from vis4d.model.detect.retinanet import RetinaNet
+from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
+from vis4d.op.detect.retinanet import (
+ RetinaNetHeadLoss,
+ get_default_anchor_generator,
+)
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_VIS,
+ CONN_BOX_LOSS_2D,
+ CONN_IMAGES_TEST,
+ CONN_IMAGES_TRAIN,
+)
+from vis4d.zoo.base.datasets.coco import (
+ CONN_COCO_BBOX_EVAL,
+ get_coco_detection_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the RetinaNet config dict for the coco detection task.
+
+ This is an example that shows how to set up a training experiment for the
+ COCO detection task.
+
+ Note that the high level params are exposed in the config. This allows
+ to easily change them from the command line.
+ E.g.:
+ >>> python -m vis4d.engine.run fit --config vis4d/zoo/retinanet/retinanet_rcnn_coco.py --config.num_epochs 100 --config.params.lr 0.001
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="retinanet_r50_fpn_coco")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_epochs = 12
+ params.num_classes = 80
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/coco"
+ train_split = "train2017"
+ test_split = "val2017"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_coco_detection_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(
+ RetinaNet,
+ num_classes=params.num_classes,
+ # weights="mmdet",
+ )
+
+ box_encoder = class_config(
+ DeltaXYWHBBoxEncoder,
+ target_means=(0.0, 0.0, 0.0, 0.0),
+ target_stds=(1.0, 1.0, 1.0, 1.0),
+ )
+
+ anchor_generator = class_config(get_default_anchor_generator)
+
+ retina_loss = class_config(
+ RetinaNetHeadLoss,
+ box_encoder=box_encoder,
+ anchor_generator=anchor_generator,
+ )
+
+ config.loss = class_config(
+ LossModule,
+ losses={
+ "loss": retina_loss,
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_BOX_LOSS_2D
+ ),
+ },
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_IMAGES_TRAIN,
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_IMAGES_TEST,
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector,
+ key_mapping=CONN_BBOX_2D_VIS,
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ COCODetectEvaluator,
+ data_root=data_root,
+ split=test_split,
+ ),
+ metrics_to_eval=["Det"],
+ test_connector=class_config(
+ CallbackConnector,
+ key_mapping=CONN_COCO_BBOX_EVAL,
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/run.py b/vis4d/zoo/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53b8d955e5fc130aa2c69011f5200d74b003c59
--- /dev/null
+++ b/vis4d/zoo/run.py
@@ -0,0 +1,27 @@
+"""CLI interface."""
+
+from __future__ import annotations
+
+from absl import app # pylint: disable=no-name-in-module
+
+from vis4d.common.typing import ArgsType
+from vis4d.zoo import AVAILABLE_MODELS
+
+
+def main(argv: ArgsType) -> None:
+ """Main entry point for the model zoo."""
+ assert len(argv) > 1, "Command must be specified: `list`"
+ if argv[1] == "list":
+ for ds, models in AVAILABLE_MODELS.items():
+ print(ds)
+ model_names = list(models.keys())
+ for model in model_names[:-1]:
+ print(" ├─", model)
+ print(" └─", model_names[-1])
+ else:
+ raise ValueError(f"Invalid command. {argv[1]}")
+
+
+def entrypoint() -> None:
+ """Entry point for the CLI."""
+ app.run(main)
diff --git a/vis4d/zoo/shift/__init__.py b/vis4d/zoo/shift/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..498a278c10e344317651b20aae8194781a091287
--- /dev/null
+++ b/vis4d/zoo/shift/__init__.py
@@ -0,0 +1,31 @@
+"""BDD100K Model Zoo."""
+
+from .faster_rcnn import (
+ faster_rcnn_r50_6e_shift_all_domains,
+ faster_rcnn_r50_12e_shift,
+ faster_rcnn_r50_36e_shift,
+)
+from .mask_rcnn import (
+ mask_rcnn_r50_6e_shift_all_domains,
+ mask_rcnn_r50_12e_shift,
+ mask_rcnn_r50_36e_shift,
+)
+from .semantic_fpn import (
+ semantic_fpn_r50_40k_shift,
+ semantic_fpn_r50_40k_shift_all_domains,
+ semantic_fpn_r50_160k_shift,
+ semantic_fpn_r50_160k_shift_all_domains,
+)
+
+AVAILABLE_MODELS = {
+ "faster_rcnn_r50_6e_shift_all_domains": faster_rcnn_r50_6e_shift_all_domains, # pylint: disable=line-too-long
+ "faster_rcnn_r50_12e_shift": faster_rcnn_r50_12e_shift,
+ "faster_rcnn_r50_36e_shift": faster_rcnn_r50_36e_shift,
+ "mask_rcnn_r50_6e_shift_all_domains": mask_rcnn_r50_6e_shift_all_domains,
+ "mask_rcnn_r50_12e_shift": mask_rcnn_r50_12e_shift,
+ "mask_rcnn_r50_36e_shift": mask_rcnn_r50_36e_shift,
+ "semantic_fpn_r50_40k_shift_all_domains": semantic_fpn_r50_40k_shift_all_domains, # pylint: disable=line-too-long
+ "semantic_fpn_r50_40k_shift": semantic_fpn_r50_40k_shift,
+ "semantic_fpn_r50_160k_shift_all_domains": semantic_fpn_r50_160k_shift_all_domains, # pylint: disable=line-too-long
+ "semantic_fpn_r50_160k_shift": semantic_fpn_r50_160k_shift,
+}
diff --git a/vis4d/zoo/shift/faster_rcnn/__init__.py b/vis4d/zoo/shift/faster_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7bf040b0026b6b3c7bf685770abfa55fde40e0f
--- /dev/null
+++ b/vis4d/zoo/shift/faster_rcnn/__init__.py
@@ -0,0 +1 @@
+"""Faster R-CNN for SHIFT."""
diff --git a/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_12e_shift.py b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_12e_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f422d3e57eae876e45fc84727d119d8ad75a4b0
--- /dev/null
+++ b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_12e_shift.py
@@ -0,0 +1,170 @@
+# pylint: disable=duplicate-code
+"""Faster RCNN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.shift import SHIFTDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.shift import (
+ CONN_SHIFT_DET_EVAL,
+ get_shift_det_config,
+)
+from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Faster-RCNN config dict for the SHIFT detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="faster_rcnn_r50_12e_shift")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 12
+ params.num_classes = 6
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_det_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4
+ )
+
+ config.model, config.loss = get_faster_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTDetectEvaluator,
+ annotation_path=(
+ f"{data_root}/discrete/images/val/front/det_2d.json"
+ ),
+ attributes_to_load=domain_attr,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SHIFT_DET_EVAL
+ ),
+ metrics_to_eval=[SHIFTDetectEvaluator.METRICS_DET],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_36e_shift.py b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_36e_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..64dbef288e1f6ec12f2aee9085841533e8e7b471
--- /dev/null
+++ b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_36e_shift.py
@@ -0,0 +1,170 @@
+# pylint: disable=duplicate-code
+"""Faster RCNN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.shift import SHIFTDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.shift import (
+ CONN_SHIFT_DET_EVAL,
+ get_shift_det_config,
+)
+from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Faster-RCNN config dict for the SHIFT detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="faster_rcnn_r50_36e_shift")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 36
+ params.num_classes = 6
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_det_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4
+ )
+
+ config.model, config.loss = get_faster_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[24, 33], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTDetectEvaluator,
+ annotation_path=(
+ f"{data_root}/discrete/images/val/front/det_2d.json"
+ ),
+ attributes_to_load=domain_attr,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SHIFT_DET_EVAL
+ ),
+ metrics_to_eval=[SHIFTDetectEvaluator.METRICS_DET],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_6e_shift_all_domains.py b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_6e_shift_all_domains.py
new file mode 100644
index 0000000000000000000000000000000000000000..45d427e501100a5fa487bd7d183eeb38566133b2
--- /dev/null
+++ b/vis4d/zoo/shift/faster_rcnn/faster_rcnn_r50_6e_shift_all_domains.py
@@ -0,0 +1,170 @@
+# pylint: disable=duplicate-code
+"""Faster RCNN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.shift import SHIFTDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_BBOX_2D_VIS,
+)
+from vis4d.zoo.base.datasets.shift import (
+ CONN_SHIFT_DET_EVAL,
+ get_shift_det_config,
+)
+from vis4d.zoo.base.models.faster_rcnn import get_faster_rcnn_cfg
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the Faster-RCNN config dict for the SHIFT detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="faster_rcnn_r50_6e_shift_all_domains")
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 6
+ params.num_classes = 6
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = None
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_det_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4
+ )
+
+ config.model, config.loss = get_faster_rcnn_cfg(
+ num_classes=params.num_classes, basemodel=basemodel
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[4, 5], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(BoundingBoxVisualizer, vis_freq=100),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTDetectEvaluator,
+ annotation_path=(
+ f"{data_root}/discrete/images/val/front/det_2d.json"
+ ),
+ attributes_to_load=domain_attr,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SHIFT_DET_EVAL
+ ),
+ metrics_to_eval=[SHIFTDetectEvaluator.METRICS_DET],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/mask_rcnn/__init__.py b/vis4d/zoo/shift/mask_rcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..14605694ec1b083b095b2e098ff55c802550018b
--- /dev/null
+++ b/vis4d/zoo/shift/mask_rcnn/__init__.py
@@ -0,0 +1 @@
+"""Mask R-CNN for SHIFT."""
diff --git a/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_12e_shift.py b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_12e_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66c4a5140de1d3b2e773da75dc32489fb285b55
--- /dev/null
+++ b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_12e_shift.py
@@ -0,0 +1,174 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import FieldConfigDict, class_config
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.shift import SHIFTDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_INS_MASK_2D_VIS,
+)
+from vis4d.zoo.base.datasets.shift import (
+ CONN_SHIFT_INS_EVAL,
+ get_shift_instance_seg_config,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> FieldConfigDict:
+ """Returns the Faster-RCNN config dict for the SHIFT detection task.
+
+ Returns:
+ FieldConfigDict: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_12e_shift")
+
+ # High level hyper parameters
+ params = FieldConfigDict()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 12
+ params.num_classes = 6
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_instance_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes,
+ basemodel=basemodel,
+ no_overlap=True,
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[8, 11], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=25),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTDetectEvaluator,
+ annotation_path=(
+ f"{data_root}/discrete/images/val/front/det_insseg_2d.json"
+ ),
+ attributes_to_load=domain_attr,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL
+ ),
+ metrics_to_eval=[
+ SHIFTDetectEvaluator.METRICS_DET,
+ SHIFTDetectEvaluator.METRICS_INS_SEG,
+ ],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_36e_shift.py b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_36e_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd2f746ee810d8bb40a205d99b46a788b20bd7f
--- /dev/null
+++ b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_36e_shift.py
@@ -0,0 +1,174 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import FieldConfigDict, class_config
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.shift import SHIFTDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_INS_MASK_2D_VIS,
+)
+from vis4d.zoo.base.datasets.shift import (
+ CONN_SHIFT_INS_EVAL,
+ get_shift_instance_seg_config,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> FieldConfigDict:
+ """Returns the Faster-RCNN config dict for the SHIFT detection task.
+
+ Returns:
+ FieldConfigDict: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_36e_shift")
+
+ # High level hyper parameters
+ params = FieldConfigDict()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 36
+ params.num_classes = 6
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_instance_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes,
+ basemodel=basemodel,
+ no_overlap=True,
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[24, 33], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=25),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTDetectEvaluator,
+ annotation_path=(
+ f"{data_root}/discrete/images/val/front/det_insseg_2d.json"
+ ),
+ attributes_to_load=domain_attr,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL
+ ),
+ metrics_to_eval=[
+ SHIFTDetectEvaluator.METRICS_DET,
+ SHIFTDetectEvaluator.METRICS_INS_SEG,
+ ],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_6e_shift_all_domains.py b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_6e_shift_all_domains.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f92ff6df1367e4ea643ac65f76da5572f69d9af
--- /dev/null
+++ b/vis4d/zoo/shift/mask_rcnn/mask_rcnn_r50_6e_shift_all_domains.py
@@ -0,0 +1,174 @@
+# pylint: disable=duplicate-code
+"""Mask RCNN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR, MultiStepLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import FieldConfigDict, class_config
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.shift import SHIFTDetectEvaluator
+from vis4d.op.base import ResNet
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors import (
+ CONN_BBOX_2D_TEST,
+ CONN_BBOX_2D_TRAIN,
+ CONN_INS_MASK_2D_VIS,
+)
+from vis4d.zoo.base.datasets.shift import (
+ CONN_SHIFT_INS_EVAL,
+ get_shift_instance_seg_config,
+)
+from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg
+
+
+def get_config() -> FieldConfigDict:
+ """Returns the Faster-RCNN config dict for the SHIFT detection task.
+
+ Returns:
+ FieldConfigDict: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="mask_rcnn_r50_6e_shift_all_domains")
+
+ # High level hyper parameters
+ params = FieldConfigDict()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.02
+ params.num_epochs = 6
+ params.num_classes = 6
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = None
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_instance_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ basemodel = class_config(
+ ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4
+ )
+
+ config.model, config.loss = get_mask_rcnn_cfg(
+ num_classes=params.num_classes,
+ basemodel=basemodel,
+ no_overlap=True,
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(MultiStepLR, milestones=[4, 5], gamma=0.1),
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg()
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=25),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTDetectEvaluator,
+ annotation_path=(
+ f"{data_root}/discrete/images/val/front/det_insseg_2d.json"
+ ),
+ attributes_to_load=domain_attr,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL
+ ),
+ metrics_to_eval=[
+ SHIFTDetectEvaluator.METRICS_DET,
+ SHIFTDetectEvaluator.METRICS_INS_SEG,
+ ],
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/semantic_fpn/__init__.py b/vis4d/zoo/shift/semantic_fpn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56b5a6da0c8a3eb7db1a09297403b73a5f30af4d
--- /dev/null
+++ b/vis4d/zoo/shift/semantic_fpn/__init__.py
@@ -0,0 +1 @@
+"""Semantic FPN for SHIFT."""
diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb6d91e2da013146ae8d8a7714795b00117c9ac7
--- /dev/null
+++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift.py
@@ -0,0 +1,191 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.shift import SHIFTSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_EVAL,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentParameters: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="semantic_fpn_r50_160k_shift")
+ config.sync_batchnorm = True
+ config.val_check_interval = 2000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 160000
+ params.num_classes = 23
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_sem_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(SemanticFPN, num_classes=params.num_classes)
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTSegEvaluator,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift_all_domains.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift_all_domains.py
new file mode 100644
index 0000000000000000000000000000000000000000..25fcf86184a1633905aa1f97d6529f367616b13d
--- /dev/null
+++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_160k_shift_all_domains.py
@@ -0,0 +1,193 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.shift import SHIFTSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_EVAL,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentParameters: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(
+ exp_name="semantic_fpn_r50_160k_shift_all_domains"
+ )
+ config.sync_batchnorm = True
+ config.val_check_interval = 2000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 160000
+ params.num_classes = 23
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = None
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_sem_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(SemanticFPN, num_classes=params.num_classes)
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTSegEvaluator,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8dc16032662fe410a5ae7431aaf61e4ef5eb74
--- /dev/null
+++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift.py
@@ -0,0 +1,191 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.shift import SHIFTSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_EVAL,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentParameters: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="semantic_fpn_r50_40k_shift")
+ config.sync_batchnorm = True
+ config.val_check_interval = 2000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 160000
+ params.num_classes = 23
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_sem_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(SemanticFPN, num_classes=params.num_classes)
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTSegEvaluator,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift_all_domains.py b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift_all_domains.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c607cede9fd2505201bc1967afd69602dae7b4
--- /dev/null
+++ b/vis4d/zoo/shift/semantic_fpn/semantic_fpn_r50_40k_shift_all_domains.py
@@ -0,0 +1,191 @@
+# pylint: disable=duplicate-code
+"""Semantic FPN SHIFT training example."""
+from __future__ import annotations
+
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.sgd import SGD
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.engine.optim import PolyLR
+from vis4d.eval.shift import SHIFTSegEvaluator
+from vis4d.model.seg.semantic_fpn import SemanticFPN
+from vis4d.op.loss import SegCrossEntropyLoss
+from vis4d.vis.image import SegMaskVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.seg import (
+ CONN_MASKS_TEST,
+ CONN_MASKS_TRAIN,
+ CONN_SEG_EVAL,
+ CONN_SEG_LOSS,
+ CONN_SEG_VIS,
+)
+from vis4d.zoo.base.datasets.shift import get_shift_sem_seg_config
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the BDD100K semantic segmentation task.
+
+ Returns:
+ ExperimentParameters: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="semantic_fpn_r50_40k_shift_all_domains")
+ config.sync_batchnorm = True
+ config.val_check_interval = 2000
+ config.check_val_every_n_epoch = None
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 2
+ params.workers_per_gpu = 2
+ params.lr = 0.01
+ params.num_steps = 160000
+ params.num_classes = 23
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/shift/"
+ views_to_load = ["front"]
+ train_split = "train"
+ test_split = "val"
+ domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}]
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_shift_sem_seg_config(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ train_views_to_load=views_to_load,
+ test_views_to_load=views_to_load,
+ train_attributes_to_load=domain_attr,
+ test_attributes_to_load=domain_attr,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model = class_config(SemanticFPN, num_classes=params.num_classes)
+ config.loss = class_config(
+ LossModule,
+ losses=[
+ {
+ "loss": class_config(SegCrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_SEG_LOSS
+ ),
+ },
+ ],
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ SGD, lr=params.lr, momentum=0.9, weight_decay=0.0005
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(
+ LinearLR, start_factor=0.001, total_iters=500
+ ),
+ end=500,
+ epoch_based=False,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ PolyLR,
+ max_steps=params.num_steps,
+ min_lr=0.0001,
+ power=0.9,
+ ),
+ epoch_based=False,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_MASKS_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg(epoch_based=False)
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ SHIFTSegEvaluator,
+ ),
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_EVAL
+ ),
+ )
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(SegMaskVisualizer, vis_freq=20),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_SEG_VIS
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.epoch_based = False
+ pl_trainer.max_steps = params.num_steps
+
+ pl_trainer.checkpoint_period = config.val_check_interval
+ pl_trainer.val_check_interval = config.val_check_interval
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+
+ pl_trainer.sync_batchnorm = config.sync_batchnorm
+ # pl_trainer.precision = 16
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/util.py b/vis4d/zoo/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58c5f87017cf7a261ccf9a2446b0c123f22e11e
--- /dev/null
+++ b/vis4d/zoo/util.py
@@ -0,0 +1,14 @@
+"""Utility functions for the zoo module."""
+
+from __future__ import annotations
+
+import importlib
+
+from vis4d.config.typing import ExperimentConfig
+
+
+def get_config_for_name(config_name: str) -> ExperimentConfig:
+ """Get config for name."""
+ module = importlib.import_module("vis4d.zoo." + config_name)
+
+ return module.get_config()
diff --git a/vis4d/zoo/vit/__init__.py b/vis4d/zoo/vit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c0b0ef7a1b904cf878134f38dd52a0554b0a10
--- /dev/null
+++ b/vis4d/zoo/vit/__init__.py
@@ -0,0 +1,8 @@
+"""ViT for image classification configs."""
+
+from . import vit_small_imagenet, vit_tiny_imagenet
+
+AVAILABLE_MODELS = {
+ "vit_small_imagenet": vit_small_imagenet,
+ "vit_tiny_imagenet": vit_tiny_imagenet,
+}
diff --git a/vis4d/zoo/vit/vit_small_imagenet.py b/vis4d/zoo/vit/vit_small_imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2c10bda561118188d22a2729fb5b0e3e5a284ce
--- /dev/null
+++ b/vis4d/zoo/vit/vit_small_imagenet.py
@@ -0,0 +1,181 @@
+# pylint: disable=duplicate-code
+"""VIT ImageNet-1k training example."""
+from __future__ import annotations
+
+from torch import nn
+from torch.optim.adamw import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.engine.callbacks import EMACallback, EvaluatorCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.eval.common.cls import ClassificationEvaluator
+from vis4d.model.adapter import ModelEMAAdapter
+from vis4d.model.cls.vit import ViTClassifer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.cls import (
+ CONN_CLS_LOSS,
+ CONN_CLS_TEST,
+ CONN_CLS_TRAIN,
+)
+from vis4d.zoo.base.datasets.imagenet import (
+ CONN_IMAGENET_CLS_EVAL,
+ get_imagenet_cls_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the ImageNet Classification task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+
+ config = get_default_cfg(exp_name="vit_small_16_imagenet1k")
+ config.sync_batchnorm = True
+ config.check_val_every_n_epoch = 1
+ config.ema_decay_rate = 0.99996
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 256
+ params.workers_per_gpu = 8
+ params.num_epochs = 300
+ params.lr = 1e-3
+ params.weight_decay = 0.01
+ params.num_classes = 1000
+ params.grad_norm_clip = 1.0
+ params.accumulate_grad_batches = 1
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/imagenet1k"
+ train_split = "train"
+ test_split = "val"
+ image_size = (224, 224)
+
+ config.data = get_imagenet_cls_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ image_size=image_size,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ config.model = class_config(
+ ModelEMAAdapter,
+ model=class_config(
+ ViTClassifer,
+ variant="vit_small_patch16_224",
+ num_classes=params.num_classes,
+ drop_rate=0.1,
+ drop_path_rate=0.1,
+ ),
+ )
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ config.loss = class_config(
+ LossModule,
+ losses={
+ "loss": class_config(nn.CrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_CLS_LOSS
+ ),
+ },
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ AdamW, lr=params.lr, weight_decay=params.weight_decay
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(LinearLR, estart_factor=1e-3, total_iters=10),
+ end=10,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ CosineAnnealingLR,
+ T_max=params.num_epochs,
+ eta_min=1e-9,
+ ),
+ begin=10,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_CLS_TRAIN,
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_CLS_TEST,
+ )
+
+ ######################################################
+ ## GENERIC CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg()
+
+ # EMA callback
+ callbacks.append(class_config(EMACallback))
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(ClassificationEvaluator),
+ metrics_to_eval=["Cls"],
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_IMAGENET_CLS_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = params.grad_norm_clip
+ pl_trainer.accumulate_grad_batches = params.accumulate_grad_batches
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/vit/vit_tiny_imagenet.py b/vis4d/zoo/vit/vit_tiny_imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e28542ed5db406b22a114904cb53123600bea198
--- /dev/null
+++ b/vis4d/zoo/vit/vit_tiny_imagenet.py
@@ -0,0 +1,181 @@
+# pylint: disable=duplicate-code
+"""VIT ImageNet-1k training example."""
+from __future__ import annotations
+
+from torch import nn
+from torch.optim.adamw import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.engine.callbacks import EMACallback, EvaluatorCallback
+from vis4d.engine.connectors import (
+ CallbackConnector,
+ DataConnector,
+ LossConnector,
+)
+from vis4d.engine.loss_module import LossModule
+from vis4d.eval.common.cls import ClassificationEvaluator
+from vis4d.model.adapter import ModelEMAAdapter
+from vis4d.model.cls.vit import ViTClassifer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+ get_lr_scheduler_cfg,
+ get_optimizer_cfg,
+)
+from vis4d.zoo.base.data_connectors.cls import (
+ CONN_CLS_LOSS,
+ CONN_CLS_TEST,
+ CONN_CLS_TRAIN,
+)
+from vis4d.zoo.base.datasets.imagenet import (
+ CONN_IMAGENET_CLS_EVAL,
+ get_imagenet_cls_cfg,
+)
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the config dict for the ImageNet Classification task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+
+ config = get_default_cfg(exp_name="vit_tiny_16_imagenet1k")
+ config.sync_batchnorm = True
+ config.check_val_every_n_epoch = 1
+
+ ## High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 256
+ params.workers_per_gpu = 8
+ params.num_epochs = 300
+ params.lr = 1e-3
+ params.weight_decay = 0.01
+ params.num_classes = 1000
+ params.grad_norm_clip = 1.0
+ params.accumulate_grad_batches = 1
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/imagenet1k"
+ train_split = "train"
+ test_split = "val"
+ image_size = (224, 224)
+
+ config.data = get_imagenet_cls_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ image_size=image_size,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL ##
+ ######################################################
+ config.model = class_config(
+ ModelEMAAdapter,
+ model=class_config(
+ ViTClassifer,
+ variant="vit_tiny_patch16_224",
+ num_classes=params.num_classes,
+ drop_rate=0.1,
+ drop_path_rate=0.1,
+ ),
+ decay=0.99998,
+ )
+
+ ######################################################
+ ## LOSS ##
+ ######################################################
+ config.loss = class_config(
+ LossModule,
+ losses={
+ "loss": class_config(nn.CrossEntropyLoss),
+ "connector": class_config(
+ LossConnector, key_mapping=CONN_CLS_LOSS
+ ),
+ },
+ )
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+
+ config.optimizers = [
+ get_optimizer_cfg(
+ optimizer=class_config(
+ AdamW, lr=params.lr, weight_decay=params.weight_decay
+ ),
+ lr_schedulers=[
+ get_lr_scheduler_cfg(
+ class_config(LinearLR, estart_factor=1e-3, total_iters=10),
+ end=10,
+ ),
+ get_lr_scheduler_cfg(
+ class_config(
+ CosineAnnealingLR,
+ T_max=params.num_epochs,
+ eta_min=1e-9,
+ ),
+ begin=10,
+ ),
+ ],
+ )
+ ]
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_CLS_TRAIN,
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector,
+ key_mapping=CONN_CLS_TEST,
+ )
+
+ ######################################################
+ ## GENERIC CALLBACKS ##
+ ######################################################
+ callbacks = get_default_callbacks_cfg()
+
+ # EMA callback
+ callbacks.append(class_config(EMACallback))
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(ClassificationEvaluator),
+ metrics_to_eval=["Cls"],
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_IMAGENET_CLS_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.gradient_clip_val = params.grad_norm_clip
+ pl_trainer.accumulate_grad_batches = params.accumulate_grad_batches
+
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/yolox/__init__.py b/vis4d/zoo/yolox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ba54d4d1130c41aa4d141eb035c6923dc7e0bd
--- /dev/null
+++ b/vis4d/zoo/yolox/__init__.py
@@ -0,0 +1,8 @@
+"""YOLOX Model Zoo."""
+
+from . import yolox_s_300e_coco, yolox_tiny_300e_coco
+
+AVAILABLE_MODELS = {
+ "yolox_s_300e_coco": yolox_s_300e_coco,
+ "yolox_tiny_300e_coco": yolox_tiny_300e_coco,
+}
diff --git a/vis4d/zoo/yolox/data.py b/vis4d/zoo/yolox/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..eefda5e35a025aede27ee6a0b10ff8a93b9d56d0
--- /dev/null
+++ b/vis4d/zoo/yolox/data.py
@@ -0,0 +1,261 @@
+# pylint: disable=duplicate-code
+"""COCO data loading config for YOLOX object detection."""
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from ml_collections import ConfigDict
+
+from vis4d.config import class_config
+from vis4d.config.typing import DataConfig
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.data_pipe import DataPipe, MultiSampleDataPipe
+from vis4d.data.datasets.coco import COCO
+from vis4d.data.io import DataBackend
+from vis4d.data.loader import build_train_dataloader, default_collate
+from vis4d.data.transforms.affine import (
+ AffineBoxes2D,
+ AffineImages,
+ GenAffineParameters,
+)
+from vis4d.data.transforms.base import RandomApply, compose
+from vis4d.data.transforms.flip import FlipBoxes2D, FlipImages
+from vis4d.data.transforms.mixup import (
+ GenMixupParameters,
+ MixupBoxes2D,
+ MixupImages,
+)
+from vis4d.data.transforms.mosaic import (
+ GenMosaicParameters,
+ MosaicBoxes2D,
+ MosaicImages,
+)
+from vis4d.data.transforms.pad import PadImages
+from vis4d.data.transforms.photometric import RandomHSV
+from vis4d.data.transforms.post_process import PostProcessBoxes2D
+from vis4d.data.transforms.resize import (
+ GenResizeParameters,
+ ResizeBoxes2D,
+ ResizeImages,
+)
+from vis4d.data.transforms.to_tensor import ToTensor
+from vis4d.engine.connectors import data_key, pred_key
+from vis4d.zoo.base import get_inference_dataloaders_cfg
+from vis4d.zoo.base.callable import get_callable_cfg
+
+CONN_COCO_BBOX_EVAL = {
+ "coco_image_id": data_key(K.sample_names),
+ "pred_boxes": pred_key("boxes"),
+ "pred_scores": pred_key("scores"),
+ "pred_classes": pred_key("class_ids"),
+}
+
+CONN_COCO_MASK_EVAL = {
+ "coco_image_id": data_key(K.sample_names),
+ "pred_boxes": pred_key("boxes.boxes"),
+ "pred_scores": pred_key("boxes.scores"),
+ "pred_classes": pred_key("boxes.class_ids"),
+ "pred_masks": pred_key("masks"),
+}
+
+
+def get_train_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ data_backend: None | DataBackend,
+ image_size: tuple[int, int],
+ scaling_ratio_range: tuple[float, float],
+ use_mixup: bool,
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default train dataloader for COCO detection."""
+ # Train Dataset
+ train_dataset_cfg = class_config(
+ COCO,
+ keys_to_load=keys_to_load,
+ data_root=data_root,
+ split=split,
+ remove_empty=False,
+ image_channel_mode="BGR",
+ data_backend=data_backend,
+ )
+
+ # Train Preprocessing
+ preprocess_transforms = [
+ [
+ class_config(GenMosaicParameters, out_shape=image_size),
+ class_config(MosaicImages, imresize_backend="cv2"),
+ class_config(MosaicBoxes2D),
+ ]
+ ]
+
+ preprocess_transforms += [
+ [
+ class_config(
+ GenAffineParameters,
+ scaling_ratio_range=scaling_ratio_range,
+ border=(-image_size[0] // 2, -image_size[1] // 2),
+ ),
+ class_config(AffineImages, as_int=True),
+ class_config(AffineBoxes2D),
+ ]
+ ]
+
+ if use_mixup:
+ preprocess_transforms += [
+ [
+ class_config(
+ GenMixupParameters,
+ out_shape=image_size,
+ mixup_ratio_dist="const",
+ scale_range=(0.8, 1.6),
+ pad_value=114.0,
+ ),
+ class_config(MixupImages, imresize_backend="cv2"),
+ class_config(MixupBoxes2D),
+ ]
+ ]
+
+ preprocess_transforms.append(
+ [class_config(PostProcessBoxes2D, min_area=1.0)]
+ )
+
+ train_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(RandomHSV, same_on_batch=False),
+ class_config(
+ RandomApply,
+ transforms=[
+ class_config(FlipImages),
+ class_config(FlipBoxes2D),
+ ],
+ probability=0.5,
+ same_on_batch=False,
+ ),
+ class_config(
+ GenResizeParameters,
+ shape=image_size,
+ keep_ratio=True,
+ same_on_batch=False,
+ ),
+ class_config(ResizeImages, imresize_backend="cv2"),
+ class_config(ResizeBoxes2D),
+ class_config(PadImages, value=114.0, pad2square=True),
+ class_config(ToTensor),
+ ],
+ )
+
+ return class_config(
+ build_train_dataloader,
+ dataset=class_config(
+ MultiSampleDataPipe,
+ datasets=train_dataset_cfg,
+ preprocess_fn=preprocess_transforms,
+ ),
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ batchprocess_fn=train_batchprocess_cfg,
+ collate_fn=get_callable_cfg(default_collate),
+ pin_memory=True,
+ shuffle=True,
+ )
+
+
+def get_test_dataloader(
+ data_root: str,
+ split: str,
+ keys_to_load: Sequence[str],
+ data_backend: None | DataBackend,
+ image_size: tuple[int, int],
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+) -> ConfigDict:
+ """Get the default test dataloader for COCO detection."""
+ # Test Dataset
+ test_dataset = class_config(
+ COCO,
+ keys_to_load=keys_to_load,
+ data_root=data_root,
+ split=split,
+ image_channel_mode="BGR",
+ data_backend=data_backend,
+ )
+
+ # Test Preprocessing
+ preprocess_transforms = [
+ class_config(GenResizeParameters, shape=image_size, keep_ratio=True),
+ class_config(ResizeImages, imresize_backend="cv2"),
+ ]
+
+ test_preprocess_cfg = class_config(
+ compose, transforms=preprocess_transforms
+ )
+
+ test_batchprocess_cfg = class_config(
+ compose,
+ transforms=[
+ class_config(PadImages, value=114.0, pad2square=True),
+ class_config(ToTensor),
+ ],
+ )
+
+ # Test Dataset Config
+ test_dataset_cfg = class_config(
+ DataPipe, datasets=test_dataset, preprocess_fn=test_preprocess_cfg
+ )
+
+ return get_inference_dataloaders_cfg(
+ datasets_cfg=test_dataset_cfg,
+ batchprocess_cfg=test_batchprocess_cfg,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+
+def get_coco_yolox_cfg(
+ data_root: str = "data/coco",
+ train_split: str = "train2017",
+ train_keys_to_load: Sequence[str] = (
+ K.images,
+ K.boxes2d,
+ K.boxes2d_classes,
+ ),
+ test_split: str = "val2017",
+ test_keys_to_load: Sequence[str] = (K.images, K.original_images),
+ data_backend: None | ConfigDict = None,
+ train_image_size: tuple[int, int] = (640, 640),
+ scaling_ratio_range: tuple[float, float] = (0.1, 2.0),
+ use_mixup: bool = True,
+ test_image_size: tuple[int, int] = (640, 640),
+ samples_per_gpu: int = 2,
+ workers_per_gpu: int = 2,
+) -> DataConfig:
+ """Get the default config for COCO detection."""
+ data = DataConfig()
+
+ data.train_dataloader = get_train_dataloader(
+ data_root=data_root,
+ split=train_split,
+ keys_to_load=train_keys_to_load,
+ data_backend=data_backend,
+ image_size=train_image_size,
+ scaling_ratio_range=scaling_ratio_range,
+ use_mixup=use_mixup,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ data.test_dataloader = get_test_dataloader(
+ data_root=data_root,
+ split=test_split,
+ keys_to_load=test_keys_to_load,
+ data_backend=data_backend,
+ image_size=test_image_size,
+ samples_per_gpu=1,
+ workers_per_gpu=workers_per_gpu,
+ )
+
+ return data
diff --git a/vis4d/zoo/yolox/yolox_s_300e_coco.py b/vis4d/zoo/yolox/yolox_s_300e_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b65110335ba1f394f2a01f37a7fe1af380bc4e4
--- /dev/null
+++ b/vis4d/zoo/yolox/yolox_s_300e_coco.py
@@ -0,0 +1,159 @@
+# pylint: disable=duplicate-code
+"""YOLOX COCO."""
+from __future__ import annotations
+
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.coco import COCODetectEvaluator
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+)
+from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TEST, CONN_BBOX_2D_VIS
+from vis4d.zoo.base.models.yolox import (
+ get_yolox_callbacks_cfg,
+ get_yolox_cfg,
+ get_yolox_optimizers_cfg,
+)
+from vis4d.zoo.yolox.data import CONN_COCO_BBOX_EVAL, get_coco_yolox_cfg
+
+CONN_BBOX_2D_TRAIN = {"images": K.images}
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the YOLOX config dict for the coco detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="yolox_s_300e_coco")
+ config.checkpoint_period = 15
+ config.check_val_every_n_epoch = 10
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 8
+ params.workers_per_gpu = 4
+ params.lr = 0.01
+ params.num_epochs = 300
+ params.num_classes = 80
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/coco"
+ train_split = "train2017"
+ test_split = "val2017"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_coco_yolox_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model, config.loss = get_yolox_cfg(params.num_classes, "small")
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ num_last_epochs, warmup_epochs = 15, 5
+ config.optimizers = get_yolox_optimizers_cfg(
+ params.lr, params.num_epochs, warmup_epochs, num_last_epochs
+ )
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg(
+ refresh_rate=config.log_every_n_steps
+ )
+
+ # YOLOX callbacks
+ callbacks += get_yolox_callbacks_cfg(
+ switch_epoch=params.num_epochs - num_last_epochs
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(
+ BoundingBoxVisualizer, vis_freq=100, image_mode="BGR"
+ ),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ COCODetectEvaluator, data_root=data_root, split=test_split
+ ),
+ metrics_to_eval=["Det"],
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+ pl_trainer.checkpoint_callback = class_config(
+ ModelCheckpoint,
+ dirpath=config.get_ref("output_dir") + "/checkpoints",
+ verbose=True,
+ save_last=True,
+ save_on_train_epoch_end=True,
+ every_n_epochs=config.checkpoint_period,
+ save_top_k=3,
+ mode="max",
+ monitor="step",
+ )
+ pl_trainer.wandb = True
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/vis4d/zoo/yolox/yolox_tiny_300e_coco.py b/vis4d/zoo/yolox/yolox_tiny_300e_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6bad53e5718edc606e5c6e7f236c91ad10f942b
--- /dev/null
+++ b/vis4d/zoo/yolox/yolox_tiny_300e_coco.py
@@ -0,0 +1,162 @@
+# pylint: disable=duplicate-code
+"""YOLOX COCO."""
+from __future__ import annotations
+
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+from vis4d.config import class_config
+from vis4d.config.typing import ExperimentConfig, ExperimentParameters
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.io.hdf5 import HDF5Backend
+from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
+from vis4d.engine.connectors import CallbackConnector, DataConnector
+from vis4d.eval.coco import COCODetectEvaluator
+from vis4d.vis.image import BoundingBoxVisualizer
+from vis4d.zoo.base import (
+ get_default_callbacks_cfg,
+ get_default_cfg,
+ get_default_pl_trainer_cfg,
+)
+from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TEST, CONN_BBOX_2D_VIS
+from vis4d.zoo.base.models.yolox import (
+ get_yolox_callbacks_cfg,
+ get_yolox_cfg,
+ get_yolox_optimizers_cfg,
+)
+from vis4d.zoo.yolox.data import CONN_COCO_BBOX_EVAL, get_coco_yolox_cfg
+
+CONN_BBOX_2D_TRAIN = {"images": K.images}
+
+
+def get_config() -> ExperimentConfig:
+ """Returns the YOLOX config dict for the coco detection task.
+
+ Returns:
+ ExperimentConfig: The configuration
+ """
+ ######################################################
+ ## General Config ##
+ ######################################################
+ config = get_default_cfg(exp_name="yolox_tiny_300e_coco")
+ config.checkpoint_period = 15
+ config.check_val_every_n_epoch = 10
+
+ # High level hyper parameters
+ params = ExperimentParameters()
+ params.samples_per_gpu = 8
+ params.workers_per_gpu = 4
+ params.lr = 0.01
+ params.num_epochs = 300
+ params.num_classes = 80
+ config.params = params
+
+ ######################################################
+ ## Datasets with augmentations ##
+ ######################################################
+ data_root = "data/coco"
+ train_split = "train2017"
+ test_split = "val2017"
+
+ data_backend = class_config(HDF5Backend)
+
+ config.data = get_coco_yolox_cfg(
+ data_root=data_root,
+ train_split=train_split,
+ test_split=test_split,
+ data_backend=data_backend,
+ scaling_ratio_range=(0.5, 1.5),
+ use_mixup=False,
+ test_image_size=(416, 416),
+ samples_per_gpu=params.samples_per_gpu,
+ workers_per_gpu=params.workers_per_gpu,
+ )
+
+ ######################################################
+ ## MODEL & LOSS ##
+ ######################################################
+ config.model, config.loss = get_yolox_cfg(params.num_classes, "tiny")
+
+ ######################################################
+ ## OPTIMIZERS ##
+ ######################################################
+ num_last_epochs, warmup_epochs = 15, 5
+ config.optimizers = get_yolox_optimizers_cfg(
+ params.lr, params.num_epochs, warmup_epochs, num_last_epochs
+ )
+
+ ######################################################
+ ## DATA CONNECTOR ##
+ ######################################################
+ config.train_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
+ )
+
+ config.test_data_connector = class_config(
+ DataConnector, key_mapping=CONN_BBOX_2D_TEST
+ )
+
+ ######################################################
+ ## CALLBACKS ##
+ ######################################################
+ # Logger
+ callbacks = get_default_callbacks_cfg(
+ refresh_rate=config.log_every_n_steps
+ )
+
+ # YOLOX callbacks
+ callbacks += get_yolox_callbacks_cfg(
+ switch_epoch=params.num_epochs - num_last_epochs, shape=(320, 320)
+ )
+
+ # Visualizer
+ callbacks.append(
+ class_config(
+ VisualizerCallback,
+ visualizer=class_config(
+ BoundingBoxVisualizer, vis_freq=100, image_mode="BGR"
+ ),
+ output_dir=config.output_dir,
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
+ ),
+ )
+ )
+
+ # Evaluator
+ callbacks.append(
+ class_config(
+ EvaluatorCallback,
+ evaluator=class_config(
+ COCODetectEvaluator, data_root=data_root, split=test_split
+ ),
+ metrics_to_eval=["Det"],
+ test_connector=class_config(
+ CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL
+ ),
+ )
+ )
+
+ config.callbacks = callbacks
+
+ ######################################################
+ ## PL CLI ##
+ ######################################################
+ # PL Trainer args
+ pl_trainer = get_default_pl_trainer_cfg(config)
+ pl_trainer.max_epochs = params.num_epochs
+ pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
+ pl_trainer.checkpoint_callback = class_config(
+ ModelCheckpoint,
+ dirpath=config.get_ref("output_dir") + "/checkpoints",
+ verbose=True,
+ save_last=True,
+ save_on_train_epoch_end=True,
+ every_n_epochs=config.checkpoint_period,
+ save_top_k=3,
+ mode="max",
+ monitor="step",
+ )
+ pl_trainer.wandb = True
+ config.pl_trainer = pl_trainer
+
+ return config.value_mode()
diff --git a/wilddet3d/__init__.py b/wilddet3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce3fca0e97e250dd3b5af327624c378434d3949
--- /dev/null
+++ b/wilddet3d/__init__.py
@@ -0,0 +1,29 @@
+"""WildDet3D: Open-Vocabulary Monocular 3D Object Detection in the Wild."""
+
+import sys
+from pathlib import Path
+
+# Add third_party submodules to Python path
+_third_party = Path(__file__).parent.parent / "third_party"
+_sam3_path = str(_third_party / "sam3")
+_lingbot_path = str(_third_party / "lingbot_depth")
+
+if _sam3_path not in sys.path:
+ sys.path.insert(0, _sam3_path)
+if _lingbot_path not in sys.path:
+ sys.path.insert(0, _lingbot_path)
+
+from .data_types import Det3DOut, WildDet3DInput, WildDet3DOut
+from .inference import WildDet3DPredictor, build_model
+from .model import WildDet3D
+from .preprocessing import preprocess
+
+__all__ = [
+ "WildDet3D",
+ "WildDet3DPredictor",
+ "WildDet3DInput",
+ "WildDet3DOut",
+ "Det3DOut",
+ "build_model",
+ "preprocess",
+]
diff --git a/wilddet3d/__pycache__/__init__.cpython-311.pyc b/wilddet3d/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2b553fec1d09854d0bbbaeb9ed86d53e833d2d6
Binary files /dev/null and b/wilddet3d/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/__pycache__/data_types.cpython-311.pyc b/wilddet3d/__pycache__/data_types.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5951004c70b0d5c933fe5795c3aec5ccd652d885
Binary files /dev/null and b/wilddet3d/__pycache__/data_types.cpython-311.pyc differ
diff --git a/wilddet3d/__pycache__/inference.cpython-311.pyc b/wilddet3d/__pycache__/inference.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..404984d26919949c7d9eb15acd7cbc5ba22237bb
Binary files /dev/null and b/wilddet3d/__pycache__/inference.cpython-311.pyc differ
diff --git a/wilddet3d/__pycache__/model.cpython-311.pyc b/wilddet3d/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c25246280fffb9cd1da2963ee69ad33eee0b1fd1
Binary files /dev/null and b/wilddet3d/__pycache__/model.cpython-311.pyc differ
diff --git a/wilddet3d/__pycache__/preprocessing.cpython-311.pyc b/wilddet3d/__pycache__/preprocessing.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e770793d4dbe2c093b7a9d4212527c980aec66bb
Binary files /dev/null and b/wilddet3d/__pycache__/preprocessing.cpython-311.pyc differ
diff --git a/wilddet3d/connector.py b/wilddet3d/connector.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea034b79986782562e0822d7aaba9b6a5d8ade7d
--- /dev/null
+++ b/wilddet3d/connector.py
@@ -0,0 +1,1852 @@
+"""WildDet3D data connector and collator configuration.
+
+This module provides:
+1. DataConnector key mappings for train/test
+2. WildDet3DCollator: converts per-image DataLoader output to WildDet3DInput
+3. Point prompt sampling (from mask or box region)
+"""
+
+from __future__ import annotations
+
+import random
+import time
+from collections import defaultdict
+from typing import List, Literal, Optional
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from wilddet3d.ops.profiler import profile_start, profile_stop
+
+from ml_collections import ConfigDict
+from vis4d.config import class_config
+from vis4d.data.const import CommonKeys as K
+from vis4d.engine.connectors import DataConnector, data_key, pred_key
+
+from wilddet3d.model import WildDet3DInput
+
+
+# ============================================================================
+# Point Sampling Utilities
+# ============================================================================
+
+def sample_points_from_mask(
+ mask: np.ndarray,
+ n_points: int,
+ mode: Literal["centered", "random_mask", "random_box"],
+ box: Optional[np.ndarray] = None,
+) -> np.ndarray:
+ """Sample points from a binary mask.
+
+ Args:
+ mask: Binary mask (H, W), 1=foreground, 0=background
+ n_points: Number of points to sample
+ mode: Sampling mode
+ - "centered": sample from mask center (farthest from edges)
+ - "random_mask": uniform sample from mask interior
+ - "random_box": uniform sample from box, label from mask
+ box: Box in xyxy format (required for random_box mode)
+
+ Returns:
+ Points array (n_points, 3) with (x, y, label)
+ """
+ if mode == "centered":
+ return _center_positive_sample(mask, n_points)
+ elif mode == "random_mask":
+ return _uniform_positive_sample(mask, n_points)
+ elif mode == "random_box":
+ assert box is not None, "'random_box' mode requires a provided box."
+ return _uniform_sample_from_box(mask, box, n_points)
+ else:
+ raise ValueError(f"Unknown point sampling mode {mode}.")
+
+
+def _uniform_positive_sample(mask: np.ndarray, n_points: int) -> np.ndarray:
+ """Sample positive points uniformly from mask interior."""
+ mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0)
+ if len(mask_points) == 0:
+ # Empty mask, return center of image as fallback
+ h, w = mask.shape
+ return np.array([[w // 2, h // 2, 1]] * n_points)
+
+ selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points)
+ selected_points = mask_points[selected_idxs]
+ selected_points = selected_points[:, ::-1] # (y, x) -> (x, y)
+ labels = np.ones((len(selected_points), 1))
+ return np.concatenate([selected_points, labels], axis=1)
+
+
+def _center_positive_sample(mask: np.ndarray, n_points: int) -> np.ndarray:
+ """Sample points farthest from mask edges (using distance transform)."""
+ try:
+ import cv2
+ except ImportError:
+ # Fallback to uniform sampling if cv2 not available
+ return _uniform_positive_sample(mask, n_points)
+
+ if np.max(mask) == 0:
+ h, w = mask.shape
+ return np.array([[w // 2, h // 2, 1]] * n_points)
+
+ padded_mask = np.pad(mask.astype(np.uint8), 1)
+ points = []
+
+ for _ in range(n_points):
+ if np.max(padded_mask) == 0:
+ break
+ dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0)
+ point = np.unravel_index(dist.argmax(), dist.shape)
+ padded_mask[point[0], point[1]] = 0
+ points.append(point[::-1]) # (y, x) -> (x, y)
+
+ if len(points) == 0:
+ h, w = mask.shape
+ return np.array([[w // 2, h // 2, 1]] * n_points)
+
+ points = np.stack(points, axis=0)
+ points = points - 1 # Subtract padding offset
+ labels = np.ones((len(points), 1))
+ return np.concatenate([points, labels], axis=1)
+
+
+def _uniform_sample_from_box(
+ mask: np.ndarray,
+ box: np.ndarray,
+ n_points: int,
+) -> np.ndarray:
+ """Sample points uniformly from box, determine labels from mask."""
+ int_box = np.ceil(box).astype(int)
+ x1, y1, x2, y2 = int_box
+
+ # Ensure valid box
+ x2 = max(x2, x1 + 1)
+ y2 = max(y2, y1 + 1)
+
+ x = np.random.randint(low=x1, high=x2, size=n_points)
+ y = np.random.randint(low=y1, high=y2, size=n_points)
+
+ # Clip to mask boundaries
+ h, w = mask.shape
+ x = np.clip(x, 0, w - 1)
+ y = np.clip(y, 0, h - 1)
+
+ labels = mask[y, x]
+ return np.stack([x, y, labels], axis=1)
+
+
+def sample_points_without_mask(
+ box: np.ndarray,
+ n_positive: int,
+ n_negative: int,
+ H: int,
+ W: int,
+) -> np.ndarray:
+ """Sample points when no mask is available.
+
+ Uses box region as pseudo-mask:
+ - Positive points: uniformly from inside box
+ - Negative points: uniformly from outside box
+
+ Args:
+ box: Box in xyxy format (x1, y1, x2, y2)
+ n_positive: Number of positive points to sample
+ n_negative: Number of negative points to sample
+ H: Image height
+ W: Image width
+
+ Returns:
+ Points array (n_positive + n_negative, 3) with (x, y, label)
+ """
+ x1, y1, x2, y2 = map(int, box)
+
+ # Ensure valid box
+ x1 = max(0, min(x1, W - 1))
+ x2 = max(x1 + 1, min(x2, W))
+ y1 = max(0, min(y1, H - 1))
+ y2 = max(y1 + 1, min(y2, H))
+
+ points_list = []
+
+ # Positive points: inside box
+ if n_positive > 0:
+ pos_x = np.random.randint(x1, x2, size=n_positive)
+ pos_y = np.random.randint(y1, y2, size=n_positive)
+ pos_labels = np.ones(n_positive)
+ pos_points = np.stack([pos_x, pos_y, pos_labels], axis=1)
+ points_list.append(pos_points)
+
+ # Negative points: outside box
+ if n_negative > 0:
+ neg_points = []
+ max_attempts = n_negative * 100
+
+ for _ in range(max_attempts):
+ if len(neg_points) >= n_negative:
+ break
+ x = np.random.randint(0, W)
+ y = np.random.randint(0, H)
+ # Check if outside box
+ if not (x1 <= x < x2 and y1 <= y < y2):
+ neg_points.append([x, y, 0])
+
+ if len(neg_points) < n_negative:
+ # Fallback: sample from image corners if box is too large
+ corners = [(0, 0), (W-1, 0), (0, H-1), (W-1, H-1)]
+ while len(neg_points) < n_negative:
+ cx, cy = corners[len(neg_points) % 4]
+ neg_points.append([cx, cy, 0])
+
+ neg_points = np.array(neg_points[:n_negative])
+ points_list.append(neg_points)
+
+ if points_list:
+ return np.concatenate(points_list, axis=0)
+ else:
+ return np.zeros((0, 3))
+
+
+def noise_box(
+ box: np.ndarray,
+ im_size: tuple,
+ box_noise_std: float = 0.1,
+ box_noise_max: Optional[float] = None,
+ min_box_area: float = 0.0,
+) -> np.ndarray:
+ """Add noise to a box for data augmentation.
+
+ Follows SAM3's noise_box implementation:
+ - Gaussian noise scaled by box dimensions
+ - Optional pixel clamp
+ - Fallback to original box if area too small
+
+ Args:
+ box: Box in xyxy format (x1, y1, x2, y2)
+ im_size: Image size (H, W)
+ box_noise_std: Noise std relative to box size
+ box_noise_max: Max noise in pixels (None = no clamp)
+ min_box_area: Min area after noising (SAM3 default: 0.0)
+
+ Returns:
+ Noised box in xyxy format
+ """
+ if box_noise_std <= 0.0:
+ return box
+
+ noise = box_noise_std * np.random.randn(4)
+ w, h = box[2] - box[0], box[3] - box[1]
+ scale_factor = np.array([w, h, w, h])
+ noise = noise * scale_factor
+
+ if box_noise_max is not None:
+ noise = np.clip(noise, -box_noise_max, box_noise_max)
+
+ noised_box = box + noise
+
+ # Clamp to image bounds
+ H, W = im_size
+ noised_box = np.maximum(noised_box, 0)
+ noised_box = np.minimum(noised_box, [W, H, W, H])
+
+ # Check min area (SAM3 default: 0.0 = no limit)
+ new_w = noised_box[2] - noised_box[0]
+ new_h = noised_box[3] - noised_box[1]
+ if new_w * new_h <= min_box_area:
+ return box
+
+ return noised_box
+
+
+# ============================================================================
+# WildDet3D Collator
+# ============================================================================
+
+class WildDet3DCollator:
+ """Collator that converts per-image data to WildDet3DInput.
+
+ Design (SAM3 original - per-category queries):
+ - DataLoader produces per-image samples
+ - Collator groups GT boxes by category
+ - Each category creates ONE query with multi-instance targets
+ - This aligns with SAM3's multi-instance detection design
+
+ Per-prompt batch strategy:
+ - N_prompts = sum of unique categories across batch (NOT sum of boxes!)
+ - img_ids[i] indicates which image prompt i belongs to
+ - Each prompt can have multiple GT boxes (multi-instance targets)
+
+ Coordinate format:
+ - Input boxes2d: pixel xyxy (from dataset)
+ - geo_boxes: normalized cxcywh [0,1] (for SAM3)
+ - geo_points: normalized xy [0,1] (for SAM3)
+ - gt_boxes2d: normalized xyxy [0,1] (for loss)
+ - gt_boxes2d shape: (N_prompts, max_gts, 4) for multi-instance
+ - num_gts: (N_prompts,) number of GT boxes per query (can be > 1)
+
+ Text/Visual Query:
+ - text_query_prob controls the ratio of text vs visual queries
+ - text_query_prob=1.0: all text queries (SAM3 default for training)
+ - text_query_prob=0.7: 70% text, 30% visual (recommended by SAM3)
+ - Visual queries use one randomly selected target box as geo_box
+ """
+
+ def __init__(
+ self,
+ max_prompts_per_image: int = 50,
+ use_text_prompts: bool = True,
+ default_text: str = "visual",
+ # Point prompt options
+ use_point_prompts: bool = False,
+ num_positive_points: int | tuple[int, int] = 1,
+ num_negative_points: int | tuple[int, int] = 0,
+ point_sample_mode: Literal["centered", "random_mask", "random_box"] = "random_mask",
+ # Box prompt options
+ use_box_prompts: bool = True,
+ box_noise_std: float = 0.0,
+ box_noise_max: float | None = None,
+ # Multi-tier box noise: (prob, std) tiers sampled per box.
+ # If set, overrides box_noise_std. Each tier is (probability, std).
+ # Probabilities must sum to 1.0.
+ # Example: [(0.3, 0.0), (0.5, 0.1), (0.2, 0.2)]
+ # = 30% no noise, 50% mild, 20% extreme
+ box_noise_tiers: list[tuple[float, float]] | None = None,
+ # Text/Visual query ratio (SAM3 original design)
+ text_query_prob: float = 0.7, # 70% text, 30% visual (SAM3 recommended)
+ keep_text_for_visual: bool = False, # If True, visual queries keep category text
+ # Geometry prompt options (text + geometry training)
+ use_geometry_prompts: bool = False, # If True, create 2 queries per category
+ geometric_query_str: str = "geometric", # Text for geometry queries
+ visual_query_str: str = "visual", # Text for visual queries
+ # 5-mode training: Branch 1 and Branch 2 probabilities
+ # Branch 1 (o2m): TEXT (text_only_prob) / VISUAL or VISUAL+LABEL (1-text_only_prob)
+ # Branch 2 (o2o): GEOMETRY or GEOMETRY+LABEL
+ # use_label_prob controls +LABEL variants for both branches
+ text_only_prob: float = 0.5, # Branch 1: P(TEXT) vs P(box-based query)
+ use_label_prob: float = 1/3, # P(+LABEL) when query has a box prompt
+ # Oracle evaluation mode (GT box as geometry prompt)
+ oracle_eval: bool = False, # If True, each GT box = one geometry prompt
+ oracle_text_category: bool = False, # If True, oracle + category text
+ # Point prompt: SAM3-style box/point budget (only when use_point_prompts=True)
+ # num_points is the total geometric prompt budget.
+ # box_chance controls probability of including a box (which takes 1 slot).
+ # E.g. num_points=(1,3), box_chance=0.5:
+ # num=1, box=True → pure box | num=1, box=False → 1 point
+ # num=2, box=True → box+1pt | num=2, box=False → 2 points
+ # num=3, box=True → box+2pt | num=3, box=False → 3 points
+ box_chance: float = 0.5,
+ # Exclusive point mode probability. When use_point_prompts=True,
+ # Branch 2 randomly picks EITHER box-only OR point-only (never
+ # both). Point-only is chosen with probability point_mode_prob,
+ # but only when the selected box has a mask (masks2d_rle).
+ # Otherwise box-only. Points use SAM3 random_box mode: uniform
+ # from box region, mask determines pos/neg labels.
+ point_mode_prob: float = 0.3,
+ # Negative sampling (SAM3 style)
+ include_negatives: bool = False, # Add negative queries (absent categories)
+ max_negatives_per_image: int = 5, # Max negative queries per image
+ # Training vs inference filtering
+ filter_empty_boxes: bool = True, # Set False at test time to keep 0-GT-box images
+ ):
+ """Initialize collator.
+
+ Args:
+ max_prompts_per_image: Max number of prompts (categories) per image
+ use_text_prompts: Whether to include text with geometric prompts
+ default_text: Default text when class name not available
+ use_point_prompts: Whether to sample point prompts (for ablation)
+ num_positive_points: Number of positive points to sample
+ Can be int or (min, max) tuple for random range
+ num_negative_points: Number of negative points to sample
+ Can be int or (min, max) tuple for random range
+ point_sample_mode: How to sample points when mask is available
+ - "centered": sample from mask center (farthest from edges)
+ - "random_mask": uniform sample from mask interior
+ - "random_box": uniform sample from box, label from mask
+ use_box_prompts: Whether to use box prompts
+ box_noise_std: Noise std for box jittering (0 = no noise)
+ box_noise_max: Max noise in pixels (None = no clamp)
+ box_noise_tiers: Multi-tier noise as list of (prob, std).
+ Overrides box_noise_std when set.
+ text_query_prob: Probability of text-only queries (SAM3 recommended: 0.7)
+ Only used when use_geometry_prompts=False (legacy 2-mode).
+ keep_text_for_visual: If True, visual queries keep category text
+ If False (default), visual queries use "visual" as text.
+ Only used when use_geometry_prompts=False (legacy 2-mode).
+ use_geometry_prompts: If True, 5-mode training with 2 queries
+ per category (Branch 1 o2m + Branch 2 o2o).
+ geometric_query_str: Text for geometry queries (default "geometric")
+ visual_query_str: Text for visual queries (default "visual")
+ text_only_prob: Branch 1 probability of TEXT mode (no box).
+ Remaining (1-text_only_prob) is box-based (VISUAL or VISUAL+LABEL).
+ use_label_prob: Probability of +LABEL variant when query has a box.
+ Controls both Branch 1 (VISUAL vs VISUAL+LABEL) and
+ Branch 2 (GEOMETRY vs GEOMETRY+LABEL).
+ +LABEL format: "visual: car" / "geometric: car".
+ oracle_eval: If True, each GT 2D box becomes its own geometry
+ prompt (one-to-one). For measuring 3D regression quality
+ in isolation, following DetAny3D's GT prompt evaluation.
+ oracle_text_category: If True, oracle mode with category text.
+ Each GT box = one GEOMETRY+LABEL prompt with text
+ "geometric: " (e.g., "geometric: apple").
+ """
+ self.max_prompts_per_image = max_prompts_per_image
+ self.use_text_prompts = use_text_prompts
+ self.default_text = default_text
+
+ # Point prompt options
+ self.use_point_prompts = use_point_prompts
+ self.num_positive_points = num_positive_points
+ self.num_negative_points = num_negative_points
+ self.point_sample_mode = point_sample_mode
+
+ # Box prompt options
+ self.use_box_prompts = use_box_prompts
+ self.box_noise_std = box_noise_std
+ self.box_noise_max = box_noise_max
+ self.box_noise_tiers = box_noise_tiers
+
+ # Text/Visual query ratio
+ self.text_query_prob = text_query_prob
+ self.keep_text_for_visual = keep_text_for_visual
+
+ # Geometry prompt options (5-mode training)
+ self.use_geometry_prompts = use_geometry_prompts
+ self.geometric_query_str = geometric_query_str
+ self.visual_query_str = visual_query_str
+ self.text_only_prob = text_only_prob
+ self.use_label_prob = use_label_prob
+
+ # Oracle evaluation mode
+ self.oracle_eval = oracle_eval
+ self.oracle_text_category = oracle_text_category
+
+ # Point prompt: box/point budget
+ self.box_chance = box_chance
+ self.point_mode_prob = point_mode_prob
+
+ # Negative sampling (SAM3 style presence loss training)
+ self.include_negatives = include_negatives
+ self.max_negatives_per_image = max_negatives_per_image
+
+ # Training vs inference filtering
+ self.filter_empty_boxes = filter_empty_boxes
+
+ def _sample_box_noise_std(self) -> float:
+ """Sample box noise std from tiers or fallback to self.box_noise_std."""
+ if self.box_noise_tiers is not None:
+ r = random.random()
+ cumulative = 0.0
+ for prob, std in self.box_noise_tiers:
+ cumulative += prob
+ if r < cumulative:
+ return std
+ return self.box_noise_tiers[-1][1]
+ return self.box_noise_std
+
+ def _sample_num_points(self, num_spec: int | tuple[int, int]) -> int:
+ """Sample number of points from spec."""
+ if isinstance(num_spec, int):
+ return num_spec
+ else:
+ low, high = num_spec
+ return np.random.randint(low, high + 1)
+
+ def _sample_points_for_box(
+ self,
+ box_xyxy: np.ndarray,
+ mask: Optional[np.ndarray],
+ H: int,
+ W: int,
+ ) -> np.ndarray:
+ """Sample points for a single box.
+
+ Args:
+ box_xyxy: Box in pixel xyxy format
+ mask: Optional binary mask (H, W)
+ H, W: Image dimensions
+
+ Returns:
+ Points array (N, 3) with (x, y, label) in pixel coords
+ """
+ n_pos = self._sample_num_points(self.num_positive_points)
+ n_neg = self._sample_num_points(self.num_negative_points)
+
+ if mask is not None:
+ # Sample from actual mask
+ points = sample_points_from_mask(
+ mask, n_pos + n_neg, self.point_sample_mode, box_xyxy
+ )
+ else:
+ # Use box as pseudo-mask
+ points = sample_points_without_mask(box_xyxy, n_pos, n_neg, H, W)
+
+ return points
+
+ def _sample_geo_budget(self) -> tuple[int, bool]:
+ """Sample geometric prompt budget (SAM3 style).
+
+ Returns:
+ (n_points, use_box): number of point prompts and whether to
+ include a box. Box takes 1 slot from the total budget.
+ """
+ n_total = self._sample_num_points(self.num_positive_points)
+ if self.box_chance > 0:
+ use_box = random.random() < self.box_chance
+ n_points = max(n_total - int(use_box), 0)
+ else:
+ use_box = False
+ n_points = n_total
+ return n_points, use_box
+
+ def _sample_points_normalized(
+ self,
+ box_xyxy_pixel: np.ndarray,
+ n_points: int,
+ H: int,
+ W: int,
+ mask: Optional[np.ndarray] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Sample n_points from box region, return normalized coords + labels.
+
+ Returns:
+ pts_xy: (n_points, 2) normalized [0,1]
+ pts_labels: (n_points,) long, 1=positive 0=negative
+ """
+ if n_points <= 0:
+ return None, None
+ if mask is not None:
+ points = sample_points_from_mask(
+ mask, n_points, self.point_sample_mode, box_xyxy_pixel
+ )
+ else:
+ points = sample_points_without_mask(
+ box_xyxy_pixel, n_points, 0, H, W
+ )
+ pts_xy = torch.tensor(
+ points[:, :2] / np.array([W, H]),
+ dtype=torch.float32,
+ )
+ pts_labels = torch.tensor(points[:, 2], dtype=torch.long)
+ return pts_xy, pts_labels
+
+ def __call__(self, batch: List[dict]) -> WildDet3DInput:
+ """Collate batch of per-image samples to WildDet3DInput.
+
+ Args:
+ batch: List of dicts, each containing:
+ - images: (3, H, W)
+ - boxes2d: (N_i, 4) pixel xyxy
+ - boxes2d_classes: (N_i,) class indices
+ - boxes2d_names: List[str] class names (optional)
+ - boxes3d: (N_i, 7+) 3D box params
+ - intrinsics: (3, 3)
+ - masks2d: (N_i, H, W) binary masks (optional)
+
+ Returns:
+ WildDet3DInput with per-prompt batch
+ """
+ profile_start(" collator_total")
+
+ # Filter out images with no GT boxes to avoid empty prompts.
+ # Only applied during training; at test time we keep all images so
+ # that the evaluator receives predictions for every image (even ones
+ # with 0 valid 3D GT boxes). _forward_test already handles the
+ # n_prompts_this_img==0 case by returning empty tensors.
+ original_batch_size = len(batch)
+ if self.filter_empty_boxes:
+ batch = [
+ item for item in batch
+ if item.get("boxes2d") is not None and len(item["boxes2d"]) > 0
+ ]
+
+ # if len(batch) < original_batch_size:
+ # import torch.distributed as dist
+ # rank = dist.get_rank() if dist.is_initialized() else 0
+ # filtered_count = original_batch_size - len(batch)
+ # print(
+ # f"[WildDet3DCollator] Filtered {filtered_count}/{original_batch_size} "
+ # f"empty images on rank {rank}"
+ # )
+
+ B = len(batch)
+
+ # Handle completely empty batch (all images filtered out)
+ if B == 0:
+ # import torch.distributed as dist
+ # rank = dist.get_rank() if dist.is_initialized() else 0
+ # print(
+ # f"[WildDet3DCollator] WARNING: Entire batch empty after filtering "
+ # f"({original_batch_size} images all had 0 GT boxes) on rank {rank}"
+ # )
+ # Return minimal empty batch - model will handle this gracefully
+ return WildDet3DInput(
+ images=torch.zeros(0, 3, 1, 1), # (0, 3, H, W)
+ intrinsics=torch.zeros(0, 3, 3), # (0, 3, 3)
+ img_ids=torch.zeros(0, dtype=torch.long),
+ text_ids=torch.zeros(0, dtype=torch.long),
+ unique_texts=[self.default_text],
+ sample_names=None,
+ dataset_name=None,
+ original_hw=None,
+ original_images=None,
+ original_intrinsics=None,
+ padding=None,
+ )
+
+ device = batch[0]["images"].device if batch[0]["images"].is_cuda else "cpu"
+
+ # Collect image-level data
+ profile_start(" collator_image_stack")
+ # Images might be (3, H, W) or (1, 3, H, W) depending on data pipeline
+ images_list = []
+ for b in batch:
+ img = b["images"]
+ # Handle case where img might have extra batch dim
+ if img.dim() == 4 and img.shape[0] == 1:
+ img = img.squeeze(0) # (1, 3, H, W) -> (3, H, W)
+ images_list.append(img)
+ images = torch.stack(images_list) # (B, 3, H, W)
+ intrinsics = torch.stack([b["intrinsics"] for b in batch]) # (B, 3, 3)
+ H, W = images.shape[-2:] # Use -2: and -1 for H, W to be safe
+ profile_stop(" collator_image_stack")
+
+ # Collect metadata for evaluation/visualization
+ sample_names = []
+ dataset_name_list = []
+ original_hw_list = []
+ original_images_list = []
+ original_intrinsics_list = []
+ padding_list = []
+ for b_idx, b in enumerate(batch):
+ # sample_names - image identifier for evaluation
+ if "sample_names" in b:
+ sample_names.append(b["sample_names"])
+ elif "image_id" in b:
+ sample_names.append(b["image_id"])
+ else:
+ sample_names.append(None)
+
+ # dataset_name - for evaluator to route to correct dataset
+ if "dataset_name" in b:
+ dataset_name_list.append(b["dataset_name"])
+ else:
+ dataset_name_list.append(None)
+
+ # original_hw - for coordinate scaling back
+ if "original_hw" in b:
+ original_hw_list.append(b["original_hw"])
+ else:
+ original_hw_list.append(None)
+
+ # original_images - unresized images for visualization
+ if "original_images" in b:
+ original_images_list.append(b["original_images"])
+ else:
+ original_images_list.append(None)
+
+ # original_intrinsics - intrinsics before resize
+ if "original_intrinsics" in b:
+ original_intrinsics_list.append(b["original_intrinsics"])
+ else:
+ original_intrinsics_list.append(None)
+
+ # padding - CenterPad offsets [pad_left, pad_right, pad_top, pad_bottom]
+ if "padding" in b:
+ padding_list.append(b["padding"])
+ else:
+ padding_list.append(None)
+
+ # Collect depth maps for geometry backend supervision
+ depth_maps_list = []
+ for b in batch:
+ # depth_maps - K.depth_maps key from dataset
+ if "depth_maps" in b and b["depth_maps"] is not None:
+ depth_maps_list.append(b["depth_maps"])
+ else:
+ depth_maps_list.append(None)
+
+ # Stack depth maps if available (all images must have depth)
+ depth_gt = None
+ if depth_maps_list and all(d is not None for d in depth_maps_list):
+ try:
+ depth_gt = torch.stack(depth_maps_list, dim=0) # (B, H, W) or (B, 1, H, W)
+ if depth_gt.dim() == 3:
+ depth_gt = depth_gt.unsqueeze(1) # (B, H, W) -> (B, 1, H, W)
+ except (RuntimeError, TypeError):
+ depth_gt = None
+
+ # Convert to proper format (None if all are None)
+ sample_names = sample_names if any(s is not None for s in sample_names) else None
+ dataset_name = dataset_name_list if any(d is not None for d in dataset_name_list) else None
+ original_hw = original_hw_list if any(h is not None for h in original_hw_list) else None
+ padding = padding_list if any(p is not None for p in padding_list) else None
+ original_images = None
+ if any(img is not None for img in original_images_list):
+ # Convert numpy arrays to tensors, then try stacking.
+ # Different-sized images (e.g. cross-dataset) cannot be stacked;
+ # in that case keep as list for the visualizer.
+ imgs = []
+ for img in original_images_list:
+ if img is None:
+ continue
+ if not isinstance(img, torch.Tensor):
+ img = torch.as_tensor(img)
+ imgs.append(img)
+ if len(imgs) == 1:
+ original_images = imgs[0].unsqueeze(0) if imgs[0].dim() == 3 else imgs[0]
+ elif len(imgs) > 1:
+ try:
+ original_images = torch.stack(imgs)
+ except RuntimeError:
+ # Different shapes across batch - keep first only
+ original_images = imgs[0].unsqueeze(0) if imgs[0].dim() == 3 else imgs[0]
+ original_intrinsics = None
+ if any(intr is not None for intr in original_intrinsics_list):
+ intrs = []
+ for intr in original_intrinsics_list:
+ if intr is None:
+ continue
+ if not isinstance(intr, torch.Tensor):
+ intr = torch.as_tensor(intr)
+ intrs.append(intr)
+ try:
+ original_intrinsics = torch.stack(intrs)
+ except (RuntimeError, TypeError):
+ original_intrinsics = None
+
+ # Build per-prompt data (SAM3 original: per-category queries)
+ # If use_geometry_prompts=True: Each category creates TWO queries
+ # - TEXT query (one-to-many targets)
+ # - GEOMETRY query (one-to-one target)
+ # If use_geometry_prompts=False: Original behavior (text or visual per category)
+ img_ids_list = []
+ text_ids_list = []
+ geo_boxes_list = [] # normalized cxcywh (for visual/geometry queries)
+ geo_points_list = [] # normalized xy (N, 2) or None
+ geo_point_labels_list = [] # labels (N,) or None
+ is_visual_query_list = [] # Track which queries have visual prompts
+ # Query types (collator-level label only, does NOT control SAM3 internal matching):
+ # 0=TEXT, 1=VISUAL, 2=GEOMETRY, 3=VISUAL+LABEL, 4=GEOMETRY+LABEL
+ query_types_list = []
+
+ # Multi-instance targets: list of lists
+ # gt_boxes2d_per_query[i] = list of normalized xyxy boxes for query i
+ gt_boxes2d_per_query = []
+ gt_boxes3d_per_query = []
+ gt_category_ids_list = []
+
+ # Ignore boxes per query (for negative loss suppression)
+ ignore_boxes2d_per_query = []
+
+ # Build unique text list
+ unique_texts = []
+ text_to_id = {}
+
+ # Helper function to normalize box to xyxy [0,1]
+ def normalize_box_xyxy(box_xyxy_raw):
+ if isinstance(box_xyxy_raw, torch.Tensor):
+ gt_box_norm = box_xyxy_raw.clone().float()
+ else:
+ gt_box_norm = torch.tensor(box_xyxy_raw, dtype=torch.float32)
+ gt_box_norm[0::2] /= W
+ gt_box_norm[1::2] /= H
+ return gt_box_norm.to(device)
+
+ # Helper function to convert xyxy to cxcywh
+ def xyxy_to_cxcywh(box_norm_xyxy):
+ cx = (box_norm_xyxy[0] + box_norm_xyxy[2]) / 2
+ cy = (box_norm_xyxy[1] + box_norm_xyxy[3]) / 2
+ w_box = box_norm_xyxy[2] - box_norm_xyxy[0]
+ h_box = box_norm_xyxy[3] - box_norm_xyxy[1]
+ return torch.tensor([cx, cy, w_box, h_box], device=device)
+
+ profile_start(" collator_category_group")
+
+ if self.oracle_eval:
+ # ========== Oracle Mode: Each GT box = one geometry prompt ==========
+ # Following DetAny3D's GT prompt evaluation approach.
+ # One-to-one mapping: each GT box becomes a separate geometry
+ # prompt, model predicts 3D for each box independently.
+ geo_text = self.geometric_query_str
+ if geo_text not in text_to_id:
+ text_to_id[geo_text] = len(unique_texts)
+ unique_texts.append(geo_text)
+ geo_text_id = text_to_id[geo_text]
+
+ for img_idx, sample in enumerate(batch):
+ boxes2d = sample.get("boxes2d")
+ boxes3d = sample.get("boxes3d")
+ class_ids = sample.get("boxes2d_classes")
+
+ if boxes2d is None or len(boxes2d) == 0:
+ continue
+
+ # During test, boxes2d are in original pixel space (test
+ # transforms don't include ResizeBoxes2D / CenterPadBoxes2D).
+ # Transform to padded pixel space using the SAME math as
+ # _forward_test's inverse (subtract pad, divide scale), reversed:
+ # original -> padded: x * scale_x + pad_left
+ # where scale_x = content_w / orig_w (from _forward_test)
+ original_hw = sample.get("original_hw", None)
+ pad_info = sample.get("padding", None)
+
+ if original_hw is not None and pad_info is not None:
+ orig_h, orig_w = original_hw
+ if isinstance(orig_h, torch.Tensor):
+ orig_h, orig_w = orig_h.item(), orig_w.item()
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
+ if isinstance(pad_left, torch.Tensor):
+ pad_left = pad_left.item()
+ pad_right = pad_right.item()
+ pad_top = pad_top.item()
+ pad_bottom = pad_bottom.item()
+ content_w = W - pad_left - pad_right
+ content_h = H - pad_top - pad_bottom
+ scale_x = content_w / orig_w
+ scale_y = content_h / orig_h
+
+ def transform_box_to_padded(box_raw):
+ """Transform box: original pixel -> padded pixel."""
+ if isinstance(box_raw, torch.Tensor):
+ box = box_raw.clone().float()
+ else:
+ box = torch.tensor(box_raw, dtype=torch.float32)
+ box[0::2] = box[0::2] * scale_x + pad_left
+ box[1::2] = box[1::2] * scale_y + pad_top
+ return box
+ else:
+ def transform_box_to_padded(box_raw):
+ if isinstance(box_raw, torch.Tensor):
+ return box_raw.clone().float()
+ return torch.tensor(box_raw, dtype=torch.float32)
+
+ for box_idx in range(len(boxes2d)):
+ img_ids_list.append(img_idx)
+
+ # Category ID
+ if class_ids is not None:
+ cat_id = class_ids[box_idx]
+ if isinstance(cat_id, torch.Tensor):
+ cat_id = cat_id.item()
+ else:
+ cat_id = 0
+ gt_category_ids_list.append(cat_id)
+
+ # Geometry query type
+ query_types_list.append(2) # GEOMETRY
+ is_visual_query_list.append(True)
+ text_ids_list.append(geo_text_id)
+
+ # Transform box to padded pixel space, then normalize
+ box_padded = transform_box_to_padded(boxes2d[box_idx])
+ box_norm_xyxy = normalize_box_xyxy(box_padded)
+ geo_boxes_list.append(xyxy_to_cxcywh(box_norm_xyxy))
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+
+ # Target = this single box (one-to-one)
+ gt_boxes2d_per_query.append(
+ [normalize_box_xyxy(boxes2d[box_idx])]
+ )
+ if boxes3d is not None and box_idx < len(boxes3d):
+ gt_boxes3d_per_query.append(
+ [boxes3d[box_idx].to(device)]
+ )
+ else:
+ gt_boxes3d_per_query.append(None)
+ # Oracle mode: no ignore box suppression needed
+ ignore_boxes2d_per_query.append([])
+
+ elif self.oracle_text_category:
+ # ========== Oracle + Text Category Mode ==========
+ # Same as oracle (each GT box = one geometry prompt), but with
+ # category-specific text: "geometric: " instead of
+ # generic "geometric". Query type = GEOMETRY+LABEL (4).
+ for img_idx, sample in enumerate(batch):
+ boxes2d = sample.get("boxes2d")
+ boxes3d = sample.get("boxes3d")
+ class_ids = sample.get("boxes2d_classes")
+ class_names = sample.get("boxes2d_names", None)
+
+ if boxes2d is None or len(boxes2d) == 0:
+ continue
+
+ original_hw = sample.get("original_hw", None)
+ pad_info = sample.get("padding", None)
+
+ if original_hw is not None and pad_info is not None:
+ orig_h, orig_w = original_hw
+ if isinstance(orig_h, torch.Tensor):
+ orig_h, orig_w = orig_h.item(), orig_w.item()
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
+ if isinstance(pad_left, torch.Tensor):
+ pad_left = pad_left.item()
+ pad_right = pad_right.item()
+ pad_top = pad_top.item()
+ pad_bottom = pad_bottom.item()
+ content_w = W - pad_left - pad_right
+ content_h = H - pad_top - pad_bottom
+ scale_x = content_w / orig_w
+ scale_y = content_h / orig_h
+
+ def transform_box_to_padded(box_raw):
+ """Transform box: original pixel -> padded pixel."""
+ if isinstance(box_raw, torch.Tensor):
+ box = box_raw.clone().float()
+ else:
+ box = torch.tensor(box_raw, dtype=torch.float32)
+ box[0::2] = box[0::2] * scale_x + pad_left
+ box[1::2] = box[1::2] * scale_y + pad_top
+ return box
+ else:
+ def transform_box_to_padded(box_raw):
+ if isinstance(box_raw, torch.Tensor):
+ return box_raw.clone().float()
+ return torch.tensor(box_raw, dtype=torch.float32)
+
+ for box_idx in range(len(boxes2d)):
+ img_ids_list.append(img_idx)
+
+ # Category ID
+ if class_ids is not None:
+ cat_id = class_ids[box_idx]
+ if isinstance(cat_id, torch.Tensor):
+ cat_id = cat_id.item()
+ else:
+ cat_id = 0
+ gt_category_ids_list.append(cat_id)
+
+ # Get category name
+ if class_names is not None and cat_id < len(class_names):
+ cat_name = class_names[cat_id]
+ else:
+ cat_name = self.default_text
+
+ # GEOMETRY+LABEL query: "geometric: "
+ gl_text = f"{self.geometric_query_str}: {cat_name}"
+ if gl_text not in text_to_id:
+ text_to_id[gl_text] = len(unique_texts)
+ unique_texts.append(gl_text)
+ query_types_list.append(4) # GEOMETRY+LABEL
+ is_visual_query_list.append(True)
+ text_ids_list.append(text_to_id[gl_text])
+
+ # Transform box to padded pixel space, then normalize
+ box_padded = transform_box_to_padded(boxes2d[box_idx])
+ box_norm_xyxy = normalize_box_xyxy(box_padded)
+ geo_boxes_list.append(xyxy_to_cxcywh(box_norm_xyxy))
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+
+ # Target = this single box (one-to-one)
+ gt_boxes2d_per_query.append(
+ [normalize_box_xyxy(boxes2d[box_idx])]
+ )
+ if boxes3d is not None and box_idx < len(boxes3d):
+ gt_boxes3d_per_query.append(
+ [boxes3d[box_idx].to(device)]
+ )
+ else:
+ gt_boxes3d_per_query.append(None)
+ ignore_boxes2d_per_query.append([])
+
+ else:
+ # ========== Standard Mode: Group by category ==========
+ for img_idx, sample in enumerate(batch):
+ boxes2d = sample.get("boxes2d") # (N_i, 4) pixel xyxy
+ boxes3d = sample.get("boxes3d") # (N_i, 7+)
+ class_ids = sample.get("boxes2d_classes") # (N_i,)
+ class_names = sample.get("boxes2d_names", None) # List[str] or None
+ masks2d = sample.get("masks2d", None) # (N_i, H, W) or None
+
+ if boxes2d is None or len(boxes2d) == 0:
+ continue
+
+ # ========== SAM3 Original: Group boxes by category ==========
+ cat_to_box_indices = defaultdict(list)
+ for box_idx in range(len(boxes2d)):
+ if class_ids is not None:
+ cat_id = class_ids[box_idx]
+ if isinstance(cat_id, torch.Tensor):
+ cat_id = cat_id.item()
+ else:
+ cat_id = 0
+ cat_to_box_indices[cat_id].append(box_idx)
+
+ # Group ignore boxes by category (for negative loss suppression)
+ ignore_boxes2d_raw = sample.get("ignore_boxes2d", None)
+ ignore_class_ids_raw = sample.get("ignore_class_ids", None)
+ cat_to_ignore_indices = defaultdict(list)
+ if (
+ ignore_boxes2d_raw is not None
+ and len(ignore_boxes2d_raw) > 0
+ ):
+ for ign_idx in range(len(ignore_boxes2d_raw)):
+ ign_cat_id = int(ignore_class_ids_raw[ign_idx])
+ cat_to_ignore_indices[ign_cat_id].append(ign_idx)
+
+ # Limit number of categories (queries) per image
+ categories = list(cat_to_box_indices.keys())
+ if len(categories) > self.max_prompts_per_image:
+ random.shuffle(categories)
+ categories = categories[:self.max_prompts_per_image]
+
+ # ========== Create queries per category ==========
+ for cat_id in categories:
+ box_indices = cat_to_box_indices[cat_id]
+
+ # Get category name for text
+ if self.use_text_prompts and class_names is not None:
+ cat_name = class_names[cat_id] if cat_id < len(class_names) else self.default_text
+ else:
+ cat_name = self.default_text
+
+ if self.use_geometry_prompts:
+ # ========== 5-Mode Training ==========
+ # Creates 2 queries per category:
+ #
+ # Branch 1 ("multi-target"): target = ALL instances of this category
+ # - TEXT: text="car", no box
+ # - VISUAL: text="visual", geo_box
+ # - VISUAL+LABEL: text="visual: car", geo_box
+ #
+ # Branch 2 ("single-target"): target = 1 selected instance only
+ # - GEOMETRY: text="geometric", geo_box
+ # - GEOMETRY+LABEL: text="geometric: car", geo_box
+ #
+ # NOTE on "multi-target" vs "single-target":
+ # This refers to how many GT boxes are assigned as
+ # targets in this collator (num_gts). This is DIFFERENT
+ # from SAM3's internal o2o/o2m matching (DAC mechanism).
+ # SAM3's DAC always runs both Hungarian (o2o) and
+ # one-to-many (o2m) matchers in the decoder regardless
+ # of how many GT targets we assign here.
+
+ # Helper: add text to unique_texts and return its id
+ def _get_text_id(text_str):
+ if text_str not in text_to_id:
+ text_to_id[text_str] = len(unique_texts)
+ unique_texts.append(text_str)
+ return text_to_id[text_str]
+
+ # Helper: select a random GT box and return its
+ # normalized cxcywh (with optional noise)
+ def _make_geo_box(box_indices_inner):
+ sel_idx = random.choice(box_indices_inner)
+ bx = boxes2d[sel_idx]
+ bx_np = bx.cpu().numpy() if isinstance(bx, torch.Tensor) else bx
+ std = self._sample_box_noise_std()
+ if std > 0:
+ bx_np = noise_box(
+ bx_np,
+ im_size=(H, W),
+ box_noise_std=std,
+ box_noise_max=self.box_noise_max,
+ )
+ norm_xyxy = torch.tensor([
+ bx_np[0] / W, bx_np[1] / H,
+ bx_np[2] / W, bx_np[3] / H,
+ ], dtype=torch.float32, device=device)
+ return sel_idx, xyxy_to_cxcywh(norm_xyxy)
+
+ # ----- Branch 1 (multi-target): TEXT / VISUAL / VISUAL+LABEL -----
+ img_ids_list.append(img_idx)
+ gt_category_ids_list.append(cat_id)
+
+ is_text_only = random.random() < self.text_only_prob
+ if is_text_only:
+ # TEXT: text="car", no box, no points, all targets
+ query_types_list.append(0) # TEXT
+ is_visual_query_list.append(False)
+ text_ids_list.append(_get_text_id(cat_name))
+ geo_boxes_list.append(None)
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+ else:
+ # Box-based o2m query
+ has_label = random.random() < self.use_label_prob
+ if has_label:
+ # VISUAL+LABEL: text="visual: car", box, all targets
+ query_types_list.append(3) # VISUAL+LABEL
+ vl_text = f"{self.visual_query_str}: {cat_name}"
+ text_ids_list.append(_get_text_id(vl_text))
+ else:
+ # VISUAL: text="visual", box, all targets
+ query_types_list.append(1) # VISUAL
+ text_ids_list.append(_get_text_id(self.visual_query_str))
+ is_visual_query_list.append(True)
+ _, geo_cxcywh = _make_geo_box(box_indices)
+ geo_boxes_list.append(geo_cxcywh)
+ # Branch 1 visual: no point prompts (box only)
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+
+ # Targets: ALL boxes of this category (multi-target)
+ query_gt_boxes2d = []
+ query_gt_boxes3d = []
+ for box_idx in box_indices:
+ query_gt_boxes2d.append(normalize_box_xyxy(boxes2d[box_idx]))
+ if boxes3d is not None and box_idx < len(boxes3d):
+ query_gt_boxes3d.append(boxes3d[box_idx].to(device))
+ gt_boxes2d_per_query.append(query_gt_boxes2d)
+ gt_boxes3d_per_query.append(query_gt_boxes3d if query_gt_boxes3d else None)
+ # Collect ignore boxes for this category
+ ign_indices = cat_to_ignore_indices.get(cat_id, [])
+ query_ign = [normalize_box_xyxy(ignore_boxes2d_raw[i]) for i in ign_indices] if ign_indices and ignore_boxes2d_raw is not None else []
+ ignore_boxes2d_per_query.append(query_ign)
+
+ # ----- Branch 2 (single-target): GEOMETRY / GEOMETRY+LABEL -----
+ img_ids_list.append(img_idx)
+ gt_category_ids_list.append(cat_id)
+
+ has_label_b2 = random.random() < self.use_label_prob
+ if has_label_b2:
+ # GEOMETRY+LABEL: text="geometric: car", 1 target
+ query_types_list.append(4) # GEOMETRY+LABEL
+ gl_text = f"{self.geometric_query_str}: {cat_name}"
+ text_ids_list.append(_get_text_id(gl_text))
+ else:
+ # GEOMETRY: text="geometric", 1 target
+ query_types_list.append(2) # GEOMETRY
+ text_ids_list.append(_get_text_id(self.geometric_query_str))
+ is_visual_query_list.append(True)
+
+ selected_idx, geo_cxcywh = _make_geo_box(box_indices)
+
+ # Decide geometric prompt mode for Branch 2
+ if self.use_point_prompts:
+ # Exclusive mode: box OR point, never both
+ masks2d = sample.get(
+ "masks2d", None
+ )
+ has_mask = (
+ masks2d is not None
+ and selected_idx < len(masks2d)
+ and masks2d[selected_idx].sum() > 0
+ )
+ use_pt = (
+ has_mask
+ and random.random()
+ < self.point_mode_prob
+ )
+ if use_pt:
+ # Point-only (no box)
+ sel_mask = masks2d[selected_idx]
+ if isinstance(
+ sel_mask, torch.Tensor
+ ):
+ sel_mask = (
+ sel_mask.cpu().numpy()
+ )
+ sel_box = boxes2d[selected_idx]
+ sel_box_np = (
+ sel_box.cpu().numpy()
+ if isinstance(
+ sel_box, torch.Tensor
+ )
+ else np.array(sel_box)
+ )
+ n_pts = self._sample_num_points(
+ self.num_positive_points
+ )
+ if n_pts == 1:
+ # Single point: always positive
+ # from mask center (farthest
+ # from edges)
+ points = sample_points_from_mask(
+ sel_mask,
+ 1,
+ "centered",
+ )
+ else:
+ # Multi-point: random_box mode,
+ # mask determines pos/neg labels
+ points = sample_points_from_mask(
+ sel_mask,
+ n_pts,
+ "random_box",
+ sel_box_np,
+ )
+ pts_xy = torch.tensor(
+ points[:, :2]
+ / np.array([W, H]),
+ dtype=torch.float32,
+ )
+ pts_labels = torch.tensor(
+ points[:, 2],
+ dtype=torch.long,
+ )
+ geo_boxes_list.append(None)
+ geo_points_list.append(pts_xy)
+ geo_point_labels_list.append(
+ pts_labels
+ )
+ else:
+ # Box-only (no points)
+ geo_boxes_list.append(geo_cxcywh)
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+ else:
+ geo_boxes_list.append(geo_cxcywh)
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+
+ # Target: ONLY the selected box (single-target)
+ query_gt_boxes2d = [normalize_box_xyxy(boxes2d[selected_idx])]
+ query_gt_boxes3d = []
+ if boxes3d is not None and selected_idx < len(boxes3d):
+ query_gt_boxes3d.append(boxes3d[selected_idx].to(device))
+ gt_boxes2d_per_query.append(query_gt_boxes2d)
+ gt_boxes3d_per_query.append(query_gt_boxes3d if query_gt_boxes3d else None)
+ # Same ignore boxes as Branch 1 (same category)
+ ign_indices = cat_to_ignore_indices.get(cat_id, [])
+ query_ign = [normalize_box_xyxy(ignore_boxes2d_raw[i]) for i in ign_indices] if ign_indices and ignore_boxes2d_raw is not None else []
+ ignore_boxes2d_per_query.append(query_ign)
+
+ else:
+ # ========== Original: Text/Visual random selection ==========
+ img_ids_list.append(img_idx)
+ gt_category_ids_list.append(cat_id)
+
+ # Decide query type: text-only or visual
+ is_text_query = random.random() < self.text_query_prob
+ is_visual_query = not is_text_query
+
+ # Track query type (0=TEXT for both text and visual in original mode)
+ query_types_list.append(0 if is_text_query else 1) # 1=VISUAL
+ is_visual_query_list.append(is_visual_query)
+
+ # Determine text for this query
+ if is_visual_query and not self.keep_text_for_visual:
+ text = "visual"
+ else:
+ text = cat_name
+
+ if text not in text_to_id:
+ text_to_id[text] = len(unique_texts)
+ unique_texts.append(text)
+ text_ids_list.append(text_to_id[text])
+
+ # Visual query: pick one target as geo_box
+ if is_visual_query and self.use_box_prompts:
+ selected_idx = random.choice(box_indices)
+ box_xyxy = boxes2d[selected_idx]
+ box_xyxy_np = box_xyxy.cpu().numpy() if isinstance(box_xyxy, torch.Tensor) else box_xyxy
+
+ std = self._sample_box_noise_std()
+ if std > 0:
+ box_xyxy_np = noise_box(
+ box_xyxy_np,
+ im_size=(H, W),
+ box_noise_std=std,
+ box_noise_max=self.box_noise_max,
+ )
+
+ box_norm_xyxy = torch.tensor([
+ box_xyxy_np[0] / W,
+ box_xyxy_np[1] / H,
+ box_xyxy_np[2] / W,
+ box_xyxy_np[3] / H,
+ ], dtype=torch.float32, device=device)
+ geo_boxes_list.append(xyxy_to_cxcywh(box_norm_xyxy))
+ else:
+ geo_boxes_list.append(None)
+ # Legacy mode: no point prompts
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+
+ # Multi-instance targets: ALL boxes of this category
+ query_gt_boxes2d = []
+ query_gt_boxes3d = []
+ for box_idx in box_indices:
+ query_gt_boxes2d.append(normalize_box_xyxy(boxes2d[box_idx]))
+ if boxes3d is not None and box_idx < len(boxes3d):
+ query_gt_boxes3d.append(boxes3d[box_idx].to(device))
+ gt_boxes2d_per_query.append(query_gt_boxes2d)
+ gt_boxes3d_per_query.append(query_gt_boxes3d if query_gt_boxes3d else None)
+ # Collect ignore boxes for this category
+ ign_indices = cat_to_ignore_indices.get(cat_id, [])
+ query_ign = [normalize_box_xyxy(ignore_boxes2d_raw[i]) for i in ign_indices] if ign_indices and ignore_boxes2d_raw is not None else []
+ ignore_boxes2d_per_query.append(query_ign)
+
+ # ========== Negative sampling (SAM3 style) ==========
+ # Add TEXT queries for absent categories (num_gts=0).
+ # These train the presence head to predict "not present".
+ # SAM3 does this via COCO_FROM_JSON include_negatives=True.
+ if (
+ self.include_negatives
+ and class_names is not None
+ and 0 < len(class_names) <= 100
+ ):
+ present_cats = set(cat_to_box_indices.keys())
+ all_cats = set(range(len(class_names)))
+ absent_cats = list(all_cats - present_cats)
+
+ if len(absent_cats) > self.max_negatives_per_image:
+ absent_cats = random.sample(
+ absent_cats, self.max_negatives_per_image
+ )
+
+ for neg_cat_id in absent_cats:
+ neg_cat_name = class_names[neg_cat_id]
+ img_ids_list.append(img_idx)
+ gt_category_ids_list.append(neg_cat_id)
+ query_types_list.append(0) # TEXT (exhaustive)
+ is_visual_query_list.append(False)
+ if neg_cat_name not in text_to_id:
+ text_to_id[neg_cat_name] = len(unique_texts)
+ unique_texts.append(neg_cat_name)
+ text_ids_list.append(text_to_id[neg_cat_name])
+ geo_boxes_list.append(None)
+ geo_points_list.append(None)
+ geo_point_labels_list.append(None)
+ gt_boxes2d_per_query.append([])
+ gt_boxes3d_per_query.append(None)
+ ignore_boxes2d_per_query.append([])
+
+ profile_stop(" collator_category_group")
+
+ N_prompts = len(img_ids_list)
+
+ if N_prompts == 0:
+ import torch.distributed as dist
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ print(
+ f"[WildDet3DCollator] WARNING: Unexpected N_prompts=0 "
+ f"(B={B} images passed filter) on rank {rank}"
+ )
+ return WildDet3DInput(
+ images=images,
+ intrinsics=intrinsics,
+ img_ids=torch.zeros(0, dtype=torch.long, device=device),
+ text_ids=torch.zeros(0, dtype=torch.long, device=device),
+ unique_texts=[self.default_text],
+ sample_names=sample_names,
+ dataset_name=dataset_name,
+ original_hw=original_hw,
+ original_images=original_images,
+ original_intrinsics=original_intrinsics,
+ padding=padding,
+ )
+
+ # Stack tensors
+ profile_start(" collator_tensor_stack")
+ img_ids = torch.tensor(img_ids_list, dtype=torch.long, device=device)
+ text_ids = torch.tensor(text_ids_list, dtype=torch.long, device=device)
+
+ # ========== Box prompts for visual queries ==========
+ # geo_boxes: (N_prompts, 1, 4) - None for text-only queries
+ geo_boxes = None
+ geo_boxes_mask = None
+ geo_box_labels = None
+
+ # Check if any visual queries exist
+ has_visual = any(g is not None for g in geo_boxes_list)
+ if has_visual:
+ # Stack geo_boxes, use zeros for text-only queries
+ stacked_geo_boxes = []
+ for g in geo_boxes_list:
+ if g is not None:
+ stacked_geo_boxes.append(g)
+ else:
+ stacked_geo_boxes.append(torch.zeros(4, device=device))
+ geo_boxes = torch.stack(stacked_geo_boxes).unsqueeze(1) # (N, 1, 4)
+
+ # Mask: True = padding (i.e., text-only queries have no valid box)
+ geo_boxes_mask = torch.tensor(
+ [[g is None] for g in geo_boxes_list],
+ dtype=torch.bool, device=device
+ ) # (N, 1)
+
+ # Labels: 1 for positive (valid) boxes
+ geo_box_labels = torch.tensor(
+ [[1 if g is not None else 0] for g in geo_boxes_list],
+ dtype=torch.long, device=device
+ ) # (N, 1)
+
+ # ========== Point prompts: pad to (N_prompts, max_P, 2) ==========
+ geo_points = None
+ geo_points_mask = None
+ geo_point_labels = None
+ has_points = any(p is not None for p in geo_points_list)
+ if has_points:
+ max_P = max(
+ len(p) for p in geo_points_list if p is not None
+ )
+ if max_P > 0:
+ pts_padded = []
+ pts_mask_list = []
+ pts_labels_padded = []
+ for pts, lbls in zip(
+ geo_points_list, geo_point_labels_list
+ ):
+ if pts is None or len(pts) == 0:
+ pts_padded.append(
+ torch.zeros(max_P, 2, device=device)
+ )
+ pts_mask_list.append(
+ torch.ones(max_P, dtype=torch.bool, device=device)
+ )
+ pts_labels_padded.append(
+ torch.zeros(max_P, dtype=torch.long, device=device)
+ )
+ else:
+ n = len(pts)
+ pad_n = max_P - n
+ pts_padded.append(torch.cat([
+ pts.to(device),
+ torch.zeros(pad_n, 2, device=device),
+ ]))
+ pts_mask_list.append(torch.cat([
+ torch.zeros(n, dtype=torch.bool, device=device),
+ torch.ones(pad_n, dtype=torch.bool, device=device),
+ ]))
+ pts_labels_padded.append(torch.cat([
+ lbls.to(device),
+ torch.zeros(pad_n, dtype=torch.long, device=device),
+ ]))
+ geo_points = torch.stack(pts_padded) # (N, max_P, 2)
+ geo_points_mask = torch.stack(pts_mask_list) # (N, max_P)
+ geo_point_labels = torch.stack(pts_labels_padded) # (N, max_P)
+
+ # ========== Multi-instance GT boxes: pad to (N_prompts, max_gt, 4) ==========
+ # Find max number of targets per query (at least 1 for tensor shape)
+ max_gt = max(
+ (len(q) for q in gt_boxes2d_per_query), default=1
+ )
+ max_gt = max(max_gt, 1) # Ensure at least 1 for padded tensor shape
+ num_gts_list = []
+
+ gt_boxes2d_padded = []
+ for query_boxes in gt_boxes2d_per_query:
+ n_gt = len(query_boxes)
+ num_gts_list.append(n_gt)
+
+ if n_gt == 0:
+ # Negative query: all-zero padding, num_gts=0
+ padded = [torch.zeros(4, device=device)] * max_gt
+ elif n_gt < max_gt:
+ # Pad with zeros
+ padded = query_boxes + [torch.zeros(4, device=device)] * (max_gt - n_gt)
+ else:
+ padded = query_boxes
+ gt_boxes2d_padded.append(torch.stack(padded))
+
+ gt_boxes2d = torch.stack(gt_boxes2d_padded) # (N_prompts, max_gt, 4)
+ num_gts = torch.tensor(num_gts_list, dtype=torch.long, device=device) # (N_prompts,)
+
+ # 3D boxes (if available)
+ gt_boxes3d = None
+ if any(q is not None for q in gt_boxes3d_per_query):
+ # Get 3D box dimension from first valid entry
+ box3d_dim = None
+ for q in gt_boxes3d_per_query:
+ if q is not None and len(q) > 0:
+ box3d_dim = q[0].shape[-1]
+ break
+
+ if box3d_dim is not None:
+ gt_boxes3d_padded = []
+ for query_boxes in gt_boxes3d_per_query:
+ if query_boxes is None or len(query_boxes) == 0:
+ # No 3D boxes for this query
+ padded = [torch.zeros(box3d_dim, device=device)] * max_gt
+ else:
+ n_gt = len(query_boxes)
+ if n_gt < max_gt:
+ padded = query_boxes + [torch.zeros(box3d_dim, device=device)] * (max_gt - n_gt)
+ else:
+ padded = query_boxes
+ gt_boxes3d_padded.append(torch.stack(padded))
+ gt_boxes3d = torch.stack(gt_boxes3d_padded) # (N_prompts, max_gt, box3d_dim)
+
+ gt_category_ids = torch.tensor(gt_category_ids_list, dtype=torch.long, device=device)
+
+ # ========== Ignore boxes: pad to (N_prompts, max_ignore, 4) ==========
+ max_ignore = max(
+ (len(q) for q in ignore_boxes2d_per_query), default=0
+ )
+ if max_ignore > 0:
+ num_ignores_list = []
+ ignore_padded = []
+ for q in ignore_boxes2d_per_query:
+ n_ign = len(q)
+ num_ignores_list.append(n_ign)
+ if n_ign < max_ignore:
+ padded = q + [
+ torch.zeros(4, device=device)
+ ] * (max_ignore - n_ign)
+ else:
+ padded = q
+ ignore_padded.append(torch.stack(padded))
+ ignore_boxes2d_tensor = torch.stack(ignore_padded)
+ num_ignores_tensor = torch.tensor(
+ num_ignores_list, dtype=torch.long, device=device
+ )
+ else:
+ ignore_boxes2d_tensor = None
+ num_ignores_tensor = None
+
+ # Query types: 0=TEXT, 1=VISUAL, 2=GEOMETRY
+ query_types = torch.tensor(query_types_list, dtype=torch.long, device=device)
+ profile_stop(" collator_tensor_stack")
+ profile_stop(" collator_total")
+
+ return WildDet3DInput(
+ images=images,
+ intrinsics=intrinsics,
+ img_ids=img_ids,
+ text_ids=text_ids,
+ unique_texts=unique_texts,
+ geo_boxes=geo_boxes,
+ geo_boxes_mask=geo_boxes_mask,
+ geo_box_labels=geo_box_labels,
+ geo_points=geo_points,
+ geo_points_mask=geo_points_mask,
+ geo_point_labels=geo_point_labels,
+ gt_boxes2d=gt_boxes2d,
+ gt_boxes3d=gt_boxes3d,
+ num_gts=num_gts,
+ gt_category_ids=gt_category_ids,
+ ignore_boxes2d=ignore_boxes2d_tensor,
+ num_ignores=num_ignores_tensor,
+ query_types=query_types,
+ # Metadata for evaluation/visualization
+ sample_names=sample_names,
+ dataset_name=dataset_name,
+ original_hw=original_hw,
+ original_images=original_images,
+ original_intrinsics=original_intrinsics,
+ padding=padding,
+ # Depth ground truth for geometry backend supervision
+ depth_gt=depth_gt,
+ depth_mask=None, # Not yet implemented
+ )
+
+
+# ============================================================================
+# WildDet3D Specific Connectors
+# ============================================================================
+
+# Training connector for WildDet3D
+# Note: SAM3 uses geometric prompts (boxes/points) instead of text
+CONN_WILDDET3D_TRAIN = {
+ "images": K.images,
+ "input_hw": K.input_hw,
+ # Geometric prompts (boxes as prompts)
+ "prompt_boxes": K.boxes2d, # Use GT boxes as prompts during training
+ "prompt_box_labels": K.boxes2d_classes,
+ # Targets
+ "boxes2d": K.boxes2d,
+ "boxes2d_classes": K.boxes2d_classes,
+ "boxes3d": K.boxes3d,
+ # Camera
+ "intrinsics": K.intrinsics,
+ # Depth for geometry backend
+ "depth_gt": K.depth_maps,
+}
+
+# Test connector for WildDet3D
+CONN_WILDDET3D_TEST = {
+ "images": K.images,
+ "input_hw": K.input_hw,
+ "original_hw": K.original_hw,
+ # Geometric prompts (from external detector or user input)
+ "prompt_boxes": K.boxes2d, # External 2D detections as prompts
+ # Camera
+ "intrinsics": K.intrinsics,
+ "padding": "padding",
+}
+
+# Loss connector for WildDet3D
+CONN_WILDDET3D_LOSS = {
+ # Model outputs
+ "pred_logits": pred_key("pred_logits"),
+ "pred_boxes_2d": pred_key("pred_boxes_2d"),
+ "pred_boxes_3d": pred_key("pred_boxes_3d"),
+ "aux_outputs": pred_key("aux_outputs"),
+ "geom_losses": pred_key("geom_losses"),
+ # Matching indices (computed by model)
+ "indices": pred_key("indices"),
+ # Targets
+ "targets": {
+ "boxes": data_key(K.boxes2d),
+ "boxes_xyxy": data_key(K.boxes2d), # Will be converted
+ "boxes_3d": data_key(K.boxes3d),
+ "num_boxes": data_key("num_boxes"),
+ "image_size": data_key(K.input_hw), # (H, W) for pixel coordinate conversion
+ },
+ # Camera
+ "intrinsics": data_key(K.intrinsics),
+ # Image size for pixel coordinate conversion (following GDino3D)
+ "image_size": data_key(K.input_hw),
+}
+
+# Evaluation connector
+CONN_WILDDET3D_EVAL = {
+ "coco_image_id": data_key(K.sample_names),
+ "pred_boxes": pred_key("boxes"),
+ "pred_scores": pred_key("scores"),
+ "pred_classes": pred_key("class_ids"),
+ "pred_boxes3d": pred_key("boxes3d"),
+}
+
+# Visualization connector
+CONN_WILDDET3D_VIS = {
+ "images": data_key(K.original_images),
+ "image_names": data_key(K.sample_names),
+ "intrinsics": data_key("original_intrinsics"),
+ "boxes3d": pred_key("boxes3d"),
+ "class_ids": pred_key("class_ids"),
+ "scores": pred_key("scores"),
+}
+
+
+class WildDet3DPassthroughConnector:
+ """Data connector that passes WildDet3DInput directly to model.
+
+ Since WildDet3DCollator already produces WildDet3DInput with all needed
+ data, we just pass it through as the 'batch' parameter to model.forward().
+
+ This bypasses the key_mapping approach used by vis4d's DataConnector,
+ which expects raw DataLoader output format.
+ """
+
+ def __call__(self, data: WildDet3DInput) -> dict:
+ """Pass batch directly to model.
+
+ Args:
+ data: WildDet3DInput from collator
+
+ Returns:
+ Dict with 'batch' key pointing to the input data
+ """
+ return {"batch": data}
+
+
+class WildDet3DLossConnector:
+ """Loss connector that passes model output and batch directly to loss.
+
+ Similar to WildDet3DPassthroughConnector, this bypasses vis4d's key_mapping
+ since WildDet3DLoss expects structured objects (WildDet3DOut, WildDet3DInput).
+
+ This connector is used with LossModule to enable proper wandb logging of
+ individual loss components (loss_cls, loss_bbox, loss_giou, etc.).
+ """
+
+ def __call__(self, predictions, batch: WildDet3DInput) -> dict:
+ """Map model output and batch to loss function inputs.
+
+ Args:
+ predictions: WildDet3DOut from model.forward()
+ batch: WildDet3DInput from collator
+
+ Returns:
+ Dict with 'out' and 'batch' keys for WildDet3DLoss.forward()
+ """
+ return {
+ "out": predictions,
+ "batch": batch,
+ }
+
+
+class WildDet3DVisConnector:
+ """Vis connector that extracts from WildDet3DInput for visualization.
+
+ vis4d's CallbackConnector uses dict access (data[key]) which doesn't
+ work with WildDet3DInput dataclass. This connector does the
+ extraction manually.
+
+ Args:
+ score_threshold: Only visualize boxes with score >= this value.
+ Separate from model's score_threshold so evaluation AP is unaffected.
+ """
+
+ def __init__(self, score_threshold: float = 0.0):
+ self.score_threshold = score_threshold
+
+ def __call__(self, prediction, data: WildDet3DInput) -> dict:
+ """Extract visualization data from dataclass + prediction.
+
+ Args:
+ prediction: Det3DOut NamedTuple from model.
+ data: WildDet3DInput from collator.
+
+ Returns:
+ Dict with keys expected by BoundingBox3DVisualizer.
+ """
+ # When the collator filters out images with no GT boxes (empty batch),
+ # original_images is None. Return empty tensor so the visualizer's
+ # for-loop iterates 0 times instead of crashing.
+ images = data.original_images
+ if images is None:
+ images = torch.zeros(0, 3, 1, 1)
+
+ boxes3d = prediction.boxes3d
+ class_ids = prediction.class_ids
+ scores = prediction.scores
+
+ # Filter by score threshold per image for cleaner visualization
+ if self.score_threshold > 0.0 and scores is not None:
+ filtered_boxes3d = []
+ filtered_class_ids = []
+ filtered_scores = []
+ for i in range(len(scores)):
+ mask = scores[i] >= self.score_threshold
+ filtered_scores.append(scores[i][mask])
+ filtered_class_ids.append(class_ids[i][mask])
+ filtered_boxes3d.append(boxes3d[i][mask])
+ boxes3d = filtered_boxes3d
+ class_ids = filtered_class_ids
+ scores = filtered_scores
+
+ # Cast to float32 for numpy compatibility (bf16 not supported)
+ if scores is not None:
+ scores = [s.float() for s in scores]
+ if boxes3d is not None:
+ boxes3d = [b.float() for b in boxes3d]
+
+ intrinsics = data.original_intrinsics
+ if intrinsics is not None:
+ intrinsics = intrinsics.float()
+
+ return {
+ "images": images,
+ "image_names": data.sample_names,
+ "intrinsics": intrinsics,
+ "boxes3d": boxes3d,
+ "class_ids": class_ids,
+ "scores": scores,
+ }
+
+
+class WildDet3DEvalConnector:
+ """Eval connector that extracts from WildDet3DInput for evaluator.
+
+ Same issue as WildDet3DVisConnector: CallbackConnector doesn't work with
+ dataclass. This connector manually extracts fields.
+ """
+
+ def __call__(self, prediction, data: WildDet3DInput) -> dict:
+ """Extract evaluation data from dataclass + prediction.
+
+ Args:
+ prediction: Det3DOut NamedTuple from model.
+ data: WildDet3DInput from collator.
+
+ Returns:
+ Dict with keys expected by Omni3DEvaluator.
+ """
+ return {
+ "coco_image_id": data.sample_names,
+ "dataset_names": data.dataset_name,
+ "pred_boxes": prediction.boxes,
+ "pred_scores": prediction.scores,
+ "pred_classes": prediction.class_ids,
+ "pred_boxes3d": prediction.boxes3d,
+ }
+
+
+class WildDet3DDetect3DEvalConnector:
+ """Eval connector for Detect3DEvaluator with WildDet3DInput.
+
+ Unlike WildDet3DEvalConnector, this connector does not include dataset_names
+ since Detect3DEvaluator.process_batch does not accept that argument.
+ """
+
+ def __call__(self, prediction, data: WildDet3DInput) -> dict:
+ """Extract evaluation data from dataclass + prediction.
+
+ Args:
+ prediction: Det3DOut NamedTuple from model.
+ data: WildDet3DInput from collator.
+
+ Returns:
+ Dict with keys expected by Detect3DEvaluator.process_batch.
+ """
+ return {
+ "coco_image_id": data.sample_names,
+ "pred_boxes": prediction.boxes,
+ "pred_scores": prediction.scores,
+ "pred_classes": prediction.class_ids,
+ "pred_boxes3d": prediction.boxes3d,
+ }
+
+
+def get_wilddet3d_data_connector_cfg() -> tuple[ConfigDict, ConfigDict]:
+ """Get WildDet3D data connector configuration.
+
+ Returns:
+ Tuple of (train_connector, test_connector).
+
+ Note:
+ Uses WildDet3DPassthroughConnector which passes the collated batch
+ directly to model.forward(batch=...), rather than mapping individual
+ keys like standard vis4d DataConnector.
+ """
+ train_data_connector = class_config(WildDet3DPassthroughConnector)
+ test_data_connector = class_config(WildDet3DPassthroughConnector)
+
+ return train_data_connector, test_data_connector
+
+
+def get_wilddet3d_collator_cfg(
+ max_prompts_per_image: int = 50,
+ use_text_prompts: bool = True,
+ # Point prompt options (for ablation)
+ use_point_prompts: bool = False,
+ num_positive_points: int | tuple[int, int] = 1,
+ num_negative_points: int | tuple[int, int] = 0,
+ point_sample_mode: Literal["centered", "random_mask", "random_box"] = "random_mask",
+ # Box prompt options
+ use_box_prompts: bool = True,
+ box_noise_std: float = 0.0,
+ box_noise_max: float | None = 20.0,
+ # Text/Visual query ratio (SAM3 original design)
+ text_query_prob: float = 0.7,
+ keep_text_for_visual: bool = False,
+) -> ConfigDict:
+ """Get WildDet3D collator configuration.
+
+ The collator converts per-image DataLoader output to WildDet3DInput.
+ Following SAM3 original design: per-category queries with multi-instance targets.
+
+ Args:
+ max_prompts_per_image: Max prompts (categories) per image
+ use_text_prompts: Whether to include text with geometric prompts
+ use_point_prompts: Whether to sample point prompts (for ablation)
+ num_positive_points: Number of positive points to sample
+ Can be int or (min, max) tuple for random range
+ num_negative_points: Number of negative points to sample
+ Can be int or (min, max) tuple for random range
+ point_sample_mode: How to sample points when mask is available
+ - "centered": sample from mask center (farthest from edges)
+ - "random_mask": uniform sample from mask interior
+ - "random_box": uniform sample from box, label from mask
+ use_box_prompts: Whether to use box prompts
+ box_noise_std: Noise std for box jittering (0 = no noise)
+ box_noise_max: Max noise in pixels
+ text_query_prob: Probability of text-only queries (SAM3 recommended: 0.7)
+ 1.0 = all text queries (pure text training)
+ 0.7 = 70% text, 30% visual (SAM3 mixed training)
+ 0.0 = all visual queries (DetAny3D style)
+ keep_text_for_visual: If True, visual queries keep category text
+ If False (default), visual queries use "visual" as text
+
+ Returns:
+ Collator configuration
+ """
+ return class_config(
+ WildDet3DCollator,
+ max_prompts_per_image=max_prompts_per_image,
+ use_text_prompts=use_text_prompts,
+ use_point_prompts=use_point_prompts,
+ num_positive_points=num_positive_points,
+ num_negative_points=num_negative_points,
+ point_sample_mode=point_sample_mode,
+ use_box_prompts=use_box_prompts,
+ box_noise_std=box_noise_std,
+ box_noise_max=box_noise_max,
+ text_query_prob=text_query_prob,
+ keep_text_for_visual=keep_text_for_visual,
+ )
diff --git a/wilddet3d/data/__init__.py b/wilddet3d/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ffbd63f41daecbc90a48541f408e03349431e3
--- /dev/null
+++ b/wilddet3d/data/__init__.py
@@ -0,0 +1 @@
+"""Data utilities."""
diff --git a/wilddet3d/data/__pycache__/__init__.cpython-311.pyc b/wilddet3d/data/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e86e856a2a2bce6ef5568d3965d30f143d540678
Binary files /dev/null and b/wilddet3d/data/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/data/datasets/__init__.py b/wilddet3d/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/data/datasets/argoverse.py b/wilddet3d/data/datasets/argoverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..5032fc27faed0de4d2b531e6c9826e106c99400b
--- /dev/null
+++ b/wilddet3d/data/datasets/argoverse.py
@@ -0,0 +1,94 @@
+"""Argoverse V2 Sensor dataset."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from .coco3d import COCO3DDataset
+
+TRAIN_SAMPLE_RATE = 10
+VAL_SAMPLE_RATE = 5
+ACC_FRAMES = 5
+
+
+av2_class_map = {
+ "regular vehicle": 0,
+ "pedestrian": 1,
+ "bicyclist": 2,
+ "motorcyclist": 3,
+ "wheeled rider": 4,
+ "bollard": 5,
+ "construction cone": 6,
+ "sign": 7,
+ "construction barrel": 8,
+ "stop sign": 9,
+ "mobile pedestrian crossing sign": 10,
+ "large vehicle": 11,
+ "bus": 12,
+ "box truck": 13,
+ "truck": 14,
+ "vehicular trailer": 15,
+ "truck cab": 16,
+ "school bus": 17,
+ "articulated bus": 18,
+ "message board trailer": 19,
+ "bicycle": 20,
+ "motorcycle": 21,
+ "wheeled device": 22,
+ "wheelchair": 23,
+ "stroller": 24,
+ "dog": 25,
+}
+
+av2_det_map = {
+ "regular vehicle": 0,
+ "pedestrian": 1,
+ "bicyclist": 2,
+ "motorcyclist": 3,
+ "wheeled rider": 4,
+ "bollard": 5,
+ "construction cone": 6,
+ "sign": 7,
+ "construction barrel": 8,
+ "stop sign": 9,
+ "mobile pedestrian crossing sign": 10,
+ "large vehicle": 11,
+ "bus": 12,
+ "box truck": 13,
+ "truck": 14,
+ "vehicular trailer": 15,
+ "truck cab": 16,
+ "school bus": 17,
+ "articulated bus": 18,
+ "bicycle": 19,
+ "motorcycle": 20,
+ "wheeled device": 21,
+ "stroller": 22,
+}
+
+
+class AV2SensorDataset(COCO3DDataset):
+ """Argoverse V2 Sensor dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = av2_class_map,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames."""
+ return (
+ img["file_path"]
+ .replace("images", "depth")
+ .replace(".jpg", "_depth.png")
+ )
diff --git a/wilddet3d/data/datasets/coco3d.py b/wilddet3d/data/datasets/coco3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b0613fca6abf0443fec6b38fd9e2c43dcc56cb9
--- /dev/null
+++ b/wilddet3d/data/datasets/coco3d.py
@@ -0,0 +1,574 @@
+"""COCO 3D API."""
+
+from __future__ import annotations
+
+import contextlib
+import io
+import json
+import os
+import time
+from collections import defaultdict
+from collections.abc import Sequence
+
+import numpy as np
+from pycocotools.coco import COCO
+from pyquaternion import Quaternion
+from scipy.spatial.transform import Rotation as R
+from vis4d.common.logging import rank_zero_info, rank_zero_warn
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import AxisMode
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.base import Dataset
+from vis4d.data.datasets.util import (
+ CacheMappingMixin,
+ im_decode,
+ print_class_histogram,
+)
+from vis4d.data.typing import DictData
+
+
+class COCO3DDataset(CacheMappingMixin, Dataset):
+ """3D Object Detection Dataset using coco annotation files."""
+
+ def __init__(
+ self,
+ data_root: str,
+ dataset_name: str,
+ class_map: dict[str, int],
+ det_map: dict[str, int],
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
+ with_depth: bool = False,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ remove_empty: bool = False,
+ data_prefix: str | None = None,
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ # Omni3DAPI filtering thresholds (passed to COCO3D)
+ truncation_thres: float = 0.33333333,
+ visibility_thres: float = 0.33333333,
+ min_height_thres: float = 0.0625,
+ max_height_thres: float = 1.50,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(**kwargs)
+ self.data_root = data_root
+ self.dataset_name = dataset_name
+ self.annotation_file = f"{dataset_name}.json"
+
+ self.keys_to_load = list(keys_to_load)
+ self.remove_empty = remove_empty
+
+ self.class_map = class_map # Class mapping in the annotation file
+ self.det_map = det_map # Class mapping for detection
+ self.categories = sorted(self.det_map, key=self.det_map.get)
+
+ self.data_prefix = data_prefix
+ self.text_prompt_mapping = text_prompt_mapping
+
+ # Omni3DAPI filtering thresholds
+ self.truncation_thres = truncation_thres
+ self.visibility_thres = visibility_thres
+ self.min_height_thres = min_height_thres
+ self.max_height_thres = max_height_thres
+
+ # Metric Depth
+ if with_depth and not K.depth_maps in keys_to_load:
+ self.keys_to_load.append(K.depth_maps)
+
+ self.max_depth = max_depth
+ self.depth_scale = depth_scale
+
+ # Load annotations
+ self.samples, _ = self._load_mapping(
+ self._generate_data_mapping,
+ self._filter_data,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return self.dataset_name
+
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
+ """Remove empty samples."""
+ samples = []
+
+ frequencies = {cat: 0 for cat in sorted(self.det_map)}
+
+ empty_samples = 0
+ no_depth_samples = 0
+ for sample in data:
+ if self.remove_empty and len(sample["anns"]) == 0:
+ empty_samples += 1
+ continue
+
+ if (
+ K.depth_maps in self.keys_to_load
+ and "depth_filename" not in sample
+ ):
+ empty_samples += 1
+ no_depth_samples += 1
+ continue
+
+ for ann in sample["anns"]:
+ frequencies[ann["category_name"]] += 1
+
+ samples.append(sample)
+
+ rank_zero_info(
+ f"Propocessing {self.dataset_name} with {len(samples)} samples."
+ )
+ rank_zero_info(f"No depth samples: {no_depth_samples}")
+ rank_zero_info(f"Filtered {empty_samples} empty samples")
+ print_class_histogram(frequencies)
+
+ return samples
+
+ def _get_cat_id(
+ self, img: DictStrAny, ann: DictStrAny, cat_name: str
+ ) -> None:
+ """Get the category id from the category name."""
+ ann["category_id"] = self.det_map[cat_name]
+
+ def _generate_data_mapping(self) -> list[DictStrAny]:
+ """Generates the data mapping."""
+ # Load annotations
+ with contextlib.redirect_stdout(io.StringIO()):
+ coco_api = COCO3D(
+ os.path.join(
+ self.data_root, "annotations", self.annotation_file
+ ),
+ self.categories,
+ truncation_thres=self.truncation_thres,
+ visibility_thres=self.visibility_thres,
+ min_height_thres=self.min_height_thres,
+ max_height_thres=self.max_height_thres,
+ )
+
+ cats_map = {v: k for k, v in self.class_map.items()}
+
+ img_ids = sorted(coco_api.getImgIds())
+ imgs = coco_api.loadImgs(img_ids)
+
+ samples = []
+ for img_id, img in zip(img_ids, imgs):
+ # Fix file path for Omni3D
+ if self.data_prefix is not None:
+ img["file_path"] = os.path.join(
+ self.data_prefix, img["file_path"]
+ )
+
+ valid_anns = []
+ anns = coco_api.imgToAnns[img_id]
+
+ boxes = []
+ boxes3d = np.empty((0, 10), dtype=np.float32)[1:]
+ class_ids = np.empty((0,), dtype=np.int64)[1:]
+ ignore_boxes = []
+ ignore_class_ids_list = []
+ for ann in anns:
+ cat_name = cats_map[ann["category_id"]]
+ assert cat_name == ann["category_name"]
+
+ if cat_name in {"dontcare", "ignore", "void"}:
+ continue
+
+ if ann["ignore"]:
+ # Preserve ignore box 2D coords and class ID
+ # for negative loss suppression during training.
+ # Only keep objects that are actually visible in
+ # the image — skip behind_camera and degenerate bbox.
+ if (
+ cat_name in self.det_map
+ and not ann.get("behind_camera", False)
+ ):
+ x1, y1, w, h = ann["bbox"]
+ if w > 0 and h > 0:
+ ignore_boxes.append(
+ (x1, y1, x1 + w, y1 + h)
+ )
+ ignore_class_ids_list.append(
+ self.det_map[cat_name]
+ )
+ continue
+
+ # Box 3D
+ center = ann["center_cam"]
+ width, height, length = ann["dimensions"]
+
+ # Check if the rotation matrix is valid
+ try:
+ x, y, z, w = R.from_matrix(
+ np.array(ann["R_cam"])
+ ).as_quat()
+ except Exception as e:
+ rank_zero_warn(
+ f"Error processing rotation matrix for annotation {ann['id']}: {e}"
+ )
+ continue
+
+ orientation = Quaternion([w, x, y, z])
+
+ boxes3d = np.concatenate(
+ [
+ boxes3d,
+ np.array(
+ [
+ [
+ *center,
+ width,
+ length,
+ height,
+ *orientation.elements,
+ ]
+ ],
+ dtype=np.float32,
+ ),
+ ]
+ )
+
+ # Box 2D
+ x1, y1, width, height = ann["bbox"]
+ x2, y2 = x1 + width, y1 + height
+ boxes.append((x1, y1, x2, y2))
+
+ # Class
+ self._get_cat_id(img, ann, cat_name)
+
+ class_ids = np.concatenate(
+ [
+ class_ids,
+ np.array([ann["category_id"]], dtype=np.int64),
+ ]
+ )
+
+ valid_anns.append(ann)
+
+ boxes2d = (
+ np.empty((0, 4), dtype=np.float32)
+ if not boxes
+ else np.array(boxes, dtype=np.float32)
+ )
+
+ depth_filename = self.get_depth_filenames(img)
+
+ ignore_boxes2d = (
+ np.empty((0, 4), dtype=np.float32)
+ if not ignore_boxes
+ else np.array(ignore_boxes, dtype=np.float32)
+ )
+ ignore_class_ids = np.array(
+ ignore_class_ids_list, dtype=np.int64
+ )
+
+ sample = {
+ "img_id": img_id,
+ "img": img,
+ "anns": valid_anns,
+ "boxes2d": boxes2d,
+ "boxes3d": boxes3d,
+ "class_ids": class_ids,
+ "ignore_boxes2d": ignore_boxes2d,
+ "ignore_class_ids": ignore_class_ids,
+ }
+
+ if depth_filename is not None and (
+ self.data_backend.exists(depth_filename)
+ or os.path.exists(depth_filename)
+ ):
+ sample["depth_filename"] = depth_filename
+
+ samples.append(sample)
+
+ return samples
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ return None
+
+ def get_cat_ids(self, idx: int) -> list[int]:
+ """Return the samples."""
+ return self.samples[idx]["class_ids"].tolist()
+
+ def __len__(self) -> int:
+ """Total number of samples of data."""
+ return len(self.samples)
+
+ def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
+ """Get the depth map."""
+ depth_bytes = self.data_backend.get(sample["depth_filename"])
+ depth_array = im_decode(depth_bytes)
+
+ depth = np.ascontiguousarray(depth_array, dtype=np.float32)
+
+ depth = depth / self.depth_scale
+
+ return depth
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Get single sample.
+
+ Args:
+ idx (int): Index of sample.
+
+ Returns:
+ DictData: sample at index in Vis4D input format.
+ """
+ sample = self.samples[idx]
+ data_dict: DictData = {}
+
+ # Get image info
+ data_dict[K.sample_names] = sample["img_id"]
+
+ data_dict["dataset_name"] = self.dataset_name
+ data_dict[K.boxes2d_names] = self.categories
+ data_dict["text_prompt_mapping"] = self.text_prompt_mapping
+
+ if K.images in self.keys_to_load:
+ im_bytes = self.data_backend.get(sample["img"]["file_path"])
+ image = np.ascontiguousarray(
+ im_decode(im_bytes, mode=self.image_channel_mode),
+ dtype=np.float32,
+ )[None]
+
+ data_dict[K.images] = image
+ data_dict[K.input_hw] = (image.shape[1], image.shape[2])
+
+ data_dict[K.original_images] = image
+ data_dict[K.original_hw] = (image.shape[1], image.shape[2])
+
+ # Get camera info
+ intrinsics = np.array(sample["img"]["K"], dtype=np.float32)
+ data_dict[K.intrinsics] = intrinsics
+ data_dict["original_intrinsics"] = intrinsics
+
+ data_dict[K.boxes2d] = sample["boxes2d"]
+ data_dict[K.boxes2d_classes] = sample["class_ids"]
+ data_dict[K.boxes3d] = sample["boxes3d"]
+ data_dict[K.boxes3d_classes] = sample["class_ids"]
+ data_dict[K.axis_mode] = AxisMode.OPENCV
+
+ # Ignore boxes for negative loss suppression (backward compat)
+ data_dict["ignore_boxes2d"] = sample.get(
+ "ignore_boxes2d", np.empty((0, 4), dtype=np.float32)
+ )
+ data_dict["ignore_class_ids"] = sample.get(
+ "ignore_class_ids", np.empty((0,), dtype=np.int64)
+ )
+
+ if K.depth_maps in self.keys_to_load:
+ depth = self.get_depth_map(sample)
+
+ depth[depth > self.max_depth] = 0
+
+ data_dict[K.depth_maps] = depth
+
+ data_dict["tokens_positive"] = None
+
+ self.data_backend.close()
+
+ return data_dict
+
+
+class COCO3D(COCO):
+ """COCO API with 3D annotations."""
+
+ def __init__(
+ self,
+ annotation_files: Sequence[str] | str,
+ category_names: Sequence[str] | None = None,
+ ignore_names: Sequence[str] = ("dontcare", "ignore", "void"),
+ truncation_thres: float = 0.33333333,
+ visibility_thres: float = 0.33333333,
+ min_height_thres: float = 0.0625,
+ max_height_thres: float = 1.50,
+ modal_2D_boxes: bool = False,
+ trunc_2D_boxes: bool = True,
+ max_depth: int = 1e8,
+ ) -> None:
+ """Creates an instance of the class."""
+ self.dataset, self.anns, self.cats, self.imgs = {}, {}, {}, {}
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
+
+ self.truncation_thres = truncation_thres
+ self.visibility_thres = visibility_thres
+ self.min_height_thres = min_height_thres
+ self.max_height_thres = max_height_thres
+ self.max_depth = max_depth
+
+ if isinstance(annotation_files, str):
+ annotation_files = [annotation_files]
+
+ cats_ids_master = []
+ cats_master = []
+
+ for annotation_file in annotation_files:
+ _, tail = os.path.split(annotation_file)
+ name, _ = os.path.splitext(tail)
+
+ print(f"loading {name} annotations into memory...")
+ tic = time.time()
+
+ with open(annotation_file, "r") as f:
+ dataset = json.load(f)
+
+ assert (
+ type(dataset) == dict
+ ), f"annotation file format {type(dataset)} not supported"
+ print(f"Done (t={time.time() - tic:.2f}s)")
+
+ if "info" not in dataset:
+ dataset["info"] = {"description": name}
+
+ if type(dataset["info"]) == list:
+ dataset["info"] = dataset["info"][0]
+
+ dataset["info"]["known_category_ids"] = [
+ cat["id"] for cat in dataset["categories"]
+ ]
+
+ # first dataset
+ if len(self.dataset) == 0:
+ self.dataset = dataset
+ # concatenate datasets
+ else:
+ if type(self.dataset["info"]) == dict:
+ self.dataset["info"] = [self.dataset["info"]]
+
+ self.dataset["info"] += [dataset["info"]]
+ self.dataset["annotations"] += dataset["annotations"]
+ self.dataset["images"] += dataset["images"]
+
+ # sort through categories
+ for cat in dataset["categories"]:
+ if not cat["id"] in cats_ids_master:
+ cats_ids_master.append(cat["id"])
+ cats_master.append(cat)
+
+ # category names are provided to us
+ if category_names is not None:
+ self.dataset["categories"] = [
+ cats_master[i]
+ for i in np.argsort(cats_ids_master)
+ if cats_master[i]["name"] in category_names
+ ]
+ # no categories are provided, so assume use ALL available.
+ else:
+ self.dataset["categories"] = [
+ cats_master[i] for i in np.argsort(cats_ids_master)
+ ]
+
+ category_names = [
+ cat["name"] for cat in self.dataset["categories"]
+ ]
+
+ # determine which categories we may actually use for filtering.
+ trainable_cats = set(ignore_names) | set(category_names)
+
+ valid_anns = []
+ im_height_map = {}
+
+ for im_obj in self.dataset["images"]:
+ im_height_map[im_obj["id"]] = im_obj["height"]
+
+ # Filter out annotations
+ for anno_idx, anno in enumerate(self.dataset["annotations"]):
+
+ im_height = im_height_map[anno["image_id"]]
+
+ # tightly annotated 2D boxes are not always available.
+ if (
+ modal_2D_boxes
+ and "bbox2D_tight" in anno
+ and anno["bbox2D_tight"][0] != -1
+ ):
+ bbox2D = anno["bbox2D_tight"]
+ elif (
+ trunc_2D_boxes
+ and "bbox2D_trunc" in anno
+ and not np.all([val == -1 for val in anno["bbox2D_trunc"]])
+ ):
+ bbox2D = anno["bbox2D_trunc"]
+ elif anno["bbox2D_proj"][0] != -1:
+ bbox2D = anno["bbox2D_proj"]
+ elif anno["bbox2D_tight"][0] != -1:
+ bbox2D = anno["bbox2D_tight"]
+ else:
+ continue
+
+ # convert to xywh
+ bbox2D[2] = bbox2D[2] - bbox2D[0]
+ bbox2D[3] = bbox2D[3] - bbox2D[1]
+
+ ignore = self.is_ignore(anno, bbox2D, ignore_names, im_height)
+
+ width = bbox2D[2]
+ height = bbox2D[3]
+
+ self.dataset["annotations"][anno_idx]["area"] = width * height
+ self.dataset["annotations"][anno_idx]["iscrowd"] = False
+ self.dataset["annotations"][anno_idx]["ignore"] = ignore
+ self.dataset["annotations"][anno_idx]["ignore2D"] = ignore
+ self.dataset["annotations"][anno_idx]["ignore3D"] = ignore
+
+ self.dataset["annotations"][anno_idx]["bbox"] = bbox2D
+ self.dataset["annotations"][anno_idx]["bbox3D"] = anno[
+ "bbox3D_cam"
+ ]
+ self.dataset["annotations"][anno_idx]["depth"] = anno[
+ "center_cam"
+ ][2]
+
+ category_name = anno["category_name"]
+
+ if category_name in trainable_cats:
+ valid_anns.append(self.dataset["annotations"][anno_idx])
+
+ self.dataset["annotations"] = valid_anns
+
+ self.createIndex()
+
+ def is_ignore(
+ self,
+ anno,
+ bbox2D: list[float, float, float, float],
+ ignore_names: Sequence[str] | None,
+ image_height: int,
+ ) -> bool:
+ ignore = anno["behind_camera"]
+ ignore |= not bool(anno["valid3D"])
+
+ if ignore:
+ return ignore
+
+ ignore |= anno["dimensions"][0] <= 0
+ ignore |= anno["dimensions"][1] <= 0
+ ignore |= anno["dimensions"][2] <= 0
+ ignore |= anno["center_cam"][2] > self.max_depth
+ ignore |= anno["lidar_pts"] == 0
+ ignore |= anno["segmentation_pts"] == 0
+ ignore |= anno["depth_error"] > 0.5
+
+ ignore |= bbox2D[3] <= self.min_height_thres * image_height
+ ignore |= bbox2D[3] >= self.max_height_thres * image_height
+
+ ignore |= (
+ anno["truncation"] >= 0
+ and anno["truncation"] >= self.truncation_thres
+ )
+ ignore |= (
+ anno["visibility"] >= 0
+ and anno["visibility"] <= self.visibility_thres
+ )
+
+ if ignore_names is not None:
+ ignore |= anno["category_name"] in ignore_names
+
+ return ignore
diff --git a/wilddet3d/data/datasets/cubifyanything.py b/wilddet3d/data/datasets/cubifyanything.py
new file mode 100644
index 0000000000000000000000000000000000000000..92f066ab56189fabca293cfa15c3de216f936ad3
--- /dev/null
+++ b/wilddet3d/data/datasets/cubifyanything.py
@@ -0,0 +1,90 @@
+"""CubifyAnything (CA-1M) dataset for 3D object detection."""
+
+from __future__ import annotations
+
+import json
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+
+def get_cubifyanything_det_map(
+ dataset_name: str,
+ data_root: str = "data/cubifyanything",
+) -> dict[str, int]:
+ """Build det_map from CA-1M annotation JSON categories.
+
+ CA-1M has ~3000 free-form categories. Since our model is
+ open-vocabulary (text-prompted), we build det_map dynamically
+ from the annotation JSON's categories list.
+
+ Args:
+ dataset_name: e.g. "CubifyAnything_train" or "CubifyAnything_val"
+ data_root: Root directory for CubifyAnything data.
+ """
+ cache_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}_class_map.json"
+ )
+ if os.path.exists(cache_path):
+ with open(cache_path) as f:
+ return json.load(f)
+ json_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}.json"
+ )
+ with open(json_path) as f:
+ data = json.load(f)
+ class_map = {cat["name"]: cat["id"] for cat in data["categories"]}
+ with open(cache_path, "w") as f:
+ json.dump(class_map, f)
+ return class_map
+
+
+def get_cubifyanything_class_map(
+ dataset_name: str,
+ data_root: str = "data/cubifyanything",
+) -> dict[str, int]:
+ """Build class_map from CA-1M annotation JSON categories.
+
+ CA-1M has ~3000 categories (not in omni3d_class_map), so
+ we build class_map dynamically from the annotation JSON.
+ class_map maps category_name -> category_id (same as det_map
+ for CA-1M, since all categories are trainable).
+
+ Args:
+ dataset_name: e.g. "CubifyAnything_train" or "CubifyAnything_val"
+ data_root: Root directory for CubifyAnything data.
+ """
+ return get_cubifyanything_det_map(dataset_name, data_root)
+
+
+class CubifyAnything(COCO3DDataset):
+ """CubifyAnything (CA-1M) Dataset.
+
+ Indoor scenes with uint16 mm-encoded depth maps.
+ """
+
+ def __init__(
+ self,
+ max_depth: float = 20.0,
+ depth_scale: float = 1000.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filename for a given image.
+
+ Maps image path to depth path:
+ cubifyanything/data/CubifyAnything/train/42446540/ts.jpg
+ -> cubifyanything/depth_gt/train/42446540/ts.png
+ """
+ return img["file_path"].replace(
+ "data/CubifyAnything", "depth_gt"
+ ).replace(".jpg", ".png")
diff --git a/wilddet3d/data/datasets/foundationpose.py b/wilddet3d/data/datasets/foundationpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..882b8ab1042b14d22e75f36199989d04c95a6b32
--- /dev/null
+++ b/wilddet3d/data/datasets/foundationpose.py
@@ -0,0 +1,78 @@
+"""FoundationPose (GSO) dataset for 3D object detection.
+
+Synthetic dataset from FoundationPose with Google Scanned Objects (GSO).
+438 categories, ~446K images with dense depth maps (uint16, depth_m * 256).
+"""
+
+from __future__ import annotations
+
+import json
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+
+def get_foundationpose_det_map(
+ dataset_name: str,
+ data_root: str = "data/foundationpose",
+) -> dict[str, int]:
+ """Build det_map from FoundationPose annotation JSON categories."""
+ cache_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}_class_map.json"
+ )
+ if os.path.exists(cache_path):
+ with open(cache_path) as f:
+ return json.load(f)
+ json_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}.json"
+ )
+ with open(json_path) as f:
+ data = json.load(f)
+ class_map = {cat["name"]: cat["id"] for cat in data["categories"]}
+ with open(cache_path, "w") as f:
+ json.dump(class_map, f)
+ return class_map
+
+
+def get_foundationpose_class_map(
+ dataset_name: str,
+ data_root: str = "data/foundationpose",
+) -> dict[str, int]:
+ """Build class_map from FoundationPose annotation JSON categories."""
+ return get_foundationpose_det_map(dataset_name, data_root)
+
+
+class FoundationPoseDataset(COCO3DDataset):
+ """FoundationPose (GSO) Dataset.
+
+ Synthetic scenes with dense depth maps (uint16, depth_m * 256).
+ """
+
+ def __init__(
+ self,
+ max_depth: float = 20.0,
+ depth_scale: float = 256.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filename for a given image.
+
+ Maps image path to depth path:
+ foundationpose/images_jpg/gso/{name}.jpg
+ -> foundationpose/depth/gso/{name}.png
+ """
+ path = img["file_path"]
+ path = path.replace(
+ "foundationpose/images_jpg/", "foundationpose/depth/"
+ )
+ path = path.replace(".jpg", ".png")
+ return path
diff --git a/wilddet3d/data/datasets/in_the_wild.py b/wilddet3d/data/datasets/in_the_wild.py
new file mode 100644
index 0000000000000000000000000000000000000000..0335f8c2ef99383b1919b48d5dcd9d371d5d9984
--- /dev/null
+++ b/wilddet3d/data/datasets/in_the_wild.py
@@ -0,0 +1,447 @@
+"""In-The-Wild 3D dataset (COCO/LVIS/Objects365 with human-annotated 3D boxes)."""
+
+from __future__ import annotations
+
+import json
+import os
+import time
+from collections import defaultdict
+
+import numpy as np
+import cv2
+from pycocotools import mask as maskUtils
+
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import CommonKeys as K
+
+from .coco3d import COCO3DDataset
+
+_V4_DEPTH_ROOT = (
+ "/weka/oe-training-default/weikaih/3d_boundingbox_detection"
+ "/single_frame_data/experiment/v4_depth_new"
+)
+
+# Depth directories (v4_depth_new, unscaled)
+_DEPTH_DIRS_NEW = {
+ "coco/val": f"{_V4_DEPTH_ROOT}/coco/val/depth",
+ "coco/train": f"{_V4_DEPTH_ROOT}/coco/train/depth",
+ "obj365/val": f"{_V4_DEPTH_ROOT}/obj365/val/depth",
+ "obj365/train": f"{_V4_DEPTH_ROOT}/obj365/train/depth",
+ "v3det/train": f"{_V4_DEPTH_ROOT}/v3det/train/depth",
+}
+
+# Confidence map directories (uint8 PNG, same resolution as depth)
+_CONF_DIRS = {
+ "coco/val": f"{_V4_DEPTH_ROOT}/coco/val/confidence",
+ "coco/train": f"{_V4_DEPTH_ROOT}/coco/train/confidence",
+ "obj365/val": f"{_V4_DEPTH_ROOT}/obj365/val/confidence",
+ "obj365/train": f"{_V4_DEPTH_ROOT}/obj365/train/confidence",
+ "v3det/train": f"{_V4_DEPTH_ROOT}/v3det/train/confidence",
+}
+
+# Depth values in the .npy files are in mm; convert to meters
+_DEPTH_MM_TO_M = 1.0 / 1000.0
+
+
+def _get_source_key_from_file_path(file_path: str) -> str:
+ """Infer v4_depth source key from image file_path.
+
+ Handles both absolute paths (legacy) and HDF5 relative paths:
+ /weka/.../coco/train2017/X.jpg -> "coco/train"
+ images/coco_train/X.jpg -> "coco/train"
+ images/v3det_train/Q.../X.jpg -> "v3det/train"
+ """
+ if "/v3det_train/" in file_path:
+ return "v3det/train"
+ elif "coco/val2017" in file_path or "/coco_val/" in file_path:
+ return "coco/val"
+ elif "coco/train2017" in file_path or "/coco_train/" in file_path:
+ return "coco/train"
+ elif (
+ ("obj365" in file_path and "/train/" in file_path)
+ or "/obj365_train/" in file_path
+ ):
+ return "obj365/train"
+ else:
+ return "obj365/val"
+
+
+def _get_formatted_id_from_file_path(file_path: str) -> str:
+ """Extract zero-padded 12-digit image ID from file path."""
+ basename = file_path.split("/")[-1] # e.g. 000000000724.jpg
+ return (
+ basename.replace(".jpg", "")
+ .replace("obj365_val_", "")
+ .replace("obj365_train_", "")
+ )
+
+
+def load_in_the_wild_class_map(
+ annotation_path: str = "data/in_the_wild/annotations/InTheWild_val.json",
+) -> dict[str, int]:
+ """Load class map from InTheWild annotation file.
+
+ Returns a mapping from category name to category ID (0-indexed alphabetical).
+
+ Args:
+ annotation_path: Path to the InTheWild annotation JSON file.
+
+ Returns:
+ dict mapping category name to annotation category ID.
+ """
+ cache_path = annotation_path.replace(".json", "_class_map.json")
+ if os.path.exists(cache_path):
+ with open(cache_path) as f:
+ return json.load(f)
+ with open(annotation_path) as f:
+ data = json.load(f)
+ class_map = {cat["name"]: cat["id"] for cat in data["categories"]}
+ with open(cache_path, "w") as f:
+ json.dump(class_map, f)
+ return class_map
+
+
+class InTheWild3DDataset(COCO3DDataset):
+ """In-The-Wild 3D dataset with 800+ open-vocabulary categories.
+
+ Human-annotated 3D bounding boxes on COCO val2017, LVIS (COCO train2017),
+ and Objects365 val images.
+
+ Annotations converted from human_annotated_val_full2d.json to Omni3D
+ COCO3D format using scripts/in_the_wild/convert_in_the_wild.py.
+ Camera intrinsics are scaled back to original image resolution (non-SR).
+
+ Depth maps are from v4_depth (SR 1024-long-edge .npy, mm units),
+ resized to original image resolution on load.
+ """
+
+ def __init__(
+ self,
+ class_map: dict[str, int],
+ max_depth: float = 100.0,
+ per_image_categories: bool = False,
+ depth_confidence_threshold: int = 0,
+ mask_annotation_files: dict[str, str] | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ class_map: Mapping from category name to category ID.
+ max_depth: Maximum depth in meters (clip beyond this).
+ per_image_categories: If True, boxes2d_names only contains
+ the GT categories present in each image. Required for
+ GDino/3D-MOOD eval (avoids BERT truncation with 1246
+ categories). Must be False for WildDet3D (collator indexes
+ boxes2d_names by global cat_id).
+ depth_confidence_threshold: Minimum confidence (uint8, 0-255)
+ for a depth pixel to be considered valid. Pixels below
+ this threshold are set to 0 (invalid). Set to 0 to
+ disable confidence masking. Only applies when confidence
+ map exists for the image.
+ mask_annotation_files: Optional dict mapping source key
+ (e.g. "coco/train", "obj365/val") to the annotation
+ JSON path that contains segmentation masks. When
+ provided, masks are matched to each sample's boxes and
+ returned as "masks2d_rle" in __getitem__.
+ """
+ super().__init__(
+ class_map=class_map,
+ det_map=class_map,
+ max_depth=max_depth,
+ **kwargs,
+ )
+ self.per_image_categories = per_image_categories
+ self.depth_confidence_threshold = depth_confidence_threshold
+
+ # Separate dict for mask RLEs (DatasetFromList serializes
+ # samples, so in-place mutation does not persist).
+ self._mask_rle_index: dict[int, list] = {}
+ if mask_annotation_files:
+ self._build_mask_index(mask_annotation_files)
+
+ def _build_mask_index(
+ self, mask_annotation_files: dict[str, str | list[str]]
+ ) -> None:
+ """Load mask annotations and build per-sample mask index.
+
+ For each mask annotation file, builds an index by image filename,
+ then matches masks to ITW sample boxes by (x1, y1, w, h)
+ coordinate proximity. Supports multiple files per source key
+ (e.g. both LVIS and COCO instances for coco/train).
+
+ Args:
+ mask_annotation_files: {source_key: path_or_list_of_paths}.
+ """
+ # Group samples by (source_key, basename) for matching
+ source_bn_to_indices = defaultdict(list)
+ for i in range(len(self.samples)):
+ sample = self.samples[i]
+ fp = sample["img"]["file_path"]
+ sk = _get_source_key_from_file_path(fp)
+ bn = fp.split("/")[-1]
+ source_bn_to_indices[(sk, bn)].append(i)
+
+ # Normalize to list of paths per source key
+ expanded = {}
+ for source_key, paths in mask_annotation_files.items():
+ if isinstance(paths, str):
+ expanded[source_key] = [paths]
+ else:
+ expanded[source_key] = list(paths)
+
+ for source_key, ann_paths in expanded.items():
+ for ann_path in ann_paths:
+ # Basenames we need from this source
+ needed_bns = {
+ bn
+ for (sk, bn) in source_bn_to_indices
+ if sk == source_key
+ }
+ if not needed_bns:
+ continue
+
+ rank_zero_info(
+ f"[masks] Loading {source_key} from {ann_path} ..."
+ )
+ t0 = time.time()
+ with open(ann_path) as f:
+ data = json.load(f)
+ rank_zero_info(
+ f"[masks] Loaded in {time.time() - t0:.1f}s "
+ f"({len(data.get('images', []))} images, "
+ f"{len(data.get('annotations', []))} annotations)"
+ )
+
+ # filename -> (mask_img_id, height, width)
+ fn_to_info = {}
+ for img in data["images"]:
+ fn = img.get("file_name")
+ if fn is None:
+ # LVIS format: file_name is None, use id
+ fn = f"{img['id']:012d}.jpg"
+ else:
+ fn = fn.split("/")[-1]
+ if fn in needed_bns:
+ fn_to_info[fn] = (
+ img["id"],
+ img["height"],
+ img["width"],
+ )
+
+ # Reverse lookup: mask_img_id -> (height, width)
+ mid_to_hw = {
+ v[0]: (v[1], v[2]) for v in fn_to_info.values()
+ }
+
+ rank_zero_info(
+ f"[masks] Matched {len(fn_to_info)} images "
+ "by filename"
+ )
+
+ # mask_img_id -> [(x1, y1, rle_dict), ...]
+ needed_ids = set(mid_to_hw.keys())
+ mask_by_id = defaultdict(list)
+ for ann in data["annotations"]:
+ mid = ann["image_id"]
+ if mid not in needed_ids:
+ continue
+ seg = ann.get("segmentation")
+ if seg is None:
+ continue
+ bbox = ann["bbox"] # xywh
+ # Convert polygon / uncompressed RLE to compressed
+ # RLE for uniform handling
+ hw = mid_to_hw.get(mid)
+ if hw is None:
+ continue
+ if isinstance(seg, list):
+ # Polygon format
+ rles = maskUtils.frPyObjects(seg, hw[0], hw[1])
+ seg = maskUtils.merge(rles)
+ elif isinstance(seg.get("counts"), list):
+ # Uncompressed RLE (iscrowd) -> compress
+ seg = maskUtils.frPyObjects(
+ seg, hw[0], hw[1]
+ )
+ mask_by_id[mid].append(
+ (bbox[0], bbox[1], bbox[2], bbox[3], seg)
+ )
+
+ del data # free raw JSON
+
+ # Match masks to ITW sample boxes (merge with
+ # existing matches from previous files)
+ n_matched = 0
+ n_total = 0
+ for (sk, bn), indices in source_bn_to_indices.items():
+ if sk != source_key:
+ continue
+ info = fn_to_info.get(bn)
+ if info is None:
+ continue
+ mid = info[0]
+ masks_for_img = mask_by_id.get(mid, [])
+ if not masks_for_img:
+ continue
+ for si in indices:
+ sample = self.samples[si]
+ boxes2d = sample["boxes2d"] # (N, 4) xyxy
+ # Get existing matches (from previous file)
+ existing = self._mask_rle_index.get(si)
+ masks_rle = (
+ list(existing)
+ if existing is not None
+ else [None] * len(boxes2d)
+ )
+ for bi, box in enumerate(boxes2d):
+ if masks_rle[bi] is not None:
+ # Already matched by previous file
+ n_total += 1
+ n_matched += 1
+ continue
+ x1 = float(box[0])
+ y1 = float(box[1])
+ bw = float(box[2]) - x1
+ bh = float(box[3]) - y1
+ matched = None
+ for mx1, my1, mw, mh, rle in masks_for_img:
+ if (
+ abs(mx1 - x1) < 1.0
+ and abs(my1 - y1) < 1.0
+ and abs(mw - bw) < 2.0
+ and abs(mh - bh) < 2.0
+ ):
+ matched = rle
+ break
+ masks_rle[bi] = matched
+ n_total += 1
+ if matched is not None:
+ n_matched += 1
+ self._mask_rle_index[si] = masks_rle
+
+ rank_zero_info(
+ f"[masks] Matched {n_matched}/{n_total} boxes "
+ f"for {source_key}"
+ )
+
+ rank_zero_info(
+ f"[masks] Total: {len(self._mask_rle_index)}"
+ f"/{len(self.samples)} samples have masks"
+ )
+
+ def __getitem__(self, idx: int):
+ """Get single sample, optionally with per-image category filtering."""
+ data_dict = super().__getitem__(idx)
+ if self.per_image_categories:
+ class_ids_in_img = data_dict[K.boxes2d_classes]
+ if len(class_ids_in_img) > 0:
+ unique_global_ids = sorted(set(class_ids_in_img.tolist()))
+ data_dict[K.boxes2d_names] = [
+ self.categories[gid] for gid in unique_global_ids
+ ]
+ else:
+ data_dict[K.boxes2d_names] = []
+
+ # Decode masks and add as (N, H, W) uint8 array for transforms.
+ # masks_rle is aligned with sample["boxes2d"] (pre-filter).
+ # data_dict boxes2d comes from COCO3D which may filter some
+ # boxes (ignore, bad rotation, etc.), but the ordering of
+ # valid boxes is preserved, so masks_rle indices still match.
+ masks_rle = self._mask_rle_index.get(idx)
+ if masks_rle is not None and len(masks_rle) > 0:
+ n_boxes = len(data_dict[K.boxes2d])
+ if n_boxes == 0 or n_boxes != len(masks_rle):
+ pass # Misaligned or empty, skip masks
+ else:
+ sample = self.samples[idx]
+ h = sample["img"]["height"]
+ w = sample["img"]["width"]
+ decoded = []
+ for rle in masks_rle:
+ if rle is not None:
+ decoded.append(maskUtils.decode(rle))
+ else:
+ decoded.append(
+ np.zeros((h, w), dtype=np.uint8)
+ )
+ data_dict["masks2d"] = np.stack(
+ decoded, axis=0
+ )
+
+ return data_dict
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Return path to the .npy depth file for this image.
+
+ Uses v4_depth_new (unscaled depth maps).
+ """
+ file_path = img["file_path"]
+ source_key = _get_source_key_from_file_path(file_path)
+ if "formatted_id" in img:
+ formatted_id = img["formatted_id"]
+ else:
+ formatted_id = _get_formatted_id_from_file_path(file_path)
+ depth_dir = _DEPTH_DIRS_NEW.get(source_key)
+ if depth_dir is None:
+ return None
+ depth_path = f"{depth_dir}/{formatted_id}_sr_1024_long.npy"
+ return depth_path if os.path.exists(depth_path) else None
+
+ def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
+ """Load .npy depth (mm) and resize to original image resolution.
+
+ If depth_confidence_threshold > 0, loads the MoGe2 confidence
+ map (uint8 PNG, same resolution as depth) and zeros out pixels
+ where confidence < threshold.
+ """
+ depth_npy = np.load(sample["depth_filename"]) # (H_sr, W_sr) float32, mm
+
+ # Apply MoGe2 confidence masking before resize
+ if self.depth_confidence_threshold > 0:
+ img_entry = sample["img"]
+ file_path = img_entry["file_path"]
+ source_key = _get_source_key_from_file_path(file_path)
+ conf_dir = _CONF_DIRS.get(source_key)
+ if conf_dir is not None:
+ if "formatted_id" in img_entry:
+ formatted_id = img_entry["formatted_id"]
+ else:
+ formatted_id = _get_formatted_id_from_file_path(
+ file_path
+ )
+ conf_path = f"{conf_dir}/{formatted_id}.png"
+ if os.path.exists(conf_path):
+ conf = cv2.imread(
+ conf_path, cv2.IMREAD_UNCHANGED
+ ) # uint8, same shape as depth
+ if conf.shape != depth_npy.shape:
+ conf = cv2.resize(
+ conf,
+ (depth_npy.shape[1], depth_npy.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ depth_npy[
+ conf < self.depth_confidence_threshold
+ ] = 0.0
+
+ orig_h = sample["img"]["height"]
+ orig_w = sample["img"]["width"]
+
+ # Resize to original image size using nearest-neighbor to avoid
+ # interpolation artifacts at depth discontinuities
+ if depth_npy.shape != (orig_h, orig_w):
+ depth_npy = cv2.resize(
+ depth_npy,
+ (orig_w, orig_h),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Convert mm -> meters
+ depth = depth_npy * _DEPTH_MM_TO_M
+
+ # Clip to max_depth
+ depth[depth > self.max_depth] = 0.0
+
+ return depth.astype(np.float32)
diff --git a/wilddet3d/data/datasets/labelany3d_coco.py b/wilddet3d/data/datasets/labelany3d_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b0e34605646838fe4b3f60ab2a9b0b1d9b14226
--- /dev/null
+++ b/wilddet3d/data/datasets/labelany3d_coco.py
@@ -0,0 +1,145 @@
+"""LabelAny3D COCO Dataset.
+
+This dataset contains COCO images with 3D annotations generated by LabelAny3D.
+It uses standard COCO 80 categories with metric 3D bounding boxes.
+
+The dataset provides:
+- 2010 validation images from COCO val2017
+- 5409 3D bounding box annotations
+- 80 COCO object categories
+- Camera intrinsics (K matrix) for each image
+- Full 3D annotations: center_cam, dimensions, R_cam, bbox3D_cam
+"""
+
+from __future__ import annotations
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+# =============================================================================
+# COCO 80 categories mapping
+# Keys: category names (str)
+# Values: category IDs from the annotation file (int)
+# Sorted alphabetically for consistency
+# =============================================================================
+labelany3d_coco_class_map = {
+ "airplane": 98,
+ "apple": 136,
+ "backpack": 116,
+ "banana": 135,
+ "baseball bat": 126,
+ "baseball glove": 127,
+ "bear": 113,
+ "bed": 39,
+ "bench": 105,
+ "bicycle": 11,
+ "bird": 106,
+ "boat": 100,
+ "book": 149,
+ "bottle": 15,
+ "bowl": 56,
+ "broccoli": 139,
+ "bus": 12,
+ "cake": 144,
+ "car": 1,
+ "carrot": 140,
+ "cat": 107,
+ "cell phone": 148,
+ "chair": 18,
+ "clock": 87,
+ "couch": 145,
+ "cow": 111,
+ "cup": 19,
+ "dining table": 146,
+ "dog": 108,
+ "donut": 143,
+ "elephant": 112,
+ "fire hydrant": 102,
+ "fork": 132,
+ "frisbee": 121,
+ "giraffe": 115,
+ "hair drier": 152,
+ "handbag": 118,
+ "horse": 109,
+ "hot dog": 141,
+ "keyboard": 77,
+ "kite": 125,
+ "knife": 133,
+ "laptop": 20,
+ "microwave": 54,
+ "motorcycle": 10,
+ "mouse": 81,
+ "orange": 138,
+ "oven": 57,
+ "parking meter": 104,
+ "person": 7,
+ "pizza": 142,
+ "potted plant": 73,
+ "refrigerator": 49,
+ "remote": 95,
+ "sandwich": 137,
+ "scissors": 150,
+ "sheep": 110,
+ "sink": 28,
+ "skateboard": 128,
+ "skis": 122,
+ "snowboard": 123,
+ "spoon": 134,
+ "sports ball": 124,
+ "stop sign": 103,
+ "suitcase": 120,
+ "surfboard": 129,
+ "teddy bear": 151,
+ "tennis racket": 130,
+ "tie": 119,
+ "toaster": 72,
+ "toilet": 32,
+ "toothbrush": 153,
+ "traffic light": 101,
+ "train": 99,
+ "truck": 5,
+ "tv": 147,
+ "umbrella": 117,
+ "vase": 58,
+ "wine glass": 131,
+ "zebra": 114,
+}
+
+# Detection map for evaluation (0-indexed, continuous)
+labelany3d_coco_det_map = {cat: i for i, cat in enumerate(sorted(labelany3d_coco_class_map.keys()))}
+
+
+class LabelAny3DCOCO(COCO3DDataset):
+ """LabelAny3D COCO Dataset with 3D annotations."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = labelany3d_coco_class_map,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ class_map: Mapping from category names to class IDs
+ max_depth: Maximum depth value for clipping
+ depth_scale: Scale factor for depth values
+ **kwargs: Additional arguments passed to COCO3DDataset
+ """
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ LabelAny3D COCO doesn't have pre-computed depth maps.
+ Depth will be estimated on-the-fly during inference.
+ """
+ return None
+
diff --git a/wilddet3d/data/datasets/odvg.py b/wilddet3d/data/datasets/odvg.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f63dba8df954450eb4b4680883836209729b35
--- /dev/null
+++ b/wilddet3d/data/datasets/odvg.py
@@ -0,0 +1,280 @@
+"""Object detection and visual grounding dataset."""
+
+from __future__ import annotations
+
+import json
+import os.path as osp
+
+import numpy as np
+from tqdm import tqdm
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.datasets.base import Dataset
+from vis4d.data.datasets.util import (
+ CacheMappingMixin,
+ im_decode,
+ print_class_histogram,
+)
+from vis4d.data.typing import DictData
+
+
+class ODVGDataset(CacheMappingMixin, Dataset):
+ """Object detection and visual grounding dataset."""
+
+ def __init__(
+ self,
+ data_root: str,
+ ann_file: str,
+ label_map_file: str | None = None,
+ dataset_type: str = "VG",
+ dataset_prefix: str | None = None,
+ remove_empty: bool = False,
+ cache_as_binary: bool = False,
+ cached_file_path: str | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Create an object detection and visual grounding dataset."""
+ super().__init__(**kwargs)
+
+ self.data_root = data_root
+ self.ann_file = ann_file
+ self.dataset_type = dataset_type
+ self.dataset_prefix = dataset_prefix
+ self.remove_empty = remove_empty
+
+ if label_map_file is not None:
+ label_map_file = osp.join(self.data_root, label_map_file)
+
+ with open(label_map_file, "r") as file:
+ # dict[class_id (str): class_name (str)]
+ self.label_map = json.load(file)
+
+ self.dataset_type = "OD"
+
+ self.det_map = {v: int(k) for k, v in self.label_map.items()}
+ self.categories = sorted(self.det_map, key=self.det_map.get)
+ else:
+ self.label_map = None
+ self.dataset_type = "VG"
+
+ # Load annotations
+ self.samples, _ = self._load_mapping(
+ self._generate_data_mapping,
+ self._filter_data,
+ cache_as_binary=cache_as_binary,
+ cached_file_path=cached_file_path,
+ )
+
+ def __repr__(self) -> str:
+ """Concise representation of the dataset."""
+ return f"ODVGDataset({self.ann_file})"
+
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
+ """Remove empty samples."""
+ samples = []
+
+ if self.dataset_type == "OD":
+ frequencies = {cat: 0 for _, cat in self.label_map.items()}
+
+ empty_samples = 0
+ for sample in data:
+ if self.remove_empty and len(sample["anns"]) == 0:
+ empty_samples += 1
+ continue
+
+ if self.dataset_type == "OD":
+ for ann in sample["anns"]:
+ frequencies[ann["category"]] += 1
+
+ samples.append(sample)
+
+ rank_zero_info(f"Propocessing {self} with {len(samples)} samples.")
+ rank_zero_info(f"Filtered {empty_samples} empty samples")
+
+ if self.dataset_type == "OD":
+ frequencies = dict(sorted(frequencies.items()))
+
+ print_class_histogram(frequencies)
+
+ return samples
+
+ def _generate_data_mapping(self) -> list[DictStrAny]:
+ """Generates the data mapping."""
+ with open(osp.join(self.data_root, self.ann_file), "r") as f:
+ data_list = [json.loads(line) for line in f]
+
+ if self.with_camera:
+ with open(osp.join(self.data_root, "cam_info.json"), "r") as f:
+ cameras = json.load(f)
+
+ samples = []
+ for data in tqdm(data_list):
+ data_info = {}
+
+ if self.dataset_prefix is not None:
+ img_path = osp.join(
+ self.data_root, self.dataset_prefix, data["filename"]
+ )
+ else:
+ img_path = osp.join(self.data_root, data["filename"])
+
+ data_info["img_path"] = img_path
+
+ # Pseudo K
+ if self.with_camera:
+ data_info["K"] = cameras[img_path][0]
+
+ # Pseudo Depth Path
+ if self.dataset_prefix is not None:
+ depth_path = osp.join(
+ self.data_root,
+ f"{self.dataset_prefix}_depth",
+ data["filename"].replace(".jpg", "_depth.png"),
+ )
+ else:
+ depth_path = osp.join(
+ self.data_root,
+ data["filename"].replace(".jpg", "_depth.png"),
+ )
+ data_info["depth_path"] = depth_path
+
+ data_info["height"] = data["height"]
+ data_info["width"] = data["width"]
+
+ valid_anns = []
+ boxes = []
+ class_ids = np.empty((0,), dtype=np.int64)[1:]
+ if self.dataset_type == "OD":
+ instances = data.get("detection", {}).get("instances", [])
+
+ for ann in instances:
+ bbox = ann["bbox"]
+
+ # Box 2D
+ x1, y1, x2, y2 = bbox
+ inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
+ inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
+
+ if inter_w * inter_h == 0:
+ continue
+ if (x2 - x1) < 1 or (y2 - y1) < 1:
+ continue
+
+ boxes.append(bbox)
+
+ # Class
+ class_ids = np.concatenate(
+ [class_ids, np.array([ann["label"]], dtype=np.int64)]
+ )
+
+ valid_anns.append(ann)
+ else:
+ anno = data["grounding"]
+
+ caption = anno["caption"].lower().strip()
+ if not caption.endswith("."):
+ caption = caption + ". "
+
+ data_info["caption"] = caption
+
+ regions = anno["regions"]
+ phrases = []
+ positive_positions = []
+ for i, region in enumerate(regions):
+ bboxes = region["bbox"]
+
+ if not isinstance(bboxes[0], list):
+ bboxes = [bboxes]
+
+ for bbox in bboxes:
+ x1, y1, x2, y2 = bbox
+ inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
+ inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
+
+ if inter_w * inter_h == 0:
+ continue
+ if (x2 - x1) < 1 or (y2 - y1) < 1:
+ continue
+
+ boxes.append(bbox)
+ phrases.append(region["phrase"])
+ positive_positions.append(region["tokens_positive"])
+ valid_anns.append(region)
+
+ class_ids = np.concatenate(
+ [class_ids, np.array([i], dtype=np.int64)]
+ )
+
+ data_info["phrases"] = phrases
+ data_info["positive_positions"] = positive_positions
+
+ boxes2d = (
+ np.empty((0, 4), dtype=np.float32)
+ if not boxes
+ else np.array(boxes, dtype=np.float32)
+ )
+
+ data_info["boxes2d"] = boxes2d
+ data_info["class_ids"] = class_ids
+ data_info["anns"] = valid_anns
+
+ samples.append(data_info)
+
+ del data_list
+ return samples
+
+ def get_cat_ids(self, idx: int) -> list[int]:
+ """Return the samples."""
+ return self.samples[idx]["class_ids"].tolist()
+
+ def __len__(self) -> int:
+ """Total number of samples of data."""
+ return len(self.samples)
+
+ def __getitem__(self, idx: int) -> DictData:
+ """Get single sample.
+
+ Args:
+ idx (int): Index of sample.
+
+ Returns:
+ DictData: sample at index in Vis4D input format.
+ """
+ sample = self.samples[idx]
+ data_dict: DictData = {}
+
+ # Get image info
+ sample_name = sample["img_path"].split("/")[-1]
+ data_dict[K.sample_names] = sample_name
+
+ im_bytes = self.data_backend.get(sample["img_path"])
+ image = np.ascontiguousarray(
+ im_decode(im_bytes, mode=self.image_channel_mode),
+ dtype=np.float32,
+ )[None]
+
+ data_dict[K.images] = image
+ data_dict[K.input_hw] = (image.shape[1], image.shape[2])
+
+ data_dict[K.original_images] = image
+ data_dict[K.original_hw] = (image.shape[1], image.shape[2])
+
+ data_dict[K.boxes2d] = sample["boxes2d"]
+ data_dict[K.boxes2d_classes] = sample["class_ids"]
+
+ if self.dataset_type == "OD":
+ data_dict[K.boxes2d_names] = self.categories
+ data_dict["phrases"] = None
+ data_dict["positive_positions"] = None
+ else:
+ data_dict[K.boxes2d_names] = sample["caption"]
+ data_dict["phrases"] = sample["phrases"]
+ data_dict["positive_positions"] = sample["positive_positions"]
+
+ data_dict["dataset_type"] = self.dataset_type
+ data_dict["label_map"] = self.label_map
+
+ self.data_backend.close()
+
+ return data_dict
diff --git a/wilddet3d/data/datasets/omni3d/__init__.py b/wilddet3d/data/datasets/omni3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f173db6b4ea0691fb71340d1a8b870c841873c00
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/__init__.py
@@ -0,0 +1 @@
+"""Omni3D Dataset."""
diff --git a/wilddet3d/data/datasets/omni3d/arkitscenes.py b/wilddet3d/data/datasets/omni3d/arkitscenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbbf4591454016c81de77b048b5d263529604bf7
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/arkitscenes.py
@@ -0,0 +1,81 @@
+"""ARKitScenes from Omni3D."""
+
+from __future__ import annotations
+
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+from .omni3d_classes import omni3d_class_map
+
+arkitscenes_det_map = {
+ "bathtub": 0,
+ "bed": 1,
+ "cabinet": 2,
+ "chair": 3,
+ "fireplace": 4,
+ "machine": 5,
+ "oven": 6,
+ "refrigerator": 7,
+ "shelves": 8,
+ "sink": 9,
+ "sofa": 10,
+ "stove": 11,
+ "table": 12,
+ "television": 13,
+ "toilet": 14,
+}
+
+omni3d_arkitscenes_det_map = {
+ "table": 0,
+ "bed": 1,
+ "sofa": 2,
+ "television": 3,
+ "refrigerator": 4,
+ "chair": 5,
+ "oven": 6,
+ "machine": 7,
+ "stove": 8,
+ "shelves": 9,
+ "sink": 10,
+ "cabinet": 11,
+ "bathtub": 12,
+ "toilet": 13,
+}
+
+
+class ARKitScenes(COCO3DDataset):
+ """ARKitScenes Dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = omni3d_class_map,
+ max_depth: float = 10.0,
+ depth_scale: float = 1000.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ _, _, split, video_id, image_name = img["file_path"].split("/")
+
+ depth_filename = os.path.join(
+ "data/ARKitScenes_depth",
+ split,
+ video_id,
+ image_name.replace("jpg", "png"),
+ )
+
+ return depth_filename
diff --git a/wilddet3d/data/datasets/omni3d/hypersim.py b/wilddet3d/data/datasets/omni3d/hypersim.py
new file mode 100644
index 0000000000000000000000000000000000000000..8260f8415144d7d37d3f33a41abeb84bf91170e9
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/hypersim.py
@@ -0,0 +1,190 @@
+"""Hypersim from Omni3D."""
+
+from __future__ import annotations
+
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+from .omni3d_classes import omni3d_class_map
+
+hypersim_train_det_map = {
+ "bathtub": 0,
+ "bed": 1,
+ "blinds": 2,
+ "bookcase": 3,
+ "books": 4,
+ "box": 5,
+ "cabinet": 6,
+ "chair": 7,
+ "clothes": 8,
+ "counter": 9,
+ "curtain": 10,
+ "desk": 11,
+ "door": 12,
+ "dresser": 13,
+ "floor mat": 14,
+ "lamp": 15,
+ "mirror": 16,
+ "night stand": 17,
+ "person": 18,
+ "picture": 19,
+ "pillow": 20,
+ "refrigerator": 21,
+ "shelves": 22,
+ "sink": 23,
+ "sofa": 24,
+ "stationery": 25,
+ "table": 26,
+ "television": 27,
+ "toilet": 28,
+ "towel": 29,
+ "window": 30,
+}
+
+hypersim_val_det_map = {
+ "bathtub": 0,
+ "bed": 1,
+ "blinds": 2,
+ "bookcase": 3,
+ "books": 4,
+ "box": 5,
+ "cabinet": 6,
+ "chair": 7,
+ "clothes": 8,
+ "counter": 9,
+ "curtain": 10,
+ "desk": 11,
+ "door": 12,
+ "dresser": 13,
+ "floor mat": 14,
+ "lamp": 15,
+ "mirror": 16,
+ "night stand": 17,
+ "picture": 18,
+ "pillow": 19,
+ "refrigerator": 20,
+ "shelves": 21,
+ "sink": 22,
+ "sofa": 23,
+ "stationery": 24,
+ "table": 25,
+ "television": 26,
+ "toilet": 27,
+ "towel": 28,
+ "window": 29,
+}
+
+hypersim_test_det_map = {
+ "bathtub": 0,
+ "bed": 1,
+ "blinds": 2,
+ "board": 3,
+ "bookcase": 4,
+ "books": 5,
+ "box": 6,
+ "cabinet": 7,
+ "chair": 8,
+ "clothes": 9,
+ "counter": 10,
+ "curtain": 11,
+ "desk": 12,
+ "door": 13,
+ "floor mat": 14,
+ "lamp": 15,
+ "mirror": 16,
+ "night stand": 17,
+ "picture": 18,
+ "pillow": 19,
+ "refrigerator": 20,
+ "shelves": 21,
+ "sink": 22,
+ "sofa": 23,
+ "stationery": 24,
+ "table": 25,
+ "television": 26,
+ "towel": 27,
+ "window": 28,
+}
+
+
+omni3d_hypersim_det_map = {
+ "books": 0,
+ "chair": 1,
+ "towel": 2,
+ "blinds": 3,
+ "window": 4,
+ "lamp": 5,
+ "shelves": 6,
+ "mirror": 7,
+ "sink": 8,
+ "cabinet": 9,
+ "bathtub": 10,
+ "door": 11,
+ "desk": 12,
+ "box": 13,
+ "bookcase": 14,
+ "picture": 15,
+ "table": 16,
+ "counter": 17,
+ "bed": 18,
+ "night stand": 19,
+ "pillow": 20,
+ "sofa": 21,
+ "television": 22,
+ "floor mat": 23,
+ "curtain": 24,
+ "clothes": 25,
+ "stationery": 26,
+ "refrigerator": 27,
+}
+
+
+def get_hypersim_det_map(split: str) -> dict[str, int]:
+ """Get Hypersim detection map."""
+ assert split in {"train", "val", "test"}, f"Invalid split: {split}"
+
+ if split == "train":
+ return hypersim_train_det_map
+ elif split == "val":
+ return hypersim_val_det_map
+ elif split == "test":
+ return hypersim_test_det_map
+
+
+class Hypersim(COCO3DDataset):
+ """Hypersim Dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = omni3d_class_map,
+ max_depth: float = 50.0,
+ depth_scale: float = 1000.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ _, _, scene, _, img_dir, img_name = img["file_path"].split("/")
+
+ depth_filename = os.path.join(
+ "data/hypersim_depth",
+ scene,
+ "images",
+ img_dir,
+ img_name.replace("jpg", "png"),
+ )
+
+ return depth_filename
diff --git a/wilddet3d/data/datasets/omni3d/kitti_object.py b/wilddet3d/data/datasets/omni3d/kitti_object.py
new file mode 100644
index 0000000000000000000000000000000000000000..13aa64d94ea7d4ff5dd15d8e0ca7bed7070d0fa1
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/kitti_object.py
@@ -0,0 +1,105 @@
+"""KITTI Object from Omni3D.
+
+KITTI Object Labels:
+Categories, -, -, alpha, x1, y1, x2, y2, h, w, l, x, botom_y, z, ry
+
+KITTI Object Categories:
+{
+ "Pedestrian": "pedestrian",
+ "Cyclist": "cyclist",
+ "Car": "car",
+ "Van": "car",
+ "Truck": "truck",
+ "Tram": "tram",
+ "Person": "pedestrian",
+ "Person_sitting": "pedestrian",
+ "Misc": "misc",
+ "DontCare": "dontcare",
+}
+"""
+
+from __future__ import annotations
+
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+from .omni3d_classes import omni3d_class_map
+
+kitti_train_det_map = kitti_test_det_map = {
+ "car": 0,
+ "cyclist": 1,
+ "pedestrian": 2,
+ "person": 3,
+ "tram": 4,
+ "truck": 5,
+ "van": 6,
+}
+
+kitti_val_det_map = {
+ "car": 0,
+ "cyclist": 1,
+ "pedestrian": 2,
+ "tram": 3,
+ "truck": 4,
+}
+
+# KITTI-Omni3D Mapping
+omni3d_kitti_det_map = {
+ "pedestrian": 0,
+ "car": 1,
+ "cyclist": 2,
+ "van": 3,
+ "truck": 4,
+}
+
+
+def get_kitti_det_map(split: str) -> dict[str, int]:
+ """Get the KITTI detection map."""
+ assert split in {"train", "val", "test"}, f"Invalid split: {split}"
+
+ if split == "val":
+ return kitti_val_det_map
+
+ # Train and Test are the same
+ return kitti_train_det_map
+
+
+class KITTIObject(COCO3DDataset):
+ """KITTI Object Dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = omni3d_class_map,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ depth_data_root: str = "data/KITTI_object_depth",
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ self.depth_data_root = depth_data_root
+
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ _, _, split, image_id, img_filename = img["file_path"].split("/")
+
+ depth_filename = os.path.join(
+ self.depth_data_root,
+ split,
+ image_id,
+ img_filename.replace(".jpg", ".png"),
+ )
+
+ return depth_filename
diff --git a/wilddet3d/data/datasets/omni3d/nuscenes.py b/wilddet3d/data/datasets/omni3d/nuscenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f218bff9205ddaf87801c98c6a68005f66c7d364
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/nuscenes.py
@@ -0,0 +1,62 @@
+"""nuScenes from Omni3D."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+from .omni3d_classes import omni3d_class_map
+
+nusc_det_map = {
+ "bicycle": 0,
+ "motorcycle": 1,
+ "pedestrian": 2,
+ "bus": 3,
+ "car": 4,
+ "trailer": 5,
+ "truck": 6,
+ "traffic cone": 7,
+ "barrier": 8,
+}
+
+
+class nuScenes(COCO3DDataset):
+ """nuScenes dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = omni3d_class_map,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ img["file_path"] = img["file_path"].replace("nuScenes", "nuscenes")
+
+ depth_filename = (
+ img["file_path"]
+ .replace("nuscenes", "nuscenes_depth")
+ .replace("jpg", "png")
+ )
+ return depth_filename
+
+ def get_cat_ids(self, idx: int) -> list[int]:
+ """Return the samples."""
+ return self.samples[idx]["class_ids"].tolist()
+
+ def __len__(self) -> int:
+ """Total number of samples of data."""
+ return len(self.samples)
diff --git a/wilddet3d/data/datasets/omni3d/objectron.py b/wilddet3d/data/datasets/omni3d/objectron.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f9cf007b3753b45c3c13290ba94324502828f7a
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/objectron.py
@@ -0,0 +1,56 @@
+"""Objectron from Omni3D."""
+
+from __future__ import annotations
+
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+from .omni3d_classes import omni3d_class_map
+
+objectron_det_map = {
+ "bicycle": 0,
+ "books": 1,
+ "bottle": 2,
+ "camera": 3,
+ "cereal box": 4,
+ "chair": 5,
+ "cup": 6,
+ "laptop": 7,
+ "shoes": 8,
+}
+
+
+class Objectron(COCO3DDataset):
+ """Objectron dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = omni3d_class_map,
+ max_depth: float = 12.0,
+ depth_scale: float = 1000.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ _, _, split, img_name = img["file_path"].split("/")
+
+ depth_filename = os.path.join(
+ "data/objectron_depth",
+ split,
+ img_name.replace(".jpg", "_depth.png"),
+ )
+ return depth_filename
diff --git a/wilddet3d/data/datasets/omni3d/omni3d_classes.py b/wilddet3d/data/datasets/omni3d/omni3d_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9dceb74ea84bd9fd1ebe91388990ae6134fc291
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/omni3d_classes.py
@@ -0,0 +1,156 @@
+"""Omni3D classes."""
+
+omni3d_class_map = {
+ "pedestrian": 0,
+ "car": 1,
+ "dontcare": 2,
+ "cyclist": 3,
+ "van": 4,
+ "truck": 5,
+ "tram": 6,
+ "person": 7,
+ "traffic cone": 8,
+ "barrier": 9,
+ "motorcycle": 10,
+ "bicycle": 11,
+ "bus": 12,
+ "trailer": 13,
+ "books": 14,
+ "bottle": 15,
+ "camera": 16,
+ "cereal box": 17,
+ "chair": 18,
+ "cup": 19,
+ "laptop": 20,
+ "shoes": 21,
+ "towel": 22,
+ "blinds": 23,
+ "window": 24,
+ "lamp": 25,
+ "shelves": 26,
+ "mirror": 27,
+ "sink": 28,
+ "cabinet": 29,
+ "bathtub": 30,
+ "door": 31,
+ "toilet": 32,
+ "desk": 33,
+ "box": 34,
+ "bookcase": 35,
+ "picture": 36,
+ "table": 37,
+ "counter": 38,
+ "bed": 39,
+ "night stand": 40,
+ "dresser": 41,
+ "pillow": 42,
+ "sofa": 43,
+ "television": 44,
+ "floor mat": 45,
+ "curtain": 46,
+ "clothes": 47,
+ "stationery": 48,
+ "refrigerator": 49,
+ "board": 50,
+ "kitchen pan": 51,
+ "bin": 52,
+ "stove": 53,
+ "microwave": 54,
+ "plates": 55,
+ "bowl": 56,
+ "oven": 57,
+ "vase": 58,
+ "faucet": 59,
+ "tissues": 60,
+ "machine": 61,
+ "printer": 62,
+ "monitor": 63,
+ "podium": 64,
+ "cart": 65,
+ "projector": 66,
+ "electronics": 67,
+ "computer": 68,
+ "air conditioner": 69,
+ "drawers": 70,
+ "coffee maker": 71,
+ "toaster": 72,
+ "potted plant": 73,
+ "painting": 74,
+ "bag": 75,
+ "tray": 76,
+ "keyboard": 77,
+ "blanket": 78,
+ "rack": 79,
+ "phone": 80,
+ "mouse": 81,
+ "fire extinguisher": 82,
+ "toys": 83,
+ "ladder": 84,
+ "fan": 85,
+ "glass": 86,
+ "clock": 87,
+ "toilet paper": 88,
+ "closet": 89,
+ "fume hood": 90,
+ "utensils": 91,
+ "soundsystem": 92,
+ "fire place": 93,
+ "shower curtain": 94,
+ "remote": 95,
+ "pen": 96,
+ "fireplace": 97,
+}
+
+# Used for Cube R-CNN and Omni3D benchmark
+omni3d_det_map = {
+ "pedestrian": 0,
+ "car": 1,
+ "cyclist": 2,
+ "van": 3,
+ "truck": 4,
+ "traffic cone": 5,
+ "barrier": 6,
+ "motorcycle": 7,
+ "bicycle": 8,
+ "bus": 9,
+ "trailer": 10,
+ "books": 11,
+ "bottle": 12,
+ "camera": 13,
+ "cereal box": 14,
+ "chair": 15,
+ "cup": 16,
+ "laptop": 17,
+ "shoes": 18,
+ "towel": 19,
+ "blinds": 20,
+ "window": 21,
+ "lamp": 22,
+ "shelves": 23,
+ "mirror": 24,
+ "sink": 25,
+ "cabinet": 26,
+ "bathtub": 27,
+ "door": 28,
+ "toilet": 29,
+ "desk": 30,
+ "box": 31,
+ "bookcase": 32,
+ "picture": 33,
+ "table": 34,
+ "counter": 35,
+ "bed": 36,
+ "night stand": 37,
+ "pillow": 38,
+ "sofa": 39,
+ "television": 40,
+ "floor mat": 41,
+ "curtain": 42,
+ "clothes": 43,
+ "stationery": 44,
+ "refrigerator": 45,
+ "bin": 46,
+ "stove": 47,
+ "oven": 48,
+ "machine": 49,
+}
diff --git a/wilddet3d/data/datasets/omni3d/sunrgbd.py b/wilddet3d/data/datasets/omni3d/sunrgbd.py
new file mode 100644
index 0000000000000000000000000000000000000000..274579c850e7555f44197c05554d05fe1e564cee
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/sunrgbd.py
@@ -0,0 +1,278 @@
+"""SUN RGB-D from Omni3D."""
+
+from __future__ import annotations
+
+import os
+
+import numpy as np
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.datasets.util import im_decode
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+from .omni3d_classes import omni3d_class_map
+
+# Train and Test are sharing the classes
+sun_rgbd_train_det_map = sun_rgbd_test_det_map = {
+ "air conditioner": 0,
+ "bag": 1,
+ "bathtub": 2,
+ "bed": 3,
+ "bicycle": 4,
+ "bin": 5,
+ "blanket": 6,
+ "blinds": 7,
+ "board": 8,
+ "bookcase": 9,
+ "books": 10,
+ "bottle": 11,
+ "bowl": 12,
+ "box": 13,
+ "cabinet": 14,
+ "cart": 15,
+ "chair": 16,
+ "clock": 17,
+ "closet": 18,
+ "clothes": 19,
+ "coffee maker": 20,
+ "computer": 21,
+ "counter": 22,
+ "cup": 23,
+ "curtain": 24,
+ "desk": 25,
+ "door": 26,
+ "drawers": 27,
+ "dresser": 28,
+ "electronics": 29,
+ "fan": 30,
+ "faucet": 31,
+ "fire extinguisher": 32,
+ "fire place": 33,
+ "floor mat": 34,
+ "fume hood": 35,
+ "glass": 36,
+ "keyboard": 37,
+ "kitchen pan": 38,
+ "ladder": 39,
+ "lamp": 40,
+ "laptop": 41,
+ "machine": 42,
+ "microwave": 43,
+ "mirror": 44,
+ "monitor": 45,
+ "mouse": 46,
+ "night stand": 47,
+ "oven": 48,
+ "painting": 49,
+ "pen": 50,
+ "person": 51,
+ "phone": 52,
+ "picture": 53,
+ "pillow": 54,
+ "plates": 55,
+ "podium": 56,
+ "potted plant": 57,
+ "printer": 58,
+ "projector": 59,
+ "rack": 60,
+ "refrigerator": 61,
+ "remote": 62,
+ "shelves": 63,
+ "shoes": 64,
+ "shower curtain": 65,
+ "sink": 66,
+ "sofa": 67,
+ "soundsystem": 68,
+ "stationery": 69,
+ "stove": 70,
+ "table": 71,
+ "television": 72,
+ "tissues": 73,
+ "toaster": 74,
+ "toilet": 75,
+ "toilet paper": 76,
+ "towel": 77,
+ "toys": 78,
+ "tray": 79,
+ "utensils": 80,
+ "vase": 81,
+ "window": 82,
+}
+
+sun_rgbd_val_det_map = {
+ "air conditioner": 0,
+ "bag": 1,
+ "bathtub": 2,
+ "bed": 3,
+ "bin": 4,
+ "blanket": 5,
+ "blinds": 6,
+ "board": 7,
+ "bookcase": 8,
+ "books": 9,
+ "bottle": 10,
+ "bowl": 11,
+ "box": 12,
+ "cabinet": 13,
+ "cart": 14,
+ "chair": 15,
+ "closet": 16,
+ "clothes": 17,
+ "coffee maker": 18,
+ "computer": 19,
+ "counter": 20,
+ "cup": 21,
+ "curtain": 22,
+ "desk": 23,
+ "door": 24,
+ "drawers": 25,
+ "dresser": 26,
+ "electronics": 27,
+ "fan": 28,
+ "faucet": 29,
+ "fire extinguisher": 30,
+ "fire place": 31,
+ "fume hood": 32,
+ "keyboard": 33,
+ "kitchen pan": 34,
+ "lamp": 35,
+ "laptop": 36,
+ "machine": 37,
+ "microwave": 38,
+ "mirror": 39,
+ "monitor": 40,
+ "night stand": 41,
+ "oven": 42,
+ "painting": 43,
+ "pen": 44,
+ "person": 45,
+ "phone": 46,
+ "picture": 47,
+ "pillow": 48,
+ "plates": 49,
+ "potted plant": 50,
+ "printer": 51,
+ "projector": 52,
+ "rack": 53,
+ "refrigerator": 54,
+ "shelves": 55,
+ "sink": 56,
+ "sofa": 57,
+ "soundsystem": 58,
+ "stationery": 59,
+ "stove": 60,
+ "table": 61,
+ "television": 62,
+ "tissues": 63,
+ "toaster": 64,
+ "toilet": 65,
+ "towel": 66,
+ "toys": 67,
+ "tray": 68,
+ "utensils": 69,
+ "vase": 70,
+ "window": 71,
+}
+
+omni3d_sun_rgbd_det_map = {
+ "bicycle": 0,
+ "books": 1,
+ "bottle": 2,
+ "chair": 3,
+ "cup": 4,
+ "laptop": 5,
+ "shoes": 6,
+ "towel": 7,
+ "blinds": 8,
+ "window": 9,
+ "lamp": 10,
+ "shelves": 11,
+ "mirror": 12,
+ "sink": 13,
+ "cabinet": 14,
+ "bathtub": 15,
+ "door": 16,
+ "toilet": 17,
+ "desk": 18,
+ "box": 19,
+ "bookcase": 20,
+ "picture": 21,
+ "table": 22,
+ "counter": 23,
+ "bed": 24,
+ "night stand": 25,
+ "pillow": 26,
+ "sofa": 27,
+ "television": 28,
+ "floor mat": 29,
+ "curtain": 30,
+ "clothes": 31,
+ "stationery": 32,
+ "refrigerator": 33,
+ "bin": 34,
+ "stove": 35,
+ "oven": 36,
+ "machine": 37,
+}
+
+
+def get_sunrgbd_det_map(split: str) -> dict[str, int]:
+ """Get the SUN RGB-D detection map."""
+ assert split in {"train", "val", "test"}, f"Invalid split: {split}"
+
+ if split == "train":
+ return sun_rgbd_train_det_map
+ elif split == "val":
+ return sun_rgbd_val_det_map
+ else:
+ return sun_rgbd_test_det_map
+
+
+class SUNRGBD(COCO3DDataset):
+ """SUN RGB-D Dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = omni3d_class_map,
+ max_depth: float = 8.0,
+ depth_scale: float = 1000.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Initialize SUN RGB-D dataset."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ img["file_path"] = img["file_path"].replace("//", "/")
+
+ data_dir = img["file_path"].split("/image")[0]
+
+ depth_files = self.data_backend.listdir(
+ os.path.join(data_dir, "depth")
+ )
+ assert len(depth_files) == 1
+
+ depth_filename = os.path.join(data_dir, "depth", depth_files[0])
+
+ return depth_filename
+
+ def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
+ """Get the depth map."""
+ depth_bytes = self.data_backend.get(sample["depth_filename"])
+ depth_array = im_decode(depth_bytes)
+
+ depth_array = depth_array >> 3 | depth_array << (16 - 3)
+
+ depth = np.ascontiguousarray(depth_array, dtype=np.float32)
+
+ depth = depth / self.depth_scale
+
+ return depth
diff --git a/wilddet3d/data/datasets/omni3d/util.py b/wilddet3d/data/datasets/omni3d/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebdaa8365c5741c74cff9e22b378fa65a4ff2b40
--- /dev/null
+++ b/wilddet3d/data/datasets/omni3d/util.py
@@ -0,0 +1,86 @@
+"""Omni3D data util."""
+
+from __future__ import annotations
+
+from .arkitscenes import arkitscenes_det_map, omni3d_arkitscenes_det_map
+from .hypersim import get_hypersim_det_map, omni3d_hypersim_det_map
+from .kitti_object import get_kitti_det_map, omni3d_kitti_det_map
+from .nuscenes import nusc_det_map
+from .objectron import objectron_det_map
+from .sunrgbd import get_sunrgbd_det_map, omni3d_sun_rgbd_det_map
+
+DATASET_ID_MAP = {
+ 0: "KITTI_train",
+ 1: "KITTI_val",
+ 2: "KITTI_test",
+ 3: "nuScenes_train",
+ 4: "nuScenes_val",
+ 5: "nuScenes_test",
+ 6: "Objectron_train",
+ 7: "Objectron_val",
+ 8: "Objectron_test",
+ 9: "Hypersim_train",
+ 10: "Hypersim_val",
+ 11: "Hypersim_test",
+ 12: "SUNRGBD_train",
+ 13: "SUNRGBD_val",
+ 14: "SUNRGBD_test",
+ 15: "ARKitScenes_train",
+ 16: "ARKitScenes_val",
+ 17: "ARKitScenes_test",
+}
+
+
+def get_dataset_det_map(
+ dataset_name: str,
+ omni3d50: bool = True,
+) -> tuple[str, dict[str, int]]:
+ """Get the detection map."""
+ if "train" in dataset_name:
+ split = "train"
+ elif "val" in dataset_name:
+ split = "val"
+ elif "test" in dataset_name:
+ split = "test"
+ else:
+ raise ValueError(f"Unknown dataset_name: {dataset_name}")
+
+ if "nuScenes" in dataset_name:
+ det_map = nusc_det_map
+ elif "KITTI" in dataset_name:
+ if omni3d50:
+ det_map = omni3d_kitti_det_map
+ else:
+ det_map = get_kitti_det_map(split)
+ elif "Objectron" in dataset_name:
+ det_map = objectron_det_map
+ elif "SUNRGBD" in dataset_name:
+ if omni3d50:
+ det_map = omni3d_sun_rgbd_det_map
+ else:
+ det_map = get_sunrgbd_det_map(split)
+ elif "Hypersim" in dataset_name:
+ if omni3d50:
+ det_map = omni3d_hypersim_det_map
+ else:
+ det_map = get_hypersim_det_map(split)
+ elif "ARKitScenes" in dataset_name:
+ det_map = (
+ omni3d_arkitscenes_det_map if omni3d50 else arkitscenes_det_map
+ )
+ elif "CubifyAnything" in dataset_name:
+ from wilddet3d.data.datasets.cubifyanything import (
+ get_cubifyanything_det_map,
+ )
+
+ det_map = get_cubifyanything_det_map(dataset_name)
+ elif "Waymo" in dataset_name:
+ from wilddet3d.data.datasets.waymo import (
+ get_waymo_det_map,
+ )
+
+ det_map = get_waymo_det_map(dataset_name)
+ else:
+ raise ValueError(f"Unknown dataset_name: {dataset_name}")
+
+ return det_map
diff --git a/wilddet3d/data/datasets/scannet.py b/wilddet3d/data/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f31438a6b22fa81876aa58f4652877786fa9246
--- /dev/null
+++ b/wilddet3d/data/datasets/scannet.py
@@ -0,0 +1,449 @@
+"""ScanNet dataset."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from .coco3d import COCO3DDataset
+
+scannet_class_map = {
+ "cabinet": 3,
+ "bed": 4,
+ "chair": 5,
+ "sofa": 6,
+ "table": 7,
+ "door": 8,
+ "window": 9,
+ "bookshelf": 10,
+ "picture": 11,
+ "counter": 12,
+ "desk": 14,
+ "curtain": 16,
+ "refrigerator": 24,
+ "shower curtain": 28,
+ "toilet": 33,
+ "sink": 34,
+ "bathtub": 36,
+ "other furniture": 39,
+}
+
+scannet_det_map = {
+ "cabinet": 0,
+ "bed": 1,
+ "chair": 2,
+ "sofa": 3,
+ "table": 4,
+ "door": 5,
+ "window": 6,
+ "bookshelf": 7,
+ "picture": 8,
+ "counter": 9,
+ "desk": 10,
+ "curtain": 11,
+ "refrigerator": 12,
+ "shower curtain": 13,
+ "toilet": 14,
+ "sink": 15,
+ "bathtub": 16,
+ "other furniture": 17,
+}
+
+scannet200_class_map = {
+ "chair": 2,
+ "book": 22,
+ "door": 5,
+ "object": 1163,
+ "window": 16,
+ "table": 4,
+ "trash can": 56,
+ "pillow": 13,
+ "picture": 15,
+ "box": 26,
+ "doorframe": 161,
+ "monitor": 19,
+ "cabinet": 7,
+ "desk": 9,
+ "shelf": 8,
+ "office chair": 10,
+ "towel": 31,
+ "couch": 6,
+ "sink": 14,
+ "backpack": 48,
+ "lamp": 28,
+ "bed": 11,
+ "bookshelf": 18,
+ "mirror": 71,
+ "curtain": 21,
+ "plant": 40,
+ "whiteboard": 52,
+ "radiator": 96,
+ "kitchen cabinet": 29,
+ "toilet paper": 49,
+ "armchair": 23,
+ "shoe": 63,
+ "coffee table": 24,
+ "toilet": 17,
+ "bag": 47,
+ "clothes": 32,
+ "keyboard": 46,
+ "bottle": 65,
+ "recycling bin": 97,
+ "nightstand": 34,
+ "stool": 38,
+ "tv": 33,
+ "file cabinet": 75,
+ "dresser": 36,
+ "computer tower": 64,
+ "telephone": 101,
+ "cup": 130,
+ "refrigerator": 27,
+ "end table": 44,
+ "jacket": 131,
+ "shower curtain": 55,
+ "bathtub": 42,
+ "microwave": 59,
+ "kitchen counter": 159,
+ "sofa chair": 74,
+ "paper towel dispenser": 82,
+ "bathroom vanity": 1164,
+ "suitcase": 93,
+ "laptop": 77,
+ "ottoman": 67,
+ "shower wall": 128,
+ "printer": 50,
+ "counter": 35,
+ "board": 69,
+ "soap dispenser": 100,
+ "stove": 62,
+ "light": 105,
+ "closet wall": 1165,
+ "mini fridge": 165,
+ "fan": 76,
+ "tissue box": 230,
+ "blanket": 54,
+ "bathroom stall": 125,
+ "copier": 72,
+ "bench": 68,
+ "bar": 145,
+ "soap dish": 157,
+ "laundry hamper": 1166,
+ "storage bin": 132,
+ "bathroom stall door": 1167,
+ "light switch": 232,
+ "coffee maker": 134,
+ "tv stand": 51,
+ "decoration": 250,
+ "ceiling light": 1168,
+ "range hood": 342,
+ "blackboard": 89,
+ "clock": 103,
+ "wardrobe": 99,
+ "rail": 95,
+ "bulletin board": 154,
+ "mat": 140,
+ "trash bin": 1169,
+ "ledge": 193,
+ "seat": 116,
+ "mouse": 202,
+ "basket": 73,
+ "shower": 78,
+ "dumbbell": 1170,
+ "paper": 79,
+ "person": 80,
+ "windowsill": 141,
+ "closet": 57,
+ "bucket": 102,
+ "sign": 261,
+ "speaker": 118,
+ "dishwasher": 136,
+ "container": 98,
+ "stair rail": 1171,
+ "shower curtain rod": 170,
+ "tube": 1172,
+ "bathroom cabinet": 1173,
+ "storage container": 221,
+ "paper bag": 570,
+ "paper towel roll": 138,
+ "ball": 168,
+ "closet door": 276,
+ "laundry basket": 106,
+ "cart": 214,
+ "dish rack": 323,
+ "stairs": 58,
+ "blinds": 86,
+ "purse": 399,
+ "bicycle": 121,
+ "tray": 185,
+ "plunger": 300,
+ "paper cutter": 180,
+ "toilet paper dispenser": 163,
+ "bin": 66,
+ "toilet seat cover dispenser": 208,
+ "guitar": 112,
+ "mailbox": 540,
+ "handicap bar": 395,
+ "fire extinguisher": 166,
+ "ladder": 122,
+ "column": 120,
+ "pipe": 107,
+ "vacuum cleaner": 283,
+ "plate": 88,
+ "piano": 90,
+ "water cooler": 177,
+ "cd case": 1174,
+ "bowl": 562,
+ "closet rod": 1175,
+ "bathroom counter": 1156,
+ "oven": 84,
+ "stand": 104,
+ "scale": 229,
+ "washing machine": 70,
+ "broom": 325,
+ "hat": 169,
+ "guitar case": 331,
+ "rack": 87,
+ "water pitcher": 488,
+ "laundry detergent": 776,
+ "hair dryer": 370,
+ "pillar": 191,
+ "divider": 748,
+ "power outlet": 242,
+ "dining table": 45,
+ "shower floor": 417,
+ "shower door": 188,
+ "coffee kettle": 1176,
+ "structure": 1178,
+ "clothes dryer": 110,
+ "toaster": 148,
+ "ironing board": 155,
+ "alarm clock": 572,
+ "shower head": 1179,
+ "water bottle": 392,
+ "keyboard piano": 1180,
+ "projector screen": 609,
+ "case of water bottles": 1181,
+ "toaster oven": 195,
+ "music stand": 581,
+ "coat rack": 1182,
+ "storage organizer": 1183,
+ "machine": 139,
+ "folded chair": 1184,
+ "fire alarm": 1185,
+ "fireplace": 156,
+ "vent": 408,
+ "furniture": 213,
+ "power strip": 1186,
+ "calendar": 1187,
+ "poster": 1188,
+ "toilet paper holder": 115,
+ "potted plant": 1189,
+ "stuffed animal": 304,
+ "luggage": 1190,
+ "headphones": 312,
+ "crate": 233,
+ "candle": 286,
+ "projector": 264,
+ "mattress": 1191,
+ "dustpan": 356,
+ "cushion": 39,
+ "stick": 1163,
+}
+
+scannet200_det_map = {
+ "chair": 0,
+ "table": 1,
+ "door": 2,
+ "couch": 3,
+ "cabinet": 4,
+ "shelf": 5,
+ "desk": 6,
+ "office chair": 7,
+ "bed": 8,
+ "pillow": 9,
+ "sink": 10,
+ "picture": 11,
+ "window": 12,
+ "toilet": 13,
+ "bookshelf": 14,
+ "monitor": 15,
+ "curtain": 16,
+ "book": 17,
+ "armchair": 18,
+ "coffee table": 19,
+ "box": 20,
+ "refrigerator": 21,
+ "lamp": 22,
+ "kitchen cabinet": 23,
+ "towel": 24,
+ "clothes": 25,
+ "tv": 26,
+ "nightstand": 27,
+ "counter": 28,
+ "dresser": 29,
+ "stool": 30,
+ "plant": 31,
+ "bathtub": 32,
+ "end table": 33,
+ "dining table": 34,
+ "keyboard": 35,
+ "bag": 36,
+ "backpack": 37,
+ "toilet paper": 38,
+ "printer": 39,
+ "tv stand": 40,
+ "whiteboard": 41,
+ "blanket": 42,
+ "shower curtain": 43,
+ "trash can": 44,
+ "closet": 45,
+ "stairs": 46,
+ "microwave": 47,
+ "stove": 48,
+ "shoe": 49,
+ "computer tower": 50,
+ "bottle": 51,
+ "bin": 52,
+ "ottoman": 53,
+ "bench": 54,
+ "board": 55,
+ "washing machine": 56,
+ "mirror": 57,
+ "copier": 58,
+ "basket": 59,
+ "sofa chair": 60,
+ "file cabinet": 61,
+ "fan": 62,
+ "laptop": 63,
+ "shower": 64,
+ "paper": 65,
+ "person": 66,
+ "paper towel dispenser": 67,
+ "oven": 68,
+ "blinds": 69,
+ "rack": 70,
+ "plate": 71,
+ "blackboard": 72,
+ "piano": 73,
+ "suitcase": 74,
+ "rail": 75,
+ "radiator": 76,
+ "recycling bin": 77,
+ "container": 78,
+ "wardrobe": 79,
+ "soap dispenser": 80,
+ "telephone": 81,
+ "bucket": 82,
+ "clock": 83,
+ "stand": 84,
+ "light": 85,
+ "laundry basket": 86,
+ "pipe": 87,
+ "clothes dryer": 88,
+ "guitar": 89,
+ "toilet paper holder": 90,
+ "seat": 91,
+ "speaker": 92,
+ "column": 93,
+ "ladder": 94,
+ "cup": 95,
+ "jacket": 96,
+ "storage bin": 97,
+ "coffee maker": 98,
+ "dishwasher": 99,
+ "paper towel roll": 100,
+ "machine": 101,
+ "mat": 102,
+ "windowsill": 103,
+ "bar": 104,
+ "bulletin board": 105,
+ "ironing board": 106,
+ "fireplace": 107,
+ "soap dish": 108,
+ "kitchen counter": 109,
+ "doorframe": 110,
+ "toilet paper dispenser": 111,
+ "mini fridge": 112,
+ "fire extinguisher": 113,
+ "ball": 114,
+ "hat": 115,
+ "shower curtain rod": 116,
+ "water cooler": 117,
+ "paper cutter": 118,
+ "tray": 119,
+ "pillar": 120,
+ "ledge": 121,
+ "toaster oven": 122,
+ "mouse": 123,
+ "toilet seat cover dispenser": 124,
+ "cart": 125,
+ "scale": 126,
+ "tissue box": 127,
+ "light switch": 128,
+ "crate": 129,
+ "power outlet": 130,
+ "decoration": 131,
+ "sign": 132,
+ "projector": 133,
+ "closet door": 134,
+ "vacuum cleaner": 135,
+ "headphones": 136,
+ "dish rack": 137,
+ "broom": 138,
+ "range hood": 139,
+ "hair dryer": 140,
+ "water bottle": 141,
+ "vent": 142,
+ "mailbox": 143,
+ "bowl": 144,
+ "paper bag": 145,
+ "projector screen": 146,
+ "divider": 147,
+ "laundry detergent": 148,
+ "bathroom counter": 149,
+ "stick": 150,
+ "bathroom vanity": 151,
+ "closet wall": 152,
+ "laundry hamper": 153,
+ "bathroom stall door": 154,
+ "ceiling light": 155,
+ "trash bin": 156,
+ "dumbbell": 157,
+ "stair rail": 158,
+ "tube": 159,
+ "bathroom cabinet": 160,
+ "coffee kettle": 161,
+ "shower head": 162,
+ "case of water bottles": 163,
+ "power strip": 164,
+ "calendar": 165,
+ "poster": 166,
+ "mattress": 167,
+}
+
+
+class ScanNetDataset(COCO3DDataset):
+ """ScanNetV2 dataset."""
+
+ def __init__(
+ self,
+ class_map: dict[str, int] = scannet_class_map,
+ max_depth: float = 12.0,
+ depth_scale: float = 1000.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ class_map=class_map,
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filenames.
+
+ Since not every data has depth.
+ """
+ return (
+ img["file_path"].replace("image", "depth").replace(".jpg", ".png")
+ )
diff --git a/wilddet3d/data/datasets/stereo4d.py b/wilddet3d/data/datasets/stereo4d.py
new file mode 100644
index 0000000000000000000000000000000000000000..074a5948ec9b7a000dc78aaf97ec5ed6fc395a34
--- /dev/null
+++ b/wilddet3d/data/datasets/stereo4d.py
@@ -0,0 +1,178 @@
+"""Stereo4D tinyval 3D dataset (real stereo depth, 500 images)."""
+
+from __future__ import annotations
+
+import json
+import os
+
+import cv2
+import numpy as np
+
+from vis4d.common.typing import ArgsType, DictStrAny
+from vis4d.data.const import CommonKeys as K
+
+from .coco3d import COCO3DDataset
+
+# Stereo4D v3 depth directory (meters, 512x512 .npy files)
+_STEREO4D_DEPTH_DIR = (
+ "/weka/oe-training-default/weikaih/3d_boundingbox_detection"
+ "/video_data/stereo4d_test/stereo4d_dataset_v3/depth"
+)
+
+# V3 annotation for image_id -> filename mapping (depth file lookup)
+_V3_ANN_PATH = (
+ "/weka/oe-training-default/weikaih/3d_boundingbox_detection"
+ "/video_data/stereo4d_test/stereo4d_dataset_v3"
+ "/annotations/stereo4d_test.json"
+)
+
+# Tinyval source directory (to recover original v3 image_ids)
+_TINYVAL_DIR = (
+ "/weka/oe-training-default/weikaih/3d_boundingbox_detection"
+ "/single_frame_data/experiment/v4_score_merged_la3d/stereo4d/tinyval"
+)
+
+# Cached v3 id-to-stem mapping (built once, reused)
+_v3_id_to_stem_cache: dict[int, str] | None = None
+_tinyval_orig_ids_cache: list[int] | None = None
+
+
+def _load_v3_id_to_stem() -> dict[int, str]:
+ """Load v3 image_id -> file stem mapping for depth lookup."""
+ global _v3_id_to_stem_cache
+ if _v3_id_to_stem_cache is not None:
+ return _v3_id_to_stem_cache
+ with open(_V3_ANN_PATH) as f:
+ v3 = json.load(f)
+ _v3_id_to_stem_cache = {}
+ for img in v3["images"]:
+ stem = os.path.splitext(os.path.basename(img["file_name"]))[0]
+ _v3_id_to_stem_cache[img["id"]] = stem
+ return _v3_id_to_stem_cache
+
+
+def _load_tinyval_orig_ids() -> list[int]:
+ """Load tinyval original v3 image_ids in sorted order."""
+ global _tinyval_orig_ids_cache
+ if _tinyval_orig_ids_cache is not None:
+ return _tinyval_orig_ids_cache
+ files = sorted(
+ f for f in os.listdir(_TINYVAL_DIR) if f.endswith(".json")
+ )
+ _tinyval_orig_ids_cache = []
+ for f in files:
+ img_id = int(f.split("_")[-1].replace(".json", ""))
+ _tinyval_orig_ids_cache.append(img_id)
+ return _tinyval_orig_ids_cache
+
+
+def load_stereo4d_class_map(
+ annotation_path: str,
+) -> dict[str, int]:
+ """Load class map from Stereo4D annotation file.
+
+ Returns a mapping from category name to category ID.
+ """
+ cache_path = annotation_path.replace(".json", "_class_map.json")
+ if os.path.exists(cache_path):
+ with open(cache_path) as f:
+ return json.load(f)
+ with open(annotation_path) as f:
+ data = json.load(f)
+ class_map = {cat["name"]: cat["id"] for cat in data["categories"]}
+ with open(cache_path, "w") as f:
+ json.dump(class_map, f)
+ return class_map
+
+
+class Stereo4D3DDataset(COCO3DDataset):
+ """Stereo4D tinyval 3D dataset with real stereo depth.
+
+ 500 images from Stereo4D test set with human-reviewed 3D bounding
+ boxes. Depth maps are real stereo depth (meters, 512x512).
+
+ Key differences from InTheWild3DDataset:
+ - Depth is real stereo depth (meters), not estimated depth (mm).
+ - All images are 512x512.
+ - No confidence masking needed (stereo depth is high quality).
+ """
+
+ def __init__(
+ self,
+ class_map: dict[str, int],
+ max_depth: float = 100.0,
+ per_image_categories: bool = False,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ class_map: Mapping from category name to category ID.
+ max_depth: Maximum depth in meters (clip beyond this).
+ per_image_categories: If True, boxes2d_names only contains
+ the GT categories present in each image.
+ """
+ # Initialize depth mappings BEFORE super().__init__() because
+ # _generate_data_mapping -> get_depth_filenames needs these.
+ self.per_image_categories = per_image_categories
+ self._v3_id_to_stem = _load_v3_id_to_stem()
+ self._tinyval_orig_ids = _load_tinyval_orig_ids()
+
+ super().__init__(
+ class_map=class_map,
+ det_map=class_map,
+ max_depth=max_depth,
+ **kwargs,
+ )
+
+ def __getitem__(self, idx: int):
+ """Get single sample, optionally with per-image category filtering."""
+ data_dict = super().__getitem__(idx)
+ if self.per_image_categories:
+ class_ids_in_img = data_dict[K.boxes2d_classes]
+ if len(class_ids_in_img) > 0:
+ unique_global_ids = sorted(set(class_ids_in_img.tolist()))
+ data_dict[K.boxes2d_names] = [
+ self.categories[gid] for gid in unique_global_ids
+ ]
+ else:
+ data_dict[K.boxes2d_names] = []
+ return data_dict
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Return path to the .npy stereo depth file for this image.
+
+ Maps: converted image index -> tinyval orig_id -> v3 stem -> depth .npy
+ """
+ img_id = img["id"]
+ if img_id >= len(self._tinyval_orig_ids):
+ return None
+ orig_id = self._tinyval_orig_ids[img_id]
+ stem = self._v3_id_to_stem.get(orig_id)
+ if stem is None:
+ return None
+ depth_path = os.path.join(_STEREO4D_DEPTH_DIR, f"{stem}.npy")
+ return depth_path if os.path.exists(depth_path) else None
+
+ def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
+ """Load stereo depth .npy (meters, 512x512).
+
+ No mm-to-meters conversion needed (already in meters).
+ Resize to original resolution if needed.
+ """
+ depth = np.load(sample["depth_filename"]) # (H, W) float32, meters
+
+ orig_h = sample["img"]["height"]
+ orig_w = sample["img"]["width"]
+
+ if depth.shape != (orig_h, orig_w):
+ depth = cv2.resize(
+ depth,
+ (orig_w, orig_h),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Clip to max_depth
+ depth[depth > self.max_depth] = 0.0
+
+ return depth.astype(np.float32)
diff --git a/wilddet3d/data/datasets/threeeed.py b/wilddet3d/data/datasets/threeeed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d61644b152c5719e94ebbe69eeb0c9e99f56c47
--- /dev/null
+++ b/wilddet3d/data/datasets/threeeed.py
@@ -0,0 +1,79 @@
+"""3EED dataset for 3D object detection.
+
+Multi-platform outdoor scenes (Waymo vehicle, M3ED drone, M3ED quadruped)
+with sparse LiDAR depth maps (uint16, depth_m * 256).
+Categories: car, pedestrian, bus, truck, othervehicle, cyclist.
+"""
+
+from __future__ import annotations
+
+import json
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+
+def get_threeeed_det_map(
+ dataset_name: str,
+ data_root: str = "data/3eed",
+) -> dict[str, int]:
+ """Build det_map from 3EED annotation JSON categories."""
+ cache_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}_class_map.json"
+ )
+ if os.path.exists(cache_path):
+ with open(cache_path) as f:
+ return json.load(f)
+ json_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}.json"
+ )
+ with open(json_path) as f:
+ data = json.load(f)
+ class_map = {cat["name"]: cat["id"] for cat in data["categories"]}
+ with open(cache_path, "w") as f:
+ json.dump(class_map, f)
+ return class_map
+
+
+def get_threeeed_class_map(
+ dataset_name: str,
+ data_root: str = "data/3eed",
+) -> dict[str, int]:
+ """Build class_map from 3EED annotation JSON categories."""
+ return get_threeeed_det_map(dataset_name, data_root)
+
+
+class ThreeEEDDataset(COCO3DDataset):
+ """3EED Dataset.
+
+ Multi-platform outdoor scenes with sparse LiDAR depth maps.
+ """
+
+ def __init__(
+ self,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filename for a given image.
+
+ Maps image path to depth path:
+ 3eed/3eed_dataset/{platform}/{seq}/{frame}/image.jpg
+ -> 3eed/depth/{platform}/{seq}/{frame}.png
+ """
+ # image: 3eed/3eed_dataset/waymo/seq/frame/image.jpg
+ # depth: 3eed/depth/waymo/seq/frame.png
+ path = img["file_path"]
+ parts = path.replace("3eed/3eed_dataset/", "3eed/depth/")
+ parts = parts.replace("/image.jpg", ".png")
+ return parts
diff --git a/wilddet3d/data/datasets/waymo.py b/wilddet3d/data/datasets/waymo.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b2066530c43229d9423d632e0c6a6fe3c90d9be
--- /dev/null
+++ b/wilddet3d/data/datasets/waymo.py
@@ -0,0 +1,89 @@
+"""Waymo Open Dataset for 3D object detection.
+
+Outdoor driving scenes with sparse LiDAR depth maps (uint16, depth_m * 256).
+Categories: vehicle, pedestrian, cyclist, sign.
+"""
+
+from __future__ import annotations
+
+import json
+import os
+
+from vis4d.common.typing import ArgsType, DictStrAny
+
+from wilddet3d.data.datasets.coco3d import COCO3DDataset
+
+
+def get_waymo_det_map(
+ dataset_name: str,
+ data_root: str = "data/waymo",
+) -> dict[str, int]:
+ """Build det_map from Waymo annotation JSON categories.
+
+ Waymo has 4 categories (vehicle, pedestrian, cyclist, sign).
+ Since our model is open-vocabulary (text-prompted), we build
+ det_map dynamically from the annotation JSON.
+
+ Args:
+ dataset_name: e.g. "Waymo_train" or "Waymo_val"
+ data_root: Root directory for Waymo data.
+ """
+ cache_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}_class_map.json"
+ )
+ if os.path.exists(cache_path):
+ with open(cache_path) as f:
+ return json.load(f)
+ json_path = os.path.join(
+ data_root, "annotations", f"{dataset_name}.json"
+ )
+ with open(json_path) as f:
+ data = json.load(f)
+ class_map = {cat["name"]: cat["id"] for cat in data["categories"]}
+ with open(cache_path, "w") as f:
+ json.dump(class_map, f)
+ return class_map
+
+
+def get_waymo_class_map(
+ dataset_name: str,
+ data_root: str = "data/waymo",
+) -> dict[str, int]:
+ """Build class_map from Waymo annotation JSON categories.
+
+ Args:
+ dataset_name: e.g. "Waymo_train" or "Waymo_val"
+ data_root: Root directory for Waymo data.
+ """
+ return get_waymo_det_map(dataset_name, data_root)
+
+
+class WaymoDataset(COCO3DDataset):
+ """Waymo Open Dataset.
+
+ Outdoor driving scenes with sparse LiDAR depth maps.
+ """
+
+ def __init__(
+ self,
+ max_depth: float = 80.0,
+ depth_scale: float = 256.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__(
+ max_depth=max_depth,
+ depth_scale=depth_scale,
+ **kwargs,
+ )
+
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
+ """Get the depth filename for a given image.
+
+ Maps image path to depth path:
+ waymo/images/validation/xxx.jpg
+ -> waymo/depth/validation/xxx.png
+ """
+ return img["file_path"].replace(
+ "images", "depth"
+ ).replace(".jpg", ".png")
diff --git a/wilddet3d/data/samplers.py b/wilddet3d/data/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..536278189841bd67aa0c5431e79a2bc15e99e2be
--- /dev/null
+++ b/wilddet3d/data/samplers.py
@@ -0,0 +1,224 @@
+"""Dataset-ratio weighted sampler for multi-dataset training."""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Callable, Iterator, Sequence
+
+import torch
+from torch.utils.data import ConcatDataset, DataLoader, Sampler
+from torch.utils.data.distributed import DistributedSampler
+
+from vis4d.common.distributed import get_rank, get_world_size
+from vis4d.data.data_pipe import DataPipe
+from vis4d.data.loader import build_train_dataloader
+from vis4d.data.typing import DictData, DictDataOrList
+
+
+class DatasetRatioSampler(Sampler[int]):
+ """Weighted sampler that controls per-dataset sampling ratios.
+
+ For a ConcatDataset with N sub-datasets, this sampler assigns each
+ sample a weight based on which sub-dataset it belongs to, then
+ performs weighted random sampling. This allows controlling the
+ proportion each dataset appears during training without dropping
+ any data.
+
+ Two modes of specifying ratios:
+
+ 1. dataset_ratios (original): raw per-dataset weights.
+ weight_i = ratio_i / size_i, proportion is derived.
+ Example: dataset_ratios=[1.0, 1.0] for Omni3D(100K)+CA-1M(200K)
+ -> 50/50 sampling proportion.
+
+ 2. target_proportions (new): directly specify desired proportions.
+ Must sum to 1.0. Weights are computed automatically.
+ Example: target_proportions=[0.5, 0.25, 0.25]
+ -> Omni3D 50%, CA-1M 25%, Waymo 25%.
+
+ epoch_dataset_idx: If set, one epoch = the specified dataset sees
+ every sample once. num_samples is computed as:
+ size[idx] / proportion[idx]
+
+ Supports distributed training (splits indices across ranks).
+
+ Args:
+ dataset: A ConcatDataset (e.g., DataPipe with multiple datasets).
+ dataset_ratios: Per-dataset sampling weight. Mutually exclusive
+ with target_proportions.
+ target_proportions: Per-dataset target proportion (must sum to 1).
+ Mutually exclusive with dataset_ratios.
+ epoch_dataset_idx: If set, one epoch = this dataset sees all its
+ samples once. Overrides num_samples.
+ num_samples: Total samples per epoch. If None and
+ epoch_dataset_idx is None, uses sum of all dataset sizes.
+ shuffle: Whether to shuffle indices each epoch.
+ seed: Random seed for reproducibility.
+ """
+
+ def __init__(
+ self,
+ dataset: ConcatDataset,
+ dataset_ratios: list[float] | None = None,
+ target_proportions: list[float] | None = None,
+ epoch_dataset_idx: int | None = None,
+ num_samples: int | None = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ ) -> None:
+ """Creates an instance of the class."""
+ assert isinstance(dataset, ConcatDataset), (
+ "dataset must be a ConcatDataset (e.g., DataPipe)"
+ )
+ assert (dataset_ratios is None) != (target_proportions is None), (
+ "Exactly one of dataset_ratios or target_proportions "
+ "must be provided"
+ )
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.seed = seed
+ self.epoch = 0
+
+ num_datasets = len(dataset.datasets)
+ sizes = [len(d) for d in dataset.datasets]
+
+ if target_proportions is not None:
+ assert len(target_proportions) == num_datasets, (
+ f"target_proportions length ({len(target_proportions)}) "
+ f"must match number of sub-datasets ({num_datasets})"
+ )
+ assert abs(sum(target_proportions) - 1.0) < 1e-6, (
+ f"target_proportions must sum to 1.0, "
+ f"got {sum(target_proportions)}"
+ )
+ # weight per sample = proportion_i / size_i
+ # Expected count: num_samples * (prop_i/size_i * size_i) / sum(prop) = num_samples * prop_i
+ sample_weights = []
+ for size, prop in zip(sizes, target_proportions):
+ w = prop / size
+ sample_weights.extend([w] * size)
+ proportions = list(target_proportions)
+ else:
+ assert len(dataset_ratios) == num_datasets, (
+ f"dataset_ratios length ({len(dataset_ratios)}) must "
+ f"match number of sub-datasets ({num_datasets})"
+ )
+ # weight_i = ratio_i / size_i
+ sample_weights = []
+ for size, ratio in zip(sizes, dataset_ratios):
+ w = ratio / size
+ sample_weights.extend([w] * size)
+ # Compute actual proportions for epoch_dataset_idx
+ raw = [r / s for r, s in zip(dataset_ratios, sizes)]
+ total = sum(raw)
+ proportions = [r / total for r in raw]
+
+ self.weights = torch.tensor(sample_weights, dtype=torch.float64)
+
+ # Determine num_samples (epoch length)
+ if epoch_dataset_idx is not None:
+ assert 0 <= epoch_dataset_idx < num_datasets
+ # 1 epoch = dataset[idx] sees all samples once
+ self.num_samples = int(
+ sizes[epoch_dataset_idx] / proportions[epoch_dataset_idx]
+ )
+ print(
+ f"[DatasetRatioSampler] epoch_dataset_idx={epoch_dataset_idx}"
+ f" ({sizes[epoch_dataset_idx]} samples,"
+ f" {proportions[epoch_dataset_idx]:.1%} proportion)"
+ f" -> {self.num_samples} samples/epoch"
+ )
+ elif num_samples is not None:
+ self.num_samples = num_samples
+ else:
+ self.num_samples = len(dataset)
+
+ # Log dataset info
+ for i, (size, prop) in enumerate(zip(sizes, proportions)):
+ expected = int(self.num_samples * prop)
+ print(
+ f"[DatasetRatioSampler] dataset[{i}]: "
+ f"size={size}, proportion={prop:.1%}, "
+ f"~{expected} samples/epoch"
+ )
+
+ # Distributed settings
+ self.world_size = get_world_size()
+ self.rank = get_rank()
+ # Each rank gets an equal share
+ self.num_samples_per_rank = math.ceil(
+ self.num_samples / self.world_size
+ )
+ self.total_size = self.num_samples_per_rank * self.world_size
+
+ def __iter__(self) -> Iterator[int]:
+ """Generate sampled indices."""
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+
+ indices = torch.multinomial(
+ self.weights,
+ num_samples=self.total_size,
+ replacement=True,
+ generator=g,
+ ).tolist()
+
+ # Subsample for this rank
+ indices = indices[self.rank::self.world_size]
+ assert len(indices) == self.num_samples_per_rank
+
+ return iter(indices)
+
+ def __len__(self) -> int:
+ """Return number of samples for this rank."""
+ return self.num_samples_per_rank
+
+ def set_epoch(self, epoch: int) -> None:
+ """Set epoch for shuffling (required for distributed training)."""
+ self.epoch = epoch
+
+
+def build_train_dataloader_with_ratios(
+ dataset: DataPipe,
+ dataset_ratios: list[float] | None = None,
+ target_proportions: list[float] | None = None,
+ epoch_dataset_idx: int | None = None,
+ num_samples: int | None = None,
+ **kwargs,
+) -> DataLoader[DictDataOrList]:
+ """Build training dataloader with per-dataset ratio sampling.
+
+ Thin wrapper around vis4d's build_train_dataloader that creates a
+ DatasetRatioSampler at runtime (when the dataset is instantiated).
+
+ Two ways to specify dataset mixing:
+
+ 1. dataset_ratios: raw weights (original, for backwards compat).
+ Example: dataset_ratios=[1.0, 1.0] -> equal weight per dataset.
+
+ 2. target_proportions: direct proportions (must sum to 1.0).
+ Example: target_proportions=[0.5, 0.25, 0.25]
+
+ Args:
+ dataset: DataPipe (ConcatDataset) with multiple sub-datasets.
+ dataset_ratios: Per-dataset sampling weight (mutually exclusive
+ with target_proportions).
+ target_proportions: Per-dataset target proportion, must sum to 1.
+ epoch_dataset_idx: If set, 1 epoch = this dataset sees all its
+ samples once. Overrides num_samples.
+ num_samples: Total samples per epoch (overridden by
+ epoch_dataset_idx).
+ **kwargs: All other arguments forwarded to build_train_dataloader.
+ """
+ sampler = DatasetRatioSampler(
+ dataset,
+ dataset_ratios=dataset_ratios,
+ target_proportions=target_proportions,
+ epoch_dataset_idx=epoch_dataset_idx,
+ num_samples=num_samples,
+ shuffle=kwargs.pop("shuffle", True),
+ )
+ # shuffle must be False when using custom sampler (PyTorch requirement)
+ return build_train_dataloader(
+ dataset=dataset, sampler=sampler, shuffle=False, **kwargs
+ )
diff --git a/wilddet3d/data/transforms/__init__.py b/wilddet3d/data/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f92c4b6244dea653d83d9a1f4396519f679b27
--- /dev/null
+++ b/wilddet3d/data/transforms/__init__.py
@@ -0,0 +1 @@
+"""Data transforms."""
diff --git a/wilddet3d/data/transforms/__pycache__/__init__.cpython-311.pyc b/wilddet3d/data/transforms/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecdb728e1ea33c5b26f108392223efb1366f2a82
Binary files /dev/null and b/wilddet3d/data/transforms/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/data/transforms/__pycache__/pad.cpython-311.pyc b/wilddet3d/data/transforms/__pycache__/pad.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d9577488bacd77ecd4c1ea79244d3e4fa0fae2b
Binary files /dev/null and b/wilddet3d/data/transforms/__pycache__/pad.cpython-311.pyc differ
diff --git a/wilddet3d/data/transforms/__pycache__/resize.cpython-311.pyc b/wilddet3d/data/transforms/__pycache__/resize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f2b24e8f082cbbd4983ed8ca306bcc01997ca59
Binary files /dev/null and b/wilddet3d/data/transforms/__pycache__/resize.cpython-311.pyc differ
diff --git a/wilddet3d/data/transforms/crop.py b/wilddet3d/data/transforms/crop.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3bc38cc72158d85e1e8f8795528df74d0ecf313
--- /dev/null
+++ b/wilddet3d/data/transforms/crop.py
@@ -0,0 +1,43 @@
+"""Crop transforms."""
+
+from __future__ import annotations
+
+from vis4d.common.typing import (
+ NDArrayBool,
+ NDArrayF32,
+ NDArrayI64,
+)
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.transforms.base import Transform
+
+
+@Transform(
+ in_keys=[
+ K.boxes3d,
+ K.boxes3d_classes,
+ K.boxes3d_track_ids,
+ "transforms.crop.keep_mask",
+ ],
+ out_keys=[K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids],
+)
+class CropBoxes3D:
+ """Crop 3D bounding boxes."""
+
+ def __call__(
+ self,
+ boxes_list: list[NDArrayF32],
+ classes_list: list[NDArrayI64],
+ track_ids_list: list[NDArrayI64] | None,
+ keep_mask_list: list[NDArrayBool],
+ ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
+ """Crop 3D bounding boxes."""
+ for i, (boxes, classes, keep_mask) in enumerate(
+ zip(boxes_list, classes_list, keep_mask_list)
+ ):
+ boxes_list[i] = boxes[keep_mask]
+ classes_list[i] = classes[keep_mask]
+
+ if track_ids_list is not None:
+ track_ids_list[i] = track_ids_list[i][keep_mask]
+
+ return boxes_list, classes_list, track_ids_list
diff --git a/wilddet3d/data/transforms/language.py b/wilddet3d/data/transforms/language.py
new file mode 100644
index 0000000000000000000000000000000000000000..07e5c0ab40242a8e0550799a41de8c1f1cead87a
--- /dev/null
+++ b/wilddet3d/data/transforms/language.py
@@ -0,0 +1,267 @@
+"""Language related transforms."""
+
+from __future__ import annotations
+
+import random
+import re
+
+import numpy as np
+from transformers import AutoTokenizer
+from vis4d.common.logging import rank_zero_warn
+from vis4d.common.typing import NDArrayF32, NDArrayI64
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.transforms.base import Transform
+
+
+def clean_name(name: str) -> str:
+ """Clean the name."""
+ name = re.sub(r"\(.*\)", "", name)
+ name = re.sub(r"_", " ", name)
+ name = re.sub(r" ", " ", name)
+ name = name.lower()
+ return name
+
+
+def generate_senetence_given_labels(
+ positive_label_list: list[int],
+ negative_label_list: list[str],
+ label_map: dict[str, str],
+) -> tuple[dict[int, list[list[int]]], str, dict[int, int]]:
+ """Generate a sentence given positive and negative labels."""
+ label_to_positions = {}
+
+ label_list = negative_label_list + positive_label_list
+
+ random.shuffle(label_list)
+
+ pheso_caption = ""
+
+ label_remap_dict = {}
+ for index, label in enumerate(label_list):
+ start_index = len(pheso_caption)
+
+ pheso_caption += clean_name(label_map[str(label)])
+
+ end_index = len(pheso_caption)
+
+ if label in positive_label_list:
+ label_to_positions[index] = [[start_index, end_index]]
+ label_remap_dict[int(label)] = index
+
+ pheso_caption += ". "
+
+ return label_to_positions, pheso_caption, label_remap_dict
+
+
+@Transform(
+ [
+ "dataset_type",
+ K.boxes2d,
+ K.boxes2d_classes,
+ K.boxes2d_names,
+ "label_map",
+ "positive_positions",
+ ],
+ [K.boxes2d, K.boxes2d_classes, K.boxes2d_names, "tokens_positive"],
+)
+class RandomSamplingNegPos:
+ """Randomly sample negative and positive labels for object detection."""
+
+ def __init__(
+ self,
+ tokenizer_name: str = "bert-base-uncased",
+ num_sample_negative: int = 85,
+ max_tokens: int = 256,
+ full_sampling_prob: float = 0.5,
+ ) -> None:
+ """Creates an instance of RandomSamplingNegPos."""
+ if AutoTokenizer is None:
+ raise RuntimeError(
+ "transformers is not installed, please install it by: "
+ "pip install transformers."
+ )
+
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ self.num_sample_negative = num_sample_negative
+ self.full_sampling_prob = full_sampling_prob
+ self.max_tokens = max_tokens
+
+ def __call__(
+ self,
+ dataset_type_list: list[str],
+ boxes_list: list[NDArrayF32],
+ class_ids_list: list[NDArrayI64],
+ texts_list: list[str] | None = None,
+ label_map_list: dict | None = None,
+ positive_positions_list: list[dict] | None = None,
+ ) -> tuple[
+ list[NDArrayF32],
+ list[NDArrayI64],
+ list[str],
+ list[dict[int, list[list[int]]]],
+ ]:
+ """Randomly sample negative and positive labels."""
+ new_texts_list = []
+ tokens_positive_list = []
+ for i, (boxes, class_ids) in enumerate(
+ zip(boxes_list, class_ids_list)
+ ):
+ if dataset_type_list[i] == "OD":
+ assert (
+ label_map_list[i] is not None
+ ), "label_map should not be None"
+ boxes_list[i], class_ids_list[i], text, tokens_positive = (
+ self.od_aug(boxes, class_ids, label_map_list[i])
+ )
+ new_texts_list.append(text)
+ tokens_positive_list.append(tokens_positive)
+ else:
+ assert (
+ positive_positions_list[i] is not None
+ ), "positive_positions should not be None"
+ tokens_positive = self.vg_aug(
+ class_ids, positive_positions_list[i]
+ )
+ new_texts_list.append(texts_list[i])
+ tokens_positive_list.append(tokens_positive)
+
+ return boxes_list, class_ids_list, new_texts_list, tokens_positive_list
+
+ def vg_aug(self, class_ids: NDArrayI64, positive_positions):
+ """Visual Genome data augmentation."""
+ positive_label_list = np.unique(class_ids).tolist()
+
+ label_to_positions = {}
+ for label in positive_label_list:
+ label_to_positions[label] = positive_positions[label]
+
+ return label_to_positions
+
+ def od_aug(
+ self,
+ boxes: NDArrayF32,
+ class_ids: NDArrayI64,
+ label_map: dict,
+ ) -> tuple[NDArrayF32, NDArrayI64, str, dict[int, list[list[int]]]]:
+ """Object detection data augmentation."""
+ original_box_num = len(class_ids)
+
+ # If the category name is in the format of 'a/b' (in object365),
+ # we randomly select one of them.
+ for key, value in label_map.items():
+ if "/" in value:
+ label_map[key] = random.choice(value.split("/")).strip()
+
+ keep_box_index, class_ids, positive_caption_length = (
+ self.check_for_positive_overflow(class_ids, label_map)
+ )
+
+ boxes = boxes[keep_box_index]
+
+ if len(boxes) < original_box_num:
+ rank_zero_warn(
+ f"Remove {original_box_num - len(boxes)} boxes due to "
+ "positive caption overflow."
+ )
+
+ valid_negative_indexes = list(label_map.keys())
+
+ positive_label_list = np.unique(class_ids).tolist()
+
+ full_negative = self.num_sample_negative
+ if full_negative > len(valid_negative_indexes):
+ full_negative = len(valid_negative_indexes)
+
+ outer_prob = random.random()
+
+ if outer_prob < self.full_sampling_prob:
+ # c. probability_full: add both all positive and all negatives
+ num_negatives = full_negative
+ else:
+ if random.random() < 1.0:
+ num_negatives = np.random.choice(max(1, full_negative)) + 1
+ else:
+ num_negatives = full_negative
+
+ # Keep some negatives
+ negative_label_list = set()
+ if num_negatives != -1:
+ if num_negatives > len(valid_negative_indexes):
+ num_negatives = len(valid_negative_indexes)
+
+ for i in np.random.choice(
+ valid_negative_indexes, size=num_negatives, replace=False
+ ):
+ if int(i) not in positive_label_list:
+ negative_label_list.add(i)
+
+ random.shuffle(positive_label_list)
+
+ negative_label_list = list(negative_label_list)
+ random.shuffle(negative_label_list)
+
+ negative_max_length = self.max_tokens - positive_caption_length
+ screened_negative_label_list = []
+
+ for negative_label in negative_label_list:
+ label_text = clean_name(label_map[str(negative_label)]) + ". "
+
+ tokenized = self.tokenizer.tokenize(label_text)
+
+ negative_max_length -= len(tokenized)
+
+ if negative_max_length > 0:
+ screened_negative_label_list.append(negative_label)
+ else:
+ break
+
+ negative_label_list = screened_negative_label_list
+ label_to_positions, pheso_caption, label_remap_dict = (
+ generate_senetence_given_labels(
+ positive_label_list, negative_label_list, label_map
+ )
+ )
+
+ # label remap
+ if len(class_ids) > 0:
+ class_ids = np.vectorize(lambda x: label_remap_dict[x])(class_ids)
+
+ return boxes, class_ids, pheso_caption, label_to_positions
+
+ def check_for_positive_overflow(
+ self, class_ids: NDArrayI64, label_map: dict[str, str]
+ ) -> tuple[list[int], NDArrayI64, int]:
+ """Check if having too many positive labels."""
+ # generate a caption by appending the positive labels
+ positive_label_list = np.unique(class_ids).tolist()
+
+ # random shuffule so we can sample different annotations
+ # at different epochs
+ random.shuffle(positive_label_list)
+
+ kept_lables = []
+ length = 0
+ for _, label in enumerate(positive_label_list):
+ label_text = clean_name(label_map[str(label)]) + ". "
+
+ tokenized = self.tokenizer.tokenize(label_text)
+
+ length += len(tokenized)
+
+ if length > self.max_tokens:
+ break
+ else:
+ kept_lables.append(label)
+
+ keep_box_index = []
+ keep_gt_labels = []
+ for i, class_id in enumerate(class_ids):
+ if class_id in kept_lables:
+ keep_box_index.append(i)
+ keep_gt_labels.append(class_id)
+
+ return (
+ keep_box_index,
+ np.array(keep_gt_labels, dtype=np.int64),
+ length,
+ )
diff --git a/wilddet3d/data/transforms/masks.py b/wilddet3d/data/transforms/masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c0849a46228f3e062f81cc8f7ab818a8fea410
--- /dev/null
+++ b/wilddet3d/data/transforms/masks.py
@@ -0,0 +1,120 @@
+"""Spatial transforms for per-box binary masks (masks2d).
+
+masks2d is a list (per image in batch) of (N, H, W) uint8 arrays,
+where N is the number of boxes in that image. Each mask slice is a
+binary mask for one box. These transforms keep masks aligned with
+images, boxes2d, and depth_maps through the spatial augmentation
+pipeline.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from vis4d.data.transforms.base import Transform
+
+MASKS2D_KEY = "masks2d"
+
+
+@Transform(
+ [MASKS2D_KEY, "transforms.resize.target_shape"],
+ MASKS2D_KEY,
+)
+class ResizeMasks2D:
+ """Resize per-box masks using nearest interpolation."""
+
+ def __call__(
+ self,
+ masks_list,
+ target_shapes,
+ ):
+ """Resize masks."""
+ if masks_list is None:
+ return masks_list
+ for i, (masks, target_shape) in enumerate(
+ zip(masks_list, target_shapes)
+ ):
+ if masks is None or len(masks) == 0:
+ continue
+ # masks: (N, H, W) uint8
+ t = torch.from_numpy(masks).float().unsqueeze(1) # (N,1,H,W)
+ t = F.interpolate(
+ t, size=target_shape, mode="nearest"
+ )
+ masks_list[i] = (
+ t.squeeze(1).to(torch.uint8).numpy()
+ ) # (N, H', W')
+ return masks_list
+
+
+@Transform([MASKS2D_KEY, "transforms.crop.crop_box"], MASKS2D_KEY)
+class CropMasks2D:
+ """Crop per-box masks."""
+
+ def __call__(
+ self,
+ masks_list,
+ crop_box_list,
+ ):
+ """Crop masks."""
+ if masks_list is None:
+ return masks_list
+ for i, (masks, crop_box) in enumerate(
+ zip(masks_list, crop_box_list)
+ ):
+ if masks is None or len(masks) == 0:
+ continue
+ x1, y1, x2, y2 = crop_box
+ masks_list[i] = masks[:, y1:y2, x1:x2]
+ return masks_list
+
+
+@Transform(MASKS2D_KEY, MASKS2D_KEY)
+class FlipMasks2D:
+ """Flip per-box masks horizontally."""
+
+ def __call__(
+ self,
+ masks_list,
+ ):
+ """Flip masks."""
+ if masks_list is None:
+ return masks_list
+ for i, masks in enumerate(masks_list):
+ if masks is None or len(masks) == 0:
+ continue
+ masks_list[i] = np.ascontiguousarray(
+ masks[:, :, ::-1]
+ )
+ return masks_list
+
+
+@Transform([MASKS2D_KEY, "transforms.pad"], MASKS2D_KEY)
+class CenterPadMasks2D:
+ """Center-pad per-box masks."""
+
+ def __call__(
+ self,
+ masks_list,
+ pad_params,
+ ):
+ """Pad masks."""
+ if masks_list is None:
+ return masks_list
+ for i, (masks, pad_param) in enumerate(
+ zip(masks_list, pad_params)
+ ):
+ if masks is None or len(masks) == 0:
+ continue
+ pad = (
+ pad_param["pad_left"],
+ pad_param["pad_right"],
+ pad_param["pad_top"],
+ pad_param["pad_bottom"],
+ )
+ t = torch.from_numpy(masks).unsqueeze(1) # (N,1,H,W)
+ t = F.pad(t, pad, mode="constant", value=0)
+ masks_list[i] = t.squeeze(1).numpy() # (N, H', W')
+ return masks_list
diff --git a/wilddet3d/data/transforms/pad.py b/wilddet3d/data/transforms/pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc32a1d0233ca25bd34426a56cf3a7d17ba7d5a
--- /dev/null
+++ b/wilddet3d/data/transforms/pad.py
@@ -0,0 +1,176 @@
+"""Pad transformation."""
+
+from __future__ import annotations
+
+from typing import TypedDict
+
+import torch
+import torch.nn.functional as F
+from vis4d.common.typing import NDArrayF32
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.transforms.base import Transform
+from vis4d.data.transforms.pad import _get_max_shape
+
+
+class PadParam(TypedDict):
+ """Parameters for Reshape."""
+
+ pad_top: int
+ pad_bottom: int
+ pad_left: int
+ pad_right: int
+
+
+@Transform(
+ [K.images, K.input_hw],
+ [K.images, "transforms.pad", K.input_hw, "padding"],
+)
+class CenterPadImages:
+ """Pad batch of images at the bottom right."""
+
+ def __init__(
+ self,
+ stride: int = 32,
+ mode: str = "constant",
+ value: float = 0.0,
+ update_input_hw: bool = False,
+ shape: tuple[int, int] | None = None,
+ pad2square: bool = False,
+ ) -> None:
+ """Creates an instance of PadImage.
+
+ Args:
+ stride (int, optional): Chooses padding size so that the input will
+ be divisible by stride. Defaults to 32.
+ mode (str, optional): Padding mode. One of constant, reflect,
+ replicate or circular. Defaults to "constant".
+ value (float, optional): Value for constant padding.
+ Defaults to 0.0.
+ shape (tuple[int, int], optional): Shape of the padded image
+ (H, W). Defaults to None.
+ pad2square (bool, optional): Pad to square. Defaults to False.
+ """
+ self.stride = stride
+ self.mode = mode
+ self.value = value
+ self.update_input_hw = update_input_hw
+ self.shape = shape
+ self.pad2square = pad2square
+
+ def __call__(
+ self, images: list[NDArrayF32], input_hw: list[tuple[int, int]]
+ ) -> tuple[list[NDArrayF32], list[PadParam], list[tuple[int, int]]]:
+ """Pad images to consistent size."""
+ heights = [im.shape[1] for im in images]
+ widths = [im.shape[2] for im in images]
+
+ max_hw = _get_max_shape(
+ heights, widths, self.stride, self.shape, self.pad2square
+ )
+
+ # generate params for torch pad
+ pad_params = []
+ target_input_hw = []
+ paddings = []
+ for i, (image, h, w) in enumerate(zip(images, heights, widths)):
+ pad_top, pad_bottom = (max_hw[0] - h) // 2, max_hw[0] - h - (
+ max_hw[0] - h
+ ) // 2
+
+ pad_left, pad_right = (max_hw[1] - w) // 2, max_hw[1] - w - (
+ max_hw[1] - w
+ ) // 2
+
+ image_ = torch.from_numpy(image).permute(0, 3, 1, 2)
+ image_ = F.pad(
+ image_,
+ (pad_left, pad_right, pad_top, pad_bottom),
+ self.mode,
+ self.value,
+ )
+ images[i] = image_.permute(0, 2, 3, 1).numpy()
+
+ pad_params.append(
+ PadParam(
+ pad_top=pad_top,
+ pad_bottom=pad_bottom,
+ pad_left=pad_left,
+ pad_right=pad_right,
+ )
+ )
+
+ paddings.append([pad_left, pad_right, pad_top, pad_bottom])
+
+ target_input_hw.append(max_hw)
+
+ if self.update_input_hw:
+ input_hw = target_input_hw
+
+ return images, pad_params, input_hw, paddings
+
+
+@Transform([K.intrinsics, "transforms.pad"], K.intrinsics)
+class CenterPadIntrinsics:
+ """Resize Intrinsics."""
+
+ def __call__(
+ self, intrinsics: list[NDArrayF32], pad_params: list[PadParam]
+ ) -> list[NDArrayF32]:
+ """Scale camera intrinsics when resizing."""
+ for i, intrinsic in enumerate(intrinsics):
+ intrinsic[0, 2] += pad_params[i]["pad_left"]
+ intrinsic[1, 2] += pad_params[i]["pad_top"]
+
+ intrinsics[i] = intrinsic
+ return intrinsics
+
+
+@Transform([K.boxes2d, "transforms.pad"], K.boxes2d)
+class CenterPadBoxes2D:
+ """Pad batch of depth maps at the bottom right."""
+
+ def __call__(
+ self, boxes_list: list[NDArrayF32], pad_params: list[PadParam]
+ ) -> list[NDArrayF32]:
+ """Scale camera intrinsics when resizing."""
+ for i, boxes in enumerate(boxes_list):
+ boxes[:, 0] += pad_params[i]["pad_left"]
+ boxes[:, 1] += pad_params[i]["pad_top"]
+ boxes[:, 2] += pad_params[i]["pad_left"]
+ boxes[:, 3] += pad_params[i]["pad_top"]
+
+ boxes_list[i] = boxes
+
+ return boxes_list
+
+
+@Transform([K.depth_maps, "transforms.pad"], K.depth_maps)
+class CenterPadDepthMaps:
+ """Pad batch of depth maps at the bottom right."""
+
+ def __init__(self, mode: str = "constant", value: int = 0) -> None:
+ """Creates an instance."""
+ self.mode = mode
+ self.value = value
+
+ def __call__(
+ self, depth_maps: list[NDArrayF32], pad_params: list[PadParam]
+ ) -> list[NDArrayF32]:
+ """Pad images to consistent size."""
+
+ # generate params for torch pad
+ for i, (depth, pad_param_dict) in enumerate(
+ zip(depth_maps, pad_params)
+ ):
+ pad_param = (
+ pad_param_dict["pad_left"],
+ pad_param_dict["pad_right"],
+ pad_param_dict["pad_top"],
+ pad_param_dict["pad_bottom"],
+ )
+
+ depth_ = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0)
+ depth_ = F.pad(depth_, pad_param, self.mode, self.value)
+ depth_maps[i] = depth_.squeeze(0).squeeze(0).numpy()
+
+ return depth_maps
diff --git a/wilddet3d/data/transforms/resize.py b/wilddet3d/data/transforms/resize.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ef2b4aac045f5bf11bb19f6ea9dedeb5a7a50fb
--- /dev/null
+++ b/wilddet3d/data/transforms/resize.py
@@ -0,0 +1,121 @@
+"""Resize transformation."""
+
+from __future__ import annotations
+
+import math
+
+import numpy as np
+import torch
+from vis4d.common.typing import NDArrayF32, NDArrayI64
+from vis4d.data.const import CommonKeys as K
+from vis4d.data.transforms.base import Transform
+from vis4d.data.transforms.resize import ResizeParam, resize_tensor
+
+
+@Transform(K.images, ["transforms.resize", K.input_hw])
+class GenResizeParameters:
+ """Generate the parameters for a resize operation."""
+
+ def __init__(
+ self, shape: tuple[int, int], scales: tuple[float, float] | float = 1.0
+ ) -> None:
+ """Create a new instance of the class."""
+ self.shape = shape
+ self.scales = scales
+
+ def __call__(
+ self, images: list[NDArrayF32]
+ ) -> tuple[list[ResizeParam], list[tuple[int, int]]]:
+ """Compute the parameters and put them in the data dict."""
+ if isinstance(self.scales, float):
+ random_scale = self.scales
+ else:
+ random_scale = np.random.uniform(self.scales[0], self.scales[1])
+
+ shape = (
+ math.ceil(self.shape[0] * random_scale - 0.5),
+ math.ceil(self.shape[1] * random_scale - 0.5),
+ )
+
+ output_ratio = shape[1] / shape[0]
+
+ image = images[0]
+
+ input_h, input_w = (image.shape[1], image.shape[2])
+ input_ratio = input_w / input_h
+
+ if output_ratio > input_ratio:
+ scale = shape[0] / input_h
+ else:
+ scale = shape[1] / input_w
+
+ target_shape = (
+ math.ceil(input_h * scale - 0.5),
+ math.ceil(input_w * scale - 0.5),
+ )
+
+ scale_factor = (target_shape[0] / input_h, target_shape[1] / input_w)
+
+ resize_params = [
+ ResizeParam(target_shape=target_shape, scale_factor=scale_factor)
+ ] * len(images)
+ target_shapes = [target_shape] * len(images)
+
+ return resize_params, target_shapes
+
+
+@Transform(
+ [K.panoptic_masks, "transforms.resize.target_shape"], K.panoptic_masks
+)
+class ResizePanopticMasks:
+ """Resize panoptic segmentation masks."""
+
+ def __call__(
+ self,
+ masks_list: list[NDArrayI64],
+ target_shape_list: list[tuple[int, int]],
+ ) -> list[NDArrayI64]:
+ """Resize masks."""
+ for i, (masks, target_shape) in enumerate(
+ zip(masks_list, target_shape_list)
+ ):
+ masks_ = torch.from_numpy(masks)
+ masks_ = (
+ resize_tensor(
+ masks_.float().unsqueeze(0).unsqueeze(0),
+ target_shape,
+ interpolation="nearest",
+ )
+ .type(masks_.dtype)
+ .squeeze(0)
+ .squeeze(0)
+ )
+ masks_list[i] = masks_.numpy()
+ return masks_list
+
+
+@Transform([K.boxes3d, "transforms.resize.scale_factor"], K.boxes3d)
+class ResizeBoxes3D:
+ """Resize list of 2D bounding boxes."""
+
+ def __call__(
+ self,
+ boxes_list: list[NDArrayF32],
+ scale_factors: list[tuple[float, float]],
+ ) -> list[NDArrayF32]:
+ """Resize 2D bounding boxes.
+
+ Args:
+ boxes_list: (list[NDArrayF32]): The bounding boxes to be resized.
+ scale_factors (list[tuple[float, float]]): scaling factors.
+
+ Returns:
+ list[NDArrayF32]: Resized bounding boxes according to parameters in
+ resize.
+ """
+ for i, (boxes, scale_factor) in enumerate(
+ zip(boxes_list, scale_factors)
+ ):
+ boxes[:, 2] /= scale_factor[0]
+ boxes_list[i] = boxes
+ return boxes_list
diff --git a/wilddet3d/data_types.py b/wilddet3d/data_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..c587150c43764b349490b8d6db254da11d9770ee
--- /dev/null
+++ b/wilddet3d/data_types.py
@@ -0,0 +1,229 @@
+"""WildDet3D data types."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, fields
+from typing import List, NamedTuple
+
+import torch
+from torch import Tensor
+
+
+class Det3DOut(NamedTuple):
+ """Output of the detection model.
+
+ boxes (list[Tensor]): 2D bounding boxes of shape [N, 4] in xyxy format.
+ boxes3d (list[Tensor]): 3D bounding boxes of shape [N, 10].
+ scores (list[Tensor]): 2D confidence scores of shape [N,].
+ class_ids (list[Tensor]): class ids of shape [N,].
+ depth_maps (list[Tensor] | None): depth maps for each image.
+ categories (list[list[str]] | None): category names for each detection.
+ predicted_intrinsics (Tensor | None): predicted camera intrinsics (B, 3, 3).
+ scores_3d (list[Tensor] | None): 3D confidence scores of shape [N,].
+ scores_2d (list[Tensor] | None): pure 2D confidence scores of shape [N,].
+ """
+
+ boxes: list[Tensor]
+ boxes3d: list[Tensor]
+ scores: list[Tensor]
+ class_ids: list[Tensor]
+ depth_maps: list[Tensor] | None
+ categories: list[list[str]] | None = None
+ predicted_intrinsics: Tensor | None = None
+ scores_3d: list[Tensor] | None = None
+ scores_2d: list[Tensor] | None = None
+
+
+class WildDet3DOut(NamedTuple):
+ """Output of WildDet3D model.
+
+ All tensors use batch-first format: (B, num_queries, dim)
+ where B = N_prompts (per-prompt batch).
+
+ Coordinate formats:
+ - pred_boxes_2d: normalized xyxy [0, 1]
+ - pred_boxes_3d: encoded 3D params (delta_center, log_depth, log_dims, rot_6d)
+ """
+ # 2D Detection (from SAM3 decoder) - O2O outputs
+ pred_logits: Tensor # (N_prompts, num_queries, 1) - objectness
+ pred_boxes_2d: Tensor # (N_prompts, num_queries, 4) - normalized xyxy
+
+ # 3D Detection (from 3D head) - O2O outputs
+ pred_boxes_3d: Tensor | None # (N_prompts, num_queries, 12) - encoded 3D params
+
+ # Auxiliary outputs for each decoder layer (for deep supervision)
+ aux_outputs: list[dict] | None
+
+ # Geometry backend losses (SILog depth, phi, theta)
+ geom_losses: dict[str, Tensor] | None
+
+ # SAM3 specific outputs
+ presence_logits: Tensor | None # (N_prompts, num_queries, 1)
+ queries: Tensor | None # (N_prompts, num_queries, d_model) - for segmentation
+
+ # Encoder hidden states (for depth head if needed)
+ encoder_hidden_states: Tensor | None # (H*W, N_prompts, d_model)
+
+ # Matching indices from SAM3 (for loss computation)
+ # Format: (batch_idx, src_idx, tgt_idx) from Hungarian matching
+ indices: tuple | None = None
+
+ # Normalized cxcywh boxes (needed by SAM3's Boxes loss for L1)
+ pred_boxes_2d_cxcywh: Tensor | None = None # (N_prompts, num_queries, 4) - normalized cxcywh
+
+ # O2M (One-to-Many) outputs from SAM3 DAC mechanism
+ # These are separate outputs from the second half of queries in DAC mode
+ pred_logits_o2m: Tensor | None = None # (N_prompts, num_queries, 1)
+ pred_boxes_2d_o2m: Tensor | None = None # (N_prompts, num_queries, 4) - normalized xyxy
+ pred_boxes_2d_cxcywh_o2m: Tensor | None = None # (N_prompts, num_queries, 4) - normalized cxcywh
+ pred_boxes_3d_o2m: Tensor | None = None # (N_prompts, num_queries, 12) - encoded 3D params
+
+ # 3D confidence head outputs (camera+depth conditioned)
+ pred_conf_3d: Tensor | None = None # (N_prompts, num_queries, 1)
+ pred_conf_3d_o2m: Tensor | None = None # (N_prompts, num_queries, 1)
+
+ def __getitem__(self, key: str):
+ """Support dict-like access for vis4d data connector compatibility."""
+ return getattr(self, key)
+
+ def keys(self):
+ """Return field names for dict-like iteration."""
+ return [f.name for f in fields(self)]
+
+ def __contains__(self, key: str) -> bool:
+ """Support 'in' operator for dict-like access."""
+ return hasattr(self, key)
+
+
+@dataclass
+class WildDet3DInput:
+ """WildDet3D batched input format (per-prompt batch).
+
+ Design Principles:
+ 1. Aligned with SAM3's BatchedDatapoint
+ 2. Added 3D detection required fields (intrinsics, gt_boxes3d)
+ 3. Supports three modes: TEXT / GEOMETRIC / TEXT_GEOMETRIC
+
+ Coordinate Format Convention:
+ - geo_boxes: normalized [0,1] cxcywh (SAM3 Geometry Encoder input)
+ - gt_boxes2d: normalized [0,1] xyxy (for loss computation)
+ - Model output pred_boxes_2d: normalized xyxy [0,1]
+ """
+
+ # ========== Image-level (Backbone processing) ==========
+ images: Tensor # (B_images, 3, H, W)
+ intrinsics: Tensor # (B_images, 3, 3)
+
+ # ========== Prompt-level (expanded) ==========
+ img_ids: Tensor # (N_prompts,) - which image each prompt belongs to
+ text_ids: Tensor # (N_prompts,) - text index for each prompt
+ unique_texts: List[str] # deduplicated texts (including "visual" placeholder)
+
+ # Geometry input - batch-first: (N_prompts, max_K, 4) - normalized cxcywh
+ # Converted to sequence-first when passed to SAM3 Prompt class
+ geo_boxes: Tensor | None = None # (N_prompts, max_K, 4)
+ geo_boxes_mask: Tensor | None = None # (N_prompts, max_K) - True=padding
+ geo_box_labels: Tensor | None = None # (N_prompts, max_K) - 0/1 for neg/pos
+
+ # Point prompts (optional)
+ geo_points: Tensor | None = None # (N_prompts, max_P, 2) - (x, y)
+ geo_points_mask: Tensor | None = None # (N_prompts, max_P) - True=padding
+ geo_point_labels: Tensor | None = None # (N_prompts, max_P) - 0/1 for neg/pos
+
+ # Ground Truth - normalized xyxy (training)
+ gt_boxes2d: Tensor | None = None # (N_prompts, max_gt, 4) - xyxy
+ gt_boxes3d: Tensor | None = None # (N_prompts, max_gt, 12) - 3D params
+ num_gts: Tensor | None = None # (N_prompts,) - number of GTs per prompt
+ gt_category_ids: Tensor | None = None # (N_prompts, max_gt)
+
+ # Ignore boxes for negative loss suppression (per-prompt, same category)
+ # Objects marked ignore in Omni3D (truncated, occluded, behind camera, etc.)
+ # are not used as GT but should not cause FP penalty either.
+ ignore_boxes2d: Tensor | None = None # (N_prompts, max_ignore, 4) normalized xyxy
+ num_ignores: Tensor | None = None # (N_prompts,) number of ignore boxes per prompt
+
+ # Query type tracking (collator-level label, does NOT control SAM3 internal matching).
+ # 0=TEXT, 1=VISUAL, 2=GEOMETRY, 3=VISUAL+LABEL, 4=GEOMETRY+LABEL
+ # "multi-target" (0,1,3): num_gts can be > 1 (all instances of a category)
+ # "single-target" (2,4): num_gts = 1 (one selected instance)
+ # NOTE: SAM3's DAC mechanism (internal o2o/o2m matcher) always runs
+ # both branches regardless of this field.
+ query_types: Tensor | None = None # (N_prompts,) int
+
+ # Metadata for evaluation/visualization
+ sample_names: List[str] | None = None # (B_images,) - image identifiers
+ dataset_name: List[str] | None = None # (B_images,) - dataset names for evaluator
+ original_hw: List[tuple] | None = None # (B_images,) - original (H, W) per image
+ original_images: Tensor | None = None # (B_images, 3, H_orig, W_orig) - unresized
+ original_intrinsics: Tensor | None = None # (B_images, 3, 3) - intrinsics before resize
+
+ # CenterPad offsets [pad_left, pad_right, pad_top, pad_bottom]
+ padding: List | None = None # (B_images,) - padding offsets per image
+
+ # Depth Ground Truth (for geometry backend supervision)
+ depth_gt: Tensor | None = None # (B_images, 1, H, W) depth map
+ depth_mask: Tensor | None = None # (B_images, H, W) valid depth mask
+
+ # Key aliases for vis4d DataConnector compatibility
+ # Maps expected DataLoader keys to actual dataclass field names
+ _KEY_ALIASES = {
+ # Target boxes (for loss computation)
+ "boxes2d": "gt_boxes2d",
+ "boxes3d": "gt_boxes3d",
+ "boxes2d_classes": "gt_category_ids",
+ # Geometric prompts (for SAM3 input)
+ "prompt_boxes": "geo_boxes",
+ "prompt_box_labels": "geo_box_labels",
+ # Not available in per-prompt batch
+ "depth_maps": None,
+ "original_hw": None,
+ "original_images": None,
+ "padding": None,
+ }
+
+ def __getitem__(self, key: str):
+ """Support dict-like access for vis4d data connector compatibility.
+
+ Supports both actual field names and aliased keys from raw DataLoader.
+ """
+ # Check alias first
+ if key in self._KEY_ALIASES:
+ aliased_key = self._KEY_ALIASES[key]
+ if aliased_key is None:
+ return None # Field not available
+ return getattr(self, aliased_key)
+
+ # Handle special computed fields
+ if key == "input_hw":
+ # Return (H, W) from images shape
+ return (self.images.shape[2], self.images.shape[3])
+
+ # Direct field access
+ if hasattr(self, key):
+ return getattr(self, key)
+
+ # Return None for unknown keys instead of raising error
+ return None
+
+ def keys(self):
+ """Return field names for dict-like iteration."""
+ return [f.name for f in fields(self)]
+
+ def __contains__(self, key: str) -> bool:
+ """Support 'in' operator for dict-like access."""
+ return hasattr(self, key)
+
+ @property
+ def num_images(self) -> int:
+ """Number of unique images."""
+ return self.images.shape[0]
+
+ @property
+ def num_prompts(self) -> int:
+ """Number of prompts (batch size for decoder)."""
+ return self.img_ids.shape[0]
+
+ @property
+ def device(self) -> torch.device:
+ """Device of the batch."""
+ return self.images.device
diff --git a/wilddet3d/depth/__init__.py b/wilddet3d/depth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0090ead850330426c45d79494176e5df9133eb5b
--- /dev/null
+++ b/wilddet3d/depth/__init__.py
@@ -0,0 +1,10 @@
+"""Depth estimation backends."""
+
+from .base import GeometryBackendBase, GeometryBackendOutput
+from .lingbot_backend import LingbotDepthBackend
+
+__all__ = [
+ "GeometryBackendBase",
+ "GeometryBackendOutput",
+ "LingbotDepthBackend",
+]
diff --git a/wilddet3d/depth/__pycache__/__init__.cpython-311.pyc b/wilddet3d/depth/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ac8232468f56074bfc14d446d055735c07afb79
Binary files /dev/null and b/wilddet3d/depth/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/depth/__pycache__/base.cpython-311.pyc b/wilddet3d/depth/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4dcadc33aef3774e2b25009a974e589c037d330a
Binary files /dev/null and b/wilddet3d/depth/__pycache__/base.cpython-311.pyc differ
diff --git a/wilddet3d/depth/__pycache__/depth_fusion.cpython-311.pyc b/wilddet3d/depth/__pycache__/depth_fusion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bb84f3cebb34bbafa5c669923b6daa82af1ac51
Binary files /dev/null and b/wilddet3d/depth/__pycache__/depth_fusion.cpython-311.pyc differ
diff --git a/wilddet3d/depth/__pycache__/lingbot_backend.cpython-311.pyc b/wilddet3d/depth/__pycache__/lingbot_backend.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea80e0fd3d9f5d3fc767ea6945614a384b724758
Binary files /dev/null and b/wilddet3d/depth/__pycache__/lingbot_backend.cpython-311.pyc differ
diff --git a/wilddet3d/depth/base.py b/wilddet3d/depth/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c856ce6ff1fb60e03b33c447dd24c071cfad4daf
--- /dev/null
+++ b/wilddet3d/depth/base.py
@@ -0,0 +1,187 @@
+"""GeometryBackendBase: Abstract interface for depth/geometry backends.
+
+Each backend is a self-contained geometry module that:
+- Extracts features using its own method (DINO, Swin+FPN, etc.)
+- Runs its own depth head
+- Computes its own geometry losses
+
+The interface provides a unified way to plug different geometry systems
+into the 3D-MOOD / GroundingDINO3D framework.
+"""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TypedDict
+
+import torch
+from torch import Tensor, nn
+
+
+class GeometryBackendOutput(TypedDict, total=False):
+ """Output dictionary from GeometryBackend.
+
+ Attributes:
+ depth_map: Predicted depth map [B, 1, H, W] in metric scale.
+ depth_latents: Depth latent tokens [B, N, C] for 3D head.
+ Dimension C is aligned to target_latent_dim (default: 128).
+ K_pred: Predicted camera intrinsics [B, 3, 3] (optional).
+ ray_intrinsics: Intrinsics to use for ray_embeddings generation [B, 3, 3].
+ This may be adjusted intrinsics for DINOv2-based backends.
+ ray_image_hw: Image (H, W) to use for ray_embeddings generation.
+ This corresponds to the space where depth_latents were computed.
+ ray_downsample: Downsample factor for ray_embeddings (8 or 16).
+ Must match the spatial resolution of depth_latents.
+ aux: Auxiliary outputs (rays, points, confidence, etc.).
+ losses: Dictionary of geometry losses (only in training).
+ """
+
+ depth_map: Tensor
+ depth_latents: Tensor
+ K_pred: Tensor | None
+ ray_intrinsics: Tensor
+ ray_image_hw: tuple[int, int]
+ ray_downsample: int
+ aux: dict[str, Tensor]
+ losses: dict[str, Tensor]
+
+
+class GeometryBackendBase(nn.Module, ABC):
+ """Abstract base class for geometry backends.
+
+ Each concrete implementation wraps a complete geometry pipeline:
+ - Feature extraction (backbone + neck specific to this backend)
+ - Depth head
+ - Loss computation
+
+ This allows switching between different depth systems (UniDepthHead,
+ DetAny3D, UniDepthV2) without changing the main GroundingDINO3D code.
+
+ Args:
+ detach_depth_latents: If True, detach depth_latents before returning.
+ This prevents gradients from the 3D head from flowing back to
+ the depth head. Useful when you want to freeze depth training
+ but still use its features for 3D detection.
+ """
+
+ # Whether this backend's depth decoder already incorporates ray/camera info.
+ # If True, the 3D head does NOT need a separate camera prompt branch,
+ # because the depth_latents are already ray-aware.
+ # - UniDepthV2 / DetAny3D: True (decoder fuses rays internally)
+ # - UniDepthHead (v1): False (no ray info in decoder)
+ is_ray_aware: bool = False
+
+ def __init__(self, detach_depth_latents: bool = False) -> None:
+ """Initialize the geometry backend.
+
+ Args:
+ detach_depth_latents: Whether to detach depth_latents from the graph.
+ """
+ super().__init__()
+ self.detach_depth_latents = detach_depth_latents
+
+ def _maybe_detach_latents(self, depth_latents: Tensor | None) -> Tensor | None:
+ """Optionally detach depth latents from computation graph.
+
+ Args:
+ depth_latents: Depth latents [B, N, C] or None
+
+ Returns:
+ Detached latents if detach_depth_latents is True, otherwise unchanged
+ """
+ if depth_latents is not None and self.detach_depth_latents:
+ return depth_latents.detach()
+ return depth_latents
+
+ @abstractmethod
+ def forward_train(
+ self,
+ images: Tensor,
+ depth_feats: list[Tensor] | None,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ depth_gt: Tensor | None = None,
+ depth_mask: Tensor | None = None,
+ **kwargs,
+ ) -> GeometryBackendOutput:
+ """Forward pass for training.
+
+ Args:
+ images: Input images [B, 3, H, W].
+ depth_feats: Multi-scale features from FPN [B, C, H_i, W_i] (for
+ backends that use external features like UniDepthHead).
+ Can be None for backends with their own encoder (e.g., UniDepthV2).
+ intrinsics: Camera intrinsics [B, 3, 3].
+ image_hw: Input image size (H, W).
+ depth_gt: Ground truth depth [B, H, W] (optional).
+ depth_mask: Valid depth mask [B, H, W] (optional).
+ **kwargs: Additional backend-specific arguments.
+
+ Returns:
+ GeometryBackendOutput containing:
+ - depth_map: [B, 1, H, W]
+ - depth_latents: [B, N, C]
+ - K_pred: [B, 3, 3] or None
+ - aux: dict of auxiliary outputs
+ - losses: dict of loss tensors
+ """
+ raise NotImplementedError
+
+ @torch.no_grad()
+ @abstractmethod
+ def forward_test(
+ self,
+ images: Tensor,
+ depth_feats: list[Tensor] | None,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ **kwargs,
+ ) -> GeometryBackendOutput:
+ """Forward pass for inference (no loss computation).
+
+ Args:
+ images: Input images [B, 3, H, W].
+ depth_feats: Multi-scale features from FPN (optional).
+ intrinsics: Camera intrinsics [B, 3, 3].
+ image_hw: Input image size (H, W).
+ **kwargs: Additional backend-specific arguments.
+
+ Returns:
+ GeometryBackendOutput containing:
+ - depth_map: [B, 1, H, W]
+ - depth_latents: [B, N, C]
+ - K_pred: [B, 3, 3] or None
+ - aux: dict of auxiliary outputs
+ - losses: empty dict
+ """
+ raise NotImplementedError
+
+ def forward(
+ self,
+ images: Tensor,
+ depth_feats: list[Tensor] | None,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ depth_gt: Tensor | None = None,
+ depth_mask: Tensor | None = None,
+ **kwargs,
+ ) -> GeometryBackendOutput:
+ """Forward pass (dispatches to train or test based on mode)."""
+ if self.training:
+ return self.forward_train(
+ images=images,
+ depth_feats=depth_feats,
+ intrinsics=intrinsics,
+ image_hw=image_hw,
+ depth_gt=depth_gt,
+ depth_mask=depth_mask,
+ **kwargs,
+ )
+ return self.forward_test(
+ images=images,
+ depth_feats=depth_feats,
+ intrinsics=intrinsics,
+ image_hw=image_hw,
+ depth_gt=depth_gt,
+ **kwargs,
+ )
diff --git a/wilddet3d/depth/depth_fusion.py b/wilddet3d/depth/depth_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..3487cc477bced2c1fa5f9d48e987263ef542c770
--- /dev/null
+++ b/wilddet3d/depth/depth_fusion.py
@@ -0,0 +1,223 @@
+"""Early Depth Fusion Modules.
+
+Two variants for fusing depth latents into visual features before the encoder:
+
+1. EarlyDepthFusionUniDepthV2 (Concat-Add):
+ Concatenate visual + depth, project back, residual add.
+ delta = W * [P; D]
+ output = P + delta
+
+2. EarlyDepthFusionLingbot (ControlNet-style):
+ LayerNorm depth, project depth only, residual add.
+ delta = W_d @ LayerNorm(D)
+ output = P + delta
+"""
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+class EarlyDepthFusionUniDepthV2(nn.Module):
+ """Concat-Add fusion for UniDepthV2 backend.
+
+ Concatenates visual and depth features, projects back to visual dim,
+ then adds as residual. More expressive than depth-only projection:
+ delta = W_P * P + W_D * D (from concat projection)
+ output = P + delta = (I + W_P) * P + W_D * D
+
+ Args:
+ visual_dim: Dimension of visual features (e.g., 256).
+ depth_dim: Dimension of depth latents (e.g., 256).
+ fusion_type: Kept for config compatibility, ignored.
+ zero_init: Whether to zero-initialize the projection layer.
+ """
+
+ def __init__(
+ self,
+ visual_dim: int = 256,
+ depth_dim: int = 256,
+ fusion_type: str = "concat_add",
+ zero_init: bool = True,
+ ):
+ super().__init__()
+
+ self.visual_dim = visual_dim
+ self.depth_dim = depth_dim
+
+ # Projection: [C + C_depth] -> [C]
+ self.proj = nn.Conv2d(
+ visual_dim + depth_dim,
+ visual_dim,
+ kernel_size=1,
+ bias=True,
+ )
+
+ if zero_init:
+ nn.init.zeros_(self.proj.weight)
+ nn.init.zeros_(self.proj.bias)
+
+ def forward(
+ self,
+ visual_feats: list[Tensor],
+ depth_latents: Tensor,
+ depth_latents_hw: tuple[int, int],
+ ) -> list[Tensor]:
+ """Fuse depth latents into visual features.
+
+ Args:
+ visual_feats: List of visual features [[B, C, H, W]].
+ depth_latents: Depth features [B, N, C_depth].
+ depth_latents_hw: (H_d, W_d) spatial dims of depth latents.
+
+ Returns:
+ List of fused visual features with same shapes as input.
+ """
+ if depth_latents is None or len(visual_feats) == 0:
+ return visual_feats
+
+ B, N, C_depth = depth_latents.shape
+ H_d, W_d = depth_latents_hw
+
+ assert N == H_d * W_d, f"depth_latents N={N} != H_d*W_d={H_d * W_d}"
+
+ # Reshape: [B, N, C_depth] -> [B, C_depth, H_d, W_d]
+ depth_2d = depth_latents.permute(0, 2, 1).reshape(
+ B, C_depth, H_d, W_d
+ )
+
+ fused_feats = []
+ for visual_feat in visual_feats:
+ B_v, C_v, H_v, W_v = visual_feat.shape
+ assert C_v == self.visual_dim
+
+ # Interpolate depth to match visual spatial size
+ if (H_d, W_d) != (H_v, W_v):
+ depth_resized = torch.nn.functional.interpolate(
+ depth_2d,
+ size=(H_v, W_v),
+ mode="bilinear",
+ align_corners=False,
+ )
+ else:
+ depth_resized = depth_2d
+
+ # Concat + project + residual
+ concat_feat = torch.cat([visual_feat, depth_resized], dim=1)
+ proj_feat = self.proj(concat_feat)
+ fused_feat = visual_feat + proj_feat
+
+ fused_feats.append(fused_feat)
+
+ return fused_feats
+
+
+class EarlyDepthFusionLingbot(nn.Module):
+ """ControlNet-style fusion for Lingbot depth backend.
+
+ LayerNorm on depth latents, project depth only, residual add.
+ Visual features never pass through any trainable layer, preserving
+ the pretrained distribution.
+
+ Args:
+ visual_dim: Dimension of visual features (e.g., 256).
+ depth_dim: Dimension of depth latents (e.g., 256).
+ fusion_type: Kept for config compatibility, ignored.
+ zero_init: Whether to zero-initialize the projection layer.
+ """
+
+ def __init__(
+ self,
+ visual_dim: int = 256,
+ depth_dim: int = 256,
+ fusion_type: str = "concat_add",
+ zero_init: bool = True,
+ ):
+ super().__init__()
+
+ self.visual_dim = visual_dim
+ self.depth_dim = depth_dim
+
+ # Normalize depth_latents to unit scale before projection.
+ # depth_latents (raw neck output, std~4.0) and visual features
+ # (SAM3 FPN, std~0.017) differ by ~230x. LayerNorm brings depth
+ # to mean=0, std=1 so the projection sees consistent input scale.
+ self.depth_norm = nn.LayerNorm(depth_dim)
+
+ # Projection: depth_dim -> visual_dim (depth only)
+ self.proj = nn.Conv2d(
+ depth_dim,
+ visual_dim,
+ kernel_size=1,
+ bias=True,
+ )
+
+ if zero_init:
+ nn.init.zeros_(self.proj.weight)
+ nn.init.zeros_(self.proj.bias)
+
+ def forward(
+ self,
+ visual_feats: list[Tensor],
+ depth_latents: Tensor,
+ depth_latents_hw: tuple[int, int],
+ ) -> list[Tensor]:
+ """Fuse depth latents into visual features.
+
+ Args:
+ visual_feats: List of visual features [[B, C, H, W]].
+ depth_latents: Depth features [B, N, C_depth].
+ depth_latents_hw: (H_d, W_d) spatial dims of depth latents.
+
+ Returns:
+ List of fused visual features with same shapes as input.
+ """
+ if depth_latents is None or len(visual_feats) == 0:
+ return visual_feats
+
+ B, N, C_depth = depth_latents.shape
+ H_d, W_d = depth_latents_hw
+
+ assert N == H_d * W_d, f"depth_latents N={N} != H_d*W_d={H_d * W_d}"
+
+ # Normalize depth_latents to unit scale
+ # Cast to match LayerNorm dtype (AMP bf16 compatibility)
+ depth_latents = depth_latents.to(self.depth_norm.weight.dtype)
+ depth_latents = self.depth_norm(depth_latents)
+
+ # Reshape: [B, N, C_depth] -> [B, C_depth, H_d, W_d]
+ depth_2d = depth_latents.permute(0, 2, 1).reshape(
+ B, C_depth, H_d, W_d
+ )
+
+ fused_feats = []
+ for visual_feat in visual_feats:
+ B_v, C_v, H_v, W_v = visual_feat.shape
+ assert C_v == self.visual_dim
+
+ # Interpolate depth to match visual spatial size
+ if (H_d, W_d) != (H_v, W_v):
+ depth_resized = torch.nn.functional.interpolate(
+ depth_2d,
+ size=(H_v, W_v),
+ mode="bilinear",
+ align_corners=False,
+ )
+ else:
+ depth_resized = depth_2d
+
+ # Project depth only + residual add
+ delta = self.proj(depth_resized)
+ fused_feat = visual_feat + delta
+
+ self._last_delta_mean_abs = delta.detach().abs().mean().item()
+
+ fused_feats.append(fused_feat)
+
+ return fused_feats
+
+
+# Backward compatibility alias
+EarlyDepthFusion = EarlyDepthFusionUniDepthV2
diff --git a/wilddet3d/depth/lingbot_backend.py b/wilddet3d/depth/lingbot_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a7dd5fa2e3b9fff89c01fac62b90845849b9eef
--- /dev/null
+++ b/wilddet3d/depth/lingbot_backend.py
@@ -0,0 +1,1543 @@
+"""LingbotDepthBackend: LingBot-Depth geometry backend for 3D-MOOD.
+
+Uses DINOv2 RGB-D encoder with mixed depth input strategy (per-sample):
+- 70% monocular: zero depth input
+- 20% patch-masked: patch-level random masking (60-90% ratio, following
+ the Masked Depth Modeling paper) for depth completion training
+- 10% copy-through: full depth_gt as input
+- Inference: always zero depth (monocular mode)
+
+Intrinsic prediction: MLP on cls_token predicts camera K.
+is_ray_aware = False so the 3D head's camera prompt branch is active.
+
+Depth loss: L1 + MoGe2 affine-invariant losses (global, local, edge)
+ + confidence mask BCE on all valid pixels.
+Camera loss: ray-based MSE (same approach as UniDepthV2).
+"""
+
+from __future__ import annotations
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from .base import GeometryBackendBase, GeometryBackendOutput
+from wilddet3d.ops.ray import generate_rays
+
+import utils3d
+
+
+def backproject_depth_to_points(
+ depth: Tensor, K: Tensor, H: int, W: int
+) -> Tensor:
+ """Back-project depth map to 3D points using camera intrinsics.
+
+ Uses utils3d (same as MoGe2) with normalized intrinsics.
+
+ Args:
+ depth: [B, 1, H, W] or [B, H, W] metric depth.
+ K: [B, 3, 3] camera intrinsics (pixel space).
+ H: Image height.
+ W: Image width.
+
+ Returns:
+ points: [B, H, W, 3] 3D points in camera space (x, y, z).
+ """
+ z = depth.squeeze(1) if depth.ndim == 4 else depth # [B, H, W]
+ # Normalize pixel intrinsics to [0, 1] for utils3d
+ K_norm = K.clone()
+ K_norm[:, 0, 0] /= W
+ K_norm[:, 0, 2] /= W
+ K_norm[:, 1, 1] /= H
+ K_norm[:, 1, 2] /= H
+ return utils3d.pt.depth_map_to_point_map(z, intrinsics=K_norm)
+
+
+class LingbotDepthBackend(GeometryBackendBase):
+ """Backend using LingBot-Depth (DINOv2 RGB-D encoder + ConvStack decoder).
+
+ Loads a pretrained MDMModel and decomposes it into:
+ - encoder: DINOv2_RGBD_Encoder (RGB-D feature extraction)
+ - neck: ConvStack (multiscale refinement)
+ - depth_head: ConvStack (depth regression)
+
+ depth_latents are extracted from neck level 1 output (after 2 ResBlocks,
+ 256-dim, 2x encoder resolution) and pooled to encoder grid size.
+ This matches UniDepthV2's approach of using decoder intermediate features.
+
+ During training, each sample independently gets one of three modes:
+ - monocular (zero depth): prob = monocular_prob (default 0.7)
+ - patch-masked depth: prob = masked_prob (default 0.2)
+ - copy-through (full depth): prob = 1 - monocular - masked (0.1)
+ During inference, always zero depth.
+
+ Args:
+ pretrained_model: Path or HuggingFace repo ID for MDMModel.
+ num_tokens: Number of base tokens for the encoder.
+ target_latent_dim: Target dimension for depth_latents.
+ Neck level 1 outputs 256-dim; if target != 256, a Linear
+ projection is applied. Use 256 to avoid projection.
+ depth_loss_weight: Weight for L1 depth loss.
+ silog_loss_weight: Weight for SILog depth loss (scale-invariant).
+ affine_global_weight: Weight for MoGe2 affine-invariant global loss.
+ affine_local_weight: Weight for MoGe2 affine-invariant local loss.
+ edge_loss_weight: Weight for MoGe2 edge loss.
+ mask_loss_weight: Weight for confidence mask BCE loss.
+ monocular_prob: Probability of zero depth input (training).
+ masked_prob: Probability of patch-masked depth input (training).
+ mask_ratio_range: (min, max) masking ratio for patch-masked mode.
+ mask_patch_size: Patch size for depth masking grid.
+ camera_loss_weight: Weight for ray-based L2 camera loss.
+ detach_depth_latents: Whether to detach depth_latents from graph.
+ encoder_freeze_blocks: Number of encoder transformer blocks to
+ freeze (from the beginning). ViT-L has 24 blocks; e.g. 20
+ freezes blocks[0..19], only training the last 4.
+ """
+
+ # Encoder does not fuse camera rays; 3D head needs camera prompt
+ is_ray_aware: bool = False
+
+ def __init__(
+ self,
+ pretrained_model: str = (
+ "robbyant/lingbot-depth-pretrain-vitl-14-v0.5"
+ ),
+ num_tokens: int = 2400,
+ target_latent_dim: int = 128,
+ depth_loss_weight: float = 1.0,
+ silog_loss_weight: float = 0.5,
+ affine_global_weight: float = 10.0,
+ affine_local_weight: float = 10.0,
+ edge_loss_weight: float = 10.0,
+ mask_loss_weight: float = 0.1,
+ monocular_prob: float = 0.7,
+ masked_prob: float = 0.2,
+ mask_ratio_range: tuple[float, float] = (0.6, 0.9),
+ mask_patch_size: int = 14,
+ camera_loss_weight: float = 1.0,
+ detach_depth_latents: bool = True,
+ encoder_freeze_blocks: int = 0,
+ unpad_test: bool = True,
+ ) -> None:
+ """Initialize the LingbotDepthBackend."""
+ super().__init__(detach_depth_latents=detach_depth_latents)
+ self.unpad_test = unpad_test
+
+ self.num_tokens = num_tokens
+ self.target_latent_dim = target_latent_dim
+ self.depth_loss_weight = depth_loss_weight
+ self.silog_loss_weight = silog_loss_weight
+ self.affine_global_weight = affine_global_weight
+ self.affine_local_weight = affine_local_weight
+ self.edge_loss_weight = edge_loss_weight
+ self.mask_loss_weight = mask_loss_weight
+ self.monocular_prob = monocular_prob
+ self.masked_prob = masked_prob
+ self.mask_ratio_range = mask_ratio_range
+ self.mask_patch_size = mask_patch_size
+ self.camera_loss_weight = camera_loss_weight
+
+ # SILog loss (scale-invariant) - lazy init, only needed for training
+ self._silog_loss_weight = silog_loss_weight
+ self._silog_loss = None
+
+ # Load pretrained MDMModel and decompose into sub-modules
+ from mdm.model.v2 import MDMModel
+
+ print(
+ f"[LingbotDepth] Loading pretrained model: "
+ f"{pretrained_model}"
+ )
+ mdm_model = MDMModel.from_pretrained(pretrained_model)
+
+ self.encoder = mdm_model.encoder
+ self.neck = mdm_model.neck
+ self.depth_head = mdm_model.depth_head
+ self.remap_depth_in = mdm_model.remap_depth_in
+ self.remap_depth_out = mdm_model.remap_depth_out
+
+ # Load mask_head from pretrained model (confidence prediction)
+ if hasattr(mdm_model, "mask_head"):
+ self.mask_head = mdm_model.mask_head
+ print("[LingbotDepth] mask_head loaded from checkpoint")
+ else:
+ self.mask_head = None
+ print(
+ "[LingbotDepth] WARNING: mask_head not found in "
+ "checkpoint, confidence prediction disabled"
+ )
+
+ # Get dimensions from loaded model
+ cls_dim = self.encoder.dim_features
+
+ # Neck level 1 outputs 256-dim features.
+ # If target_latent_dim != 256, project; otherwise Identity.
+ self._neck_latent_dim = 256
+ if target_latent_dim != self._neck_latent_dim:
+ self.latent_proj = nn.Linear(
+ self._neck_latent_dim, target_latent_dim
+ )
+ else:
+ self.latent_proj = nn.Identity()
+
+ # Intrinsic prediction head: cls_token -> camera K
+ # Same parameterization as UniDepthV2 CameraHead:
+ # exp(raw_f) * 0.7 * diagonal for focal length,
+ # sigmoid(raw_c) * W/H for principal point.
+ # Init: exp(0)=1.0 gives fx ~ 0.7*diag, sigmoid(0)=0.5 gives cx=W/2
+ self.intrinsic_head = nn.Sequential(
+ nn.LayerNorm(cls_dim),
+ nn.Linear(cls_dim, 256),
+ nn.ReLU(),
+ nn.Linear(256, 4),
+ )
+ nn.init.zeros_(self.intrinsic_head[-1].weight)
+ nn.init.zeros_(self.intrinsic_head[-1].bias)
+
+ # De-normalization buffers: convert 3D-MOOD normalized images
+ # back to [0,1] for the encoder (which does its own ImageNet norm)
+ self.register_buffer(
+ "denorm_mean",
+ torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1),
+ )
+ self.register_buffer(
+ "denorm_std",
+ torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1),
+ )
+
+ # Delete reference to full model (sub-modules survive via self)
+ del mdm_model
+
+ # torch.compile for encoder (controlled by SAM3_COMPILE env var)
+ import os
+ if os.environ.get("SAM3_COMPILE", "0") == "1":
+ self.encoder = torch.compile(self.encoder)
+ print("[LingbotDepth] torch.compile ENABLED for encoder")
+
+ # Freeze the first N transformer blocks of the encoder backbone.
+ # ViT-L has 24 blocks; e.g. encoder_freeze_blocks=20 freezes
+ # blocks[0..19] and only trains blocks[20..23] + patch_embed +
+ # norm + output_projections + neck + depth_head + new heads.
+ num_blocks = len(self.encoder.backbone.blocks)
+ encoder_freeze_blocks = min(encoder_freeze_blocks, num_blocks)
+ if encoder_freeze_blocks > 0:
+ bb = self.encoder.backbone
+ # Freeze everything in backbone first
+ for p in bb.parameters():
+ p.requires_grad = False
+ # Unfreeze the last (num_blocks - freeze_blocks) blocks
+ for i in range(encoder_freeze_blocks, num_blocks):
+ for p in bb.blocks[i].parameters():
+ p.requires_grad = True
+ # Unfreeze final norm (after all blocks)
+ for p in bb.norm.parameters():
+ p.requires_grad = True
+
+ copythrough_prob = 1.0 - monocular_prob - masked_prob
+ freeze_msg = (
+ f" encoder freeze: {encoder_freeze_blocks}/{num_blocks}"
+ f" blocks frozen"
+ )
+ print(
+ f"[LingbotDepth] Initialized: "
+ f"cls_dim={cls_dim}, num_tokens={num_tokens}, "
+ f"depth_latents=neck[1] (256-dim, pooled)\n"
+ f" remap_depth_in={self.remap_depth_in}, "
+ f"remap_depth_out={self.remap_depth_out}\n"
+ f" depth strategy: {monocular_prob:.0%} monocular / "
+ f"{masked_prob:.0%} patch-masked / "
+ f"{copythrough_prob:.0%} copy-through\n"
+ f" mask_ratio_range={mask_ratio_range}, "
+ f"mask_patch_size={mask_patch_size}\n"
+ f" losses: L1={depth_loss_weight}, "
+ f"affine_global={affine_global_weight}, "
+ f"affine_local={affine_local_weight}, "
+ f"edge={edge_loss_weight}, "
+ f"mask_bce={mask_loss_weight}, "
+ f"camera_ray={camera_loss_weight}\n"
+ f" mask_head={'loaded' if self.mask_head is not None else 'none'}\n"
+ f"{freeze_msg}"
+ )
+
+ def load_pretrained_weights(self) -> None:
+ """No-op: weights already loaded in __init__ via from_pretrained."""
+ pass
+
+ def _compute_token_grid(
+ self, H: int, W: int
+ ) -> tuple[int, int]:
+ """Compute token grid dimensions from image aspect ratio.
+
+ Same formula as MDMModel.forward lines 110-115.
+
+ Args:
+ H: Image height.
+ W: Image width.
+
+ Returns:
+ (base_h, base_w) token grid dimensions.
+ """
+ aspect_ratio = W / H
+ base_h = round(math.sqrt(self.num_tokens / aspect_ratio))
+ base_w = round(math.sqrt(self.num_tokens * aspect_ratio))
+ return base_h, base_w
+
+ def _prepare_depth_input(
+ self,
+ depth_gt: Tensor | None,
+ depth_mask: Tensor | None,
+ B: int,
+ H: int,
+ W: int,
+ device: torch.device,
+ ) -> Tensor | None:
+ """Prepare depth input with mixed strategy for training.
+
+ Per-sample mode selection:
+ - [0, monocular_prob): zero depth (monocular)
+ - [monocular_prob, monocular_prob + masked_prob): patch-masked
+ - [monocular_prob + masked_prob, 1.0): copy-through (full depth)
+
+ Args:
+ depth_gt: Ground truth depth [B, H, W] or [B, 1, H, W].
+ depth_mask: Valid depth mask [B, H, W] or [B, 1, H, W].
+ B: Batch size.
+ H: Image height.
+ W: Image width.
+ device: Tensor device.
+
+ Returns:
+ depth_input [B, 1, H, W] or None if no depth_gt.
+ """
+ if depth_gt is None:
+ return None
+
+ if depth_gt.ndim == 3:
+ depth_gt = depth_gt.unsqueeze(1) # [B, 1, H, W]
+
+ # Apply depth_mask if provided
+ if depth_mask is not None:
+ if depth_mask.ndim == 3:
+ depth_mask = depth_mask.unsqueeze(1)
+ depth_gt = depth_gt * depth_mask.float()
+
+ depth_input = torch.zeros_like(depth_gt)
+ rand_vals = torch.rand(B, device=device)
+ masked_threshold = self.monocular_prob + self.masked_prob
+
+ for i in range(B):
+ if rand_vals[i] < self.monocular_prob:
+ # Monocular: keep zeros
+ pass
+ elif rand_vals[i] < masked_threshold:
+ # Patch-level random masking
+ depth_input[i] = self._patch_mask_depth(
+ depth_gt[i], H, W, device
+ )
+ else:
+ # Copy-through: full depth
+ depth_input[i] = depth_gt[i]
+
+ return depth_input
+
+ def _patch_mask_depth(
+ self,
+ depth: Tensor,
+ H: int,
+ W: int,
+ device: torch.device,
+ ) -> Tensor:
+ """Apply patch-level random masking to depth map.
+
+ Following the MDM paper: randomly mask 60-90% of patches,
+ zeroing out entire patch regions.
+
+ Args:
+ depth: [1, H, W] single-sample depth map.
+ H: Image height.
+ W: Image width.
+ device: Tensor device.
+
+ Returns:
+ Masked depth [1, H, W] with some patches zeroed out.
+ """
+ ps = self.mask_patch_size
+ grid_h = H // ps
+ grid_w = W // ps
+ num_patches = grid_h * grid_w
+
+ # Random masking ratio in [min, max]
+ lo, hi = self.mask_ratio_range
+ mask_ratio = torch.rand(1, device=device).item() * (hi - lo) + lo
+ num_masked = int(num_patches * mask_ratio)
+
+ # Random permutation: first num_masked patches are masked (0)
+ perm = torch.randperm(num_patches, device=device)
+ keep = torch.ones(num_patches, device=device)
+ keep[perm[:num_masked]] = 0.0
+
+ # Reshape to spatial grid and upsample to image size
+ keep = keep.view(1, 1, grid_h, grid_w)
+ keep = F.interpolate(
+ keep, size=(grid_h * ps, grid_w * ps), mode="nearest"
+ ) # [1, 1, grid_h*ps, grid_w*ps]
+
+ # Pad if image size not divisible by patch size
+ pad_h = H - grid_h * ps
+ pad_w = W - grid_w * ps
+ if pad_h > 0 or pad_w > 0:
+ keep = F.pad(keep, (0, pad_w, 0, pad_h), value=1.0)
+
+ return depth * keep.squeeze(0) # [1, H, W]
+
+ def _predict_intrinsics(
+ self, cls_token: Tensor, H: int, W: int
+ ) -> Tensor:
+ """Predict camera intrinsics from cls_token.
+
+ Same parameterization as UniDepthV2 CameraHead.fill_intrinsics:
+ - fx = exp(raw) * 0.7 * diagonal
+ - fy = exp(raw) * 0.7 * diagonal
+ - cx = sigmoid(raw) * W
+ - cy = sigmoid(raw) * H
+
+ Args:
+ cls_token: [B, cls_dim] class token from encoder.
+ H: Image height (original pixel space).
+ W: Image width (original pixel space).
+
+ Returns:
+ K_pred: [B, 3, 3] predicted intrinsics in pixel coords.
+ """
+ params = self.intrinsic_head(cls_token) # [B, 4]
+
+ diagonal = (H**2 + W**2) ** 0.5
+ fx = torch.exp(params[:, 0].clamp(-10, 10)) * 0.7 * diagonal
+ fy = torch.exp(params[:, 1].clamp(-10, 10)) * 0.7 * diagonal
+ cx = torch.sigmoid(params[:, 2]) * W
+ cy = torch.sigmoid(params[:, 3]) * H
+
+ B = cls_token.shape[0]
+ K_pred = torch.zeros(
+ B, 3, 3, device=cls_token.device, dtype=cls_token.dtype
+ )
+ K_pred[:, 0, 0] = fx
+ K_pred[:, 1, 1] = fy
+ K_pred[:, 0, 2] = cx
+ K_pred[:, 1, 2] = cy
+ K_pred[:, 2, 2] = 1.0
+
+ return K_pred
+
+ def _run_encoder_and_decoder(
+ self,
+ images: Tensor,
+ depth_input: Tensor | None,
+ image_hw: tuple[int, int],
+ ) -> tuple[Tensor, Tensor, Tensor, int, int, list[Tensor]]:
+ """Run encoder + neck + depth_head pipeline.
+
+ Replicates MDMModel.forward() logic (lines 98-168 of v2.py).
+
+ Args:
+ images: [B, 3, H, W] 3D-MOOD normalized images.
+ depth_input: [B, 1, H, W] depth for encoder, or None.
+ image_hw: Original (H, W) dimensions.
+
+ Returns:
+ depth_map: [B, 1, H, W] metric depth in meters.
+ depth_latents: [B, N, target_latent_dim].
+ cls_token: [B, cls_dim].
+ base_h: Token grid height.
+ base_w: Token grid width.
+ neck_out: List of neck feature maps for mask_head.
+ """
+ from mdm.utils.geo import normalized_view_plane_uv
+
+ B = images.shape[0]
+ H, W = image_hw
+ device, dtype = images.device, images.dtype
+
+ # De-normalize from 3D-MOOD normalization to [0, 1]
+ # 3D-MOOD: norm_img = (img_255 - mean_255) / std_255
+ # Reverse: img_01 = norm_img * (std_255/255) + (mean_255/255)
+ # = norm_img * imagenet_std + imagenet_mean
+ images_01 = images * self.denorm_std + self.denorm_mean
+
+ # Compute token grid
+ base_h, base_w = self._compute_token_grid(H, W)
+
+ # Prepare depth: zeros if None (monocular mode)
+ if depth_input is None:
+ depth_for_encoder = torch.zeros(
+ B, 1, H, W, device=device, dtype=dtype
+ )
+ else:
+ depth_for_encoder = depth_input
+
+ # Encoder forward: expects [0,1] images
+ # (encoder internally normalizes with ImageNet stats and resizes
+ # to (base_h*14, base_w*14))
+ # enable_depth_mask=False avoids xformers BlockDiagonalMask
+ # dependency and uses standard attention instead
+ features, cls_token, _, _ = self.encoder(
+ images_01,
+ depth_for_encoder,
+ base_h,
+ base_w,
+ return_class_token=True,
+ remap_depth_in=self.remap_depth_in,
+ enable_depth_mask=False,
+ )
+ # features: [B, encoder_dim, base_h, base_w]
+ # cls_token: [B, cls_dim]
+
+ # Run neck + depth_head (MDMModel.forward lines 120-148)
+ aspect_ratio = W / H
+
+ # Add cls_token to features
+ feat_with_cls = features + cls_token[..., None, None]
+ feat_list = [feat_with_cls, None, None, None, None]
+
+ # Concat UV coordinates at 5 pyramid levels
+ for level in range(5):
+ uv = normalized_view_plane_uv(
+ width=base_w * 2**level,
+ height=base_h * 2**level,
+ aspect_ratio=aspect_ratio,
+ dtype=dtype,
+ device=device,
+ )
+ uv = (
+ uv.permute(2, 0, 1).unsqueeze(0).expand(B, -1, -1, -1)
+ )
+ if feat_list[level] is None:
+ feat_list[level] = uv
+ else:
+ feat_list[level] = torch.cat(
+ [feat_list[level], uv], dim=1
+ )
+
+ # Shared neck
+ neck_out = self.neck(feat_list)
+
+ # Extract depth_latents from neck level 1 (after 2 ResBlocks)
+ # neck_out[1]: [B, 256, base_h*2, base_w*2]
+ # Pool to (base_h, base_w) to keep N = base_h * base_w
+ neck_feat = neck_out[1] # [B, 256, base_h*2, base_w*2]
+ neck_feat_pooled = F.adaptive_avg_pool2d(
+ neck_feat, (base_h, base_w)
+ ) # [B, 256, base_h, base_w]
+ depth_latents = neck_feat_pooled.flatten(2).permute(
+ 0, 2, 1
+ ) # [B, N, 256]
+ depth_latents = self.latent_proj(depth_latents)
+
+ # Depth head: take last output
+ depth_reg = self.depth_head(neck_out)[-1] # [B, 1, h, w]
+
+ # Resize to original image dimensions
+ depth_reg = F.interpolate(
+ depth_reg,
+ (H, W),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ # Apply output remapping
+ # Clamp before exp to prevent overflow (float16 overflows at ~11,
+ # float32 at ~88). Range [-10, 10] maps to depth [4.5e-5, 22026] m.
+ if self.remap_depth_out == "exp":
+ depth_map = depth_reg.clamp(-10, 10).exp() # [B, 1, H, W]
+ elif self.remap_depth_out == "linear":
+ # Linear output can be negative; clamp to positive for
+ # downstream log-based losses and 3D head depth usage.
+ depth_map = depth_reg.clamp(min=1e-3)
+ else:
+ raise ValueError(
+ f"Invalid remap_depth_out: {self.remap_depth_out}"
+ )
+
+ return depth_map, depth_latents, cls_token, base_h, base_w, neck_out
+
+ def _run_mask_head(
+ self,
+ neck_out: list[Tensor],
+ H: int,
+ W: int,
+ ) -> Tensor | None:
+ """Run mask_head to produce confidence map.
+
+ Args:
+ neck_out: List of neck feature maps.
+ H: Target height.
+ W: Target width.
+
+ Returns:
+ confidence_map: [B, 1, H, W] sigmoid probabilities, or None.
+ """
+ if self.mask_head is None:
+ return None
+ confidence_raw = self.mask_head(neck_out)[-1] # [B, 1, h, w]
+ confidence_map = F.interpolate(
+ confidence_raw,
+ (H, W),
+ mode="bilinear",
+ align_corners=False,
+ ).sigmoid()
+ return confidence_map
+
+ @torch.autocast(device_type="cuda", enabled=False)
+ def _compute_losses(
+ self,
+ depth_map: Tensor,
+ depth_gt: Tensor | None,
+ depth_mask: Tensor | None,
+ K_pred: Tensor,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ confidence_map: Tensor | None = None,
+ ) -> dict[str, Tensor]:
+ """Compute depth and camera losses.
+
+ Depth loss: masked L1 + MoGe2 affine-invariant losses (global,
+ local level 4 & 16, edge) + confidence mask BCE.
+ Camera loss: ray-based L2 RMSE (same as UniDepthV2).
+
+ Args:
+ depth_map: [B, 1, H, W] predicted metric depth.
+ depth_gt: [B, H, W] or [B, 1, H, W] ground truth depth.
+ depth_mask: [B, H, W] or [B, 1, H, W] valid depth mask.
+ K_pred: [B, 3, 3] predicted intrinsics.
+ intrinsics: [B, 3, 3] ground truth intrinsics.
+ image_hw: (H, W) image dimensions.
+ confidence_map: [B, 1, H, W] confidence from mask_head.
+
+ Returns:
+ Dictionary of loss tensors.
+ """
+ # Lazy import: moge losses only needed for training
+ # Lazy imports: only needed for training
+ from moge.train.losses import (
+ affine_invariant_global_loss,
+ affine_invariant_local_loss,
+ edge_loss,
+ mask_bce_loss,
+ )
+ if self._silog_loss is None and self._silog_loss_weight > 0:
+ from wilddet3d.loss.silog_loss import SILogLoss
+ self._silog_loss = SILogLoss(scale_pred_weight=0.15)
+
+ losses = {}
+ H, W = image_hw
+
+ # Cast to float32 for numerical stability under mixed precision
+ depth_map = depth_map.float()
+ K_pred = K_pred.float()
+ intrinsics = intrinsics.float()
+ if depth_gt is not None:
+ depth_gt = depth_gt.float()
+ if depth_mask is not None:
+ depth_mask = depth_mask.float()
+ if confidence_map is not None:
+ confidence_map = confidence_map.float()
+
+ # Depth losses
+ if depth_gt is not None:
+ depth_pred = depth_map.squeeze(1) # [B, H, W]
+
+ if depth_gt.ndim == 4:
+ depth_gt = depth_gt.squeeze(1) # [B, H, W]
+
+ valid_mask = depth_gt > 0
+ if depth_mask is not None:
+ if depth_mask.ndim == 4:
+ depth_mask = depth_mask.squeeze(1)
+ valid_mask = valid_mask & depth_mask.bool()
+
+ # Filter out extreme GT depth (>100m) and extreme
+ # pred/gt ratio (>3x or <1/3x) to prevent unstable
+ # gradients from outlier pixels.
+ _MAX_DEPTH = 100.0
+ _MAX_RATIO = 3.0
+ valid_mask = valid_mask & (depth_gt <= _MAX_DEPTH)
+ with torch.no_grad():
+ ratio = depth_pred / (depth_gt + 1e-6)
+ valid_mask = valid_mask & (
+ (ratio > 1.0 / _MAX_RATIO)
+ & (ratio < _MAX_RATIO)
+ )
+
+ B = depth_pred.shape[0]
+
+ # L1 metric depth loss
+ if valid_mask.any():
+ depth_loss = F.l1_loss(
+ depth_pred[valid_mask], depth_gt[valid_mask]
+ )
+ else:
+ depth_loss = depth_pred.new_tensor(0.0)
+
+ losses["depth_l1"] = (
+ depth_loss.clamp(max=10.0) * self.depth_loss_weight
+ )
+
+ # SILog loss (scale-invariant)
+ if self._silog_loss is not None and valid_mask.any():
+ silog_val = self._silog_loss(
+ depth_pred, depth_gt, mask=valid_mask
+ )
+ losses["depth_silog"] = (
+ silog_val.clamp(max=10.0)
+ * self.silog_loss_weight
+ )
+
+ # Back-project to 3D points for MoGe2 losses
+ # 50% chance per image: use K_pred or GT intrinsics
+ # This trains intrinsic head via MoGe2 loss while keeping
+ # depth supervised with GT intrinsics half the time.
+ use_pred_k = torch.rand(B, device=depth_pred.device) < 0.5
+ K_for_pred = torch.where(
+ use_pred_k[:, None, None], K_pred, intrinsics
+ )
+ pred_points = backproject_depth_to_points(
+ depth_pred, K_for_pred, H, W
+ ) # [B, H, W, 3]
+ gt_points = backproject_depth_to_points(
+ depth_gt, intrinsics, H, W
+ ) # [B, H, W, 3]
+ # MoGe2 convention: invalid GT -> inf
+ gt_points[~valid_mask] = float("inf")
+
+ # Per-image MoGe2 losses (alignment is per-image)
+ zero = depth_pred.new_tensor(0.0)
+ aff_global_sum = zero
+ aff_local4_sum = zero
+ aff_local16_sum = zero
+ edge_sum = zero
+
+ for i in range(B):
+ has_valid = valid_mask[i].any()
+ if has_valid:
+ loss_g, _, scale_i = (
+ affine_invariant_global_loss(
+ pred_points[i],
+ gt_points[i],
+ align_resolution=48,
+ )
+ )
+ else:
+ loss_g = zero
+ scale_i = zero
+ aff_global_sum = aff_global_sum + loss_g
+
+ # MoGe2 local loss expects normalized focal
+ # (fx/W, fy/H ~0.5-1.0), not pixel focal
+ fx_norm = K_pred[i, 0, 0] / W
+ fy_norm = K_pred[i, 1, 1] / H
+ focal_i = 1.0 / (
+ 1.0 / fx_norm**2 + 1.0 / fy_norm**2
+ ) ** 0.5
+
+ if has_valid:
+ loss_l4, _ = affine_invariant_local_loss(
+ pred_points[i],
+ gt_points[i],
+ focal_i,
+ scale_i,
+ level=4,
+ align_resolution=24,
+ num_patches=16,
+ importance_sampling=False,
+ )
+ loss_l16, _ = affine_invariant_local_loss(
+ pred_points[i],
+ gt_points[i],
+ focal_i,
+ scale_i,
+ level=16,
+ align_resolution=12,
+ num_patches=256,
+ importance_sampling=False,
+ )
+ loss_e, _ = edge_loss(
+ pred_points[i], gt_points[i]
+ )
+ else:
+ loss_l4 = zero
+ loss_l16 = zero
+ loss_e = zero
+ aff_local4_sum = aff_local4_sum + loss_l4
+ aff_local16_sum = aff_local16_sum + loss_l16
+ edge_sum = edge_sum + loss_e
+
+ losses["affine_global"] = (
+ (aff_global_sum / B).clamp(max=10.0)
+ * self.affine_global_weight
+ )
+ losses["affine_local_4"] = (
+ (aff_local4_sum / B).clamp(max=10.0)
+ * self.affine_local_weight
+ )
+ losses["affine_local_16"] = (
+ (aff_local16_sum / B).clamp(max=10.0)
+ * self.affine_local_weight
+ )
+ losses["edge"] = (
+ (edge_sum / B).clamp(max=10.0)
+ * self.edge_loss_weight
+ )
+
+ # Mask BCE loss (confidence map)
+ # MoGe2 uses 3-state masks (fin / inf / unknown).
+ # For sparse data (LiDAR), most pixels have no
+ # annotation and should NOT be labeled "known invalid".
+ # Use per-image coverage to decide: dense (>50%)
+ # treats all non-valid as known-invalid; sparse
+ # treats only depth_mask-annotated invalid pixels.
+ if (
+ confidence_map is not None
+ and self.mask_loss_weight > 0
+ ):
+ conf = confidence_map.squeeze(1) # [B, H, W]
+ gt_mask_fin = valid_mask # [B, H, W]
+ has_depth = depth_gt > 0 # [B, H, W]
+ if depth_mask is not None:
+ annotated = depth_mask.bool()
+ else:
+ # Per-image: dense -> all pixels annotated;
+ # sparse -> only depth>0 pixels annotated.
+ coverage = has_depth.flatten(1).float().mean(1)
+ is_dense = coverage > 0.7 # [B]
+ annotated = torch.where(
+ is_dense[:, None, None],
+ torch.ones_like(has_depth),
+ has_depth,
+ )
+ gt_mask_inf = annotated & ~has_depth
+ loss_mask, _ = mask_bce_loss(
+ conf, gt_mask_fin, gt_mask_inf
+ )
+ losses["mask_bce"] = (
+ loss_mask.mean().clamp(max=10.0)
+ * self.mask_loss_weight
+ )
+
+ # Camera loss: ray-based MSE (same as UniDepthV2)
+ rays_pred, _ = generate_rays(K_pred, image_hw)
+ rays_gt, _ = generate_rays(intrinsics, image_hw)
+ camera_loss = F.mse_loss(rays_pred, rays_gt)
+ losses["camera_ray"] = (
+ camera_loss.clamp(max=10.0) * self.camera_loss_weight
+ )
+
+ return losses
+
+ def _scale_intrinsics(
+ self,
+ intrinsics: Tensor,
+ from_hw: tuple[int, int],
+ to_hw: tuple[int, int],
+ ) -> Tensor:
+ """Scale intrinsics from one image space to another.
+
+ Args:
+ intrinsics: [B, 3, 3] intrinsics in from_hw space.
+ from_hw: Source (H, W).
+ to_hw: Target (H, W).
+
+ Returns:
+ Scaled intrinsics [B, 3, 3] in to_hw space.
+ """
+ scale_x = to_hw[1] / from_hw[1]
+ scale_y = to_hw[0] / from_hw[0]
+
+ K_scaled = intrinsics.clone()
+ K_scaled[:, 0, 0] *= scale_x # fx
+ K_scaled[:, 0, 2] *= scale_x # cx
+ K_scaled[:, 1, 1] *= scale_y # fy
+ K_scaled[:, 1, 2] *= scale_y # cy
+
+ return K_scaled
+
+ def _has_valid_padding(self, padding: list | None) -> bool:
+ """Check if padding info is valid and non-zero."""
+ if padding is None:
+ return False
+ return any(
+ p is not None and any(v > 0 for v in p) for p in padding
+ )
+
+ def _crop_padding_single(
+ self,
+ image: Tensor,
+ intrinsics: Tensor,
+ pad_info: list[int],
+ H_pad: int,
+ W_pad: int,
+ depth_gt: Tensor | None = None,
+ depth_mask: Tensor | None = None,
+ ) -> tuple[Tensor, Tensor, int, int, Tensor | None, Tensor | None]:
+ """Crop padding from a single image and adjust intrinsics.
+
+ Args:
+ image: [1, 3, H_pad, W_pad] padded image.
+ intrinsics: [1, 3, 3] padded-space intrinsics.
+ pad_info: [pad_left, pad_right, pad_top, pad_bottom].
+ H_pad: Padded height.
+ W_pad: Padded width.
+ depth_gt: [1, 1, H_pad, W_pad] or None.
+ depth_mask: [1, 1, H_pad, W_pad] or [1, H_pad, W_pad] or None.
+
+ Returns:
+ (cropped_image, adjusted_intrinsics, H_orig, W_orig,
+ cropped_depth_gt, cropped_depth_mask)
+ """
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
+ H_orig = H_pad - pad_top - pad_bottom
+ W_orig = W_pad - pad_left - pad_right
+
+ # Crop image
+ img_cropped = image[
+ :, :, pad_top : pad_top + H_orig, pad_left : pad_left + W_orig
+ ]
+
+ # Adjust intrinsics: reverse CenterPadIntrinsics
+ K_cropped = intrinsics.clone()
+ K_cropped[0, 0, 2] -= pad_left # cx
+ K_cropped[0, 1, 2] -= pad_top # cy
+
+ # Crop depth_gt
+ dgt_cropped = None
+ if depth_gt is not None:
+ dgt_cropped = depth_gt[
+ :, :,
+ pad_top : pad_top + H_orig,
+ pad_left : pad_left + W_orig,
+ ]
+
+ # Crop depth_mask
+ dm_cropped = None
+ if depth_mask is not None:
+ if depth_mask.ndim == 3:
+ dm_cropped = depth_mask[
+ :,
+ pad_top : pad_top + H_orig,
+ pad_left : pad_left + W_orig,
+ ]
+ else:
+ dm_cropped = depth_mask[
+ :, :,
+ pad_top : pad_top + H_orig,
+ pad_left : pad_left + W_orig,
+ ]
+
+ return (
+ img_cropped,
+ K_cropped,
+ H_orig,
+ W_orig,
+ dgt_cropped,
+ dm_cropped,
+ )
+
+ def _repad_depth_latents(
+ self,
+ depth_latents: Tensor,
+ base_h_orig: int,
+ base_w_orig: int,
+ base_h_pad: int,
+ base_w_pad: int,
+ pad_top: int,
+ pad_left: int,
+ H_pad: int,
+ W_pad: int,
+ ) -> Tensor:
+ """Repad depth latents from original to padded token grid.
+
+ Places original-resolution tokens at the correct position within
+ the padded token grid, with zeros filling the padding regions.
+
+ Args:
+ depth_latents: [1, N_orig, C] original-resolution latents.
+ base_h_orig: Original token grid height.
+ base_w_orig: Original token grid width.
+ base_h_pad: Padded token grid height.
+ base_w_pad: Padded token grid width.
+ pad_top: Pixel-space top padding.
+ pad_left: Pixel-space left padding.
+ H_pad: Padded image height.
+ W_pad: Padded image width.
+
+ Returns:
+ [1, N_pad, C] depth latents in padded token grid.
+ """
+ if (
+ base_h_orig == base_h_pad
+ and base_w_orig == base_w_pad
+ ):
+ return depth_latents
+
+ _, N_orig, C = depth_latents.shape
+
+ # Reshape to spatial: [1, C, base_h_orig, base_w_orig]
+ dl_2d = depth_latents.permute(0, 2, 1).reshape(
+ 1, C, base_h_orig, base_w_orig
+ )
+
+ # Compute token-space offsets
+ pad_top_tok = round(pad_top * base_h_pad / H_pad)
+ pad_left_tok = round(pad_left * base_w_pad / W_pad)
+
+ # Clamp to valid range
+ pad_top_tok = min(pad_top_tok, base_h_pad - 1)
+ pad_left_tok = min(pad_left_tok, base_w_pad - 1)
+
+ # How many original tokens fit
+ h_fit = min(base_h_orig, base_h_pad - pad_top_tok)
+ w_fit = min(base_w_orig, base_w_pad - pad_left_tok)
+
+ # Create padded output with zeros
+ dl_padded = torch.zeros(
+ 1,
+ C,
+ base_h_pad,
+ base_w_pad,
+ device=depth_latents.device,
+ dtype=depth_latents.dtype,
+ )
+ dl_padded[
+ :,
+ :,
+ pad_top_tok : pad_top_tok + h_fit,
+ pad_left_tok : pad_left_tok + w_fit,
+ ] = dl_2d[:, :, :h_fit, :w_fit]
+
+ # Flatten back: [1, N_pad, C]
+ return dl_padded.flatten(2).permute(0, 2, 1)
+
+ def _repad_depth_map(
+ self,
+ depth_map: Tensor,
+ pad_left: int,
+ pad_right: int,
+ pad_top: int,
+ pad_bottom: int,
+ ) -> Tensor:
+ """Repad depth map from original to padded resolution.
+
+ Args:
+ depth_map: [1, 1, H_orig, W_orig].
+ pad_left, pad_right, pad_top, pad_bottom: Pixel padding.
+
+ Returns:
+ [1, 1, H_pad, W_pad] with zeros in padding region.
+ """
+ return F.pad(
+ depth_map,
+ (pad_left, pad_right, pad_top, pad_bottom),
+ value=0.0,
+ )
+
+ def forward_train(
+ self,
+ images: Tensor,
+ depth_feats: list[Tensor] | None,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ depth_gt: Tensor | None = None,
+ depth_mask: Tensor | None = None,
+ **kwargs,
+ ) -> GeometryBackendOutput:
+ """Forward pass for training.
+
+ Uses mixed depth input strategy: each sample independently
+ gets monocular / patch-masked / copy-through depth input.
+
+ When padding info is provided, crops padding before the encoder
+ so LingBot-Depth processes at original resolution with correct
+ aspect ratio, then repads outputs back to padded space.
+
+ Args:
+ images: [B, 3, H, W] 3D-MOOD normalized images.
+ depth_feats: Ignored (we use our own encoder).
+ intrinsics: [B, 3, 3] camera intrinsics.
+ image_hw: (H, W) image dimensions.
+ depth_gt: [B, H, W] ground truth depth.
+ depth_mask: [B, H, W] valid depth mask.
+ **kwargs: May contain 'padding' (list of [L,R,T,B] per image).
+
+ Returns:
+ GeometryBackendOutput.
+ """
+ B = images.shape[0]
+ H_pad, W_pad = image_hw
+ padding = kwargs.get("padding", None)
+
+ # If no valid padding, use original batched code path
+ if not self._has_valid_padding(padding):
+ return self._forward_train_batched(
+ images, intrinsics, image_hw, depth_gt, depth_mask
+ )
+
+ # Per-image processing at original (unpadded) resolution
+ # Padded token grid (target for repadding depth_latents)
+ base_h_pad, base_w_pad = self._compute_token_grid(
+ H_pad, W_pad
+ )
+
+ depth_maps_list = []
+ depth_latents_list = []
+ K_pred_list = []
+ confidence_maps_list = []
+ losses_accum = {}
+
+ for i in range(B):
+ pad_info = padding[i]
+ if pad_info is None or all(v == 0 for v in pad_info):
+ # No padding for this image
+ pad_left = pad_right = pad_top = pad_bottom = 0
+ img_i = images[i : i + 1]
+ K_i = intrinsics[i : i + 1]
+ H_orig, W_orig = H_pad, W_pad
+ dgt_i = (
+ depth_gt[i : i + 1] if depth_gt is not None
+ else None
+ )
+ dm_i = (
+ depth_mask[i : i + 1]
+ if depth_mask is not None
+ else None
+ )
+ else:
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
+ (
+ img_i,
+ K_i,
+ H_orig,
+ W_orig,
+ dgt_i,
+ dm_i,
+ ) = self._crop_padding_single(
+ images[i : i + 1],
+ intrinsics[i : i + 1],
+ pad_info,
+ H_pad,
+ W_pad,
+ (
+ depth_gt[i : i + 1]
+ if depth_gt is not None
+ else None
+ ),
+ (
+ depth_mask[i : i + 1]
+ if depth_mask is not None
+ else None
+ ),
+ )
+
+ orig_hw = (H_orig, W_orig)
+
+ # Prepare depth input with mixed strategy (per-image)
+ depth_input_i = self._prepare_depth_input(
+ dgt_i, dm_i, 1, H_orig, W_orig, images.device
+ )
+
+ # Run encoder at ORIGINAL resolution (correct aspect ratio)
+ (
+ depth_map_i,
+ depth_latents_i,
+ cls_token_i,
+ base_h_i,
+ base_w_i,
+ neck_out_i,
+ ) = self._run_encoder_and_decoder(
+ img_i, depth_input_i, orig_hw
+ )
+
+ # Predict intrinsics at original resolution
+ K_pred_i = self._predict_intrinsics(
+ cls_token_i, H_orig, W_orig
+ )
+
+ # Run mask_head for confidence map
+ confidence_map_i = self._run_mask_head(
+ neck_out_i, H_orig, W_orig
+ )
+
+ # Compute losses at original resolution
+ losses_i = self._compute_losses(
+ depth_map_i,
+ dgt_i,
+ dm_i,
+ K_pred_i,
+ K_i,
+ orig_hw,
+ confidence_map=confidence_map_i,
+ )
+
+ # Accumulate losses
+ for key, val in losses_i.items():
+ if key not in losses_accum:
+ losses_accum[key] = val
+ else:
+ losses_accum[key] = losses_accum[key] + val
+
+ # Repad depth_map back to padded resolution
+ depth_map_padded_i = self._repad_depth_map(
+ depth_map_i,
+ pad_left,
+ pad_right,
+ pad_top,
+ pad_bottom,
+ )
+ depth_maps_list.append(depth_map_padded_i)
+
+ # Repad confidence_map back to padded resolution
+ if confidence_map_i is not None:
+ confidence_maps_list.append(
+ self._repad_depth_map(
+ confidence_map_i,
+ pad_left,
+ pad_right,
+ pad_top,
+ pad_bottom,
+ )
+ )
+
+ # Repad depth_latents to padded token grid
+ depth_latents_padded_i = self._repad_depth_latents(
+ depth_latents_i,
+ base_h_i,
+ base_w_i,
+ base_h_pad,
+ base_w_pad,
+ pad_top,
+ pad_left,
+ H_pad,
+ W_pad,
+ )
+ depth_latents_list.append(depth_latents_padded_i)
+
+ # K_pred: restore to padded space (add padding offset)
+ # fx, fy unchanged (padding doesn't change focal length)
+ # Use non-inplace ops to preserve autograd graph
+ K_pred_padded_i = K_pred_i.clone()
+ K_pred_padded_i[:, 0, 2] = K_pred_i[:, 0, 2] + pad_left
+ K_pred_padded_i[:, 1, 2] = K_pred_i[:, 1, 2] + pad_top
+ K_pred_list.append(K_pred_padded_i)
+
+ # Average losses across batch
+ for key in losses_accum:
+ losses_accum[key] = losses_accum[key] / B
+
+ # Stack results
+ depth_map = torch.cat(depth_maps_list, dim=0)
+ depth_latents = torch.cat(depth_latents_list, dim=0)
+ K_pred = torch.cat(K_pred_list, dim=0)
+ confidence_map = (
+ torch.cat(confidence_maps_list, dim=0)
+ if confidence_maps_list
+ else None
+ )
+
+ depth_latents = self._maybe_detach_latents(depth_latents)
+
+ # Ray intrinsics: padded intrinsics scaled to padded token grid
+ # (consistent with padded depth_latents space)
+ internal_hw = (base_h_pad * 14, base_w_pad * 14)
+ ray_intrinsics = self._scale_intrinsics(
+ intrinsics, (H_pad, W_pad), internal_hw
+ )
+
+ return GeometryBackendOutput(
+ depth_map=depth_map,
+ depth_latents=depth_latents,
+ K_pred=K_pred,
+ ray_intrinsics=ray_intrinsics,
+ ray_image_hw=internal_hw,
+ ray_downsample=14,
+ aux={
+ "depth_latents_hw": (base_h_pad, base_w_pad),
+ "confidence_map": confidence_map,
+ },
+ losses=losses_accum,
+ )
+
+ def _forward_train_batched(
+ self,
+ images: Tensor,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ depth_gt: Tensor | None,
+ depth_mask: Tensor | None,
+ ) -> GeometryBackendOutput:
+ """Original batched forward_train path (no unpadding)."""
+ B = images.shape[0]
+ H, W = image_hw
+
+ depth_input = self._prepare_depth_input(
+ depth_gt, depth_mask, B, H, W, images.device
+ )
+
+ (
+ depth_map, depth_latents, cls_token,
+ base_h, base_w, neck_out,
+ ) = self._run_encoder_and_decoder(
+ images, depth_input, image_hw
+ )
+
+ depth_latents = self._maybe_detach_latents(depth_latents)
+ K_pred = self._predict_intrinsics(cls_token, H, W)
+
+ # Run mask_head for confidence map
+ confidence_map = self._run_mask_head(neck_out, H, W)
+
+ losses = self._compute_losses(
+ depth_map, depth_gt, depth_mask, K_pred, intrinsics,
+ image_hw, confidence_map=confidence_map,
+ )
+
+ internal_hw = (base_h * 14, base_w * 14)
+ ray_intrinsics = self._scale_intrinsics(
+ intrinsics, (H, W), internal_hw
+ )
+
+ return GeometryBackendOutput(
+ depth_map=depth_map,
+ depth_latents=depth_latents,
+ K_pred=K_pred,
+ ray_intrinsics=ray_intrinsics,
+ ray_image_hw=internal_hw,
+ ray_downsample=14,
+ aux={
+ "depth_latents_hw": (base_h, base_w),
+ "confidence_map": confidence_map,
+ },
+ losses=losses,
+ )
+
+ @torch.no_grad()
+ def forward_test(
+ self,
+ images: Tensor,
+ depth_feats: list[Tensor] | None,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ depth_gt: Tensor | None = None,
+ **kwargs,
+ ) -> GeometryBackendOutput:
+ """Forward pass for inference.
+
+ When padding info is provided, crops padding before the encoder
+ so LingBot-Depth processes at original resolution, then repads.
+
+ Args:
+ images: [B, 3, H, W] 3D-MOOD normalized images.
+ depth_feats: Ignored.
+ intrinsics: [B, 3, 3] camera intrinsics.
+ image_hw: (H, W) image dimensions.
+ depth_gt: [B, 1, H, W] depth map input (optional).
+ **kwargs: May contain 'padding' (list of [L,R,T,B]).
+
+ Returns:
+ GeometryBackendOutput.
+ """
+ H_pad, W_pad = image_hw
+ padding = kwargs.get("padding", None)
+
+ # If unpad disabled or no valid padding, use batched (padded) path
+ if not self.unpad_test or not self._has_valid_padding(padding):
+ return self._forward_test_batched(
+ images, intrinsics, image_hw, depth_gt
+ )
+
+ # Per-image processing at original resolution
+ B = images.shape[0]
+ base_h_pad, base_w_pad = self._compute_token_grid(
+ H_pad, W_pad
+ )
+
+ depth_maps_list = []
+ depth_latents_list = []
+ K_pred_list = []
+ confidence_maps_list = []
+
+ for i in range(B):
+ pad_info = padding[i]
+ if pad_info is None or all(v == 0 for v in pad_info):
+ pad_left = pad_right = pad_top = pad_bottom = 0
+ img_i = images[i : i + 1]
+ K_i = intrinsics[i : i + 1]
+ H_orig, W_orig = H_pad, W_pad
+ dgt_i = (
+ depth_gt[i : i + 1]
+ if depth_gt is not None
+ else None
+ )
+ else:
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
+ (
+ img_i,
+ K_i,
+ H_orig,
+ W_orig,
+ dgt_i,
+ _,
+ ) = self._crop_padding_single(
+ images[i : i + 1],
+ intrinsics[i : i + 1],
+ pad_info,
+ H_pad,
+ W_pad,
+ (
+ depth_gt[i : i + 1]
+ if depth_gt is not None
+ else None
+ ),
+ )
+
+ orig_hw = (H_orig, W_orig)
+
+ # Use depth_gt as input if available, otherwise monocular
+ depth_input_i = dgt_i if dgt_i is not None else None
+
+ (
+ depth_map_i,
+ depth_latents_i,
+ cls_token_i,
+ base_h_i,
+ base_w_i,
+ neck_out_i,
+ ) = self._run_encoder_and_decoder(
+ img_i, depth_input_i, orig_hw
+ )
+
+ K_pred_i = self._predict_intrinsics(
+ cls_token_i, H_orig, W_orig
+ )
+
+ # Run mask_head for confidence map
+ confidence_map_i = self._run_mask_head(
+ neck_out_i, H_orig, W_orig
+ )
+
+ # Repad depth_map
+ depth_maps_list.append(
+ self._repad_depth_map(
+ depth_map_i,
+ pad_left,
+ pad_right,
+ pad_top,
+ pad_bottom,
+ )
+ )
+
+ # Repad confidence_map
+ if confidence_map_i is not None:
+ confidence_maps_list.append(
+ self._repad_depth_map(
+ confidence_map_i,
+ pad_left,
+ pad_right,
+ pad_top,
+ pad_bottom,
+ )
+ )
+
+ # Repad depth_latents
+ depth_latents_list.append(
+ self._repad_depth_latents(
+ depth_latents_i,
+ base_h_i,
+ base_w_i,
+ base_h_pad,
+ base_w_pad,
+ pad_top,
+ pad_left,
+ H_pad,
+ W_pad,
+ )
+ )
+
+ # K_pred: restore to padded space (non-inplace for autograd)
+ K_pred_padded_i = K_pred_i.clone()
+ K_pred_padded_i[:, 0, 2] = K_pred_i[:, 0, 2] + pad_left
+ K_pred_padded_i[:, 1, 2] = K_pred_i[:, 1, 2] + pad_top
+ K_pred_list.append(K_pred_padded_i)
+
+ depth_map = torch.cat(depth_maps_list, dim=0)
+ depth_latents = torch.cat(depth_latents_list, dim=0)
+ K_pred = torch.cat(K_pred_list, dim=0)
+ confidence_map = (
+ torch.cat(confidence_maps_list, dim=0)
+ if confidence_maps_list
+ else None
+ )
+
+ depth_latents = self._maybe_detach_latents(depth_latents)
+
+ internal_hw = (base_h_pad * 14, base_w_pad * 14)
+ ray_intrinsics = self._scale_intrinsics(
+ intrinsics, (H_pad, W_pad), internal_hw
+ )
+
+ return GeometryBackendOutput(
+ depth_map=depth_map,
+ depth_latents=depth_latents,
+ K_pred=K_pred,
+ ray_intrinsics=ray_intrinsics,
+ ray_image_hw=internal_hw,
+ ray_downsample=14,
+ aux={
+ "depth_latents_hw": (base_h_pad, base_w_pad),
+ "confidence_map": confidence_map,
+ },
+ losses={},
+ )
+
+ def _forward_test_batched(
+ self,
+ images: Tensor,
+ intrinsics: Tensor,
+ image_hw: tuple[int, int],
+ depth_gt: Tensor | None,
+ ) -> GeometryBackendOutput:
+ """Original batched forward_test path (no unpadding)."""
+ H, W = image_hw
+
+ depth_input = depth_gt if depth_gt is not None else None
+ (
+ depth_map, depth_latents, cls_token,
+ base_h, base_w, neck_out,
+ ) = self._run_encoder_and_decoder(
+ images, depth_input, image_hw
+ )
+
+ depth_latents = self._maybe_detach_latents(depth_latents)
+ K_pred = self._predict_intrinsics(cls_token, H, W)
+
+ # Run mask_head for confidence map
+ confidence_map = self._run_mask_head(neck_out, H, W)
+
+ internal_hw = (base_h * 14, base_w * 14)
+ ray_intrinsics = self._scale_intrinsics(
+ intrinsics, (H, W), internal_hw
+ )
+
+ return GeometryBackendOutput(
+ depth_map=depth_map,
+ depth_latents=depth_latents,
+ K_pred=K_pred,
+ ray_intrinsics=ray_intrinsics,
+ ray_image_hw=internal_hw,
+ ray_downsample=14,
+ aux={
+ "depth_latents_hw": (base_h, base_w),
+ "confidence_map": confidence_map,
+ },
+ losses={},
+ )
diff --git a/wilddet3d/eval/__init__.py b/wilddet3d/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/eval/detect3d.py b/wilddet3d/eval/detect3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..78ca7f853a13298d515a09604ff080e25e01d3d8
--- /dev/null
+++ b/wilddet3d/eval/detect3d.py
@@ -0,0 +1,1734 @@
+"""3D Multiple Object Detection Evaluator."""
+
+import contextlib
+import copy
+import datetime
+import io
+import itertools
+import json
+import os
+import time
+from collections import defaultdict
+
+import numpy as np
+import pycocotools.mask as maskUtils
+import torch
+from pycocotools.cocoeval import COCOeval
+from scipy.spatial.distance import cdist
+from terminaltables import AsciiTable
+from vis4d.common.array import array_to_numpy
+from vis4d.common.distributed import all_gather_object_cpu
+from vis4d.common.typing import (
+ ArrayLike,
+ DictStrAny,
+ GenericFunc,
+ MetricLogs,
+ NDArrayF32,
+ NDArrayI64,
+)
+from vis4d.eval.base import Evaluator
+from vis4d.eval.coco.detect import xyxy_to_xywh
+
+from vis4d.data.const import AxisMode
+from vis4d.op.box.box3d import boxes3d_to_corners
+from vis4d.op.geometry.rotation import quaternion_to_matrix
+
+from wilddet3d.data.datasets.coco3d import COCO3D
+from wilddet3d.ops.box3d import box3d_overlap
+from wilddet3d.ops.rotation import so3_relative_angle
+
+
+def _canonicalize_rotation_np(R_cam, dims_whl):
+ """Canonicalize rotation for evaluation (numpy version).
+
+ Matches _normalize_canonical in coder.py. Eliminates 4-fold OBB
+ rotation ambiguity:
+ Step 1 - Force W <= L: if W > L, swap and apply Ry(90).
+ Step 2 - Normalize yaw to [0, pi): if yaw outside, apply Ry(180).
+
+ Args:
+ R_cam: 3x3 rotation matrix (numpy).
+ dims_whl: [W, H, L] dimensions (numpy or list).
+
+ Returns:
+ R_out: 3x3 canonical rotation matrix.
+ """
+ R_out = np.array(R_cam, dtype=np.float64).copy()
+ w, h, l = float(dims_whl[0]), float(dims_whl[1]), float(dims_whl[2])
+
+ # Step 1: Force W <= L
+ if w > l:
+ w, l = l, w
+ col0 = R_out[:, 0].copy()
+ R_out[:, 0] = -R_out[:, 2]
+ R_out[:, 2] = col0
+
+ # Step 2: Normalize yaw to [0, pi)
+ # YZX intrinsic: yaw = atan2(-R[2,0], R[0,0])
+ yaw = np.arctan2(-R_out[2, 0], R_out[0, 0])
+ if yaw < 0 or yaw > np.pi - 1e-4:
+ R_out[:, 0] = -R_out[:, 0]
+ R_out[:, 2] = -R_out[:, 2]
+
+ return R_out
+
+
+class Detect3DEvaluator(Evaluator):
+ """3D object detection evaluation with COCO format."""
+
+ def __init__(
+ self,
+ det_map: dict[str, int],
+ cat_map: dict[str, int],
+ annotation: str,
+ id2name: dict[int, str] | None = None,
+ per_class_eval: bool = True,
+ eval_prox: bool = False,
+ iou_type: str = "bbox",
+ num_columns: int = 6,
+ base_classes: list[str] | None = None,
+ # Frequency-based AP split (LVIS-style)
+ # Categories with APr
+ # Categories with rare_thresh..freq_thresh images -> APc
+ # Categories with >=freq_thresh images -> APf
+ freq_rare_thresh: int = 0,
+ freq_freq_thresh: int = 0,
+ # APRel3D parameters (LabelAny3D-style)
+ enable_aprel3d: bool = False,
+ aprel_2d_iou_thresh: float = 0.75,
+ ) -> None:
+ """Create an instance of the class."""
+ if id2name is None:
+ self.id2name = {v: k for k, v in det_map.items()}
+ else:
+ self.id2name = id2name
+
+ self.annotation = annotation
+ self.per_class_eval = per_class_eval
+ self.eval_prox = eval_prox
+ self.iou_type = iou_type
+ self.num_columns = num_columns
+ self.base_classes = base_classes
+
+ # APRel3D settings (LabelAny3D-style)
+ self.enable_aprel3d = enable_aprel3d
+ self.aprel_2d_iou_thresh = aprel_2d_iou_thresh
+
+ self.tp_errors = ["ATE", "AOE", "ASE"]
+
+ category_names = sorted(det_map, key=det_map.get)
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ self._coco_gt = COCO3D([annotation], category_names)
+
+ self.cat_map = cat_map
+
+ # Build frequency split if thresholds are set
+ self.freq_rare_thresh = freq_rare_thresh
+ self.freq_freq_thresh = freq_freq_thresh
+ self.cat_freq_group: dict[int, str] | None = None
+ if freq_rare_thresh > 0 and freq_freq_thresh > 0:
+ with open(annotation) as f:
+ ann_data = json.load(f)
+ cat_img_count: dict[int, set] = {}
+ for ann in ann_data["annotations"]:
+ cid = ann["category_id"]
+ if cid not in cat_img_count:
+ cat_img_count[cid] = set()
+ cat_img_count[cid].add(ann["image_id"])
+ self.cat_freq_group = {}
+ for cat in ann_data["categories"]:
+ n = len(cat_img_count.get(cat["id"], set()))
+ if n < freq_rare_thresh:
+ self.cat_freq_group[cat["id"]] = "rare"
+ elif n < freq_freq_thresh:
+ self.cat_freq_group[cat["id"]] = "common"
+ else:
+ self.cat_freq_group[cat["id"]] = "frequent"
+ n_r = sum(1 for v in self.cat_freq_group.values() if v == "rare")
+ n_c = sum(1 for v in self.cat_freq_group.values() if v == "common")
+ n_f = sum(1 for v in self.cat_freq_group.values() if v == "frequent")
+ print(f"[Detect3DEvaluator] Frequency split: "
+ f"rare(<{freq_rare_thresh})={n_r}, "
+ f"common({freq_rare_thresh}-{freq_freq_thresh})={n_c}, "
+ f"frequent(>={freq_freq_thresh})={n_f}")
+
+ self.bbox_2D_evals_per_cat_area: DictStrAny = {}
+ self.bbox_3D_evals_per_cat_area: DictStrAny = {}
+ self._predictions: list[DictStrAny] = []
+
+ # Store optimal scales for APRel3D
+ self.optimal_scales: dict[int, float] = {}
+
+ def __repr__(self) -> str:
+ """Returns the string representation of the object."""
+ return f"3D Object Detection Evaluator with {self.annotation}"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics.
+
+ Returns:
+ list[str]: Metrics to evaluate.
+ """
+ return ["2D", "3D"]
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across all processes.
+
+ Uses NCCL-based all_gather_object instead of vis4d's file-based
+ all_gather_object_cpu, which fails on weka cross-node due to
+ filesystem cache consistency issues.
+ """
+ import torch.distributed as dist
+
+ if not dist.is_initialized() or dist.get_world_size() == 1:
+ return
+
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+
+ # Use NCCL-based gathering (avoids cross-node filesystem issues)
+ all_preds = [None] * world_size
+ dist.all_gather_object(all_preds, self._predictions)
+
+ if rank == 0:
+ self._predictions = list(
+ itertools.chain(*all_preds)
+ )
+ else:
+ self._predictions = []
+
+ def reset(self) -> None:
+ """Reset the saved predictions to start new round of evaluation."""
+ self._predictions.clear()
+ self.bbox_2D_evals_per_cat_area.clear()
+ self.bbox_3D_evals_per_cat_area.clear()
+ self.optimal_scales.clear()
+
+ def _find_optimal_scale(
+ self, preds: list[DictStrAny], gts: list[DictStrAny]
+ ) -> float:
+ """Find optimal global scale factor (LabelAny3D method).
+
+ Ported from LabelAny3D compute_optimal_scale():
+ 1. Match each dt to best gt using 2D IoU (threshold 0.75).
+ 2. Grid search [0.1, 3.5] step 0.1, maximize avg 3D IoU.
+ """
+ # Collect dt/gt 3D corners and 2D boxes
+ dt_boxes = []
+ dt_boxes_2d = []
+ for pred in preds:
+ if "bbox3D" not in pred or "bbox" not in pred:
+ continue
+ dt_boxes.append(pred["bbox3D"])
+ # bbox is COCO [x, y, w, h], convert to [x1, y1, x2, y2]
+ b = pred["bbox"]
+ dt_boxes_2d.append([b[0], b[1], b[0] + b[2], b[1] + b[3]])
+
+ gt_boxes = []
+ gt_boxes_2d = []
+ for gt in gts:
+ if "bbox3D" not in gt or "bbox" not in gt:
+ continue
+ gt_boxes.append(gt["bbox3D"])
+ b = gt["bbox"]
+ gt_boxes_2d.append([b[0], b[1], b[0] + b[2], b[1] + b[3]])
+
+ if len(gt_boxes) == 0 or len(dt_boxes) == 0:
+ return 1.0
+
+ dt_boxes = np.array(dt_boxes, dtype=np.float32)
+ gt_boxes = np.array(gt_boxes, dtype=np.float32)
+
+ # Match each dt to the most similar gt using 2D IoU
+ matched_pairs = []
+ for dt_idx, dt_2d in enumerate(dt_boxes_2d):
+ best_iou = 0
+ best_gt_idx = -1
+
+ for gt_idx, gt_2d in enumerate(gt_boxes_2d):
+ x1 = max(dt_2d[0], gt_2d[0])
+ y1 = max(dt_2d[1], gt_2d[1])
+ x2 = min(dt_2d[2], gt_2d[2])
+ y2 = min(dt_2d[3], gt_2d[3])
+
+ if x2 <= x1 or y2 <= y1:
+ continue
+
+ inter_area = (x2 - x1) * (y2 - y1)
+ dt_area = (dt_2d[2] - dt_2d[0]) * (dt_2d[3] - dt_2d[1])
+ gt_area = (gt_2d[2] - gt_2d[0]) * (gt_2d[3] - gt_2d[1])
+ iou = inter_area / (dt_area + gt_area - inter_area)
+
+ if iou > best_iou:
+ best_iou = iou
+ best_gt_idx = gt_idx
+
+ if best_gt_idx >= 0 and best_iou > 0.75:
+ matched_pairs.append((dt_idx, best_gt_idx))
+
+ if len(matched_pairs) == 0:
+ return 1.0
+
+ def compute_avg_iou(scale):
+ avg_iou = 0.0
+ for dt_idx, gt_idx in matched_pairs:
+ scaled_dt_box = dt_boxes[dt_idx] * scale
+ dt_tensor = torch.tensor(
+ scaled_dt_box[np.newaxis, :, :],
+ dtype=torch.float32,
+ )
+ gt_tensor = torch.tensor(
+ gt_boxes[gt_idx][np.newaxis, :, :],
+ dtype=torch.float32,
+ )
+ iou = box3d_overlap(dt_tensor, gt_tensor).cpu().numpy()[0]
+ avg_iou += iou
+ return avg_iou / len(matched_pairs)
+
+ # Grid search: start with scale=1.0, then search [0.1, 3.5]
+ best_scale = 1.0
+ best_iou = compute_avg_iou(best_scale)
+
+ for scale in np.arange(0.1, 3.51, 0.1):
+ iou = compute_avg_iou(scale)
+ if iou > best_iou:
+ best_iou = iou
+ best_scale = scale
+
+ return best_scale
+
+ def _optimize_and_apply_scales(self) -> None:
+ """Optimize per-image scale and apply to all predictions."""
+ print("Optimizing scales for APRel3D (LabelAny3D method)...")
+ print(f" 2D IoU match threshold: {self.aprel_2d_iou_thresh}")
+
+ # Step 1: Group predictions by image
+ preds_by_image = defaultdict(list)
+ for pred in self._predictions:
+ preds_by_image[pred["image_id"]].append(pred)
+
+ # Step 2: Optimize scale for each image
+ n_matched_images = 0
+ for img_id, preds in preds_by_image.items():
+ gts = self._coco_gt.loadAnns(
+ self._coco_gt.getAnnIds(imgIds=[img_id])
+ )
+ if len(gts) == 0:
+ self.optimal_scales[img_id] = 1.0
+ continue
+
+ s_star = self._find_optimal_scale(preds, gts)
+ self.optimal_scales[img_id] = s_star
+ if s_star != 1.0:
+ n_matched_images += 1
+
+ # Step 3: Apply scales (direct corner multiplication)
+ scaled_predictions = []
+ for pred in self._predictions:
+ img_id = pred["image_id"]
+ scale = self.optimal_scales.get(img_id, 1.0)
+
+ scaled_pred = pred.copy()
+
+ if "center_cam" in pred:
+ scaled_pred["center_cam"] = [
+ c * scale for c in pred["center_cam"]
+ ]
+ if "dimensions" in pred:
+ scaled_pred["dimensions"] = [
+ d * scale for d in pred["dimensions"]
+ ]
+ if "bbox3D" in pred:
+ scaled_pred["bbox3D"] = [
+ [c * scale for c in corner]
+ for corner in pred["bbox3D"]
+ ]
+ if "depth" in pred:
+ scaled_pred["depth"] = pred["depth"] * scale
+
+ scaled_predictions.append(scaled_pred)
+
+ self._predictions = scaled_predictions
+
+ # Print statistics
+ scales = list(self.optimal_scales.values())
+ if len(scales) > 0:
+ print(f"APRel3D: {len(scales)} images, "
+ f"{n_matched_images} had 2D-IoU matches")
+ print(f" Mean scale: {np.mean(scales):.3f}")
+ print(f" Std scale: {np.std(scales):.3f}")
+ print(f" Min scale: {np.min(scales):.3f}")
+ print(f" Max scale: {np.max(scales):.3f}")
+
+ def process_batch(
+ self,
+ coco_image_id: list[int],
+ pred_boxes: list[ArrayLike],
+ pred_scores: list[ArrayLike],
+ pred_classes: list[ArrayLike],
+ pred_boxes3d: list[ArrayLike] | None = None,
+ ) -> None:
+ """Process sample and convert detections to coco format."""
+ for i, image_id in enumerate(coco_image_id):
+ boxes = array_to_numpy(
+ pred_boxes[i].to(torch.float32), n_dims=None, dtype=np.float32
+ )
+ scores = array_to_numpy(
+ pred_scores[i].to(torch.float32), n_dims=None, dtype=np.float32
+ )
+ classes = array_to_numpy(
+ pred_classes[i], n_dims=None, dtype=np.int64
+ )
+
+ if pred_boxes3d is not None:
+ boxes3d = array_to_numpy(
+ pred_boxes3d[i].to(torch.float32),
+ n_dims=None,
+ dtype=np.float32,
+ )
+ else:
+ boxes3d = None
+
+ self._predictions_to_coco(
+ image_id, boxes, boxes3d, scores, classes
+ )
+
+ def _predictions_to_coco(
+ self,
+ img_id: int,
+ boxes: NDArrayF32,
+ boxes3d: NDArrayF32 | None,
+ scores: NDArrayF32,
+ classes: NDArrayI64,
+ ) -> None:
+ """Convert predictions to COCO format."""
+ boxes_xyxy = copy.deepcopy(boxes)
+ boxes_xywh = xyxy_to_xywh(boxes_xyxy)
+
+ if boxes3d is not None:
+ # FIXME: Make axismode configurable
+ corners_3d = boxes3d_to_corners(
+ torch.from_numpy(boxes3d), AxisMode.OPENCV
+ )
+
+ for i, (box, box_score, box_class) in enumerate(
+ zip(boxes_xywh, scores, classes)
+ ):
+ xywh = box.tolist()
+
+ result = {
+ "image_id": img_id,
+ "bbox": xywh,
+ "category_id": self.cat_map[self.id2name[box_class.item()]],
+ "score": box_score.item(),
+ }
+
+ # mapping to Omni3D format
+ if boxes3d is not None:
+ result["center_cam"] = boxes3d[i][:3].tolist()
+
+ # wlh to whl
+ result["dimensions"] = boxes3d[i][[3, 5, 4]].tolist()
+
+ result["R_cam"] = (
+ quaternion_to_matrix(torch.from_numpy(boxes3d[i][6:10]))
+ .numpy()
+ .tolist()
+ )
+
+ corners = corners_3d[i].numpy().tolist()
+
+ result["bbox3D"] = [
+ corners[6],
+ corners[4],
+ corners[0],
+ corners[2],
+ corners[7],
+ corners[5],
+ corners[1],
+ corners[3],
+ ]
+
+ result["depth"] = boxes3d[i][2].item()
+
+ self._predictions.append(result)
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions."""
+ if metric == "2D":
+ metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"]
+ else:
+ if self.iou_type == "bbox":
+ if self.enable_aprel3d:
+ metrics = [
+ "APRel3D",
+ "APRel15",
+ "APRel25",
+ "APRel50",
+ "APReln",
+ "APRelm",
+ "APRelf",
+ ]
+ main_metric = "APRel3D"
+ else:
+ metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"]
+ main_metric = "AP"
+ else:
+ if self.enable_aprel3d:
+ metrics = ["APRel", "ATERel", "ASERel", "AOERel", "ODSRel", "ODSRelSym"]
+ main_metric = "ODSRel"
+ else:
+ metrics = ["AP", "ATE", "ASE", "AOE", "ODS", "ODS_Sym"]
+ main_metric = "ODS"
+
+ if self.base_classes is not None:
+ metrics += [f"{main_metric}_Base", f"{main_metric}_Novel"]
+
+ if len(self._predictions) == 0:
+ return {m: 0.0 for m in metrics}, "No predictions to evaluate."
+
+ # APRel3D: Optimize and apply scales before evaluation
+ if self.enable_aprel3d and metric == "3D":
+ self._optimize_and_apply_scales()
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ coco_dt = self._coco_gt.loadRes(self._predictions)
+
+ assert coco_dt is not None
+ evaluator = Detect3Deval(
+ self._coco_gt,
+ coco_dt,
+ mode=metric,
+ eval_prox=self.eval_prox,
+ iou_type=self.iou_type,
+ )
+ evaluator.evaluate()
+ evaluator.accumulate()
+
+ if self.iou_type == "bbox":
+ log_str = "\n" + evaluator.summarize()
+
+ # precision: (iou, recall, cls, area range, max dets)
+ precisions = evaluator.eval["precision"]
+ assert len(self._coco_gt.getCatIds()) == precisions.shape[2]
+
+ if metric == "2D":
+ self.bbox_2D_evals_per_cat_area = evaluator.evals_per_cat_area
+
+ score_dict = dict(zip(metrics, evaluator.stats))
+ else:
+ if self.iou_type == "bbox":
+ self.bbox_3D_evals_per_cat_area = evaluator.evals_per_cat_area
+
+ score_dict = dict(zip(metrics, evaluator.stats))
+
+ # Compute mASE, mAOE, mAOE_Sym for bbox mode
+ # Note: ATE is not returned in bbox mode because the normalization
+ # by IoU threshold makes it unreliable (can be > 1)
+ rot_tp_errors = evaluator.eval["rot_tp_errors"]
+ rot_sym_tp_errors = evaluator.eval["rot_sym_tp_errors"]
+ rot_canonical_tp_errors = evaluator.eval["rot_canonical_tp_errors"]
+ scale_tp_errors = evaluator.eval["scale_tp_errors"]
+
+ rot_tp = rot_tp_errors[:, :, :, 0, -1]
+ rot_tp = rot_tp[rot_tp > -1]
+
+ rot_sym_tp = rot_sym_tp_errors[:, :, :, 0, -1]
+ rot_sym_tp = rot_sym_tp[rot_sym_tp > -1]
+
+ rot_canonical_tp = rot_canonical_tp_errors[:, :, :, 0, -1]
+ rot_canonical_tp = rot_canonical_tp[rot_canonical_tp > -1]
+
+ scale_tp = scale_tp_errors[:, :, :, 0, -1]
+ scale_tp = scale_tp[scale_tp > -1]
+
+ if rot_tp.size:
+ mAOE = np.mean(rot_tp).item()
+ mAOE_Sym = np.mean(rot_sym_tp).item()
+ mAOE_Canonical = np.mean(rot_canonical_tp).item()
+ mASE = np.mean(scale_tp).item()
+ else:
+ mAOE = float("nan")
+ mAOE_Sym = float("nan")
+ mAOE_Canonical = float("nan")
+ mASE = float("nan")
+
+ # Add error metrics to output (no ATE in bbox mode)
+ if self.enable_aprel3d:
+ score_dict["ASERel"] = mASE
+ score_dict["AOERel"] = mAOE
+ score_dict["AOERelSym"] = mAOE_Sym
+ score_dict["AOERelCanonical"] = mAOE_Canonical
+ else:
+ score_dict["ASE"] = mASE
+ score_dict["AOE"] = mAOE
+ score_dict["AOE_Sym"] = mAOE_Sym
+ score_dict["AOE_Canonical"] = mAOE_Canonical
+
+ # Add scale statistics for APRel3D
+ if self.enable_aprel3d and len(self.optimal_scales) > 0:
+ scales = list(self.optimal_scales.values())
+ score_dict["mean_scale"] = np.mean(scales)
+ score_dict["std_scale"] = np.std(scales)
+ else:
+ trans_tp_errors = evaluator.eval["trans_tp_errors"]
+ rot_tp_errors = evaluator.eval["rot_tp_errors"]
+ rot_sym_tp_errors = evaluator.eval["rot_sym_tp_errors"]
+ rot_canonical_tp_errors = evaluator.eval["rot_canonical_tp_errors"]
+ scale_tp_errors = evaluator.eval["scale_tp_errors"]
+
+ precision = precisions[:, :, :, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ mAP = np.mean(precision).item()
+ else:
+ mAP = float("nan")
+
+ trans_tp = trans_tp_errors[:, :, :, 0, -1]
+ trans_tp = trans_tp[trans_tp > -1]
+
+ rot_tp = rot_tp_errors[:, :, :, 0, -1]
+ rot_tp = rot_tp[rot_tp > -1]
+
+ rot_sym_tp = rot_sym_tp_errors[:, :, :, 0, -1]
+ rot_sym_tp = rot_sym_tp[rot_sym_tp > -1]
+
+ rot_canonical_tp = rot_canonical_tp_errors[:, :, :, 0, -1]
+ rot_canonical_tp = rot_canonical_tp[rot_canonical_tp > -1]
+
+ scale_tp = scale_tp_errors[:, :, :, 0, -1]
+ scale_tp = scale_tp[scale_tp > -1]
+
+ if trans_tp.size:
+ mATE = np.mean(trans_tp).item()
+ mAOE = np.mean(rot_tp).item()
+ mAOE_Sym = np.mean(rot_sym_tp).item()
+ mAOE_Canonical = np.mean(rot_canonical_tp).item()
+ mASE = np.mean(scale_tp).item()
+
+ mODS = (
+ np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE) + (1 - mASE))
+ / 6
+ )
+ mODS_Sym = (
+ np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE_Sym) + (1 - mASE))
+ / 6
+ )
+ mODS_Canonical = (
+ np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE_Canonical) + (1 - mASE))
+ / 6
+ )
+
+ else:
+ mATE = float("nan")
+ mAOE = float("nan")
+ mAOE_Sym = float("nan")
+ mAOE_Canonical = float("nan")
+ mASE = float("nan")
+ mODS = float("nan")
+ mODS_Sym = float("nan")
+ mODS_Canonical = float("nan")
+
+ if self.enable_aprel3d:
+ score_dict = {
+ "APRel": mAP,
+ "ATERel": mATE,
+ "ASERel": mASE,
+ "AOERel": mAOE,
+ "AOERelSym": mAOE_Sym,
+ "AOERelCanonical": mAOE_Canonical,
+ "ODSRel": mODS,
+ "ODSRelSym": mODS_Sym,
+ "ODSRelCanonical": mODS_Canonical,
+ }
+ else:
+ score_dict = {
+ "AP": mAP,
+ "ATE": mATE,
+ "ASE": mASE,
+ "AOE": mAOE,
+ "AOE_Sym": mAOE_Sym,
+ "AOE_Canonical": mAOE_Canonical,
+ "ODS": mODS,
+ "ODS_Sym": mODS_Sym,
+ "ODS_Canonical": mODS_Canonical,
+ }
+
+ # Add scale statistics for APRel3D
+ if self.enable_aprel3d and len(self.optimal_scales) > 0:
+ scales = list(self.optimal_scales.values())
+ score_dict["mean_scale"] = np.mean(scales)
+ score_dict["std_scale"] = np.std(scales)
+
+ log_str = "\nHigh-level metrics:"
+ for k, v in score_dict.items():
+ log_str += f"\n{k}: {v:.4f}"
+
+ if self.per_class_eval:
+ results_per_category = []
+ score_base_list = []
+ score_novel_list = []
+ freq_ap: dict[str, list] = {"rare": [], "common": [], "frequent": []}
+
+ for idx, cat_id in enumerate(self._coco_gt.getCatIds()):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self._coco_gt.loadCats(cat_id)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision).item()
+ else:
+ ap = float("nan")
+
+ if self.iou_type == "dist":
+ trans_tp = trans_tp_errors[:, :, idx, 0, -1]
+ trans_tp = trans_tp[trans_tp > -1]
+
+ rot_tp = rot_tp_errors[:, :, idx, 0, -1]
+ rot_tp = rot_tp[rot_tp > -1]
+
+ rot_sym_tp = rot_sym_tp_errors[:, :, idx, 0, -1]
+ rot_sym_tp = rot_sym_tp[rot_sym_tp > -1]
+
+ scale_tp = scale_tp_errors[:, :, idx, 0, -1]
+ scale_tp = scale_tp[scale_tp > -1]
+
+ if trans_tp.size:
+ ate = np.mean(trans_tp).item()
+ aoe = np.mean(rot_tp).item()
+ aoe_sym = np.mean(rot_sym_tp).item()
+ ase = np.mean(scale_tp).item()
+
+ ods = (
+ np.sum(ap * 3 + (1 - ate) + (1 - aoe) + (1 - ase))
+ / 6
+ )
+ ods_sym = (
+ np.sum(ap * 3 + (1 - ate) + (1 - aoe_sym) + (1 - ase))
+ / 6
+ )
+
+ else:
+ ate = float("nan")
+ aoe = float("nan")
+ aoe_sym = float("nan")
+ ase = float("nan")
+ ods = float("nan")
+ ods_sym = float("nan")
+
+ results_per_category.append(
+ (
+ f'{nm["name"]}',
+ f"{ap:0.3f}",
+ f"{ate:0.3f}",
+ f"{ase:0.3f}",
+ f"{aoe:0.3f}",
+ f"{aoe_sym:0.3f}",
+ f"{ods:0.3f}",
+ f"{ods_sym:0.3f}",
+ )
+ )
+ else:
+ results_per_category.append(
+ (f'{nm["name"]}', f"{ap:0.3f}")
+ )
+
+ if self.base_classes is not None:
+ if self.iou_type == "dist":
+ score = ods
+ else:
+ score = ap
+
+ if nm["name"] in self.base_classes:
+ score_base_list.append(score)
+ else:
+ score_novel_list.append(score)
+
+ if self.cat_freq_group is not None and not np.isnan(ap):
+ group = self.cat_freq_group.get(cat_id, "rare")
+ freq_ap[group].append(ap)
+
+ results_flatten = list(itertools.chain(*results_per_category))
+
+ if self.iou_type == "dist":
+ num_columns = 8
+ headers = ["category", "AP", "ATE", "ASE", "AOE", "AOE_Sym", "ODS", "ODS_Sym"]
+ else:
+ num_columns = min(
+ self.num_columns, len(results_per_category) * 2
+ )
+ headers = ["category", "AP"] * (num_columns // 2)
+ results = itertools.zip_longest(
+ *[results_flatten[i::num_columns] for i in range(num_columns)]
+ )
+ table_data = [headers] + list(results)
+ if AsciiTable is not None:
+ table = AsciiTable(table_data)
+ log_str = f"\n{table.table}\n{log_str}"
+ else:
+ # Fallback when terminaltables is not installed.
+ log_str = f"\n(per-class table omitted; install terminaltables for pretty output)\n{log_str}"
+
+ if self.base_classes is not None:
+ score_dict[f"{main_metric}_Base"] = np.mean(score_base_list).item()
+ score_dict[f"{main_metric}_Novel"] = np.mean(
+ score_novel_list
+ ).item()
+
+ if self.cat_freq_group is not None and self.per_class_eval:
+ score_dict["APr"] = np.mean(freq_ap["rare"]).item() if freq_ap["rare"] else float("nan")
+ score_dict["APc"] = np.mean(freq_ap["common"]).item() if freq_ap["common"] else float("nan")
+ score_dict["APf"] = np.mean(freq_ap["frequent"]).item() if freq_ap["frequent"] else float("nan")
+ log_str += (
+ f"\nFrequency split (<{self.freq_rare_thresh}/{self.freq_freq_thresh}):"
+ f" APr={score_dict['APr']:.4f} ({len(freq_ap['rare'])} cats),"
+ f" APc={score_dict['APc']:.4f} ({len(freq_ap['common'])} cats),"
+ f" APf={score_dict['APf']:.4f} ({len(freq_ap['frequent'])} cats)"
+ )
+
+ return score_dict, log_str
+
+ def save(
+ self, metric: str, output_dir: str, prefix: str | None = None
+ ) -> None:
+ """Save the results to json files."""
+ assert metric in self.metrics
+
+ if prefix is not None:
+ result_folder = os.path.join(output_dir, prefix)
+ os.makedirs(result_folder, exist_ok=True)
+ else:
+ result_folder = output_dir
+
+ result_file = os.path.join(
+ result_folder, f"detect_{metric}_results.json"
+ )
+
+ with open(result_file, mode="w", encoding="utf-8") as f:
+ json.dump(self._predictions, f)
+
+
+class Detect3Deval(COCOeval):
+ """COCOeval Wrapper for 2D and 3D box evaluation.
+
+ Now it support bbox IoU matching only.
+ """
+
+ def __init__(
+ self,
+ cocoGt=None,
+ cocoDt=None,
+ mode: str = "2D",
+ iou_type: str = "bbox",
+ eval_prox: bool = False,
+ ):
+ """Initialize Detect3Deval using coco APIs for Gt and Dt.
+
+ Args:
+ cocoGt: COCO object with ground truth annotations
+ cocoDt: COCO object with detection results
+ mode: (str) defines whether to evaluate 2D or 3D performance.
+ One of {"2D", "3D"}
+ eval_prox: (bool) if True, performs "Proximity Evaluation", i.e.
+ evaluates detections in the proximity of the ground truth2D
+ boxes. This is used for datasets which are not exhaustively
+ annotated.
+ """
+ if mode not in {"2D", "3D"}:
+ raise Exception(f"{mode} mode is not supported")
+ self.mode = mode
+ self.iou_type = iou_type
+ self.eval_prox = eval_prox
+
+ self.cocoGt = cocoGt # ground truth COCO API
+ self.cocoDt = cocoDt # detections COCO API
+
+ # per-image per-category evaluation results [KxAxI] elements
+ self.evalImgs = defaultdict(list)
+
+ self.eval = {} # accumulated evaluation results
+ self._gts = defaultdict(list) # gt for evaluation
+ self._dts = defaultdict(list) # dt for evaluation
+ self.params = Detect3DParams(mode=mode, iouType=iou_type) # parameters
+ self._paramsEval = {} # parameters for evaluation
+ self.stats = [] # result summarization
+ self.ious = {} # ious between all gts and dts
+
+ if cocoGt is not None:
+ self.params.imgIds = sorted(cocoGt.getImgIds())
+ self.params.catIds = sorted(cocoGt.getCatIds())
+
+ self.evals_per_cat_area = None
+
+ def _prepare(self) -> None:
+ """Prepare ._gts and ._dts for evaluation based on params."""
+ p = self.params
+
+ if p.useCats:
+ gts = self.cocoGt.loadAnns(
+ self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
+ )
+ dts = self.cocoDt.loadAnns(
+ self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
+ )
+
+ else:
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
+
+ # set ignore flag
+ ignore_flag = "ignore2D" if self.mode == "2D" else "ignore3D"
+ for gt in gts:
+ gt[ignore_flag] = gt[ignore_flag] if ignore_flag in gt else 0
+
+ self._gts = defaultdict(list) # gt for evaluation
+ self._dts = defaultdict(list) # dt for evaluation
+
+ for gt in gts:
+ self._gts[gt["image_id"], gt["category_id"]].append(gt)
+
+ for dt in dts:
+ self._dts[dt["image_id"], dt["category_id"]].append(dt)
+
+ self.evalImgs = defaultdict(
+ list
+ ) # per-image per-category evaluation results
+ self.eval = {} # accumulated evaluation results
+
+ def accumulate(self, p=None) -> None:
+ """Accumulate per image evaluation and store the result in self.eval.
+
+ Args:
+ p: input params for evaluation
+ """
+ print("Accumulating evaluation results...")
+ assert self.evalImgs, "Please run evaluate() first"
+
+ tic = time.time()
+
+ # allows input customized parameters
+ if p is None:
+ p = self.params
+
+ p.catIds = p.catIds if p.useCats == 1 else [-1]
+
+ T = len(p.iouThrs)
+ R = len(p.recThrs)
+ K = len(p.catIds) if p.useCats else 1
+ A = len(p.areaRng)
+ M = len(p.maxDets)
+
+ precision = -np.ones(
+ (T, R, K, A, M)
+ ) # -1 for the precision of absent categories
+ trans_tp_errors = -np.ones((T, R, K, A, M))
+ rot_tp_errors = -np.ones((T, R, K, A, M))
+ rot_sym_tp_errors = -np.ones((T, R, K, A, M))
+ rot_canonical_tp_errors = -np.ones((T, R, K, A, M))
+ scale_tp_errors = -np.ones((T, R, K, A, M))
+ recall = -np.ones((T, K, A, M))
+ scores = -np.ones((T, R, K, A, M))
+
+ # create dictionary for future indexing
+ _pe = self._paramsEval
+
+ catIds = _pe.catIds if _pe.useCats else [-1]
+ setK = set(catIds)
+ setA = set(map(tuple, _pe.areaRng))
+ setM = set(_pe.maxDets)
+ setI = set(_pe.imgIds)
+
+ # get inds to evaluate
+ catid_list = [k for n, k in enumerate(p.catIds) if k in setK]
+ k_list = [n for n, k in enumerate(p.catIds) if k in setK]
+ m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
+ a_list = [
+ n
+ for n, a in enumerate(map(lambda x: tuple(x), p.areaRng))
+ if a in setA
+ ]
+ i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
+
+ I0 = len(_pe.imgIds)
+ A0 = len(_pe.areaRng)
+
+ has_precomputed_evals = not (self.evals_per_cat_area is None)
+
+ if has_precomputed_evals:
+ evals_per_cat_area = self.evals_per_cat_area
+ else:
+ evals_per_cat_area = {}
+
+ # retrieve E at each category, area range, and max number of detections
+ for k, (k0, catId) in enumerate(zip(k_list, catid_list)):
+ Nk = k0 * A0 * I0
+ for a, a0 in enumerate(a_list):
+ Na = a0 * I0
+
+ if has_precomputed_evals:
+ E = evals_per_cat_area.get((catId, a), [])
+
+ else:
+ E = [self.evalImgs[Nk + Na + i] for i in i_list]
+ E = [e for e in E if not e is None]
+ evals_per_cat_area[(catId, a)] = E
+
+ if len(E) == 0:
+ continue
+
+ for m, maxDet in enumerate(m_list):
+
+ dtScores = np.concatenate(
+ [e["dtScores"][0:maxDet] for e in E]
+ )
+
+ # different sorting method generates slightly different results.
+ # mergesort is used to be consistent as Matlab implementation.
+ inds = np.argsort(-dtScores, kind="mergesort")
+ dtScoresSorted = dtScores[inds]
+
+ dtm = np.concatenate(
+ [e["dtMatches"][:, 0:maxDet] for e in E], axis=1
+ )[:, inds]
+ dtIg = np.concatenate(
+ [e["dtIgnore"][:, 0:maxDet] for e in E], axis=1
+ )[:, inds]
+ gtIg = np.concatenate([e["gtIgnore"] for e in E])
+ npig = np.count_nonzero(gtIg == 0)
+
+ if npig == 0:
+ continue
+
+ tps = np.logical_and(dtm, np.logical_not(dtIg))
+ fps = np.logical_and(
+ np.logical_not(dtm), np.logical_not(dtIg)
+ )
+
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float64)
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float64)
+
+ # Compute TP error (for both bbox and dist modes)
+ tems = np.concatenate(
+ [e["dtTranslationError"][:, 0:maxDet] for e in E],
+ axis=1,
+ )[:, inds]
+
+ oems = np.concatenate(
+ [e["dtOrientationError"][:, 0:maxDet] for e in E],
+ axis=1,
+ )[:, inds]
+
+ oems_sym = np.concatenate(
+ [e["dtOrientationErrorSym"][:, 0:maxDet] for e in E],
+ axis=1,
+ )[:, inds]
+
+ oems_canonical = np.concatenate(
+ [e["dtOrientationErrorCanonical"][:, 0:maxDet] for e in E],
+ axis=1,
+ )[:, inds]
+
+ sems = np.concatenate(
+ [e["dtScaleError"][:, 0:maxDet] for e in E], axis=1
+ )[:, inds]
+
+ for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
+ tp = np.array(tp)
+ fp = np.array(fp)
+ nd = len(tp)
+ rc = tp / npig
+ pr = tp / (fp + tp + np.spacing(1))
+
+ q = np.zeros((R,))
+ ss = np.zeros((R,))
+ tran_tp_error = np.ones((R,))
+ rot_tp_error = np.ones((R,))
+ rot_sym_tp_error = np.ones((R,))
+ rot_canonical_tp_error = np.ones((R,))
+ scale_tp_error = np.ones((R,))
+
+ if nd:
+ recall[t, k, a, m] = rc[-1]
+
+ else:
+ recall[t, k, a, m] = 0
+
+ # numpy is slow without cython optimization for accessing elements
+ # use python array gets significant speed improvement
+ pr = pr.tolist()
+ q = q.tolist()
+ tran_tp_error = tran_tp_error.tolist()
+ rot_tp_error = rot_tp_error.tolist()
+ rot_sym_tp_error = rot_sym_tp_error.tolist()
+ rot_canonical_tp_error = rot_canonical_tp_error.tolist()
+ scale_tp_error = scale_tp_error.tolist()
+
+ for i in range(nd - 1, 0, -1):
+ if pr[i] > pr[i - 1]:
+ pr[i - 1] = pr[i]
+
+ inds = np.searchsorted(rc, p.recThrs, side="left")
+
+ try:
+ for ri, pi in enumerate(inds):
+ q[ri] = pr[pi]
+ ss[ri] = dtScoresSorted[pi]
+ # Store errors for both bbox and dist modes
+ tran_tp_error[ri] = tems[t][pi]
+ rot_tp_error[ri] = oems[t][pi]
+ rot_sym_tp_error[ri] = oems_sym[t][pi]
+ rot_canonical_tp_error[ri] = oems_canonical[t][pi]
+ scale_tp_error[ri] = sems[t][pi]
+ except:
+ pass
+
+ precision[t, :, k, a, m] = np.array(q)
+ scores[t, :, k, a, m] = np.array(ss)
+
+ # Store errors for both bbox and dist modes
+ trans_tp_errors[t, :, k, a, m] = np.array(
+ tran_tp_error
+ )
+ rot_tp_errors[t, :, k, a, m] = np.array(
+ rot_tp_error
+ )
+ rot_sym_tp_errors[t, :, k, a, m] = np.array(
+ rot_sym_tp_error
+ )
+ rot_canonical_tp_errors[t, :, k, a, m] = np.array(
+ rot_canonical_tp_error
+ )
+ scale_tp_errors[t, :, k, a, m] = np.array(
+ scale_tp_error
+ )
+
+ self.evals_per_cat_area = evals_per_cat_area
+
+ self.eval = {
+ "params": p,
+ "counts": [T, R, K, A, M],
+ "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+ "precision": precision,
+ "recall": recall,
+ "scores": scores,
+ "trans_tp_errors": trans_tp_errors,
+ "rot_tp_errors": rot_tp_errors,
+ "rot_sym_tp_errors": rot_sym_tp_errors,
+ "rot_canonical_tp_errors": rot_canonical_tp_errors,
+ "scale_tp_errors": scale_tp_errors,
+ }
+
+ toc = time.time()
+ print("DONE (t={:0.2f}s).".format(toc - tic))
+
+ def evaluate(self) -> None:
+ """Run per image evaluation on given images.
+
+ It will store results (a list of dict) in self.evalImgs
+ """
+ print("Running per image evaluation...")
+
+ p = self.params
+ print(f"Evaluate annotation type *{p.iouType}*")
+
+ tic = time.time()
+
+ p.imgIds = list(np.unique(p.imgIds))
+ if p.useCats:
+ p.catIds = list(np.unique(p.catIds))
+
+ p.maxDets = sorted(p.maxDets)
+ self.params = p
+
+ self._prepare()
+
+ catIds = p.catIds if p.useCats else [-1]
+
+ # loop through images, area range, max detection number
+ self.ious = {
+ (imgId, catId): self.computeIoU(imgId, catId)
+ for imgId in p.imgIds
+ for catId in catIds
+ }
+
+ maxDet = p.maxDets[-1]
+
+ self.evalImgs = [
+ self.evaluateImg(imgId, catId, areaRng, maxDet)
+ for catId in catIds
+ for areaRng in p.areaRng
+ for imgId in p.imgIds
+ ]
+
+ self._paramsEval = copy.deepcopy(self.params)
+
+ toc = time.time()
+ print("DONE (t={:0.2f}s).".format(toc - tic))
+
+ def computeIoU(self, imgId, catId) -> tuple[NDArrayF32, NDArrayF32]:
+ """Computes the IoUs by sorting based on score"""
+ p = self.params
+
+ if p.useCats:
+ gt = self._gts[imgId, catId]
+ dt = self._dts[imgId, catId]
+ else:
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
+
+ if len(gt) == 0 and len(dt) == 0:
+ return []
+
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
+ dt = [dt[i] for i in inds]
+ if len(dt) > p.maxDets[-1]:
+ dt = dt[0 : p.maxDets[-1]]
+
+ if self.mode == "2D":
+ g = [g["bbox"] for g in gt]
+ d = [d["bbox"] for d in dt]
+ elif self.mode == "3D":
+ g = [g["bbox3D"] for g in gt]
+ d = [d["bbox3D"] for d in dt]
+
+ # compute iou between each dt and gt region
+ # iscrowd is required in builtin maskUtils so we
+ # use a dummy buffer for it
+ iscrowd = [0 for _ in gt]
+ if self.mode == "2D":
+ ious = maskUtils.iou(d, g, iscrowd)
+ elif len(d) > 0 and len(g) > 0:
+ if p.iouType == "bbox":
+ dd = torch.tensor(d, dtype=torch.float32)
+ gg = torch.tensor(g, dtype=torch.float32)
+
+ ious = box3d_overlap(dd, gg).cpu().numpy()
+ else:
+ ious = np.zeros((len(d), len(g)))
+
+ dd = [d["center_cam"] for d in dt]
+ gg = [g["center_cam"] for g in gt]
+
+ ious = cdist(dd, gg, metric="euclidean")
+ else:
+ ious = []
+
+ in_prox = None
+
+ if self.eval_prox:
+ g = [g["bbox"] for g in gt]
+ d = [d["bbox"] for d in dt]
+ iscrowd = [0 for o in gt]
+ ious2d = maskUtils.iou(d, g, iscrowd)
+
+ if type(ious2d) == list:
+ in_prox = []
+
+ else:
+ in_prox = ious2d > p.proximity_thresh
+
+ return ious, in_prox
+
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
+ """
+ Perform evaluation for single category and image
+ Returns:
+ dict (single image results)
+ """
+
+ p = self.params
+ if p.useCats:
+ gt = self._gts[imgId, catId]
+ dt = self._dts[imgId, catId]
+
+ else:
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
+
+ if len(gt) == 0 and len(dt) == 0:
+ return None
+
+ flag_range = "area" if self.mode == "2D" else "depth"
+ flag_ignore = "ignore2D" if self.mode == "2D" else "ignore3D"
+
+ for g in gt:
+ if g[flag_ignore] or (
+ g[flag_range] < aRng[0] or g[flag_range] > aRng[1]
+ ):
+ g["_ignore"] = 1
+ else:
+ g["_ignore"] = 0
+
+ # sort dt highest score first, sort gt ignore last
+ gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
+ gt = [gt[i] for i in gtind]
+ dtind = np.argsort([-d["score"] for d in dt], kind="mergesort")
+ dt = [dt[i] for i in dtind[0:maxDet]]
+
+ # load computed ious
+ ious = (
+ self.ious[imgId, catId][0][:, gtind]
+ if len(self.ious[imgId, catId][0]) > 0
+ else self.ious[imgId, catId][0]
+ )
+
+ if self.eval_prox:
+ in_prox = (
+ self.ious[imgId, catId][1][:, gtind]
+ if len(self.ious[imgId, catId][1]) > 0
+ else self.ious[imgId, catId][1]
+ )
+
+ T = len(p.iouThrs)
+ G = len(gt)
+ D = len(dt)
+ gtm = np.zeros((T, G))
+ dtm = np.zeros((T, D))
+ tem = np.ones((T, D)) # Translation Error
+ sem = np.ones((T, D)) # Scale Error
+ oem = np.ones((T, D)) # Oritentation Error
+ oem_sym = np.ones((T, D)) # Symmetric Orientation Error (mod 180)
+ oem_canonical = np.ones((T, D)) # Canonical Orientation Error
+ gtIg = np.array([g["_ignore"] for g in gt])
+ dtIg = np.zeros((T, D))
+
+ dist_thres = 1
+ if not len(ious) == 0:
+ for tind, t in enumerate(p.iouThrs):
+ for dind, d in enumerate(dt):
+
+ # information about best match so far (m=-1 -> unmatched)
+ iou = min([t, 1 - 1e-10])
+ m = -1
+
+ for gind, g in enumerate(gt):
+ # in case of proximity evaluation, if not in proximity continue
+ if self.eval_prox and not in_prox[dind, gind]:
+ continue
+
+ # if this gt already matched, continue
+ if gtm[tind, gind] > 0:
+ continue
+
+ # if dt matched to reg gt, and on ignore gt, stop
+ if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
+ break
+
+ # continue to next gt unless better match made
+ if p.iouType == "bbox" and ious[dind, gind] < iou:
+ continue
+
+ if p.iouType == "dist":
+ # Compute Object Radius
+ gt_obj_radius = (
+ np.linalg.norm(np.array(g["dimensions"])) / 2
+ )
+ if ious[dind, gind] > gt_obj_radius * iou:
+ continue
+ else:
+ dist_thres = gt_obj_radius * iou
+
+ # if match successful and best so far, store appropriately
+ iou = ious[dind, gind]
+ m = gind
+
+ # if match made store id of match for both dt and gt
+ if m == -1:
+ continue
+
+ dtIg[tind, dind] = gtIg[m]
+ dtm[tind, dind] = gt[m]["id"]
+ gtm[tind, m] = d["id"]
+
+ # Compute errors for both bbox and dist modes
+ # (previously only computed for dist mode)
+
+ # Compute GT object radius for normalization
+ gt_obj_radius = (
+ np.linalg.norm(np.array(gt[m]["dimensions"])) / 2
+ )
+
+ # Translation Error
+ if p.iouType == "dist":
+ # For dist mode: normalize by distance threshold
+ # (dist_thres was computed during matching)
+ tem[tind, dind] = np.linalg.norm(
+ np.array(d["center_cam"])
+ - np.array(gt[m]["center_cam"])
+ ) / (dist_thres)
+ else:
+ # For bbox mode: normalize by distance threshold
+ # (same as dist mode, for consistency)
+ dist_thres_bbox = gt_obj_radius * t
+ tem[tind, dind] = np.linalg.norm(
+ np.array(d["center_cam"])
+ - np.array(gt[m]["center_cam"])
+ ) / (dist_thres_bbox + 1e-6)
+
+ # Orientation Error (same for both modes)
+ try:
+ angle = so3_relative_angle(
+ torch.tensor(d["R_cam"])[None],
+ torch.tensor(gt[m]["R_cam"])[None],
+ cos_bound=1e-2,
+ eps=1e-3,
+ ).item()
+ oem[tind, dind] = angle / np.pi
+ # Symmetric: fold 180 ambiguity, min(angle, pi-angle)
+ # range [0, pi/2], normalized by pi/2 to [0, 1]
+ oem_sym[tind, dind] = min(angle, np.pi - angle) / (np.pi / 2)
+
+ # Canonical: normalize both to canonical form
+ # (W<=L + yaw [0,pi)) before computing angle
+ R_pred_c = _canonicalize_rotation_np(
+ d["R_cam"], d["dimensions"]
+ )
+ R_gt_c = _canonicalize_rotation_np(
+ gt[m]["R_cam"], gt[m]["dimensions"]
+ )
+ angle_c = so3_relative_angle(
+ torch.tensor(R_pred_c)[None],
+ torch.tensor(R_gt_c)[None],
+ cos_bound=1e-2,
+ eps=1e-3,
+ ).item()
+ oem_canonical[tind, dind] = angle_c / np.pi
+ except ValueError as e:
+ # Skip invalid rotation matrix pairs
+ # This can happen when GT or prediction has numerical precision issues
+ import warnings
+ R_pred = np.array(d["R_cam"])
+ R_gt = np.array(gt[m]["R_cam"])
+ R_rel = R_pred @ R_gt.T
+ warnings.warn(
+ f"Skipping rotation error for img={imgId}, cat={catId}: {e}\n"
+ f" det(R_pred)={np.linalg.det(R_pred):.6f}, "
+ f"det(R_gt)={np.linalg.det(R_gt):.6f}, "
+ f"trace(R_rel)={np.trace(R_rel):.6f}"
+ )
+ # Set to maximum error (180 degrees = 1.0 in normalized units)
+ oem[tind, dind] = 1.0
+ oem_sym[tind, dind] = 1.0
+ oem_canonical[tind, dind] = 1.0
+
+ # Scale Error (same for both modes)
+ min_whl = np.minimum(
+ d["dimensions"], gt[m]["dimensions"]
+ )
+ volume_annotation = np.prod(gt[m]["dimensions"])
+ volume_result = np.prod(d["dimensions"])
+
+ intersection = np.prod(min_whl)
+ union = (
+ volume_annotation + volume_result - intersection
+ )
+ scale_iou = intersection / union
+
+ sem[tind, dind] = 1 - scale_iou
+
+ # set unmatched detections outside of area range to ignore
+ a = np.array(
+ [d[flag_range] < aRng[0] or d[flag_range] > aRng[1] for d in dt]
+ ).reshape((1, len(dt)))
+
+ dtIg = np.logical_or(
+ dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0))
+ )
+
+ # in case of proximity evaluation, ignore detections which are far from gt regions
+ if self.eval_prox and len(in_prox) > 0:
+ dt_far = in_prox.any(1) == 0
+ dtIg = np.logical_or(
+ dtIg, np.repeat(dt_far.reshape((1, len(dt))), T, 0)
+ )
+
+ # store results for given image and category
+ return {
+ "image_id": imgId,
+ "category_id": catId,
+ "aRng": aRng,
+ "maxDet": maxDet,
+ "dtIds": [d["id"] for d in dt],
+ "gtIds": [g["id"] for g in gt],
+ "dtMatches": dtm,
+ "gtMatches": gtm,
+ "dtScores": [d["score"] for d in dt],
+ "gtIgnore": gtIg,
+ "dtIgnore": dtIg,
+ "dtTranslationError": tem,
+ "dtScaleError": sem,
+ "dtOrientationError": oem,
+ "dtOrientationErrorSym": oem_sym,
+ "dtOrientationErrorCanonical": oem_canonical,
+ }
+
+ def summarize(self):
+ """
+ Compute and display summary metrics for evaluation results.
+ Note this functin can *only* be applied on the default parameter setting
+ """
+
+ def _summarize(
+ mode, ap=1, iouThr=None, areaRng="all", maxDets=100, log_str=""
+ ):
+ p = self.params
+ eval = self.eval
+
+ if mode == "2D":
+ if self.iou_type == "bbox":
+ iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
+ else:
+ iStr = " {:<18} {} @[ Dist={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
+
+ elif mode == "3D":
+ if self.iou_type == "bbox":
+ iStr = " {:<18} {} @[ IoU={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}"
+ else:
+ iStr = " {:<18} {} @[ Dist={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}"
+
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
+ typeStr = "(AP)" if ap == 1 else "(AR)"
+
+ iouStr = (
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
+ if iouThr is None
+ else "{:0.2f}".format(iouThr)
+ )
+
+ aind = [
+ i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng
+ ]
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+
+ if ap == 1:
+
+ # dimension of precision: [TxRxKxAxM]
+ s = eval["precision"]
+
+ # IoU
+ if iouThr is not None:
+ t = np.where(np.isclose(iouThr, p.iouThrs.astype(float)))[
+ 0
+ ]
+ s = s[t]
+
+ s = s[:, :, :, aind, mind]
+
+ else:
+ # dimension of recall: [TxKxAxM]
+ s = eval["recall"]
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, aind, mind]
+
+ if len(s[s > -1]) == 0:
+ mean_s = -1
+
+ else:
+ mean_s = np.mean(s[s > -1])
+
+ if log_str != "":
+ log_str += "\n"
+
+ log_str += "mode={} ".format(mode) + iStr.format(
+ titleStr, typeStr, iouStr, areaRng, maxDets, mean_s
+ )
+
+ return mean_s, log_str
+
+ def _summarizeDets(mode):
+
+ params = self.params
+
+ # Define the thresholds to be printed
+ if mode == "2D":
+ thres = [0.5, 0.75, 0.95]
+ else:
+ if self.iou_type == "bbox":
+ thres = [0.15, 0.25, 0.50]
+ else:
+ thres = [0.5, 0.75, 1.0]
+
+ stats = np.zeros((13,))
+ stats[0], log_str = _summarize(mode, 1)
+
+ stats[1], log_str = _summarize(
+ mode,
+ 1,
+ iouThr=thres[0],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[2], log_str = _summarize(
+ mode,
+ 1,
+ iouThr=thres[1],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[3], log_str = _summarize(
+ mode,
+ 1,
+ iouThr=thres[2],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[4], log_str = _summarize(
+ mode,
+ 1,
+ areaRng=params.areaRngLbl[1],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[5], log_str = _summarize(
+ mode,
+ 1,
+ areaRng=params.areaRngLbl[2],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[6], log_str = _summarize(
+ mode,
+ 1,
+ areaRng=params.areaRngLbl[3],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[7], log_str = _summarize(
+ mode, 0, maxDets=params.maxDets[0], log_str=log_str
+ )
+
+ stats[8], log_str = _summarize(
+ mode, 0, maxDets=params.maxDets[1], log_str=log_str
+ )
+
+ stats[9], log_str = _summarize(
+ mode, 0, maxDets=params.maxDets[2], log_str=log_str
+ )
+
+ stats[10], log_str = _summarize(
+ mode,
+ 0,
+ areaRng=params.areaRngLbl[1],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[11], log_str = _summarize(
+ mode,
+ 0,
+ areaRng=params.areaRngLbl[2],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ stats[12], log_str = _summarize(
+ mode,
+ 0,
+ areaRng=params.areaRngLbl[3],
+ maxDets=params.maxDets[2],
+ log_str=log_str,
+ )
+
+ return stats, log_str
+
+ if not self.eval:
+ raise Exception("Please run accumulate() first")
+
+ stats, log_str = _summarizeDets(self.mode)
+ self.stats = stats
+
+ return log_str
+
+
+class Detect3DParams:
+ """Params for the 3d detection evaluation API."""
+
+ def __init__(
+ self,
+ mode: str = "2D",
+ iouType: str = "bbox",
+ proximity_thresh: float = 0.3,
+ ) -> None:
+ """Create an instance of Detect3DParams.
+
+ Args:
+ mode: (str) defines whether to evaluate 2D or 3D performance.
+ iouType: (str) defines the type of IoU to be used for evaluation.
+ proximity_thresh (float): It defines the neighborhood when
+ evaluating on non-exhaustively annotated datasets.
+ """
+ assert iouType in {"bbox", "dist"}, f"Invalid iouType {iouType}."
+ self.iouType = iouType
+
+ if mode == "2D":
+ self.setDet2DParams()
+ elif mode == "3D":
+ self.setDet3DParams()
+ else:
+ raise Exception(f"{mode} mode is not supported")
+ self.mode = mode
+ self.proximity_thresh = proximity_thresh
+
+ def setDet2DParams(self) -> None:
+ """Set parameters for 2D detection evaluation."""
+ self.imgIds = []
+ self.catIds = []
+
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
+ self.iouThrs = np.linspace(
+ 0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
+ )
+
+ self.recThrs = np.linspace(
+ 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
+ )
+
+ self.maxDets = [1, 10, 100]
+ self.areaRng = [
+ [0**2, 1e5**2],
+ [0**2, 32**2],
+ [32**2, 96**2],
+ [96**2, 1e5**2],
+ ]
+
+ self.areaRngLbl = ["all", "small", "medium", "large"]
+ self.useCats = 1
+
+ def setDet3DParams(self) -> None:
+ """Set parameters for 3D detection evaluation."""
+ self.imgIds = []
+ self.catIds = []
+
+ # np.arange causes trouble. The data point on arange is slightly
+ # larger than the true value
+ if self.iouType == "bbox":
+ self.iouThrs = np.linspace(
+ 0.05,
+ 0.5,
+ int(np.round((0.5 - 0.05) / 0.05)) + 1,
+ endpoint=True,
+ )
+ else:
+ self.iouThrs = np.linspace(
+ 0.5, 1.0, int(np.round((1.00 - 0.5) / 0.05)) + 1, endpoint=True
+ )
+
+ self.recThrs = np.linspace(
+ 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
+ )
+
+ self.maxDets = [1, 10, 100]
+ self.areaRng = [[0, 1e5], [0, 10], [10, 35], [35, 1e5]]
+ self.areaRngLbl = ["all", "near", "medium", "far"]
+ self.useCats = 1
diff --git a/wilddet3d/eval/omni3d.py b/wilddet3d/eval/omni3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..77fc2805c8faeb02dc7ced5cd636f461a7062358
--- /dev/null
+++ b/wilddet3d/eval/omni3d.py
@@ -0,0 +1,378 @@
+"""Omni3D 3D detection evaluation."""
+
+import contextlib
+import copy
+import io
+import itertools
+import os
+from collections.abc import Sequence
+
+import numpy as np
+from terminaltables import AsciiTable
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
+from vis4d.eval.base import Evaluator
+
+from wilddet3d.data.datasets.omni3d.omni3d_classes import omni3d_class_map
+from wilddet3d.data.datasets.omni3d.util import get_dataset_det_map
+
+from .detect3d import Detect3Deval, Detect3DEvaluator
+
+omni3d_in = {
+ "stationery",
+ "sink",
+ "table",
+ "floor mat",
+ "bottle",
+ "bookcase",
+ "bin",
+ "blinds",
+ "pillow",
+ "bicycle",
+ "refrigerator",
+ "night stand",
+ "chair",
+ "sofa",
+ "books",
+ "oven",
+ "towel",
+ "cabinet",
+ "window",
+ "curtain",
+ "bathtub",
+ "laptop",
+ "desk",
+ "television",
+ "clothes",
+ "stove",
+ "cup",
+ "shelves",
+ "box",
+ "shoes",
+ "mirror",
+ "door",
+ "picture",
+ "lamp",
+ "machine",
+ "counter",
+ "bed",
+ "toilet",
+}
+
+omni3d_out = {
+ "cyclist",
+ "pedestrian",
+ "trailer",
+ "bus",
+ "motorcycle",
+ "car",
+ "barrier",
+ "truck",
+ "van",
+ "traffic cone",
+ "bicycle",
+}
+
+
+class Omni3DEvaluator(Evaluator):
+ """Omni3D 3D detection evaluator."""
+
+ def __init__(
+ self,
+ data_root: str = "data/omni3d",
+ omni3d50: bool = True,
+ datasets: Sequence[str] = (
+ "KITTI_test",
+ "nuScenes_test",
+ "SUNRGBD_test",
+ "Hypersim_test",
+ "ARKitScenes_test",
+ "Objectron_test",
+ ),
+ per_class_eval: bool = True,
+ # APRel3D parameters (LabelAny3D-style)
+ enable_aprel3d: bool = False,
+ aprel_2d_iou_thresh: float = 0.75,
+ # Mini dataset support
+ use_mini_dataset: bool = False,
+ ) -> None:
+ """Initialize the evaluator.
+
+ Args:
+ data_root: Root directory for Omni3D data.
+ omni3d50: Whether to use Omni3D-50 class mapping.
+ datasets: List of dataset names to evaluate.
+ per_class_eval: Whether to evaluate per-class metrics.
+ enable_aprel3d: Whether to enable APRel3D evaluation.
+ aprel_2d_iou_thresh: 2D IoU threshold for matching (default 0.75).
+ use_mini_dataset: If True, use annotations_mini100/ for GT.
+ """
+ super().__init__()
+ self.id_to_name = {v: k for k, v in omni3d_class_map.items()}
+ self.dataset_names = datasets
+ self.per_class_eval = per_class_eval
+ self.enable_aprel3d = enable_aprel3d
+ self.aprel_2d_iou_thresh = aprel_2d_iou_thresh
+ self.use_mini_dataset = use_mini_dataset
+
+ # Each dataset evaluator is stored here
+ self.evaluators: dict[str, Detect3DEvaluator] = {}
+
+ # These store the evaluations for each category and area,
+ # concatenated from ALL evaluated datasets. Doing so avoids
+ # the need to re-compute them when accumulating results.
+ self.evals_per_cat_area2D = {}
+ self.evals_per_cat_area3D = {}
+
+ self.overall_imgIds = set()
+ self.overall_catIds = set()
+
+ # Determine annotation directory based on mini dataset flag
+ if use_mini_dataset:
+ annotation_dir = os.path.join(data_root, "annotations_mini100")
+ else:
+ annotation_dir = os.path.join(data_root, "annotations")
+
+ for dataset_name in self.dataset_names:
+ annotation = os.path.join(
+ annotation_dir, f"{dataset_name}.json"
+ )
+
+ det_map = get_dataset_det_map(
+ dataset_name=dataset_name, omni3d50=omni3d50
+ )
+
+ # create an individual dataset evaluator
+ self.evaluators[dataset_name] = Detect3DEvaluator(
+ det_map,
+ cat_map=omni3d_class_map,
+ annotation=annotation,
+ eval_prox=(
+ "Objectron" in dataset_name or "SUNRGBD" in dataset_name
+ ),
+ enable_aprel3d=enable_aprel3d,
+ aprel_2d_iou_thresh=aprel_2d_iou_thresh,
+ )
+
+ self.overall_imgIds.update(
+ set(self.evaluators[dataset_name]._coco_gt.getImgIds())
+ )
+ self.overall_catIds.update(
+ set(self.evaluators[dataset_name]._coco_gt.getCatIds())
+ )
+
+ def __repr__(self) -> str:
+ """Returns the string representation of the object."""
+ datasets_str = ", ".join(self.dataset_names)
+ return f"Omni3DEvaluator ({datasets_str})"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics.
+
+ Returns:
+ list[str]: Metrics to evaluate.
+ """
+ return ["2D", "3D"]
+
+ def reset(self) -> None:
+ """Reset the saved predictions to start new round of evaluation."""
+ for dataset_name in self.dataset_names:
+ self.evaluators[dataset_name].reset()
+ self.evals_per_cat_area2D.clear()
+ self.evals_per_cat_area3D.clear()
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across processes."""
+ for dataset_name in self.dataset_names:
+ self.evaluators[dataset_name].gather(gather_func)
+
+ def process_batch(
+ self,
+ coco_image_id: list[int],
+ dataset_names: list[str],
+ pred_boxes: list[NDArrayNumber],
+ pred_scores: list[NDArrayNumber],
+ pred_classes: list[NDArrayNumber],
+ pred_boxes3d: list[NDArrayNumber] | None = None,
+ ) -> None:
+ """Process sample and convert detections to coco format."""
+ # Handle empty batch (can happen when all images have 0 GT boxes)
+ if dataset_names is None or len(dataset_names) == 0:
+ return
+ for i, dataset_name in enumerate(dataset_names):
+ self.evaluators[dataset_name].process_batch(
+ [coco_image_id[i]],
+ [pred_boxes[i]],
+ [pred_scores[i]],
+ [pred_classes[i]],
+ pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
+ )
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions and return the results."""
+ assert metric in self.metrics, f"Unsupported metric: {metric}"
+
+ log_dict = {}
+ per_dataset_results = {} # Store results for later aggregation
+
+ for dataset_name in self.dataset_names:
+ rank_zero_info(f"Evaluating {dataset_name}...")
+ per_dataset_log_dict, dataset_log_str = self.evaluators[
+ dataset_name
+ ].evaluate(metric)
+
+ per_dataset_results[dataset_name] = per_dataset_log_dict
+
+ # Get the main metric key (APRel3D/APRel in APRel mode, AP otherwise)
+ # Priority: APRel3D > APRel > AP
+ if "APRel3D" in per_dataset_log_dict:
+ main_metric_key = "APRel3D"
+ elif "APRel" in per_dataset_log_dict:
+ main_metric_key = "APRel"
+ elif "AP" in per_dataset_log_dict:
+ main_metric_key = "AP"
+ else:
+ # Fallback: use the first key that starts with "AP"
+ main_metric_key = next((k for k in per_dataset_log_dict.keys() if k.startswith("AP")), "AP")
+
+ log_dict[f"AP_{dataset_name}"] = per_dataset_log_dict[main_metric_key]
+
+ rank_zero_info(dataset_log_str + "\n")
+
+ # store the partially accumulated evaluations per category per area
+ if metric == "2D":
+ for key, item in self.evaluators[
+ dataset_name
+ ].bbox_2D_evals_per_cat_area.items():
+ if not key in self.evals_per_cat_area2D:
+ self.evals_per_cat_area2D[key] = []
+ self.evals_per_cat_area2D[key] += item
+ else:
+ for key, item in self.evaluators[
+ dataset_name
+ ].bbox_3D_evals_per_cat_area.items():
+ if not key in self.evals_per_cat_area3D:
+ self.evals_per_cat_area3D[key] = []
+ self.evals_per_cat_area3D[key] += item
+
+ results_per_category_dict = {}
+ results_per_category = []
+
+ rank_zero_info(f"Evaluating Omni3D for {metric} Detection...")
+
+ evaluator = Detect3Deval(mode=metric)
+ evaluator.params.catIds = list(self.overall_catIds)
+ evaluator.params.imgIds = list(self.overall_imgIds)
+ evaluator.evalImgs = True
+
+ if metric == "2D":
+ evaluator.evals_per_cat_area = self.evals_per_cat_area2D
+ metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"]
+ else:
+ evaluator.evals_per_cat_area = self.evals_per_cat_area3D
+ if self.enable_aprel3d:
+ metrics = ["APRel3D", "APRel15", "APRel25", "APRel50", "APReln", "APRelm", "APRelf"]
+ else:
+ metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"]
+
+ evaluator._paramsEval = copy.deepcopy(evaluator.params)
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ evaluator.accumulate()
+ log_str = "\n" + evaluator.summarize()
+
+ log_dict.update(dict(zip(metrics, evaluator.stats)))
+
+ # Add error metrics (aggregate from all datasets)
+ # Note: In bbox mode, only ASE and AOE are returned (no ATE)
+ # In dist mode, ATE, ASE, and AOE are all returned
+ if metric == "3D":
+ # Collect error metrics from all datasets
+ all_ase = []
+ all_aoe = []
+ all_aoe_sym = []
+ all_aoe_canonical = []
+ all_ods_sym = []
+ all_ods_canonical = []
+
+ # Determine which keys to look for based on mode
+ if self.enable_aprel3d:
+ ase_key, aoe_key, aoe_sym_key = "ASERel", "AOERel", "AOERelSym"
+ aoe_canonical_key = "AOERelCanonical"
+ ods_sym_key = "ODSRelSym"
+ ods_canonical_key = "ODSRelCanonical"
+ else:
+ ase_key, aoe_key, aoe_sym_key = "ASE", "AOE", "AOE_Sym"
+ aoe_canonical_key = "AOE_Canonical"
+ ods_sym_key = "ODS_Sym"
+ ods_canonical_key = "ODS_Canonical"
+
+ for dataset_name in self.dataset_names:
+ per_dataset_log_dict = per_dataset_results[dataset_name]
+ if ase_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[ase_key]):
+ all_ase.append(per_dataset_log_dict[ase_key])
+ if aoe_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[aoe_key]):
+ all_aoe.append(per_dataset_log_dict[aoe_key])
+ if aoe_sym_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[aoe_sym_key]):
+ all_aoe_sym.append(per_dataset_log_dict[aoe_sym_key])
+ if aoe_canonical_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[aoe_canonical_key]):
+ all_aoe_canonical.append(per_dataset_log_dict[aoe_canonical_key])
+ if ods_sym_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[ods_sym_key]):
+ all_ods_sym.append(per_dataset_log_dict[ods_sym_key])
+ if ods_canonical_key in per_dataset_log_dict and not np.isnan(per_dataset_log_dict[ods_canonical_key]):
+ all_ods_canonical.append(per_dataset_log_dict[ods_canonical_key])
+
+ log_dict[ase_key] = np.mean(all_ase) if len(all_ase) > 0 else float("nan")
+ log_dict[aoe_key] = np.mean(all_aoe) if len(all_aoe) > 0 else float("nan")
+ log_dict[aoe_sym_key] = np.mean(all_aoe_sym) if len(all_aoe_sym) > 0 else float("nan")
+ log_dict[aoe_canonical_key] = np.mean(all_aoe_canonical) if len(all_aoe_canonical) > 0 else float("nan")
+ log_dict[ods_sym_key] = np.mean(all_ods_sym) if len(all_ods_sym) > 0 else float("nan")
+ log_dict[ods_canonical_key] = np.mean(all_ods_canonical) if len(all_ods_canonical) > 0 else float("nan")
+
+ if self.per_class_eval:
+ precisions = evaluator.eval["precision"]
+ for idx, cat_id in enumerate(self.overall_catIds):
+ cat_name = self.id_to_name[cat_id]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = float(np.mean(precision).item())
+ else:
+ ap = float("nan")
+
+ results_per_category_dict[cat_name] = ap
+ results_per_category.append((f"{cat_name}", f"{ap:0.3f}"))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(itertools.chain(*results_per_category))
+ headers = ["category", "AP"] * (num_columns // 2)
+ results_2d = itertools.zip_longest(
+ *[results_flatten[i::num_columns] for i in range(num_columns)]
+ )
+ table_data = [headers] + list(results_2d)
+ table = AsciiTable(table_data)
+ log_str = f"\n{table.table}\n{log_str}"
+
+ # Omni3D Outdoor performance
+ ap_out_lst = []
+ for cat in omni3d_out:
+ ap_out_lst.append(results_per_category_dict.get(cat, 0.0))
+
+ log_dict["Omni3D_Out"] = np.mean(ap_out_lst).item()
+
+ # Omni3D Indoor performance
+ ap_in_lst = []
+ for cat in omni3d_in:
+ ap_in_lst.append(results_per_category_dict.get(cat, 0.0))
+
+ log_dict["Omni3D_In"] = np.mean(ap_in_lst).item()
+
+ return log_dict, log_str
+
+ def save(self, metric: str, output_dir: str) -> None:
+ """Save the results to json files."""
+ for dataset_name in self.dataset_names:
+ self.evaluators[dataset_name].save(
+ metric, output_dir, prefix=dataset_name
+ )
diff --git a/wilddet3d/eval/open.py b/wilddet3d/eval/open.py
new file mode 100644
index 0000000000000000000000000000000000000000..044ffb5e6691f57caac00df8d37bc40e37bd8629
--- /dev/null
+++ b/wilddet3d/eval/open.py
@@ -0,0 +1,143 @@
+"""Multi-data 3D detection evaluation."""
+
+from collections.abc import Sequence
+
+from vis4d.common.logging import rank_zero_info
+from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
+from vis4d.eval.base import Evaluator
+
+from .detect3d import Detect3DEvaluator
+from .omni3d import Omni3DEvaluator
+
+
+class OpenDetect3DEvaluator(Evaluator):
+ """Multi-data 3D detection evaluator."""
+
+ def __init__(
+ self,
+ datasets: Sequence[str],
+ evaluators: Sequence[Detect3DEvaluator],
+ omni3d_evaluator: Omni3DEvaluator | None = None,
+ ) -> None:
+ """Initialize the evaluator."""
+ super().__init__()
+ self.dataset_names = datasets
+ self.evaluators = {
+ name: evaluator for name, evaluator in zip(datasets, evaluators)
+ }
+
+ self.omni3d_evaluator = omni3d_evaluator
+
+ def __repr__(self) -> str:
+ """Returns the string representation of the object."""
+ datasets_str = ", ".join(self.dataset_names)
+ return f"Open 3D Object Detection Evaluator ({datasets_str})"
+
+ @property
+ def metrics(self) -> list[str]:
+ """Supported metrics.
+
+ Returns:
+ list[str]: Metrics to evaluate.
+ """
+ return ["2D", "3D"]
+
+ def reset(self) -> None:
+ """Reset the saved predictions to start new round of evaluation."""
+ for dataset_name in self.dataset_names:
+ self.evaluators[dataset_name].reset()
+
+ if self.omni3d_evaluator is not None:
+ self.omni3d_evaluator.reset()
+
+ def gather(self, gather_func: GenericFunc) -> None:
+ """Accumulate predictions across processes."""
+ for dataset_name in self.dataset_names:
+ self.evaluators[dataset_name].gather(gather_func)
+
+ if self.omni3d_evaluator is not None:
+ self.omni3d_evaluator.gather(gather_func)
+
+ def process_batch(
+ self,
+ coco_image_id: list[int],
+ dataset_names: list[str],
+ pred_boxes: list[NDArrayNumber],
+ pred_scores: list[NDArrayNumber],
+ pred_classes: list[NDArrayNumber],
+ pred_boxes3d: list[NDArrayNumber] | None = None,
+ ) -> None:
+ """Process sample and convert detections to coco format."""
+ for i, dataset_name in enumerate(dataset_names):
+ if (
+ self.omni3d_evaluator is not None
+ and dataset_name in self.omni3d_evaluator.dataset_names
+ ):
+ self.omni3d_evaluator.process_batch(
+ [coco_image_id[i]],
+ [dataset_name],
+ [pred_boxes[i]],
+ [pred_scores[i]],
+ [pred_classes[i]],
+ pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
+ )
+ else:
+ self.evaluators[dataset_name].process_batch(
+ [coco_image_id[i]],
+ [pred_boxes[i]],
+ [pred_scores[i]],
+ [pred_classes[i]],
+ pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
+ )
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ """Evaluate predictions and return the results."""
+ assert metric in self.metrics, f"Unsupported metric: {metric}"
+
+ log_dict = {}
+ log_str = ""
+
+ if self.omni3d_evaluator is not None:
+ log_dict_omni3d, omni3d_log_str = self.omni3d_evaluator.evaluate(
+ metric
+ )
+
+ log_dict.update(log_dict_omni3d)
+ log_str += omni3d_log_str
+
+ for dataset_name in self.dataset_names:
+ rank_zero_info(f"Evaluating {dataset_name}...")
+ per_dataset_log_dict, dataset_log_str = self.evaluators[
+ dataset_name
+ ].evaluate(metric)
+
+ if "ODS" in per_dataset_log_dict:
+ score = "ODS"
+ else:
+ score = "AP"
+
+ log_dict[f"{score}_{dataset_name}"] = per_dataset_log_dict[score]
+
+ if self.evaluators[dataset_name].base_classes is not None:
+ log_dict[f"{score}_Base_{dataset_name}"] = (
+ per_dataset_log_dict[f"{score}_Base"]
+ )
+ log_dict[f"{score}_Novel_{dataset_name}"] = (
+ per_dataset_log_dict[f"{score}_Novel"]
+ )
+
+ log_str += f"\nCheck {dataset_name} results in log dict."
+
+ rank_zero_info(dataset_log_str + "\n")
+
+ return log_dict, log_str
+
+ def save(self, metric: str, output_dir: str) -> None:
+ """Save the results to json files."""
+ if self.omni3d_evaluator is not None:
+ self.omni3d_evaluator.save(metric, output_dir)
+
+ for dataset_name in self.dataset_names:
+ self.evaluators[dataset_name].save(
+ metric, output_dir, prefix=dataset_name
+ )
diff --git a/wilddet3d/eval/postprocess_cache_export.py b/wilddet3d/eval/postprocess_cache_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..e63c586a191439e227a0bafa215c7ba57e17a043
--- /dev/null
+++ b/wilddet3d/eval/postprocess_cache_export.py
@@ -0,0 +1,185 @@
+"""Postprocess cache exporter (test-time).
+
+This evaluator is used with vis4d's EvaluatorCallback to export per-image caches
+needed for depth-based 3D box post-processing, without changing the normal
+evaluation flow.
+
+Cache layout:
+ {cache_root}/{dataset_name}/{image_id}.npz
+
+We intentionally store the full metric depth map (aligned to original_hw) to
+avoid coordinate-system bugs from cropping.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import numpy as np
+import torch
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
+from vis4d.eval.base import Evaluator
+
+
+class PostprocessCacheExporter(Evaluator):
+ """Exports model outputs needed for post-processing into .npz cache files."""
+
+ def __init__(
+ self,
+ cache_root: str,
+ compress: bool = True,
+ overwrite: bool = False,
+ depth_dtype: str = "float32",
+ ) -> None:
+ super().__init__()
+ self.cache_root = cache_root
+ self.compress = compress
+ self.overwrite = overwrite
+ if depth_dtype not in {"float16", "float32"}:
+ raise ValueError(f"Unsupported depth_dtype: {depth_dtype}")
+ self.depth_dtype = depth_dtype
+
+ self._num_written = 0
+ self._num_skipped = 0
+
+ @property
+ def metrics(self) -> list[str]:
+ # Not a real evaluator; we only export.
+ return []
+
+ def reset(self) -> None: # pragma: no cover
+ self._num_written = 0
+ self._num_skipped = 0
+
+ def gather(self, gather_func: GenericFunc) -> None: # pragma: no cover
+ # Nothing to gather; each rank writes its own files (safe because image_id is unique).
+ return
+
+ def process_batch(
+ self,
+ coco_image_id: list[int],
+ dataset_names: list[str],
+ pred_boxes: list[NDArrayNumber],
+ pred_scores: list[NDArrayNumber],
+ pred_classes: list[NDArrayNumber],
+ pred_boxes3d: list[NDArrayNumber] | None = None,
+ pred_categories: list[list[str]] | None = None,
+ depth_maps: list[torch.Tensor] | None = None,
+ intrinsics: list[NDArrayNumber] | NDArrayNumber | None = None,
+ original_hw: list[tuple[int, int]] | None = None,
+ ) -> None:
+ """Write one .npz per image."""
+ if pred_boxes3d is None:
+ # No 3D boxes -> nothing to export for depth alignment.
+ print("[PostprocessCacheExporter] Skipping: pred_boxes3d is None")
+ return
+ if depth_maps is None:
+ # Depth backend disabled -> nothing to export.
+ print("[PostprocessCacheExporter] Skipping: depth_maps is None")
+ return
+ if intrinsics is None:
+ print("[PostprocessCacheExporter] Skipping: intrinsics is None")
+ return
+ if original_hw is None:
+ print("[PostprocessCacheExporter] Skipping: original_hw is None")
+ return
+
+ print(f"[PostprocessCacheExporter] Processing batch: {len(coco_image_id)} images")
+
+ # Normalize intrinsics to per-sample list
+ if torch.is_tensor(intrinsics):
+ # intrinsics: Tensor [B, 3, 3] (may be on GPU)
+ intrinsics_np = intrinsics.detach().cpu().numpy()
+ intrinsics_list = [intrinsics_np[j] for j in range(intrinsics_np.shape[0])]
+ elif isinstance(intrinsics, np.ndarray):
+ # intrinsics: ndarray [3,3] or [B,3,3]
+ if intrinsics.ndim == 2:
+ intrinsics_list = [intrinsics for _ in range(len(coco_image_id))]
+ else:
+ intrinsics_list = [intrinsics[j] for j in range(intrinsics.shape[0])]
+ else:
+ # intrinsics: sequence of arrays/tensors
+ intrinsics_list = list(intrinsics)
+
+ for i, image_id in enumerate(coco_image_id):
+ dataset_name = dataset_names[i]
+ out_dir = os.path.join(self.cache_root, str(dataset_name))
+ os.makedirs(out_dir, exist_ok=True)
+
+ out_path = os.path.join(out_dir, f"{int(image_id)}.npz")
+ if (not self.overwrite) and os.path.exists(out_path):
+ self._num_skipped += 1
+ continue
+
+ boxes2d = array_to_numpy(
+ pred_boxes[i].to(torch.float32) if hasattr(pred_boxes[i], "to") else pred_boxes[i],
+ n_dims=None,
+ dtype=np.float32,
+ )
+ scores = array_to_numpy(
+ pred_scores[i].to(torch.float32) if hasattr(pred_scores[i], "to") else pred_scores[i],
+ n_dims=None,
+ dtype=np.float32,
+ )
+ class_ids = array_to_numpy(
+ pred_classes[i].to(torch.int64) if hasattr(pred_classes[i], "to") else pred_classes[i],
+ n_dims=None,
+ dtype=np.int64,
+ )
+ boxes3d = array_to_numpy(
+ pred_boxes3d[i].to(torch.float32) if hasattr(pred_boxes3d[i], "to") else pred_boxes3d[i],
+ n_dims=None,
+ dtype=np.float32,
+ )
+
+ # depth_maps is list[Tensor] where each Tensor is [H, W] or [1, H, W]
+ depth = depth_maps[i]
+ if depth.ndim == 3 and depth.shape[0] == 1:
+ depth = depth[0]
+ depth_np = depth.detach().cpu().numpy()
+ depth_np = depth_np.astype(np.float16 if self.depth_dtype == "float16" else np.float32)
+
+ Ki = intrinsics_list[i]
+ if torch.is_tensor(Ki):
+ K = Ki.detach().cpu().numpy().astype(np.float32)
+ else:
+ K = np.asarray(Ki, dtype=np.float32)
+ hw = original_hw[i]
+
+ meta: dict[str, Any] = {
+ "dataset_name": str(dataset_name),
+ "image_id": int(image_id),
+ "original_hw": np.asarray(hw, dtype=np.int32),
+ }
+
+ # Categories are variable-length strings; store as object array.
+ if pred_categories is not None and i < len(pred_categories) and pred_categories[i] is not None:
+ cats = np.asarray(pred_categories[i], dtype=object)
+ else:
+ cats = np.asarray([], dtype=object)
+
+ save_fn = np.savez_compressed if self.compress else np.savez
+ save_fn(
+ out_path,
+ boxes2d=boxes2d,
+ scores=scores,
+ class_ids=class_ids,
+ boxes3d_raw=boxes3d,
+ categories=cats,
+ depth_map=depth_np,
+ intrinsics=K,
+ meta=np.asarray(meta, dtype=object),
+ )
+ self._num_written += 1
+
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
+ # No evaluation; return empty.
+ return {}, f"PostprocessCacheExporter: wrote={self._num_written}, skipped={self._num_skipped}"
+
+ def save(self, metric: str, output_dir: str, prefix: str | None = None) -> None: # pragma: no cover
+ # Nothing to save beyond the cache files.
+ return
+
+
diff --git a/wilddet3d/head/__init__.py b/wilddet3d/head/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5c8863b9706e3dfafb6f584f1e74090c2a6fed9
--- /dev/null
+++ b/wilddet3d/head/__init__.py
@@ -0,0 +1,12 @@
+"""3D detection head."""
+
+from .coder_3d import Det3DCoder
+from .depth_cross_attn import DepthCrossAttention
+from .head_3d import Det3DHead, RoI2Det3D
+
+__all__ = [
+ "Det3DHead",
+ "RoI2Det3D",
+ "Det3DCoder",
+ "DepthCrossAttention",
+]
diff --git a/wilddet3d/head/__pycache__/__init__.cpython-311.pyc b/wilddet3d/head/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2ddb486972490a79f081628dcaf1b21dbc1bb22
Binary files /dev/null and b/wilddet3d/head/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/head/__pycache__/coder_3d.cpython-311.pyc b/wilddet3d/head/__pycache__/coder_3d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4f100addb5accc9a72fef27dffcaa0c69119060
Binary files /dev/null and b/wilddet3d/head/__pycache__/coder_3d.cpython-311.pyc differ
diff --git a/wilddet3d/head/__pycache__/depth_cross_attn.cpython-311.pyc b/wilddet3d/head/__pycache__/depth_cross_attn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e8db85fe10cd9ab695ef961548419bcb9b3a00f
Binary files /dev/null and b/wilddet3d/head/__pycache__/depth_cross_attn.cpython-311.pyc differ
diff --git a/wilddet3d/head/__pycache__/head_3d.cpython-311.pyc b/wilddet3d/head/__pycache__/head_3d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38351d955d44e0b88c747644984f81614f3e4617
Binary files /dev/null and b/wilddet3d/head/__pycache__/head_3d.cpython-311.pyc differ
diff --git a/wilddet3d/head/coder_3d.py b/wilddet3d/head/coder_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc59dd319ef1af28753e870e0449986fd55bc8e0
--- /dev/null
+++ b/wilddet3d/head/coder_3d.py
@@ -0,0 +1,263 @@
+"""3D bounding box encoder."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+from vis4d.data.const import AxisMode
+from vis4d.op.geometry.projection import project_points, unproject_points
+from vis4d.op.geometry.rotation import (
+ euler_angles_to_matrix,
+ matrix_to_quaternion,
+ quaternion_to_matrix,
+ rotation_matrix_yaw,
+)
+
+from wilddet3d.ops.rotation import (
+ matrix_to_rotation_6d,
+ rotation_6d_to_matrix,
+)
+
+
+def _normalize_rotation_half(poses: Tensor) -> Tensor:
+ """Normalize rotation matrices to [0, pi) yaw range.
+
+ For objects with 180-degree rotational ambiguity (e.g. tables, chairs),
+ this folds yaw into [0, pi) so that 90 and 270 map to the same target.
+ Also handles boundary: 180 and 0 map to the same target.
+
+ Uses Y-axis rotation (OPENCV convention) to detect and flip.
+ """
+ import math
+
+ yaw = rotation_matrix_yaw(
+ poses, axis_mode=AxisMode.OPENCV
+ )[:, 1] # [N]
+ # Flip by 180 around Y-axis: Ry(pi) = diag(-1, 1, -1)
+ # yaw in [-pi, 0) or yaw ~= pi -> flip to [0, pi)
+ flip_mask = (yaw < 0) | (yaw > math.pi - 1e-4)
+ poses_out = poses.clone()
+ # R_new = R @ Ry(pi), Ry(pi) negates columns 0 and 2
+ poses_out[flip_mask, :, 0] = -poses[flip_mask, :, 0]
+ poses_out[flip_mask, :, 2] = -poses[flip_mask, :, 2]
+ return poses_out
+
+
+def _normalize_canonical(
+ poses: Tensor, dims: Tensor,
+) -> tuple[Tensor, Tensor]:
+ """Normalize rotation and dimensions to canonical form.
+
+ Eliminates OBB rotation ambiguity via 2 steps:
+
+ Step 1 - Force W <= L:
+ If W > L, swap W and L, then apply Ry(90 deg) to rotation.
+ boxes3d dims = [W, L, H]. Canonical: X=L, Z=W, so swapping
+ W<->L requires rotating 90 deg around Y to keep the box
+ geometry identical.
+ Ry(90): new_col0 = old_col2, new_col2 = -old_col0
+
+ Step 2 - Normalize yaw to [0, pi):
+ Same as _normalize_rotation_half. Apply Ry(180 deg) if yaw < 0
+ or yaw >= pi.
+
+ Together these reduce 4-fold Ry ambiguity to 1-fold.
+ (Rx(180) upside-down ambiguity is left to data preprocessing.)
+
+ Args:
+ poses: Rotation matrices [N, 3, 3].
+ dims: Dimensions [N, 3] as [W, L, H].
+
+ Returns:
+ poses_out: Normalized rotation matrices [N, 3, 3].
+ dims_out: Normalized dimensions [N, 3] with W <= L.
+ """
+ import math
+
+ poses_out = poses.clone()
+ dims_out = dims.clone()
+
+ # Step 1: Force W <= L
+ # dims = [W, L, H], indices 0, 1, 2
+ swap_mask = dims_out[:, 0] > dims_out[:, 1] # W > L
+ if swap_mask.any():
+ # Swap W and L
+ w_old = dims_out[swap_mask, 0].clone()
+ dims_out[swap_mask, 0] = dims_out[swap_mask, 1]
+ dims_out[swap_mask, 1] = w_old
+
+ # Apply Ry(90 deg): R_new = R @ Ry(90)
+ # Ry(90) = [[0,0,1],[0,1,0],[-1,0,0]]
+ # col0_new = R @ [0,0,-1]^T = -col2
+ # col1_new = R @ [0,1,0]^T = col1 (unchanged)
+ # col2_new = R @ [1,0,0]^T = col0
+ col0 = poses_out[swap_mask, :, 0].clone()
+ col2 = poses_out[swap_mask, :, 2].clone()
+ poses_out[swap_mask, :, 0] = -col2
+ poses_out[swap_mask, :, 2] = col0
+
+ # Step 2: Normalize yaw to [0, pi)
+ yaw = rotation_matrix_yaw(
+ poses_out, axis_mode=AxisMode.OPENCV
+ )[:, 1] # [N]
+ flip_mask = (yaw < 0) | (yaw > math.pi - 1e-4)
+ if flip_mask.any():
+ # R_new = R @ Ry(pi), negates columns 0 and 2
+ poses_out[flip_mask, :, 0] = -poses_out[flip_mask, :, 0]
+ poses_out[flip_mask, :, 2] = -poses_out[flip_mask, :, 2]
+
+ return poses_out, dims_out
+
+
+class Det3DCoder:
+ """3D box coder for encoding/decoding 3D bounding boxes."""
+
+ def __init__(
+ self,
+ center_scale: float = 10.0,
+ depth_scale: float = 2.0,
+ dim_scale: float = 2.0,
+ orientation: str = "rotation_6d",
+ ambiguous_rotation: bool = False,
+ canonical_rotation: bool = False,
+ ) -> None:
+ """Initialize the 3D box coder."""
+ self.center_scale = center_scale
+ self.depth_scale = depth_scale
+ self.dim_scale = dim_scale
+ self.ambiguous_rotation = ambiguous_rotation
+ self.canonical_rotation = canonical_rotation
+ if canonical_rotation:
+ print(
+ "[Det3DCoder] canonical_rotation=True: "
+ "dims normalized to W<=L, yaw to [0, 180)"
+ )
+ elif ambiguous_rotation:
+ print(
+ "[Det3DCoder] ambiguous_rotation=True: "
+ "GT rotation normalized to [0, 180) yaw range"
+ )
+
+ assert orientation in {
+ "yaw",
+ "rotation_6d",
+ }, f"Invalid orientation {orientation}."
+ self.orientation = orientation
+
+ if orientation == "yaw":
+ reg_dims = 8
+ elif orientation == "rotation_6d":
+ reg_dims = 12
+
+ self.reg_dims = reg_dims
+
+ def encode(
+ self, boxes: Tensor, boxes3d: Tensor, intrinsics: Tensor,
+ ) -> tuple[Tensor, Tensor]:
+ """Encode the 3D bounding boxes.
+
+ Args:
+ boxes: 2D boxes in PIXEL xyxy format. Shape (N, 4).
+ IMPORTANT: Should be GT 2D boxes during training (not predictions!)
+ This ensures stable targets. At inference, decode() uses pred boxes.
+ boxes3d: GT 3D boxes [center_3d(3), dims(3), quat(4)]. Shape (N, 10).
+ intrinsics: Camera intrinsics. Shape (3, 3) or (N, 3, 3).
+
+ Returns:
+ boxes3d_target: Encoded targets [delta_2d(2), log_depth(1), log_dims(3), rot_6d(6)].
+ boxes3d_weights: Per-element weights (0 for invalid depth/dims).
+ """
+ projected_center_3d = project_points(boxes3d[:, :3], intrinsics)
+ ctr_x = (boxes[:, 0] + boxes[:, 2]) / 2
+ ctr_y = (boxes[:, 1] + boxes[:, 3]) / 2
+ center_2d = torch.stack([ctr_x, ctr_y], -1)
+
+ delta_center = projected_center_3d - center_2d
+
+ delta_center /= self.center_scale
+
+ valid_depth = boxes3d[:, 2] > 0
+
+ depth = torch.where(
+ valid_depth,
+ torch.log(boxes3d[:, 2]) * self.depth_scale,
+ boxes3d[:, 2].new_zeros(1),
+ )
+ depth = depth.unsqueeze(-1)
+
+ raw_dims = boxes3d[:, 3:6] # [W, L, H]
+
+ poses = quaternion_to_matrix(boxes3d[:, 6:])
+
+ if self.canonical_rotation:
+ poses, raw_dims = _normalize_canonical(poses, raw_dims)
+ elif self.ambiguous_rotation:
+ poses = _normalize_rotation_half(poses)
+
+ valid_dims = raw_dims > 0
+ dims = torch.where(
+ valid_dims,
+ torch.log(raw_dims) * self.dim_scale,
+ raw_dims.new_zeros(1),
+ )
+
+ if self.orientation == "yaw":
+ yaw = rotation_matrix_yaw(
+ poses,
+ axis_mode=AxisMode.OPENCV,
+ )[:, 1]
+
+ sin_yaw = torch.sin(yaw).unsqueeze(-1)
+ cos_yaw = torch.cos(yaw).unsqueeze(-1)
+
+ boxes3d_target = torch.cat(
+ [delta_center, depth, dims, sin_yaw, cos_yaw], -1
+ )
+ elif self.orientation == "rotation_6d":
+ rot_6d = matrix_to_rotation_6d(poses)
+
+ boxes3d_target = torch.cat([delta_center, depth, dims, rot_6d], -1)
+
+ boxes3d_weights = torch.ones_like(boxes3d_target)
+ boxes3d_weights[:, 2] = valid_depth.float()
+ boxes3d_weights[:, 3:6] = valid_dims.float()
+
+ return boxes3d_target, boxes3d_weights
+
+ def decode(
+ self, boxes: Tensor, boxes3d: Tensor, intrinsics: Tensor
+ ) -> Tensor:
+ """Decode the 3D bounding boxes."""
+ delta_center = boxes3d[:, :2] * self.center_scale
+
+ ctr_x = (boxes[:, 0] + boxes[:, 2]) / 2
+ ctr_y = (boxes[:, 1] + boxes[:, 3]) / 2
+ center_2d = torch.stack([ctr_x, ctr_y], -1)
+
+ proj_center_3d = center_2d + delta_center
+
+ depth = torch.exp(boxes3d[:, 2] / self.depth_scale)
+
+ center_3d = unproject_points(proj_center_3d, depth, intrinsics)
+
+ dims = torch.exp(boxes3d[:, 3:6] / self.dim_scale)
+
+ if self.orientation == "yaw":
+ yaw = torch.atan2(boxes3d[:, 6], boxes3d[:, 7])
+
+ orientation = torch.stack(
+ [torch.zeros_like(yaw), yaw, torch.zeros_like(yaw)], -1
+ )
+
+ poses = euler_angles_to_matrix(orientation)
+ elif self.orientation == "rotation_6d":
+ poses = rotation_6d_to_matrix(boxes3d[:, 6:])
+
+ if self.canonical_rotation:
+ poses, dims = _normalize_canonical(poses, dims)
+ elif self.ambiguous_rotation:
+ poses = _normalize_rotation_half(poses)
+
+ orientation = matrix_to_quaternion(poses)
+
+ return torch.cat([center_3d, dims, orientation], dim=1)
diff --git a/wilddet3d/head/depth_cross_attn.py b/wilddet3d/head/depth_cross_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c0412e62ea0b872321704ee87f78c02d56ca38
--- /dev/null
+++ b/wilddet3d/head/depth_cross_attn.py
@@ -0,0 +1,340 @@
+"""Depth cross-attention head."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+from einops import rearrange
+from timm.layers import trunc_normal_
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from wilddet3d.ops.ray import generate_rays, rsh_cart_8
+from wilddet3d.ops.attention import (
+ AttentionBlock,
+ NystromBlock,
+ PositionEmbeddingSine,
+)
+from wilddet3d.ops.mlp import MLP
+from wilddet3d.ops.upsample import ConvUpsample
+from wilddet3d.ops.util import flat_interpolate
+
+
+class DepthCrossAttention(nn.Module):
+ """Depth cross-attention head for depth estimation."""
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ depth_scale: float = 2.0,
+ input_dims: Sequence[int] = (256, 256, 256),
+ output_scales: int = 1,
+ ) -> None:
+ """Initialize the depth head."""
+ super().__init__()
+ self.depth_scale = depth_scale
+ assert (
+ output_scales >= 1 and output_scales <= 3
+ ), "Invalid output scales."
+ self.output_scales = output_scales
+
+ num_resolutions = len(input_dims)
+ self.input_dims = input_dims
+ self.num_resolutions = num_resolutions
+
+ # Pool features as depth query
+ self.features_channel_cat = nn.Linear(
+ embed_dims * self.num_resolutions, embed_dims
+ )
+ self.to_latents = MLP(embed_dims, expansion=2)
+
+ self.pos_embed = PositionEmbeddingSine(embed_dims // 2, normalize=True)
+
+ self.level_embeds = nn.Parameter(
+ torch.randn(self.num_resolutions, embed_dims),
+ requires_grad=True,
+ )
+ self.level_embed_layer = nn.Sequential(
+ nn.Linear(embed_dims, embed_dims),
+ nn.GELU(),
+ nn.Linear(embed_dims, embed_dims),
+ nn.LayerNorm(embed_dims),
+ )
+
+ self.aggregate_16 = AttentionBlock(
+ embed_dims,
+ num_heads=1,
+ expansion=4,
+ context_dim=embed_dims,
+ )
+
+ self.prompt_camera = AttentionBlock(
+ embed_dims, num_heads=1, expansion=4, context_dim=embed_dims
+ )
+
+ # 1/16 resolution
+ self.project_rays_16 = MLP(81, expansion=4, output_dim=embed_dims)
+
+ self.layers_16 = nn.ModuleList(
+ [
+ AttentionBlock(embed_dims, num_heads=8, expansion=4),
+ NystromBlock(embed_dims, num_heads=8, expansion=4),
+ ]
+ )
+
+ self.up_8 = ConvUpsample(embed_dims, expansion=4)
+
+ if self.output_scales == 1:
+ self.out_8 = nn.Conv2d(embed_dims // 2, 1, 3, padding=1)
+
+ if self.output_scales >= 2:
+ # 1/8 resolution
+ embed_dims_8 = embed_dims // 2
+ self.project_rays_8 = MLP(81, expansion=4, output_dim=embed_dims_8)
+
+ self.layers_8 = nn.ModuleList(
+ [
+ AttentionBlock(embed_dims_8, num_heads=4, expansion=4),
+ NystromBlock(embed_dims_8, num_heads=4, expansion=4),
+ ]
+ )
+
+ self.up_4 = ConvUpsample(embed_dims_8, expansion=4)
+
+ if self.output_scales == 2:
+ self.out_4 = nn.Conv2d(embed_dims_8 // 2, 1, 3, padding=1)
+
+ if self.output_scales == 3:
+ # 1/4 resolution
+ embed_dims_4 = embed_dims // 4
+ self.project_rays_4 = MLP(81, expansion=4, output_dim=embed_dims_4)
+
+ self.layers_4 = nn.ModuleList(
+ [
+ AttentionBlock(embed_dims_4, num_heads=2, expansion=4),
+ NystromBlock(embed_dims_4, num_heads=2, expansion=4),
+ ]
+ )
+
+ self.up_2 = ConvUpsample(embed_dims_4, expansion=4)
+
+ self.out_2 = nn.Conv2d(embed_dims_4 // 2, 1, 3, padding=1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_rsh_cart(self, rays_embedding: Tensor) -> Tensor:
+ """Get real spherical harmonic."""
+ return rsh_cart_8(rays_embedding)
+
+ def forward(
+ self, feats: Tensor, intrinsics: Tensor, image_hw: tuple[int, int]
+ ) -> Tensor:
+ """Forward."""
+ # Camera Embedding
+ rays_hr, _ = generate_rays(intrinsics, image_hw)
+
+ # 1/16 shape
+ shape = image_hw[0] // 16, image_hw[1] // 16
+
+ latents = []
+ for _, feat in enumerate(feats):
+ latent = (
+ F.interpolate(
+ feat,
+ size=shape,
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+ .flatten(2)
+ .permute(0, 2, 1)
+ )
+
+ latents.append(latent)
+
+ # positional embeddings, spatial and level
+ level_embed = torch.cat(
+ [
+ self.level_embed_layer(self.level_embeds)[i : i + 1]
+ .unsqueeze(0)
+ .repeat(feats[0].shape[0], shape[0] * shape[1], 1)
+ for i in range(self.num_resolutions)
+ ],
+ dim=1,
+ )
+ pos_embed = self.pos_embed(
+ torch.zeros(
+ feats[0].shape[0],
+ 1,
+ shape[0],
+ shape[1],
+ device=feats[0].device,
+ requires_grad=False,
+ )
+ )
+ pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
+ 1, self.num_resolutions, 1
+ )
+
+ features_tokens = torch.cat(latents, dim=1)
+ features_tokens_pos = pos_embed + level_embed
+
+ features_channels = torch.cat(latents, dim=-1)
+ features_16 = self.features_channel_cat(features_channels)
+ latents_16 = self.to_latents(features_16)
+
+ # Aggregate features: F -> D
+ latents_16 = self.aggregate_16(
+ latents_16,
+ context=features_tokens,
+ pos_embed_context=features_tokens_pos,
+ )
+
+ # 1/16 shape
+ rays_embedding_16 = F.normalize(
+ flat_interpolate(rays_hr, old=image_hw, new=shape), dim=-1
+ )
+
+ rays_embedding_16 = self.project_rays_16(
+ self.get_rsh_cart(rays_embedding_16)
+ )
+
+ # Aggregate camera: D -> D|E
+ latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16)
+
+ outs = []
+ depth_latents = []
+
+ # Block 16 - Out 8
+ for layer in self.layers_16:
+ latents_16 = layer(latents_16, pos_embed=rays_embedding_16)
+
+ latents_8 = self.up_8(
+ rearrange(
+ latents_16,
+ "b (h w) c -> b c h w",
+ h=shape[0],
+ w=shape[1],
+ ).contiguous()
+ )
+
+ if self.output_scales == 1:
+ out_8 = self.out_8(
+ rearrange(
+ latents_8,
+ "b (h w) c -> b c h w",
+ h=shape[0] * 2,
+ w=shape[1] * 2,
+ )
+ )
+ outs.append(out_8)
+ depth_latents.append(latents_8.detach())
+
+ if self.output_scales >= 2:
+ # 1/8 shape
+ rays_embedding_8 = F.normalize(
+ flat_interpolate(
+ rays_hr, old=image_hw, new=(shape[0] * 2, shape[1] * 2)
+ ),
+ dim=-1,
+ )
+
+ rays_embedding_8 = self.project_rays_8(
+ self.get_rsh_cart(rays_embedding_8)
+ )
+
+ # Block 8 - Out 4
+ for layer in self.layers_8:
+ latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
+
+ latents_4 = self.up_4(
+ rearrange(
+ latents_8,
+ "b (h w) c -> b c h w",
+ h=shape[0] * 2,
+ w=shape[1] * 2,
+ ).contiguous()
+ )
+
+ if self.output_scales == 2:
+ out_4 = self.out_4(
+ rearrange(
+ latents_4,
+ "b (h w) c -> b c h w",
+ h=shape[0] * 4,
+ w=shape[1] * 4,
+ )
+ )
+ outs.append(out_4)
+ depth_latents.append(latents_4.detach())
+
+ if self.output_scales == 3:
+ # 1/4 shape
+ rays_embedding_4 = F.normalize(
+ flat_interpolate(
+ rays_hr, old=image_hw, new=(shape[0] * 4, shape[1] * 4)
+ ),
+ dim=-1,
+ )
+
+ rays_embedding_4 = self.project_rays_4(
+ self.get_rsh_cart(rays_embedding_4)
+ )
+
+ # Block 4 - Out 2
+ for layer in self.layers_4:
+ latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
+
+ latents_2 = self.up_2(
+ rearrange(
+ latents_4,
+ "b (h w) c -> b c h w",
+ h=shape[0] * 4,
+ w=shape[1] * 4,
+ ).contiguous()
+ )
+ out_2 = self.out_2(
+ rearrange(
+ latents_2,
+ "b (h w) c -> b c h w",
+ h=shape[0] * 8,
+ w=shape[1] * 8,
+ )
+ )
+ outs.append(out_2)
+ depth_latents.append(latents_2.detach())
+
+ # MS Outputs
+ depth_preds = (
+ sum(
+ [
+ F.interpolate(
+ torch.exp((out / self.depth_scale).clamp(-10.0, 10.0)),
+ size=image_hw,
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+ for out in outs
+ ]
+ )
+ / len(outs)
+ ).squeeze(1)
+
+ depth_latent = depth_latents[-1]
+
+ return depth_preds, depth_latent
diff --git a/wilddet3d/head/head_3d.py b/wilddet3d/head/head_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f3ebbbf49e1a17c1bdaad1ec97a5d80ee85c3ff
--- /dev/null
+++ b/wilddet3d/head/head_3d.py
@@ -0,0 +1,452 @@
+"""3D detection head."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+from torchvision.ops import batched_nms, nms
+from vis4d.op.layer.attention import MultiheadAttention
+from vis4d.op.layer.transformer import FFN, get_clones
+from vis4d.op.layer.weight_init import xavier_init
+
+from wilddet3d.ops.box2d import bbox_cxcywh_to_xyxy
+from wilddet3d.ops.ray import generate_rays, rsh_cart_8
+from wilddet3d.ops.mlp import MLP
+from wilddet3d.ops.util import flat_interpolate
+
+from .coder_3d import Det3DCoder
+
+
+def convert_grounding_to_cls_scores(
+ logits: Tensor, positive_maps: dict[int, list[int, int]]
+) -> Tensor:
+ """Convert logits to class scores."""
+ assert len(positive_maps) == logits.shape[0] # batch size
+
+ scores = torch.zeros(
+ logits.shape[0], logits.shape[1], len(positive_maps[0])
+ ).to(logits.device)
+ if positive_maps is not None:
+ if all(x == positive_maps[0] for x in positive_maps):
+ # only need to compute once
+ positive_map = positive_maps[0]
+ for label_j in positive_map:
+ scores[:, :, label_j - 1] = logits[
+ :, :, torch.LongTensor(positive_map[label_j])
+ ].mean(-1)
+ else:
+ for i, positive_map in enumerate(positive_maps):
+ for label_j in positive_map:
+ scores[i, :, label_j - 1] = logits[
+ i, :, torch.LongTensor(positive_map[label_j])
+ ].mean(-1)
+ return scores
+
+
+class Det3DHead(nn.Module):
+ """3D detection head.
+
+ Args:
+ embed_dims: Embedding dimension for the head.
+ num_decoder_layer: Number of decoder layers.
+ num_reg_fcs: Number of fully connected layers in regression branch.
+ as_two_stage: Whether to use two-stage detection.
+ box_coder: 3D box coder for encoding/decoding.
+ depth_output_scales: Scale factor for depth embedding dims.
+ use_camera_prompt: Whether to use camera/ray prompt branch.
+ Set to False when using ray-aware depth backends (UniDepthV2, DetAny3D)
+ since their depth_latents already incorporate ray information.
+ Set to True for non-ray-aware backends (UniDepthHead v1).
+ use_depth_prompt: Whether to use depth prompt branch.
+ Set to False for ablation: only use depth via encoder fusion.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_decoder_layer: int = 6,
+ num_reg_fcs: int = 2,
+ as_two_stage: bool = True,
+ box_coder: Det3DCoder | None = None,
+ depth_output_scales: int = 1,
+ depth_latent_dim: int | None = None,
+ use_camera_prompt: bool = True,
+ use_depth_prompt: bool = True,
+ ) -> None:
+ """Initialize the 3D detection head.
+
+ Args:
+ depth_latent_dim: Dimension of depth latents from geometry backend.
+ If provided, uses this directly. If None, computes from
+ depth_output_scales as embed_dims // 2**depth_output_scales.
+ """
+ super().__init__()
+ self.embed_dims = embed_dims
+ self.use_camera_prompt = use_camera_prompt
+ self.use_depth_prompt = use_depth_prompt
+
+ self.num_pred_layer = (
+ num_decoder_layer + 1 if as_two_stage else num_decoder_layer
+ )
+ self.as_two_stage = as_two_stage
+
+ self.box_coder = box_coder or Det3DCoder()
+
+ reg_branch = self._get_reg_branch(num_reg_fcs, self.box_coder.reg_dims)
+ self.reg_branches = get_clones(reg_branch, self.num_pred_layer)
+
+ # 3D confidence branch (predicts 3D-aware objectness score)
+ conf_branch = self._get_conf_branch(num_reg_fcs)
+ self.conf_branches = get_clones(conf_branch, self.num_pred_layer)
+
+ # Camera prompt branch (only created if use_camera_prompt is True)
+ if self.use_camera_prompt:
+ project_rays, prompt_camera = self._get_condition_branch(
+ input_dims=81, expansion=4, embed_dims=embed_dims
+ )
+ self.project_rays = get_clones(project_rays, self.num_pred_layer)
+ self.prompt_camera = get_clones(prompt_camera, self.num_pred_layer)
+ else:
+ self.project_rays = None
+ self.prompt_camera = None
+
+ # Depth prompt branch (only created if use_depth_prompt is True)
+ if self.use_depth_prompt:
+ # Use depth_latent_dim directly if provided, else compute from depth_output_scales
+ if depth_latent_dim is not None:
+ depth_embed_dims = depth_latent_dim
+ else:
+ depth_embed_dims = embed_dims // 2**depth_output_scales
+ project_depth, prompt_depth = self._get_condition_branch(
+ depth_embed_dims, expansion=4, embed_dims=embed_dims
+ )
+ self.project_depth = get_clones(project_depth, self.num_pred_layer)
+ self.prompt_depth = get_clones(prompt_depth, self.num_pred_layer)
+ else:
+ self.project_depth = None
+ self.prompt_depth = None
+
+ self._init_weights()
+
+ def _get_reg_branch(
+ self, num_reg_fcs: int, reg_dims: int
+ ) -> nn.Sequential:
+ """Get the regression branch."""
+ reg_branch = []
+ for _ in range(num_reg_fcs):
+ reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
+ reg_branch.append(nn.ReLU())
+ reg_branch.append(nn.Linear(self.embed_dims, reg_dims))
+ return nn.Sequential(*reg_branch)
+
+ def _get_conf_branch(self, num_reg_fcs: int) -> nn.Sequential:
+ """Get the 3D confidence branch (output dim = 1)."""
+ conf_branch = []
+ for _ in range(num_reg_fcs):
+ conf_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
+ conf_branch.append(nn.ReLU())
+ conf_branch.append(nn.Linear(self.embed_dims, 1))
+ return nn.Sequential(*conf_branch)
+
+ def _get_condition_branch(
+ self, input_dims: int, expansion: int, embed_dims: int
+ ) -> tuple[nn.Module, nn.Module]:
+ """Get the condition branch."""
+ project_layer = MLP(
+ input_dims, expansion=expansion, output_dim=embed_dims
+ )
+
+ prompt_layer = Prompt3DQueryLayer(embed_dims)
+
+ return project_layer, prompt_layer
+
+ def _init_weights(self) -> None:
+ """Initialize weights of the Deformable DETR head."""
+ for m in self.reg_branches:
+ xavier_init(m, distribution="uniform")
+ for m in self.conf_branches:
+ xavier_init(m, distribution="uniform")
+
+ def get_camera_embeddings(
+ self,
+ intrinsics: Tensor,
+ image_shape: tuple[int, int],
+ downsample: int = 16,
+ ) -> Tensor:
+ """Get the camera embeddings.
+
+ Args:
+ intrinsics: Camera intrinsics [B, 3, 3]. Should match the space
+ where depth_latents were computed (may be adjusted for DINOv2).
+ image_shape: Image (H, W) in the same space as intrinsics.
+ downsample: Downsample factor for ray grid (8 or 16).
+ Must match depth_latents resolution.
+
+ Returns:
+ ray_embeddings: [B, H//downsample * W//downsample, 81]
+ """
+ rays, _ = generate_rays(intrinsics, image_shape)
+
+ rays = F.normalize(
+ flat_interpolate(
+ rays,
+ old=image_shape,
+ new=(image_shape[0] // downsample, image_shape[1] // downsample),
+ ),
+ dim=-1,
+ )
+
+ return rsh_cart_8(rays)
+
+ def single_forward(
+ self,
+ layer_id: int,
+ hidden_state: Tensor,
+ ray_embeddings: Tensor | None,
+ depth_latents: Tensor | None = None,
+ ) -> tuple[Tensor, Tensor]:
+ """Single layer forward pass of the 3D detection head.
+
+ Args:
+ layer_id: Index of the decoder layer.
+ hidden_state: Query hidden states [B, num_queries, embed_dims].
+ ray_embeddings: Ray embeddings [B, H*W, 81]. Only used if use_camera_prompt=True.
+ depth_latents: Depth latent features [B, H*W, depth_embed_dims].
+
+ Returns:
+ Tuple of (reg_output, conf_output):
+ - reg_output: 3D box regression [B, num_queries, reg_dims]
+ - conf_output: 3D confidence logits [B, num_queries, 1]
+ """
+ # Camera-aware 3D queries (only if use_camera_prompt is True)
+ if self.use_camera_prompt and ray_embeddings is not None:
+ ray_embedding = self.project_rays[layer_id](ray_embeddings)
+ hidden_state = self.prompt_camera[layer_id](
+ hidden_state, ray_embedding, ray_embedding
+ )
+
+ # Depth-aware 3D queries (only if use_depth_prompt is True)
+ if self.use_depth_prompt and depth_latents is not None:
+ proj_depth_latents = self.project_depth[layer_id](depth_latents)
+ hidden_state = self.prompt_depth[layer_id](
+ hidden_state, proj_depth_latents, proj_depth_latents
+ )
+
+ reg_output = self.reg_branches[layer_id](hidden_state)
+ conf_output = self.conf_branches[layer_id](hidden_state)
+
+ return reg_output, conf_output
+
+ def forward(
+ self,
+ hidden_states: Tensor,
+ ray_embeddings: Tensor | None,
+ depth_latents: Tensor | None = None,
+ ) -> tuple[Tensor, Tensor]:
+ """Forward pass of the 3D detection head.
+
+ Args:
+ hidden_states: Query hidden states [num_layers, B, num_queries, embed_dims].
+ ray_embeddings: Ray embeddings [B, H*W, 81]. Can be None if use_camera_prompt=False.
+ depth_latents: Depth latent features [B, H*W, depth_embed_dims].
+
+ Returns:
+ Tuple of (stacked_reg, stacked_conf):
+ - stacked_reg: [num_layers, B, num_queries, reg_dims]
+ - stacked_conf: [num_layers, B, num_queries, 1]
+ """
+ all_layers_outputs_3d = []
+ all_layers_conf_3d = []
+
+ for layer_id in range(hidden_states.shape[0]):
+ hidden_state = hidden_states[layer_id]
+
+ reg_output, conf_output = self.single_forward(
+ layer_id, hidden_state, ray_embeddings, depth_latents
+ )
+
+ all_layers_outputs_3d.append(reg_output)
+ all_layers_conf_3d.append(conf_output)
+
+ return torch.stack(all_layers_outputs_3d), torch.stack(all_layers_conf_3d)
+
+
+class Prompt3DQueryLayer(nn.Module):
+ """Prompt 3D object query Layer."""
+
+ def __init__(self, embed_dims: int = 256) -> None:
+ """Init."""
+ super().__init__()
+ self.self_attn = MultiheadAttention(
+ embed_dims=256, num_heads=8, batch_first=True
+ )
+
+ self.norm1 = nn.LayerNorm(embed_dims)
+
+ self.cross_attn = MultiheadAttention(
+ embed_dims=256, num_heads=1, batch_first=True
+ )
+
+ self.norm2 = nn.LayerNorm(embed_dims)
+
+ self.ffn = FFN(embed_dims)
+
+ self.norm3 = nn.LayerNorm(embed_dims)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ query_pos: Tensor | None = None,
+ ) -> Tensor:
+ """Forward."""
+ # self attention
+ query = self.self_attn(
+ query=query,
+ key=query,
+ value=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ )
+ query = self.norm1(query)
+
+ # cross attention
+ query = self.cross_attn(
+ query=query,
+ key=key,
+ value=value,
+ query_pos=query_pos,
+ )
+ query = self.norm2(query)
+
+ # FFN
+ query = self.ffn(query)
+ query = self.norm3(query)
+
+ return query
+
+
+class RoI2Det3D:
+ """Convert RoI to 3D Detection."""
+
+ def __init__(
+ self,
+ nms: bool = False,
+ max_per_img: int = 300,
+ class_agnostic_nms: bool = False,
+ score_threshold: float = 0.0,
+ iou_threshold: float = 0.5,
+ box_coder: Det3DCoder | None = None,
+ ) -> None:
+ """Create an instance of RoI2Det3D."""
+ self.nms = nms
+ self.max_per_img = max_per_img
+ self.class_agnostic_nms = class_agnostic_nms
+ self.score_threshold = score_threshold
+ self.iou_threshold = iou_threshold
+
+ self.box_coder = box_coder or Det3DCoder()
+
+ def __call__(
+ self,
+ cls_score: Tensor,
+ bbox_pred: Tensor,
+ token_positive_maps: dict[int, list[int]] | None,
+ img_shape: tuple[int, int],
+ ori_shape: tuple[int, int],
+ bbox_3d_pred: Tensor,
+ intrinsics: Tensor,
+ padding: list[int] | None,
+ ) -> tuple[Tensor, Tensor, Tensor]:
+ """Transform the bbox head output into bbox results."""
+ assert len(cls_score) == len(bbox_pred) # num_queries
+
+ det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
+ det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
+ det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
+ det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
+ det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
+
+ if token_positive_maps is not None:
+ cls_score = convert_grounding_to_cls_scores(
+ logits=cls_score.sigmoid()[None],
+ positive_maps=[token_positive_maps],
+ )[0]
+
+ k = min(self.max_per_img, cls_score.view(-1).shape[0])
+ if k == 0:
+ device = cls_score.device
+ return (
+ torch.zeros(0, 4, device=device),
+ torch.zeros(0, device=device),
+ torch.zeros(0, dtype=torch.long, device=device),
+ torch.zeros(0, 10, device=device),
+ )
+ scores, indexes = cls_score.view(-1).topk(k)
+ num_classes = cls_score.shape[-1]
+ det_labels = indexes % num_classes
+ bbox_index = indexes // num_classes
+ det_bboxes = det_bboxes[bbox_index]
+ bbox_3d_pred = bbox_3d_pred[bbox_index]
+
+ # Remove low scoring boxes
+ if self.score_threshold > 0.0:
+ mask = scores > self.score_threshold
+ det_bboxes = det_bboxes[mask]
+ det_labels = det_labels[mask]
+ scores = scores[mask]
+ bbox_3d_pred = bbox_3d_pred[mask]
+
+ if self.nms:
+ if self.class_agnostic_nms:
+ keep = nms(det_bboxes, scores, self.iou_threshold)
+ else:
+ keep = batched_nms(
+ det_bboxes, scores, det_labels, self.iou_threshold
+ )
+
+ det_bboxes = det_bboxes[keep]
+ det_labels = det_labels[keep]
+ scores = scores[keep]
+ bbox_3d_pred = bbox_3d_pred[keep]
+ else:
+ cls_score = cls_score.sigmoid()
+ scores, _ = cls_score.max(-1)
+ scores, indexes = scores.topk(self.max_per_img)
+ det_bboxes = det_bboxes[indexes]
+ bbox_3d_pred = bbox_3d_pred[indexes]
+ det_labels = scores.new_zeros(scores.shape, dtype=torch.long)
+
+ if bbox_3d_pred.numel() == 0:
+ return (
+ det_bboxes,
+ scores,
+ det_labels,
+ bbox_3d_pred.new_empty((0, 10)),
+ )
+
+ det_bboxes3d = self.box_coder.decode(
+ det_bboxes, bbox_3d_pred, intrinsics
+ )
+
+ # Remove padding when input_hw is affected by padding
+ if padding is not None:
+ det_bboxes[:, 0] -= padding[0]
+ det_bboxes[:, 1] -= padding[2]
+ det_bboxes[:, 2] -= padding[0]
+ det_bboxes[:, 3] -= padding[2]
+
+ scales = [
+ (img_shape[1] - padding[0] - padding[1]) / ori_shape[1],
+ (img_shape[0] - padding[2] - padding[3]) / ori_shape[0],
+ ]
+
+ else:
+ scales = [img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]]
+
+ # Rescale to original shape
+ det_bboxes /= det_bboxes.new_tensor(scales).repeat((1, 2))
+
+ return det_bboxes, scores, det_labels, det_bboxes3d
diff --git a/wilddet3d/inference.py b/wilddet3d/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdda2f055a6a5d235d6d06c36e44b1be0d13f50f
--- /dev/null
+++ b/wilddet3d/inference.py
@@ -0,0 +1,606 @@
+"""WildDet3D inference wrapper.
+
+Provides a simple forward() interface for WildDet3D inference:
+
+Supports three prompt types with 5-mode text labels:
+- Text prompt: input_texts=["chair", "table"]
+- Box prompt: input_boxes=[[x1, y1, x2, y2]] (pixel xyxy)
+- Point prompt: input_points=[[(x, y, label), ...]] (pixel coords,
+ label: 1=pos, 0=neg)
+
+5-mode support via prompt_text parameter (for box/point prompts):
+- "visual" -> VISUAL mode (one-to-many, no category label)
+- "visual: car" -> VISUAL+LABEL mode (one-to-many, with category)
+- "geometric" -> GEOMETRY mode (one-to-one, no category label)
+- "geometric: car" -> GEOMETRY+LABEL mode (one-to-one, with category)
+- "object" -> default (backward compatible)
+
+Example usage:
+ from wilddet3d.inference import build_model
+ from wilddet3d.preprocessing import preprocess
+
+ # Build model
+ model = build_model(
+ checkpoint="path/to/checkpoint.ckpt"
+ )
+
+ # Preprocess data
+ data = preprocess(image, intrinsics)
+
+ # TEXT mode
+ boxes, boxes3d, scores, class_ids, depth_maps = model(
+ images=data["images"],
+ intrinsics=data["intrinsics"],
+ input_hw=[data["input_hw"]],
+ original_hw=[data["original_hw"]],
+ padding=[data["padding"]],
+ input_texts=["chair", "table"],
+ )
+
+ # VISUAL mode (box prompt, one-to-many)
+ boxes, boxes3d, scores, class_ids, depth_maps = model(
+ ...,
+ input_boxes=[[100, 200, 300, 400]],
+ prompt_text="visual",
+ )
+
+ # GEOMETRY mode (box prompt, one-to-one)
+ boxes, boxes3d, scores, class_ids, depth_maps = model(
+ ...,
+ input_boxes=[[100, 200, 300, 400]],
+ prompt_text="geometric",
+ )
+
+ # Point prompt (works with any prompt_text)
+ boxes, boxes3d, scores, class_ids, depth_maps = model(
+ ...,
+ input_points=[[(150, 250, 1), (200, 300, 0)]],
+ prompt_text="geometric",
+ )
+"""
+
+from typing import List, Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+
+from wilddet3d.data_types import WildDet3DInput
+from wilddet3d.depth import LingbotDepthBackend
+from wilddet3d.depth.depth_fusion import EarlyDepthFusionLingbot
+from wilddet3d.head import Det3DCoder, RoI2Det3D
+from wilddet3d.model import WildDet3D
+
+
+class WildDet3DPredictor(nn.Module):
+ """WildDet3D wrapper with a simple forward() interface.
+
+ Provides a simple forward() interface:
+ boxes, boxes3d, scores, class_ids, depth_maps = model(
+ images=...,
+ intrinsics=...,
+ input_texts=["chair", "table"],
+ )
+ """
+
+ def __init__(
+ self,
+ wilddet3d: WildDet3D,
+ score_threshold: float = 0.3,
+ ):
+ super().__init__()
+ self.wilddet3d = wilddet3d
+ self.score_threshold = score_threshold
+
+ def forward(
+ self,
+ images: Tensor,
+ intrinsics: Optional[Tensor],
+ input_hw: List[Tuple[int, int]],
+ original_hw: List[Tuple[int, int]],
+ padding: List[Tuple[int, int, int, int]],
+ # Prompt types (mutually exclusive)
+ input_texts: Optional[List[str]] = None,
+ input_boxes: Optional[List[List[float]]] = None,
+ input_points: Optional[
+ List[List[Tuple[float, float, int]]]
+ ] = None,
+ # Text label for box/point prompts (5-mode support)
+ # e.g. "visual", "visual: car", "geometric", "geometric: car"
+ prompt_text: str = "object",
+ return_predicted_intrinsics: bool = False,
+ # Optional depth input (e.g., from LiDAR)
+ depth_gt: Optional[Tensor] = None, # (B, 1, H, W) meters
+ ) -> Tuple[
+ List[Tensor],
+ List[Tensor],
+ List[Tensor],
+ List[Tensor],
+ Optional[List[Tensor]],
+ Optional[Tensor],
+ ]:
+ """Forward with simple interface.
+
+ Args:
+ images: (B, 3, H, W) preprocessed images
+ intrinsics: (B, 3, 3) camera intrinsics, or None to use
+ predicted
+ input_hw: List of (H, W) tuples for each image
+ original_hw: List of original (H, W) tuples
+ padding: List of (left, right, top, bottom) padding tuples
+ input_texts: Text prompts (e.g., ["chair", "table"])
+ input_boxes: Box prompts per image, pixel xyxy
+ [[x1,y1,x2,y2], ...]
+ input_points: Point prompts per image
+ [[(x,y,label), ...], ...]
+ prompt_text: Text label for box/point prompts. Controls
+ 5-mode: "object" (default), "visual", "visual: car",
+ "geometric", "geometric: car"
+ return_predicted_intrinsics: Whether to return predicted
+ intrinsics
+ depth_gt: Optional depth input (B, 1, H, W) in meters
+
+ Returns:
+ boxes: List of 2D boxes per image (pixel xyxy)
+ boxes3d: List of 3D boxes per image
+ scores: List of confidence scores per image
+ class_ids: List of class IDs per image
+ depth_maps: List of depth maps per image (or None)
+ predicted_intrinsics: (B, 3, 3) predicted intrinsics
+ (if requested)
+ """
+ device = images.device
+ B = images.shape[0]
+ H, W = input_hw[0]
+
+ # Determine prompt type and create batch
+ if input_texts is not None:
+ batch = self._create_text_batch(
+ images,
+ intrinsics,
+ input_texts,
+ device,
+ padding=padding,
+ )
+ class_names = input_texts
+ elif input_boxes is not None:
+ batch = self._create_box_batch(
+ images,
+ intrinsics,
+ input_boxes,
+ (H, W),
+ device,
+ text=prompt_text,
+ padding=padding,
+ )
+ class_names = [prompt_text]
+ elif input_points is not None:
+ batch = self._create_point_batch(
+ images,
+ intrinsics,
+ input_points,
+ (H, W),
+ device,
+ text=prompt_text,
+ padding=padding,
+ )
+ class_names = [prompt_text]
+ else:
+ raise ValueError(
+ "Must provide one of: input_texts, input_boxes, "
+ "input_points"
+ )
+
+ # Attach depth input if provided
+ if depth_gt is not None:
+ batch.depth_gt = depth_gt
+
+ # Run inference
+ with torch.no_grad():
+ output = self.wilddet3d(batch)
+
+ # Output is Det3DOut with per-image lists
+ boxes = output.boxes
+ boxes3d = output.boxes3d
+ scores = output.scores
+ scores_2d = output.scores_2d
+ scores_3d = output.scores_3d
+ class_ids = output.class_ids
+ depth_maps = output.depth_maps
+
+ # Apply score threshold and rescale boxes to original size
+ boxes_out = []
+ boxes3d_out = []
+ scores_out = []
+ scores_2d_out = []
+ scores_3d_out = []
+ class_ids_out = []
+
+ for i in range(B):
+ # Filter by 2D score
+ mask = scores[i] >= self.score_threshold
+ img_scores = scores[i][mask]
+ img_scores_2d = (
+ scores_2d[i][mask]
+ if scores_2d is not None
+ else torch.zeros_like(img_scores)
+ )
+ img_scores_3d = (
+ scores_3d[i][mask]
+ if scores_3d is not None
+ else torch.zeros_like(img_scores)
+ )
+ img_boxes = boxes[i][mask]
+ img_boxes3d = boxes3d[i][mask]
+ img_class_ids = class_ids[i][mask]
+
+ # Rescale 2D boxes from input_hw to original_hw
+ # Account for padding
+ pad_left, pad_right, pad_top, pad_bottom = padding[i]
+ orig_h, orig_w = original_hw[i]
+
+ # Remove padding offset and rescale
+ img_boxes = img_boxes.clone()
+ img_boxes[:, 0] -= pad_left # x1
+ img_boxes[:, 2] -= pad_left # x2
+ img_boxes[:, 1] -= pad_top # y1
+ img_boxes[:, 3] -= pad_top # y2
+
+ # Scale from padded size to original
+ padded_h = H - pad_top - pad_bottom
+ padded_w = W - pad_left - pad_right
+ scale_x = orig_w / padded_w
+ scale_y = orig_h / padded_h
+
+ img_boxes[:, 0::2] *= scale_x
+ img_boxes[:, 1::2] *= scale_y
+
+ # Clamp to image bounds
+ img_boxes[:, 0::2] = img_boxes[:, 0::2].clamp(0, orig_w)
+ img_boxes[:, 1::2] = img_boxes[:, 1::2].clamp(0, orig_h)
+
+ boxes_out.append(img_boxes)
+ boxes3d_out.append(img_boxes3d)
+ scores_out.append(img_scores)
+ scores_2d_out.append(img_scores_2d)
+ scores_3d_out.append(img_scores_3d)
+ class_ids_out.append(img_class_ids)
+
+ # Get predicted intrinsics if available
+ predicted_K = output.predicted_intrinsics
+
+ if return_predicted_intrinsics:
+ return (
+ boxes_out,
+ boxes3d_out,
+ scores_out,
+ scores_2d_out,
+ scores_3d_out,
+ class_ids_out,
+ depth_maps,
+ predicted_K,
+ )
+ else:
+ return (
+ boxes_out,
+ boxes3d_out,
+ scores_out,
+ scores_2d_out,
+ scores_3d_out,
+ class_ids_out,
+ depth_maps,
+ )
+
+ def _create_text_batch(
+ self,
+ images: Tensor,
+ intrinsics: Tensor,
+ texts: List[str],
+ device: torch.device,
+ padding: Optional[List[Tuple[int, int, int, int]]] = None,
+ ) -> WildDet3DInput:
+ """Create batch for text prompts."""
+ n_prompts = len(texts)
+
+ return WildDet3DInput(
+ images=images,
+ intrinsics=intrinsics,
+ img_ids=torch.zeros(
+ n_prompts, dtype=torch.long, device=device
+ ),
+ text_ids=torch.arange(
+ n_prompts, dtype=torch.long, device=device
+ ),
+ unique_texts=texts,
+ padding=padding,
+ )
+
+ def _create_box_batch(
+ self,
+ images: Tensor,
+ intrinsics: Tensor,
+ boxes_xyxy: List[List[float]],
+ input_hw: Tuple[int, int],
+ device: torch.device,
+ text: str = "object",
+ padding: Optional[List[Tuple[int, int, int, int]]] = None,
+ ) -> WildDet3DInput:
+ """Create batch for box prompts.
+
+ Args:
+ text: Text label for the prompt. Controls 5-mode behavior:
+ "visual" / "visual: car" for one-to-many matching,
+ "geometric" / "geometric: car" for one-to-one matching.
+ """
+ H, W = input_hw
+ n_prompts = len(boxes_xyxy)
+
+ # Convert pixel xyxy to normalized cxcywh
+ boxes_cxcywh = []
+ for box in boxes_xyxy:
+ x1, y1, x2, y2 = box
+ cx = (x1 + x2) / 2 / W
+ cy = (y1 + y2) / 2 / H
+ w = (x2 - x1) / W
+ h = (y2 - y1) / H
+ boxes_cxcywh.append([cx, cy, w, h])
+
+ geo_boxes = torch.tensor(
+ boxes_cxcywh, dtype=torch.float32, device=device
+ )
+ geo_boxes = geo_boxes.unsqueeze(1) # (n_prompts, 1, 4)
+
+ return WildDet3DInput(
+ images=images,
+ intrinsics=intrinsics,
+ img_ids=torch.zeros(
+ n_prompts, dtype=torch.long, device=device
+ ),
+ text_ids=torch.zeros(
+ n_prompts, dtype=torch.long, device=device
+ ),
+ unique_texts=[text],
+ geo_boxes=geo_boxes,
+ geo_boxes_mask=torch.zeros(
+ n_prompts, 1, dtype=torch.bool, device=device
+ ),
+ geo_box_labels=torch.ones(
+ n_prompts, 1, dtype=torch.long, device=device
+ ),
+ padding=padding,
+ )
+
+ def _create_point_batch(
+ self,
+ images: Tensor,
+ intrinsics: Tensor,
+ points_list: List[List[Tuple[float, float, int]]],
+ input_hw: Tuple[int, int],
+ device: torch.device,
+ text: str = "object",
+ padding: Optional[List[Tuple[int, int, int, int]]] = None,
+ ) -> WildDet3DInput:
+ """Create batch for point prompts.
+
+ Args:
+ text: Text label for the prompt. Controls 5-mode behavior:
+ "visual" / "visual: car" for one-to-many matching,
+ "geometric" / "geometric: car" for one-to-one matching.
+ """
+ H, W = input_hw
+ n_prompts = len(points_list)
+
+ # Find max points per prompt for padding
+ max_points = max(len(pts) for pts in points_list)
+
+ # Normalize and pad points
+ geo_points = torch.zeros(
+ n_prompts, max_points, 2, device=device
+ )
+ geo_point_labels = torch.zeros(
+ n_prompts, max_points, dtype=torch.long, device=device
+ )
+ geo_points_mask = torch.ones(
+ n_prompts, max_points, dtype=torch.bool, device=device
+ )
+
+ for i, pts in enumerate(points_list):
+ for j, (x, y, label) in enumerate(pts):
+ geo_points[i, j] = torch.tensor([x / W, y / H])
+ geo_point_labels[i, j] = label
+ geo_points_mask[i, j] = False # False = valid
+
+ return WildDet3DInput(
+ images=images,
+ intrinsics=intrinsics,
+ img_ids=torch.zeros(
+ n_prompts, dtype=torch.long, device=device
+ ),
+ text_ids=torch.zeros(
+ n_prompts, dtype=torch.long, device=device
+ ),
+ unique_texts=[text],
+ geo_points=geo_points,
+ geo_points_mask=geo_points_mask,
+ geo_point_labels=geo_point_labels,
+ padding=padding,
+ )
+
+
+def build_model(
+ checkpoint: str,
+ sam3_checkpoint: str = "pretrained/sam3/sam3_detector.pt",
+ score_threshold: float = 0.3,
+ nms: bool = True,
+ iou_threshold: float = 0.6,
+ device: str = "cuda",
+ backbone_freeze_blocks: int = 28,
+ lingbot_encoder_freeze_blocks: int = 21,
+ ambiguous_rotation: bool = False,
+ canonical_rotation: bool = False,
+ use_depth_input_test: bool = False,
+ use_predicted_intrinsics: bool = False,
+ skip_pretrained: bool = False,
+) -> WildDet3DPredictor:
+ """Build WildDet3D model with LingBot-Depth backend.
+
+ Args:
+ checkpoint: Path to trained WildDet3D checkpoint (.ckpt file)
+ sam3_checkpoint: Path to SAM3 pretrained weights
+ score_threshold: Confidence threshold for filtering
+ nms: Whether to apply NMS
+ iou_threshold: IoU threshold for NMS
+ device: Device to load model on
+ backbone_freeze_blocks: Number of SAM3 ViT blocks to freeze.
+ lingbot_encoder_freeze_blocks: Number of LingBot encoder blocks
+ to freeze.
+ use_predicted_intrinsics: If True, use geometry backend's
+ predicted intrinsics (K_pred) for 3D box decoding instead of
+ the input intrinsics. Useful for in-the-wild images without
+ GT intrinsics.
+ skip_pretrained: If True, skip loading pretrained weights for
+ SAM3 and LingBot-Depth. Use this for inference when the
+ training checkpoint already contains all weights (avoids
+ loading ~4GB of pretrained weights that get immediately
+ overwritten).
+
+ Returns:
+ WildDet3DPredictor model ready for inference
+ """
+ print("Building WildDet3D model with LingBot-Depth backend...")
+
+ # When skip_pretrained=True, patch MDMModel.from_pretrained to build
+ # model structure from config without loading weights (~1GB saved).
+ _mdm_patch_cleanup = None
+ if skip_pretrained:
+ from mdm.model.v2 import MDMModel
+
+ _orig_from_pretrained = MDMModel.from_pretrained
+
+ @classmethod
+ def _from_pretrained_config_only(cls, path, **kwargs):
+ from pathlib import Path as P
+
+ from huggingface_hub import hf_hub_download
+
+ if P(path).exists():
+ cp = path
+ else:
+ cp = hf_hub_download(
+ repo_id=path,
+ repo_type="model",
+ filename="model.pt",
+ **kwargs,
+ )
+ ckpt = torch.load(
+ cp, map_location="cpu", weights_only=True
+ )
+ model = cls(**ckpt["model_config"])
+ print(
+ f"[LingbotDepth] Built model structure from config "
+ f"(skipped pretrained weights)"
+ )
+ return model
+
+ MDMModel.from_pretrained = _from_pretrained_config_only
+ _mdm_patch_cleanup = lambda: setattr(
+ MDMModel, "from_pretrained", _orig_from_pretrained
+ )
+
+ # Build geometry backend (LingBot-Depth)
+ geometry_backend = LingbotDepthBackend(
+ pretrained_model="robbyant/lingbot-depth-postrain-dc-vitl14",
+ num_tokens=2400,
+ target_latent_dim=256,
+ depth_loss_weight=1.0,
+ silog_loss_weight=0.5,
+ monocular_prob=0.7,
+ masked_prob=0.2,
+ mask_ratio_range=(0.6, 0.9),
+ mask_patch_size=14,
+ camera_loss_weight=1.0,
+ detach_depth_latents=True,
+ encoder_freeze_blocks=lingbot_encoder_freeze_blocks,
+ )
+
+ # Restore original from_pretrained
+ if _mdm_patch_cleanup is not None:
+ _mdm_patch_cleanup()
+
+ # Build components
+ box_coder = Det3DCoder(
+ ambiguous_rotation=ambiguous_rotation,
+ canonical_rotation=canonical_rotation,
+ )
+ roi2det3d = RoI2Det3D(
+ box_coder=box_coder,
+ score_threshold=0.0, # Threshold in wrapper
+ nms=nms,
+ iou_threshold=iou_threshold,
+ )
+
+ # ControlNet-style fusion for LingBot-Depth
+ early_depth_fusion = EarlyDepthFusionLingbot(
+ visual_dim=256,
+ depth_dim=256,
+ zero_init=True,
+ )
+
+ # Build WildDet3D
+ # When skip_pretrained=True, build SAM3 model structure without
+ # loading pretrained weights (~3.2GB) since the training checkpoint
+ # already contains all weights.
+ if skip_pretrained:
+ from sam3.model_builder import build_sam3_image_model
+
+ print(
+ "[skip_pretrained] Building SAM3 structure without "
+ "pretrained weights..."
+ )
+ sam3_model = build_sam3_image_model(
+ checkpoint_path=None,
+ load_from_HF=False,
+ device="cpu",
+ eval_mode=False,
+ enable_segmentation=False,
+ )
+ else:
+ sam3_model = None
+
+ wilddet3d = WildDet3D(
+ sam3_model=sam3_model if skip_pretrained else None,
+ sam3_checkpoint=None if skip_pretrained else sam3_checkpoint,
+ box_coder=box_coder,
+ geometry_backend=geometry_backend,
+ roi2det3d=roi2det3d,
+ early_depth_fusion=early_depth_fusion,
+ backbone_freeze_blocks=backbone_freeze_blocks,
+ use_depth_input_test=use_depth_input_test,
+ use_predicted_intrinsics=use_predicted_intrinsics,
+ )
+
+ # Load trained checkpoint
+ print(f"Loading checkpoint: {checkpoint}")
+ ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
+ state_dict = ckpt.get("state_dict", ckpt)
+
+ # Remove "model." prefix
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ new_key = (
+ k.replace("model.", "") if k.startswith("model.") else k
+ )
+ new_state_dict[new_key] = v
+
+ wilddet3d.load_state_dict(new_state_dict, strict=False)
+ wilddet3d = wilddet3d.to(device)
+ wilddet3d.eval()
+
+ # Wrap with predictor interface
+ model = WildDet3DPredictor(
+ wilddet3d, score_threshold=score_threshold
+ )
+ model = model.to(device)
+ model.eval()
+
+ print("Model ready!")
+ return model
diff --git a/wilddet3d/loss/__init__.py b/wilddet3d/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/loss/det2d_loss.py b/wilddet3d/loss/det2d_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3c2a09109ae00395909e5474b004e5d26c74e92
--- /dev/null
+++ b/wilddet3d/loss/det2d_loss.py
@@ -0,0 +1,964 @@
+"""G-DINO Loss."""
+
+import torch
+from torch import Tensor, nn
+from vis4d.common.distributed import reduce_mean
+from vis4d.op.loss.common import l1_loss
+from vis4d.op.loss.reducer import SumWeightedLoss
+
+from wilddet3d.ops.box2d import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
+from wilddet3d.ops.matchers.hungarian import HungarianMatcher
+from wilddet3d.loss.focal_loss import FocalLoss
+from wilddet3d.loss.iou_loss import GIoULoss
+from wilddet3d.ops.match_cost import (
+ BBoxL1Cost,
+ BinaryFocalLossCost,
+ IoUCost,
+)
+from wilddet3d.ops.util import multi_apply
+
+
+class Det2DLoss(nn.Module):
+ """Grounding DINO loss module."""
+
+ def __init__(
+ self, max_text_len: int = 256, sync_cls_avg_factor: bool = True
+ ):
+ super().__init__()
+ self.sync_cls_avg_factor = sync_cls_avg_factor
+ self.max_text_len = max_text_len
+
+ # Matcher
+ self.cls_cost = BinaryFocalLossCost(weight=2.0)
+ self.reg_cost = BBoxL1Cost(weight=5.0, box_format="xywh")
+ self.iou_cost = IoUCost(weight=2.0, iou_mode="giou")
+
+ self.assigner = HungarianMatcher()
+
+ # Losses
+ self.loss_cls = FocalLoss(alpha=0.25, gamma=2.0)
+ self.bg_cls_weight = 0.0
+ self.cls_loss_weight = 1.0
+
+ self.loss_bbox = l1_loss
+ self.bbox_loss_weight = 5.0
+
+ self.loss_iou = GIoULoss()
+ self.iou_loss_weight = 2.0
+
+ def get_targets(
+ self,
+ cls_scores_list: list[Tensor],
+ bbox_preds_list: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ batch_gt_boxes: list[Tensor],
+ batch_gt_boxes_classes: list[Tensor],
+ positive_maps: list[Tensor],
+ text_token_mask: Tensor,
+ ) -> tuple:
+ """Compute regression and classification targets for a batch image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_scores_list (list[Tensor]): Box score logits from a single
+ decoder layer for each image, has shape [num_queries,
+ cls_out_channels].
+ bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
+ decoder layer for each image, with normalized coordinate
+ (cx, cy, w, h) and shape [num_queries, 4].
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels for all images.
+ - label_weights_list (list[Tensor]): Label weights for all images.
+ - bbox_targets_list (list[Tensor]): BBox targets for all images.
+ - bbox_weights_list (list[Tensor]): BBox weights for all images.
+ - num_total_pos (int): Number of positive samples in all images.
+ - num_total_neg (int): Number of negative samples in all images.
+ """
+ (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ pos_inds_list,
+ neg_inds_list,
+ ) = multi_apply(
+ self._get_targets_single,
+ cls_scores_list,
+ bbox_preds_list,
+ input_hw,
+ batch_gt_boxes,
+ batch_gt_boxes_classes,
+ positive_maps,
+ text_token_mask,
+ )
+
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+
+ return (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_pos,
+ num_total_neg,
+ )
+
+ def _get_cost(
+ self,
+ cls_score,
+ bbox_pred,
+ gt_boxes,
+ input_hw,
+ text_token_mask,
+ positive_map,
+ ):
+ """Compute regression and classification cost for one image."""
+ if self.cls_cost.weight != 0:
+ cls_cost = self.cls_cost(cls_score, text_token_mask, positive_map)
+ else:
+ cls_cost = 0
+
+ if self.reg_cost.weight != 0:
+ reg_cost = self.reg_cost(
+ bbox_pred, gt_boxes, input_hw[0], input_hw[1]
+ )
+ else:
+ reg_cost = 0
+
+ if self.iou_cost.weight != 0:
+ iou_cost = self.iou_cost(bbox_pred, gt_boxes)
+ else:
+ iou_cost = 0
+
+ return cls_cost + reg_cost + iou_cost
+
+ def _get_targets_2d_single(
+ self,
+ cls_score: Tensor,
+ bbox_pred: Tensor,
+ input_hw: tuple[int, int],
+ gt_boxes: Tensor,
+ gt_classes: Tensor,
+ positive_map: Tensor,
+ text_token_mask: Tensor,
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Compute regression and classification targets for one image."""
+ img_h, img_w = input_hw
+ num_bboxes = bbox_pred.size(0)
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(
+ 0
+ )
+
+ # convert bbox_pred from xywh, normalized to xyxy, unnormalized
+ bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
+ bbox_pred = bbox_pred * factor
+
+ # assigner and sampler
+ cost = self._get_cost(
+ cls_score,
+ bbox_pred,
+ gt_boxes,
+ input_hw,
+ text_token_mask,
+ positive_map,
+ )
+
+ assign_result = self.assigner(cost, bbox_pred, gt_boxes, gt_classes)
+
+ pos_inds = (
+ torch.nonzero(
+ assign_result.assigned_gt_indices > 0, as_tuple=False
+ )
+ .squeeze(-1)
+ .unique()
+ )
+ neg_inds = (
+ torch.nonzero(
+ assign_result.assigned_gt_indices == 0, as_tuple=False
+ )
+ .squeeze(-1)
+ .unique()
+ )
+ pos_assigned_gt_inds = assign_result.assigned_gt_indices[pos_inds] - 1
+ pos_gt_bboxes = gt_boxes[pos_assigned_gt_inds.long(), :]
+
+ # Major changes. The labels are 0-1 binary labels for each bbox
+ # and text tokens.
+ labels = gt_boxes.new_full(
+ (num_bboxes, self.max_text_len), 0, dtype=torch.float32
+ )
+ labels[pos_inds] = positive_map[pos_assigned_gt_inds]
+ label_weights = gt_boxes.new_ones(num_bboxes)
+
+ # bbox targets
+ bbox_targets = torch.zeros_like(bbox_pred, dtype=gt_boxes.dtype)
+ bbox_weights = torch.zeros_like(bbox_pred, dtype=gt_boxes.dtype)
+ bbox_weights[pos_inds] = 1.0
+
+ # DETR regress the relative position of boxes (cxcywh) in the image.
+ # Thus the learning target should be normalized by the image size, also
+ # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
+ pos_gt_bboxes_normalized = pos_gt_bboxes / factor
+ pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+ bbox_targets[pos_inds] = pos_gt_bboxes_targets
+
+ return (
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ pos_gt_bboxes,
+ pos_inds,
+ neg_inds,
+ pos_assigned_gt_inds,
+ )
+
+ def _get_targets_single(
+ self,
+ cls_score: Tensor,
+ bbox_pred: Tensor,
+ input_hw: tuple[int, int],
+ gt_boxes: Tensor,
+ gt_classes: Tensor,
+ positive_map: Tensor,
+ text_token_mask: Tensor,
+ ) -> tuple:
+ """Compute regression and classification targets for one image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_score (Tensor): Box score logits from a single decoder layer
+ for one image. Shape [num_queries, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
+ for one image, with normalized coordinate (cx, cy, w, h) and
+ shape [num_queries, 4].
+ gt_instances (:obj:`InstanceData`): Ground truth of instance
+ annotations. It should includes ``bboxes`` and ``labels``
+ attributes.
+ img_meta (dict): Meta information for one image.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image.
+ - label_weights (Tensor]): Label weights of each image.
+ - bbox_targets (Tensor): BBox targets of each image.
+ - bbox_weights (Tensor): BBox weights of each image.
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+ (
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ _,
+ pos_inds,
+ neg_inds,
+ _,
+ ) = self._get_targets_2d_single(
+ cls_score,
+ bbox_pred,
+ input_hw,
+ gt_boxes,
+ gt_classes,
+ positive_map,
+ text_token_mask,
+ )
+
+ return (
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ pos_inds,
+ neg_inds,
+ )
+
+ def loss_by_feat_single(
+ self,
+ cls_scores: Tensor,
+ bbox_preds: Tensor,
+ text_token_mask: Tensor,
+ input_hw: list[tuple[int, int]],
+ batch_gt_boxes: list[Tensor],
+ batch_gt_boxes_classes: list[Tensor],
+ positive_maps: list[Tensor],
+ ) -> tuple[Tensor]:
+ """Loss function for outputs from a single decoder layer of a single
+ feature level.
+
+ Args:
+ cls_scores (Tensor): Box score logits from a single decoder layer
+ for all images, has shape (bs, num_queries, cls_out_channels).
+ bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
+ for all images, with normalized coordinate (cx, cy, w, h) and
+ shape (bs, num_queries, 4).
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+
+ Returns:
+ Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and
+ `loss_iou`.
+ """
+ num_imgs = cls_scores.size(0)
+
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
+
+ with torch.no_grad():
+ (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_pos,
+ num_total_neg,
+ ) = self.get_targets(
+ cls_scores_list,
+ bbox_preds_list,
+ input_hw,
+ batch_gt_boxes,
+ batch_gt_boxes_classes,
+ positive_maps,
+ text_token_mask,
+ )
+
+ labels = torch.stack(labels_list, 0)
+ label_weights = torch.stack(label_weights_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ bbox_weights = torch.cat(bbox_weights_list, 0)
+
+ # Loss is not computed for the padded regions of the text.
+ assert text_token_mask.dim() == 2
+ text_masks = text_token_mask.new_zeros(
+ (text_token_mask.size(0), self.max_text_len)
+ )
+ text_masks[:, : text_token_mask.size(1)] = text_token_mask
+ text_mask = (text_masks > 0).unsqueeze(1)
+ text_mask = text_mask.repeat(1, cls_scores.size(1), 1)
+ cls_scores = torch.masked_select(cls_scores, text_mask).contiguous()
+
+ labels = torch.masked_select(labels, text_mask)
+ label_weights = label_weights[..., None].repeat(
+ 1, 1, text_mask.size(-1)
+ )
+ label_weights = torch.masked_select(label_weights, text_mask)
+
+ # classification loss
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = (
+ num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
+ )
+ if self.sync_cls_avg_factor:
+ cls_avg_factor = reduce_mean(
+ cls_scores.new_tensor([cls_avg_factor])
+ )
+ cls_avg_factor = max(cls_avg_factor, 1)
+
+ loss_cls = self.cls_loss_weight * self.loss_cls(
+ cls_scores,
+ labels,
+ reducer=SumWeightedLoss(
+ weight=label_weights, avg_factor=cls_avg_factor
+ ),
+ )
+
+ # Compute the average number of gt boxes across all gpus, for
+ # normalization purposes
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
+
+ # construct factors used for rescale bboxes
+ factors = []
+ for img_hw, bbox_pred in zip(input_hw, bbox_preds):
+ img_h, img_w = img_hw
+ factor = (
+ bbox_pred.new_tensor([img_w, img_h, img_w, img_h])
+ .unsqueeze(0)
+ .repeat(bbox_pred.size(0), 1)
+ )
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+
+ # DETR regress the relative position of boxes (cxcywh) in the image,
+ # thus the learning target is normalized by the image size. So here
+ # we need to re-scale them for calculating IoU loss
+ bbox_preds = bbox_preds.reshape(-1, 4)
+ bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+ bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+ # regression L1 loss
+ loss_bbox = self.bbox_loss_weight * self.loss_bbox(
+ bbox_preds,
+ bbox_targets,
+ reducer=SumWeightedLoss(
+ weight=bbox_weights, avg_factor=num_total_pos
+ ),
+ )
+
+ # regression IoU loss, defaultly GIoU loss
+ loss_iou = self.iou_loss_weight * self.loss_iou(
+ bboxes,
+ bboxes_gt,
+ reducer=SumWeightedLoss(
+ weight=bbox_weights.mean(-1), avg_factor=num_total_pos
+ ),
+ )
+
+ return loss_cls, loss_bbox, loss_iou
+
+ def forward(
+ self,
+ all_layers_cls_scores: Tensor,
+ all_layers_bbox_preds: Tensor,
+ text_token_mask: Tensor,
+ enc_cls_scores: Tensor,
+ enc_bbox_preds: Tensor,
+ dn_meta: dict[str, int],
+ positive_maps: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ batch_gt_boxes: list[Tensor],
+ batch_gt_boxes_classes: list[Tensor],
+ ) -> dict[str, Tensor]:
+ """Loss function.
+
+ Args:
+ all_layers_cls_scores (Tensor): Classification scores of all
+ decoder layers, has shape (num_decoder_layers, bs,
+ num_queries_total, cls_out_channels), where
+ `num_queries_total` is the sum of `num_denoising_queries`
+ and `num_matching_queries`.
+ all_layers_bbox_preds (Tensor): Regression outputs of all decoder
+ layers. Each is a 4D-tensor with normalized coordinate format
+ (cx, cy, w, h) and has shape (num_decoder_layers, bs,
+ num_queries_total, 4).
+ enc_cls_scores (Tensor): The score of each point on encode
+ feature map, has shape (bs, num_feat_points, cls_out_channels).
+ enc_bbox_preds (Tensor): The proposal generate from the encode
+ feature map, has shape (bs, num_feat_points, 4) with the last
+ dimension arranged as (cx, cy, w, h).
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ dn_meta (Dict[str, int]): The dictionary saves information about
+ group collation, including 'num_denoising_queries' and
+ 'num_denoising_groups'. It will be used for split outputs of
+ denoising and matching parts and loss calculation.
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
+ data that is ignored during training and testing.
+ Defaults to None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # extract denoising and matching part of outputs
+ (
+ all_layers_matching_cls_scores,
+ all_layers_matching_bbox_preds,
+ all_layers_denoising_cls_scores,
+ all_layers_denoising_bbox_preds,
+ ) = split_outputs(
+ all_layers_cls_scores, all_layers_bbox_preds, dn_meta
+ )
+
+ # DETRHead loss_by_feat
+ losses_cls, losses_bbox, losses_iou = multi_apply(
+ self.loss_by_feat_single,
+ all_layers_matching_cls_scores,
+ all_layers_matching_bbox_preds,
+ text_token_mask=text_token_mask,
+ input_hw=input_hw,
+ batch_gt_boxes=batch_gt_boxes,
+ batch_gt_boxes_classes=batch_gt_boxes_classes,
+ positive_maps=positive_maps,
+ )
+
+ loss_dict = dict()
+
+ # loss from the last decoder layer
+ loss_dict["loss_cls"] = losses_cls[-1]
+ loss_dict["loss_bbox"] = losses_bbox[-1]
+ loss_dict["loss_iou"] = losses_iou[-1]
+
+ # loss from other decoder layers
+ for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in enumerate(
+ zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1])
+ ):
+ loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
+ loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i
+ loss_dict[f"d{num_dec_layer}.loss_iou"] = loss_iou_i
+
+ # loss of proposal generated from encode feature map.
+ if enc_cls_scores is not None:
+ # NOTE The enc_loss calculation of the DINO is
+ # different from that of Deformable DETR.
+ enc_loss_cls, enc_losses_bbox, enc_losses_iou = (
+ self.loss_by_feat_single(
+ enc_cls_scores,
+ enc_bbox_preds,
+ text_token_mask=text_token_mask,
+ input_hw=input_hw,
+ batch_gt_boxes=batch_gt_boxes,
+ batch_gt_boxes_classes=batch_gt_boxes_classes,
+ positive_maps=positive_maps,
+ )
+ )
+ loss_dict["enc_loss_cls"] = enc_loss_cls
+ loss_dict["enc_loss_bbox"] = enc_losses_bbox
+ loss_dict["enc_loss_iou"] = enc_losses_iou
+
+ if all_layers_denoising_cls_scores is not None:
+ # calculate denoising loss from all decoder layers
+ dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn(
+ all_layers_denoising_cls_scores,
+ all_layers_denoising_bbox_preds,
+ boxes2d=batch_gt_boxes,
+ boxes2d_classes=batch_gt_boxes_classes,
+ positive_maps=positive_maps,
+ input_hw=input_hw,
+ text_token_mask=text_token_mask,
+ dn_meta=dn_meta,
+ )
+
+ # collate denoising loss
+ loss_dict["dn_loss_cls"] = dn_losses_cls[-1]
+ loss_dict["dn_loss_bbox"] = dn_losses_bbox[-1]
+ loss_dict["dn_loss_iou"] = dn_losses_iou[-1]
+
+ for num_dec_layer, (
+ loss_cls_i,
+ loss_bbox_i,
+ loss_iou_i,
+ ) in enumerate(
+ zip(
+ dn_losses_cls[:-1], dn_losses_bbox[:-1], dn_losses_iou[:-1]
+ )
+ ):
+ loss_dict[f"d{num_dec_layer}.dn_loss_cls"] = loss_cls_i
+ loss_dict[f"d{num_dec_layer}.dn_loss_bbox"] = loss_bbox_i
+ loss_dict[f"d{num_dec_layer}.dn_loss_iou"] = loss_iou_i
+
+ return loss_dict
+
+ def _get_dn_targets_single(
+ self,
+ gt_bboxes: Tensor,
+ gt_labels: Tensor,
+ positive_maps: Tensor,
+ img_shape: tuple[int, int],
+ num_groups: int,
+ num_denoising_queries: int,
+ ) -> tuple:
+ """Get targets in denoising part for one image.
+
+ Args:
+ gt_instances (:obj:`InstanceData`): Ground truth of instance
+ annotations. It should includes ``bboxes`` and ``labels``
+ attributes.
+ img_meta (dict): Meta information for one image.
+ dn_meta (Dict[str, int]): The dictionary saves information about
+ group collation, including 'num_denoising_queries' and
+ 'num_denoising_groups'. It will be used for split outputs of
+ denoising and matching parts and loss calculation.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image.
+ - label_weights (Tensor]): Label weights of each image.
+ - bbox_targets (Tensor): BBox targets of each image.
+ - bbox_weights (Tensor): BBox weights of each image.
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+ num_queries_each_group = int(num_denoising_queries / num_groups)
+ device = gt_bboxes.device
+
+ if len(gt_labels) > 0:
+ t = torch.arange(len(gt_labels), dtype=torch.long, device=device)
+ t = t.unsqueeze(0).repeat(num_groups, 1)
+ pos_assigned_gt_inds = t.flatten()
+ pos_inds = torch.arange(
+ num_groups, dtype=torch.long, device=device
+ )
+ pos_inds = pos_inds.unsqueeze(1) * num_queries_each_group + t
+ pos_inds = pos_inds.flatten()
+ else:
+ pos_inds = pos_assigned_gt_inds = gt_bboxes.new_tensor(
+ [], dtype=torch.long
+ )
+
+ neg_inds = pos_inds + num_queries_each_group // 2
+ # label targets
+ # this change
+ labels = gt_bboxes.new_full(
+ (num_denoising_queries, self.max_text_len), 0, dtype=torch.float32
+ )
+ labels[pos_inds] = positive_maps[pos_assigned_gt_inds]
+ label_weights = gt_bboxes.new_ones(num_denoising_queries)
+
+ # bbox targets
+ bbox_targets = torch.zeros(num_denoising_queries, 4, device=device)
+ bbox_weights = torch.zeros(num_denoising_queries, 4, device=device)
+ bbox_weights[pos_inds] = 1.0
+
+ img_h, img_w = img_shape
+
+ # DETR regress the relative position of boxes (cxcywh) in the image.
+ # Thus the learning target should be normalized by the image size, also
+ # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
+ factor = gt_bboxes.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(
+ 0
+ )
+ gt_bboxes_normalized = gt_bboxes / factor
+ gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized)
+ bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1])
+
+ return (
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ pos_inds,
+ neg_inds,
+ )
+
+ def get_dn_targets(
+ self,
+ boxes2d: list[Tensor],
+ boxes2d_classes: list[Tensor],
+ positive_maps: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ dn_meta: dict[str, int],
+ ) -> tuple:
+ """Get targets in denoising part for a batch of images.
+
+ Args:
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ dn_meta (Dict[str, int]): The dictionary saves information about
+ group collation, including 'num_denoising_queries' and
+ 'num_denoising_groups'. It will be used for split outputs of
+ denoising and matching parts and loss calculation.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels for all images.
+ - label_weights_list (list[Tensor]): Label weights for all images.
+ - bbox_targets_list (list[Tensor]): BBox targets for all images.
+ - bbox_weights_list (list[Tensor]): BBox weights for all images.
+ - num_total_pos (int): Number of positive samples in all images.
+ - num_total_neg (int): Number of negative samples in all images.
+ """
+ (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ pos_inds_list,
+ neg_inds_list,
+ ) = multi_apply(
+ self._get_dn_targets_single,
+ boxes2d,
+ boxes2d_classes,
+ positive_maps,
+ input_hw,
+ num_groups=dn_meta["num_denoising_groups"],
+ num_denoising_queries=dn_meta["num_denoising_queries"],
+ )
+
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+
+ return (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_pos,
+ num_total_neg,
+ )
+
+ def _loss_dn_single(
+ self,
+ dn_cls_scores: Tensor,
+ dn_bbox_preds: Tensor,
+ boxes2d: list[Tensor],
+ boxes2d_classes: list[Tensor],
+ positive_maps: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ text_token_mask: Tensor,
+ dn_meta,
+ ):
+ """Denoising loss for outputs from a single decoder layer.
+
+ Args:
+ dn_cls_scores (Tensor): Classification scores of a single decoder
+ layer in denoising part, has shape (bs, num_denoising_queries,
+ cls_out_channels).
+ dn_bbox_preds (Tensor): Regression outputs of a single decoder
+ layer in denoising part. Each is a 4D-tensor with normalized
+ coordinate format (cx, cy, w, h) and has shape
+ (bs, num_denoising_queries, 4).
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ dn_meta (Dict[str, int]): The dictionary saves information about
+ group collation, including 'num_denoising_queries' and
+ 'num_denoising_groups'. It will be used for split outputs of
+ denoising and matching parts and loss calculation.
+
+ Returns:
+ Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and
+ `loss_iou`.
+ """
+ (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_pos,
+ num_total_neg,
+ ) = self.get_dn_targets(
+ boxes2d, boxes2d_classes, positive_maps, input_hw, dn_meta
+ )
+
+ labels = torch.stack(labels_list, 0)
+ label_weights = torch.stack(label_weights_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ bbox_weights = torch.cat(bbox_weights_list, 0)
+
+ # Loss is not computed for the padded regions of the text.
+ assert text_token_mask.dim() == 2
+ text_masks = text_token_mask.new_zeros(
+ (text_token_mask.size(0), self.max_text_len)
+ )
+ text_masks[:, : text_token_mask.size(1)] = text_token_mask
+ text_mask = (text_masks > 0).unsqueeze(1)
+ text_mask = text_mask.repeat(1, dn_cls_scores.size(1), 1)
+ cls_scores = torch.masked_select(dn_cls_scores, text_mask).contiguous()
+ labels = torch.masked_select(labels, text_mask)
+ label_weights = label_weights[..., None].repeat(
+ 1, 1, text_mask.size(-1)
+ )
+ label_weights = torch.masked_select(label_weights, text_mask)
+
+ # classification loss
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = (
+ num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
+ )
+ if self.sync_cls_avg_factor:
+ cls_avg_factor = reduce_mean(
+ cls_scores.new_tensor([cls_avg_factor])
+ )
+ cls_avg_factor = max(cls_avg_factor, 1)
+
+ if len(cls_scores) > 0:
+ loss_cls = self.cls_loss_weight * self.loss_cls(
+ cls_scores,
+ labels,
+ reducer=SumWeightedLoss(
+ weight=label_weights, avg_factor=cls_avg_factor
+ ),
+ )
+ else:
+ loss_cls = torch.zeros(
+ 1, dtype=cls_scores.dtype, device=cls_scores.device
+ )
+
+ # Compute the average number of gt boxes across all gpus, for
+ # normalization purposes
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
+
+ # construct factors used for rescale bboxes
+ factors = []
+ for img_hw, bbox_pred in zip(input_hw, dn_bbox_preds):
+ img_h, img_w = img_hw
+ factor = (
+ bbox_pred.new_tensor([img_w, img_h, img_w, img_h])
+ .unsqueeze(0)
+ .repeat(bbox_pred.size(0), 1)
+ )
+ factors.append(factor)
+ factors = torch.cat(factors)
+
+ # DETR regress the relative position of boxes (cxcywh) in the image,
+ # thus the learning target is normalized by the image size. So here
+ # we need to re-scale them for calculating IoU loss
+ bbox_preds = dn_bbox_preds.reshape(-1, 4)
+ bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+ bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+ if bbox_targets.shape[0] == 0:
+ loss_bbox = bbox_preds.sum()
+ loss_iou = bbox_preds.sum()
+ return loss_cls, loss_bbox, loss_iou
+
+ # regression L1 loss
+ loss_bbox = self.bbox_loss_weight * self.loss_bbox(
+ bbox_preds,
+ bbox_targets,
+ reducer=SumWeightedLoss(
+ weight=bbox_weights, avg_factor=num_total_pos
+ ),
+ )
+
+ # regression IoU loss, defaultly GIoU loss
+ loss_iou = self.iou_loss_weight * self.loss_iou(
+ bboxes,
+ bboxes_gt,
+ reducer=SumWeightedLoss(
+ weight=bbox_weights.mean(-1), avg_factor=num_total_pos
+ ),
+ )
+
+ return loss_cls, loss_bbox, loss_iou
+
+ def loss_dn(
+ self,
+ all_layers_denoising_cls_scores: Tensor,
+ all_layers_denoising_bbox_preds: Tensor,
+ boxes2d: list[Tensor],
+ boxes2d_classes: list[Tensor],
+ positive_maps: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ text_token_mask: Tensor,
+ dn_meta: dict[str, int],
+ ):
+ """Calculate denoising loss.
+
+ Args:
+ all_layers_denoising_cls_scores (Tensor): Classification scores of
+ all decoder layers in denoising part, has shape (
+ num_decoder_layers, bs, num_denoising_queries,
+ cls_out_channels).
+ all_layers_denoising_bbox_preds (Tensor): Regression outputs of all
+ decoder layers in denoising part. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and has shape
+ (num_decoder_layers, bs, num_denoising_queries, 4).
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ dn_meta (Dict[str, int]): The dictionary saves information about
+ group collation, including 'num_denoising_queries' and
+ 'num_denoising_groups'. It will be used for split outputs of
+ denoising and matching parts and loss calculation.
+
+ Returns:
+ Tuple[List[Tensor]]: The loss_dn_cls, loss_dn_bbox, and loss_dn_iou
+ of each decoder layers.
+ """
+ return multi_apply(
+ self._loss_dn_single,
+ all_layers_denoising_cls_scores,
+ all_layers_denoising_bbox_preds,
+ boxes2d=boxes2d,
+ boxes2d_classes=boxes2d_classes,
+ positive_maps=positive_maps,
+ input_hw=input_hw,
+ text_token_mask=text_token_mask,
+ dn_meta=dn_meta,
+ )
+
+
+# TODO: Move to DINO ops
+def split_outputs(
+ all_layers_cls_scores: Tensor,
+ all_layers_bbox_preds: Tensor,
+ dn_meta: dict[str, int] | None = None,
+) -> tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Split outputs of the denoising part and the matching part.
+
+ For the total outputs of `num_queries_total` length, the former
+ `num_denoising_queries` outputs are from denoising queries, and
+ the rest `num_matching_queries` ones are from matching queries,
+ where `num_queries_total` is the sum of `num_denoising_queries` and
+ `num_matching_queries`.
+
+ Args:
+ all_layers_cls_scores (Tensor): Classification scores of all
+ decoder layers, has shape (num_decoder_layers, bs,
+ num_queries_total, cls_out_channels).
+ all_layers_bbox_preds (Tensor): Regression outputs of all decoder
+ layers. Each is a 4D-tensor with normalized coordinate format
+ (cx, cy, w, h) and has shape (num_decoder_layers, bs,
+ num_queries_total, 4).
+ dn_meta (Dict[str, int]): The dictionary saves information about
+ group collation, including 'num_denoising_queries' and
+ 'num_denoising_groups'.
+
+ Returns:
+ Tuple[Tensor]: a tuple containing the following outputs.
+
+ - all_layers_matching_cls_scores (Tensor): Classification scores
+ of all decoder layers in matching part, has shape
+ (num_decoder_layers, bs, num_matching_queries, cls_out_channels).
+ - all_layers_matching_bbox_preds (Tensor): Regression outputs of
+ all decoder layers in matching part. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and has shape
+ (num_decoder_layers, bs, num_matching_queries, 4).
+ - all_layers_denoising_cls_scores (Tensor): Classification scores
+ of all decoder layers in denoising part, has shape
+ (num_decoder_layers, bs, num_denoising_queries,
+ cls_out_channels).
+ - all_layers_denoising_bbox_preds (Tensor): Regression outputs of
+ all decoder layers in denoising part. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and has shape
+ (num_decoder_layers, bs, num_denoising_queries, 4).
+ """
+ # FIXME: Can dn_meta be None?
+ num_denoising_queries = dn_meta["num_denoising_queries"]
+
+ if dn_meta is not None:
+ all_layers_denoising_cls_scores = all_layers_cls_scores[
+ :, :, :num_denoising_queries, :
+ ]
+ all_layers_denoising_bbox_preds = all_layers_bbox_preds[
+ :, :, :num_denoising_queries, :
+ ]
+ all_layers_matching_cls_scores = all_layers_cls_scores[
+ :, :, num_denoising_queries:, :
+ ]
+ all_layers_matching_bbox_preds = all_layers_bbox_preds[
+ :, :, num_denoising_queries:, :
+ ]
+ else:
+ all_layers_denoising_cls_scores = None
+ all_layers_denoising_bbox_preds = None
+ all_layers_matching_cls_scores = all_layers_cls_scores
+ all_layers_matching_bbox_preds = all_layers_bbox_preds
+
+ return (
+ all_layers_matching_cls_scores,
+ all_layers_matching_bbox_preds,
+ all_layers_denoising_cls_scores,
+ all_layers_denoising_bbox_preds,
+ )
diff --git a/wilddet3d/loss/det3d_loss.py b/wilddet3d/loss/det3d_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..739d0b36e24847c9839a6bb967a1d984b8d40948
--- /dev/null
+++ b/wilddet3d/loss/det3d_loss.py
@@ -0,0 +1,490 @@
+"""3D-MOOD loss."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+from vis4d.common.distributed import reduce_mean
+from vis4d.common.typing import ArgsType
+from vis4d.op.loss.common import l1_loss
+from vis4d.op.loss.reducer import SumWeightedLoss
+
+from wilddet3d.ops.box2d import bbox_cxcywh_to_xyxy
+from wilddet3d.loss.det2d_loss import (
+ Det2DLoss,
+ split_outputs,
+)
+from wilddet3d.ops.util import multi_apply
+
+from .coder import Det3DCoder
+
+
+class Det3DLoss(Det2DLoss):
+ """Grounding DINO with 3D loss."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ box_coder: Det3DCoder | None = None,
+ loss_center_weight: float = 1.0,
+ loss_depth_weight: float = 1.0,
+ loss_dim_weight: float = 1.0,
+ loss_rot_weight: float = 1.0,
+ loss_2d_scale: float = 1.0,
+ loss_3d_scale: float = 1.0,
+ **kwargs: ArgsType,
+ ):
+ """Init."""
+ super().__init__(*args, **kwargs)
+ self.box_coder = box_coder or Det3DCoder()
+
+ self.reg_dims = self.box_coder.reg_dims
+
+ self.loss_center_weight = loss_center_weight
+ self.loss_depth_weight = loss_depth_weight
+ self.loss_dim_weight = loss_dim_weight
+ self.loss_rot_weight = loss_rot_weight
+ self.loss_2d_scale = loss_2d_scale
+ self.loss_3d_scale = loss_3d_scale
+
+ def get_targets_3d(
+ self,
+ cls_scores_list: list[Tensor],
+ bbox_preds_list: list[Tensor],
+ bbox_preds_3d_list: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ batch_gt_boxes: list[Tensor],
+ batch_gt_boxes_3d: list[Tensor],
+ batch_gt_boxes_classes: list[Tensor],
+ batch_gt_intrinsics: list[Tensor],
+ positive_maps: list[Tensor],
+ text_token_mask: Tensor,
+ ) -> tuple:
+ """Compute regression and classification targets for a batch image."""
+ (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ bbox_targets_3d_list,
+ bbox_weights_3d_list,
+ pos_inds_list,
+ neg_inds_list,
+ ) = multi_apply(
+ self._get_targets_3d_single,
+ cls_scores_list,
+ bbox_preds_list,
+ bbox_preds_3d_list,
+ input_hw,
+ batch_gt_boxes,
+ batch_gt_boxes_3d,
+ batch_gt_boxes_classes,
+ batch_gt_intrinsics,
+ positive_maps,
+ text_token_mask,
+ )
+
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+
+ return (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ bbox_targets_3d_list,
+ bbox_weights_3d_list,
+ num_total_pos,
+ num_total_neg,
+ )
+
+ def _get_targets_3d_single(
+ self,
+ cls_score: Tensor,
+ bbox_pred: Tensor,
+ bbox_pred_3d: Tensor,
+ input_hw: tuple[int, int],
+ gt_boxes: Tensor,
+ gt_boxes_3d: Tensor,
+ gt_classes: Tensor,
+ gt_intrinsics: Tensor,
+ positive_map: Tensor,
+ text_token_mask: Tensor,
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Compute regression and classification targets for one image."""
+ # 2D Target
+ with torch.no_grad():
+ (
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ pos_pred_boxes2d,
+ pos_inds,
+ neg_inds,
+ pos_assigned_gt_inds,
+ ) = self._get_targets_2d_single(
+ cls_score,
+ bbox_pred,
+ input_hw,
+ gt_boxes,
+ gt_classes,
+ positive_map,
+ text_token_mask,
+ )
+
+ # 3D Target
+ pos_gt_boxes3d = gt_boxes_3d[pos_assigned_gt_inds.long(), :]
+
+ pos_gt_bboxes_3d, pos_gt_bboxes_3d_weights = self.box_coder.encode(
+ pos_pred_boxes2d, pos_gt_boxes3d, gt_intrinsics
+ )
+
+ bbox_targets_3d = torch.zeros_like(bbox_pred_3d)
+ bbox_targets_3d[pos_inds] = pos_gt_bboxes_3d
+
+ bbox_weights_3d = torch.zeros_like(bbox_pred_3d)
+ bbox_weights_3d[pos_inds] = pos_gt_bboxes_3d_weights
+
+ return (
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ bbox_targets_3d,
+ bbox_weights_3d,
+ pos_inds,
+ neg_inds,
+ )
+
+ def loss_3d_by_feat_single(
+ self,
+ cls_scores: Tensor,
+ bbox_preds: Tensor,
+ bbox_3d_preds: Tensor,
+ text_token_mask: Tensor,
+ input_hw: list[tuple[int, int]],
+ batch_gt_boxes: list[Tensor],
+ batch_gt_boxes_3d: list[Tensor],
+ batch_gt_boxes_classes: list[Tensor],
+ batch_gt_intrinsics: list[Tensor],
+ positive_maps: list[Tensor],
+ ):
+ """Loss function for outputs from a single decoder layer."""
+ num_imgs = cls_scores.size(0)
+
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
+ bbox_preds_3d_list = [bbox_3d_preds[i] for i in range(num_imgs)]
+
+ (
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ bbox_targets_3d_list,
+ bbox_weights_3d_list,
+ num_total_pos,
+ num_total_neg,
+ ) = self.get_targets_3d(
+ cls_scores_list,
+ bbox_preds_list,
+ bbox_preds_3d_list,
+ input_hw,
+ batch_gt_boxes,
+ batch_gt_boxes_3d,
+ batch_gt_boxes_classes,
+ batch_gt_intrinsics,
+ positive_maps,
+ text_token_mask,
+ )
+
+ labels = torch.stack(labels_list, 0)
+ label_weights = torch.stack(label_weights_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ bbox_targets_3d = torch.cat(bbox_targets_3d_list, 0)
+ bbox_weights = torch.cat(bbox_weights_list, 0)
+ bbox_weights_3d = torch.cat(bbox_weights_3d_list, 0)
+
+ # Loss is not computed for the padded regions of the text.
+ assert text_token_mask.dim() == 2
+ text_masks = text_token_mask.new_zeros(
+ (text_token_mask.size(0), self.max_text_len)
+ )
+ text_masks[:, : text_token_mask.size(1)] = text_token_mask
+ text_mask = (text_masks > 0).unsqueeze(1)
+ text_mask = text_mask.repeat(1, cls_scores.size(1), 1)
+ cls_scores = torch.masked_select(cls_scores, text_mask).contiguous()
+
+ labels = torch.masked_select(labels, text_mask)
+ label_weights = label_weights[..., None].repeat(
+ 1, 1, text_mask.size(-1)
+ )
+ label_weights = torch.masked_select(label_weights, text_mask)
+
+ # classification loss
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = (
+ num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
+ )
+ if self.sync_cls_avg_factor:
+ cls_avg_factor = reduce_mean(
+ cls_scores.new_tensor([cls_avg_factor])
+ )
+ cls_avg_factor = max(cls_avg_factor, 1)
+
+ loss_cls = self.loss_2d_scale * self.cls_loss_weight * self.loss_cls(
+ cls_scores,
+ labels,
+ reducer=SumWeightedLoss(
+ weight=label_weights, avg_factor=cls_avg_factor
+ ),
+ )
+
+ # Compute the average number of gt boxes across all gpus, for
+ # normalization purposes
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
+
+ # construct factors used for rescale bboxes
+ factors = []
+ for img_hw, bbox_pred in zip(input_hw, bbox_preds):
+ img_h, img_w = img_hw
+ factor = (
+ bbox_pred.new_tensor([img_w, img_h, img_w, img_h])
+ .unsqueeze(0)
+ .repeat(bbox_pred.size(0), 1)
+ )
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+
+ # DETR regress the relative position of boxes (cxcywh) in the image,
+ # thus the learning target is normalized by the image size. So here
+ # we need to re-scale them for calculating IoU loss
+ bbox_preds = bbox_preds.reshape(-1, 4)
+ bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+ bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+ # regression L1 loss (2D)
+ loss_bbox = self.loss_2d_scale * self.bbox_loss_weight * self.loss_bbox(
+ bbox_preds,
+ bbox_targets,
+ reducer=SumWeightedLoss(
+ weight=bbox_weights, avg_factor=num_total_pos
+ ),
+ )
+
+ # regression IoU loss (2D)
+ loss_iou = self.loss_2d_scale * self.iou_loss_weight * self.loss_iou(
+ bboxes,
+ bboxes_gt,
+ reducer=SumWeightedLoss(
+ weight=bbox_weights.mean(-1), avg_factor=num_total_pos
+ ),
+ )
+
+ # 3D Loss
+ bbox_3d_preds = bbox_3d_preds.reshape(-1, self.reg_dims)
+
+ # Delta 2D center Loss (3D)
+ loss_cen = self.loss_3d_scale * self.loss_center_weight * l1_loss(
+ bbox_3d_preds[:, :2],
+ bbox_targets_3d[:, :2],
+ reducer=SumWeightedLoss(
+ weight=bbox_weights_3d[:, :2], avg_factor=num_total_pos
+ ),
+ )
+
+ # Depth Loss (3D)
+ loss_depth = self.loss_3d_scale * self.loss_depth_weight * l1_loss(
+ bbox_3d_preds[:, 2],
+ bbox_targets_3d[:, 2],
+ reducer=SumWeightedLoss(
+ weight=bbox_weights_3d[:, 2], avg_factor=num_total_pos
+ ),
+ )
+
+ # Dimension Loss (3D)
+ loss_dim = self.loss_3d_scale * self.loss_dim_weight * l1_loss(
+ bbox_3d_preds[:, 3:6],
+ bbox_targets_3d[:, 3:6],
+ reducer=SumWeightedLoss(
+ weight=bbox_weights_3d[:, 3:6], avg_factor=num_total_pos
+ ),
+ )
+
+ # Rotation Loss (3D)
+ loss_rot = self.loss_3d_scale * self.loss_rot_weight * l1_loss(
+ bbox_3d_preds[:, 6:],
+ bbox_targets_3d[:, 6:],
+ reducer=SumWeightedLoss(
+ weight=bbox_weights_3d[:, 6:], avg_factor=num_total_pos
+ ),
+ )
+
+ return (
+ loss_cls,
+ loss_bbox,
+ loss_iou,
+ loss_cen,
+ loss_depth,
+ loss_dim,
+ loss_rot,
+ )
+
+ def forward(
+ self,
+ all_layers_cls_scores: Tensor,
+ all_layers_bbox_preds: Tensor,
+ all_layers_bbox_3d_preds: Tensor,
+ text_token_mask: Tensor,
+ enc_cls_scores: Tensor,
+ enc_bbox_preds: Tensor,
+ enc_outputs_3d: Tensor,
+ dn_meta: dict[str, int],
+ positive_maps: list[Tensor],
+ input_hw: list[tuple[int, int]],
+ batch_gt_boxes: list[Tensor],
+ batch_gt_boxes_3d: list[Tensor],
+ batch_gt_boxes_classes: list[Tensor],
+ batch_gt_intrinsics: list[Tensor],
+ ) -> dict[str, Tensor]:
+ """Forward pass of the 3D Grounding DINO loss."""
+ (
+ all_layers_matching_cls_scores,
+ all_layers_matching_bbox_preds,
+ all_layers_denoising_cls_scores,
+ all_layers_denoising_bbox_preds,
+ ) = split_outputs(
+ all_layers_cls_scores, all_layers_bbox_preds, dn_meta
+ )
+
+ (
+ losses_cls,
+ losses_bbox,
+ losses_iou,
+ losses_cen,
+ losses_depth,
+ losses_dim,
+ losses_rot,
+ ) = multi_apply(
+ self.loss_3d_by_feat_single,
+ all_layers_matching_cls_scores,
+ all_layers_matching_bbox_preds,
+ all_layers_bbox_3d_preds,
+ text_token_mask=text_token_mask,
+ input_hw=input_hw,
+ batch_gt_boxes=batch_gt_boxes,
+ batch_gt_boxes_3d=batch_gt_boxes_3d,
+ batch_gt_boxes_classes=batch_gt_boxes_classes,
+ batch_gt_intrinsics=batch_gt_intrinsics,
+ positive_maps=positive_maps,
+ )
+
+ loss_dict = dict()
+
+ # loss from the last decoder layer
+ loss_dict["loss_cls"] = losses_cls[-1]
+ loss_dict["loss_bbox"] = losses_bbox[-1]
+ loss_dict["loss_iou"] = losses_iou[-1]
+ loss_dict["loss_delta_2d"] = losses_cen[-1]
+ loss_dict["loss_depth"] = losses_depth[-1]
+ loss_dict["loss_dim"] = losses_dim[-1]
+ loss_dict["loss_rot"] = losses_rot[-1]
+
+ # loss from other decoder layers
+ for num_dec_layer, (loss_cls_i, loss_bbox_i, loss_iou_i) in enumerate(
+ zip(losses_cls[:-1], losses_bbox[:-1], losses_iou[:-1])
+ ):
+ loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
+ loss_dict[f"d{num_dec_layer}.loss_bbox"] = loss_bbox_i
+ loss_dict[f"d{num_dec_layer}.loss_iou"] = loss_iou_i
+ loss_dict[f"d{num_dec_layer}.loss_delta_2d"] = losses_cen[
+ num_dec_layer
+ ]
+ loss_dict[f"d{num_dec_layer}.loss_depth"] = losses_depth[
+ num_dec_layer
+ ]
+ loss_dict[f"d{num_dec_layer}.loss_dim"] = losses_dim[num_dec_layer]
+ loss_dict[f"d{num_dec_layer}.loss_rot"] = losses_rot[num_dec_layer]
+
+ # loss of proposal generated from encode feature map.
+ if enc_cls_scores is not None:
+ if enc_outputs_3d is None:
+ # NOTE The enc_loss calculation of the DINO is
+ # different from that of Deformable DETR.
+ enc_loss_cls, enc_losses_bbox, enc_losses_iou = (
+ self.loss_by_feat_single(
+ enc_cls_scores,
+ enc_bbox_preds,
+ text_token_mask=text_token_mask,
+ input_hw=input_hw,
+ batch_gt_boxes=batch_gt_boxes,
+ batch_gt_boxes_classes=batch_gt_boxes_classes,
+ positive_maps=positive_maps,
+ )
+ )
+ loss_dict["enc_loss_cls"] = enc_loss_cls
+ loss_dict["enc_loss_bbox"] = enc_losses_bbox
+ loss_dict["enc_loss_iou"] = enc_losses_iou
+ else:
+ (
+ enc_loss_cls,
+ enc_losses_bbox,
+ enc_losses_iou,
+ enc_losses_cen,
+ enc_losses_depth,
+ enc_losses_dim,
+ enc_losses_rot,
+ ) = self.loss_3d_by_feat_single(
+ enc_cls_scores,
+ enc_bbox_preds,
+ enc_outputs_3d,
+ text_token_mask=text_token_mask,
+ input_hw=input_hw,
+ batch_gt_boxes=batch_gt_boxes,
+ batch_gt_boxes_3d=batch_gt_boxes_3d,
+ batch_gt_boxes_classes=batch_gt_boxes_classes,
+ batch_gt_intrinsics=batch_gt_intrinsics,
+ positive_maps=positive_maps,
+ )
+ loss_dict["enc_loss_cls"] = enc_loss_cls
+ loss_dict["enc_loss_bbox"] = enc_losses_bbox
+ loss_dict["enc_loss_iou"] = enc_losses_iou
+ loss_dict["enc_loss_delta_2d"] = enc_losses_cen
+ loss_dict["enc_loss_depth"] = enc_losses_depth
+ loss_dict["enc_loss_dim"] = enc_losses_dim
+ loss_dict["enc_loss_rot"] = enc_losses_rot
+
+ if all_layers_denoising_cls_scores is not None:
+ # calculate denoising loss from all decoder layers
+ dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn(
+ all_layers_denoising_cls_scores,
+ all_layers_denoising_bbox_preds,
+ boxes2d=batch_gt_boxes,
+ boxes2d_classes=batch_gt_boxes_classes,
+ positive_maps=positive_maps,
+ input_hw=input_hw,
+ text_token_mask=text_token_mask,
+ dn_meta=dn_meta,
+ )
+
+ # collate denoising loss
+ loss_dict["dn_loss_cls"] = dn_losses_cls[-1]
+ loss_dict["dn_loss_bbox"] = dn_losses_bbox[-1]
+ loss_dict["dn_loss_iou"] = dn_losses_iou[-1]
+
+ for num_dec_layer, (
+ loss_cls_i,
+ loss_bbox_i,
+ loss_iou_i,
+ ) in enumerate(
+ zip(
+ dn_losses_cls[:-1], dn_losses_bbox[:-1], dn_losses_iou[:-1]
+ )
+ ):
+ loss_dict[f"d{num_dec_layer}.dn_loss_cls"] = loss_cls_i
+ loss_dict[f"d{num_dec_layer}.dn_loss_bbox"] = loss_bbox_i
+ loss_dict[f"d{num_dec_layer}.dn_loss_iou"] = loss_iou_i
+
+ return loss_dict
diff --git a/wilddet3d/loss/focal_loss.py b/wilddet3d/loss/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1cdb12ad534780d2b33bac40f3159c78322e68b
--- /dev/null
+++ b/wilddet3d/loss/focal_loss.py
@@ -0,0 +1,62 @@
+"""Focal Loss."""
+
+from __future__ import annotations
+
+import torch.nn.functional as F
+from torch import Tensor
+from torchvision.ops import sigmoid_focal_loss
+from vis4d.op.loss.base import Loss
+from vis4d.op.loss.reducer import LossReducer, mean_loss
+
+
+class FocalLoss(Loss):
+ """Focal loss `_."""
+
+ def __init__(
+ self,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+ reducer: LossReducer = mean_loss,
+ ) -> None:
+ """Creates an instance of the class.
+
+ Args:
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ reducer (LossReducer, optional): Reducer for the loss function.
+ Defaults to mean_loss.
+ """
+ super().__init__(reducer)
+ self.alpha = alpha
+ self.gamma = gamma
+
+ def forward(
+ self, pred: Tensor, target: Tensor, reducer: LossReducer | None = None
+ ) -> Tensor:
+ """Forward function.
+
+ Args:
+ pred (Tensor): The prediction.
+ target (Tensor): The learning label of the prediction.
+
+ Returns:
+ Tensor: The calculated loss.
+ """
+ # this means that target is not in One-Hot form.
+ if pred.dim() != target.dim():
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes + 1).float()
+ target = target[:, :num_classes]
+
+ reducer = reducer or self.reducer
+
+ focal_loss = sigmoid_focal_loss(
+ pred,
+ target,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ )
+
+ return reducer(focal_loss)
diff --git a/wilddet3d/loss/geom_loss_aggregator.py b/wilddet3d/loss/geom_loss_aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0459b01730c973841195726e1a7250747bce2e2
--- /dev/null
+++ b/wilddet3d/loss/geom_loss_aggregator.py
@@ -0,0 +1,55 @@
+"""Geometry Loss Aggregator.
+
+This module provides a loss class that aggregates geometry losses from
+the model output (geom_losses dict from GeometryBackend).
+"""
+
+from __future__ import annotations
+
+from torch import Tensor
+from vis4d.common.typing import ArgsType
+from vis4d.op.loss.base import Loss
+
+
+class GeomLossAggregator(Loss):
+ """Aggregates geometry losses from model output.
+
+ This loss class takes the geom_losses dict from the model output
+ and returns the sum of all losses. Each individual loss is also
+ logged separately.
+
+ Args:
+ weight: Global weight multiplier for all geometry losses.
+ """
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ weight: float = 1.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Initialize the GeomLossAggregator."""
+ super().__init__(*args, **kwargs)
+ self.weight = weight
+
+ def forward(
+ self,
+ geom_losses: dict[str, Tensor] | None,
+ ) -> dict[str, Tensor]:
+ """Forward function.
+
+ Args:
+ geom_losses: Dictionary of geometry losses from the model.
+
+ Returns:
+ Dictionary of weighted losses.
+ """
+ if geom_losses is None or len(geom_losses) == 0:
+ return {}
+
+ weighted_losses = {}
+ for name, loss in geom_losses.items():
+ weighted_losses[f"geom_{name}"] = loss * self.weight
+
+ return weighted_losses
+
diff --git a/wilddet3d/loss/iou_loss.py b/wilddet3d/loss/iou_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5d93c7efa7ff98cb6facb4b83076d79e7b7e4ae
--- /dev/null
+++ b/wilddet3d/loss/iou_loss.py
@@ -0,0 +1,81 @@
+"""IoU Loss."""
+
+import torch
+from torch import Tensor
+from vis4d.op.loss.base import Loss
+from vis4d.op.loss.reducer import LossReducer, mean_loss
+
+from wilddet3d.ops.box2d import bbox_overlaps
+
+
+def giou_loss(pred: Tensor, target: Tensor, eps: float = 1e-7) -> Tensor:
+ r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
+ Box Regression `_.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Epsilon to avoid log(0).
+
+ Return:
+ Tensor: Loss tensor.
+ """
+ # avoid fp16 overflow
+ if pred.dtype == torch.float16:
+ fp16 = True
+ pred = pred.to(torch.float32)
+ else:
+ fp16 = False
+
+ gious = bbox_overlaps(pred, target, mode="giou", is_aligned=True, eps=eps)
+
+ if fp16:
+ gious = gious.to(torch.float16)
+
+ loss = 1 - gious
+ return loss
+
+
+class GIoULoss(Loss):
+ r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
+ Box Regression `_.
+ """
+
+ def __init__(
+ self,
+ eps: float = 1e-6,
+ reducer: LossReducer = mean_loss,
+ ) -> None:
+ super().__init__(reducer)
+ self.eps = eps
+
+ def forward(
+ self,
+ pred: Tensor,
+ target: Tensor,
+ reducer: LossReducer | None = None,
+ ) -> Tensor:
+ """Forward function.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): The learning target of the prediction,
+ shape (n, 4).
+ weight (Optional[Tensor], optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (Optional[int], optional): Average factor that is used
+ to average the loss. Defaults to None.
+ reduction_override (Optional[str], optional): The reduction method
+ used to override the original reduction method of the loss.
+ Defaults to None. Options are "none", "mean" and "sum".
+
+ Returns:
+ Tensor: Loss tensor.
+ """
+ reducer = reducer or self.reducer
+
+ loss = giou_loss(pred, target, eps=self.eps)
+
+ return reducer(loss)
diff --git a/wilddet3d/loss/silog_loss.py b/wilddet3d/loss/silog_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..de46146513fda9269994dffe007a4d4aa1bd6fbd
--- /dev/null
+++ b/wilddet3d/loss/silog_loss.py
@@ -0,0 +1,59 @@
+"""SILog loss for depth estimation."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+from vis4d.common.typing import ArgsType
+from vis4d.op.loss.base import Loss
+
+from .util import masked_mean_var
+
+
+class SILogLoss(Loss):
+ """SILogLoss."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ scale_pred_weight: float = 0.15,
+ eps: float = 1e-5,
+ min_depth: float = 0.0,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Init."""
+ super().__init__(*args, **kwargs)
+ self.scale_pred_weight = scale_pred_weight
+ self.eps = eps
+ self.min_depth = min_depth
+
+ def forward(
+ self, depths: Tensor, target_depths: Tensor, mask: Tensor | None = None
+ ) -> Tensor:
+ """Forward function.
+
+ Args:
+ depths (Tensor): Predicted depth. Shape: (B, H, W)
+ target_depths (Tensor): Target depth. Shape: (B, H, W)
+ mask (Tensor | None): Mask. Shape: (B, H, W)
+ """
+ if mask is None:
+ mask = target_depths > self.min_depth
+ else:
+ mask = mask.to(torch.bool)
+ mask = torch.logical_and(mask, target_depths > self.min_depth)
+
+ log_depths = torch.log(depths.clamp(min=self.eps))
+ log_target_depths = torch.log(target_depths.clamp(min=self.eps))
+
+ log_error = log_depths - log_target_depths
+
+ mean_error, var_error = masked_mean_var(log_error, mask=mask)
+
+ scale_error = mean_error**2
+
+ loss = var_error + self.scale_pred_weight * scale_error
+
+ out_loss = torch.sqrt(loss.clamp(min=self.eps))
+
+ return out_loss.mean()
diff --git a/wilddet3d/loss/util.py b/wilddet3d/loss/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f01a93fdbcedfab1396af9ec44f4c6cabb75751
--- /dev/null
+++ b/wilddet3d/loss/util.py
@@ -0,0 +1,35 @@
+"""Loss util."""
+
+from __future__ import annotations
+
+import torch
+from torch import Tensor
+
+
+def masked_mean_var(error: Tensor, mask: Tensor | None = None) -> Tensor:
+ """Compute mean and variance of error with mask."""
+ if mask is None:
+ return error.mean(dim=[-2, -1], keepdim=True), error.var(
+ dim=[-2, -1], keepdim=True
+ )
+ mask = mask.float()
+ mask_sum = torch.sum(mask, dim=[-2, -1], keepdim=True)
+ mask_mean = torch.sum(
+ error * mask, dim=[-2, -1], keepdim=True
+ ) / torch.clamp(mask_sum, min=1.0)
+ mask_var = torch.sum(
+ mask * (error - mask_mean) ** 2, dim=[-2, -1], keepdim=True
+ ) / torch.clamp(mask_sum, min=1.0)
+ return mask_mean.squeeze([-2, -1]), mask_var.squeeze([-2, -1])
+
+
+def masked_mean(data: Tensor, mask: Tensor | None):
+ """Compute mean of data with mask."""
+ if mask is None:
+ return data.mean(dim=[-2, -1], keepdim=True)
+ mask = mask.float()
+ mask_sum = torch.sum(mask, dim=[-2, -1], keepdim=True)
+ mask_mean = torch.sum(
+ data * mask, dim=[-2, -1], keepdim=True
+ ) / torch.clamp(mask_sum, min=1.0)
+ return mask_mean
diff --git a/wilddet3d/loss/wilddet3d_loss.py b/wilddet3d/loss/wilddet3d_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5947ab9822f8e66c32bd6d89a5cc6ea075e82d1
--- /dev/null
+++ b/wilddet3d/loss/wilddet3d_loss.py
@@ -0,0 +1,1256 @@
+"""WildDet3D Loss Module.
+
+This module implements the loss function for WildDet3D, combining:
+1. SAM3-style 2D losses (IABCEMdetr for classification, L1+GIoU for boxes)
+2. 3D-MOOD-style 3D losses (delta_center, depth, dimensions, rotation)
+
+Key Design Decisions:
+- Uses SAM3's Hungarian matcher for assignment (already computed in model)
+- Follows SAM3's loss normalization (global/local/none)
+- Adds 3D regression losses on top of 2D losses
+- Supports deep supervision on auxiliary decoder outputs
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Literal
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+from vis4d.common.distributed import reduce_mean
+from vis4d.op.loss.common import l1_loss
+from vis4d.op.loss.reducer import SumWeightedLoss
+
+from wilddet3d.head.coder_3d import Det3DCoder
+from sam3.model.box_ops import fast_diag_box_iou, fast_diag_generalized_box_iou
+from sam3.train.matcher import BinaryOneToManyMatcher
+from sam3.train.loss.loss_fns import (
+ IABCEMdetr,
+ Boxes as SAM3Boxes,
+ sigmoid_focal_loss,
+)
+
+
+def _packed_to_padded(boxes_packed: Tensor, num_boxes: Tensor, fill_value: float = 0.0) -> Tensor:
+ """Convert packed tensor to padded tensor.
+
+ This function converts a packed (concatenated) tensor of bounding boxes
+ to a batch-wise padded tensor, following SAM3's collator implementation.
+
+ Args:
+ boxes_packed: Packed boxes tensor of shape (N_total, 4) where
+ N_total = N_1 + N_2 + ... + N_B
+ num_boxes: Number of boxes per image, shape (B,)
+ fill_value: Value to use for padding (default: 0.0)
+
+ Returns:
+ Padded boxes tensor of shape (B, max_N, 4) where max_N = max(num_boxes)
+
+ Example:
+ >>> boxes = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
+ >>> num_boxes = torch.tensor([1, 2])
+ >>> padded = _packed_to_padded(boxes, num_boxes)
+ >>> padded.shape
+ torch.Size([2, 2, 4])
+ """
+ B = num_boxes.shape[0]
+ Ns = num_boxes.tolist()
+ max_N = max(Ns)
+
+ # Create padded tensor
+ boxes_padded = boxes_packed.new_full((B, max_N, *boxes_packed.shape[1:]), fill_value)
+
+ # Fill in actual boxes
+ prev_idx = 0
+ for i in range(B):
+ next_idx = prev_idx + Ns[i]
+ boxes_padded[i, :Ns[i]] = boxes_packed[prev_idx:next_idx]
+ prev_idx = next_idx
+
+ return boxes_padded
+
+
+@dataclass
+class WildDet3DLossConfig:
+ """Configuration for WildDet3D loss.
+
+ Follows SAM3's loss configuration style with additional 3D loss weights.
+ """
+ # ========== Global Scale Factors ==========
+ # These allow adjusting the balance between 2D, 3D, and geometry losses
+ # Default 1.0, can be adjusted in training config to tune 2D:3D:Geom ratio
+ loss_2d_scale: float = 1.0 # Scale for 2D losses (cls, bbox, giou)
+ loss_3d_scale: float = 1.0 # Scale for 3D losses (delta, depth, dim, rot)
+ loss_geom_scale: float = 10.0 # Scale for geometry backend losses (SILog, SSI, camera angles)
+
+ # ========== O2M (One-to-Many) Matcher Configuration ==========
+ # Note: O2O matcher is configured in wilddet3d.py (self.sam3.matcher)
+ use_o2m: bool = True # Enable O2M matching
+ o2m_loss_clip: float = 150.0 # Clip O2M loss to prevent gradient explosion
+ o2m_alpha: float = 0.3 # Alpha for O2M cost computation
+ o2m_threshold: float = 0.4 # IoU threshold for O2M matching
+ o2m_topk: int = 4 # Top-k predictions per GT (SAM3 original: topk: 4)
+ o2m_loss_weight: float = 2.0 # Weight for O2M loss (SAM3 original: o2m_weight: 2.0)
+
+ # ========== 2D Loss Weights (SAM3 style) ==========
+ # Classification loss (IABCEMdetr style)
+ loss_cls_weight: float = 20.0 # SAM3 original
+ pos_weight: float = 5.0 # SAM3 original (was incorrectly 10.0)
+ gamma: float = 2.0 # SAM3 original focal (was incorrectly 0.0)
+ alpha: float = 0.25 # IoU-aware alpha
+
+ # IABCEMdetr advanced features
+ use_weak_loss: bool = False # Enable weak supervision (SAM3 original: weak_loss: False)
+ weak_loss_weight: float = 1.0 # Weight for weak loss (only used if use_weak_loss=True)
+ use_presence: bool = True # Enable presence loss (per-category presence detection)
+ presence_loss_weight: float = 20.0 # Weight for presence loss (SAM3 original: presence_weight: 20.0)
+ presence_alpha: float = 0.5 # SAM3 original presence focal loss alpha
+ presence_gamma: float = 0.0 # SAM3 original (gamma=0 = plain BCE, no focal weighting)
+
+ # Box regression loss
+ loss_bbox_weight: float = 5.0 # L1 loss weight
+ loss_giou_weight: float = 2.0 # GIoU loss weight
+
+ # ========== 3D Loss Weights (3D-MOOD style) ==========
+ loss_delta_2d_weight: float = 1.0 # Delta 2D center
+ loss_depth_weight: float = 1.0 # Log depth
+ loss_dim_weight: float = 1.0 # Log dimensions
+ loss_rot_weight: float = 1.0 # 6D rotation
+
+ # ========== Geometry Backend Loss Weights ==========
+ loss_silog_weight: float = 1.0 # SILog depth loss
+ loss_phi_weight: float = 0.1 # Phi angle loss
+ loss_theta_weight: float = 0.1 # Theta angle loss
+ loss_opt_ssi_weight: float = 0.5 # SSI loss weight (UniDepthV2)
+
+ # ========== Normalization ==========
+ normalization: Literal["global", "local", "none"] = "global"
+
+ # ========== Auxiliary Loss ==========
+ aux_loss_weight: float = 1.0 # Weight for auxiliary decoder outputs
+
+ # ========== Mask Loss (optional) ==========
+ loss_mask_weight: float = 0.0 # Set > 0 to enable mask loss
+ loss_dice_weight: float = 0.0 # Set > 0 to enable dice loss
+
+ # ========== 3D Confidence Head ==========
+ # Positive: soft target = quality (iou_3d + depth). Negative: push to 0.
+ # Inference: final_score = 2d_score + conf_3d_inference_weight * 3d_score
+ use_3d_conf: bool = False # Enable 3D confidence head loss
+ loss_3d_conf_weight: float = 20.0 # Weight for 3D confidence loss (same as 2D loss_cls_weight)
+ conf_depth_weight: float = 0.7 # Weight for depth quality in quality target
+ conf_iou_3d_weight: float = 0.3 # Weight for 3D IoU in quality target
+
+ # ========== Ignore Box Negative Loss Suppression ==========
+ # Suppress negative classification loss for predictions that overlap
+ # with ignore-annotated objects (truncated, occluded, etc.).
+ # This aligns training with eval, where such detections are neutral.
+ use_ignore_suppress: bool = False
+ ignore_iou_threshold: float = 0.5 # 2D IoU threshold for suppression
+
+
+class WildDet3DLoss(nn.Module):
+ """Loss function for WildDet3D.
+
+ Combines SAM3-style 2D losses with 3D-MOOD-style 3D losses.
+
+ Loss Components:
+ 1. Classification: IABCEMdetr (IoU-aware BCE with soft targets)
+ 2. 2D Box: L1 + GIoU
+ 3. 3D Box: L1 for (delta_center, log_depth, log_dims, rot_6d)
+ 4. Geometry: SILog depth + phi/theta angles (from geometry backend)
+ """
+
+ def __init__(
+ self,
+ config: WildDet3DLossConfig | None = None,
+ box_coder: Det3DCoder | None = None,
+ ) -> None:
+ """Initialize WildDet3D loss.
+
+ Args:
+ config: Loss configuration
+ box_coder: 3D box encoder/decoder for target encoding
+ """
+ super().__init__()
+ self.config = config or WildDet3DLossConfig()
+ self.box_coder = box_coder or Det3DCoder()
+ self.reg_dims = self.box_coder.reg_dims
+
+ # SAM3's 2D loss classes (directly imported from sam3.train.loss.loss_fns)
+ # weak_loss=False follows SAM3's own training configs — all unmatched
+ # predictions receive negative loss regardless of is_exhaustive.
+ self.cls_loss = IABCEMdetr(
+ pos_weight=self.config.pos_weight,
+ gamma=self.config.gamma,
+ alpha=self.config.alpha,
+ weak_loss=False,
+ use_presence=self.config.use_presence,
+ presence_alpha=self.config.presence_alpha,
+ presence_gamma=self.config.presence_gamma,
+ )
+ self.box_loss = SAM3Boxes()
+
+ # O2M matcher for DAC one-to-many loss
+ if self.config.use_o2m:
+ self.o2m_matcher = BinaryOneToManyMatcher(
+ alpha=self.config.o2m_alpha,
+ threshold=self.config.o2m_threshold,
+ topk=self.config.o2m_topk,
+ )
+ else:
+ self.o2m_matcher = None
+
+ def _compute_ignore_neg_mask(
+ self,
+ pred_boxes: Tensor,
+ ignore_boxes: Tensor,
+ num_ignores: Tensor,
+ threshold: float = 0.5,
+ ) -> Tensor:
+ """Compute mask for predictions overlapping ignore boxes.
+
+ Args:
+ pred_boxes: (B, S, 4) normalized xyxy predicted boxes.
+ ignore_boxes: (B, max_ignore, 4) normalized xyxy ignore boxes.
+ num_ignores: (B,) number of valid ignore boxes per prompt.
+ threshold: 2D IoU threshold above which to suppress.
+
+ Returns:
+ mask: (B, S) float. 1.0 = suppress negative loss, 0.0 = normal.
+ """
+ import torchvision.ops
+
+ B, S, _ = pred_boxes.shape
+ device = pred_boxes.device
+ mask = torch.zeros(B, S, device=device)
+
+ for b in range(B):
+ n_ign = num_ignores[b].item()
+ if n_ign == 0:
+ continue
+ iou = torchvision.ops.box_iou(
+ pred_boxes[b],
+ ignore_boxes[b, :n_ign],
+ ) # (S, n_ign)
+ mask[b] = (iou.max(dim=1).values > threshold).float()
+
+ return mask
+
+ def _build_targets_from_batch(
+ self, batch: "WildDet3DInput"
+ ) -> dict[str, Tensor]:
+ """Build targets dict from WildDet3DInput.
+
+ WildDet3D uses per-category queries with multi-instance targets.
+ The collator produces:
+ - gt_boxes2d: (N_prompts, max_gt, 4) - multiple GTs per query
+ - gt_boxes3d: (N_prompts, max_gt, 12) - multiple GTs per query (if available)
+ - num_gts: (N_prompts,) - number of valid GTs per query (can be > 1)
+
+ We convert this to the packed format expected by loss computation.
+
+ Args:
+ batch: WildDet3DInput containing GT boxes
+
+ Returns:
+ targets dict with:
+ - boxes_xyxy: (N_total, 4) GT boxes in xyxy format (packed)
+ - boxes_3d: (N_total, 12) 3D GT boxes (packed)
+ - num_boxes: (N_prompts,) number of GTs per query
+ - intrinsics: (N_prompts, 3, 3) camera intrinsics per prompt
+ """
+ device = batch.images.device
+ N_prompts = batch.img_ids.shape[0]
+
+ # Extract GT from batch
+ gt_boxes2d = batch.gt_boxes2d # (N_prompts, max_gt, 4) or (N_prompts, 4)
+ gt_boxes3d = batch.gt_boxes3d # (N_prompts, max_gt, 12) or None
+ num_gts = batch.num_gts # (N_prompts,) number of valid GTs per query
+
+ if gt_boxes2d is None:
+ # No GT available
+ return {
+ "boxes_xyxy": torch.zeros(0, 4, device=device),
+ "boxes_3d": torch.zeros(0, 12, device=device),
+ "classes": torch.zeros(0, dtype=torch.long, device=device),
+ "num_boxes": torch.zeros(N_prompts, dtype=torch.long, device=device),
+ "intrinsics": batch.intrinsics[batch.img_ids],
+ }
+
+ # Handle both old (N_prompts, 4) and new (N_prompts, max_gt, 4) formats
+ if gt_boxes2d.dim() == 2:
+ # Old format: (N_prompts, 4) - single GT per prompt
+ boxes_xyxy = gt_boxes2d
+ if num_gts is None:
+ num_gts = torch.ones(N_prompts, dtype=torch.long, device=device)
+
+ if gt_boxes3d is not None and gt_boxes3d.dim() == 2:
+ boxes_3d = gt_boxes3d
+ else:
+ boxes_3d = torch.zeros(N_prompts, 12, device=device)
+ else:
+ # New format: (N_prompts, max_gt, 4) - multi-instance targets
+ # Pack valid boxes into a flat tensor
+ if num_gts is None:
+ # Fallback: assume all boxes are valid
+ num_gts = torch.tensor([gt_boxes2d.shape[1]] * N_prompts, dtype=torch.long, device=device)
+
+ # Pack boxes into (N_total, 4)
+ boxes_list = []
+ boxes_3d_list = []
+ for i in range(N_prompts):
+ n_gt = num_gts[i].item()
+ boxes_list.append(gt_boxes2d[i, :n_gt]) # (n_gt, 4)
+ if gt_boxes3d is not None:
+ boxes_3d_list.append(gt_boxes3d[i, :n_gt]) # (n_gt, 12)
+
+ if boxes_list:
+ boxes_xyxy = torch.cat(boxes_list, dim=0) # (N_total, 4)
+ else:
+ boxes_xyxy = torch.zeros(0, 4, device=device)
+
+ if boxes_3d_list:
+ boxes_3d = torch.cat(boxes_3d_list, dim=0) # (N_total, 12)
+ else:
+ box3d_dim = gt_boxes3d.shape[-1] if gt_boxes3d is not None else 12
+ boxes_3d = torch.zeros(boxes_xyxy.shape[0], box3d_dim, device=device)
+
+ # SAM3 uses binary detection (all targets are class 1)
+ N_total = boxes_xyxy.shape[0]
+ classes = torch.ones(N_total, dtype=torch.long, device=device)
+
+ # Get per-prompt intrinsics
+ intrinsics = batch.intrinsics[batch.img_ids] # (N_prompts, 3, 3)
+
+ # SAM3's IABCEMdetr and Boxes loss classes need additional formats:
+ # - boxes (cxcywh packed) for L1 loss
+ # - boxes_padded (cxcywh padded) for presence keep_loss
+ # - object_ids_padded for presence keep_loss
+ # - is_exhaustive for weak loss masking
+ boxes_cxcywh = self._xyxy_to_cxcywh(boxes_xyxy)
+
+ # Padded format (B, max_N, 4) for presence loss keep_loss computation
+ boxes_padded = _packed_to_padded(boxes_cxcywh, num_gts)
+ max_N = boxes_padded.shape[1]
+
+ # Object IDs: sequential within each prompt's targets
+ object_ids_padded = torch.full(
+ (N_prompts, max_N), -1, dtype=torch.long, device=device
+ )
+ offset = 0
+ for i in range(N_prompts):
+ n = int(num_gts[i].item())
+ if n > 0:
+ object_ids_padded[i, :n] = torch.arange(
+ offset, offset + n, device=device
+ )
+ offset += n
+
+ # is_exhaustive: multi-target queries are exhaustive, single-target are not
+ # query_types: 0=TEXT, 1=VISUAL, 3=VISUAL+LABEL → exhaustive (True)
+ # query_types: 2=GEOMETRY, 4=GEOMETRY+LABEL → not exhaustive (False)
+ if batch.query_types is not None:
+ qt = batch.query_types.to(device)
+ is_exhaustive = (qt == 0) | (qt == 1) | (qt == 3)
+ else:
+ is_exhaustive = torch.ones(N_prompts, dtype=torch.bool, device=device)
+
+ return {
+ "boxes_xyxy": boxes_xyxy,
+ "boxes": boxes_cxcywh,
+ "boxes_padded": boxes_padded,
+ "boxes_3d": boxes_3d,
+ "classes": classes,
+ "num_boxes": num_gts,
+ "intrinsics": intrinsics,
+ "object_ids_padded": object_ids_padded,
+ "is_exhaustive": is_exhaustive,
+ }
+
+ def forward(
+ self,
+ out: "WildDet3DOutput",
+ batch: "WildDet3DInput",
+ ) -> dict[str, Tensor]:
+ """Compute all losses.
+
+ vis4d LossModule interface: expects either Tensor, dict, or namedtuple.
+ We return a dict of tensors, and LossModule will sum them automatically.
+
+ Following SAM3 and GDino3D's design, we compute 2D box L1 loss in normalized
+ cxcywh space and GIoU loss in pixel xyxy space for consistent loss weights.
+
+ Args:
+ out: Model output (WildDet3DOutput dataclass)
+ batch: Input batch (WildDet3DInput dataclass)
+
+ Returns:
+ Dict of loss tensors (vis4d LossModule will sum them)
+ """
+ import time
+ import os
+ import torch
+ _PROFILE_LOSS = os.environ.get("PROFILE_WILDDET3D", "0") == "1"
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_start = time.perf_counter()
+ # Unpack model outputs
+ pred_logits = out.pred_logits
+ pred_boxes_2d = out.pred_boxes_2d
+ pred_boxes_3d = out.pred_boxes_3d
+ aux_outputs = out.aux_outputs
+ geom_losses = out.geom_losses
+
+ # Build targets from batch
+ # Get per-prompt intrinsics by indexing into batch intrinsics
+ B_images = batch.images.shape[0]
+ N_prompts = batch.img_ids.shape[0]
+ intrinsics = batch.intrinsics[batch.img_ids] # (N_prompts, 3, 3)
+
+ # Image size from batch
+ image_size = (batch.images.shape[2], batch.images.shape[3]) # (H, W)
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t_targets = time.perf_counter()
+
+ targets = self._build_targets_from_batch(batch)
+ losses = {}
+
+ # Normalize targets to [0, 1] range (for matching and computation)
+ normalized_targets = self._normalize_targets(targets)
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_targets_time = (time.perf_counter() - _t_targets) * 1000
+
+ # Store image_size for pixel coordinate conversion
+ if image_size is None and "image_size" in targets:
+ image_size = targets["image_size"]
+
+ # Get matching indices from SAM3's internal matching
+ # SAM3's forward_grounding computes indices via _compute_matching when find_target is provided
+ # Handle empty batch (N_prompts=0) case - return zero loss with grad
+ if out.indices is None:
+ device = pred_logits.device
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ print(f"[WildDet3D Loss] Empty batch detected on rank {rank}, returning zero loss")
+
+ # CRITICAL: Must still participate in all_reduce to prevent DDP deadlock
+ # Other ranks may have non-empty batches and will call all_reduce
+ if self.config.normalization == "global" and torch.distributed.is_initialized():
+ dummy_num_boxes = torch.tensor(0.0, device=device)
+ torch.distributed.all_reduce(dummy_num_boxes)
+
+ # Use pred_logits.sum() * 0 to maintain computation graph for DDP
+ zero_loss = pred_logits.sum() * 0
+ return {
+ "loss_cls": zero_loss, # Keep grad for DDP
+ "loss_bbox": zero_loss.clone(),
+ "loss_giou": zero_loss.clone(),
+ }
+
+ batch_idx, src_idx, tgt_idx = out.indices
+
+ # Move indices to the same device as predictions
+ batch_idx = batch_idx.to(pred_logits.device)
+ src_idx = src_idx.to(pred_logits.device)
+ tgt_idx = tgt_idx.to(pred_logits.device) if tgt_idx is not None else None
+
+ indices = (batch_idx, src_idx, tgt_idx)
+
+ # Get number of boxes for normalization
+ num_boxes = self._get_num_boxes(normalized_targets)
+
+ # ========== 2D Losses via SAM3's loss classes (scaled by loss_2d_scale) ==========
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t0 = time.perf_counter()
+
+ # Build SAM3-format outputs dict for loss classes
+ sam3_outputs = {
+ "pred_logits": pred_logits,
+ "pred_boxes_xyxy": pred_boxes_2d,
+ "pred_boxes": out.pred_boxes_2d_cxcywh,
+ }
+ if out.presence_logits is not None:
+ sam3_outputs["presence_logit_dec"] = out.presence_logits
+
+ # Compute ignore negative loss suppression mask
+ if (
+ self.config.use_ignore_suppress
+ and batch.ignore_boxes2d is not None
+ and batch.num_ignores is not None
+ ):
+ normalized_targets["_ignore_boxes2d"] = batch.ignore_boxes2d
+ normalized_targets["_num_ignores"] = batch.num_ignores
+ normalized_targets["ignore_neg_mask"] = (
+ self._compute_ignore_neg_mask(
+ pred_boxes_2d,
+ batch.ignore_boxes2d,
+ batch.num_ignores,
+ threshold=self.config.ignore_iou_threshold,
+ )
+ )
+
+ # Classification + presence via SAM3's IABCEMdetr
+ cls_losses = self.cls_loss.get_loss(
+ sam3_outputs, normalized_targets, indices, num_boxes
+ )
+ losses["loss_cls"] = (
+ self.config.loss_2d_scale * cls_losses["loss_ce"] * self.config.loss_cls_weight
+ )
+ # Metrics from SAM3's IABCEMdetr (not losses, just for wandb logging)
+ if "ce_f1" in cls_losses:
+ losses["metric_ce_f1"] = cls_losses["ce_f1"].detach()
+ # Presence loss (computed inside IABCEMdetr when use_presence=True)
+ presence_val = cls_losses.get("presence_loss")
+ if presence_val is not None and isinstance(presence_val, Tensor):
+ losses["loss_presence"] = (
+ self.config.loss_2d_scale * presence_val
+ * self.config.presence_loss_weight
+ )
+ if "presence_dec_acc" in cls_losses:
+ losses["metric_presence_acc"] = cls_losses["presence_dec_acc"].detach()
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_cls_time = (time.perf_counter() - _t0) * 1000
+
+ # 2D box losses (L1 + GIoU) via SAM3's Boxes class
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t1 = time.perf_counter()
+
+ box_losses = self.box_loss.get_loss(
+ sam3_outputs, normalized_targets, indices, num_boxes
+ )
+ losses["loss_bbox"] = (
+ self.config.loss_2d_scale * box_losses["loss_bbox"] * self.config.loss_bbox_weight
+ )
+ losses["loss_giou"] = (
+ self.config.loss_2d_scale * box_losses["loss_giou"] * self.config.loss_giou_weight
+ )
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_2d_box_time = (time.perf_counter() - _t1) * 1000
+
+ # ========== O2M Loss (2D scaled by loss_2d_scale, 3D scaled by loss_3d_scale) ==========
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t_o2m = time.perf_counter()
+ _loss_o2m_time = 0
+
+ # Use real O2M outputs from SAM3 DAC mechanism (not O2O outputs)
+ if self.config.use_o2m and self.o2m_matcher is not None and out.pred_logits_o2m is not None:
+ o2m_losses = self._loss_o2m(
+ pred_logits=out.pred_logits_o2m,
+ pred_boxes_2d=out.pred_boxes_2d_o2m,
+ pred_boxes_2d_cxcywh=out.pred_boxes_2d_cxcywh_o2m,
+ pred_boxes_3d=out.pred_boxes_3d_o2m,
+ targets=normalized_targets,
+ num_boxes=num_boxes,
+ intrinsics=intrinsics,
+ image_size=image_size,
+ pred_conf_3d=out.pred_conf_3d_o2m,
+ )
+ # Apply appropriate scale and loss weights (following SAM3 original)
+ # SAM3 original: loss = loss_value * o2m_weight * loss_weight
+ # We need to apply the individual loss weights, not just o2m_loss_weight
+ o2m_weight_map = {
+ "loss_cls": self.config.loss_cls_weight,
+ "loss_bbox": self.config.loss_bbox_weight,
+ "loss_giou": self.config.loss_giou_weight,
+ "loss_delta_2d": self.config.loss_delta_2d_weight,
+ "loss_depth": self.config.loss_depth_weight,
+ "loss_dim": self.config.loss_dim_weight,
+ "loss_rot": self.config.loss_rot_weight,
+ "loss_3d_cls": self.config.loss_3d_conf_weight,
+ }
+ for key, value in o2m_losses.items():
+ loss_weight = o2m_weight_map.get(key, 1.0)
+ if key in ("loss_delta_2d", "loss_depth", "loss_dim", "loss_rot"):
+ # 3D losses use loss_3d_scale
+ o2m_loss_val = (
+ self.config.loss_3d_scale * value * loss_weight * self.config.o2m_loss_weight
+ )
+ elif key == "loss_3d_cls":
+ # 3D confidence loss: weight * o2m_weight (no extra scale)
+ o2m_loss_val = value * loss_weight * self.config.o2m_loss_weight
+ else:
+ # 2D losses (loss_cls, loss_bbox, loss_giou) use loss_2d_scale
+ o2m_loss_val = (
+ self.config.loss_2d_scale * value * loss_weight * self.config.o2m_loss_weight
+ )
+ # Clip O2M loss to prevent gradient explosion
+ losses[f"o2m_{key}"] = torch.clamp(o2m_loss_val, max=self.config.o2m_loss_clip)
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_o2m_time = (time.perf_counter() - _t_o2m) * 1000
+
+ # ========== 3D Losses (scaled by loss_3d_scale) ==========
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t2 = time.perf_counter()
+ _loss_3d_time = 0
+
+ if pred_boxes_3d is not None and intrinsics is not None:
+ loss_3d = self._loss_boxes_3d(
+ pred_boxes_2d, pred_boxes_3d, indices, normalized_targets,
+ intrinsics, num_boxes, image_size=image_size
+ )
+ # Apply loss_3d_scale to all 3D losses
+ for key, value in loss_3d.items():
+ losses[key] = self.config.loss_3d_scale * value
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_3d_time = (time.perf_counter() - _t2) * 1000
+
+ # ========== 3D Confidence Loss (positive samples only) ==========
+ if (self.config.use_3d_conf
+ and out.pred_conf_3d is not None
+ and pred_boxes_3d is not None
+ and intrinsics is not None):
+ loss_3d_cls = self._loss_3d_classification(
+ out.pred_conf_3d, pred_boxes_2d, pred_boxes_3d,
+ indices, normalized_targets, intrinsics, num_boxes, image_size,
+ )
+ losses["loss_3d_cls"] = self.config.loss_3d_conf_weight * loss_3d_cls
+
+ # ========== Geometry Backend Losses (scaled by loss_geom_scale) ==========
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t_geom = time.perf_counter()
+ _loss_geom_time = 0
+
+ if geom_losses is not None:
+ for key, value in geom_losses.items():
+ if key.startswith("metric_"):
+ # Monitoring-only: log raw value, no scaling
+ losses[key] = value.detach()
+ else:
+ weight = getattr(
+ self.config, f"loss_{key}_weight", 1.0
+ )
+ losses[f"loss_{key}"] = (
+ self.config.loss_geom_scale * value * weight
+ )
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_geom_time = (time.perf_counter() - _t_geom) * 1000
+
+ # ========== Auxiliary Losses (Deep Supervision) ==========
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _t3 = time.perf_counter()
+ _loss_aux_time = 0
+
+ _num_aux_layers = 0
+ if aux_outputs is not None:
+ _num_aux_layers = len(aux_outputs)
+ for i, aux_out in enumerate(aux_outputs):
+ aux_losses = self._compute_aux_loss(
+ aux_out, indices, normalized_targets, num_boxes, intrinsics, image_size
+ )
+ for key, value in aux_losses.items():
+ losses[f"d{i}.{key}"] = value * self.config.aux_loss_weight
+
+ if _PROFILE_LOSS:
+ torch.cuda.synchronize()
+ _loss_aux_time = (time.perf_counter() - _t3) * 1000
+ _loss_total_time = (time.perf_counter() - _loss_start) * 1000
+
+ # Print loss timing summary (every N steps via profiler)
+ from wilddet3d.ops.profiler import profiler
+ p = profiler()
+ p.current_step_timings["loss_total"] = _loss_total_time / 1000
+ p.current_step_timings[" loss_targets"] = _loss_targets_time / 1000
+ p.current_step_timings[" loss_cls"] = _loss_cls_time / 1000
+ p.current_step_timings[" loss_2d_box"] = _loss_2d_box_time / 1000
+ p.current_step_timings[" loss_o2m"] = _loss_o2m_time / 1000
+ p.current_step_timings[" loss_3d"] = _loss_3d_time / 1000
+ p.current_step_timings[" loss_geom"] = _loss_geom_time / 1000
+ p.current_step_timings[" loss_aux"] = _loss_aux_time / 1000
+ p.current_step_timings[" loss_aux_layers"] = _num_aux_layers
+
+ # ========== Ensure all losses are tensors ==========
+ # vis4d LossModule expects dict of tensors
+ for k, v in list(losses.items()):
+ if not isinstance(v, Tensor):
+ losses[k] = torch.tensor(v, device=pred_logits.device)
+
+ # vis4d LossModule will sum all losses in the dict automatically
+ return losses
+
+ def _get_num_boxes(self, targets: dict) -> Tensor:
+ """Get number of boxes for loss normalization."""
+ num_boxes = targets["num_boxes"].sum().float()
+
+ if self.config.normalization == "global":
+ # Handle non-distributed case
+ if torch.distributed.is_initialized():
+ torch.distributed.all_reduce(num_boxes)
+ world_size = torch.distributed.get_world_size()
+ num_boxes = torch.clamp(num_boxes / world_size, min=1)
+ else:
+ # Non-distributed: just clamp
+ num_boxes = torch.clamp(num_boxes, min=1)
+ elif self.config.normalization == "local":
+ num_boxes = torch.clamp(num_boxes, min=1)
+ else: # "none"
+ num_boxes = torch.ones_like(num_boxes)
+
+ return num_boxes
+
+ # 2D classification and box losses are now handled by SAM3's
+ # IABCEMdetr (self.cls_loss) and Boxes (self.box_loss) classes.
+
+ def _loss_o2m(
+ self,
+ pred_logits: Tensor, # (B, S, 1)
+ pred_boxes_2d: Tensor, # (B, S, 4) normalized xyxy
+ pred_boxes_2d_cxcywh: Tensor | None, # (B, S, 4) normalized cxcywh
+ pred_boxes_3d: Tensor | None, # (B, S, reg_dims)
+ targets: dict,
+ num_boxes: Tensor,
+ intrinsics: Tensor | None = None, # (B, 3, 3)
+ image_size: tuple[int, int] | None = None,
+ pred_conf_3d: Tensor | None = None, # (B, S, 1) 3D confidence
+ ) -> dict[str, Tensor]:
+ """Compute O2M (One-to-Many) auxiliary loss.
+
+ Uses SAM3's IABCEMdetr and Boxes classes for 2D losses,
+ plus our own 3D loss for matched predictions.
+ """
+ losses = {}
+ device = pred_logits.device
+ B, S = pred_logits.shape[:2]
+
+ # Prepare targets in padded format for O2M matcher
+ num_boxes_per_image = targets.get(
+ "num_boxes",
+ torch.tensor([len(targets["boxes_xyxy"])], device=device),
+ )
+ boxes_padded = targets.get("boxes_padded")
+ if boxes_padded is None:
+ boxes_cxcywh = self._xyxy_to_cxcywh(targets["boxes_xyxy"])
+ boxes_padded = _packed_to_padded(boxes_cxcywh, num_boxes_per_image)
+
+ max_N = boxes_padded.shape[1]
+ target_is_valid_padded = torch.zeros(
+ B, max_N, dtype=torch.bool, device=device
+ )
+ for i in range(B):
+ target_is_valid_padded[i, :num_boxes_per_image[i]] = True
+
+ # O2M matching
+ if pred_boxes_2d_cxcywh is None:
+ pred_boxes_2d_cxcywh = self._xyxy_to_cxcywh(pred_boxes_2d)
+
+ outputs_dict = {
+ "pred_logits": pred_logits,
+ "pred_boxes": pred_boxes_2d_cxcywh,
+ }
+ targets_dict = {
+ "boxes_padded": boxes_padded,
+ "labels": targets["classes"],
+ "num_boxes": num_boxes_per_image,
+ }
+ batch_idx, src_idx, tgt_idx = self.o2m_matcher(
+ outputs_dict,
+ targets_dict,
+ target_is_valid_padded=target_is_valid_padded,
+ )
+
+ if batch_idx.numel() == 0:
+ zero_losses = {
+ "loss_cls": torch.tensor(0.0, device=device),
+ "loss_bbox": torch.tensor(0.0, device=device),
+ "loss_giou": torch.tensor(0.0, device=device),
+ }
+ if pred_boxes_3d is not None and intrinsics is not None:
+ zero_losses.update({
+ "loss_delta_2d": torch.tensor(0.0, device=device),
+ "loss_depth": torch.tensor(0.0, device=device),
+ "loss_dim": torch.tensor(0.0, device=device),
+ "loss_rot": torch.tensor(0.0, device=device),
+ })
+ return zero_losses
+
+ o2m_indices = (batch_idx, src_idx, tgt_idx)
+
+ # Recompute ignore mask for O2M predictions (different pred boxes)
+ if "_ignore_boxes2d" in targets:
+ targets = targets.copy()
+ targets["ignore_neg_mask"] = self._compute_ignore_neg_mask(
+ pred_boxes_2d,
+ targets["_ignore_boxes2d"],
+ targets["_num_ignores"],
+ threshold=self.config.ignore_iou_threshold,
+ )
+
+ # 2D losses via SAM3 classes
+ o2m_outputs = {
+ "pred_logits": pred_logits,
+ "pred_boxes_xyxy": pred_boxes_2d,
+ "pred_boxes": pred_boxes_2d_cxcywh,
+ }
+ cls_losses = self.cls_loss.get_loss(
+ o2m_outputs, targets, o2m_indices, num_boxes
+ )
+ losses["loss_cls"] = cls_losses["loss_ce"]
+
+ box_losses = self.box_loss.get_loss(
+ o2m_outputs, targets, o2m_indices, num_boxes
+ )
+ losses["loss_bbox"] = box_losses["loss_bbox"]
+ losses["loss_giou"] = box_losses["loss_giou"]
+
+ # 3D losses (our own, not in SAM3)
+ if (pred_boxes_3d is not None and intrinsics is not None
+ and "boxes_3d" in targets):
+ loss_3d = self._loss_boxes_3d(
+ pred_boxes_2d=pred_boxes_2d,
+ pred_boxes_3d=pred_boxes_3d,
+ indices=o2m_indices,
+ targets=targets,
+ intrinsics=intrinsics,
+ num_boxes=num_boxes,
+ image_size=image_size,
+ )
+ losses.update(loss_3d)
+
+ # 3D confidence loss (O2M branch)
+ if (self.config.use_3d_conf
+ and pred_conf_3d is not None
+ and pred_boxes_3d is not None
+ and intrinsics is not None):
+ loss_3d_cls = self._loss_3d_classification(
+ pred_conf_3d, pred_boxes_2d, pred_boxes_3d,
+ o2m_indices, targets, intrinsics, num_boxes, image_size,
+ )
+ losses["loss_3d_cls"] = loss_3d_cls
+
+ return losses
+
+ # _loss_boxes_2d replaced by SAM3's Boxes class (self.box_loss).
+
+ def _loss_boxes_3d(
+ self,
+ pred_boxes_2d: Tensor, # (B, S, 4)
+ pred_boxes_3d: Tensor, # (B, S, reg_dims)
+ indices: tuple[Tensor, Tensor, Tensor | None],
+ targets: dict,
+ intrinsics: Tensor,
+ num_boxes: Tensor,
+ image_size: tuple[int, int] | None = None,
+ ) -> dict[str, Tensor]:
+ """Compute 3D box regression losses.
+
+ Args:
+ pred_boxes_2d: Predicted 2D boxes in normalized xyxy [0,1]. Shape (B, S, 4).
+ pred_boxes_3d: Predicted 3D box parameters. Shape (B, S, reg_dims).
+ indices: Matching indices (batch_idx, src_idx, tgt_idx).
+ targets: Target dict containing boxes_3d.
+ intrinsics: Camera intrinsics. Shape (B, 3, 3).
+ num_boxes: Number of matched boxes for normalization.
+ image_size: (H, W) tuple for converting normalized to pixel coords.
+ Required for correct box_coder.encode() which expects pixel coords.
+ """
+ batch_idx, src_idx, tgt_idx = indices
+
+ # Get matched predictions (for loss computation)
+ src_boxes_3d = pred_boxes_3d[(batch_idx, src_idx)]
+
+ # Get matched GT 2D boxes (for box_coder.encode target computation)
+ # IMPORTANT: Use GT 2D boxes, NOT predicted boxes!
+ # This matches GDino3D's design where encode() uses GT 2D boxes to compute
+ # stable targets, while decode() at inference uses predicted 2D boxes.
+ target_boxes_2d = (
+ targets["boxes_xyxy"][tgt_idx] if tgt_idx is not None
+ else targets["boxes_xyxy"]
+ )
+
+ # Get matched GT 3D boxes
+ target_boxes_3d = (
+ targets["boxes_3d"][tgt_idx] if tgt_idx is not None
+ else targets["boxes_3d"]
+ )
+
+ # Get intrinsics for matched samples
+ # Note: intrinsics is (B, 3, 3), need to index by batch_idx
+ # Since box_coder.encode() expects single intrinsics (3, 3),
+ # we need to process each matched box individually
+ if len(batch_idx) == 0:
+ # No matches, return zero losses
+ return {
+ "loss_delta_2d": torch.tensor(0.0, device=pred_boxes_2d.device),
+ "loss_depth": torch.tensor(0.0, device=pred_boxes_2d.device),
+ "loss_dim": torch.tensor(0.0, device=pred_boxes_2d.device),
+ "loss_rot": torch.tensor(0.0, device=pred_boxes_2d.device),
+ }
+
+ target_boxes_3d_encoded_list = []
+ weights_3d_list = []
+
+ # Validate image_size is provided - required for correct box_coder.encode()
+ if image_size is None:
+ raise ValueError(
+ "image_size is required for _loss_boxes_3d. "
+ "box_coder.encode() expects pixel coordinates because "
+ "project_points() returns pixel coords and "
+ "delta_center = projected_3d_center - 2d_box_center (both in pixels)."
+ )
+
+ H, W = image_size
+ factors = target_boxes_2d.new_tensor([W, H, W, H])
+
+ for i in range(len(batch_idx)):
+ single_box_3d = target_boxes_3d[i:i+1]
+
+ # Skip entries with invalid (all-zero) 3D boxes: set weight=0
+ # so they don't contribute to 3D loss. This handles the case
+ # where GT has a valid 2D box but no 3D annotation.
+ if single_box_3d.abs().sum() < 1e-6:
+ reg_dims = pred_boxes_3d.shape[-1]
+ target_boxes_3d_encoded_list.append(
+ torch.zeros(1, reg_dims, device=pred_boxes_3d.device)
+ )
+ weights_3d_list.append(
+ torch.zeros(1, reg_dims, device=pred_boxes_3d.device)
+ )
+ continue
+
+ # Use GT 2D box (normalized xyxy) and convert to pixel
+ single_gt_box_2d = target_boxes_2d[i:i+1]
+ single_gt_box_2d_pixel = single_gt_box_2d * factors
+
+ single_intrinsic = intrinsics[batch_idx[i]] # (3, 3)
+
+ encoded, weights = self.box_coder.encode(
+ single_gt_box_2d_pixel, single_box_3d, single_intrinsic,
+ )
+ target_boxes_3d_encoded_list.append(encoded)
+ weights_3d_list.append(weights)
+
+ target_boxes_3d_encoded = torch.cat(target_boxes_3d_encoded_list, dim=0)
+ weights_3d = torch.cat(weights_3d_list, dim=0)
+
+ losses = {}
+
+ # Delta 2D center loss
+ loss_delta_2d = l1_loss(
+ src_boxes_3d[:, :2],
+ target_boxes_3d_encoded[:, :2],
+ reducer=SumWeightedLoss(
+ weight=weights_3d[:, :2], avg_factor=num_boxes.item()
+ ),
+ )
+ losses["loss_delta_2d"] = loss_delta_2d * self.config.loss_delta_2d_weight
+
+ # Depth loss
+ loss_depth = l1_loss(
+ src_boxes_3d[:, 2],
+ target_boxes_3d_encoded[:, 2],
+ reducer=SumWeightedLoss(
+ weight=weights_3d[:, 2], avg_factor=num_boxes.item()
+ ),
+ )
+ losses["loss_depth"] = loss_depth * self.config.loss_depth_weight
+
+ # Dimension loss
+ loss_dim = l1_loss(
+ src_boxes_3d[:, 3:6],
+ target_boxes_3d_encoded[:, 3:6],
+ reducer=SumWeightedLoss(
+ weight=weights_3d[:, 3:6], avg_factor=num_boxes.item()
+ ),
+ )
+ losses["loss_dim"] = loss_dim * self.config.loss_dim_weight
+
+ # Rotation loss
+ loss_rot = l1_loss(
+ src_boxes_3d[:, 6:],
+ target_boxes_3d_encoded[:, 6:],
+ reducer=SumWeightedLoss(
+ weight=weights_3d[:, 6:], avg_factor=num_boxes.item()
+ ),
+ )
+ losses["loss_rot"] = loss_rot * self.config.loss_rot_weight
+
+ return losses
+
+ def _loss_3d_classification(
+ self,
+ pred_conf_3d: Tensor, # (B, S, 1)
+ pred_boxes_2d: Tensor, # (B, S, 4) normalized xyxy
+ pred_boxes_3d: Tensor, # (B, S, 12) encoded
+ indices: tuple[Tensor, Tensor, Tensor | None],
+ targets: dict,
+ intrinsics: Tensor, # (N_prompts, 3, 3)
+ num_boxes: Tensor,
+ image_size: tuple[int, int],
+ ) -> Tensor:
+ """Compute 3D confidence loss (positive + negative).
+
+ Positive: soft target = quality (0.7*iou_3d + 0.3*depth)
+ Negative: target = 0, with focal weighting
+ Same structure as 2D cls loss (IABCEMdetr).
+
+ At inference: final_score = 2d_score + 0.5 * 3d_score
+ """
+ batch_idx, src_idx, tgt_idx = indices
+ B, S, _ = pred_conf_3d.shape
+ device = pred_conf_3d.device
+ M = len(batch_idx)
+
+ if M == 0:
+ return pred_conf_3d.sum() * 0.0
+
+ prob = pred_conf_3d.sigmoid()
+ target_classes = torch.zeros(B, S, 1, device=device)
+ target_classes[(batch_idx, src_idx)] = 1.0
+
+ with torch.no_grad():
+ # 1. Depth quality - directly from encoded params, no decode needed
+ src_boxes_3d = pred_boxes_3d[(batch_idx, src_idx)]
+ target_boxes_3d_raw = (
+ targets["boxes_3d"][tgt_idx] if tgt_idx is not None
+ else targets["boxes_3d"]
+ )
+ depth_scale = self.box_coder.depth_scale
+ pred_log_z = src_boxes_3d[:, 2] / depth_scale # = log(pred_z)
+ gt_z = target_boxes_3d_raw[:, 2].clamp(min=0.1)
+ gt_log_z = torch.log(gt_z)
+ depth_quality = torch.exp(-torch.abs(pred_log_z - gt_log_z))
+ depth_quality = torch.nan_to_num(depth_quality, nan=0.0, posinf=1.0, neginf=0.0)
+
+ # 2. 3D IoU using safe shapely-based implementation
+ # (CPU, full rotation support, never crashes)
+ from wilddet3d.ops.iou_box3d import batch_box3d_iou
+
+ H, W = image_size
+ factors = pred_boxes_2d.new_tensor([[W, H, W, H]])
+ src_boxes_2d_pixel = pred_boxes_2d[(batch_idx, src_idx)] * factors
+
+ pred_decoded_list = []
+ for i in range(M):
+ single_decoded = self.box_coder.decode(
+ src_boxes_2d_pixel[i:i+1],
+ src_boxes_3d[i:i+1],
+ intrinsics[batch_idx[i]],
+ )
+ pred_decoded_list.append(single_decoded)
+ pred_decoded = torch.cat(pred_decoded_list, dim=0) # (M, 10)
+
+ iou_3d = batch_box3d_iou(pred_decoded, target_boxes_3d_raw[:, :10])
+
+ # 3. Combined quality
+ quality = (
+ self.config.conf_depth_weight * depth_quality
+ + self.config.conf_iou_3d_weight * iou_3d
+ )
+ quality = torch.nan_to_num(quality, nan=0.0).clamp(0.0, 1.0)
+
+ # 4. Build soft target (same as 2D IABCEMdetr pattern)
+ t = (
+ prob[(batch_idx, src_idx)].squeeze(-1) ** self.config.alpha
+ * quality ** (1 - self.config.alpha)
+ )
+ t = t.clamp(min=0.01).detach()
+
+ positive_target = target_classes.clone()
+ positive_target[(batch_idx, src_idx)] = t.unsqueeze(-1)
+
+ # Positive loss with soft quality target
+ loss_pos = F.binary_cross_entropy_with_logits(
+ pred_conf_3d, positive_target, reduction="none"
+ )
+ loss_pos = loss_pos * target_classes * self.config.pos_weight
+
+ # Negative loss with focal weighting (push unmatched queries toward 0)
+ loss_neg = F.binary_cross_entropy_with_logits(
+ pred_conf_3d, target_classes, reduction="none"
+ )
+ loss_neg = loss_neg * (1 - target_classes) * (prob ** self.config.gamma)
+
+ # Suppress negative loss for predictions overlapping ignore boxes
+ if "ignore_neg_mask" in targets:
+ neg_suppress = targets["ignore_neg_mask"].unsqueeze(-1)
+ loss_neg = loss_neg * (1 - neg_suppress)
+
+ loss_bce = loss_pos + loss_neg
+
+ # Apply presence mask (zero out loss for prompts with no GT)
+ if self.config.use_presence:
+ num_gts = targets.get(
+ "num_boxes", torch.zeros(B, dtype=torch.long, device=device)
+ )
+ keep_loss = (num_gts > 0).float().view(B, 1, 1) # (B, 1, 1) for (B, S, 1) broadcasting
+ loss_bce = loss_bce * keep_loss
+
+ return loss_bce.mean()
+
+ def _compute_aux_loss(
+ self,
+ aux_out: dict,
+ indices: tuple[Tensor, Tensor, Tensor | None],
+ targets: dict,
+ num_boxes: Tensor,
+ intrinsics: Tensor | None = None,
+ image_size: tuple[int, int] | None = None,
+ ) -> dict[str, Tensor]:
+ """Compute losses for auxiliary decoder outputs.
+
+ Following GDino3D's design, we compute all losses (2D + 3D) for auxiliary outputs
+ to enable full deep supervision across all decoder layers.
+
+ Args:
+ aux_out: Auxiliary output dictionary containing pred_logits, pred_boxes_2d, pred_boxes_3d
+ indices: Matching indices from matcher
+ targets: Ground truth targets
+ num_boxes: Number of boxes for normalization
+ intrinsics: Camera intrinsics for 3D loss computation
+ image_size: (H, W) for pixel coordinate conversion
+
+ Returns:
+ Dictionary of auxiliary losses
+ """
+ losses = {}
+
+ # Build SAM3-format outputs for aux layer
+ sam3_aux = {
+ "pred_logits": aux_out.get("pred_logits"),
+ "pred_boxes_xyxy": aux_out.get(
+ "pred_boxes_xyxy", aux_out.get("pred_boxes_2d")
+ ),
+ "pred_boxes": aux_out.get("pred_boxes"),
+ }
+ # If pred_boxes (cxcywh) not available, convert from xyxy
+ if sam3_aux["pred_boxes"] is None and sam3_aux["pred_boxes_xyxy"] is not None:
+ sam3_aux["pred_boxes"] = self._xyxy_to_cxcywh(
+ sam3_aux["pred_boxes_xyxy"]
+ )
+
+ # Recompute ignore mask for this aux layer's predicted boxes
+ if "_ignore_boxes2d" in targets and sam3_aux["pred_boxes_xyxy"] is not None:
+ targets = targets.copy()
+ targets["ignore_neg_mask"] = self._compute_ignore_neg_mask(
+ sam3_aux["pred_boxes_xyxy"],
+ targets["_ignore_boxes2d"],
+ targets["_num_ignores"],
+ threshold=self.config.ignore_iou_threshold,
+ )
+
+ # Classification loss via SAM3's IABCEMdetr (scaled by loss_2d_scale)
+ if sam3_aux["pred_logits"] is not None:
+ cls_losses = self.cls_loss.get_loss(
+ sam3_aux, targets, indices, num_boxes
+ )
+ losses["loss_cls"] = (
+ self.config.loss_2d_scale
+ * cls_losses["loss_ce"]
+ * self.config.loss_cls_weight
+ )
+
+ # 2D box losses via SAM3's Boxes class (scaled by loss_2d_scale)
+ if sam3_aux["pred_boxes"] is not None:
+ box_losses = self.box_loss.get_loss(
+ sam3_aux, targets, indices, num_boxes
+ )
+ losses["loss_bbox"] = (
+ self.config.loss_2d_scale
+ * box_losses["loss_bbox"]
+ * self.config.loss_bbox_weight
+ )
+ losses["loss_giou"] = (
+ self.config.loss_2d_scale
+ * box_losses["loss_giou"]
+ * self.config.loss_giou_weight
+ )
+
+ # 3D box loss (our own, scaled by loss_3d_scale)
+ pred_boxes_2d_aux = aux_out.get(
+ "pred_boxes_2d", aux_out.get("pred_boxes_xyxy")
+ )
+ if "pred_boxes_3d" in aux_out and intrinsics is not None:
+ loss_3d = self._loss_boxes_3d(
+ pred_boxes_2d_aux,
+ aux_out["pred_boxes_3d"],
+ indices,
+ targets,
+ intrinsics,
+ num_boxes,
+ image_size=image_size,
+ )
+ for key, value in loss_3d.items():
+ losses[key] = self.config.loss_3d_scale * value
+
+ # 3D confidence loss (deep supervision)
+ if (self.config.use_3d_conf
+ and "pred_conf_3d" in aux_out
+ and "pred_boxes_3d" in aux_out
+ and intrinsics is not None):
+ loss_3d_cls = self._loss_3d_classification(
+ aux_out["pred_conf_3d"],
+ pred_boxes_2d_aux,
+ aux_out["pred_boxes_3d"],
+ indices, targets, intrinsics, num_boxes, image_size,
+ )
+ losses["loss_3d_cls"] = self.config.loss_3d_conf_weight * loss_3d_cls
+
+ return losses
+
+ def _normalize_targets(self, targets: dict) -> dict:
+ """Ensure targets are in expected format for loss computation.
+
+ Note: WildDet3D collator always outputs GT boxes in normalized [0, 1] xyxy format.
+ This function simply ensures the classes tensor exists (for binary classification).
+
+ Args:
+ targets: Dictionary containing ground truth data
+ - boxes_xyxy: (N, 4) boxes in normalized xyxy [0, 1] format
+ - classes: (N,) class labels (all ones for SAM3)
+ - num_boxes: (N,) number of boxes per prompt (always 1)
+ - boxes_3d: (N, 12) 3D boxes (optional)
+
+ Returns:
+ Targets dict with classes tensor guaranteed to exist
+ """
+ normalized = targets.copy()
+ boxes_xyxy = targets["boxes_xyxy"]
+
+ # Ensure classes tensor exists (all ones for binary classification)
+ if "classes" not in normalized:
+ num_boxes = boxes_xyxy.shape[0]
+ normalized["classes"] = torch.ones(
+ num_boxes, dtype=torch.long, device=boxes_xyxy.device
+ )
+
+ return normalized
+
+ def _xyxy_to_cxcywh(self, boxes_xyxy: Tensor) -> Tensor:
+ """Convert boxes from xyxy to cxcywh format.
+
+ Args:
+ boxes_xyxy: (N, 4) boxes in xyxy format
+
+ Returns:
+ boxes_cxcywh: (N, 4) boxes in cxcywh format
+ """
+ x1, y1, x2, y2 = boxes_xyxy.unbind(-1)
+ cx = (x1 + x2) / 2
+ cy = (y1 + y2) / 2
+ w = x2 - x1
+ h = y2 - y1
+ return torch.stack([cx, cy, w, h], dim=-1)
+
diff --git a/wilddet3d/model.py b/wilddet3d/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..929b7444dd927b423b6d0d91b10c95e5440d98d1
--- /dev/null
+++ b/wilddet3d/model.py
@@ -0,0 +1,1647 @@
+"""WildDet3D: SAM3 with 3D Detection Head.
+
+This module combines SAM3 (2D detection with geometric prompting) with
+3D detection head and geometry backend.
+
+Key Design Decisions (from Design Doc):
+1. Coordinate format: SAM3 uses normalized cxcywh internally,
+ model outputs normalized xyxy [0, 1]
+2. Tensor format: SAM3 Decoder outputs sequence-first (L, S, B, C),
+ 3D Head expects batch-first (L, B, S, C) -> need permute
+3. Batch strategy: per-prompt batch with img_ids indexing
+4. bbox_head: Reuse SAM3 Decoder's internal bbox_embed,
+ no external bbox_head needed
+5. Forward: Reuse SAM3's forward_grounding() method for 2D detection,
+ then add 3D head on top
+
+Data Flow:
+1. DataLoader produces per-image data
+2. Collator expands to per-prompt batch (WildDet3DInput)
+3. Model forward receives expanded data, calls SAM3's forward_grounding
+4. 3D head processes SAM3 output
+"""
+
+from __future__ import annotations
+
+from typing import List
+
+import torch
+from torch import Tensor, nn
+from torchvision.ops import nms, batched_nms, box_iou
+
+from wilddet3d.ops.profiler import profile_start, profile_stop, profile_step
+
+# SAM3 imports
+from sam3.model.sam3_image import Sam3Image
+from sam3.model.geometry_encoders import Prompt
+from sam3.model.box_ops import box_cxcywh_to_xyxy
+from sam3.model.data_misc import FindStage, BatchedFindTarget
+
+# 3D detection imports
+from wilddet3d.head import (
+ Det3DHead,
+ Det3DCoder,
+ RoI2Det3D,
+)
+from wilddet3d.data_types import Det3DOut, WildDet3DOut, WildDet3DInput
+from wilddet3d.depth import GeometryBackendBase
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ """LayerNorm that always computes in fp32.
+
+ In mixed-precision training (bf16/fp16), standard LayerNorm can overflow
+ because the variance computation involves squaring values. bf16 max is
+ ~65504, so values > ~256 squared will overflow.
+
+ This wrapper casts input to fp32, runs LayerNorm, then casts back.
+ The overhead is negligible since LayerNorm is memory-bound.
+ """
+
+ def forward(self, x: Tensor) -> Tensor:
+ orig_dtype = x.dtype
+ x = x.float()
+ x = super().forward(x)
+ return x.to(orig_dtype)
+
+
+def _upgrade_layernorms_to_fp32(module: nn.Module) -> int:
+ """Replace all nn.LayerNorm in a module tree with Fp32LayerNorm.
+
+ Walks the module tree and swaps each nn.LayerNorm with an Fp32LayerNorm
+ that shares the same weight and bias tensors (no copy, no extra memory).
+
+ Args:
+ module: Root module to patch.
+
+ Returns:
+ Number of LayerNorm modules replaced.
+ """
+ count = 0
+ for name, child in module.named_children():
+ if isinstance(child, nn.LayerNorm) and not isinstance(child, Fp32LayerNorm):
+ fp32_ln = Fp32LayerNorm(
+ child.normalized_shape,
+ eps=child.eps,
+ elementwise_affine=child.elementwise_affine,
+ )
+ # Share weight/bias tensors (no copy)
+ fp32_ln.weight = child.weight
+ fp32_ln.bias = child.bias
+ setattr(module, name, fp32_ln)
+ count += 1
+ else:
+ count += _upgrade_layernorms_to_fp32(child)
+ return count
+
+
+class WildDet3D(nn.Module):
+ """SAM3 with 3D Detection Head.
+
+ This model combines:
+ 1. SAM3's backbone, encoder, decoder (for 2D detection with geometric prompting)
+ 2. Geometry backend (depth estimation)
+ 3. 3D head (3D box regression)
+
+ Architecture:
+ ```
+ Image + Prompts
+ |
+ v
+ +------------------------------------------+
+ | SAM3 (backbone + encoder + decoder) |
+ | - ViT backbone with SimpleFPN |
+ | - Geometry Encoder for prompts |
+ | - Transformer Encoder/Decoder |
+ | - Internal bbox_embed for 2D boxes |
+ +-------------------+----------------------+
+ | hidden_states, pred_boxes (cxcywh)
+ |
+ +-------+-------+
+ v v
+ +-----------+ +---------------+
+ | cxcywh | | Geometry |
+ | -> xyxy | | Backend |
+ +-----+-----+ | (depth) |
+ | +-------+-------+
+ | | depth_latents
+ v v
+ +-------------------------------+
+ | 3D Head |
+ | (depth + ray cross-attention)|
+ +---------------+---------------+
+ |
+ v
+ pred_boxes_3d
+ ```
+ """
+
+ def __init__(
+ self,
+ # ========== SAM3 Components ==========
+ sam3_model: Sam3Image | None = None,
+ sam3_checkpoint: str | None = None,
+
+ # ========== 3D Components ==========
+ bbox3d_head: Det3DHead | None = None,
+ box_coder: Det3DCoder | None = None,
+ geometry_backend: GeometryBackendBase | None = None,
+ roi2det3d: RoI2Det3D | None = None,
+
+ # ========== Depth-Memory Fusion ==========
+ early_depth_fusion: nn.Module | None = None,
+
+ # ========== Freeze Settings ==========
+ backbone_freeze_blocks: int = 0,
+
+ # ========== Oracle Evaluation ==========
+ oracle_eval: bool = False,
+
+ # ========== Depth Input at Test Time ==========
+ use_depth_input_test: bool = False,
+
+ # ========== Predicted Intrinsics ==========
+ use_predicted_intrinsics: bool = False,
+
+ # ========== Eval Score Control ==========
+ eval_3d_conf_weight: float = 0.5,
+ use_presence_score: bool = True,
+ ) -> None:
+ """Initialize WildDet3D.
+
+ Args:
+ sam3_model: Complete SAM3 model (backbone + encoder + decoder).
+ If None, will be built from sam3_checkpoint.
+ sam3_checkpoint: Path to SAM3 checkpoint. Only used if sam3_model is None.
+ bbox3d_head: 3D box regression head. If None, creates default.
+ box_coder: 3D box encoder/decoder. If None, creates default.
+ geometry_backend: Depth estimation backend. If None, no depth.
+ roi2det3d: Inference post-processor. If None, creates default.
+ early_depth_fusion: Early fusion module (after backbone, before encoder).
+ If None, no early fusion is performed.
+ backbone_freeze_blocks: Number of SAM3 ViT backbone blocks to
+ freeze (from the beginning). SAM3 has 32 blocks; e.g. 30
+ freezes blocks[0..29], only training the last 2.
+ 0 means no freezing.
+ oracle_eval: If True, use oracle evaluation mode where each
+ prompt gets top-1 prediction (no NMS, no score filtering).
+ For measuring 3D regression quality with GT box prompts.
+ use_predicted_intrinsics: If True, use geometry backend's
+ predicted intrinsics (K_pred) for 3D box decoding at test
+ time instead of batch.intrinsics (dataset/default).
+ Useful for in-the-wild images without GT intrinsics.
+ Can be overridden by env var SAM3_USE_PRED_K=1/0.
+ eval_3d_conf_weight: Weight for 3D confidence in eval score.
+ final_score = 2d_score + weight * 3d_score.
+ Set to 0.0 to use only 2D confidence for eval.
+ """
+ super().__init__()
+
+ # SAM3 model - build if not provided
+ if sam3_model is None:
+ import os
+ from sam3.model_builder import build_sam3_image_model
+
+ # Check if torch.compile should be enabled for SAM3
+ use_compile = os.environ.get("SAM3_COMPILE", "0") == "1"
+ if use_compile:
+ print("[WildDet3D] torch.compile ENABLED for SAM3 backbone (SAM3_COMPILE=1)")
+ else:
+ print("[WildDet3D] torch.compile disabled (set SAM3_COMPILE=1 to enable)")
+
+ print(f"Building SAM3 model from checkpoint: {sam3_checkpoint}")
+ sam3_model = build_sam3_image_model(
+ checkpoint_path=sam3_checkpoint,
+ load_from_HF=(sam3_checkpoint is None), # Only load from HF if no checkpoint provided
+ device="cpu", # Will be moved to correct device later
+ eval_mode=False, # Must be False to enable matcher for training
+ enable_segmentation=False, # Skip seg head for 3D detection (saves ~4GB memory)
+ compile=use_compile, # Enable torch.compile for backbone
+ )
+ # Store checkpoint path for logging in on_load_checkpoint
+ self._sam3_checkpoint_path = sam3_checkpoint
+ else:
+ self._sam3_checkpoint_path = "provided_model"
+
+ self.sam3 = sam3_model
+ self.hidden_dim = sam3_model.hidden_dim
+ self.oracle_eval = oracle_eval
+ self.use_depth_input_test = use_depth_input_test
+ self.use_predicted_intrinsics = use_predicted_intrinsics
+ self.eval_3d_conf_weight = eval_3d_conf_weight
+ self.use_presence_score = use_presence_score
+ print(f"[WildDet3D] use_presence_score={self.use_presence_score}")
+
+ # 3D components
+ self.box_coder = box_coder or Det3DCoder()
+ self.geometry_backend = geometry_backend
+ self.roi2det3d = roi2det3d
+ self.early_depth_fusion = early_depth_fusion
+
+ # Determine use_camera_prompt based on geometry_backend.is_ray_aware
+ # Ray-aware backends already fuse ray info into depth_latents,
+ # so we don't need the separate ray_embeddings (camera prompt) branch.
+ if self.geometry_backend is not None and hasattr(self.geometry_backend, 'is_ray_aware'):
+ use_camera_prompt = not self.geometry_backend.is_ray_aware
+ print(f"[WildDet3D] geometry_backend.is_ray_aware={self.geometry_backend.is_ray_aware}, use_camera_prompt={use_camera_prompt}")
+ else:
+ use_camera_prompt = True # Default to True for safety
+ print(f"[WildDet3D] No geometry_backend or is_ray_aware attr, defaulting use_camera_prompt=True")
+
+ # Get depth_latent_dim from geometry_backend (for 3D head)
+ if self.geometry_backend is not None and hasattr(self.geometry_backend, 'target_latent_dim'):
+ depth_latent_dim = self.geometry_backend.target_latent_dim
+ else:
+ depth_latent_dim = 256 # Default
+
+ # Create or validate bbox3d_head with correct use_camera_prompt setting
+ if bbox3d_head is not None:
+ self.bbox3d_head = bbox3d_head
+ # Warn if provided head has mismatched use_camera_prompt
+ if hasattr(bbox3d_head, 'use_camera_prompt') and bbox3d_head.use_camera_prompt != use_camera_prompt:
+ print(f"[WildDet3D] Warning: bbox3d_head.use_camera_prompt={bbox3d_head.use_camera_prompt} "
+ f"but geometry_backend suggests use_camera_prompt={use_camera_prompt}")
+ else:
+ self.bbox3d_head = Det3DHead(
+ embed_dims=self.hidden_dim,
+ box_coder=self.box_coder,
+ use_camera_prompt=use_camera_prompt,
+ depth_latent_dim=depth_latent_dim,
+ )
+ print(f"[WildDet3D] Created bbox3d_head with use_camera_prompt={use_camera_prompt}, depth_latent_dim={depth_latent_dim}")
+
+ # 3D conf_branches use xavier init (from _init_weights in head.py).
+ # No warm start from class_embed: the positive-only loss design
+ # (quality targets ~0.1-0.3 early) conflicts with class_embed's
+ # high-logit initialization, causing large initial loss.
+
+ # Load geometry backend pretrained weights
+ # This is called during __init__ to ensure weights are loaded for first training
+ # (on_load_checkpoint is only called when resuming from checkpoint)
+ if self.geometry_backend is not None and hasattr(self.geometry_backend, 'load_pretrained_weights'):
+ print("[WildDet3D] Loading geometry backend pretrained weights...")
+ self.geometry_backend.load_pretrained_weights()
+
+ # Ensure SAM3 has a matcher for training
+ # SAM3 built with eval_mode=True doesn't have a matcher, so we create one
+ # Using BinaryHungarianMatcherV2 with focal=True to match SAM3 original config
+ if self.sam3.matcher is None:
+ from sam3.train.matcher import BinaryHungarianMatcherV2
+ print("[WildDet3D] Creating BinaryHungarianMatcherV2 for training...")
+ self.sam3.matcher = BinaryHungarianMatcherV2(
+ cost_class=2.0, # SAM3 original
+ cost_bbox=5.0, # SAM3 original
+ cost_giou=2.0, # SAM3 original
+ focal=True, # SAM3 original
+ alpha=0.25, # SAM3 original
+ gamma=2.0, # SAM3 original
+ )
+
+ # Freeze SAM3 ViT backbone blocks (like lingbot encoder_freeze_blocks)
+ # SAM3 ViT has 32 blocks at sam3.backbone.vision_backbone.trunk.blocks
+ if backbone_freeze_blocks > 0:
+ trunk = self.sam3.backbone.vision_backbone.trunk
+ num_blocks = len(trunk.blocks)
+ backbone_freeze_blocks = min(backbone_freeze_blocks, num_blocks)
+
+ # Freeze patch_embed + ln_pre + first N blocks
+ for p in trunk.patch_embed.parameters():
+ p.requires_grad = False
+ for p in trunk.ln_pre.parameters():
+ p.requires_grad = False
+ for i in range(backbone_freeze_blocks):
+ for p in trunk.blocks[i].parameters():
+ p.requires_grad = False
+
+ frozen_params = sum(
+ p.numel() for p in trunk.parameters() if not p.requires_grad
+ )
+ total_params = sum(p.numel() for p in trunk.parameters())
+ print(
+ f"[WildDet3D] Backbone freeze: {backbone_freeze_blocks}/{num_blocks}"
+ f" blocks frozen ({frozen_params/1e6:.1f}M/{total_params/1e6:.1f}M params)"
+ )
+
+ # Upgrade ALL LayerNorm in the entire model to fp32.
+ # In bf16 mixed-precision, LayerNorm's variance computation can
+ # overflow (bf16 max ~65504). This covers sam3 (transformer decoder,
+ # backbone, encoder), geometry_backend (DINOv2 encoder, intrinsic
+ # head), early_depth_fusion (depth_norm), and bbox3d_head.
+ # Negligible performance cost -- LayerNorm is memory-bound.
+ n_replaced = _upgrade_layernorms_to_fp32(self)
+ print(f"[WildDet3D] Upgraded {n_replaced} LayerNorm -> Fp32LayerNorm (entire model)")
+
+ def _xyxy_to_cxcywh(self, boxes: Tensor) -> Tensor:
+ """Convert boxes from xyxy to cxcywh format.
+
+ Args:
+ boxes: Tensor of shape (..., 4) in xyxy format
+
+ Returns:
+ Tensor of shape (..., 4) in cxcywh format
+ """
+ x1, y1, x2, y2 = boxes.unbind(-1)
+ cx = (x1 + x2) / 2
+ cy = (y1 + y2) / 2
+ w = x2 - x1
+ h = y2 - y1
+ return torch.stack([cx, cy, w, h], dim=-1)
+
+ def _build_find_target(self, batch: WildDet3DInput) -> BatchedFindTarget:
+ """Convert WildDet3DInput GT to SAM3's BatchedFindTarget format.
+
+ This is used for SAM3's internal matching during training.
+
+ SAM3 expects:
+ - boxes: (N_total, 4) packed cxcywh normalized
+ - boxes_padded: (N_prompts, max_gt, 4) padded cxcywh
+ - num_boxes: (N_prompts,) number of GT per prompt
+ - is_exhaustive: (N_prompts,) bool
+
+ Note: In WildDet3D, each prompt corresponds to exactly one GT box,
+ so gt_boxes2d has shape (N_prompts, 4) not (N_prompts, max_gt, 4).
+
+ Args:
+ batch: WildDet3DInput with gt_boxes2d in normalized xyxy
+
+ Returns:
+ BatchedFindTarget for SAM3's _compute_matching
+ """
+ device = batch.gt_boxes2d.device
+ gt_boxes_xyxy = batch.gt_boxes2d
+
+ # Handle different input shapes
+ # Case 1: (N_prompts, 4) - one GT per prompt (WildDet3D design)
+ # Case 2: (N_prompts, max_gt, 4) - multiple GTs per prompt (general case)
+ if gt_boxes_xyxy.dim() == 2:
+ # Shape: (N_prompts, 4) - one GT per prompt
+ N_prompts = gt_boxes_xyxy.shape[0]
+ max_gt = 1
+
+ # Convert xyxy -> cxcywh
+ gt_boxes_cxcywh = self._xyxy_to_cxcywh(gt_boxes_xyxy) # (N_prompts, 4)
+
+ # Each prompt has exactly 1 GT box
+ num_boxes = torch.ones(N_prompts, dtype=torch.long, device=device)
+
+ # Packed boxes = all boxes (no padding)
+ boxes_packed = gt_boxes_cxcywh # (N_prompts, 4)
+
+ # Padded format: add max_gt dimension
+ gt_boxes_cxcywh_padded = gt_boxes_cxcywh.unsqueeze(1) # (N_prompts, 1, 4)
+
+ # Object IDs: sequential
+ object_ids = torch.arange(N_prompts, device=device)
+ object_ids_padded = torch.arange(N_prompts, device=device).unsqueeze(1) # (N_prompts, 1)
+
+ else:
+ # Shape: (N_prompts, max_gt, 4) - multiple GTs per prompt
+ N_prompts = gt_boxes_xyxy.shape[0]
+ max_gt = gt_boxes_xyxy.shape[1]
+
+ # Convert xyxy -> cxcywh
+ gt_boxes_cxcywh = self._xyxy_to_cxcywh(gt_boxes_xyxy)
+
+ # Compute num_boxes per prompt (count non-zero boxes)
+ valid_mask = (gt_boxes_xyxy.abs().sum(dim=-1) > 1e-6) # (N_prompts, max_gt)
+ num_boxes = valid_mask.sum(dim=-1) # (N_prompts,)
+
+ # Pack boxes (remove padding)
+ boxes_list = []
+ for i in range(N_prompts):
+ n = int(num_boxes[i].item())
+ if n > 0:
+ boxes_list.append(gt_boxes_cxcywh[i, :n])
+ if boxes_list:
+ boxes_packed = torch.cat(boxes_list, dim=0) # (N_total, 4)
+ else:
+ boxes_packed = torch.zeros(0, 4, device=device)
+
+ gt_boxes_cxcywh_padded = gt_boxes_cxcywh
+
+ # Object IDs (placeholder - just sequential)
+ object_ids = torch.arange(len(boxes_packed), device=device)
+ object_ids_padded = torch.full(
+ (N_prompts, max_gt), -1, device=device, dtype=torch.long
+ )
+ offset = 0
+ for i in range(N_prompts):
+ n = int(num_boxes[i].item())
+ if n > 0:
+ object_ids_padded[i, :n] = torch.arange(
+ offset, offset + n, device=device
+ )
+ offset += n
+
+ return BatchedFindTarget(
+ num_boxes=num_boxes,
+ boxes=boxes_packed,
+ boxes_padded=gt_boxes_cxcywh_padded,
+ repeated_boxes=None,
+ segments=None,
+ semantic_segments=None,
+ is_valid_segment=None,
+ # is_exhaustive: controls negative loss masking in SAM3's IABCEMdetr.
+ # Multi-target queries (TEXT=0, VISUAL=1, VISUAL+LABEL=3) are exhaustive:
+ # all instances of the category are annotated as targets.
+ # Single-target queries (GEOMETRY=2, GEOMETRY+LABEL=4) are NOT exhaustive:
+ # only 1 selected instance is the target, other instances of the
+ # same category exist but are not annotated for this query.
+ is_exhaustive=self._get_is_exhaustive(batch, N_prompts, device),
+ object_ids=object_ids,
+ object_ids_padded=object_ids_padded,
+ )
+
+ def _get_is_exhaustive(
+ self,
+ batch: WildDet3DInput,
+ N_prompts: int,
+ device: torch.device,
+ ) -> Tensor:
+ """Determine is_exhaustive per query based on query_types.
+
+ Multi-target queries (TEXT=0, VISUAL=1, VISUAL+LABEL=3) are exhaustive:
+ all instances of the category are annotated as targets, so unmatched
+ predictions should receive negative loss.
+
+ Single-target queries (GEOMETRY=2, GEOMETRY+LABEL=4) are NOT exhaustive:
+ only 1 selected instance is the target. Other instances of the same
+ category exist but are not annotated for this query, so unmatched
+ predictions should NOT receive negative loss.
+ """
+ if batch.query_types is not None:
+ qt = batch.query_types.to(device)
+ return (qt == 0) | (qt == 1) | (qt == 3)
+ return torch.ones(N_prompts, dtype=torch.bool, device=device)
+
+ def on_load_checkpoint(self, checkpoint):
+ """
+ PyTorch Lightning hook called when loading a checkpoint.
+
+ This is called BEFORE load_state_dict, so we can:
+ 1. Load SAM3 pretrained weights first (if first training)
+ 2. Load geometry backend pretrained weights first (if first training)
+ 3. Filter out incompatible keys from the checkpoint
+ 4. Let PyTorch Lightning load the filtered checkpoint
+ """
+ print("\n" + "="*80)
+ print("WildDet3D CHECKPOINT LOADING (PyTorch Lightning Hook)")
+ print("="*80)
+
+ # Get the state_dict from checkpoint
+ state_dict = checkpoint.get('state_dict', {})
+
+ # Analyze checkpoint content
+ has_sam3 = any('sam3.' in key for key in state_dict.keys())
+ has_geometry_backend = any('geometry_backend' in key for key in state_dict.keys())
+ has_bbox3d_head = any('bbox3d_head' in key for key in state_dict.keys())
+
+ # Determine if this is resume training or first training
+ is_resume = has_sam3 and has_geometry_backend
+
+ if is_resume:
+ # Resume training: load everything from checkpoint
+ print("\nMode: Resume Training")
+ print("Loading complete checkpoint (all components)")
+ print(f" Resuming from epoch {checkpoint.get('epoch', 'unknown')}")
+ print(f" Resuming from global_step {checkpoint.get('global_step', 'unknown')}")
+
+ else:
+ # First training: load pretrained weights
+ print("\nMode: First Training (Fine-tuning)")
+
+ # Step 1: Load SAM3 pretrained weights (if not already loaded in __init__)
+ if not has_sam3 and self.sam3 is not None:
+ print("\n[Step 1/3] SAM3 weights already loaded in __init__")
+ print(f" SAM3 checkpoint: {getattr(self, '_sam3_checkpoint_path', 'unknown')}")
+
+ # Step 2: Load geometry backend pretrained weights
+ if self.geometry_backend is not None and hasattr(self.geometry_backend, 'load_pretrained_weights'):
+ print("\n[Step 2/3] Loading geometry backend pretrained weights...")
+ self.geometry_backend.load_pretrained_weights()
+
+ # Step 3: Filter checkpoint if needed
+ print("\n[Step 3/3] Processing checkpoint...")
+ if not has_sam3:
+ print(" No SAM3 weights in checkpoint (will use pretrained SAM3)")
+ if not has_geometry_backend:
+ print(" No geometry_backend weights in checkpoint (will use pretrained)")
+ if not has_bbox3d_head:
+ print(" No bbox3d_head weights in checkpoint (will initialize randomly)")
+
+ # Step 4: Reset training state (epoch, step, optimizer)
+ print("\n[Step 4/4] Resetting training state for fine-tuning...")
+ if 'epoch' in checkpoint:
+ old_epoch = checkpoint['epoch']
+ checkpoint['epoch'] = 0
+ print(f" Reset epoch: {old_epoch} -> 0")
+
+ if 'global_step' in checkpoint:
+ old_step = checkpoint['global_step']
+ checkpoint['global_step'] = 0
+ print(f" Reset global_step: {old_step} -> 0")
+
+ # Remove optimizer states (they won't match our new optimizer config)
+ if 'optimizer_states' in checkpoint:
+ del checkpoint['optimizer_states']
+ print(f" Removed optimizer_states (will initialize fresh)")
+
+ # Remove lr_scheduler states
+ if 'lr_schedulers' in checkpoint:
+ del checkpoint['lr_schedulers']
+ print(f" Removed lr_schedulers (will initialize fresh)")
+
+ # Store resume status for later use
+ self._is_resume_training = is_resume
+
+ print("\n" + "="*80)
+ print("Checkpoint loading hook completed")
+ print("="*80 + "\n")
+
+ def forward(
+ self,
+ batch: WildDet3DInput,
+ targets: dict | None = None,
+ ) -> WildDet3DOut:
+ """Forward pass of WildDet3D using SAM3's forward_grounding.
+
+ This method reuses SAM3's complete 2D detection pipeline and adds
+ 3D detection on top.
+
+ Args:
+ batch: WildDet3DInput containing:
+ - images: (B_images, 3, H, W)
+ - intrinsics: (B_images, 3, 3)
+ - img_ids: (N_prompts,) - which image each prompt belongs to
+ - text_ids: (N_prompts,) - text index per prompt
+ - unique_texts: List[str] - all unique texts
+ - geo_boxes: (N_prompts, max_K, 4) - normalized cxcywh
+ - geo_boxes_mask: (N_prompts, max_K) - True=padding
+ - geo_box_labels: (N_prompts, max_K) - 0/1 for neg/pos
+ targets: Training targets (optional)
+
+ Returns:
+ WildDet3DOut with 2D and 3D predictions
+ """
+ B_images = batch.images.shape[0]
+ N_prompts = len(batch.img_ids)
+ _, _, H, W = batch.images.shape
+ device = batch.images.device
+
+ profile_start("forward_total")
+
+ # Sync SAM3 training mode with parent module
+ # This is important because SAM3's forward_grounding only computes
+ # matching indices when self.training is True
+ if self.sam3.training != self.training:
+ self.sam3.train(self.training)
+
+ # Handle empty batch (no prompts)
+ if N_prompts == 0:
+ if self.training:
+ # Create dummy output connected to ALL model parameters for DDP backward
+ # DDP requires all parameters to participate in backward across all ranks
+ # Using only one parameter causes deadlock when other ranks use all params
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ print(f"[WildDet3D] Empty batch (N_prompts=0) on rank {rank}, using all-param dummy")
+ dummy_grad = sum(p.sum() * 0 for p in self.parameters() if p.requires_grad)
+ dummy_logits = torch.zeros(1, 1, 1, device=device) + dummy_grad
+ return WildDet3DOut(
+ pred_logits=dummy_logits,
+ pred_boxes_2d=torch.zeros(1, 1, 4, device=device),
+ pred_boxes_3d=None,
+ aux_outputs=None,
+ geom_losses=None,
+ presence_logits=None,
+ queries=None,
+ encoder_hidden_states=None,
+ indices=None,
+ )
+ else:
+ # Test mode: return empty Det3DOut
+ return Det3DOut(
+ boxes=[torch.zeros(0, 4, device=device) for _ in range(B_images)],
+ boxes3d=[torch.zeros(0, 10, device=device) for _ in range(B_images)],
+ scores=[torch.zeros(0, device=device) for _ in range(B_images)],
+ class_ids=[torch.zeros(0, dtype=torch.long, device=device) for _ in range(B_images)],
+ depth_maps=None,
+ categories=None,
+ )
+
+ # ========== Step 1 & 2: SAM3 Backbone + Geometry Backend (PARALLEL) ==========
+ # These two operations are independent - run them in parallel using CUDA streams
+ profile_start(" backbone+geom_parallel")
+
+ # Convert images for SAM3 (needed by backbone)
+ images_for_sam3 = self._convert_imagenet_to_sam3_norm(batch.images)
+
+ # Prepare geometry backend inputs
+ geom_losses = None
+ depth_latents = None
+ geom_out = None
+ _, _, H, W = batch.images.shape
+
+ if self.geometry_backend is not None:
+ # Create CUDA streams for parallel execution
+ backbone_stream = torch.cuda.Stream()
+ geom_stream = torch.cuda.Stream()
+
+ # Prepare inputs for geometry backend (before streams)
+ intrinsics_per_image = batch.intrinsics
+ depth_gt = None
+ depth_mask = None
+ if self.training or self.use_depth_input_test:
+ depth_gt = getattr(batch, 'depth_gt', None)
+ if self.training:
+ depth_mask = getattr(batch, 'depth_mask', None)
+
+ # Run backbone on stream 1
+ profile_start(" backbone")
+ with torch.cuda.stream(backbone_stream):
+ backbone_out = {"img_batch_all_stages": batch.images}
+ backbone_out.update(self.sam3.backbone.forward_image(images_for_sam3))
+ text_out = self.sam3.backbone.forward_text(
+ batch.unique_texts, device=device
+ )
+ backbone_out.update(text_out)
+
+ # Run geometry backend on stream 2 (parallel with backbone)
+ profile_start(" geometry_backend")
+ with torch.cuda.stream(geom_stream):
+ geom_out = self.geometry_backend(
+ images=batch.images,
+ depth_feats=None, # Not using backbone features
+ intrinsics=intrinsics_per_image,
+ image_hw=(H, W),
+ depth_gt=depth_gt,
+ depth_mask=depth_mask,
+ padding=batch.padding,
+ )
+
+ # Wait for both streams to complete
+ backbone_stream.synchronize()
+ profile_stop(" backbone")
+ geom_stream.synchronize()
+ profile_stop(" geometry_backend")
+
+ # Extract geometry outputs
+ depth_latents = geom_out.get("depth_latents")
+ if self.training:
+ geom_losses = geom_out.get("losses", {})
+ else:
+ # No geometry backend - just run backbone
+ profile_start(" backbone")
+ backbone_out = {"img_batch_all_stages": batch.images}
+ backbone_out.update(self.sam3.backbone.forward_image(images_for_sam3))
+ text_out = self.sam3.backbone.forward_text(
+ batch.unique_texts, device=device
+ )
+ backbone_out.update(text_out)
+ profile_stop(" backbone")
+
+ profile_stop(" backbone+geom_parallel")
+
+ # ========== Step 2.5: Early Depth Fusion (after backbone, before encoder) ==========
+ # Fuse depth_latents into backbone visual features before encoder
+ # This allows depth information to participate in encoder's self-attention
+ # and text cross-attention
+ if self.early_depth_fusion is not None and depth_latents is not None:
+ # Get depth_latents spatial dimensions from geometry backend output
+ aux = geom_out.get("aux", {})
+ depth_latents_hw = aux.get("depth_latents_hw")
+
+ if depth_latents_hw is not None and "backbone_fpn" in backbone_out:
+ # Fuse depth into visual features
+ backbone_fpn = backbone_out["backbone_fpn"]
+
+ # early_depth_fusion expects list of visual features
+ if not isinstance(backbone_fpn, list):
+ backbone_fpn = [backbone_fpn]
+
+ # Perform fusion
+ fused_fpn = self.early_depth_fusion(
+ visual_feats=backbone_fpn,
+ depth_latents=depth_latents,
+ depth_latents_hw=depth_latents_hw,
+ )
+
+ # Update backbone_out with fused features
+ # SAM3 will use these fused features in encoder
+ if len(fused_fpn) == 1:
+ backbone_out["backbone_fpn"] = fused_fpn[0]
+ else:
+ backbone_out["backbone_fpn"] = fused_fpn
+
+ # Log fusion delta magnitude (monitoring only)
+ if self.training and geom_losses is not None:
+ geom_losses["metric_fusion_delta"] = torch.tensor(
+ self.early_depth_fusion._last_delta_mean_abs,
+ device=device,
+ )
+ else:
+ # Warn user that early depth fusion is configured but cannot run
+ import warnings
+ if depth_latents_hw is None:
+ warnings.warn(
+ "EarlyDepthFusion is configured but depth_latents_hw not "
+ "provided by geometry backend. Skipping depth fusion. "
+ "Check geometry backend outputs include 'aux.depth_latents_hw'.",
+ UserWarning,
+ )
+ elif "backbone_fpn" not in backbone_out:
+ warnings.warn(
+ "EarlyDepthFusion is configured but backbone_fpn not found "
+ "in backbone outputs. Skipping depth fusion.",
+ UserWarning,
+ )
+
+ # ========== Step 3: Build SAM3 inputs ==========
+ find_input = self._build_find_stage(batch, device)
+ geometric_prompt = self._build_geometric_prompt(batch, device)
+
+ # ========== Step 4: SAM3 forward_grounding ==========
+ # This does: encode_prompt -> encoder -> decoder -> score/box prediction
+ #
+ # In training mode, we build find_target from batch GT boxes so that
+ # SAM3's internal _compute_matching can compute matching indices.
+ # These indices are then used by our loss function.
+ find_target = None
+ if self.training:
+ assert batch.gt_boxes2d is not None, \
+ "Training requires GT boxes (batch.gt_boxes2d)"
+ find_target = self._build_find_target(batch)
+
+ profile_start(" sam3_grounding")
+ sam3_out = self.sam3.forward_grounding(
+ backbone_out=backbone_out,
+ find_input=find_input,
+ find_target=find_target,
+ geometric_prompt=geometric_prompt,
+ )
+ profile_stop(" sam3_grounding")
+
+ # ========== Step 5: Extract SAM3 outputs ==========
+ # SAM3 output format (after _update_scores_and_boxes):
+ # - pred_logits: (N_prompts, num_queries, 1) - final layer
+ # - pred_boxes: (N_prompts, num_queries, 4) - normalized cxcywh
+ # - pred_boxes_xyxy: (N_prompts, num_queries, 4) - normalized xyxy
+ # - queries: (N_prompts, num_queries, d_model) - last layer hidden states
+ # - aux_outputs: list of dicts for each decoder layer (for deep supervision)
+ # O2O outputs (one-to-one matching)
+ pred_logits = sam3_out["pred_logits"] # (N_prompts, S, 1)
+ pred_boxes_xyxy = sam3_out["pred_boxes_xyxy"] # (N_prompts, S, 4)
+ pred_boxes_cxcywh = sam3_out["pred_boxes"] # (N_prompts, S, 4)
+ queries = sam3_out.get("queries") # (N_prompts, S, d_model)
+ encoder_hidden_states = sam3_out.get("encoder_hidden_states")
+ presence_logits = sam3_out.get("presence_logit_dec")
+
+ # O2M outputs (one-to-many matching) from SAM3 DAC mechanism
+ # These are separate outputs from the second half of queries in DAC mode
+ pred_logits_o2m = sam3_out.get("pred_logits_o2m") # (N_prompts, S, 1)
+ pred_boxes_xyxy_o2m = sam3_out.get("pred_boxes_xyxy_o2m") # (N_prompts, S, 4)
+ pred_boxes_cxcywh_o2m = sam3_out.get("pred_boxes_o2m") # (N_prompts, S, 4)
+ queries_o2m = sam3_out.get("queries_o2m") # (N_prompts, S, d_model)
+
+ # Extract auxiliary outputs from SAM3 for deep supervision
+ sam3_aux_outputs = sam3_out.get("aux_outputs", [])
+
+ # ========== Step 6: 3D Head ==========
+ profile_start(" 3d_head")
+ pred_boxes_3d = None
+ pred_conf_3d = None
+ aux_outputs = None
+
+ if self.bbox3d_head is not None and queries is not None:
+ # Generate ray embeddings if camera prompt is enabled
+ # For ray-aware backends, depth_latents already
+ # contain ray info, so we can either use camera prompt or skip it
+ ray_embeddings = None
+ if self.bbox3d_head.use_camera_prompt:
+ # Get ray parameters from geometry backend output
+ if geom_out is not None:
+ # Use backend's ray parameters for consistent space
+ ray_intrinsics = geom_out.get("ray_intrinsics", batch.intrinsics)
+ ray_image_hw = geom_out.get("ray_image_hw", (H, W))
+ ray_downsample = geom_out.get("ray_downsample", 16)
+ else:
+ # Fallback: use image-level intrinsics with default downsample
+ # Note: This will broadcast to all prompts, not per-prompt
+ ray_intrinsics = batch.intrinsics
+ ray_image_hw = (H, W)
+ ray_downsample = 16 # Default
+
+ ray_embeddings = self.bbox3d_head.get_camera_embeddings(
+ ray_intrinsics, ray_image_hw, ray_downsample
+ )
+
+ # Align depth_latents and ray_embeddings spatial resolution (if needed)
+ #
+ # Note: This code only runs when use_camera_prompt=True (i.e., for non-ray-aware
+ # backends). For ray-aware backends, use_camera_prompt=False and
+ # ray_embeddings=None, so this block is skipped.
+ #
+ # When this does run, depth_latents and ray_embeddings may have different spatial
+ # resolutions that need to be aligned for the 3D head's cross-attention.
+ if depth_latents is not None and ray_embeddings is not None:
+ # depth_latents: [B_images, N_depth, C_depth]
+ # ray_embeddings: [B_images, N_ray, C_ray]
+ B_depth, N_depth, C_depth = depth_latents.shape
+ B_ray, N_ray, C_ray = ray_embeddings.shape
+
+ if N_depth != N_ray:
+ # Resize depth_latents to match ray spatial size
+ # Infer spatial dimensions (assuming square)
+ H_depth = int(N_depth ** 0.5)
+ W_depth = H_depth
+ H_ray = int(N_ray ** 0.5)
+ W_ray = H_ray
+
+ # Reshape depth_latents: [B, N, C] -> [B, C, H, W]
+ depth_latents_2d = depth_latents.permute(0, 2, 1).reshape(
+ B_depth, C_depth, H_depth, W_depth
+ )
+
+ # Adaptive pool to ray size
+ depth_latents_resized = torch.nn.functional.adaptive_avg_pool2d(
+ depth_latents_2d, (H_ray, W_ray)
+ )
+
+ # Reshape back: [B, C, H, W] -> [B, N, C]
+ depth_latents = depth_latents_resized.reshape(
+ B_depth, C_depth, H_ray * W_ray
+ ).permute(0, 2, 1)
+
+ # Index ray_embeddings and depth_latents from per-image to per-prompt
+ # ray_embeddings and depth_latents are per-image [B_images, N, C]
+ # But 3D head expects them to be per-prompt [N_prompts, N, C]
+ # Use batch.img_ids to correctly map prompts to their corresponding images
+ if ray_embeddings is not None:
+ # batch.img_ids: [N_prompts] - which image each prompt belongs to
+ # ray_embeddings: [B_images, N, C]
+ # Index to get: [N_prompts, N, C]
+ ray_embeddings = ray_embeddings[batch.img_ids]
+
+ if depth_latents is not None:
+ # depth_latents: [B_images, N, C]
+ # Index to get: [N_prompts, N, C]
+ depth_latents = depth_latents[batch.img_ids]
+
+ # ========== Deep Supervision: Process all decoder layers ==========
+ # Following SAM3's design, we process auxiliary outputs from all decoder layers
+ # for deep supervision during training
+ #
+ # SAM3's output structure:
+ # - aux_outputs[0..L-2]: intermediate decoder layers (layer 0 to layer L-2)
+ # - final output (pred_logits, queries, etc.): final decoder layer (layer L-1)
+
+ # Collect all layers' queries in correct order: [layer0, layer1, ..., layerL-1]
+ # Track which aux_outputs have queries for building aux_outputs later
+ all_layers_queries = []
+ aux_indices_with_queries = [] # Track original indices of aux_outputs with queries
+ for i, aux_out in enumerate(sam3_aux_outputs):
+ aux_queries = aux_out.get("queries")
+ if aux_queries is not None:
+ all_layers_queries.append(aux_queries)
+ aux_indices_with_queries.append(i)
+ all_layers_queries.append(queries) # Final layer at the end
+
+ # Stack to (L, N_prompts, S, C) format expected by 3D head
+ if len(all_layers_queries) > 1:
+ # Have auxiliary outputs - stack all layers
+ hidden_states = torch.stack(all_layers_queries, dim=0) # (L, N_prompts, S, C)
+ else:
+ # No auxiliary outputs - just expand final layer
+ hidden_states = queries.unsqueeze(0) # (1, N_prompts, S, C)
+
+ # Call 3D head with all layers
+ # Returns: (L, N_prompts, S, 12), (L, N_prompts, S, 1)
+ all_layers_boxes_3d, all_layers_conf_3d = self.bbox3d_head(
+ hidden_states=hidden_states,
+ ray_embeddings=ray_embeddings,
+ depth_latents=depth_latents,
+ )
+
+ # Extract final layer output
+ if len(all_layers_queries) > 1:
+ pred_boxes_3d = all_layers_boxes_3d[-1] # (N_prompts, S, 12)
+ pred_conf_3d = all_layers_conf_3d[-1] # (N_prompts, S, 1)
+ else:
+ pred_boxes_3d = all_layers_boxes_3d.squeeze(0) # (N_prompts, S, 12)
+ pred_conf_3d = all_layers_conf_3d.squeeze(0) # (N_prompts, S, 1)
+
+ # Build auxiliary outputs for deep supervision
+ # Only include layers that have queries (tracked by aux_indices_with_queries)
+ if len(aux_indices_with_queries) > 0 and self.training:
+ aux_outputs = []
+ for layer_idx, orig_idx in enumerate(aux_indices_with_queries):
+ aux_out = sam3_aux_outputs[orig_idx]
+ aux_dict = {
+ "pred_logits": aux_out["pred_logits"],
+ "pred_boxes_2d": aux_out["pred_boxes_xyxy"],
+ "pred_boxes_3d": all_layers_boxes_3d[layer_idx], # 3D predictions for this layer
+ }
+ # Include presence logits if available
+ if "presence_logit_dec" in aux_out:
+ aux_dict["presence_logits"] = aux_out["presence_logit_dec"]
+ aux_outputs.append(aux_dict)
+
+ # Compute 3D boxes for O2M queries (if available, only during training)
+ pred_boxes_3d_o2m = None
+ pred_conf_3d_o2m = None
+ if self.bbox3d_head is not None and queries_o2m is not None and self.training:
+ # O2M queries use the same 3D head but only compute final layer (no aux)
+ o2m_hidden_states = queries_o2m.unsqueeze(0) # (1, N_prompts, S, C)
+ o2m_boxes_3d, o2m_conf_3d = self.bbox3d_head(
+ hidden_states=o2m_hidden_states,
+ ray_embeddings=ray_embeddings,
+ depth_latents=depth_latents,
+ )
+ pred_boxes_3d_o2m = o2m_boxes_3d.squeeze(0) # (N_prompts, S, 12)
+ pred_conf_3d_o2m = o2m_conf_3d.squeeze(0) # (N_prompts, S, 1)
+
+ profile_stop(" 3d_head")
+
+ # Training mode: return raw outputs for loss computation
+ if self.training:
+ # Extract matching indices from SAM3 output (computed by _compute_matching)
+ sam3_indices = sam3_out.get("indices", None)
+
+ profile_stop("forward_total")
+
+ # Record profiling step (will print summary every N steps if enabled)
+ profile_step()
+
+ return WildDet3DOut(
+ pred_logits=pred_logits,
+ pred_boxes_2d=pred_boxes_xyxy,
+ pred_boxes_3d=pred_boxes_3d,
+ aux_outputs=aux_outputs,
+ geom_losses=geom_losses,
+ presence_logits=presence_logits,
+ queries=queries,
+ encoder_hidden_states=encoder_hidden_states,
+ indices=sam3_indices,
+ pred_boxes_2d_cxcywh=pred_boxes_cxcywh,
+ # O2M outputs from SAM3 DAC mechanism
+ pred_logits_o2m=pred_logits_o2m,
+ pred_boxes_2d_o2m=pred_boxes_xyxy_o2m,
+ pred_boxes_2d_cxcywh_o2m=pred_boxes_cxcywh_o2m,
+ pred_boxes_3d_o2m=pred_boxes_3d_o2m,
+ # 3D confidence head outputs
+ pred_conf_3d=pred_conf_3d,
+ pred_conf_3d_o2m=pred_conf_3d_o2m,
+ )
+
+ # Test mode: forward_test returns Det3DOut for evaluation
+ return self._forward_test(
+ pred_logits=pred_logits,
+ pred_boxes_2d=pred_boxes_xyxy,
+ pred_boxes_3d=pred_boxes_3d,
+ pred_conf_3d=pred_conf_3d,
+ presence_logits=presence_logits,
+ batch=batch,
+ geom_out=geom_out,
+ )
+
+ def _convert_imagenet_to_sam3_norm(self, images: Tensor) -> Tensor:
+ """Convert ImageNet normalized images to SAM3 normalization.
+
+ vis4d/3D-MOOD uses ImageNet normalization:
+ ImageNet: (x - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
+ Output range: ~[-2.5, 2.5]
+
+ SAM3 expects custom normalization:
+ SAM3: (x - 0.5) / 0.5
+ Output range: [-1, 1]
+
+ This function converts from ImageNet normalized to SAM3 normalized:
+ 1. Denormalize ImageNet -> [0, 1]
+ 2. Normalize SAM3 -> [-1, 1]
+
+ Args:
+ images: ImageNet normalized images (B, 3, H, W)
+
+ Returns:
+ SAM3 normalized images (B, 3, H, W)
+ """
+ # ImageNet constants
+ imagenet_mean = torch.tensor(
+ [0.485, 0.456, 0.406], device=images.device, dtype=images.dtype
+ ).view(1, 3, 1, 1)
+ imagenet_std = torch.tensor(
+ [0.229, 0.224, 0.225], device=images.device, dtype=images.dtype
+ ).view(1, 3, 1, 1)
+
+ # Denormalize: ImageNet normalized -> [0, 1]
+ images_01 = images * imagenet_std + imagenet_mean
+
+ # Normalize: [0, 1] -> SAM3 [-1, 1]
+ images_sam3 = (images_01 - 0.5) / 0.5
+
+ return images_sam3
+
+ def _forward_test(
+ self,
+ pred_logits: Tensor,
+ pred_boxes_2d: Tensor,
+ pred_boxes_3d: Tensor | None,
+ pred_conf_3d: Tensor | None = None,
+ presence_logits: Tensor | None = None,
+ batch: WildDet3DInput | None = None,
+ geom_out: dict | None = None,
+ ) -> Det3DOut:
+ """Forward pass for test/inference mode.
+
+ Postprocesses model outputs to Det3DOut format for evaluation.
+ Converts per-prompt outputs to per-image outputs with:
+ - Pixel coordinate boxes (scaled from normalized)
+ - Decoded 3D boxes
+ - Score thresholding (optional)
+
+ Args:
+ pred_logits: (N_prompts, S, 1) objectness logits
+ pred_boxes_2d: (N_prompts, S, 4) normalized xyxy boxes
+ pred_boxes_3d: (N_prompts, S, 12) encoded 3D params or None
+ pred_conf_3d: (N_prompts, S, 1) 3D confidence logits or None
+ presence_logits: (N_prompts, 1) presence logits (category exists in image)
+ batch: Input batch with img_ids, intrinsics, etc.
+ geom_out: Geometry backend output (may contain depth_maps)
+
+ Returns:
+ Det3DOut with per-image detection results
+ """
+ H, W = batch.images.shape[2:]
+ device = pred_logits.device
+ B_images = batch.images.shape[0]
+
+ # 2D confidence (foreground/background) - used for threshold & NMS
+ scores_2d = pred_logits.sigmoid().squeeze(-1) # (N_prompts, S)
+
+ # 3D confidence (depth/geometry quality) - tracked separately
+ scores_3d_all = None
+ if pred_conf_3d is not None:
+ scores_3d_all = pred_conf_3d.sigmoid().squeeze(-1) # (N_prompts, S)
+
+ # Combined score for ranking (NMS tie-breaking etc)
+ # WILDDET3D_CONF_WEIGHT env var overrides config (e.g., "0.0" for 2D only)
+ import os
+ conf_weight = self.eval_3d_conf_weight
+ conf_weight_override = os.environ.get("WILDDET3D_CONF_WEIGHT", None)
+ if conf_weight_override is not None:
+ conf_weight = float(conf_weight_override)
+ if scores_3d_all is not None and conf_weight > 0:
+ scores_all = scores_2d + conf_weight * scores_3d_all
+ else:
+ scores_all = scores_2d
+
+ # Apply presence score if available (following SAM3 original postprocessors.py)
+ # Presence score indicates whether a category has objects in the image
+ # This suppresses all proposals for categories that don't exist in the image
+ # SAM3 original: presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
+ if presence_logits is not None and self.use_presence_score:
+ presence_score = presence_logits.sigmoid()
+ # Ensure correct shape for broadcasting: (N_prompts, 1) or (N_prompts,) -> (N_prompts, 1)
+ if presence_score.dim() == 1:
+ presence_score = presence_score.unsqueeze(-1)
+ scores_all = scores_all * presence_score # (N_prompts, S) * (N_prompts, 1)
+ scores_2d = scores_2d * presence_score # Also apply to 2D scores
+
+ # Scale boxes to pixel coordinates
+ # pred_boxes_2d is normalized xyxy [0, 1]
+ boxes_pixel = pred_boxes_2d.clone()
+ boxes_pixel[..., 0::2] *= W
+ boxes_pixel[..., 1::2] *= H
+
+ # Group by image
+ boxes_list = []
+ boxes3d_list = []
+ scores_list = []
+ scores_2d_list = []
+ scores_3d_list = []
+ class_ids_list = []
+
+ # Get parameters from roi2det3d if available
+ score_threshold = getattr(self.roi2det3d, 'score_threshold', -1.0) if self.roi2det3d else -1.0
+
+ # NMS parameters (following 3D-MOOD's RoI2Det3D design)
+ # Note: max_per_img not used - WildDet3D already limits to 100 proposals per category
+ use_nms = getattr(self.roi2det3d, 'nms', False) if self.roi2det3d else False
+ # class_agnostic_nms=False: NMS only within same category (recommended for per-category prediction)
+ class_agnostic_nms = getattr(self.roi2det3d, 'class_agnostic_nms', False) if self.roi2det3d else False
+ iou_threshold = getattr(self.roi2det3d, 'iou_threshold', 0.5) if self.roi2det3d else 0.5
+
+ # Environment variable overrides (useful for A/B testing)
+ import os
+ # SAM3_NMS=0 to disable, SAM3_NMS=1 to enable
+ nms_override = os.environ.get("SAM3_NMS", None)
+ if nms_override is not None:
+ use_nms = nms_override == "1"
+ # SAM3_SCORE_THRESH to override score threshold (e.g., "0.0" to disable)
+ score_thresh_override = os.environ.get("SAM3_SCORE_THRESH", None)
+ if score_thresh_override is not None:
+ score_threshold = float(score_thresh_override)
+ # SAM3_IOU_THRESH to override NMS IoU threshold (e.g., "0.8" for more conservative)
+ iou_thresh_override = os.environ.get("SAM3_IOU_THRESH", None)
+ if iou_thresh_override is not None:
+ iou_threshold = float(iou_thresh_override)
+
+ # Debug: print config once at start
+ if not hasattr(self, '_nms_config_printed'):
+ print(f"[NMS CONFIG] use_nms={use_nms}, class_agnostic={class_agnostic_nms}, iou_thresh={iou_threshold}, score_thresh={score_threshold}")
+ # Log predicted intrinsics setting
+ _use_pred_k = self.use_predicted_intrinsics
+ _pred_k_override = os.environ.get("SAM3_USE_PRED_K", None)
+ if _pred_k_override is not None:
+ _use_pred_k = _pred_k_override == "1"
+ print(f"[INTRINSICS CONFIG] use_predicted_intrinsics={_use_pred_k}")
+ self._nms_config_printed = True
+
+ S = scores_all.shape[1] # predictions per prompt
+
+ for img_idx in range(B_images):
+ # Find prompts belonging to this image
+ prompt_mask = batch.img_ids == img_idx
+ n_prompts_this_img = prompt_mask.sum().item()
+
+ if n_prompts_this_img == 0:
+ # No prompts for this image
+ boxes_list.append(torch.zeros(0, 4, device=device))
+ boxes3d_list.append(torch.zeros(0, 10, device=device))
+ scores_list.append(torch.zeros(0, device=device))
+ scores_2d_list.append(torch.zeros(0, device=device))
+ scores_3d_list.append(torch.zeros(0, device=device))
+ class_ids_list.append(torch.zeros(0, dtype=torch.long, device=device))
+ continue
+
+ # Get predictions for this image's prompts
+ img_scores = scores_all[prompt_mask] # (n_prompts, S)
+ img_boxes = boxes_pixel[prompt_mask] # (n_prompts, S, 4)
+
+ # Get class IDs for each prompt
+ if batch.gt_category_ids is not None:
+ img_class_ids = batch.gt_category_ids[prompt_mask] # (n_prompts,) or (n_prompts, max_gt)
+ if img_class_ids.dim() > 1:
+ img_class_ids = img_class_ids[:, 0] # Take first if multiple
+ elif batch.text_ids is not None:
+ img_class_ids = batch.text_ids[prompt_mask]
+ else:
+ img_class_ids = torch.zeros(n_prompts_this_img, dtype=torch.long, device=device)
+
+ if self.oracle_eval:
+ # Oracle mode: IoU top-K + highest confidence
+ # 1. Compute 2D IoU between each proposal and its GT box
+ # 2. Take top-K proposals by IoU (well-localized candidates)
+ # 3. Among top-K, pick highest confidence (best quality)
+ oracle_topk = int(os.environ.get("SAM3_ORACLE_TOPK", "10"))
+ prompt_indices = torch.arange(n_prompts_this_img, device=device)
+ best_indices = torch.zeros(n_prompts_this_img, dtype=torch.long, device=device)
+
+ if batch.geo_boxes is not None:
+ # geo_boxes is in padded-normalized cxcywh (correct space)
+ img_geo_boxes = batch.geo_boxes[prompt_mask] # (n_prompts, max_K, 4)
+ gt_cxcywh = img_geo_boxes[:, 0, :] # (n_prompts, 4)
+ gt_xyxy_norm = box_cxcywh_to_xyxy(gt_cxcywh)
+ gt_boxes_pixel = gt_xyxy_norm.clone()
+ gt_boxes_pixel[:, 0::2] *= W
+ gt_boxes_pixel[:, 1::2] *= H
+
+ K = min(oracle_topk, S)
+ for p_idx in range(n_prompts_this_img):
+ ious = box_iou(
+ img_boxes[p_idx], gt_boxes_pixel[p_idx].unsqueeze(0)
+ ).squeeze(-1) # (S,)
+ # Top-K by IoU
+ _, topk_iou_indices = ious.topk(K)
+ # Among top-K, pick highest confidence
+ topk_scores = img_scores[p_idx][topk_iou_indices]
+ best_in_topk = topk_scores.argmax()
+ best_indices[p_idx] = topk_iou_indices[best_in_topk]
+
+ if img_idx == 0 and not hasattr(self, '_oracle_debug_printed'):
+ self._oracle_debug_printed = True
+ p0_ious = box_iou(
+ img_boxes[0], gt_boxes_pixel[0].unsqueeze(0)
+ ).squeeze(-1)
+ sel = best_indices[0].item()
+ print(
+ f"[ORACLE] topK={K}, "
+ f"IoU={p0_ious[sel]:.4f}, "
+ f"score={img_scores[0][sel]:.4f}, "
+ f"maxIoU={p0_ious.max():.4f}"
+ )
+ else:
+ # Fallback: pure argmax
+ best_indices = img_scores.argmax(dim=1)
+
+ img_scores_flat = img_scores[prompt_indices, best_indices]
+ img_boxes_flat = img_boxes[prompt_indices, best_indices]
+ img_class_ids_flat = img_class_ids
+
+ # Track 2D and 3D scores for oracle mode
+ img_scores_2d_flat = scores_2d[prompt_mask][prompt_indices, best_indices]
+ if scores_3d_all is not None:
+ img_scores_3d = scores_3d_all[prompt_mask]
+ img_scores_3d_flat = img_scores_3d[prompt_indices, best_indices]
+ else:
+ img_scores_3d_flat = torch.zeros_like(img_scores_flat)
+
+ if pred_boxes_3d is not None:
+ img_boxes3d = pred_boxes_3d[prompt_mask]
+ img_boxes3d_flat = img_boxes3d[prompt_indices, best_indices]
+ else:
+ img_boxes3d_flat = None
+
+ else:
+ # Standard mode: flatten all proposals + NMS
+ # Flatten all predictions: (n_prompts, S) -> (n_prompts * S,)
+ img_scores_flat = img_scores.flatten() # (n_prompts * S,)
+ img_boxes_flat = img_boxes.reshape(-1, 4) # (n_prompts * S, 4)
+
+ # Track 2D scores separately for threshold filtering and output
+ img_scores_2d = scores_2d[prompt_mask].flatten() # (n_prompts * S,)
+ img_scores_2d_flat = img_scores_2d # alias for output
+
+ # Track 3D scores
+ if scores_3d_all is not None:
+ img_scores_3d_flat = scores_3d_all[prompt_mask].flatten()
+ else:
+ img_scores_3d_flat = torch.zeros_like(img_scores_flat)
+
+ # Expand class_ids to match flattened shape
+ img_class_ids_flat = img_class_ids.unsqueeze(1).expand(-1, S).flatten() # (n_prompts * S,)
+
+ # Get 3D boxes if available (flattened)
+ if pred_boxes_3d is not None:
+ img_boxes3d = pred_boxes_3d[prompt_mask] # (n_prompts, S, 12)
+ img_boxes3d_flat = img_boxes3d.reshape(-1, 12) # (n_prompts * S, 12)
+ else:
+ img_boxes3d_flat = None
+
+ # Score threshold filter (uses 2D score only)
+ if score_threshold > 0:
+ keep = img_scores_2d > score_threshold
+ img_scores_flat = img_scores_flat[keep]
+ img_scores_2d_flat = img_scores_2d_flat[keep]
+ img_scores_2d = img_scores_2d[keep]
+ img_scores_3d_flat = img_scores_3d_flat[keep]
+ img_boxes_flat = img_boxes_flat[keep]
+ img_class_ids_flat = img_class_ids_flat[keep]
+ if img_boxes3d_flat is not None:
+ img_boxes3d_flat = img_boxes3d_flat[keep]
+
+ # NMS based on 2D boxes (following RoI2Det3D design)
+ if use_nms and len(img_boxes_flat) > 0:
+ n_before_nms = len(img_boxes_flat)
+ if class_agnostic_nms:
+ keep = nms(img_boxes_flat, img_scores_flat, iou_threshold)
+ else:
+ keep = batched_nms(
+ img_boxes_flat, img_scores_flat, img_class_ids_flat, iou_threshold
+ )
+ img_scores_flat = img_scores_flat[keep]
+ img_scores_2d_flat = img_scores_2d_flat[keep]
+ img_scores_3d_flat = img_scores_3d_flat[keep]
+ img_boxes_flat = img_boxes_flat[keep]
+ img_class_ids_flat = img_class_ids_flat[keep]
+ if img_boxes3d_flat is not None:
+ img_boxes3d_flat = img_boxes3d_flat[keep]
+ if img_idx == 0:
+ n_after_nms = len(img_boxes_flat)
+ print(f"[NMS DEBUG] img={img_idx}, before={n_before_nms}, after={n_after_nms}, suppressed={n_before_nms - n_after_nms}, iou_thresh={iou_threshold}")
+
+ # Decode 3D boxes in padded space BEFORE rescaling (matching GDino3D)
+ # Use padded-space intrinsics since 2D boxes are still in padded
+ # pixel coordinates at this point.
+ # When use_predicted_intrinsics is enabled, use geometry backend's
+ # K_pred (also in padded space) instead of dataset intrinsics.
+ if img_boxes3d_flat is not None and self.box_coder is not None and len(img_boxes_flat) > 0:
+ # Determine whether to use predicted intrinsics
+ use_pred_k = self.use_predicted_intrinsics
+ pred_k_override = os.environ.get("SAM3_USE_PRED_K", None)
+ if pred_k_override is not None:
+ use_pred_k = pred_k_override == "1"
+
+ if use_pred_k and geom_out is not None and "K_pred" in geom_out and geom_out["K_pred"] is not None:
+ intrinsics_this_img = geom_out["K_pred"][img_idx] # (3, 3) padded-space
+ else:
+ intrinsics_this_img = batch.intrinsics[img_idx] # (3, 3) padded-space
+
+ decoded_boxes3d = self.box_coder.decode(
+ img_boxes_flat, # pixel xyxy in padded space
+ img_boxes3d_flat,
+ intrinsics_this_img,
+ )
+ else:
+ decoded_boxes3d = torch.zeros(len(img_boxes_flat), 10, device=device)
+
+ # Rescale 2D boxes from padded space (H, W) to original image space
+ # Must account for CenterPad: first subtract padding offset, then
+ # divide by content_size/original_size (NOT padded_size/original_size).
+ # Matches GDino3D RoI2Det3D.__call__ (head.py:380-396).
+ if batch.original_hw is not None:
+ # original_hw may be List[tuple] or a single tuple
+ # (Lightning's transfer_batch_to_device can unwrap
+ # single-element lists for batch_size=1)
+ hw = batch.original_hw
+ if isinstance(hw, (tuple, list)) and len(hw) == 2 and isinstance(hw[0], (int, float)):
+ # Direct tuple (h, w) - single image batch
+ orig_h, orig_w = hw
+ elif isinstance(hw, (tuple, list)) and img_idx < len(hw):
+ orig_h, orig_w = hw[img_idx]
+ else:
+ orig_h, orig_w = None, None
+
+ if orig_h is None:
+ continue
+
+ img_boxes_flat = img_boxes_flat.clone() # Don't modify in-place
+
+ # padding may also be unwrapped for batch_size=1
+ pad_info = batch.padding
+ if pad_info is not None:
+ if isinstance(pad_info, (tuple, list)) and len(pad_info) == 4 and isinstance(pad_info[0], (int, float)):
+ # Direct [L,R,T,B] - single image batch
+ pad_left, pad_right, pad_top, pad_bottom = pad_info
+ elif isinstance(pad_info, (tuple, list)) and img_idx < len(pad_info) and pad_info[img_idx] is not None:
+ pad_left, pad_right, pad_top, pad_bottom = pad_info[img_idx]
+ else:
+ pad_left = pad_right = pad_top = pad_bottom = 0
+
+ # Step 1: subtract CenterPad offset
+ img_boxes_flat[:, 0::2] -= pad_left
+ img_boxes_flat[:, 1::2] -= pad_top
+ # Step 2: scale = content_size / original_size
+ content_w = W - pad_left - pad_right
+ content_h = H - pad_top - pad_bottom
+ scale_x = content_w / orig_w
+ scale_y = content_h / orig_h
+ else:
+ # Fallback: no padding info, use full image size
+ scale_x = W / orig_w
+ scale_y = H / orig_h
+ img_boxes_flat[:, 0::2] /= scale_x # x coordinates
+ img_boxes_flat[:, 1::2] /= scale_y # y coordinates
+
+ boxes_list.append(img_boxes_flat)
+ boxes3d_list.append(decoded_boxes3d)
+ scores_list.append(img_scores_flat)
+ scores_2d_list.append(img_scores_2d_flat)
+ scores_3d_list.append(img_scores_3d_flat)
+ class_ids_list.append(img_class_ids_flat)
+
+ # Get depth maps if available
+ depth_maps = None
+ if geom_out is not None and "depth_map" in geom_out:
+ depth_maps = [geom_out["depth_map"][i] for i in range(B_images)]
+
+ # Get predicted intrinsics if available
+ predicted_intrinsics = None
+ if geom_out is not None and "K_pred" in geom_out:
+ predicted_intrinsics = geom_out["K_pred"]
+
+ return Det3DOut(
+ boxes=boxes_list,
+ boxes3d=boxes3d_list,
+ scores=scores_list,
+ class_ids=class_ids_list,
+ depth_maps=depth_maps,
+ categories=None,
+ predicted_intrinsics=predicted_intrinsics,
+ scores_3d=scores_3d_list,
+ scores_2d=scores_2d_list,
+ )
+
+ def _build_find_stage(
+ self,
+ batch: WildDet3DInput,
+ device: torch.device,
+ ) -> FindStage:
+ """Convert WildDet3DInput to SAM3's FindStage format.
+
+ FindStage is SAM3's internal representation for per-prompt batch,
+ containing img_ids, text_ids, and geometry inputs.
+ """
+ N_prompts = len(batch.img_ids)
+
+ # Prepare geometry inputs - need to convert to sequence-first
+ # FindStage expects (max_K, N_prompts, 4) for boxes
+ if batch.geo_boxes is not None:
+ # (N_prompts, max_K, 4) -> (max_K, N_prompts, 4)
+ input_boxes = batch.geo_boxes.permute(1, 0, 2)
+ input_boxes_mask = batch.geo_boxes_mask # (N_prompts, max_K)
+ input_boxes_label = (
+ batch.geo_box_labels.permute(1, 0)
+ if batch.geo_box_labels is not None
+ else torch.ones(
+ input_boxes.shape[0], N_prompts, dtype=torch.long, device=device
+ )
+ )
+ else:
+ # No geometry input - create empty tensors
+ input_boxes = torch.zeros(0, N_prompts, 4, device=device)
+ input_boxes_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device)
+ input_boxes_label = torch.zeros(0, N_prompts, dtype=torch.long, device=device)
+
+ # Points (if any)
+ if batch.geo_points is not None:
+ input_points = batch.geo_points.permute(1, 0, 2) # (max_P, N, 2)
+ input_points_mask = batch.geo_points_mask
+ else:
+ input_points = torch.zeros(0, N_prompts, 2, device=device)
+ input_points_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device)
+
+ return FindStage(
+ img_ids=batch.img_ids,
+ text_ids=batch.text_ids,
+ input_boxes=input_boxes,
+ input_boxes_mask=input_boxes_mask,
+ input_boxes_label=input_boxes_label,
+ input_points=input_points,
+ input_points_mask=input_points_mask,
+ object_ids=None,
+ )
+
+ def _build_geometric_prompt(
+ self,
+ batch: WildDet3DInput,
+ device: torch.device,
+ ) -> Prompt:
+ """Build SAM3 Prompt object from batch.
+
+ SAM3's Prompt class expects sequence-first format: (K, N_prompts, dim)
+ """
+ N_prompts = len(batch.img_ids)
+
+ # Box prompts
+ if batch.geo_boxes is not None and batch.geo_boxes.shape[1] > 0:
+ # (N_prompts, max_K, 4) -> (max_K, N_prompts, 4)
+ box_embeddings = batch.geo_boxes.permute(1, 0, 2)
+ box_mask = batch.geo_boxes_mask # (N_prompts, max_K)
+ box_labels = (
+ batch.geo_box_labels.permute(1, 0)
+ if batch.geo_box_labels is not None
+ else torch.ones(
+ box_embeddings.shape[0], N_prompts, dtype=torch.long, device=device
+ )
+ )
+ else:
+ box_embeddings = None
+ box_mask = None
+ box_labels = None
+
+ # Point prompts
+ if batch.geo_points is not None and batch.geo_points.shape[1] > 0:
+ point_embeddings = batch.geo_points.permute(1, 0, 2) # (max_P, N, 2)
+ point_mask = batch.geo_points_mask
+ point_labels = (
+ batch.geo_point_labels.permute(1, 0)
+ if batch.geo_point_labels is not None
+ else torch.ones(
+ point_embeddings.shape[0], N_prompts, dtype=torch.long, device=device
+ )
+ )
+ else:
+ # For text-only mode: create empty tensors instead of None
+ # SAM3's geometry encoder cannot handle None for points
+ point_embeddings = torch.zeros(0, N_prompts, 2, device=device)
+ point_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device)
+ point_labels = torch.zeros(0, N_prompts, dtype=torch.long, device=device)
+
+ # Ensure box prompts also have empty tensors if None
+ if box_embeddings is None:
+ box_embeddings = torch.zeros(0, N_prompts, 4, device=device)
+ box_mask = torch.ones(N_prompts, 0, dtype=torch.bool, device=device)
+ box_labels = torch.zeros(0, N_prompts, dtype=torch.long, device=device)
+
+ return Prompt(
+ box_embeddings=box_embeddings,
+ box_mask=box_mask,
+ box_labels=box_labels,
+ point_embeddings=point_embeddings,
+ point_mask=point_mask,
+ point_labels=point_labels,
+ )
+
+ @torch.no_grad()
+ def inference(
+ self,
+ batch: WildDet3DInput,
+ score_threshold: float = 0.3,
+ nms_threshold: float = 0.5,
+ ) -> list[dict]:
+ """Run inference and decode 3D boxes.
+
+ Args:
+ batch: WildDet3DInput with images and prompts
+ score_threshold: Confidence threshold
+ nms_threshold: NMS IoU threshold
+
+ Returns:
+ List of dicts per image with decoded 3D boxes
+ """
+ self.eval()
+
+ out = self.forward(batch)
+
+ if self.roi2det3d is None or out.pred_boxes_3d is None:
+ return self._decode_2d_only(out, batch.img_ids, score_threshold)
+
+ # Decode 3D boxes using roi2det3d
+ H, W = batch.images.shape[2:]
+ intrinsics_per_prompt = batch.intrinsics[batch.img_ids]
+ results = self.roi2det3d(
+ pred_logits=out.pred_logits,
+ pred_boxes_2d=out.pred_boxes_2d,
+ pred_boxes_3d=out.pred_boxes_3d,
+ intrinsics=intrinsics_per_prompt,
+ image_size=(H, W),
+ img_ids=batch.img_ids,
+ score_threshold=score_threshold,
+ nms_threshold=nms_threshold,
+ )
+ return results
+
+ def _decode_2d_only(
+ self,
+ out: WildDet3DOut,
+ img_ids: Tensor,
+ score_threshold: float,
+ ) -> list[dict]:
+ """Decode 2D-only results when 3D head is not available."""
+ scores = out.pred_logits.sigmoid().squeeze(-1) # (N_prompts, S)
+ boxes = out.pred_boxes_2d # (N_prompts, S, 4) normalized xyxy
+
+ results = []
+ unique_img_ids = img_ids.unique()
+
+ for img_id in unique_img_ids:
+ mask = img_ids == img_id
+ img_scores = scores[mask].flatten()
+ img_boxes = boxes[mask].reshape(-1, 4)
+
+ keep = img_scores > score_threshold
+ results.append({
+ "scores": img_scores[keep],
+ "boxes_2d": img_boxes[keep],
+ "boxes_3d": None,
+ })
+
+ return results
+
+
+def build_wilddet3d(
+ sam3_checkpoint: str | None = None,
+ geometry_backend_type: str = "unidepth_v2",
+ hidden_dim: int = 256,
+ num_decoder_layers: int = 6,
+ device: str = "cuda",
+) -> WildDet3D:
+ """Factory function to build WildDet3D model.
+
+ Args:
+ sam3_checkpoint: Path to SAM3 checkpoint
+ geometry_backend_type: Type of geometry backend
+ hidden_dim: Hidden dimension for 3D head
+ num_decoder_layers: Number of decoder layers
+ device: Device to load model on
+
+ Returns:
+ Initialized WildDet3D model
+
+ Note:
+ Learning rate control is handled by param_groups in optimizer config,
+ not by freezing parameters.
+ """
+ from sam3.model.sam3_image import build_sam3_image
+ from wilddet3d.depth import GeometryBackendBase
+
+ # Build SAM3 model
+ sam3_model = build_sam3_image(checkpoint=sam3_checkpoint)
+ sam3_model = sam3_model.to(device)
+
+ # Build geometry backend
+ # Note: geometry backend construction depends on the specific backend type
+ # For now, this is a placeholder - users should construct the backend externally
+ geometry_backend = None
+
+ # Build 3D head
+ bbox3d_head = Det3DHead(
+ hidden_dim=hidden_dim,
+ num_layers=num_decoder_layers,
+ )
+
+ # Build box coder
+ box_coder = Det3DCoder()
+
+ # Build inference post-processor
+ roi2det3d = RoI2Det3D(box_coder=box_coder)
+
+ model = WildDet3D(
+ sam3_model=sam3_model,
+ bbox3d_head=bbox3d_head,
+ box_coder=box_coder,
+ geometry_backend=geometry_backend,
+ roi2det3d=roi2det3d,
+ )
+
+ return model.to(device)
diff --git a/wilddet3d/ops/__init__.py b/wilddet3d/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f43b8f2189da9b7d3906ec11138e0eb7a3ac920c
--- /dev/null
+++ b/wilddet3d/ops/__init__.py
@@ -0,0 +1 @@
+"""Operations and layers."""
diff --git a/wilddet3d/ops/__pycache__/__init__.cpython-311.pyc b/wilddet3d/ops/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75a2081652f643f2e3ee74d6379eac00e276e047
Binary files /dev/null and b/wilddet3d/ops/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/attention.cpython-311.pyc b/wilddet3d/ops/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50a698ee12f9b2aba6df8a21744b4f4dd7a266e6
Binary files /dev/null and b/wilddet3d/ops/__pycache__/attention.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/box2d.cpython-311.pyc b/wilddet3d/ops/__pycache__/box2d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa43381a26ff7d207d2c97175a33ad4e7df9d4dc
Binary files /dev/null and b/wilddet3d/ops/__pycache__/box2d.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/mlp.cpython-311.pyc b/wilddet3d/ops/__pycache__/mlp.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4661d22de7cc3762130c1ae31c785b43f2ec3cf
Binary files /dev/null and b/wilddet3d/ops/__pycache__/mlp.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/nystrom.cpython-311.pyc b/wilddet3d/ops/__pycache__/nystrom.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c0788d650bf627d4ce5a5f5b8098d4debb71eb6e
Binary files /dev/null and b/wilddet3d/ops/__pycache__/nystrom.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/profiler.cpython-311.pyc b/wilddet3d/ops/__pycache__/profiler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d2704871f91c73fb42c6b1c242180b4a39d5565
Binary files /dev/null and b/wilddet3d/ops/__pycache__/profiler.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/ray.cpython-311.pyc b/wilddet3d/ops/__pycache__/ray.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c54126675e1b0127cd2af998dd2da6b7928adeb
Binary files /dev/null and b/wilddet3d/ops/__pycache__/ray.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/rotation.cpython-311.pyc b/wilddet3d/ops/__pycache__/rotation.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5db103271b6baff2bbdee82a3754ef3a16809b1
Binary files /dev/null and b/wilddet3d/ops/__pycache__/rotation.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/upsample.cpython-311.pyc b/wilddet3d/ops/__pycache__/upsample.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..528c02c2784e5edb615b339d9dbc0c5ef29f39f4
Binary files /dev/null and b/wilddet3d/ops/__pycache__/upsample.cpython-311.pyc differ
diff --git a/wilddet3d/ops/__pycache__/util.cpython-311.pyc b/wilddet3d/ops/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d8ee3f3acc4625bff40096db60d4097ab6ec24d
Binary files /dev/null and b/wilddet3d/ops/__pycache__/util.cpython-311.pyc differ
diff --git a/wilddet3d/ops/attention.py b/wilddet3d/ops/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ffb2fa5a54a8a5ce4f1b7634db16d25fb0cd51e
--- /dev/null
+++ b/wilddet3d/ops/attention.py
@@ -0,0 +1,284 @@
+"""Attention layer."""
+
+from functools import partial
+from math import log2, pi
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import Tensor, nn
+
+from .mlp import MLP
+from .nystrom import NystromAttention
+
+
+class LayerScale(nn.Module):
+ """Layer scale."""
+
+ def __init__(
+ self,
+ dim: int,
+ init_values: float | Tensor = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ """Initialize."""
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward."""
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class AttentionBlock(nn.Module):
+ """Attention block."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ ) -> None:
+ """Initialize."""
+ super().__init__()
+ self.num_heads = num_heads
+ self.hidden_dim = dim
+ self.context_dim = context_dim or dim
+
+ self.norm_attnx = nn.LayerNorm(self.hidden_dim)
+ self.norm_attnctx = nn.LayerNorm(self.context_dim)
+
+ self.q = nn.Linear(self.hidden_dim, self.hidden_dim)
+ self.kv = nn.Linear(self.context_dim, self.hidden_dim * 2)
+
+ self.cosine = cosine
+ self.dropout = dropout
+ self.out = nn.Linear(self.hidden_dim, self.hidden_dim)
+
+ self.ls1 = (
+ LayerScale(dim, layer_scale)
+ if layer_scale > 0.0
+ else nn.Identity()
+ )
+
+ self.mlp = MLP(
+ self.hidden_dim, expansion=expansion, dropout=dropout, gated=gated
+ )
+
+ self.ls2 = (
+ LayerScale(dim, layer_scale)
+ if layer_scale > 0.0
+ else nn.Identity()
+ )
+
+ def attn(
+ self,
+ x: Tensor,
+ attn_bias: Tensor | None = None,
+ context: Tensor | None = None,
+ pos_embed: Tensor | None = None,
+ pos_embed_context: Tensor | None = None,
+ rope: nn.Module | None = None,
+ ) -> Tensor:
+ """Attention."""
+ x = self.norm_attnx(x)
+
+ context = self.norm_attnctx(context)
+
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+ k, v = rearrange(
+ self.kv(context),
+ "b n (kv h d) -> b h n d kv",
+ h=self.num_heads,
+ kv=2,
+ ).unbind(dim=-1)
+
+ if rope is not None:
+ q = rope(q)
+ k = rope(k)
+ else:
+ if pos_embed is not None:
+ pos_embed = rearrange(
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
+ )
+ q = q + pos_embed
+
+ if pos_embed_context is not None:
+ pos_embed_context = rearrange(
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
+ )
+ k = k + pos_embed_context
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))
+
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+ )
+ x = rearrange(x, "b h n d -> b n (h d)")
+ x = self.out(x)
+ return x
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_bias: Tensor | None = None,
+ context: Tensor | None = None,
+ pos_embed: Tensor | None = None,
+ pos_embed_context: Tensor | None = None,
+ rope: nn.Module | None = None,
+ ) -> Tensor:
+ """Forward."""
+ context = x if context is None else context
+
+ x = (
+ self.ls1(
+ self.attn(
+ x,
+ rope=rope,
+ attn_bias=attn_bias,
+ context=context,
+ pos_embed=pos_embed,
+ pos_embed_context=pos_embed_context,
+ )
+ )
+ + x
+ )
+
+ return self.ls2(self.mlp(x)) + x
+
+
+class NystromBlock(AttentionBlock):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ ):
+ super().__init__(
+ dim=dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ cosine=cosine,
+ gated=gated,
+ layer_scale=layer_scale,
+ context_dim=context_dim,
+ )
+ self.attention_fn = NystromAttention(
+ num_landmarks=128, num_heads=num_heads, dropout=dropout
+ )
+
+ def attn(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ ) -> torch.Tensor:
+ x = self.norm_attnx(x)
+ context = self.norm_attnctx(context)
+ k, v = rearrange(
+ self.kv(context),
+ "b n (kv h d) -> b n h d kv",
+ h=self.num_heads,
+ kv=2,
+ ).unbind(dim=-1)
+ q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
+
+ if rope is not None:
+ q = rope(q)
+ k = rope(k)
+ else:
+ if pos_embed is not None:
+ pos_embed = rearrange(
+ pos_embed, "b n (h d) -> b n h d", h=self.num_heads
+ )
+ q = q + pos_embed
+ if pos_embed_context is not None:
+ pos_embed_context = rearrange(
+ pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
+ )
+ k = k + pos_embed_context
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k))
+ x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
+ x = rearrange(x, "b n h d -> b n (h d)")
+ x = self.out(x)
+ return x
+
+
+class PositionEmbeddingSine(nn.Module):
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * pi
+ self.scale = scale
+
+ def forward(
+ self, x: torch.Tensor, mask: Tensor | None = None
+ ) -> torch.Tensor:
+ if mask is None:
+ mask = torch.zeros(
+ (x.size(0), x.size(2), x.size(3)),
+ device=x.device,
+ dtype=torch.bool,
+ )
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(
+ self.num_pos_feats, dtype=torch.float32, device=x.device
+ )
+ dim_t = self.temperature ** (
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
+ )
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self, _repr_indent=4):
+ head = "Positional encoding " + self.__class__.__name__
+ body = [
+ "num_pos_feats: {}".format(self.num_pos_feats),
+ "temperature: {}".format(self.temperature),
+ "normalize: {}".format(self.normalize),
+ "scale: {}".format(self.scale),
+ ]
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
diff --git a/wilddet3d/ops/box2d.py b/wilddet3d/ops/box2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..36bbb243d68c3c4ab8e054defd1c13cd8a379c8a
--- /dev/null
+++ b/wilddet3d/ops/box2d.py
@@ -0,0 +1,101 @@
+"""Box operations for 2D bounding boxes."""
+
+import numpy as np
+import torch
+from torch import Tensor
+
+
+def fp16_clamp(x, min=None, max=None):
+ if not x.is_cuda and x.dtype == torch.float16:
+ return x.float().clamp(min, max).half()
+ return x.clamp(min, max)
+
+
+def bbox_cxcywh_to_xyxy(bbox: Tensor) -> Tensor:
+ """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2)."""
+ cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
+ return torch.cat(bbox_new, dim=-1)
+
+
+def bbox_xyxy_to_cxcywh(bbox: Tensor) -> Tensor:
+ """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h)."""
+ x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
+ return torch.cat(bbox_new, dim=-1)
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode="iou", is_aligned=False, eps=1e-6):
+ """Calculate overlap between two set of bboxes."""
+ assert mode in ["iou", "iof", "giou"], f"Unsupported mode {mode}"
+ assert bboxes1.size(-1) == 4 or bboxes1.size(0) == 0
+ assert bboxes2.size(-1) == 4 or bboxes2.size(0) == 0
+
+ assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
+ batch_shape = bboxes1.shape[:-2]
+
+ rows = bboxes1.size(-2)
+ cols = bboxes2.size(-2)
+ if is_aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ if is_aligned:
+ return bboxes1.new(batch_shape + (rows,))
+ else:
+ return bboxes1.new(batch_shape + (rows, cols))
+
+ area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
+ bboxes1[..., 3] - bboxes1[..., 1]
+ )
+ area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
+ bboxes2[..., 3] - bboxes2[..., 1]
+ )
+
+ if is_aligned:
+ lt = torch.max(bboxes1[..., :2], bboxes2[..., :2])
+ rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:])
+
+ wh = fp16_clamp(rb - lt, min=0)
+ overlap = wh[..., 0] * wh[..., 1]
+
+ if mode in ["iou", "giou"]:
+ union = area1 + area2 - overlap
+ else:
+ union = area1
+ if mode == "giou":
+ enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
+ enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
+ else:
+ lt = torch.max(
+ bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]
+ )
+ rb = torch.min(
+ bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]
+ )
+
+ wh = fp16_clamp(rb - lt, min=0)
+ overlap = wh[..., 0] * wh[..., 1]
+
+ if mode in ["iou", "giou"]:
+ union = area1[..., None] + area2[..., None, :] - overlap
+ else:
+ union = area1[..., None]
+ if mode == "giou":
+ enclosed_lt = torch.min(
+ bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]
+ )
+ enclosed_rb = torch.max(
+ bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]
+ )
+
+ eps = union.new_tensor([eps])
+ union = torch.max(union, eps)
+ ious = overlap / union
+ if mode in ["iou", "iof"]:
+ return ious
+ enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
+ enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
+ enclose_area = torch.max(enclose_area, eps)
+ gious = ious - (enclose_area - union) / enclose_area
+ return gious
diff --git a/wilddet3d/ops/box3d.py b/wilddet3d/ops/box3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2067700d2b73e4d84717034d83285846d047e4d
--- /dev/null
+++ b/wilddet3d/ops/box3d.py
@@ -0,0 +1,79 @@
+"""Box3D ops."""
+
+from torch import Tensor
+from vis4d_cuda_ops import iou_box3d
+
+from wilddet3d.ops.iou_box3d import check_coplanar, check_nonzero
+
+
+def box3d_overlap(
+ boxes_dt: Tensor,
+ boxes_gt: Tensor,
+ eps_coplanar: float = 1e-3,
+ eps_nonzero: float = 1e-8,
+) -> Tensor:
+ """
+ Computes the intersection of 3D boxes_dt and boxes_gt.
+
+ Inputs boxes_dt, boxes_gt are tensors of shape (B, 8, 3)
+ (where B doesn't have to be the same for boxes_dt and boxes_gt),
+ containing the 8 corners of the boxes, as follows:
+
+ (4) +---------+. (5)
+ | ` . | ` .
+ | (0) +---+-----+ (1)
+ | | | |
+ (7) +-----+---+. (6)|
+ ` . | ` . |
+ (3) ` +---------+ (2)
+
+
+ NOTE: Throughout this implementation, we assume that boxes
+ are defined by their 8 corners exactly in the order specified in the
+ diagram above for the function to give correct results. In addition
+ the vertices on each plane must be coplanar.
+ As an alternative to the diagram, this is a unit bounding
+ box which has the correct vertex ordering:
+
+ box_corner_vertices = [
+ [0, 0, 0],
+ [1, 0, 0],
+ [1, 1, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [1, 0, 1],
+ [1, 1, 1],
+ [0, 1, 1],
+ ]
+
+ Args:
+ boxes_dt: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
+ boxes_gt: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
+ Returns:
+ iou: (N, M) tensor of the intersection over union which is
+ defined as: `iou = vol / (vol1 + vol2 - vol)`
+ """
+ # Make sure predictions are coplanar and nonzero
+ invalid_coplanar = ~check_coplanar(boxes_dt, eps=eps_coplanar)
+ invalid_nonzero = ~check_nonzero(boxes_dt, eps=eps_nonzero)
+
+ ious = iou_box3d(boxes_dt, boxes_gt)[1]
+
+ # Offending boxes are set to zero IoU
+ if invalid_coplanar.any():
+ ious[invalid_coplanar] = 0
+ print(
+ "Warning: skipping {:d} non-coplanar boxes at eval.".format(
+ int(invalid_coplanar.float().sum())
+ )
+ )
+
+ if invalid_nonzero.any():
+ ious[invalid_nonzero] = 0
+ print(
+ "Warning: skipping {:d} zero volume boxes at eval.".format(
+ int(invalid_nonzero.float().sum())
+ )
+ )
+
+ return ious
diff --git a/wilddet3d/ops/iou_box3d.py b/wilddet3d/ops/iou_box3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..95c3f745d697f716979112e30ff50d0af6aca624
--- /dev/null
+++ b/wilddet3d/ops/iou_box3d.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-unsafe
+
+from typing import Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from vis4d_cuda_ops import iou_box3d
+
+# -------------------------------------------------- #
+# CONSTANTS #
+# -------------------------------------------------- #
+"""
+_box_planes and _box_triangles define the 4- and 3-connectivity
+of the 8 box corners.
+_box_planes gives the quad faces of the 3D box
+_box_triangles gives the triangle faces of the 3D box
+"""
+_box_planes = [
+ [0, 1, 2, 3],
+ [3, 2, 6, 7],
+ [0, 1, 5, 4],
+ [0, 3, 7, 4],
+ [1, 2, 6, 5],
+ [4, 5, 6, 7],
+]
+_box_triangles = [
+ [0, 1, 2],
+ [0, 3, 2],
+ [4, 5, 6],
+ [4, 6, 7],
+ [1, 5, 6],
+ [1, 6, 2],
+ [0, 4, 7],
+ [0, 7, 3],
+ [3, 2, 6],
+ [3, 6, 7],
+ [0, 1, 5],
+ [0, 4, 5],
+]
+
+
+def check_coplanar(boxes: Tensor, eps: float = 1e-4) -> torch.BoolTensor:
+ """
+ Checks that plane vertices are coplanar.
+ Returns a bool tensor of size B, where True indicates a box is coplanar.
+ """
+ faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
+ verts = boxes.index_select(index=faces.view(-1), dim=1)
+ B = boxes.shape[0]
+ P, V = faces.shape
+ # (B, P, 4, 3) -> (B, P, 3)
+ v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
+
+ # Compute the normal
+ e0 = F.normalize(v1 - v0, dim=-1)
+ e1 = F.normalize(v2 - v0, dim=-1)
+ normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
+
+ # Check the fourth vertex is also on the same plane
+ mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3)
+ mat2 = normal.view(B, -1, 1) # (B, P*3, 1)
+
+ return (mat1.bmm(mat2).abs() < eps).view(B)
+
+
+def check_nonzero(boxes: Tensor, eps: float = 1e-4) -> torch.BoolTensor:
+ """
+ Checks that the sides of the box have a non zero area
+ """
+ faces = torch.tensor(
+ _box_triangles, dtype=torch.int64, device=boxes.device
+ )
+ verts = boxes.index_select(index=faces.view(-1), dim=1)
+ B = boxes.shape[0]
+ T, V = faces.shape
+ # (B, T, 3, 3) -> (B, T, 3)
+ v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)
+
+ normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3)
+ face_areas = normals.norm(dim=-1) / 2
+
+ return (face_areas > eps).all(1).view(B)
+
+
+class _box3d_overlap(Function):
+ """
+ Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
+ Backward is not supported.
+ """
+
+ @staticmethod
+ def forward(ctx, boxes1, boxes2):
+ """
+ Arguments defintions the same as in the box3d_overlap function
+ """
+ vol, iou = iou_box3d(boxes1, boxes2)
+ return vol, iou
+
+ @staticmethod
+ def backward(ctx, grad_vol, grad_iou):
+ raise ValueError("box3d_overlap backward is not supported")
+
+
+def box3d_overlap(
+ boxes1: Tensor, boxes2: Tensor, eps: float = 1e-4
+) -> Tuple[Tensor, Tensor]:
+ """
+ Computes the intersection of 3D boxes1 and boxes2.
+
+ Inputs boxes1, boxes2 are tensors of shape (B, 8, 3)
+ (where B doesn't have to be the same for boxes1 and boxes2),
+ containing the 8 corners of the boxes, as follows:
+
+ (4) +---------+. (5)
+ | ` . | ` .
+ | (0) +---+-----+ (1)
+ | | | |
+ (7) +-----+---+. (6)|
+ ` . | ` . |
+ (3) ` +---------+ (2)
+
+
+ NOTE: Throughout this implementation, we assume that boxes
+ are defined by their 8 corners exactly in the order specified in the
+ diagram above for the function to give correct results. In addition
+ the vertices on each plane must be coplanar.
+ As an alternative to the diagram, this is a unit bounding
+ box which has the correct vertex ordering:
+
+ box_corner_vertices = [
+ [0, 0, 0],
+ [1, 0, 0],
+ [1, 1, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [1, 0, 1],
+ [1, 1, 1],
+ [0, 1, 1],
+ ]
+
+ Args:
+ boxes1: tensor of shape (N, 8, 3) of the coordinates of the 1st boxes
+ boxes2: tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes
+ Returns:
+ vol: (N, M) tensor of the volume of the intersecting convex shapes
+ iou: (N, M) tensor of the intersection over union which is
+ defined as: `iou = vol / (vol1 + vol2 - vol)`
+ """
+ if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):
+ raise ValueError("Each box in the batch must be of shape (8, 3)")
+
+ if not check_coplanar(boxes1, eps):
+ raise ValueError("boxes1 plane vertices are not coplanar")
+
+ if not check_coplanar(boxes2, eps):
+ raise ValueError("boxes2 plane vertices are not coplanar")
+
+ if not check_nonzero(boxes1, eps):
+ raise ValueError("boxes1 planes have zero areas")
+
+ if not check_nonzero(boxes2, eps):
+ raise ValueError("boxes2 planes have zero areas")
+
+ vol, iou = _box3d_overlap.apply(boxes1, boxes2)
+
+ return vol, iou
diff --git a/wilddet3d/ops/language/__init__.py b/wilddet3d/ops/language/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/ops/language/grounding.py b/wilddet3d/ops/language/grounding.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2afcc80796d496be319e700ec04a53d2fee407c
--- /dev/null
+++ b/wilddet3d/ops/language/grounding.py
@@ -0,0 +1,206 @@
+"""Language grounding utilities."""
+
+import re
+
+import nltk
+import torch
+from torch import Tensor
+from transformers import BatchEncoding
+from vis4d.common.logging import rank_zero_info, rank_zero_warn
+
+
+def find_noun_phrases(caption: str) -> list:
+ """Find noun phrases in a caption using nltk.
+ Args:
+ caption (str): The caption to analyze.
+
+ Returns:
+ list: List of noun phrases found in the caption.
+
+ Examples:
+ >>> caption = 'There is two cat and a remote in the picture'
+ >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
+ """
+ caption = caption.lower()
+ tokens = nltk.word_tokenize(caption)
+ pos_tags = nltk.pos_tag(tokens)
+
+ grammar = "NP: {?*+}"
+ cp = nltk.RegexpParser(grammar)
+ result = cp.parse(pos_tags)
+
+ noun_phrases = []
+ for subtree in result.subtrees():
+ if subtree.label() == "NP":
+ noun_phrases.append(" ".join(t[0] for t in subtree.leaves()))
+
+ return noun_phrases
+
+
+def remove_punctuation(text: str) -> str:
+ """Remove punctuation from a text.
+ Args:
+ text (str): The input text.
+
+ Returns:
+ str: The text with punctuation removed.
+ """
+ punctuation = [
+ "|",
+ ":",
+ ";",
+ "@",
+ "(",
+ ")",
+ "[",
+ "]",
+ "{",
+ "}",
+ "^",
+ "'",
+ '"',
+ "’",
+ "`",
+ "?",
+ "$",
+ "%",
+ "#",
+ "!",
+ "&",
+ "*",
+ "+",
+ ",",
+ ".",
+ ]
+ for p in punctuation:
+ text = text.replace(p, "")
+ return text.strip()
+
+
+def run_ner(caption: str) -> tuple[list[list[int]], list[str]]:
+ """Run NER on a caption and return the tokens and noun phrases.
+ Args:
+ caption (str): The input caption.
+
+ Returns:
+ Tuple[List, List]: A tuple containing the tokens and noun phrases.
+ - tokens_positive (List): A list of token positions.
+ - noun_phrases (List): A list of noun phrases.
+ """
+ noun_phrases = find_noun_phrases(caption)
+ noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
+ noun_phrases = [phrase for phrase in noun_phrases if phrase != ""]
+ rank_zero_info("noun_phrases:", noun_phrases)
+ relevant_phrases = noun_phrases
+ labels = noun_phrases
+
+ tokens_positive = []
+ for entity, label in zip(relevant_phrases, labels):
+ try:
+ # search all occurrences and mark them as different entities
+ # TODO: Not Robust
+ for m in re.finditer(entity, caption.lower()):
+ tokens_positive.append([[m.start(), m.end()]])
+ except Exception:
+ rank_zero_warn("noun entities:", noun_phrases)
+ rank_zero_warn("entity:", entity)
+ rank_zero_warn("caption:", caption.lower())
+ return tokens_positive, noun_phrases
+
+
+def create_positive_map(
+ tokenized: BatchEncoding,
+ tokens_positive: list[list[int]],
+ max_num_entities: int = 256,
+) -> Tensor:
+ """construct a map such that positive_map[i,j] = True
+ if box i is associated to token j
+
+ Args:
+ tokenized: The tokenized input.
+ tokens_positive (list): A list of token ranges
+ associated with positive boxes.
+ max_num_entities (int, optional): The maximum number of entities.
+ Defaults to 256.
+
+ Returns:
+ torch.Tensor: The positive map.
+
+ Raises:
+ Exception: If an error occurs during token-to-char mapping.
+ """
+ positive_map = torch.zeros(
+ (len(tokens_positive), max_num_entities), dtype=torch.float
+ )
+
+ for j, tok_list in enumerate(tokens_positive):
+ for beg, end in tok_list:
+ try:
+ beg_pos = tokenized.char_to_token(beg)
+ end_pos = tokenized.char_to_token(end - 1)
+ except Exception as e:
+ print("beg:", beg, "end:", end)
+ print("token_positive:", tokens_positive)
+ raise e
+ if beg_pos is None:
+ try:
+ beg_pos = tokenized.char_to_token(beg + 1)
+ if beg_pos is None:
+ beg_pos = tokenized.char_to_token(beg + 2)
+ except Exception:
+ beg_pos = None
+ if end_pos is None:
+ try:
+ end_pos = tokenized.char_to_token(end - 2)
+ if end_pos is None:
+ end_pos = tokenized.char_to_token(end - 3)
+ except Exception:
+ end_pos = None
+ if beg_pos is None or end_pos is None:
+ continue
+
+ assert beg_pos is not None and end_pos is not None
+ positive_map[j, beg_pos : end_pos + 1].fill_(1)
+ return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
+
+
+def create_positive_map_label_to_token(
+ positive_map: Tensor, plus: int = 0
+) -> dict:
+ """Create a dictionary mapping the label to the token.
+ Args:
+ positive_map (Tensor): The positive map tensor.
+ plus (int, optional): Value added to the label for indexing.
+ Defaults to 0.
+
+ Returns:
+ dict: The dictionary mapping the label to the token.
+ """
+ positive_map_label_to_token = {}
+ for i in range(len(positive_map)):
+ positive_map_label_to_token[i + plus] = torch.nonzero(
+ positive_map[i], as_tuple=True
+ )[0].tolist()
+ return positive_map_label_to_token
+
+
+def clean_label_name(name: str) -> str:
+ """Clean label name."""
+ name = re.sub(r"\(.*\)", "", name)
+ name = re.sub(r"_", " ", name)
+ name = re.sub(r" ", " ", name)
+ return name
+
+
+def chunks(lst: list, n: int) -> list:
+ """Yield successive n-sized chunks from lst."""
+ all_ = []
+ for i in range(0, len(lst), n):
+ data_index = lst[i : i + n]
+ all_.append(data_index)
+ counter = 0
+ for i in all_:
+ counter += len(i)
+ assert counter == len(lst)
+
+ return all_
diff --git a/wilddet3d/ops/match_cost.py b/wilddet3d/ops/match_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1cc1341913b130cff92343f85ce5728f44ccba1
--- /dev/null
+++ b/wilddet3d/ops/match_cost.py
@@ -0,0 +1,273 @@
+"""Matcher cost op."""
+
+import torch
+from torch import Tensor
+from vis4d.op.box.box2d import bbox_iou
+
+from wilddet3d.ops.box2d import bbox_overlaps, bbox_xyxy_to_cxcywh
+
+
+class MatchCost:
+
+ def __init__(self, weight: float = 1.0) -> None:
+ """Create an instance of the class."""
+ self.weight = weight
+
+
+class ClassificationCost(MatchCost):
+ """ClsSoftmaxCost.
+
+ Args:
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+
+ Examples:
+ >>> from mmdet.models.task_modules.assigners.
+ ... match_costs.match_cost import ClassificationCost
+ >>> import torch
+ >>> self = ClassificationCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3430, -0.3525, -0.3045],
+ [-0.3077, -0.2931, -0.3992],
+ [-0.3664, -0.3455, -0.2881],
+ [-0.3343, -0.2701, -0.3956]])
+ """
+
+ def __init__(self, weight: float = 1.0) -> None:
+ """Create an instance of the class."""
+ super().__init__(weight=weight)
+
+ def __call__(self, cls_pred, gt_labels) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (:obj:`InstanceData`): ``scores`` inside is
+ predicted classification logits, of shape
+ (num_queries, num_class).
+ gt_instances (:obj:`InstanceData`): ``labels`` inside should have
+ shape (num_gt, ).
+ img_meta (Optional[dict]): _description_. Defaults to None.
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ pred_scores = cls_pred.softmax(-1)
+ cls_cost = -pred_scores[:, gt_labels]
+
+ return cls_cost * self.weight
+
+
+class BBoxL1Cost(MatchCost):
+ """BBoxL1Cost.
+
+ Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy'
+ and its coordinates are unnormalized.
+
+ Args:
+ box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN.
+ Defaults to 'xyxy'.
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+
+ Examples:
+ >>> from mmdet.models.task_modules.assigners.
+ ... match_costs.match_cost import BBoxL1Cost
+ >>> import torch
+ >>> self = BBoxL1Cost()
+ >>> bbox_pred = torch.rand(1, 4)
+ >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(bbox_pred, gt_bboxes, factor)
+ tensor([[1.6172, 1.6422]])
+ """
+
+ def __init__(self, box_format: str = "xyxy", weight: float = 1.0) -> None:
+ """Create an instance of the class."""
+ super().__init__(weight=weight)
+ assert box_format in ["xyxy", "xywh"]
+ self.box_format = box_format
+
+ def __call__(
+ self,
+ pred_bboxes,
+ gt_bboxes,
+ img_h,
+ img_w,
+ ) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (:obj:`InstanceData`): ``bboxes`` inside is
+ predicted boxes with unnormalized coordinate
+ (x, y, x, y).
+ gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt
+ bboxes with unnormalized coordinate (x, y, x, y).
+ img_meta (Optional[dict]): Image information. Defaults to None.
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ # convert box format
+ if self.box_format == "xywh":
+ gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
+ pred_bboxes = bbox_xyxy_to_cxcywh(pred_bboxes)
+
+ # normalized
+ factor = gt_bboxes.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(
+ 0
+ )
+ gt_bboxes = gt_bboxes / factor
+ pred_bboxes = pred_bboxes / factor
+
+ bbox_cost = torch.cdist(pred_bboxes, gt_bboxes, p=1)
+
+ return bbox_cost * self.weight
+
+
+class IoUCost(MatchCost):
+ """IoUCost.
+
+ Note: ``bboxes`` in ``InstanceData`` passed in is of format 'xyxy'
+ and its coordinates are unnormalized.
+
+ Args:
+ iou_mode (str): iou mode such as 'iou', 'giou'. Defaults to 'giou'.
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+
+ Examples:
+ >>> from mmdet.models.task_modules.assigners.
+ ... match_costs.match_cost import IoUCost
+ >>> import torch
+ >>> self = IoUCost()
+ >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
+ >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> self(bboxes, gt_bboxes)
+ tensor([[-0.1250, 0.1667],
+ [ 0.1667, -0.5000]])
+ """
+
+ def __init__(self, iou_mode: str = "giou", weight: float = 1.0):
+ super().__init__(weight=weight)
+ self.iou_mode = iou_mode
+
+ def __call__(
+ self,
+ pred_bboxes,
+ gt_bboxes,
+ ):
+ """Compute match cost.
+
+ Args:
+ pred_instances (:obj:`InstanceData`): ``bboxes`` inside is
+ predicted boxes with unnormalized coordinate
+ (x, y, x, y).
+ gt_instances (:obj:`InstanceData`): ``bboxes`` inside is gt
+ bboxes with unnormalized coordinate (x, y, x, y).
+ img_meta (Optional[dict]): Image information. Defaults to None.
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ # avoid fp16 overflow
+ if pred_bboxes.dtype == torch.float16:
+ fp16 = True
+ pred_bboxes = pred_bboxes.to(torch.float32)
+ else:
+ fp16 = False
+
+ if self.iou_mode == "iou":
+ overlaps = bbox_iou(pred_bboxes, gt_bboxes)
+ else:
+ overlaps = bbox_overlaps(
+ pred_bboxes, gt_bboxes, mode=self.iou_mode
+ )
+
+ if fp16:
+ overlaps = overlaps.to(torch.float16)
+
+ # The 1 is a constant that doesn't change the matching, so omitted.
+ iou_cost = -overlaps
+ return iou_cost * self.weight
+
+
+class BinaryFocalLossCost(MatchCost):
+ """BinaryFocalLossCost.
+
+ Args:
+ alpha (Union[float, int]): focal_loss alpha. Defaults to 0.25.
+ gamma (Union[float, int]): focal_loss gamma. Defaults to 2.
+ eps (float): Defaults to 1e-12.
+ binary_input (bool): Whether the input is binary. Currently,
+ binary_input = True is for masks input, binary_input = False
+ is for label input. Defaults to False.
+ weight (Union[float, int]): Cost weight. Defaults to 1.
+ """
+
+ def __init__(
+ self,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+ eps: float = 1e-12,
+ binary_input: bool = False,
+ weight: float = 1.0,
+ ) -> None:
+ super().__init__(weight=weight)
+ self.alpha = alpha
+ self.gamma = gamma
+ self.eps = eps
+ self.binary_input = binary_input
+
+ def _focal_loss_cost(self, cls_pred: Tensor, gt_labels: Tensor) -> Tensor:
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ (num_queries, num_class).
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ cls_pred = cls_pred.flatten(1)
+ gt_labels = gt_labels.flatten(1).float()
+ cls_pred = cls_pred.sigmoid()
+ neg_cost = (
+ -(1 - cls_pred + self.eps).log()
+ * (1 - self.alpha)
+ * cls_pred.pow(self.gamma)
+ )
+ pos_cost = (
+ -(cls_pred + self.eps).log()
+ * self.alpha
+ * (1 - cls_pred).pow(self.gamma)
+ )
+
+ cls_cost = torch.einsum(
+ "nc,mc->nm", pos_cost, gt_labels
+ ) + torch.einsum("nc,mc->nm", neg_cost, (1 - gt_labels))
+ return cls_cost * self.weight
+
+ def __call__(
+ self,
+ cls_pred: Tensor,
+ text_token_mask: Tensor,
+ positive_map: Tensor,
+ ) -> Tensor:
+ """Compute match cost.
+
+ Args:
+ pred_instances (:obj:`InstanceData`): Predicted instances which
+ must contain ``scores`` or ``masks``.
+ gt_instances (:obj:`InstanceData`): Ground truth which must contain
+ ``labels`` or ``mask``.
+ img_meta (Optional[dict]): Image information. Defaults to None.
+
+ Returns:
+ Tensor: Match Cost matrix of shape (num_preds, num_gts).
+ """
+ text_token_mask = torch.nonzero(text_token_mask).squeeze(-1)
+
+ pred_scores = cls_pred[:, text_token_mask]
+ gt_labels = positive_map[:, text_token_mask]
+
+ return self._focal_loss_cost(pred_scores, gt_labels)
diff --git a/wilddet3d/ops/matchers/__init__.py b/wilddet3d/ops/matchers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/ops/matchers/hungarian.py b/wilddet3d/ops/matchers/hungarian.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a2fd9f1756bd6fe04db2a8ebf0a01d5e3f92442
--- /dev/null
+++ b/wilddet3d/ops/matchers/hungarian.py
@@ -0,0 +1,117 @@
+"""Box Hungarian Assigner."""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from torch import Tensor
+from vis4d.op.box.box2d import bbox_iou
+from vis4d.op.box.matchers.base import MatchResult
+
+
+class HungarianMatcher:
+ """Computes one-to-one matching between predictions and ground truth.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The targets don't include the no_object, so generally
+ there are more predictions than targets. After the one-to-one matching, the
+ un-matched are treated as backgrounds. Thus each query prediction will be
+ assigned with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+ """
+
+ def __call__(
+ self,
+ cost: Tensor,
+ boxes: Tensor,
+ targets: Tensor,
+ target_classes: Tensor,
+ ) -> MatchResult:
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+
+ Args:
+ boxes (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ [num_query, 4].
+ boxes_classes (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ targets (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+ gt_labels (Tensor): Label of `targets`, shape (num_gt,).
+ img_meta (dict): Meta information for current image.
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ gt_depth is a single channel map
+ depth_pred is per-label maps
+
+ Returns:
+ MatchResult: Matching results.
+ """
+ num_gts, num_bboxes = targets.size(0), boxes.size(0)
+
+ match_iou = boxes.new_zeros((len(boxes),))
+
+ # 1. assign -1 by default
+ assigned_gt_inds = boxes.new_full((num_bboxes,), -1, dtype=torch.long)
+ assigned_labels = boxes.new_full((num_bboxes,), -1, dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return MatchResult(assigned_gt_inds, match_iou, assigned_labels)
+
+ # 2. compute the weighted costs.
+ # NOTE: We dissentangle the cost computation and Hungarian matching
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ cost = np.nan_to_num(cost)
+
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(boxes.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(boxes.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = target_classes[matched_col_inds]
+
+ pos_inds = (
+ torch.nonzero(assigned_gt_inds > 0, as_tuple=False)
+ .squeeze(-1)
+ .unique()
+ )
+
+ _ious = bbox_iou(boxes[pos_inds], targets)
+
+ for i, pid in enumerate(pos_inds):
+ matched_gt_idx = assigned_gt_inds[pid] - 1
+ match_iou[pid] = _ious[i, matched_gt_idx]
+
+ return MatchResult(assigned_gt_inds, match_iou, assigned_labels)
diff --git a/wilddet3d/ops/mlp.py b/wilddet3d/ops/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..60ee8cea981c806c37f5343744a1a914703041d5
--- /dev/null
+++ b/wilddet3d/ops/mlp.py
@@ -0,0 +1,67 @@
+"""Multi-layer perceptron (MLP)."""
+
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class SwiGLU(nn.Module):
+ """SwiGLU activation function."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x, gates = x.chunk(2, dim=-1)
+ return x * F.silu(gates)
+
+
+class MLP(nn.Module):
+ """Multi-layer perceptron (MLP) module."""
+
+ def __init__(
+ self,
+ input_dim: int,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ gated: bool = False,
+ output_dim: int | None = None,
+ ) -> None:
+ """Creates an instance of the class."""
+ super().__init__()
+ if gated:
+ expansion = int(expansion * 2 / 3)
+ hidden_dim = int(input_dim * expansion)
+ output_dim = output_dim if output_dim is not None else input_dim
+ self.norm = nn.LayerNorm(input_dim)
+ self.proj1 = nn.Linear(input_dim, hidden_dim)
+ self.proj2 = nn.Linear(hidden_dim, output_dim)
+ self.act = nn.GELU() if not gated else SwiGLU()
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass."""
+ x = self.norm(x)
+ x = self.proj1(x)
+ x = self.act(x)
+ x = self.proj2(x)
+ x = self.dropout(x)
+ return x
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Type definition for call implementation."""
+ return self._call_impl(x)
+
+
+class SimpleMLP(nn.Module):
+ """Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
diff --git a/wilddet3d/ops/nystrom.py b/wilddet3d/ops/nystrom.py
new file mode 100644
index 0000000000000000000000000000000000000000..669ee5707bd24f1fdd1f926bb93ca7e02ccfdee3
--- /dev/null
+++ b/wilddet3d/ops/nystrom.py
@@ -0,0 +1,374 @@
+"""Nystrom Attention.
+
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+"""
+
+import math
+import warnings
+from contextlib import nullcontext
+
+import torch
+from torch import Tensor, nn
+
+
+class AvgPool(nn.Module):
+ def __init__(self, n: int):
+ super().__init__()
+ self.n = n
+
+ def forward(self, x: Tensor):
+ seq_len = x.shape[1]
+ head_dim = x.shape[2]
+ segments = seq_len // self.n
+ assert (
+ segments > 0
+ ), "num_landmarks should be smaller than the sequence length"
+
+ if seq_len % self.n == 0:
+ return x.reshape(
+ -1,
+ self.n,
+ segments,
+ head_dim,
+ ).mean(dim=-2)
+
+ n_round = self.n - seq_len % self.n
+
+ x_avg_round = (
+ x[:, : n_round * segments, :]
+ .reshape(-1, n_round, segments, head_dim)
+ .mean(dim=-2)
+ )
+ x_avg_off = (
+ x[:, n_round * segments :, :]
+ .reshape(-1, self.n - n_round, segments + 1, head_dim)
+ .mean(dim=-2)
+ )
+ return torch.cat((x_avg_round, x_avg_off), dim=-2)
+
+
+def bmm(a: Tensor, b: Tensor) -> Tensor:
+ return a @ b
+
+
+def _apply_dropout(att, dropout):
+ if dropout is None:
+ return att
+ att = dropout(att)
+ return att
+
+
+def _matmul_with_mask(
+ a: Tensor,
+ b: Tensor,
+ mask: Tensor | None = None,
+) -> Tensor:
+ if mask is None:
+ return a @ b
+
+ att = a @ b
+ if mask.dtype == torch.bool:
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
+ att[~mask] = float("-inf")
+ else:
+ if (
+ mask.ndim == 3
+ and mask.shape[0] != att.shape[0]
+ and (att.shape[0] % mask.shape[0]) == 0
+ ):
+ repeat_factor = att.shape[0] // mask.shape[0]
+ mask = mask.repeat([repeat_factor, 1, 1])
+ warnings.warn(
+ "Mismatched batch dimensions for mask, repeating mask."
+ )
+ att += mask
+ return att
+
+
+def _softmax(a: Tensor) -> Tensor:
+ if a.is_sparse:
+ return torch.sparse.softmax(a, dim=a.ndim - 1)
+ return torch.softmax(a, dim=a.ndim - 1)
+
+
+def scaled_query_key_softmax(
+ q: Tensor,
+ k: Tensor,
+ att_mask: Tensor | None = None,
+) -> Tensor:
+ q = q / math.sqrt(k.size(-1))
+ mask = att_mask
+ att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
+ att = _softmax(att)
+ return att
+
+
+def scaled_dot_product_attention(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ att_mask: Tensor | None = None,
+ dropout: nn.Module | None = None,
+) -> Tensor:
+ autocast_disabled = att_mask is not None and att_mask.is_sparse
+
+ with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext():
+ if autocast_disabled:
+ q, k, v = q.float(), k.float(), v.float()
+
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
+ att = _apply_dropout(att, dropout)
+ y = bmm(att, v)
+ return y
+
+
+def bool_mask_to_additive(
+ mask: Tensor, dtype: torch.dtype | None = torch.float32
+) -> Tensor:
+ assert (
+ mask.dtype == torch.bool
+ ), "This util is meant to convert in between bool masks and additive ones"
+
+ mask_ = torch.zeros_like(mask, dtype=dtype)
+ mask_[~mask] = float("-inf")
+ return mask_
+
+
+def iterative_pinv(
+ softmax_mat: Tensor, n_iter=6, pinverse_original_init=False
+):
+ """Computing the Moore-Penrose inverse via iterative method."""
+ i = torch.eye(
+ softmax_mat.size(-1),
+ device=softmax_mat.device,
+ dtype=softmax_mat.dtype,
+ )
+ k = softmax_mat
+
+ if pinverse_original_init:
+ v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
+ else:
+ v = (
+ 1
+ / torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
+ * k.transpose(-1, -2)
+ )
+
+ for _ in range(n_iter):
+ kv = torch.matmul(k, v)
+ v = torch.matmul(
+ 0.25 * v,
+ 13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
+ )
+ return v
+
+
+def reshape_key_padding_mask(
+ key_padding_mask: Tensor, batched_dim: int
+) -> Tensor:
+ assert key_padding_mask.ndim == 2
+ batch_size, src_len = key_padding_mask.size()
+ num_heads = batched_dim // batch_size
+ return _reshape_key_padding_mask(
+ key_padding_mask, batch_size, src_len, num_heads
+ )
+
+
+def _reshape_key_padding_mask(
+ key_padding_mask: Tensor,
+ batch_size: int,
+ src_len: int,
+ num_heads: int,
+) -> Tensor:
+ assert key_padding_mask.shape == (batch_size, src_len)
+ key_padding_mask = (
+ key_padding_mask.view(batch_size, 1, 1, src_len)
+ .expand(-1, num_heads, -1, -1)
+ .reshape(batch_size * num_heads, 1, src_len)
+ )
+ return key_padding_mask
+
+
+class NystromAttention(nn.Module):
+ """Nystrom attention mechanism."""
+
+ def __init__(
+ self,
+ dropout: float,
+ num_heads: int,
+ num_landmarks: int = 64,
+ landmark_pooling: nn.Module | None = None,
+ causal: bool = False,
+ use_razavi_pinverse: bool = True,
+ pinverse_original_init: bool = False,
+ inv_iterations: int = 6,
+ v_skip_connection: nn.Module | None = None,
+ conv_kernel_size: int | int = None,
+ ):
+ """Creates an instance of the class."""
+ super().__init__()
+ self.requires_separate_masks = True
+ self.num_landmarks = num_landmarks
+ self.num_heads = num_heads
+ self.use_razavi_pinverse = use_razavi_pinverse
+ self.pinverse_original_init = pinverse_original_init
+ self.inv_iterations = inv_iterations
+ self.attn_drop = nn.Dropout(dropout)
+ self.skip_connection = v_skip_connection
+ self.causal = causal
+
+ if self.skip_connection is None and conv_kernel_size is not None:
+ self.skip_connection = nn.Conv2d(
+ in_channels=self.num_heads,
+ out_channels=self.num_heads,
+ kernel_size=(conv_kernel_size, 1),
+ padding=(conv_kernel_size // 2, 0),
+ bias=False,
+ groups=self.num_heads,
+ )
+
+ if landmark_pooling is not None:
+ self.landmark_pooling = landmark_pooling
+ else:
+ self.landmark_pooling = AvgPool(n=self.num_landmarks)
+
+ self.causal_mask_1: Tensor | None = None
+ self.causal_mask_2: Tensor | None = None
+ self.causal_mask_3: Tensor | None = None
+
+ self.supports_attention_mask = False
+ self.supports_key_padding_mask = True
+
+ def forward(
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ key_padding_mask: Tensor | None = None,
+ *args,
+ **kwargs,
+ ):
+ batched_dim = k.size(0)
+ seq_len = k.size(-2)
+ tt = {"dtype": q.dtype, "device": q.device}
+
+ if key_padding_mask is not None:
+ if key_padding_mask.dtype == torch.bool:
+ warnings.warn(
+ "Bool mask found, but an additive mask is expected. "
+ "Converting but this is slow"
+ )
+ key_padding_mask = bool_mask_to_additive(key_padding_mask)
+
+ if key_padding_mask.ndim == 2:
+ key_padding_mask = reshape_key_padding_mask(
+ key_padding_mask, batched_dim
+ )
+
+ zeros = torch.zeros_like(key_padding_mask)
+ ones = torch.ones_like(key_padding_mask)
+ is_masked = torch.isinf(-key_padding_mask)
+
+ _mask = torch.where(is_masked, zeros, ones)
+ _mask = _mask.transpose(2, 1)
+ assert _mask.shape == (batched_dim, q.shape[1], 1)
+
+ q = q * _mask
+ k = k * _mask
+
+ assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
+ f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
+ f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
+ )
+
+ if self.num_landmarks >= seq_len:
+ mask: Tensor | None = None
+
+ if self.causal:
+ mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
+
+ if key_padding_mask is not None:
+ mask = (
+ key_padding_mask
+ if mask is None
+ else mask + key_padding_mask
+ )
+
+ x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
+
+ else:
+ q_landmarks = self.landmark_pooling(q)
+ k_landmarks = self.landmark_pooling(k)
+
+ if self.causal and (
+ self.causal_mask_1 is None
+ or (batched_dim, seq_len, self.num_landmarks)
+ != self.causal_mask_1.size()
+ ):
+ self.causal_mask_1 = self._triu_mask(
+ batched_dim, seq_len, self.num_landmarks, **tt
+ )
+ self.causal_mask_2 = self._triu_mask(
+ batched_dim, self.num_landmarks, self.num_landmarks, **tt
+ )
+ self.causal_mask_3 = self._triu_mask(
+ batched_dim, self.num_landmarks, seq_len, **tt
+ )
+
+ mask_3: Tensor | None = self.causal_mask_3
+ if key_padding_mask is not None:
+ mask_3 = (
+ key_padding_mask
+ if mask_3 is None
+ else mask_3 + key_padding_mask
+ )
+
+ kernel_1 = scaled_query_key_softmax(
+ q=q, k=k_landmarks, att_mask=None
+ )
+ kernel_2 = scaled_query_key_softmax(
+ q=q_landmarks, k=k_landmarks, att_mask=None
+ )
+ kernel_3 = scaled_dot_product_attention(
+ q=q_landmarks, k=k, v=v, att_mask=mask_3
+ )
+
+ kernel_2_inv = (
+ iterative_pinv(
+ kernel_2, self.inv_iterations, self.pinverse_original_init
+ )
+ if self.use_razavi_pinverse
+ else torch.linalg.pinv(kernel_2)
+ )
+
+ x = torch.matmul(
+ torch.matmul(
+ kernel_1,
+ kernel_2_inv,
+ ),
+ kernel_3,
+ )
+
+ if self.skip_connection:
+ v_conv = self.skip_connection(
+ v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
+ )
+ x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
+ x = self.attn_drop(x)
+ return x
+
+ def _triu_mask(
+ self, dim_1: int, dim_2: int, dim_3: int, **kwargs
+ ) -> Tensor:
+ device = kwargs["device"]
+ dtype = kwargs["dtype"]
+
+ return torch.triu(
+ torch.ones(dim_2, dim_3, dtype=dtype, device=device)
+ * float("-inf"),
+ diagonal=1,
+ ).expand(dim_1, -1, -1)
diff --git a/wilddet3d/ops/profiler.py b/wilddet3d/ops/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..702b0f05f80677ba7fd998dc40c386091d686020
--- /dev/null
+++ b/wilddet3d/ops/profiler.py
@@ -0,0 +1,98 @@
+"""Training profiler for performance analysis.
+
+Usage:
+ Set environment variable PROFILE_WILDDET3D=1 to enable profiling.
+ Timing results are printed every N iterations.
+"""
+
+import os
+import time
+from collections import defaultdict
+from typing import Dict, List, Optional
+
+import torch
+import torch.distributed as dist
+
+
+class TrainingProfiler:
+ """Profiler for measuring training component timings."""
+
+ _instance: Optional["TrainingProfiler"] = None
+
+ def __init__(self, print_interval: int = 10, enabled: bool = True):
+ self.print_interval = print_interval
+ self.enabled = enabled
+ self.timings: Dict[str, List[float]] = defaultdict(list)
+ self.step_count = 0
+ self.current_step_timings: Dict[str, float] = {}
+ self._start_times: Dict[str, float] = {}
+
+ @classmethod
+ def get_instance(cls) -> "TrainingProfiler":
+ """Get singleton instance."""
+ if cls._instance is None:
+ enabled = os.environ.get("PROFILE_WILDDET3D", "0") == "1"
+ print_interval = int(os.environ.get("PROFILE_INTERVAL", "10"))
+ cls._instance = cls(print_interval=print_interval, enabled=enabled)
+ if enabled:
+ print(f"[TrainingProfiler] Enabled, printing every {print_interval} steps")
+ return cls._instance
+
+ def _is_main_process(self) -> bool:
+ import multiprocessing
+ current = multiprocessing.current_process()
+ return current.name == "MainProcess"
+
+ def _safe_cuda_sync(self) -> None:
+ if self._is_main_process() and torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ def start(self, name: str) -> None:
+ if not self.enabled:
+ return
+ if not self._is_main_process():
+ return
+ self._safe_cuda_sync()
+ self._start_times[name] = time.perf_counter()
+
+ def stop(self, name: str) -> float:
+ if not self.enabled:
+ return 0.0
+ if not self._is_main_process():
+ return 0.0
+ self._safe_cuda_sync()
+ elapsed = time.perf_counter() - self._start_times.get(name, time.perf_counter())
+ self.current_step_timings[name] = elapsed
+ return elapsed
+
+ def step(self) -> None:
+ if not self.enabled:
+ return
+ for name, elapsed in self.current_step_timings.items():
+ self.timings[name].append(elapsed)
+ self.step_count += 1
+
+ def _is_rank_zero(self) -> bool:
+ if not dist.is_initialized():
+ return True
+ return dist.get_rank() == 0
+
+
+def profiler() -> TrainingProfiler:
+ """Get the global profiler instance."""
+ return TrainingProfiler.get_instance()
+
+
+def profile_start(name: str) -> None:
+ """Start profiling a named section."""
+ TrainingProfiler.get_instance().start(name)
+
+
+def profile_stop(name: str) -> float:
+ """Stop profiling a named section and return elapsed time."""
+ return TrainingProfiler.get_instance().stop(name)
+
+
+def profile_step() -> None:
+ """Mark end of training step."""
+ TrainingProfiler.get_instance().step()
diff --git a/wilddet3d/ops/ray.py b/wilddet3d/ops/ray.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc7af6159d07a08e1ad770bd365bdc007a8c542
--- /dev/null
+++ b/wilddet3d/ops/ray.py
@@ -0,0 +1,771 @@
+"""Ray utilities for 3D reconstruction."""
+
+import torch
+from torch import Tensor
+from torch.nn import functional as F
+
+
+def generate_rays(
+ camera_intrinsics: Tensor,
+ image_shape: tuple[int, int],
+ noisy: bool = False,
+) -> tuple[Tensor, Tensor]:
+ """Generates rays from camera intrinsics and image shape."""
+ batch_size, device, dtype = (
+ camera_intrinsics.shape[0],
+ camera_intrinsics.device,
+ camera_intrinsics.dtype,
+ )
+
+ height, width = image_shape
+
+ # Generate grid of pixel coordinates
+ pixel_coords_x = torch.linspace(
+ 0, width - 1, width, device=device, dtype=dtype
+ )
+ pixel_coords_y = torch.linspace(
+ 0, height - 1, height, device=device, dtype=dtype
+ )
+
+ if noisy:
+ pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
+ pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
+
+ pixel_coords = torch.stack(
+ [
+ pixel_coords_x.repeat(height, 1),
+ pixel_coords_y.repeat(width, 1).t(),
+ ],
+ dim=2,
+ ) # (H, W, 2)
+ pixel_coords = pixel_coords + 0.5
+
+ # Calculate ray directions
+ intrinsics_inv = (
+ torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
+ )
+ intrinsics_inv[:, 0, 0] = 1.0 / camera_intrinsics[:, 0, 0]
+ intrinsics_inv[:, 1, 1] = 1.0 / camera_intrinsics[:, 1, 1]
+ intrinsics_inv[:, 0, 2] = (
+ -camera_intrinsics[:, 0, 2] / camera_intrinsics[:, 0, 0]
+ )
+ intrinsics_inv[:, 1, 2] = (
+ -camera_intrinsics[:, 1, 2] / camera_intrinsics[:, 1, 1]
+ )
+ homogeneous_coords = torch.cat(
+ [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
+ ) # (H, W, 3)
+ ray_directions = torch.matmul(
+ intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
+ ) # (3, H*W)
+ ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W)
+ ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3)
+
+ theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
+ phi = torch.acos(ray_directions[..., 1].clamp(-1.0, 1.0))
+ # pitch = torch.asin(ray_directions[..., 1])
+ # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
+ angles = torch.stack([theta, phi], dim=-1)
+ return ray_directions, angles
+
+
+def spherical_zbuffer_to_euclidean(
+ spherical_tensor: Tensor,
+) -> Tensor:
+ """Converts a spherical zbuffer tensor to euclidean coordinates."""
+ theta = spherical_tensor[..., 0] # Extract polar angle
+ phi = spherical_tensor[..., 1] # Extract azimuthal angle
+ z = spherical_tensor[..., 2] # Extract zbuffer depth
+
+ x = z * torch.tan(theta)
+ y = z / torch.tan(phi) / torch.cos(theta)
+
+ euclidean_tensor = torch.stack((x, y, z), dim=-1)
+ return euclidean_tensor
+
+
+def rsh_cart_3(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 3.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,16) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_8(xyz: Tensor):
+ """Computes all real spherical harmonics up to degree 8.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,81) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ # z4 = z2**2
+ return torch.stack(
+ [
+ 0.282094791773878
+ * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916
+ * y
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916
+ * x
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602
+ * (x2 - y2)
+ * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645
+ * (5197.5 * z2 - 472.5)
+ * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ -0.707162732524596
+ * y
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 2.6459606618019
+ * z
+ * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 9.98394571852353e-5
+ * y
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00239614697244565
+ * xy
+ * (x2 - y2)
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+ 0.00397356022507413
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ 3.25
+ * z
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.0561946276120613
+ * xy
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.206472245902897
+ * y
+ * (
+ -2.625
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
+ - 1.68564615005635
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 2.02901851395672
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.499450711127808 * z,
+ 0.206472245902897
+ * x
+ * (
+ -2.625
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 0.0280973138060306
+ * (x2 - y2)
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.00397356022507413
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ 3.25
+ * z
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.000599036743111412
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ * (-6.0 * x2 * y2 + x4 + y4),
+ 9.98394571852353e-5
+ * x
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 2.6459606618019
+ * z
+ * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -0.707162732524596
+ * x
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ 5.83141328139864
+ * xy
+ * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
+ -2.91570664069932
+ * yz
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 7.87853281621404e-6
+ * (1013512.5 * z2 - 67567.5)
+ * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 5.10587282657803e-5
+ * y
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00147275890257803
+ * xy
+ * (x2 - y2)
+ * (
+ 3.75
+ * z
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ - 14293.125 * z2
+ + 1299.375
+ ),
+ 0.0028519853513317
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
+ + 3.0
+ * z
+ * (
+ 3.25
+ * z
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ )
+ - 560.0 * z
+ ),
+ 0.0463392770473559
+ * xy
+ * (
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ + 2.5
+ * z
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ )
+ + 137.8125 * z2
+ - 19.6875
+ ),
+ 0.193851103820053
+ * y
+ * (
+ 3.2 * z * (1.5 - 7.5 * z2)
+ - 2.51428571428571
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ + 2.14285714285714
+ * z
+ * (
+ -2.625
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (
+ 2.33333333333333 * z * (1.5 - 7.5 * z2)
+ + 4.0 * z
+ )
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ )
+ + 5.48571428571429 * z
+ ),
+ 1.48417251362228
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.86581687426801
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 2.1808249179756
+ * z
+ * (
+ 1.14285714285714 * z * (1.5 * z2 - 0.5)
+ - 1.54285714285714
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 1.85714285714286
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.457142857142857 * z
+ )
+ - 0.954110901614325 * z2
+ + 0.318036967204775,
+ 0.193851103820053
+ * x
+ * (
+ 3.2 * z * (1.5 - 7.5 * z2)
+ - 2.51428571428571
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ + 2.14285714285714
+ * z
+ * (
+ -2.625
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (
+ 2.33333333333333 * z * (1.5 - 7.5 * z2)
+ + 4.0 * z
+ )
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ )
+ + 5.48571428571429 * z
+ ),
+ 0.0231696385236779
+ * (x2 - y2)
+ * (
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ + 2.5
+ * z
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ )
+ + 137.8125 * z2
+ - 19.6875
+ ),
+ 0.0028519853513317
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
+ + 3.0
+ * z
+ * (
+ 3.25
+ * z
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ )
+ - 560.0 * z
+ ),
+ 0.000368189725644507
+ * (-6.0 * x2 * y2 + x4 + y4)
+ * (
+ 3.75
+ * z
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ - 14293.125 * z2
+ + 1299.375
+ ),
+ 5.10587282657803e-5
+ * x
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 7.87853281621404e-6
+ * (1013512.5 * z2 - 67567.5)
+ * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -2.91570664069932
+ * xz
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ -20.4099464848952 * x2**3 * y2
+ - 20.4099464848952 * x2 * y2**3
+ + 0.72892666017483 * x4**2
+ + 51.0248662122381 * x4 * y4
+ + 0.72892666017483 * y4**2,
+ ],
+ -1,
+ )
diff --git a/wilddet3d/ops/rotation.py b/wilddet3d/ops/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ec75eb0e6f0887f33de3fd9cf077a9948cfc7c
--- /dev/null
+++ b/wilddet3d/ops/rotation.py
@@ -0,0 +1,198 @@
+"""Rotation ops."""
+
+from __future__ import annotations
+
+import math
+
+import torch
+from torch import Tensor
+from torch.nn import functional as F
+from vis4d.op.geometry.rotation import quaternion_to_matrix
+
+DEFAULT_ACOS_BOUND: float = 1.0 - 1e-4
+
+
+def _acos_linear_approximation(x: Tensor, x0: float) -> Tensor:
+ return (x - x0) * _dacos_dx(x0) + math.acos(x0)
+
+
+def _dacos_dx(x: float) -> float:
+ return (-1.0) / math.sqrt(1.0 - x * x)
+
+
+def acos_linear_extrapolation(
+ x: Tensor,
+ bounds: tuple[float, float] = (-DEFAULT_ACOS_BOUND, DEFAULT_ACOS_BOUND),
+) -> Tensor:
+ """Implements arccos(x) with linear extrapolation outside (-1, 1)."""
+ lower_bound, upper_bound = bounds
+
+ if lower_bound > upper_bound:
+ raise ValueError(
+ "lower bound has to be smaller or equal to upper bound."
+ )
+
+ if lower_bound <= -1.0 or upper_bound >= 1.0:
+ raise ValueError(
+ "Both lower bound and upper bound have to be within (-1, 1)."
+ )
+
+ acos_extrap = torch.empty_like(x)
+ x_upper = x >= upper_bound
+ x_lower = x <= lower_bound
+ x_mid = (~x_upper) & (~x_lower)
+
+ acos_extrap[x_mid] = torch.acos(x[x_mid])
+ acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound)
+ acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound)
+
+ return acos_extrap
+
+
+def so3_rotation_angle(
+ R: Tensor,
+ eps: float = 1e-4,
+ cos_angle: bool = False,
+ cos_bound: float = 1e-4,
+) -> Tensor:
+ """Calculates angles (in radians) of a batch of rotation matrices."""
+ _, dim1, dim2 = R.shape
+ if dim1 != 3 or dim2 != 3:
+ raise ValueError("Input has to be a batch of 3x3 Tensors.")
+
+ rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
+
+ if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
+ raise ValueError(
+ "A matrix has trace outside valid range [-1-eps,3+eps]."
+ )
+
+ phi_cos = (rot_trace - 1.0) * 0.5
+
+ if cos_angle:
+ return phi_cos
+ else:
+ if cos_bound > 0.0:
+ bound = 1.0 - cos_bound
+ return acos_linear_extrapolation(phi_cos, (-bound, bound))
+ else:
+ return torch.acos(phi_cos)
+
+
+def so3_relative_angle(
+ R1: Tensor,
+ R2: Tensor,
+ cos_angle: bool = False,
+ cos_bound: float = 1e-4,
+ eps: float = 1e-4,
+) -> Tensor:
+ """Calculates the relative angle between pairs of rotation matrices."""
+ R12 = torch.bmm(R1, R2.permute(0, 2, 1))
+ return so3_rotation_angle(
+ R12, cos_angle=cos_angle, cos_bound=cos_bound, eps=eps
+ )
+
+
+def axis_angle_to_quaternion(axis_angle: Tensor) -> Tensor:
+ """Convert rotations given as axis/angle to quaternions."""
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
+ half_angles = angles * 0.5
+ eps = 1e-6
+ small_angles = angles.abs() < eps
+ sin_half_angles_over_angles = torch.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
+ )
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ )
+ quaternions = torch.cat(
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],
+ dim=-1,
+ )
+ return quaternions
+
+
+def axis_angle_to_matrix(axis_angle: Tensor) -> Tensor:
+ """Convert rotations given as axis/angle to rotation matrices."""
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
+
+
+def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
+ """Converts 6D rotation representation to rotation matrix."""
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def matrix_to_rotation_6d(matrix: Tensor) -> Tensor:
+ """Converts rotation matrices to 6D rotation representation."""
+ batch_dim = matrix.size()[:-2]
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
+
+
+def R_from_allocentric(K: Tensor, R_view, u=None, v=None):
+ """Convert rotation matrix to egocentric representation."""
+ fx = K[:, 0, 0]
+ fy = K[:, 1, 1]
+ sx = K[:, 0, 2]
+ sy = K[:, 1, 2]
+
+ if u is None:
+ u = sx
+ if v is None:
+ v = sy
+
+ oray = torch.stack(((u - sx) / fx, (v - sy) / fy, torch.ones_like(u))).T
+ oray = oray / torch.linalg.norm(oray, dim=1).unsqueeze(1)
+ angle = torch.acos(oray[:, -1])
+
+ axis = torch.zeros_like(oray)
+ axis[:, 0] = axis[:, 0] - oray[:, 1]
+ axis[:, 1] = axis[:, 1] + oray[:, 0]
+ norms = torch.linalg.norm(axis, dim=1)
+
+ valid_angle = angle > 0
+
+ M = axis_angle_to_matrix(angle.unsqueeze(1) * axis / norms.unsqueeze(1))
+
+ R = R_view.clone()
+ R[valid_angle] = torch.bmm(M[valid_angle], R_view[valid_angle])
+
+ return R
+
+
+def R_to_allocentric(K: Tensor, R, u=None, v=None):
+ """Convert rotation matrix to allocentric representation."""
+ fx = K[:, 0, 0]
+ fy = K[:, 1, 1]
+ sx = K[:, 0, 2]
+ sy = K[:, 1, 2]
+
+ if u is None:
+ u = sx
+ if v is None:
+ v = sy
+
+ oray = torch.stack(((u - sx) / fx, (v - sy) / fy, torch.ones_like(u))).T
+ oray = oray / torch.linalg.norm(oray, dim=1).unsqueeze(1)
+ angle = torch.acos(oray[:, -1])
+
+ axis = torch.zeros_like(oray)
+ axis[:, 0] = axis[:, 0] - oray[:, 1]
+ axis[:, 1] = axis[:, 1] + oray[:, 0]
+ norms = torch.linalg.norm(axis, dim=1)
+
+ valid_angle = angle > 0
+
+ M = axis_angle_to_matrix(angle.unsqueeze(1) * axis / norms.unsqueeze(1))
+
+ R_view = R.clone()
+ R_view[valid_angle] = torch.bmm(
+ M[valid_angle].transpose(2, 1), R[valid_angle]
+ )
+
+ return R_view
diff --git a/wilddet3d/ops/upsample.py b/wilddet3d/ops/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74da9b8ac2d4bcb0bac07a2aff0dce4f8b0dc66
--- /dev/null
+++ b/wilddet3d/ops/upsample.py
@@ -0,0 +1,127 @@
+"""Upsampling layers."""
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+
+
+class CvnxtBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ layer_scale=1.0,
+ expansion=4,
+ dilation=1,
+ padding_mode: str = "zeros",
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ groups=dim,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ )
+ self.norm = nn.LayerNorm(dim)
+ self.pwconv1 = nn.Linear(dim, expansion * dim)
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(expansion * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale * torch.ones((dim)))
+ if layer_scale > 0.0
+ else 1.0
+ )
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ x = self.gamma * x
+ x = input + x.permute(0, 3, 1, 2)
+ return x
+
+
+class ConvUpsample(nn.Module):
+ """Convolutional upsampling layer."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ output_dim: int | None = None,
+ num_layers: int = 2,
+ expansion: int = 4,
+ layer_scale: float = 1.0,
+ kernel_size: int = 7,
+ ) -> None:
+ """Init."""
+ super().__init__()
+
+ if output_dim is None:
+ output_dim = hidden_dim // 2
+
+ self.convs = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.convs.append(
+ CvnxtBlock(
+ hidden_dim,
+ kernel_size=kernel_size,
+ expansion=expansion,
+ layer_scale=layer_scale,
+ )
+ )
+
+ self.up = nn.Sequential(
+ nn.Conv2d(hidden_dim, output_dim, kernel_size=1, padding=0),
+ nn.UpsamplingBilinear2d(scale_factor=2),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+
+ def forward(self, x: Tensor):
+ for conv in self.convs:
+ x = conv(x)
+ x = self.up(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ return x
+
+
+class ConvUpsampleShuffle(nn.Module):
+ def __init__(
+ self,
+ hidden_dim,
+ num_layers: int = 2,
+ expansion: int = 4,
+ layer_scale: float = 1.0,
+ kernel_size: int = 7,
+ ):
+ super().__init__()
+ self.convs = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.convs.append(
+ CvnxtBlock(
+ hidden_dim,
+ kernel_size=kernel_size,
+ expansion=expansion,
+ layer_scale=layer_scale,
+ )
+ )
+ self.up = nn.Sequential(
+ nn.PixelShuffle(2),
+ nn.Conv2d(
+ hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1
+ ),
+ )
+
+ def forward(self, x: Tensor):
+ for conv in self.convs:
+ x = conv(x)
+ x = self.up(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ return x
diff --git a/wilddet3d/ops/util.py b/wilddet3d/ops/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..621c9834abc94c41726be625ba44146d016c8522
--- /dev/null
+++ b/wilddet3d/ops/util.py
@@ -0,0 +1,44 @@
+"""Op utility functions."""
+
+from __future__ import annotations
+
+from functools import partial
+
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def multi_apply(func, *args, **kwargs):
+ """Apply function to a list of arguments."""
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+def flat_interpolate(
+ flat_tensor: Tensor,
+ old: tuple[int, int],
+ new: tuple[int, int],
+ antialias: bool = True,
+ mode: str = "bilinear",
+) -> Tensor:
+ if old[0] == new[0] and old[1] == new[1]:
+ return flat_tensor
+ tensor = flat_tensor.view(
+ flat_tensor.shape[0], old[0], old[1], -1
+ ).permute(
+ 0, 3, 1, 2
+ )
+ tensor_interp = F.interpolate(
+ tensor,
+ size=(new[0], new[1]),
+ mode=mode,
+ align_corners=False,
+ antialias=antialias,
+ )
+ flat_tensor_interp = tensor_interp.view(
+ flat_tensor.shape[0], -1, new[0] * new[1]
+ ).permute(
+ 0, 2, 1
+ )
+ return flat_tensor_interp.contiguous()
diff --git a/wilddet3d/preprocessing.py b/wilddet3d/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..880a8aa5b96b80330bf39f94cb5c875d3dd72992
--- /dev/null
+++ b/wilddet3d/preprocessing.py
@@ -0,0 +1,83 @@
+"""Preprocessing utilities for WildDet3D inference.
+
+Handles image resizing, normalization, center padding, and intrinsics
+adjustment to prepare raw inputs for the WildDet3D model.
+"""
+
+from typing import Optional
+
+import numpy as np
+
+from vis4d.data.transforms.base import compose
+from vis4d.data.transforms.normalize import NormalizeImages
+from vis4d.data.transforms.resize import ResizeImages, ResizeIntrinsics
+from vis4d.data.transforms.to_tensor import ToTensor
+
+from wilddet3d.data.transforms.pad import (
+ CenterPadImages,
+ CenterPadIntrinsics,
+)
+from wilddet3d.data.transforms.resize import GenResizeParameters
+
+# WildDet3D expects 1008x1008 images
+IMAGE_SIZE = (1008, 1008)
+
+
+def preprocess(
+ image: np.ndarray,
+ intrinsics: Optional[np.ndarray] = None,
+) -> dict:
+ """Preprocess image for WildDet3D.
+
+ Args:
+ image: RGB image as numpy array (H, W, 3)
+ intrinsics: Camera intrinsics (3, 3), or None to use default/predicted
+
+ Returns:
+ Dict with preprocessed tensors and metadata
+ """
+ images = image.astype(np.float32)[None, ...]
+ H, W = images.shape[1], images.shape[2]
+
+ # If no intrinsics provided, create a placeholder.
+ # When use_predicted_intrinsics=True in the model, the geometry backend's
+ # K_pred will be used for 3D box decoding instead of this placeholder.
+ # The placeholder is still needed so the data pipeline doesn't crash.
+ if intrinsics is None:
+ focal = max(H, W)
+ intrinsics = np.array(
+ [
+ [focal, 0, W / 2],
+ [0, focal, H / 2],
+ [0, 0, 1],
+ ],
+ dtype=np.float32,
+ )
+
+ data_dict = {
+ "images": images,
+ "original_images": images.copy(),
+ "input_hw": (H, W),
+ "original_hw": (H, W),
+ "intrinsics": intrinsics.astype(np.float32),
+ "original_intrinsics": intrinsics.astype(np.float32).copy(),
+ }
+
+ preprocess_transforms = compose(
+ transforms=[
+ GenResizeParameters(shape=IMAGE_SIZE),
+ ResizeImages(),
+ ResizeIntrinsics(),
+ NormalizeImages(),
+ CenterPadImages(
+ stride=1, shape=IMAGE_SIZE, update_input_hw=True
+ ),
+ CenterPadIntrinsics(),
+ ]
+ )
+
+ data = preprocess_transforms([data_dict])[0]
+ to_tensor = ToTensor()
+ data = to_tensor([data])[0]
+
+ return data
diff --git a/wilddet3d/vis/__init__.py b/wilddet3d/vis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/vis/__pycache__/__init__.cpython-311.pyc b/wilddet3d/vis/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..df8bab22fe019b68da5b00cab0d9dc421ab2ae78
Binary files /dev/null and b/wilddet3d/vis/__pycache__/__init__.cpython-311.pyc differ
diff --git a/wilddet3d/vis/__pycache__/visualize.cpython-311.pyc b/wilddet3d/vis/__pycache__/visualize.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de0dbe8a4aa3349fc38cd8e3812cc208b191cc4c
Binary files /dev/null and b/wilddet3d/vis/__pycache__/visualize.cpython-311.pyc differ
diff --git a/wilddet3d/vis/fonts/Manrope-Bold.ttf b/wilddet3d/vis/fonts/Manrope-Bold.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..52d93a3fbf43352085a4089047aeb600929fd583
Binary files /dev/null and b/wilddet3d/vis/fonts/Manrope-Bold.ttf differ
diff --git a/wilddet3d/vis/fonts/Manrope-SemiBold.ttf b/wilddet3d/vis/fonts/Manrope-SemiBold.ttf
new file mode 100644
index 0000000000000000000000000000000000000000..85e036efd34d272b11398d7320f5505611913003
Binary files /dev/null and b/wilddet3d/vis/fonts/Manrope-SemiBold.ttf differ
diff --git a/wilddet3d/vis/image/__init__.py b/wilddet3d/vis/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/wilddet3d/vis/image/depth_visualizer.py b/wilddet3d/vis/image/depth_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b911c4caad26a9250d7f17375b420c54c5822c11
--- /dev/null
+++ b/wilddet3d/vis/image/depth_visualizer.py
@@ -0,0 +1,200 @@
+"""Depth visualizer."""
+
+from __future__ import annotations
+
+import os
+from dataclasses import dataclass
+
+import numpy as np
+from PIL import Image
+from vis4d.common.array import array_to_numpy
+from vis4d.common.typing import (
+ ArgsType,
+ ArrayLikeFloat,
+ NDArrayF32,
+ NDArrayUI8,
+)
+from vis4d.vis.base import Visualizer
+from vis4d.vis.image.util import preprocess_image
+from vis4d.vis.util import generate_color_map
+
+from .util import (
+ colorize,
+ get_pointcloud_from_rgbd,
+ save_depth_map,
+ save_file_ply,
+)
+
+
+@dataclass
+class DataSample:
+ """Dataclass storing a data sample that can be visualized."""
+
+ image: NDArrayUI8
+ image_name: str
+ depth: NDArrayF32
+ depth_gt: NDArrayF32 | None = None
+ depth_error: NDArrayF32 | None = None
+ points_rgb: NDArrayF32 | None = None
+
+
+class DepthVisualizer(Visualizer):
+ """Depth visualizer class."""
+
+ def __init__(
+ self,
+ *args: ArgsType,
+ max_depth: None | float = None,
+ plot_error: bool = False,
+ lift: bool = False,
+ color_palette: list[tuple[int, int, int]] | None = None,
+ **kwargs: ArgsType,
+ ) -> None:
+ """Creates a new Visualizer for Depth.
+
+ Args:
+ max_depth (None | float): Maximum depth to visualize.
+ """
+ super().__init__(*args, **kwargs)
+ self.max_depth = max_depth
+ self._samples: list[DataSample] = []
+ self._gt_samples = []
+ self.plot_error = plot_error
+ self.lift = lift
+ self.color_palette = (
+ generate_color_map(50) if color_palette is None else color_palette
+ )
+
+ def reset(self) -> None:
+ """Reset the visualizer."""
+ self._samples.clear()
+ self._gt_samples.clear()
+
+ def process(
+ self,
+ cur_iter: int,
+ images: list[ArrayLikeFloat],
+ image_names: list[str],
+ depths: ArrayLikeFloat,
+ depth_gts: ArrayLikeFloat | None = None,
+ intrinsics: ArrayLikeFloat | None = None,
+ ) -> None:
+ """Process data of a batch of data."""
+ if self._run_on_batch(cur_iter):
+ for i, image in enumerate(images):
+ image = preprocess_image(image)
+ self._samples.append(
+ self.process_single_image(
+ image,
+ image_names[i],
+ array_to_numpy(depths[i]),
+ (
+ array_to_numpy(depth_gts[i])
+ if depth_gts is not None
+ else None
+ ),
+ (
+ array_to_numpy(intrinsics[i])
+ if intrinsics is not None
+ else None
+ ),
+ )
+ )
+
+ def process_single_image(
+ self,
+ image: NDArrayUI8,
+ image_name: str,
+ depth: NDArrayF32,
+ depth_gt: NDArrayF32 | None = None,
+ intrinsic: NDArrayF32 | None = None,
+ ) -> DataSample:
+ """Process data of a batch of data."""
+ if self.max_depth is not None:
+ mask = depth <= self.max_depth
+ else:
+ mask = np.full(depth.shape, True)
+
+ if self.plot_error:
+ assert (
+ depth_gt is not None
+ ), "Ground truth depth is required for plotting error."
+ error = np.zeros_like(depth_gt)
+ error[depth_gt > 0] = (
+ np.abs(depth_gt - depth)[depth_gt > 0] / depth_gt[depth_gt > 0]
+ )
+ else:
+ error = None
+
+ if self.lift:
+ assert (
+ intrinsic is not None
+ ), "Intrinsic matrix is required for lifting."
+ points_rgb = get_pointcloud_from_rgbd(
+ image, depth, intrinsic, mask
+ )
+ else:
+ points_rgb = None
+
+ return DataSample(
+ image=image,
+ image_name=image_name,
+ depth=depth,
+ depth_gt=depth_gt,
+ depth_error=error,
+ points_rgb=points_rgb,
+ )
+
+ def save_to_disk(self, cur_iter: int, output_folder: str) -> None:
+ """Saves the visualization to disk.
+
+ Args:
+ cur_iter (int): Current iteration.
+ output_folder (str): Folder where the output should be written.
+ """
+ if self._run_on_batch(cur_iter):
+ for sample in self._samples:
+ save_dir = os.path.join(output_folder, "depth")
+ os.makedirs(save_dir, exist_ok=True)
+
+ Image.fromarray(sample.image).save(
+ f"{save_dir}/{sample.image_name}.png",
+ )
+
+ if self.plot_error:
+ error = sample.depth_error
+
+ error_image = Image.fromarray(
+ colorize(
+ error.clip(0.0, 0.3),
+ vmin=0.001,
+ vmax=0.3,
+ cmap="coolwarm",
+ )
+ )
+
+ error_image.save(
+ f"{save_dir}/{sample.image_name}_error.png"
+ )
+
+ save_depth_map(
+ sample.depth,
+ f"{save_dir}/{sample.image_name}_pred.png",
+ )
+
+ if sample.depth_gt is not None:
+ save_depth_map(
+ sample.depth_gt,
+ f"{save_dir}/{sample.image_name}_gt.png",
+ )
+
+ if self.lift:
+ save_dir = os.path.join(output_folder, "points")
+ os.makedirs(save_dir, exist_ok=True)
+
+ if sample.points_rgb is not None:
+ save_file_ply(
+ sample.points_rgb[:, :3],
+ sample.points_rgb[:, 3:],
+ os.path.join(save_dir, f"{sample.image_name}.ply"),
+ )
diff --git a/wilddet3d/vis/image/util.py b/wilddet3d/vis/image/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7adffc042a79ed0918ab0c4c616b7ade6a9a487
--- /dev/null
+++ b/wilddet3d/vis/image/util.py
@@ -0,0 +1,151 @@
+"""Utility functions for image processing operations."""
+
+from __future__ import annotations
+
+import numpy as np
+from matplotlib.pyplot import get_cmap
+from PIL import Image
+from vis4d.common.typing import (
+ NDArrayBool,
+ NDArrayF32,
+ NDArrayUI8,
+ NDArrayUI16,
+)
+
+
+def save_depth_map(
+ depth_map: NDArrayF32, filename: str, depth_scale: float = 256.0
+) -> None:
+ """Dump depth map.
+
+ Args:
+ depth_map (NDArrayF32): Depth map to dump.
+ filename (str): Path to dump depth map.
+ depth_scale (float): Depth scale.
+ """
+ numpy_image = (depth_map * depth_scale).astype(np.uint16)
+ numpy_image = colorize(numpy_image)
+ Image.fromarray(numpy_image).save(filename)
+
+
+def colorize(
+ value: NDArrayUI16,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str = "magma_r",
+) -> Image.Image:
+ if value.ndim > 2:
+ return value
+ invalid_mask = value < 1e-3
+ # normalize
+ vmin = value.min() if vmin is None else vmin
+ vmax = value.max() if vmax is None else vmax
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
+
+ # set color
+ cmapper = get_cmap(cmap)
+ value = cmapper(value, bytes=True) # (nxmx4)
+ value[invalid_mask] = 0
+ img = value[..., :3]
+ return img
+
+
+def get_pointcloud_from_rgbd(
+ image: NDArrayUI8,
+ depth: NDArrayF32,
+ intrinsic_matrix: NDArrayF32,
+ mask: NDArrayBool,
+ remove_height: float | None = None,
+) -> NDArrayF32:
+ """Get pointcloud from RGBD image.
+
+ Args:
+ image (np.array): RGB image. Shape: (H, W, 3)
+ depth (np.array): Depth image. Shape: (H, W)
+ mask (np.ndarray): Mask of valid depth values. Shape: (H, W)
+ intrinsic_matrix (np.array): Intrinsic matrix of camera. Shape: (3, 3)
+ extrinsic_matrix (np.array, optional): Extrinsic matrix of camera.
+ Shape: (4, 4). Defaults to None.
+ voxelize (bool, optional): Whether to voxelize the pointcloud.
+
+ Returns:
+ NDArrayF32: Pointcloud. Shape: (N, 6)
+ """
+ # Mask the depth array
+ masked_depth = np.ma.masked_where(mask == False, depth)
+
+ # Create idx array
+ idxs = np.indices(masked_depth.shape)
+ u_idxs = idxs[1]
+ v_idxs = idxs[0]
+
+ # Get only non-masked depth and idxs
+ z = masked_depth[~masked_depth.mask]
+ compressed_u_idxs = u_idxs[~masked_depth.mask]
+ compressed_v_idxs = v_idxs[~masked_depth.mask]
+ image = np.stack(
+ [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])],
+ axis=-1,
+ )
+
+ # Calculate local position of each point
+ # Apply vectorized math to depth using compressed arrays
+ cx = intrinsic_matrix[0, 2]
+ fx = intrinsic_matrix[0, 0]
+ x = (compressed_u_idxs - cx) * z / fx
+ cy = intrinsic_matrix[1, 2]
+ fy = intrinsic_matrix[1, 1]
+
+ # Flip y as we want +y pointing up not down
+ y = (compressed_v_idxs - cy) * z / fy
+
+ # Remove height
+ if remove_height is not None:
+ mask = y >= remove_height
+ x = x[mask]
+ y = y[mask]
+ z = z[mask]
+ image = image[mask]
+ else:
+ x = x.reshape(-1)
+ y = y.reshape(-1)
+ z = z.reshape(-1)
+ image = image.reshape(-1, 3)
+
+ x_y_z_local = np.stack((x, y, z), axis=-1)
+
+ return np.concatenate([x_y_z_local, image], axis=-1)
+
+
+def save_file_ply(xyz: NDArrayF32, rgb: NDArrayF32, pc_file: str) -> None:
+ """Save point cloud to ply file."""
+ if rgb.max() < 1.001:
+ rgb = rgb * 255.0
+ rgb = rgb.astype(np.uint8)
+
+ with open(pc_file, "w") as f:
+ # headers
+ f.writelines(
+ [
+ "ply\n" "format ascii 1.0\n",
+ "element vertex {}\n".format(xyz.shape[0]),
+ "property float x\n",
+ "property float y\n",
+ "property float z\n",
+ "property uchar red\n",
+ "property uchar green\n",
+ "property uchar blue\n",
+ "end_header\n",
+ ]
+ )
+
+ for i in range(xyz.shape[0]):
+ str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format(
+ xyz[i][0],
+ xyz[i, 1],
+ xyz[i, 2],
+ rgb[i, 0],
+ rgb[i, 1],
+ rgb[i, 2],
+ )
+ f.write(str_v)
diff --git a/wilddet3d/vis/visualize.py b/wilddet3d/vis/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..3768eb58ac5434250669758c479b2c76f74eb2f0
--- /dev/null
+++ b/wilddet3d/vis/visualize.py
@@ -0,0 +1,261 @@
+"""WildDet3D visualization utilities.
+
+Anti-aliased 3D bounding boxes with Manrope font score labels.
+Uses vis4d's preprocess_boxes3d for correct 3D corner projection,
+cv2 LINE_AA for smooth lines, PIL + Manrope for text rendering.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from torch import Tensor
+
+from vis4d.common.array import array_to_numpy
+from vis4d.data.const import AxisMode
+from vis4d.op.box.box3d import boxes3d_to_corners
+from vis4d.vis.util import generate_color_map
+
+_FONT_DIR = Path(__file__).parent / "fonts"
+
+# vis4d edge order (from PillowCanvasBackend.draw_box_3d)
+# Front face: 0-1-5-4, Back face: 2-3-7-6, Sides: 0-2, 1-3, 4-6, 5-7
+_EDGES = [
+ # Front
+ (0, 1), (1, 5), (5, 4), (4, 0),
+ # Sides
+ (0, 2), (1, 3), (4, 6), (5, 7),
+ # Back
+ (2, 3), (3, 7), (7, 6), (6, 2),
+]
+
+
+def _get_font(size: int = 14) -> ImageFont.FreeTypeFont:
+ """Get Manrope Bold font with fallbacks."""
+ for path in [
+ _FONT_DIR / "Manrope-Bold.ttf",
+ _FONT_DIR / "Manrope-SemiBold.ttf",
+ Path("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"),
+ ]:
+ if path.exists():
+ return ImageFont.truetype(str(path), size)
+ return ImageFont.load_default()
+
+
+def _project_pt_simple(pt_3d, K_np):
+ """Project single 3D point to 2D using intrinsics (no torch overhead)."""
+ x, y, z = pt_3d
+ fx, fy = K_np[0, 0], K_np[1, 1]
+ cx, cy = K_np[0, 2], K_np[1, 2]
+ u = fx * x / z + cx
+ v = fy * y / z + cy
+ return float(u), float(v)
+
+
+def _clip_to_near(p1, p2, near=0.15):
+ """Clip line to near plane, return clipped point."""
+ x1, y1, z1 = p1
+ x2, y2, z2 = p2
+ k_up = abs(z1 - near)
+ k_down = abs(z1 - z2)
+ k = min(k_up / k_down, 1.0) if k_down > 0 else 1.0
+ return ((1 - k) * x1 + k * x2, (1 - k) * y1 + k * y2, near)
+
+
+def draw_3d_boxes(
+ image: np.ndarray,
+ boxes3d: Tensor | np.ndarray,
+ intrinsics: np.ndarray,
+ scores_2d: Tensor | np.ndarray | None = None,
+ scores_3d: Tensor | np.ndarray | None = None,
+ class_ids: Tensor | np.ndarray | None = None,
+ class_names: list[str] | None = None,
+ line_width: int = 2,
+ font_size: int = 13,
+ n_colors: int = 50,
+ score_format: str = "{name} 2D:{s2d:.2f} 3D:{s3d:.2f}",
+ near_clip: float = 0.15,
+ save_path: str | None = None,
+) -> Image.Image:
+ """Draw anti-aliased 3D bounding boxes with 2D/3D score labels.
+
+ Args:
+ image: RGB image (H, W, 3) uint8.
+ boxes3d: 3D boxes (N, 10) in OPENCV camera coordinates.
+ intrinsics: Camera intrinsics (3, 3).
+ scores_2d: 2D confidence scores (N,).
+ scores_3d: 3D confidence scores (N,).
+ class_ids: Class indices (N,).
+ class_names: List of class names.
+ line_width: Width of 3D box edges.
+ font_size: Font size for labels.
+ n_colors: Number of colors in palette.
+ score_format: Format string. Available: {name}, {s2d}, {s3d}.
+ near_clip: Camera near clipping plane.
+ save_path: If provided, save the result.
+
+ Returns:
+ PIL Image with drawn boxes and score labels.
+ """
+ if isinstance(image, Tensor):
+ image = image.cpu().numpy()
+ if image.dtype != np.uint8:
+ image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
+ if isinstance(boxes3d, Tensor):
+ boxes3d_t = boxes3d.cpu().float()
+ else:
+ boxes3d_t = torch.tensor(boxes3d, dtype=torch.float32)
+ if isinstance(scores_2d, Tensor):
+ scores_2d = scores_2d.cpu().numpy()
+ if isinstance(scores_3d, Tensor):
+ scores_3d = scores_3d.cpu().numpy()
+ if isinstance(class_ids, Tensor):
+ class_ids = class_ids.cpu().numpy()
+
+ N = len(boxes3d_t)
+ H, W = image.shape[:2]
+ K_np = intrinsics.astype(np.float32)
+
+ if N == 0:
+ pil_img = Image.fromarray(image)
+ if save_path:
+ pil_img.save(save_path, quality=95)
+ return pil_img
+
+ # Get 3D corners (N, 8, 3) using vis4d's OPENCV convention
+ corners_3d = boxes3d_to_corners(boxes3d_t, AxisMode.OPENCV).numpy()
+
+ color_map = generate_color_map(n_colors)
+
+ # --- Draw lines with cv2 (anti-aliased) ---
+ canvas = image.copy()
+ canvas_bgr = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR)
+
+ for i in range(N):
+ cid = int(class_ids[i]) if class_ids is not None else i
+ color_rgb = color_map[cid % len(color_map)]
+ color_bgr = (int(color_rgb[2]), int(color_rgb[1]), int(color_rgb[0]))
+
+ corners = corners_3d[i] # (8, 3)
+
+ for e0, e1 in _EDGES:
+ p1 = tuple(corners[e0].tolist())
+ p2 = tuple(corners[e1].tolist())
+
+ # Near-plane clipping
+ if p1[2] < near_clip and p2[2] < near_clip:
+ continue
+ if p1[2] < near_clip:
+ p1 = _clip_to_near(p1, p2, near_clip)
+ elif p2[2] < near_clip:
+ p2 = _clip_to_near(p2, p1, near_clip)
+
+ # Project to 2D
+ u1, v1 = _project_pt_simple(p1, K_np)
+ u2, v2 = _project_pt_simple(p2, K_np)
+
+ # Skip if way outside image
+ margin = max(W, H)
+ if (abs(u1) > margin * 2 or abs(v1) > margin * 2 or
+ abs(u2) > margin * 2 or abs(v2) > margin * 2):
+ continue
+
+ cv2.line(
+ canvas_bgr,
+ (int(round(u1)), int(round(v1))),
+ (int(round(u2)), int(round(v2))),
+ color_bgr,
+ thickness=line_width,
+ lineType=cv2.LINE_AA,
+ )
+
+ canvas_rgb = cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB)
+
+ # --- Draw text labels with PIL (Manrope font) ---
+ # Use RGBA for rounded rectangle with alpha
+ pil_img = Image.fromarray(canvas_rgb).convert("RGBA")
+ overlay = Image.new("RGBA", pil_img.size, (0, 0, 0, 0))
+ draw_overlay = ImageDraw.Draw(overlay)
+ draw_main = ImageDraw.Draw(pil_img)
+ font = _get_font(font_size)
+
+ for i in range(N):
+ cid = int(class_ids[i]) if class_ids is not None else 0
+ color = color_map[cid % len(color_map)]
+
+ # Project center to 2D
+ center_3d = boxes3d_t[i, :3].numpy()
+ if center_3d[2] < near_clip:
+ continue
+ cx, cy = _project_pt_simple(tuple(center_3d.tolist()), K_np)
+ if cx < -50 or cx >= W + 50 or cy < -50 or cy >= H + 50:
+ continue
+
+ name = class_names[cid] if class_names is not None else str(cid)
+ s2d = float(scores_2d[i]) if scores_2d is not None else 0.0
+ s3d = float(scores_3d[i]) if scores_3d is not None else 0.0
+ label = score_format.format(name=name, s2d=s2d, s3d=s3d)
+
+ # Measure text size (textbbox returns actual glyph bounds)
+ left, top, right, bottom = draw_main.textbbox((0, 0), label, font=font)
+ tw = right - left
+ th = bottom - top
+ y_offset = top # font ascent offset (glyphs don't start at y=0)
+
+ # Position: inside the box, near the projected center
+ pad_x, pad_y = 6, 4
+ radius = 5
+
+ # Place label centered at projected center
+ rx0 = cx - tw / 2 - pad_x
+ ry0 = cy - th / 2 - pad_y
+ rx1 = cx + tw / 2 + pad_x
+ ry1 = cy + th / 2 + pad_y
+
+ # Clamp to image bounds
+ if rx0 < 2:
+ shift = 2 - rx0
+ rx0 += shift
+ rx1 += shift
+ if rx1 > W - 2:
+ shift = rx1 - (W - 2)
+ rx0 -= shift
+ rx1 -= shift
+ if ry0 < 2:
+ shift = 2 - ry0
+ ry0 += shift
+ ry1 += shift
+ if ry1 > H - 2:
+ shift = ry1 - (H - 2)
+ ry0 -= shift
+ ry1 -= shift
+
+ # Draw rounded rectangle on overlay (semi-transparent)
+ fill_color = tuple(color) + (210,)
+ draw_overlay.rounded_rectangle(
+ [rx0, ry0, rx1, ry1],
+ radius=radius,
+ fill=fill_color,
+ )
+
+ # Text centered in the rounded rect (compensate font ascent offset)
+ text_x = rx0 + pad_x - left
+ text_y = ry0 + pad_y - y_offset
+ draw_overlay.text(
+ (text_x, text_y),
+ label,
+ fill=(255, 255, 255, 255),
+ font=font,
+ )
+
+ # Composite overlay onto main image
+ pil_img = Image.alpha_composite(pil_img, overlay).convert("RGB")
+
+ if save_path:
+ pil_img.save(save_path, quality=95)
+
+ return pil_img