File size: 2,859 Bytes
11aa70b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

import logging
from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers.ops as xops

from ..utils import named_apply, named_replace

logger = logging.getLogger("dinov3")


class LinearW24(torch.nn.Linear):
    ALGO = "largest_abs_values_greedy"

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.sparsity_enabled = False

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if not self.sparsity_enabled:
            return super().forward(input)

        input_shape = input.shape
        input = input.flatten(end_dim=-2)
        dim0 = input.shape[0]
        if dim0 % 8 != 0:
            # NOTE: This should be torch-compiled away
            input = F.pad(input, [0, 0, 0, -dim0 % 8])
        w_sparse = xops.sparsify24(
            self.weight,
            algo=self.ALGO,
            gradient="ste",
            backend="cusparselt",
        )
        return F.linear(input, w_sparse, self.bias,)[
            :dim0
        ].unflatten(dim=0, sizes=input_shape[:-1])


def replace_linears_with_sparse_linear(root_module: nn.Module, *, filter_fn: Callable[[str], bool]) -> nn.Module:
    total_count = 0

    def replace(module: nn.Module, name: str) -> nn.Module:
        nonlocal total_count
        if not isinstance(module, nn.Linear) or not filter_fn(name):
            return module
        assert type(module) == nn.Linear, "Subtypes not supported"
        new_module = LinearW24(
            in_features=module.in_features,
            out_features=module.out_features,
            bias=module.bias is not None,
            dtype=module.weight.dtype,
            device=module.weight.device,
        )
        new_module.weight = module.weight
        new_module.bias = module.bias
        total_count += 1
        return new_module

    out = named_replace(replace, root_module)
    assert total_count > 0, "2:4 sparsity: no layer found to sparsify"
    return out


def update_24sparsity(root_module: nn.Module, enabled: bool) -> int:
    num_modified = 0

    def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module:
        nonlocal num_modified
        if not isinstance(module, LinearW24):
            return module
        num_modified += 1
        module.sparsity_enabled = enabled
        logger.info(f"- {'' if module.sparsity_enabled else 'de'}sparsifying {name}")
        return module

    named_apply(maybe_apply_sparsity, root_module)
    # Force re-compile everything
    torch._dynamo.reset_code_caches()
    from torch._inductor.cudagraph_trees import reset_cudagraph_trees

    reset_cudagraph_trees()
    return num_modified