|
|
""" |
|
|
WavTokenizer model implementation for HuggingFace. |
|
|
|
|
|
This implementation exactly matches the checkpoint structure for direct weight loading. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
from .configuration_wavtokenizer import WavTokenizerConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_audio(wav, sr, target_sr, target_channels=1): |
|
|
"""Convert audio to target sample rate and channels.""" |
|
|
if wav.dim() == 1: |
|
|
wav = wav.unsqueeze(0).unsqueeze(0) |
|
|
elif wav.dim() == 2: |
|
|
wav = wav.unsqueeze(1) |
|
|
|
|
|
if wav.shape[1] > target_channels: |
|
|
wav = wav[:, :target_channels, :] |
|
|
elif wav.shape[1] < target_channels: |
|
|
wav = wav.repeat(1, target_channels, 1) |
|
|
|
|
|
if sr != target_sr: |
|
|
wav = F.interpolate(wav, size=int(wav.shape[-1] * target_sr / sr), mode='linear', align_corners=False) |
|
|
|
|
|
return wav |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WNConv1d(nn.Module): |
|
|
"""Weight-normalized Conv1d using parametrizations API to match checkpoint structure.""" |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): |
|
|
super().__init__() |
|
|
conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) |
|
|
|
|
|
self.conv = nn.utils.parametrizations.weight_norm(conv) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class WNConvTranspose1d(nn.Module): |
|
|
"""Weight-normalized ConvTranspose1d using parametrizations API.""" |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True): |
|
|
super().__init__() |
|
|
convtr = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias) |
|
|
self.convtr = nn.utils.parametrizations.weight_norm(convtr) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.convtr(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ConvWrapper(nn.Module): |
|
|
"""Wrapper to match checkpoint structure: conv.conv.weight_g, conv.conv.weight_v, conv.conv.bias""" |
|
|
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0): |
|
|
super().__init__() |
|
|
self.conv = WNConv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class _ResBlockWrapper(nn.Module): |
|
|
"""Wrapper to match checkpoint structure: block.1.conv.conv, block.3.conv.conv, shortcut.conv.conv""" |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.block = nn.Sequential() |
|
|
self.block.add_module('0', nn.ELU()) |
|
|
self.block.add_module('1', _ConvWrapper(dim, dim // 2, 3, padding=1)) |
|
|
self.block.add_module('2', nn.ELU()) |
|
|
self.block.add_module('3', _ConvWrapper(dim // 2, dim, 1)) |
|
|
self.shortcut = _ConvWrapper(dim, dim, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.shortcut(x) + self.block(x) |
|
|
|
|
|
|
|
|
class _LSTMWrapper(nn.Module): |
|
|
"""LSTM wrapper matching checkpoint: lstm.weight_ih_l0, etc.""" |
|
|
def __init__(self, dim, num_layers=2): |
|
|
super().__init__() |
|
|
self.lstm = nn.LSTM(dim, dim, num_layers=num_layers, batch_first=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.transpose(1, 2) |
|
|
y, _ = self.lstm(x) |
|
|
y = y + x |
|
|
return y.transpose(1, 2) |
|
|
|
|
|
|
|
|
class EncoderModel(nn.Module): |
|
|
""" |
|
|
Encoder matching checkpoint: feature_extractor.encodec.encoder.model.* |
|
|
|
|
|
Structure based on checkpoint: |
|
|
- model.0: initial conv (1 -> 32) |
|
|
- model.1: residual block (32) |
|
|
- model.2: ELU (not saved) |
|
|
- model.3: downsample conv (32->64, stride=2) |
|
|
- model.4: residual block (64) |
|
|
- model.5: ELU |
|
|
- model.6: downsample conv (64->128, stride=4) |
|
|
- model.7: residual block (128) |
|
|
- model.8: ELU |
|
|
- model.9: downsample conv (128->256, stride=5) |
|
|
- model.10: residual block (256) |
|
|
- model.11: ELU |
|
|
- model.12: downsample conv (256->512, stride=8) |
|
|
- model.13: LSTM |
|
|
- model.14: ELU |
|
|
- model.15: output conv (512->512) |
|
|
""" |
|
|
def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8]): |
|
|
super().__init__() |
|
|
|
|
|
layers = [] |
|
|
|
|
|
|
|
|
layers.append(_ConvWrapper(channels, n_filters, 7, padding=3)) |
|
|
|
|
|
|
|
|
in_ch = n_filters |
|
|
for ratio in ratios: |
|
|
out_ch = in_ch * 2 |
|
|
|
|
|
layers.append(_ResBlockWrapper(in_ch)) |
|
|
|
|
|
layers.append(nn.ELU()) |
|
|
|
|
|
layers.append(_ConvWrapper(in_ch, out_ch, ratio * 2, stride=ratio, padding=ratio // 2)) |
|
|
in_ch = out_ch |
|
|
|
|
|
|
|
|
layers.append(_LSTMWrapper(in_ch)) |
|
|
|
|
|
|
|
|
layers.append(nn.ELU()) |
|
|
|
|
|
|
|
|
layers.append(_ConvWrapper(in_ch, dimension, 7, padding=3)) |
|
|
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Codebook(nn.Module): |
|
|
"""Codebook matching checkpoint: _codebook.embed, _codebook.inited, _codebook.cluster_size, _codebook.embed_avg""" |
|
|
def __init__(self, num_embeddings, embedding_dim): |
|
|
super().__init__() |
|
|
|
|
|
self.register_buffer('inited', torch.zeros(1)) |
|
|
self.register_buffer('cluster_size', torch.zeros(num_embeddings)) |
|
|
self.register_buffer('embed', torch.randn(num_embeddings, embedding_dim)) |
|
|
self.register_buffer('embed_avg', torch.randn(num_embeddings, embedding_dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: (B, T, D) input |
|
|
Returns: |
|
|
quantized: (B, T, D) quantized output |
|
|
indices: (B, T) codebook indices |
|
|
""" |
|
|
|
|
|
embed = F.normalize(self.embed, dim=-1) |
|
|
x_norm = F.normalize(x, dim=-1) |
|
|
|
|
|
|
|
|
dist = torch.cdist(x_norm, embed) |
|
|
indices = dist.argmin(dim=-1) |
|
|
|
|
|
|
|
|
quantized = F.embedding(indices, embed) |
|
|
|
|
|
|
|
|
quantized = x_norm + (quantized - x_norm).detach() |
|
|
|
|
|
return quantized, indices |
|
|
|
|
|
def decode(self, indices): |
|
|
embed = F.normalize(self.embed, dim=-1) |
|
|
return F.embedding(indices, embed) |
|
|
|
|
|
|
|
|
class VQLayer(nn.Module): |
|
|
"""VQ layer matching checkpoint: vq.layers.0._codebook.*""" |
|
|
def __init__(self, dim, codebook_size): |
|
|
super().__init__() |
|
|
self._codebook = Codebook(codebook_size, dim) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
quantized, indices = self._codebook(x) |
|
|
return quantized.transpose(1, 2), indices |
|
|
|
|
|
def decode(self, indices): |
|
|
quantized = self._codebook.decode(indices) |
|
|
return quantized.transpose(1, 2) |
|
|
|
|
|
|
|
|
class VQ(nn.Module): |
|
|
"""VQ wrapper matching checkpoint: vq.layers""" |
|
|
def __init__(self, dim, codebook_size, num_quantizers=1): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([ |
|
|
VQLayer(dim, codebook_size) for _ in range(num_quantizers) |
|
|
]) |
|
|
|
|
|
def forward(self, x): |
|
|
indices_list = [] |
|
|
quantized = torch.zeros_like(x) |
|
|
residual = x |
|
|
|
|
|
for layer in self.layers: |
|
|
q, idx = layer(residual) |
|
|
residual = residual - q |
|
|
quantized = quantized + q |
|
|
indices_list.append(idx) |
|
|
|
|
|
indices = torch.stack(indices_list, dim=1) |
|
|
return quantized, indices |
|
|
|
|
|
def decode(self, indices): |
|
|
quantized = None |
|
|
for i, layer in enumerate(self.layers): |
|
|
q = layer.decode(indices[:, i]) |
|
|
quantized = q if quantized is None else quantized + q |
|
|
return quantized |
|
|
|
|
|
|
|
|
class Quantizer(nn.Module): |
|
|
"""Quantizer matching checkpoint: quantizer.vq""" |
|
|
def __init__(self, dim, codebook_size, num_quantizers=1): |
|
|
super().__init__() |
|
|
self.vq = VQ(dim, codebook_size, num_quantizers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.vq(x) |
|
|
|
|
|
def decode(self, indices): |
|
|
return self.vq.decode(indices) |
|
|
|
|
|
|
|
|
class EnCodecWrapper(nn.Module): |
|
|
"""Wrapper matching checkpoint: encodec.encoder, encodec.quantizer""" |
|
|
def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8], |
|
|
codebook_size=4096, num_quantizers=1): |
|
|
super().__init__() |
|
|
self.encoder = EncoderModel(channels, n_filters, dimension, ratios) |
|
|
self.quantizer = Quantizer(dimension, codebook_size, num_quantizers) |
|
|
|
|
|
|
|
|
def encode(self, x): |
|
|
z = self.encoder(x) |
|
|
z_q, codes = self.quantizer(z) |
|
|
return z_q, codes |
|
|
|
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
|
"""Feature extractor matching checkpoint: feature_extractor.encodec""" |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__() |
|
|
self.encodec = EnCodecWrapper(**kwargs) |
|
|
|
|
|
def encode(self, x): |
|
|
return self.encodec.encode(x) |
|
|
|
|
|
def decode_codes(self, codes): |
|
|
return self.encodec.quantizer.decode(codes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaLayerNorm(nn.Module): |
|
|
""" |
|
|
Bandwidth-conditioned Adaptive LayerNorm. |
|
|
|
|
|
Checkpoint structure: |
|
|
- norm.scale.weight: [4, 768] (4 bandwidth conditions) |
|
|
- norm.shift.weight: [4, 768] |
|
|
""" |
|
|
def __init__(self, dim, num_bandwidths=4, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.dim = dim |
|
|
|
|
|
self.scale = nn.Embedding(num_bandwidths, dim) |
|
|
self.shift = nn.Embedding(num_bandwidths, dim) |
|
|
|
|
|
|
|
|
nn.init.ones_(self.scale.weight) |
|
|
nn.init.zeros_(self.shift.weight) |
|
|
|
|
|
def forward(self, x, bandwidth_id=None): |
|
|
""" |
|
|
Args: |
|
|
x: (B, C, T) input |
|
|
bandwidth_id: (B,) bandwidth index, or None for default (0) |
|
|
""" |
|
|
|
|
|
mean = x.mean(dim=1, keepdim=True) |
|
|
var = x.var(dim=1, keepdim=True, unbiased=False) |
|
|
x = (x - mean) / torch.sqrt(var + self.eps) |
|
|
|
|
|
|
|
|
if bandwidth_id is None: |
|
|
bandwidth_id = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) |
|
|
|
|
|
scale = self.scale(bandwidth_id) |
|
|
shift = self.shift(bandwidth_id) |
|
|
|
|
|
|
|
|
x = x * scale.unsqueeze(-1) + shift.unsqueeze(-1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
|
""" |
|
|
ConvNeXt block matching checkpoint structure exactly. |
|
|
|
|
|
Checkpoint keys: |
|
|
- dwconv.weight: [768, 1, 7] |
|
|
- dwconv.bias: [768] |
|
|
- norm.scale.weight: [4, 768] |
|
|
- norm.shift.weight: [4, 768] |
|
|
- pwconv1.weight: [2304, 768] |
|
|
- pwconv1.bias: [2304] |
|
|
- pwconv2.weight: [768, 2304] |
|
|
- pwconv2.bias: [768] |
|
|
- gamma: [768] |
|
|
""" |
|
|
def __init__(self, dim, intermediate_dim, kernel_size=7, layer_scale_init=1e-6, num_bandwidths=4): |
|
|
super().__init__() |
|
|
padding = (kernel_size - 1) // 2 |
|
|
|
|
|
self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim) |
|
|
self.norm = AdaLayerNorm(dim, num_bandwidths) |
|
|
self.pwconv1 = nn.Linear(dim, intermediate_dim) |
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim) |
|
|
self.gamma = nn.Parameter(layer_scale_init * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x, bandwidth_id=None): |
|
|
residual = x |
|
|
x = self.dwconv(x) |
|
|
x = self.norm(x, bandwidth_id) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.pwconv1(x) |
|
|
x = F.gelu(x) |
|
|
x = self.pwconv2(x) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.gamma.unsqueeze(0).unsqueeze(-1) * x |
|
|
return residual + x |
|
|
|
|
|
|
|
|
class Backbone(nn.Module): |
|
|
""" |
|
|
Vocos backbone matching checkpoint structure. |
|
|
|
|
|
Checkpoint keys: |
|
|
- embed.weight, embed.bias |
|
|
- norm.scale.weight, norm.shift.weight |
|
|
- convnext.0-11.* |
|
|
- final_layer_norm.weight, final_layer_norm.bias |
|
|
""" |
|
|
def __init__(self, input_dim=512, dim=768, intermediate_dim=2304, num_blocks=12, |
|
|
num_bandwidths=4): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.embed = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3) |
|
|
|
|
|
|
|
|
self.norm = AdaLayerNorm(dim, num_bandwidths) |
|
|
|
|
|
|
|
|
self.convnext = nn.ModuleList([ |
|
|
ConvNeXtBlock(dim, intermediate_dim, num_bandwidths=num_bandwidths) |
|
|
for _ in range(num_blocks) |
|
|
]) |
|
|
|
|
|
|
|
|
self.final_layer_norm = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x, bandwidth_id=None): |
|
|
|
|
|
x = self.embed(x) |
|
|
x = self.norm(x, bandwidth_id) |
|
|
|
|
|
|
|
|
for block in self.convnext: |
|
|
x = block(x, bandwidth_id) |
|
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
x = self.final_layer_norm(x) |
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ISTFT(nn.Module): |
|
|
"""ISTFT module matching checkpoint: istft.window""" |
|
|
def __init__(self, n_fft=1280): |
|
|
super().__init__() |
|
|
self.n_fft = n_fft |
|
|
self.register_buffer('window', torch.hann_window(n_fft)) |
|
|
|
|
|
|
|
|
class ISTFTHead(nn.Module): |
|
|
""" |
|
|
iSTFT head matching checkpoint structure. |
|
|
|
|
|
Checkpoint keys: |
|
|
- out.weight: [1282, 768] |
|
|
- out.bias: [1282] |
|
|
- istft.window: [1280] |
|
|
""" |
|
|
def __init__(self, dim, n_fft=1280, hop_length=320, padding='center'): |
|
|
super().__init__() |
|
|
self.n_fft = n_fft |
|
|
self.hop_length = hop_length |
|
|
self.padding = padding |
|
|
|
|
|
|
|
|
self.out = nn.Linear(dim, n_fft + 2) |
|
|
|
|
|
|
|
|
self.istft = ISTFT(n_fft) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: (B, C, T) backbone output |
|
|
Returns: |
|
|
audio: (B, 1, samples) |
|
|
""" |
|
|
B, C, T = x.shape |
|
|
x = x.transpose(1, 2) |
|
|
x = self.out(x) |
|
|
|
|
|
|
|
|
n_bins = self.n_fft // 2 + 1 |
|
|
mag = torch.exp(x[:, :, :n_bins]) |
|
|
phase = x[:, :, n_bins:] |
|
|
|
|
|
|
|
|
stft = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase)) |
|
|
stft = stft.transpose(1, 2) |
|
|
|
|
|
|
|
|
audio = torch.istft( |
|
|
stft, |
|
|
n_fft=self.n_fft, |
|
|
hop_length=self.hop_length, |
|
|
win_length=self.n_fft, |
|
|
window=self.istft.window, |
|
|
center=(self.padding == 'center'), |
|
|
return_complex=False, |
|
|
) |
|
|
|
|
|
return audio.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WavTokenizer(PreTrainedModel): |
|
|
""" |
|
|
WavTokenizer model for audio tokenization. |
|
|
|
|
|
This implementation exactly matches the checkpoint structure for direct weight loading. |
|
|
""" |
|
|
|
|
|
config_class = WavTokenizerConfig |
|
|
base_model_prefix = "wavtokenizer" |
|
|
|
|
|
def __init__(self, config: WavTokenizerConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
|
|
|
self.feature_extractor = FeatureExtractor( |
|
|
channels=1, |
|
|
n_filters=config.encoder_dim, |
|
|
dimension=config.latent_dim, |
|
|
ratios=config.encoder_rates, |
|
|
codebook_size=config.codebook_size, |
|
|
num_quantizers=config.num_quantizers, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.backbone = Backbone( |
|
|
input_dim=config.latent_dim, |
|
|
dim=config.backbone_dim, |
|
|
intermediate_dim=config.backbone_intermediate_dim, |
|
|
num_blocks=config.backbone_num_blocks, |
|
|
num_bandwidths=4, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.head = ISTFTHead( |
|
|
dim=config.backbone_dim, |
|
|
n_fft=config.n_fft, |
|
|
hop_length=config.hop_length, |
|
|
padding=config.padding, |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def encode(self, audio, bandwidth_id=None): |
|
|
""" |
|
|
Encode audio to quantized features and codes. |
|
|
|
|
|
Args: |
|
|
audio: (B, 1, T) audio waveform |
|
|
bandwidth_id: Optional (B,) bandwidth index |
|
|
|
|
|
Returns: |
|
|
features: (B, D, T') quantized features |
|
|
codes: (B, num_quantizers, T') discrete codes |
|
|
""" |
|
|
return self.feature_extractor.encode(audio) |
|
|
|
|
|
def encode_infer(self, audio, bandwidth_id=None): |
|
|
""" |
|
|
Encode audio for inference. |
|
|
|
|
|
Args: |
|
|
audio: (B, 1, T) audio waveform |
|
|
bandwidth_id: Optional bandwidth index (scalar or tensor) |
|
|
|
|
|
Returns: |
|
|
features: (B, D, T') quantized features |
|
|
codes: (B, T') discrete codes (squeezed for single quantizer) |
|
|
""" |
|
|
features, codes = self.encode(audio, bandwidth_id) |
|
|
if codes.shape[1] == 1: |
|
|
codes = codes.squeeze(1) |
|
|
return features, codes |
|
|
|
|
|
def decode(self, features, bandwidth_id=None): |
|
|
""" |
|
|
Decode features to audio. |
|
|
|
|
|
Args: |
|
|
features: (B, D, T') quantized features |
|
|
bandwidth_id: Optional (B,) bandwidth index |
|
|
|
|
|
Returns: |
|
|
audio: (B, 1, T) reconstructed waveform |
|
|
""" |
|
|
x = self.backbone(features, bandwidth_id) |
|
|
return self.head(x) |
|
|
|
|
|
def codes_to_features(self, codes): |
|
|
""" |
|
|
Convert discrete codes back to continuous features. |
|
|
|
|
|
Args: |
|
|
codes: (B, T) or (B, num_quantizers, T) discrete codes |
|
|
|
|
|
Returns: |
|
|
features: (B, D, T) continuous features |
|
|
""" |
|
|
if codes.dim() == 2: |
|
|
codes = codes.unsqueeze(1) |
|
|
return self.feature_extractor.decode_codes(codes) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
bandwidth_id: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
HuggingFace-style forward pass. |
|
|
|
|
|
Args: |
|
|
input_values: (B, 1, T) or (B, T) audio waveform |
|
|
input_ids: (B, T) or (B, num_quantizers, T) discrete codes |
|
|
bandwidth_id: Optional (B,) bandwidth index |
|
|
|
|
|
Returns: |
|
|
BaseModelOutput with last_hidden_state (features) and hidden_states (codes, audio) |
|
|
""" |
|
|
if input_values is not None: |
|
|
if input_values.dim() == 2: |
|
|
input_values = input_values.unsqueeze(1) |
|
|
|
|
|
features, codes = self.encode(input_values, bandwidth_id) |
|
|
audio = self.decode(features, bandwidth_id) |
|
|
|
|
|
return BaseModelOutput( |
|
|
last_hidden_state=features, |
|
|
hidden_states=(codes, audio), |
|
|
) |
|
|
|
|
|
elif input_ids is not None: |
|
|
features = self.codes_to_features(input_ids) |
|
|
audio = self.decode(features, bandwidth_id) |
|
|
|
|
|
return BaseModelOutput( |
|
|
last_hidden_state=features, |
|
|
hidden_states=(input_ids, audio), |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError("Either input_values or input_ids must be provided") |
|
|
|