| |
|
|
| |
|
|
| from functools import wraps |
|
|
| import torch |
| from ..model.data_misc import BatchedDatapoint, NestedTensor |
| from torch.utils._pytree import tree_map_only |
|
|
|
|
| def recursive_fn_factory(fn): |
| def recursive_fn(b): |
| if isinstance(b, dict): |
| return {k: recursive_fn(b[k]) for k in b} |
| if isinstance(b, list): |
| return [recursive_fn(t) for t in b] |
| if isinstance(b, tuple): |
| return tuple(recursive_fn(t) for t in b) |
| if isinstance(b, NestedTensor): |
| tensors = fn(b.tensors) |
| if b.mask is None: |
| mask = None |
| else: |
| mask = fn(b.mask) |
| return NestedTensor(tensors=tensors, mask=mask) |
| if isinstance(b, torch.Tensor): |
| return fn(b) |
| if b is None: |
| return b |
| trivial_types = [bool, int, float] |
| for t in trivial_types: |
| if isinstance(b, t): |
| return b |
| raise TypeError(f"Unexpected type {type(b)}") |
|
|
| return recursive_fn |
|
|
|
|
| recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous()) |
| recursive_clone = recursive_fn_factory(torch.clone) |
|
|
|
|
| def clone_output_wrapper(f): |
| """ |
| Clone the CUDA output tensors of a function to avoid in-place operations. |
| Uses tree_map_only (C-optimized pytree traversal) matching onevision's pattern. |
| Requires NestedTensor to be registered as a pytree node (see data_misc.py). |
| """ |
|
|
| @wraps(f) |
| def wrapped(*args, **kwargs): |
| outputs = f(*args, **kwargs) |
| return tree_map_only( |
| torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs |
| ) |
|
|
| return wrapped |
|
|
|
|
| def compile_wrapper( |
| fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None |
| ): |
| """Compile with recursive_contiguous on inputs and recursive_clone on outputs. |
| Used for SAM2 tracker components that need contiguous inputs for CUDA graphs.""" |
| compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic) |
|
|
| def compiled_fn_wrapper(*args, **kwargs): |
| with torch.autograd.profiler.record_function( |
| f"compiled {fn}" if name is None else name |
| ): |
| CUDAGRAPH_MODES = ["max-autotune", "reduce-overhead"] |
| args = recursive_contiguous(args) |
| kwargs = recursive_contiguous(kwargs) |
| result = compiled_fn(*args, **kwargs) |
| if mode in CUDAGRAPH_MODES: |
| result = recursive_clone(result) |
| return result |
|
|
| return compiled_fn_wrapper |
|
|
|
|
| def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False): |
| """ |
| Wraps a function and prints the shapes of all tensor inputs. |
| Only prints when a new combination of shapes is seen. |
| """ |
| seen_shapes = set() |
|
|
| def get_shape(obj): |
| if isinstance(obj, torch.Tensor): |
| return obj.shape |
| elif isinstance(obj, (list, tuple)): |
| if len(obj) > 1: |
| return tuple(get_shape(x) for x in obj) |
| return get_shape(obj[0]) |
| elif isinstance(obj, dict): |
| return tuple(sorted((k, get_shape(v)) for k, v in obj.items())) |
| else: |
| return type(obj).__name__ |
|
|
| def wrapper(*args, **kwargs): |
| shapes = tuple(get_shape(arg) for arg in args) + tuple( |
| (k, get_shape(v)) |
| for k, v in kwargs.items() |
| if isinstance(v, (torch.Tensor, list)) |
| and (len(keep_kwargs) > 0 and k in keep_kwargs) |
| ) |
| if shapes not in seen_shapes: |
| seen_shapes.add(shapes) |
| if enable_logging: |
| print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}") |
| return fn(*args, **kwargs) |
|
|
| wrapper.enable_logging = enable_logging |
|
|
| def set_logging(enabled=False): |
| nonlocal enable_logging |
| enable_logging = enabled |
| wrapper.enable_logging = enable_logging |
|
|
| wrapper.set_logging = set_logging |
| return wrapper |
|
|