MisoTTS / moshi_compat.py
multimodalart's picture
multimodalart HF Staff
Guard moshi quantize patch for moshi>=0.2.12
4e8811f verified
from __future__ import annotations
import torch
from torch import nn
from moshi.utils import quantize
def patch_bitsandbytes_import_for_unquantized_layers() -> None:
# moshi >=0.2.12 dropped the module-level linear/multi_linear helpers (now QLinear),
# and its Mimi no longer routes through them, so there is nothing to patch.
if not hasattr(quantize, "linear") or not hasattr(quantize, "multi_linear"):
return
original_linear = quantize.linear
original_multi_linear = quantize.multi_linear
def linear(module: nn.Module, x: torch.Tensor, name: str = "weight") -> torch.Tensor:
if quantize.is_quantized(module, name):
return original_linear(module, x, name)
return nn.functional.linear(x, getattr(module, name))
def multi_linear(
num_steps: int,
schedule: list[int] | None,
module: nn.Module,
x: torch.Tensor,
offset: int,
name: str = "weight",
) -> torch.Tensor:
if quantize.is_quantized(module, name):
return original_multi_linear(num_steps, schedule, module, x, offset, name)
weight = getattr(module, name)
num_linear = num_steps if schedule is None else max(schedule) + 1
weight = weight.view(num_linear, -1, weight.shape[-1])
outputs = []
for t in range(x.shape[1]):
linear_index = t + offset
if schedule is not None:
linear_index = schedule[linear_index]
outputs.append(nn.functional.linear(x[:, t], weight[linear_index]))
return torch.stack(outputs, 1)
quantize.linear = linear
quantize.multi_linear = multi_linear