himipo's picture
first
11aa70b
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import logging
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers.ops as xops
from ..utils import named_apply, named_replace
logger = logging.getLogger("dinov3")
class LinearW24(torch.nn.Linear):
ALGO = "largest_abs_values_greedy"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.sparsity_enabled = False
def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.sparsity_enabled:
return super().forward(input)
input_shape = input.shape
input = input.flatten(end_dim=-2)
dim0 = input.shape[0]
if dim0 % 8 != 0:
# NOTE: This should be torch-compiled away
input = F.pad(input, [0, 0, 0, -dim0 % 8])
w_sparse = xops.sparsify24(
self.weight,
algo=self.ALGO,
gradient="ste",
backend="cusparselt",
)
return F.linear(input, w_sparse, self.bias,)[
:dim0
].unflatten(dim=0, sizes=input_shape[:-1])
def replace_linears_with_sparse_linear(root_module: nn.Module, *, filter_fn: Callable[[str], bool]) -> nn.Module:
total_count = 0
def replace(module: nn.Module, name: str) -> nn.Module:
nonlocal total_count
if not isinstance(module, nn.Linear) or not filter_fn(name):
return module
assert type(module) == nn.Linear, "Subtypes not supported"
new_module = LinearW24(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
device=module.weight.device,
)
new_module.weight = module.weight
new_module.bias = module.bias
total_count += 1
return new_module
out = named_replace(replace, root_module)
assert total_count > 0, "2:4 sparsity: no layer found to sparsify"
return out
def update_24sparsity(root_module: nn.Module, enabled: bool) -> int:
num_modified = 0
def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module:
nonlocal num_modified
if not isinstance(module, LinearW24):
return module
num_modified += 1
module.sparsity_enabled = enabled
logger.info(f"- {'' if module.sparsity_enabled else 'de'}sparsifying {name}")
return module
named_apply(maybe_apply_sparsity, root_module)
# Force re-compile everything
torch._dynamo.reset_code_caches()
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
return num_modified