| | """ |
| | V6 Codec β MioCodec 25Hz wrapper |
| | ================================== |
| | Single codebook, 12800 codes, 25fps, 24kHz. |
| | Supports global_embedding for voice cloning. |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | import soundfile as sf |
| | from pathlib import Path |
| | from typing import Optional, Union |
| |
|
| | from config import ( |
| | CODEC_MODEL_NAME, CODEC_SAMPLE_RATE, |
| | CODEC_CODEBOOK_SIZE, CODEC_FRAME_RATE, |
| | ) |
| |
|
| |
|
| | class CodecV6: |
| | def __init__(self, device: str = "cuda"): |
| | self.device = device |
| | self.sample_rate = CODEC_SAMPLE_RATE |
| | self.codebook_size = CODEC_CODEBOOK_SIZE |
| | self.frame_rate = CODEC_FRAME_RATE |
| | self._load_model() |
| |
|
| | def _load_model(self): |
| | from miocodec import MioCodecModel |
| | self.model = MioCodecModel.from_pretrained(CODEC_MODEL_NAME) |
| | self.model = self.model.to(self.device).eval() |
| | print(f"MioCodec loaded: {CODEC_MODEL_NAME}, {self.sample_rate}Hz, " |
| | f"{self.frame_rate}fps, {self.codebook_size} codes") |
| |
|
| | @torch.no_grad() |
| | def encode(self, wav_path: str | Path) -> dict: |
| | """ |
| | Encode wav file β MioCodec codes + global_embedding. |
| | |
| | Returns: |
| | dict with: |
| | 'codes': torch.Tensor [num_frames] β token indices in [0, 12799] |
| | 'global_embedding': torch.Tensor [128] β speaker embedding |
| | """ |
| | data, sr = sf.read(str(wav_path), dtype='float32') |
| | waveform = torch.from_numpy(data) |
| | if waveform.dim() == 2: |
| | waveform = waveform.mean(1) |
| | waveform = waveform.unsqueeze(0) |
| | if sr != self.sample_rate: |
| | import torchaudio |
| | waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) |
| |
|
| | audio = waveform.to(self.device) |
| |
|
| | |
| | result = self.model.encode(audio) |
| | codes = result.content_token_indices.squeeze().cpu() |
| | global_emb = result.global_embedding.squeeze().cpu() |
| |
|
| | return { |
| | 'codes': codes, |
| | 'global_embedding': global_emb, |
| | } |
| |
|
| | @torch.no_grad() |
| | def decode(self, codes: torch.Tensor, |
| | global_embedding: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Decode MioCodec codes β waveform. |
| | |
| | Args: |
| | codes: [num_frames] β token indices in [0, 12799] |
| | global_embedding: [128] β speaker embedding |
| | |
| | Returns: |
| | waveform: [samples] float32 |
| | """ |
| | codes = codes.to(self.device) |
| | global_embedding = global_embedding.to(self.device) |
| |
|
| | |
| | if codes.dim() > 1: |
| | codes = codes.squeeze() |
| | if global_embedding.dim() > 1: |
| | global_embedding = global_embedding.squeeze() |
| |
|
| | audio = self.model.decode( |
| | global_embedding=global_embedding, |
| | content_token_indices=codes, |
| | ) |
| | return audio.squeeze().cpu().float() |
| |
|
| | def encode_to_tokens(self, wav_path: str) -> dict: |
| | """Convenience: encode and return codes + embedding.""" |
| | return self.encode(wav_path) |
| |
|
| | def tokens_to_wav(self, codes: torch.Tensor, |
| | global_embedding: torch.Tensor, |
| | output: Optional[str] = None) -> torch.Tensor: |
| | """Decode tokens to wav, optionally save.""" |
| | wav = self.decode(codes, global_embedding) |
| | if output: |
| | sf.write(output, wav.numpy(), self.sample_rate) |
| | return wav |
| |
|
| | def get_stats(self, wav_path: str) -> dict: |
| | """Get encoding stats for a wav file.""" |
| | result = self.encode(wav_path) |
| | data, sr = sf.read(str(wav_path), dtype='float32') |
| | dur = len(data) / sr if data.ndim == 1 else data.shape[0] / sr |
| | n_tokens = len(result['codes']) |
| | return { |
| | "duration_sec": dur, |
| | "num_tokens": n_tokens, |
| | "tokens_per_sec": n_tokens / dur if dur > 0 else 0, |
| | "global_emb_shape": tuple(result['global_embedding'].shape), |
| | } |
| |
|