Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,785 Bytes
aa16b75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from transformers import MimiModel
DEFAULT_MIMI_MODEL_ID = "kyutai/mimi"
@dataclass(frozen=True)
class MimiConfig:
model_id: str = DEFAULT_MIMI_MODEL_ID
dtype: Optional[torch.dtype] = None
class MimiCodec(nn.Module):
"""Thin wrapper around transformers' MimiModel for decoding audio tokens."""
def __init__(self, model: MimiModel, device: torch.device) -> None:
super().__init__()
self.model = model
self.device = device
cfg = getattr(model, "config", None)
self.sample_rate = getattr(cfg, "sampling_rate", 24000)
self.frame_rate = getattr(cfg, "frame_rate", 12.5)
self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0
@classmethod
def from_pretrained(
cls,
model_id: str = DEFAULT_MIMI_MODEL_ID,
*,
device: torch.device,
dtype: Optional[torch.dtype] = None,
) -> "MimiCodec":
model = MimiModel.from_pretrained(
model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model = model.to(device)
model.eval()
return cls(model, device)
def decode(self, codes: torch.Tensor) -> torch.Tensor:
codes = codes.to(self.device)
with torch.inference_mode():
audio, _ = self.model.decode(codes, return_dict=False)
return torch.clamp(audio, -1.0, 1.0)
def encode(self, audio: torch.Tensor, *, return_dict: bool = False):
audio = audio.to(self.device)
with torch.inference_mode():
return self.model.encode(audio, return_dict=return_dict)
|