| """
|
| V6 Codec — MioCodec 25Hz wrapper
|
| ==================================
|
| Single codebook, 12800 codes, 25fps, 24kHz.
|
| Supports global_embedding for voice cloning.
|
| """
|
|
|
| import torch
|
| import numpy as np
|
| import soundfile as sf
|
| from pathlib import Path
|
| from typing import Optional, Union
|
|
|
| from config import (
|
| CODEC_MODEL_NAME, CODEC_SAMPLE_RATE,
|
| CODEC_CODEBOOK_SIZE, CODEC_FRAME_RATE,
|
| )
|
|
|
|
|
| class CodecV6:
|
| def __init__(self, device: str = "cuda"):
|
| self.device = device
|
| self.sample_rate = CODEC_SAMPLE_RATE
|
| self.codebook_size = CODEC_CODEBOOK_SIZE
|
| self.frame_rate = CODEC_FRAME_RATE
|
| self._load_model()
|
|
|
| def _load_model(self):
|
| from miocodec import MioCodecModel
|
| self.model = MioCodecModel.from_pretrained(CODEC_MODEL_NAME)
|
| self.model = self.model.to(self.device).eval()
|
| print(f"MioCodec loaded: {CODEC_MODEL_NAME}, {self.sample_rate}Hz, "
|
| f"{self.frame_rate}fps, {self.codebook_size} codes")
|
|
|
| @torch.no_grad()
|
| def encode(self, wav_path: str | Path) -> dict:
|
| """
|
| Encode wav file → MioCodec codes + global_embedding.
|
| """
|
| data, sr = sf.read(str(wav_path), dtype='float32')
|
| waveform = torch.from_numpy(data)
|
| return self.encode_waveform(waveform, sr)
|
|
|
| @torch.no_grad()
|
| def encode_waveform(self, waveform: torch.Tensor, sr: int) -> dict:
|
| """
|
| Encode directly from waveform tensor.
|
| waveform: [samples] or [channels, samples]
|
| sr: int
|
| """
|
| if waveform.dim() == 2:
|
| waveform = waveform.mean(1)
|
| if waveform.dim() == 1:
|
| waveform = waveform.unsqueeze(0)
|
|
|
| if sr != self.sample_rate:
|
| import torchaudio
|
| waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
|
|
|
| audio = waveform.to(self.device).float()
|
|
|
|
|
| result = self.model.encode(audio)
|
| codes = result.content_token_indices.squeeze().cpu()
|
| global_emb = result.global_embedding.squeeze().cpu()
|
|
|
| return {
|
| 'codes': codes,
|
| 'global_embedding': global_emb,
|
| }
|
|
|
| @torch.no_grad()
|
| def decode(self, codes: torch.Tensor,
|
| global_embedding: torch.Tensor) -> torch.Tensor:
|
| """
|
| Decode MioCodec codes → waveform.
|
|
|
| Args:
|
| codes: [num_frames] — token indices in [0, 12799]
|
| global_embedding: [128] — speaker embedding
|
|
|
| Returns:
|
| waveform: [samples] float32
|
| """
|
| codes = codes.to(self.device)
|
| global_embedding = global_embedding.to(self.device)
|
|
|
|
|
| if codes.dim() > 1:
|
| codes = codes.squeeze()
|
| if global_embedding.dim() > 1:
|
| global_embedding = global_embedding.squeeze()
|
|
|
| audio = self.model.decode(
|
| global_embedding=global_embedding,
|
| content_token_indices=codes,
|
| )
|
| return audio.squeeze().cpu().float()
|
|
|
| def encode_to_tokens(self, wav_path: str) -> dict:
|
| """Convenience: encode and return codes + embedding."""
|
| return self.encode(wav_path)
|
|
|
| def tokens_to_wav(self, codes: torch.Tensor,
|
| global_embedding: torch.Tensor,
|
| output: Optional[str] = None) -> torch.Tensor:
|
| """Decode tokens to wav, optionally save."""
|
| wav = self.decode(codes, global_embedding)
|
| if output:
|
| sf.write(output, wav.numpy(), self.sample_rate)
|
| return wav
|
|
|
| def get_stats(self, wav_path: str) -> dict:
|
| """Get encoding stats for a wav file."""
|
| result = self.encode(wav_path)
|
| data, sr = sf.read(str(wav_path), dtype='float32')
|
| dur = len(data) / sr if data.ndim == 1 else data.shape[0] / sr
|
| n_tokens = len(result['codes'])
|
| return {
|
| "duration_sec": dur,
|
| "num_tokens": n_tokens,
|
| "tokens_per_sec": n_tokens / dur if dur > 0 else 0,
|
| "global_emb_shape": tuple(result['global_embedding'].shape),
|
| }
|
|
|