File size: 4,156 Bytes
7d71c91
 
 
 
 
 
 
 
 
 
 
 
 
6eea8b2
7d71c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
V6 Codec — MioCodec 25Hz wrapper
==================================
Single codebook, 12800 codes, 25fps, 24kHz.
Supports global_embedding for voice cloning.
"""

import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional, Union

from config import (
    CODEC_MODEL_NAME, CODEC_SAMPLE_RATE,
    CODEC_CODEBOOK_SIZE, CODEC_FRAME_RATE,
)


class CodecV6:
    def __init__(self, device: str = "cuda"):
        self.device = device
        self.sample_rate = CODEC_SAMPLE_RATE  # 24000
        self.codebook_size = CODEC_CODEBOOK_SIZE  # 12800
        self.frame_rate = CODEC_FRAME_RATE  # 25.0
        self._load_model()

    def _load_model(self):
        from miocodec import MioCodecModel
        self.model = MioCodecModel.from_pretrained(CODEC_MODEL_NAME)
        self.model = self.model.to(self.device).eval()
        print(f"MioCodec loaded: {CODEC_MODEL_NAME}, {self.sample_rate}Hz, "
              f"{self.frame_rate}fps, {self.codebook_size} codes")

    @torch.no_grad()
    def encode(self, wav_path: str | Path) -> dict:
        """
        Encode wav file → MioCodec codes + global_embedding.

        Returns:
            dict with:
                'codes': torch.Tensor [num_frames] — token indices in [0, 12799]
                'global_embedding': torch.Tensor [128] — speaker embedding
        """
        data, sr = sf.read(str(wav_path), dtype='float32')
        waveform = torch.from_numpy(data)
        if waveform.dim() == 2:  # stereo
            waveform = waveform.mean(1)
        waveform = waveform.unsqueeze(0)  # [1, samples]
        if sr != self.sample_rate:
            import torchaudio
            waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)

        audio = waveform.to(self.device)

        # MioCodec encode returns (content_token_indices, global_embedding)
        result = self.model.encode(audio)
        codes = result.content_token_indices.squeeze().cpu()       # [num_frames]
        global_emb = result.global_embedding.squeeze().cpu()       # [128]

        return {
            'codes': codes,
            'global_embedding': global_emb,
        }

    @torch.no_grad()
    def decode(self, codes: torch.Tensor,
               global_embedding: torch.Tensor) -> torch.Tensor:
        """
        Decode MioCodec codes → waveform.

        Args:
            codes: [num_frames] — token indices in [0, 12799]
            global_embedding: [128] — speaker embedding

        Returns:
            waveform: [samples] float32
        """
        codes = codes.to(self.device)
        global_embedding = global_embedding.to(self.device)

        # MioCodec expects flat tensors: codes [num_frames], emb [128]
        if codes.dim() > 1:
            codes = codes.squeeze()
        if global_embedding.dim() > 1:
            global_embedding = global_embedding.squeeze()

        audio = self.model.decode(
            global_embedding=global_embedding,
            content_token_indices=codes,
        )
        return audio.squeeze().cpu().float()

    def encode_to_tokens(self, wav_path: str) -> dict:
        """Convenience: encode and return codes + embedding."""
        return self.encode(wav_path)

    def tokens_to_wav(self, codes: torch.Tensor,
                      global_embedding: torch.Tensor,
                      output: Optional[str] = None) -> torch.Tensor:
        """Decode tokens to wav, optionally save."""
        wav = self.decode(codes, global_embedding)
        if output:
            sf.write(output, wav.numpy(), self.sample_rate)
        return wav

    def get_stats(self, wav_path: str) -> dict:
        """Get encoding stats for a wav file."""
        result = self.encode(wav_path)
        data, sr = sf.read(str(wav_path), dtype='float32')
        dur = len(data) / sr if data.ndim == 1 else data.shape[0] / sr
        n_tokens = len(result['codes'])
        return {
            "duration_sec": dur,
            "num_tokens": n_tokens,
            "tokens_per_sec": n_tokens / dur if dur > 0 else 0,
            "global_emb_shape": tuple(result['global_embedding'].shape),
        }