File size: 4,156 Bytes
7d71c91 a3ed05a 7d71c91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | """
V6 Codec — MioCodec 25Hz wrapper
==================================
Single codebook, 12800 codes, 25fps, 24kHz.
Supports global_embedding for voice cloning.
"""
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional, Union
from config import (
CODEC_MODEL_NAME, CODEC_SAMPLE_RATE,
CODEC_CODEBOOK_SIZE, CODEC_FRAME_RATE,
)
class CodecV6:
def __init__(self, device: str = "cuda"):
self.device = device
self.sample_rate = CODEC_SAMPLE_RATE # 24000
self.codebook_size = CODEC_CODEBOOK_SIZE # 12800
self.frame_rate = CODEC_FRAME_RATE # 25.0
self._load_model()
def _load_model(self):
from miocodec import MioCodecModel
self.model = MioCodecModel.from_pretrained(CODEC_MODEL_NAME)
self.model = self.model.to(self.device).eval()
print(f"MioCodec loaded: {CODEC_MODEL_NAME}, {self.sample_rate}Hz, "
f"{self.frame_rate}fps, {self.codebook_size} codes")
@torch.no_grad()
def encode(self, wav_path: str | Path) -> dict:
"""
Encode wav file → MioCodec codes + global_embedding.
Returns:
dict with:
'codes': torch.Tensor [num_frames] — token indices in [0, 12799]
'global_embedding': torch.Tensor [128] — speaker embedding
"""
data, sr = sf.read(str(wav_path), dtype='float32')
waveform = torch.from_numpy(data)
if waveform.dim() == 2: # stereo
waveform = waveform.mean(1)
waveform = waveform.unsqueeze(0) # [1, samples]
if sr != self.sample_rate:
import torchaudio
waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
audio = waveform.to(self.device)
# MioCodec encode returns (content_token_indices, global_embedding)
result = self.model.encode(audio)
codes = result.content_token_indices.squeeze().cpu() # [num_frames]
global_emb = result.global_embedding.squeeze().cpu() # [128]
return {
'codes': codes,
'global_embedding': global_emb,
}
@torch.no_grad()
def decode(self, codes: torch.Tensor,
global_embedding: torch.Tensor) -> torch.Tensor:
"""
Decode MioCodec codes → waveform.
Args:
codes: [num_frames] — token indices in [0, 12799]
global_embedding: [128] — speaker embedding
Returns:
waveform: [samples] float32
"""
codes = codes.to(self.device)
global_embedding = global_embedding.to(self.device)
# MioCodec expects flat tensors: codes [num_frames], emb [128]
if codes.dim() > 1:
codes = codes.squeeze()
if global_embedding.dim() > 1:
global_embedding = global_embedding.squeeze()
audio = self.model.decode(
global_embedding=global_embedding,
content_token_indices=codes,
)
return audio.squeeze().cpu().float()
def encode_to_tokens(self, wav_path: str) -> dict:
"""Convenience: encode and return codes + embedding."""
return self.encode(wav_path)
def tokens_to_wav(self, codes: torch.Tensor,
global_embedding: torch.Tensor,
output: Optional[str] = None) -> torch.Tensor:
"""Decode tokens to wav, optionally save."""
wav = self.decode(codes, global_embedding)
if output:
sf.write(output, wav.numpy(), self.sample_rate)
return wav
def get_stats(self, wav_path: str) -> dict:
"""Get encoding stats for a wav file."""
result = self.encode(wav_path)
data, sr = sf.read(str(wav_path), dtype='float32')
dur = len(data) / sr if data.ndim == 1 else data.shape[0] / sr
n_tokens = len(result['codes'])
return {
"duration_sec": dur,
"num_tokens": n_tokens,
"tokens_per_sec": n_tokens / dur if dur > 0 else 0,
"global_emb_shape": tuple(result['global_embedding'].shape),
}
|