Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This software may be used and distributed in accordance with | |
| # the terms of the DINOv3 License Agreement. | |
| import logging | |
| from typing import Callable | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import xformers.ops as xops | |
| from ..utils import named_apply, named_replace | |
| logger = logging.getLogger("dinov3") | |
| class LinearW24(torch.nn.Linear): | |
| ALGO = "largest_abs_values_greedy" | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.sparsity_enabled = False | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| if not self.sparsity_enabled: | |
| return super().forward(input) | |
| input_shape = input.shape | |
| input = input.flatten(end_dim=-2) | |
| dim0 = input.shape[0] | |
| if dim0 % 8 != 0: | |
| # NOTE: This should be torch-compiled away | |
| input = F.pad(input, [0, 0, 0, -dim0 % 8]) | |
| w_sparse = xops.sparsify24( | |
| self.weight, | |
| algo=self.ALGO, | |
| gradient="ste", | |
| backend="cusparselt", | |
| ) | |
| return F.linear(input, w_sparse, self.bias,)[ | |
| :dim0 | |
| ].unflatten(dim=0, sizes=input_shape[:-1]) | |
| def replace_linears_with_sparse_linear(root_module: nn.Module, *, filter_fn: Callable[[str], bool]) -> nn.Module: | |
| total_count = 0 | |
| def replace(module: nn.Module, name: str) -> nn.Module: | |
| nonlocal total_count | |
| if not isinstance(module, nn.Linear) or not filter_fn(name): | |
| return module | |
| assert type(module) == nn.Linear, "Subtypes not supported" | |
| new_module = LinearW24( | |
| in_features=module.in_features, | |
| out_features=module.out_features, | |
| bias=module.bias is not None, | |
| dtype=module.weight.dtype, | |
| device=module.weight.device, | |
| ) | |
| new_module.weight = module.weight | |
| new_module.bias = module.bias | |
| total_count += 1 | |
| return new_module | |
| out = named_replace(replace, root_module) | |
| assert total_count > 0, "2:4 sparsity: no layer found to sparsify" | |
| return out | |
| def update_24sparsity(root_module: nn.Module, enabled: bool) -> int: | |
| num_modified = 0 | |
| def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module: | |
| nonlocal num_modified | |
| if not isinstance(module, LinearW24): | |
| return module | |
| num_modified += 1 | |
| module.sparsity_enabled = enabled | |
| logger.info(f"- {'' if module.sparsity_enabled else 'de'}sparsifying {name}") | |
| return module | |
| named_apply(maybe_apply_sparsity, root_module) | |
| # Force re-compile everything | |
| torch._dynamo.reset_code_caches() | |
| from torch._inductor.cudagraph_trees import reset_cudagraph_trees | |
| reset_cudagraph_trees() | |
| return num_modified | |