File size: 1,675 Bytes
cecbc0f
 
 
 
 
 
 
 
 
4e8811f
 
 
 
 
cecbc0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import annotations

import torch
from torch import nn

from moshi.utils import quantize


def patch_bitsandbytes_import_for_unquantized_layers() -> None:
    # moshi >=0.2.12 dropped the module-level linear/multi_linear helpers (now QLinear),
    # and its Mimi no longer routes through them, so there is nothing to patch.
    if not hasattr(quantize, "linear") or not hasattr(quantize, "multi_linear"):
        return

    original_linear = quantize.linear
    original_multi_linear = quantize.multi_linear

    def linear(module: nn.Module, x: torch.Tensor, name: str = "weight") -> torch.Tensor:
        if quantize.is_quantized(module, name):
            return original_linear(module, x, name)
        return nn.functional.linear(x, getattr(module, name))

    def multi_linear(
        num_steps: int,
        schedule: list[int] | None,
        module: nn.Module,
        x: torch.Tensor,
        offset: int,
        name: str = "weight",
    ) -> torch.Tensor:
        if quantize.is_quantized(module, name):
            return original_multi_linear(num_steps, schedule, module, x, offset, name)

        weight = getattr(module, name)
        num_linear = num_steps if schedule is None else max(schedule) + 1
        weight = weight.view(num_linear, -1, weight.shape[-1])

        outputs = []
        for t in range(x.shape[1]):
            linear_index = t + offset
            if schedule is not None:
                linear_index = schedule[linear_index]
            outputs.append(nn.functional.linear(x[:, t], weight[linear_index]))
        return torch.stack(outputs, 1)

    quantize.linear = linear
    quantize.multi_linear = multi_linear