Dalzymodderever commited on
Commit
2cba492
·
1 Parent(s): a680a6c

Intial Commit

Browse files
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import torch
5
+ import gradio as gr
6
+
7
+ # --- 1. PATH SETUP ---
8
+ current_dir = os.path.dirname(os.path.abspath(__file__))
9
+ src_path = os.path.join(current_dir, "src")
10
+ if src_path not in sys.path:
11
+ sys.path.append(src_path)
12
+
13
+ # --- 2. Imports ---
14
+ try:
15
+ from kanade_tokenizer.model import KanadeModel
16
+ from kanade_tokenizer.util import load_vocoder, vocode, load_audio
17
+ except ImportError as e:
18
+ print(f"❌ IMPORT ERROR: {e}")
19
+ raise e
20
+
21
+ # --- Configuration ---
22
+ KANADE_REPO = "frothywater/kanade-25hz-clean"
23
+ KANADE_VOCODER = "hift"
24
+ DEVICE = "cpu"
25
+ SAMPLE_RATE = 24000
26
+ MAX_AUDIO_SECONDS = 30 # Limit audio to 30 seconds
27
+
28
+ print(f"🚀 Initializing on {DEVICE}...")
29
+
30
+ # --- 3. Load Models ---
31
+ print(f"📥 Loading Kanade...")
32
+ kanade_model = KanadeModel.from_pretrained(repo_id=KANADE_REPO).to(DEVICE).eval()
33
+
34
+ print(f"🔊 Loading HiFT Vocoder...")
35
+ kanade_vocoder = load_vocoder(name=KANADE_VOCODER).to(DEVICE).eval()
36
+
37
+ print("✅ Models Loaded.")
38
+
39
+ # --- Core Inference ---
40
+ def run_inference(source_wav, ref_wav):
41
+ """Run voice conversion inference on CPU"""
42
+ with torch.inference_mode():
43
+ mel_output = kanade_model.voice_conversion(source_wav, ref_wav)
44
+ generated_wav = vocode(kanade_vocoder, mel_output.unsqueeze(0))
45
+ return generated_wav
46
+
47
+ # --- Main Handler ---
48
+ def voice_conversion(source_path, reference_path):
49
+ if not source_path or not reference_path:
50
+ return None, "⚠️ Please provide both source and reference audio."
51
+
52
+ try:
53
+ # Load audio
54
+ source_wav = load_audio(source_path, sample_rate=SAMPLE_RATE).to(DEVICE)
55
+ ref_wav = load_audio(reference_path, sample_rate=SAMPLE_RATE).to(DEVICE)
56
+
57
+ # Check duration (30 second limit)
58
+ max_samples = MAX_AUDIO_SECONDS * SAMPLE_RATE
59
+
60
+ if source_wav.shape[-1] > max_samples:
61
+ source_wav = source_wav[..., :max_samples]
62
+
63
+ if ref_wav.shape[-1] > max_samples:
64
+ ref_wav = ref_wav[..., :max_samples]
65
+
66
+ # Run inference
67
+ start = time.time()
68
+ final_wav = run_inference(source_wav, ref_wav)
69
+ proc_time = time.time() - start
70
+
71
+ output_np = final_wav.squeeze().cpu().float().numpy()
72
+ output_duration = len(output_np) / SAMPLE_RATE
73
+
74
+ # RTF = processing time / audio duration (lower is better, <1 means faster than real-time)
75
+ rtf = proc_time / output_duration if output_duration > 0 else 0
76
+
77
+ return (SAMPLE_RATE, output_np), f"✅ {proc_time:.2f}s to convert {output_duration:.1f}s of audio | RTF: {rtf:.2f}x"
78
+
79
+ except Exception as e:
80
+ import traceback
81
+ traceback.print_exc()
82
+ return None, f"❌ Error: {str(e)}"
83
+
84
+ # --- Gradio Interface ---
85
+ with gr.Blocks(title="Kanade Voice Cloning") as demo:
86
+ gr.Markdown("""
87
+ # 🗣️ Kanade Voice Cloning
88
+ **Model:** `frothywater/kanade-25hz-clean`
89
+
90
+ Convert any audio into a target voice. Upload a source audio (what to say) and a reference audio (whose voice to use).
91
+
92
+ ⏱️ **Limit:** Audio is trimmed to 30 seconds max.
93
+ """)
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ source_audio = gr.Audio(label="Source Audio (Content - what to say)", type="filepath")
98
+ reference_audio = gr.Audio(label="Reference Audio (Target Voice - whose voice)", type="filepath")
99
+ convert_btn = gr.Button("🎤 Convert Voice", variant="primary")
100
+
101
+ with gr.Column():
102
+ output_audio = gr.Audio(label="Result")
103
+ status_text = gr.Textbox(label="Status", interactive=False)
104
+
105
+ convert_btn.click(
106
+ voice_conversion,
107
+ inputs=[source_audio, reference_audio],
108
+ outputs=[output_audio, status_text]
109
+ )
110
+
111
+ gr.Markdown("""
112
+ ---
113
+ **Tips:**
114
+ - For best results, use clean reference audio (3-10 seconds of clear speech)
115
+ - Source and reference should ideally be similar in speaking pace
116
+ """)
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ jsonargparse[signatures]
3
+ numpy
4
+ safetensors
5
+ soundfile
6
+ torch
7
+ torchaudio
8
+ tqdm
9
+ vocos
10
+ gradio
src/kanade_tokenizer/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import KanadeFeatures, KanadeModel, KanadeModelConfig
2
+ from .util import load_audio, load_vocoder, vocode
3
+
4
+ __all__ = [
5
+ "KanadeModel",
6
+ "KanadeModelConfig",
7
+ "KanadeFeatures",
8
+ "load_audio",
9
+ "load_vocoder",
10
+ "vocode",
11
+ ]
src/kanade_tokenizer/data/datamodule.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+ import lightning as L
5
+ import torch
6
+ from torch.utils.data import DataLoader, Dataset
7
+
8
+ from ..util import get_logger
9
+ from .dataset import AudioItem, ChunkedAudioDataset, pad_audio
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ @dataclass
15
+ class AudioBatch:
16
+ waveform: torch.Tensor # [batch, channels, samples]
17
+ audio_ids: list[str]
18
+ paths: list[Path]
19
+ sample_rates: list[int]
20
+ frame_offsets: list[int] | None # For chunked audio
21
+
22
+
23
+ @dataclass
24
+ class AudioDataConfig:
25
+ csv_path: str
26
+ audio_root: str
27
+
28
+ # Audio processing
29
+ sample_rate: int | None = 16000
30
+ mono: bool = True
31
+ normalize: bool = True
32
+
33
+ # Chunking options
34
+ chunk_size: int | None = None
35
+ chunk_hop_size: int | None = None
36
+
37
+ # DataLoader options
38
+ batch_size: int = 32
39
+ num_workers: int = 4
40
+ pin_memory: bool = False
41
+ persistent_workers: bool = False
42
+ shuffle: bool = False
43
+ drop_last: bool = False
44
+
45
+
46
+ def audio_collate_fn(batch: list[AudioItem]) -> AudioBatch:
47
+ waveforms = [item.waveform for item in batch]
48
+
49
+ # Pad all waveforms to max length
50
+ max_length = max(wave.shape[1] for wave in waveforms)
51
+ if any(wave.shape[1] != max_length for wave in waveforms):
52
+ waveforms = [pad_audio(wave, max_length) for wave in waveforms]
53
+
54
+ return AudioBatch(
55
+ waveform=torch.stack(waveforms),
56
+ audio_ids=[item.audio_id for item in batch],
57
+ paths=[item.path for item in batch],
58
+ sample_rates=[item.sample_rate for item in batch],
59
+ frame_offsets=[item.frame_offset for item in batch],
60
+ )
61
+
62
+
63
+ class AudioDataModule(L.LightningDataModule):
64
+ def __init__(
65
+ self,
66
+ train_config: AudioDataConfig,
67
+ val_config: AudioDataConfig | None = None,
68
+ test_config: AudioDataConfig | None = None,
69
+ ):
70
+ super().__init__()
71
+ self.train_config = train_config
72
+ self.val_config = val_config or train_config
73
+ self.test_config = test_config or self.val_config
74
+
75
+ # Set to be initialized in setup()
76
+ self.train_dataset: Dataset | None = None
77
+ self.val_dataset: Dataset | None = None
78
+ self.test_dataset: Dataset | None = None
79
+
80
+ def _create_dataset(self, config: AudioDataConfig) -> Dataset:
81
+ return ChunkedAudioDataset(
82
+ csv_path=config.csv_path,
83
+ audio_root=config.audio_root,
84
+ chunk_size=config.chunk_size,
85
+ hop_size=config.chunk_hop_size,
86
+ mono=config.mono,
87
+ normalize=config.normalize,
88
+ target_sample_rate=config.sample_rate,
89
+ )
90
+
91
+ def setup(self, stage: str | None = None):
92
+ if stage == "fit" or stage is None:
93
+ self.train_dataset = self._create_dataset(self.train_config)
94
+ self.val_dataset = self._create_dataset(self.val_config)
95
+ elif stage == "validate":
96
+ self.val_dataset = self._create_dataset(self.val_config)
97
+ elif stage == "test" or stage == "predict":
98
+ self.test_dataset = self._create_dataset(self.test_config)
99
+
100
+ def train_dataloader(self) -> DataLoader:
101
+ return DataLoader(
102
+ self.train_dataset,
103
+ batch_size=self.train_config.batch_size,
104
+ num_workers=self.train_config.num_workers,
105
+ pin_memory=self.train_config.pin_memory,
106
+ persistent_workers=self.train_config.persistent_workers if self.train_config.num_workers > 0 else False,
107
+ shuffle=self.train_config.shuffle,
108
+ drop_last=self.train_config.drop_last,
109
+ collate_fn=audio_collate_fn,
110
+ )
111
+
112
+ def val_dataloader(self) -> DataLoader:
113
+ return DataLoader(
114
+ self.val_dataset,
115
+ batch_size=self.val_config.batch_size,
116
+ num_workers=self.val_config.num_workers,
117
+ pin_memory=self.val_config.pin_memory,
118
+ persistent_workers=self.val_config.persistent_workers if self.val_config.num_workers > 0 else False,
119
+ shuffle=False,
120
+ drop_last=False,
121
+ collate_fn=audio_collate_fn,
122
+ )
123
+
124
+ def test_dataloader(self) -> DataLoader:
125
+ return DataLoader(
126
+ self.test_dataset,
127
+ batch_size=self.test_config.batch_size,
128
+ num_workers=self.test_config.num_workers,
129
+ pin_memory=self.test_config.pin_memory,
130
+ persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
131
+ shuffle=False,
132
+ drop_last=False,
133
+ collate_fn=audio_collate_fn,
134
+ )
135
+
136
+ def predict_dataloader(self) -> DataLoader:
137
+ return DataLoader(
138
+ self.test_dataset,
139
+ batch_size=self.test_config.batch_size,
140
+ num_workers=self.test_config.num_workers,
141
+ pin_memory=self.test_config.pin_memory,
142
+ persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
143
+ shuffle=False,
144
+ drop_last=False,
145
+ collate_fn=audio_collate_fn,
146
+ )
src/kanade_tokenizer/data/dataset.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torchaudio
7
+ from torch.utils.data import Dataset
8
+
9
+ from ..util import _load_audio_internal, get_logger
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ @dataclass
15
+ class AudioItem:
16
+ waveform: torch.Tensor
17
+ audio_id: str
18
+ path: Path
19
+ sample_rate: int
20
+ frame_offset: int | None = None # For chunked audio
21
+
22
+
23
+ def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor:
24
+ # (1, samples)
25
+ if waveform.shape[0] > 1:
26
+ return torch.mean(waveform, dim=0, keepdim=True)
27
+ return waveform
28
+
29
+
30
+ def resample_audio(waveform: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
31
+ if orig_freq != new_freq:
32
+ resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq)
33
+ return resampler(waveform)
34
+ return waveform
35
+
36
+
37
+ def normalize_audio(waveform: torch.Tensor) -> torch.Tensor:
38
+ max_val = torch.max(torch.abs(waveform)) + 1e-8
39
+ return waveform / max_val
40
+
41
+
42
+ def preprocess_audio(
43
+ waveform: torch.Tensor, sample_rate: int, mono: bool, normalize: bool, target_sample_rate: int | None = None
44
+ ) -> tuple[torch.Tensor, int]:
45
+ # Convert to mono if needed
46
+ if mono:
47
+ waveform = convert_to_mono(waveform)
48
+
49
+ # Resample if needed
50
+ if target_sample_rate is not None and sample_rate != target_sample_rate:
51
+ waveform = resample_audio(waveform, sample_rate, target_sample_rate)
52
+ sample_rate = target_sample_rate
53
+
54
+ # Normalize if needed
55
+ if normalize:
56
+ waveform = normalize_audio(waveform)
57
+
58
+ return waveform, sample_rate
59
+
60
+
61
+ def pad_audio(waveform: torch.Tensor, target_length: int) -> torch.Tensor:
62
+ current_length = waveform.shape[1]
63
+ if current_length >= target_length:
64
+ return waveform
65
+
66
+ # Calculate padding needed
67
+ pad_length = target_length - current_length
68
+ # Pad with zeros at the end
69
+ padding = torch.zeros((waveform.shape[0], pad_length), dtype=waveform.dtype, device=waveform.device)
70
+ padded_waveform = torch.cat([waveform, padding], dim=1)
71
+ return padded_waveform
72
+
73
+
74
+ @dataclass
75
+ class ChunkInfo:
76
+ audio_id: str
77
+ frame_offset: int # In target sample rate
78
+ num_frames: int # In target sample rate
79
+
80
+
81
+ class ChunkedAudioDataset(Dataset):
82
+ """
83
+ Dataset that loads audio from CSV with optional chunking.
84
+
85
+ Args:
86
+ csv_path: Path to the CSV file with columns: audio_id, path, length, sample_rate
87
+ audio_root: Root directory for audio files (prepended to paths in CSV)
88
+ chunk_size: Size of each chunk in frames (None = no chunking)
89
+ hop_size: Hop size between chunks in frames (None = use chunk_size)
90
+ mono: Convert to mono if True
91
+ normalize: Normalize audio if True
92
+ target_sample_rate: Resample to this sample rate if provided
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ csv_path: str,
98
+ audio_root: str,
99
+ chunk_size: int | None = None,
100
+ hop_size: int | None = None,
101
+ mono: bool = True,
102
+ normalize: bool = True,
103
+ target_sample_rate: int | None = None,
104
+ ):
105
+ self.csv_path = csv_path
106
+ self.audio_root = audio_root
107
+ self.chunk_size = chunk_size
108
+ self.hop_size = hop_size if hop_size is not None else chunk_size
109
+ self.mono = mono
110
+ self.normalize = normalize
111
+ self.target_sample_rate = target_sample_rate
112
+
113
+ # Load CSV and compute chunks
114
+ self.file_entries = self._load_csv()
115
+ self.chunks = self._compute_chunks()
116
+
117
+ logger.info(f"Loaded dataset from {csv_path}: {len(self.file_entries)} files, {len(self.chunks)} chunks")
118
+
119
+ def _load_csv(self) -> dict[str, dict]:
120
+ """Load audio metadata from CSV."""
121
+ entries = {}
122
+ with open(self.csv_path, "r", encoding="utf-8") as f:
123
+ reader = csv.DictReader(f)
124
+ for row in reader:
125
+ entries[row["audio_id"]] = {
126
+ "path": row["path"],
127
+ "length": int(row["length"]),
128
+ "sample_rate": int(row["sample_rate"]),
129
+ }
130
+ return entries
131
+
132
+ def _compute_chunks(self) -> list[ChunkInfo]:
133
+ """Compute all chunks from the file entries."""
134
+ chunks = []
135
+ for audio_id, entry in self.file_entries.items():
136
+ length = entry["length"]
137
+ sample_rate = entry["sample_rate"]
138
+
139
+ # Adjust length if resampling to target sample rate
140
+ if self.target_sample_rate is not None and sample_rate != self.target_sample_rate:
141
+ length = int(length * self.target_sample_rate / sample_rate)
142
+ sample_rate = self.target_sample_rate
143
+
144
+ if self.chunk_size is None or length <= self.chunk_size:
145
+ # No chunking, or file is shorter than chunk size: use entire file
146
+ chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=0, num_frames=length))
147
+ else:
148
+ # Chunking: compute all chunks with last chunk aligned to end
149
+ frame_offset = 0
150
+ while frame_offset + self.chunk_size <= length:
151
+ chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=frame_offset, num_frames=self.chunk_size))
152
+ frame_offset += self.hop_size
153
+
154
+ # Add the last chunk aligned to the end
155
+ last_start = length - self.chunk_size
156
+ if last_start > frame_offset - self.hop_size:
157
+ chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=last_start, num_frames=self.chunk_size))
158
+
159
+ return chunks
160
+
161
+ def __len__(self) -> int:
162
+ return len(self.chunks)
163
+
164
+ def __getitem__(self, idx: int) -> AudioItem:
165
+ """Load and return a single audio chunk."""
166
+ chunk = self.chunks[idx]
167
+ entry = self.file_entries[chunk.audio_id]
168
+ orig_sample_rate = entry["sample_rate"]
169
+ full_path = Path(self.audio_root) / entry["path"]
170
+
171
+ # Calculate start frame and num frames in original sample rate
172
+ if self.target_sample_rate is not None and orig_sample_rate != self.target_sample_rate:
173
+ orig_frame_offset = int(chunk.frame_offset * orig_sample_rate / self.target_sample_rate)
174
+ orig_num_frames = int(chunk.num_frames * orig_sample_rate / self.target_sample_rate)
175
+ else:
176
+ orig_frame_offset = chunk.frame_offset
177
+ orig_num_frames = chunk.num_frames
178
+
179
+ waveform, sample_rate = _load_audio_internal(
180
+ full_path, frame_offset=orig_frame_offset, num_frames=orig_num_frames
181
+ )
182
+
183
+ waveform, sample_rate = preprocess_audio(
184
+ waveform=waveform,
185
+ sample_rate=sample_rate,
186
+ mono=self.mono,
187
+ normalize=self.normalize,
188
+ target_sample_rate=self.target_sample_rate,
189
+ )
190
+
191
+ # Pad if necessary (in case file is shorter than expected)
192
+ if self.chunk_size is not None and waveform.shape[1] < self.chunk_size:
193
+ waveform = pad_audio(waveform, self.chunk_size)
194
+
195
+ return AudioItem(
196
+ waveform=waveform,
197
+ audio_id=chunk.audio_id,
198
+ path=full_path,
199
+ sample_rate=sample_rate,
200
+ frame_offset=chunk.frame_offset,
201
+ )
src/kanade_tokenizer/model.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Literal
4
+
5
+ import jsonargparse
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .module.fsq import FiniteScalarQuantizer
11
+ from .module.global_encoder import GlobalEncoder
12
+ from .module.postnet import PostNet
13
+ from .module.ssl_extractor import SSLFeatureExtractor
14
+ from .module.transformer import Transformer
15
+ from .util import freeze_modules, get_logger
16
+
17
+ logger = get_logger()
18
+
19
+
20
+ @dataclass
21
+ class KanadeModelConfig:
22
+ # SSL Feature settings
23
+ local_ssl_layers: tuple[int, ...] = (6, 9) # Indices of SSL layers for local branch
24
+ global_ssl_layers: tuple[int, ...] = (1, 2) # Indices of SSL layers for global branch
25
+ normalize_ssl_features: bool = True # Whether to normalize local SSL features before encoding
26
+
27
+ # Down/up-sampling settings
28
+ downsample_factor: int = 2 # Temporal downsampling factor for local features
29
+ mel_upsample_factor: int = 4 # Conv1DTranspose upsampling factor for mel features before interpolation
30
+ use_conv_downsample: bool = True # Whether to use Conv1D for downsampling instead average pooling
31
+ local_interpolation_mode: str = "linear" # Interpolation mode for local upsampling ("linear", "nearest")
32
+ mel_interpolation_mode: str = "linear" # Interpolation mode for mel upsampling ("linear", "nearest")
33
+
34
+ # Mel spectrogram settings
35
+ sample_rate: int = 24000
36
+ n_fft: int = 1024
37
+ hop_length: int = 256
38
+ n_mels: int = 100
39
+ padding: str = "center"
40
+ mel_fmin: int = 0 # Minimum frequency for mel spectrograms
41
+ mel_fmax: int | None = None # Maximum frequency for mel spectrograms
42
+ bigvgan_style_mel: bool = False # Whether to use BigVGAN-style mel spectrograms
43
+
44
+ # Vocoder settings
45
+ vocoder_name: Literal["vocos", "hift"] = "vocos" # Vocoder to use for waveform synthesis
46
+
47
+
48
+ @dataclass
49
+ class KanadeFeatures:
50
+ content_embedding: torch.Tensor | None = None # (seq_len, dim)
51
+ content_token_indices: torch.Tensor | None = None # (seq_len,)
52
+ global_embedding: torch.Tensor | None = None # (dim,)
53
+
54
+
55
+ class KanadeModel(nn.Module):
56
+ """Model architecture and forward pass logic for Kanade tokenizer."""
57
+
58
+ def __init__(
59
+ self,
60
+ config: KanadeModelConfig,
61
+ ssl_feature_extractor: SSLFeatureExtractor,
62
+ local_encoder: Transformer,
63
+ local_quantizer: FiniteScalarQuantizer,
64
+ global_encoder: GlobalEncoder,
65
+ mel_prenet: Transformer,
66
+ mel_decoder: Transformer,
67
+ mel_postnet: PostNet,
68
+ feature_decoder: Transformer | None = None,
69
+ ):
70
+ super().__init__()
71
+ self.config = config
72
+ self._init_ssl_extractor(config, ssl_feature_extractor)
73
+ self._init_local_branch(config, local_encoder, local_quantizer, feature_decoder)
74
+ self._init_global_branch(global_encoder)
75
+ self._init_mel_decoder(config, mel_prenet, mel_decoder, mel_postnet)
76
+
77
+ def _init_ssl_extractor(self, config: KanadeModelConfig, ssl_feature_extractor: SSLFeatureExtractor):
78
+ """Initialize and configure SSL feature extractor."""
79
+ self.ssl_feature_extractor = ssl_feature_extractor
80
+ freeze_modules([self.ssl_feature_extractor])
81
+ logger.debug(
82
+ f"SSL feature extractor initialized and frozen, feature dim: {self.ssl_feature_extractor.feature_dim}"
83
+ )
84
+
85
+ # Configure local SSL layers
86
+ self.local_ssl_layers = list(config.local_ssl_layers)
87
+ if len(self.local_ssl_layers) > 1:
88
+ logger.debug(
89
+ f"Using average of {len(self.local_ssl_layers)} SSL layers for local branch: {self.local_ssl_layers}"
90
+ )
91
+ else:
92
+ logger.debug(f"Using single SSL layer {self.local_ssl_layers[0]} for local branch")
93
+
94
+ if config.normalize_ssl_features:
95
+ logger.debug("Normalizing local SSL features before encoding")
96
+
97
+ # Configure global SSL layers
98
+ self.global_ssl_layers = list(config.global_ssl_layers)
99
+ if len(self.global_ssl_layers) > 1:
100
+ logger.debug(
101
+ f"Using average of {len(self.global_ssl_layers)} SSL layers for global branch: {self.global_ssl_layers}"
102
+ )
103
+ else:
104
+ logger.debug(f"Using single SSL layer {self.global_ssl_layers[0]} for global branch")
105
+
106
+ def _init_local_branch(
107
+ self,
108
+ config: KanadeModelConfig,
109
+ local_encoder: Transformer,
110
+ local_quantizer: FiniteScalarQuantizer,
111
+ feature_decoder: Transformer | None,
112
+ ):
113
+ """Initialize local branch components (encoder, downsampling, quantizer, decoder)."""
114
+ self.local_encoder = local_encoder
115
+ self.local_quantizer = local_quantizer
116
+ self.feature_decoder = feature_decoder
117
+
118
+ # Configure downsampling
119
+ self.downsample_factor = config.downsample_factor
120
+ if self.downsample_factor > 1:
121
+ logger.debug(f"Using temporal downsampling with factor {self.downsample_factor}")
122
+ if config.use_conv_downsample:
123
+ # Create Conv1d layers for downsampling and upsampling local embeddings
124
+ feature_dim = local_encoder.output_dim
125
+ self.conv_downsample = nn.Conv1d(
126
+ feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
127
+ )
128
+ self.conv_upsample = nn.ConvTranspose1d(
129
+ feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
130
+ ) # won't be used unless training feature reconstruction
131
+ logger.debug(f"Using Conv1d downsampling/upsampling with kernel size {config.downsample_factor}")
132
+ else:
133
+ self.conv_downsample = None
134
+ self.conv_upsample = None
135
+ logger.debug("Using average pooling and linear interpolation for downsampling/upsampling")
136
+ else:
137
+ self.conv_downsample = None
138
+ self.conv_upsample = None
139
+
140
+ def _init_global_branch(self, global_encoder: GlobalEncoder):
141
+ """Initialize global branch components."""
142
+ self.global_encoder = global_encoder
143
+
144
+ def _init_mel_decoder(
145
+ self, config: KanadeModelConfig, mel_prenet: Transformer, mel_decoder: Transformer, mel_postnet: PostNet
146
+ ):
147
+ """Initialize mel decoder components (prenet, upsampling, decoder, postnet)."""
148
+ self.mel_prenet = mel_prenet
149
+ self.mel_decoder = mel_decoder
150
+ self.mel_postnet = mel_postnet
151
+
152
+ # Configure mel upsampling
153
+ self.mel_conv_upsample = None
154
+ if config.mel_upsample_factor > 1:
155
+ # Create Conv1DTranspose layer for mel upsampling
156
+ input_dim = mel_prenet.output_dim
157
+ self.mel_conv_upsample = nn.ConvTranspose1d(
158
+ input_dim, input_dim, kernel_size=config.mel_upsample_factor, stride=config.mel_upsample_factor
159
+ )
160
+ logger.debug(f"Using Conv1DTranspose for mel upsampling with factor {config.mel_upsample_factor}")
161
+
162
+ def _calculate_waveform_padding(self, audio_length: int, ensure_recon_length: bool = False) -> int:
163
+ """Calculate required padding for input waveform to ensure consistent SSL feature lengths."""
164
+ extractor = self.ssl_feature_extractor
165
+ sample_rate = self.config.sample_rate
166
+ # SSL may resample the input to its own sample rate, so calculate the number of samples after resampling
167
+ num_samples_after_resampling = audio_length / sample_rate * extractor.ssl_sample_rate
168
+ # We expect the SSL feature extractor to be consistent with its hop size
169
+ expected_ssl_output_length = math.ceil(num_samples_after_resampling / extractor.hop_size)
170
+ # If ensure_recon_length is True, we want to make sure the output length is exactly divisible by downsample factor
171
+ if ensure_recon_length and (remainder := expected_ssl_output_length % self.downsample_factor) != 0:
172
+ expected_ssl_output_length += self.downsample_factor - remainder
173
+ # But it may require more input samples to produce that output length, so calculate the required input length
174
+ num_samples_required_after_resampling = extractor.get_minimum_input_length(expected_ssl_output_length)
175
+ # That number of samples is at the SSL sample rate, so convert back to our original sample rate
176
+ num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
177
+ # Calculate padding needed on each side
178
+ padding = math.ceil((num_samples_required - audio_length) / 2)
179
+ return padding
180
+
181
+ def _calculate_original_audio_length(self, token_length: int) -> int:
182
+ """Calculate the original audio length based on token length."""
183
+ extractor = self.ssl_feature_extractor
184
+ sample_rate = self.config.sample_rate
185
+ # Calculate the feature length before downsampling
186
+ feature_length = token_length * self.downsample_factor
187
+ num_samples_required_after_resampling = extractor.get_minimum_input_length(feature_length)
188
+ num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
189
+ return math.ceil(num_samples_required)
190
+
191
+ def _calculate_target_mel_length(self, audio_length: int) -> int:
192
+ """Calculate the target mel spectrogram length based on audio length."""
193
+ if self.config.padding == "center":
194
+ return audio_length // self.config.hop_length + 1
195
+ elif self.config.padding == "same":
196
+ return audio_length // self.config.hop_length
197
+ else:
198
+ return (audio_length - self.config.n_fft) // self.config.hop_length + 1
199
+
200
+ def _process_ssl_features(self, features: list[torch.Tensor], layers: list[int]) -> torch.Tensor:
201
+ if len(layers) > 1:
202
+ # Get features from multiple layers and average them
203
+ selected_features = [features[i - 1] for i in layers]
204
+ mixed_features = torch.stack(selected_features, dim=0).mean(dim=0)
205
+ else:
206
+ # Just take the single specified layer
207
+ mixed_features = features[layers[0] - 1]
208
+ return mixed_features
209
+
210
+ def _normalize_ssl_features(self, features: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
211
+ if not self.config.normalize_ssl_features:
212
+ return features
213
+
214
+ # Compute mean and std across time steps for each sample and feature dimension
215
+ mean = torch.mean(features, dim=1, keepdim=True) # (B, 1, C)
216
+ std = torch.std(features, dim=1, keepdim=True) # (B, 1, C)
217
+ return (features - mean) / (std + eps)
218
+
219
+ def forward_ssl_features(
220
+ self, waveform: torch.Tensor, padding: int | None = None
221
+ ) -> tuple[torch.Tensor, torch.Tensor]:
222
+ """Forward pass to extract SSL features. (B, T, C)
223
+ Args:
224
+ waveform: Input waveform tensor of shape (B, channels, samples)
225
+ padding: Optional padding to apply on both sides of the waveform. This is useful to ensure
226
+ that the SSL feature extractor produces consistent output lengths.
227
+ Returns:
228
+ local_ssl_features: Local SSL features for local branch. (B, T, C)
229
+ global_ssl_features: Global SSL features for global branch. (B, T, C)
230
+ """
231
+ # Prepare input waveform
232
+ if waveform.dim() == 3:
233
+ waveform = waveform.squeeze(1)
234
+
235
+ # 1. Extract SSL features
236
+ if padding > 0:
237
+ waveform = F.pad(waveform, (padding, padding), mode="constant")
238
+
239
+ with torch.no_grad():
240
+ ssl_features = self.ssl_feature_extractor(waveform)
241
+
242
+ local_ssl_features = self._process_ssl_features(ssl_features, self.local_ssl_layers)
243
+ local_ssl_features = self._normalize_ssl_features(local_ssl_features)
244
+
245
+ global_ssl_features = self._process_ssl_features(ssl_features, self.global_ssl_layers)
246
+
247
+ return local_ssl_features, global_ssl_features
248
+
249
+ def forward_content(
250
+ self, local_ssl_features: torch.Tensor
251
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None:
252
+ """Forward pass to extract content embeddings from the local branch.
253
+ Args:
254
+ local_ssl_features: Local SSL features tensor of shape (B, T, C)
255
+ Returns:
256
+ local_quantized: Quantized local embeddings. (B, T/factor, C)
257
+ indices: Content token indices. (B, T/factor)
258
+ ssl_recon: Reconstructed SSL features (if feature decoder is present). (B, T, C)
259
+ perplexity: Quantizer perplexity (if feature decoder is present). Scalar tensor.
260
+ """
261
+ local_encoded = self.local_encoder(local_ssl_features)
262
+
263
+ # Downsample temporally if needed: (B, T, C) -> (B, T/factor, C)
264
+ if self.downsample_factor > 1:
265
+ if self.config.use_conv_downsample:
266
+ local_encoded = self.conv_downsample(local_encoded.transpose(1, 2)).transpose(1, 2)
267
+ else:
268
+ local_encoded = F.avg_pool1d(
269
+ local_encoded.transpose(1, 2), kernel_size=self.downsample_factor, stride=self.downsample_factor
270
+ ).transpose(1, 2)
271
+
272
+ # If training feature reconstruction, decode local embeddings
273
+ ssl_recon = None
274
+ perplexity = torch.tensor(0.0)
275
+ if self.feature_decoder is not None:
276
+ local_quantized, local_quantize_info = self.local_quantizer(local_encoded)
277
+ indices = local_quantize_info["indices"]
278
+ perplexity = torch.mean(local_quantize_info["perplexity"])
279
+
280
+ local_latent_for_ssl = local_quantized
281
+ # Upsample if needed
282
+ if self.downsample_factor > 1:
283
+ if self.config.use_conv_downsample:
284
+ # Use conv transpose for upsampling: (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T) -> (B, T, C)
285
+ local_latent_for_ssl = self.conv_upsample(local_latent_for_ssl.transpose(1, 2)).transpose(1, 2)
286
+ else:
287
+ # (B, T/factor, C) -> (B, T, C)
288
+ local_latent_for_ssl = F.interpolate(
289
+ local_latent_for_ssl.transpose(1, 2),
290
+ size=local_ssl_features.shape[1],
291
+ mode=self.config.local_interpolation_mode,
292
+ ).transpose(1, 2)
293
+
294
+ ssl_recon = self.feature_decoder(local_latent_for_ssl)
295
+ else:
296
+ # If not training feature reconstruction, just get quantized local embeddings
297
+ local_quantized, indices = self.local_quantizer.encode(local_encoded)
298
+
299
+ return local_quantized, indices, ssl_recon, perplexity
300
+
301
+ def forward_global(self, global_ssl_features: torch.Tensor) -> torch.Tensor:
302
+ """Forward pass to extract global embeddings from the global branch.
303
+ Args:
304
+ global_ssl_features: Global SSL features tensor of shape (B, T, C)
305
+ Returns:
306
+ global_encoded: Global embeddings. (B, C)
307
+ """
308
+ global_encoded = self.global_encoder(global_ssl_features)
309
+ return global_encoded
310
+
311
+ def forward_mel(
312
+ self, content_embeddings: torch.Tensor, global_embeddings: torch.Tensor, mel_length: int
313
+ ) -> torch.Tensor:
314
+ """Forward pass to generate mel spectrogram from content and global embeddings.
315
+ Args:
316
+ content_embeddings: Content embeddings tensor of shape (B, T, C)
317
+ global_embeddings: Global embeddings tensor of shape (B, C)
318
+ mel_length: Target mel spectrogram length (T_mel)
319
+ Returns:
320
+ mel_recon: Reconstructed mel spectrogram tensor of shape (B, n_mels, T_mel)
321
+ """
322
+ local_latent = self.mel_prenet(content_embeddings)
323
+
324
+ # Upsample local latent to match mel spectrogram length
325
+ # First use Conv1DTranspose if configured
326
+ if self.mel_conv_upsample is not None:
327
+ # (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T*upsample_factor) -> (B, T*upsample_factor, C)
328
+ local_latent = self.mel_conv_upsample(local_latent.transpose(1, 2)).transpose(1, 2)
329
+ local_latent = F.interpolate(
330
+ local_latent.transpose(1, 2), size=mel_length, mode=self.config.mel_interpolation_mode
331
+ ).transpose(1, 2) # (B, T_current, C) -> (B, T_mel, C)
332
+
333
+ # Generate mel spectrogram, conditioned on global embeddings
334
+ mel_recon = self.mel_decoder(local_latent, condition=global_embeddings.unsqueeze(1))
335
+ mel_recon = mel_recon.transpose(1, 2) # (B, n_mels, T)
336
+
337
+ mel_recon = self.mel_postnet(mel_recon)
338
+ return mel_recon
339
+
340
+ # ======== Inference methods ========
341
+
342
+ def weights_to_save(self, *, include_modules: list[str]) -> dict[str, torch.Tensor]:
343
+ """Get model weights for saving. Excludes certain modules not needed for inference."""
344
+ excluded_modules = [
345
+ m for m in ["ssl_feature_extractor", "feature_decoder", "conv_upsample"] if m not in include_modules
346
+ ]
347
+ state_dict = {
348
+ name: param
349
+ for name, param in self.named_parameters()
350
+ if not any(name.startswith(excl) for excl in excluded_modules)
351
+ }
352
+ return state_dict
353
+
354
+ @classmethod
355
+ def from_hparams(cls, config_path: str) -> "KanadeModel":
356
+ """Instantiate KanadeModel from config file.
357
+ Args:
358
+ config_path (str): Path to model configuration file (.yaml).
359
+ Returns:
360
+ KanadeModel: Instantiated KanadeModel.
361
+ """
362
+ parser = jsonargparse.ArgumentParser(exit_on_error=False)
363
+ parser.add_argument("--model", type=KanadeModel)
364
+ cfg = parser.parse_path(config_path)
365
+ cfg = parser.instantiate_classes(cfg)
366
+ return cfg.model
367
+
368
+ @classmethod
369
+ def from_pretrained(
370
+ cls,
371
+ repo_id: str | None = None,
372
+ revision: str | None = None,
373
+ config_path: str | None = None,
374
+ weights_path: str | None = None,
375
+ ) -> "KanadeModel":
376
+ """Load KanadeModel either from HuggingFace Hub or local config and weights files.
377
+ Args:
378
+ repo_id (str, optional): HuggingFace Hub repository ID. If provided, loads config and weights from the hub.
379
+ revision (str, optional): Revision (branch, tag, commit) for the HuggingFace Hub repo.
380
+ config_path (str, optional): Path to model configuration file (.yaml). Required if repo_id is not provided.
381
+ weights_path (str, optional): Path to model weights file (.safetensors). Required if repo_id is not provided.
382
+ Returns:
383
+ KanadeModel: Loaded KanadeModel instance.
384
+ """
385
+ if repo_id is not None:
386
+ # Load from HuggingFace Hub
387
+ from huggingface_hub import hf_hub_download
388
+
389
+ config_path = hf_hub_download(repo_id, "config.yaml", revision=revision)
390
+ weights_path = hf_hub_download(repo_id, "model.safetensors", revision=revision)
391
+ else:
392
+ # Check local paths
393
+ if config_path is None or weights_path is None:
394
+ raise ValueError(
395
+ "Please provide either HuggingFace Hub repo_id or both config_path and weights_path for model loading."
396
+ )
397
+
398
+ # Load model from config
399
+ model = cls.from_hparams(config_path)
400
+
401
+ # Load weights
402
+ from safetensors.torch import load_file
403
+
404
+ state_dict = load_file(weights_path, device="cpu")
405
+ model.load_state_dict(state_dict, strict=False)
406
+ logger.info(f"Loaded weights from safetensors file: {weights_path}")
407
+
408
+ return model
409
+
410
+ @torch.inference_mode()
411
+ def encode(self, waveform: torch.Tensor, return_content: bool = True, return_global: bool = True) -> KanadeFeatures:
412
+ """Extract content and/or global features from audio using Kanade model.
413
+ Args:
414
+ waveform (torch.Tensor): Input audio waveform tensor (samples,). The sample rate should match model config.
415
+ return_content (bool): Whether to extract content features.
416
+ return_global (bool): Whether to extract global features.
417
+ Returns:
418
+ dict[str, torch.Tensor]: Extracted features.
419
+ """
420
+ audio_length = waveform.size(0)
421
+ padding = self._calculate_waveform_padding(audio_length)
422
+ local_ssl_features, global_ssl_features = self.forward_ssl_features(waveform.unsqueeze(0), padding=padding)
423
+
424
+ result = KanadeFeatures()
425
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
426
+ if return_content:
427
+ content_embedding, token_indices, _, _ = self.forward_content(local_ssl_features)
428
+ result.content_embedding = content_embedding.squeeze(0) # (seq_len, dim)
429
+ result.content_token_indices = token_indices.squeeze(0) # (seq_len,)
430
+
431
+ if return_global:
432
+ global_embedding = self.forward_global(global_ssl_features)
433
+ result.global_embedding = global_embedding.squeeze(0) # (dim,)
434
+
435
+ return result
436
+
437
+ def decode_token_indices(self, indices: torch.Tensor) -> torch.Tensor:
438
+ """Get content embeddings from content token indices. (..., seq_len) -> (..., seq_len, dim)"""
439
+ content_embedding = self.local_quantizer.decode(indices)
440
+ return content_embedding
441
+
442
+ @torch.inference_mode()
443
+ def decode(
444
+ self,
445
+ global_embedding: torch.Tensor,
446
+ content_token_indices: torch.Tensor | None = None,
447
+ content_embedding: torch.Tensor | None = None,
448
+ target_audio_length: int | None = None,
449
+ ) -> torch.Tensor:
450
+ """Synthesize audio from content and global features using Kanade model and Vocos.
451
+ Args:
452
+ global_embedding (torch.Tensor): Global embedding tensor (dim,).
453
+ content_token_indices (torch.Tensor, optional): Optional content token indices tensor (seq_len).
454
+ content_embedding (torch.Tensor, optional): Optional content embedding tensor (seq_len, dim).
455
+ If both content_token_indices and content_embedding are provided, content_embedding takes precedence.
456
+ target_audio_length (int, optional): Target length of the output audio in samples.
457
+ If None, uses the original audio length estimated from the sequence length of content tokens.
458
+ Returns:
459
+ torch.Tensor: Generated mel spectrogram tensor (n_mels, T).
460
+ """
461
+ # Obtain content embedding if not provided
462
+ if content_embedding is None:
463
+ if content_token_indices is None:
464
+ raise ValueError("Either content_token_indices or content_embedding must be provided.")
465
+ content_embedding = self.decode_token_indices(content_token_indices)
466
+
467
+ if target_audio_length is None:
468
+ # Estimate original audio length from content token sequence length
469
+ seq_len = content_embedding.size(0)
470
+ target_audio_length = self._calculate_original_audio_length(seq_len)
471
+
472
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
473
+ mel_length = self._calculate_target_mel_length(target_audio_length)
474
+ content_embedding = content_embedding.unsqueeze(0) # (1, seq_len, dim)
475
+ global_embedding = global_embedding.unsqueeze(0) # (1, dim)
476
+ mel_spectrogram = self.forward_mel(content_embedding, global_embedding, mel_length=mel_length)
477
+
478
+ return mel_spectrogram.squeeze(0) # (n_mels, T)
479
+
480
+ @torch.inference_mode()
481
+ def voice_conversion(self, source_waveform: torch.Tensor, reference_waveform: torch.Tensor) -> torch.Tensor:
482
+ """Convert voice using Kanade model and Vocos, keeping content from source and global characteristics from reference.
483
+ Only supports single audio input. Just a convenient wrapper around encode and decode methods.
484
+ Args:
485
+ source_waveform (torch.Tensor): Source audio waveform tensor (samples,).
486
+ reference_waveform (torch.Tensor): Reference audio waveform tensor (samples_ref,).
487
+ Returns:
488
+ torch.Tensor: Converted mel spectrogram tensor (n_mels, T).
489
+ """
490
+ # Extract source content features and reference global features
491
+ source_features = self.encode(source_waveform, return_content=True, return_global=False)
492
+ reference_features = self.encode(reference_waveform, return_content=False, return_global=True)
493
+
494
+ # Synthesize mel spectrogram using source content and reference global features
495
+ mel_spectrogram = self.decode(
496
+ content_embedding=source_features.content_embedding,
497
+ global_embedding=reference_features.global_embedding,
498
+ target_audio_length=source_waveform.size(0),
499
+ )
500
+ return mel_spectrogram
src/kanade_tokenizer/module/adaln_zero.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/facebookresearch/DiT
2
+
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class AdaLNZero(nn.Module):
9
+ """
10
+ Adaptive Layer Normalization Zero (AdaLNZero) module.
11
+
12
+ Combines LayerNorm with adaptive conditioning to produce shift, scale, and gate values.
13
+ The gate is used to scale features before residual connection.
14
+
15
+ Args:
16
+ dim: Feature dimension
17
+ condition_dim: Conditioning dimension
18
+ eps: LayerNorm epsilon
19
+ return_gate: If True, returns gate value for scaling.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ condition_dim: int,
26
+ eps: float = 1e-5,
27
+ return_gate: bool = True,
28
+ ):
29
+ super().__init__()
30
+ self.dim = dim
31
+ self.condition_dim = condition_dim
32
+ self.return_gate = return_gate
33
+
34
+ # LayerNorm without learnable parameters
35
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
36
+
37
+ # Conditioning network: condition -> shift, scale, gate
38
+ output_dim = 3 * dim if return_gate else 2 * dim
39
+ self.condition_proj = nn.Sequential(
40
+ nn.SiLU(),
41
+ nn.Linear(condition_dim, output_dim),
42
+ )
43
+
44
+ # Initialize to zero for stable training
45
+ nn.init.zeros_(self.condition_proj[1].weight)
46
+ nn.init.zeros_(self.condition_proj[1].bias)
47
+
48
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None:
49
+ """
50
+ Args:
51
+ x: Input tensor of shape (B, L, dim)
52
+ condition: Conditioning tensor of shape (B, L, condition_dim) or (B, 1, condition_dim)
53
+
54
+ Returns:
55
+ modulated_x: Normalized and modulated features
56
+ gate: Gate values for scaling (None if return_gate=False)
57
+ """
58
+ x_norm = self.norm(x)
59
+ condition_params = self.condition_proj(condition)
60
+
61
+ if self.return_gate:
62
+ shift, scale, gate = condition_params.chunk(3, dim=-1)
63
+ else:
64
+ shift, scale = condition_params.chunk(2, dim=-1)
65
+ gate = None
66
+
67
+ modulated_x = x_norm * (1 + scale) + shift
68
+ return modulated_x, gate
src/kanade_tokenizer/module/audio_feature.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from:
2
+ # Vocos: https://github.com/gemelo-ai/vocos/blob/main/vocos/feature_extractors.py
3
+ # BigVGAN: https://github.com/NVIDIA/BigVGAN/blob/main/meldataset.py (Also used by HiFT)
4
+
5
+ import torch
6
+ import torchaudio
7
+ from librosa.filters import mel as librosa_mel_fn
8
+ from torch import nn
9
+
10
+
11
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
12
+ return torch.log(torch.clip(x, min=clip_val))
13
+
14
+
15
+ class MelSpectrogramFeature(nn.Module):
16
+ def __init__(
17
+ self,
18
+ sample_rate: int = 24000,
19
+ n_fft: int = 1024,
20
+ hop_length: int = 256,
21
+ n_mels: int = 100,
22
+ padding: str = "center",
23
+ fmin: int = 0,
24
+ fmax: int | None = None,
25
+ bigvgan_style_mel: bool = False,
26
+ ):
27
+ super().__init__()
28
+
29
+ self.bigvgan_style_mel = bigvgan_style_mel
30
+ if bigvgan_style_mel:
31
+ # BigVGAN style: same padding, Slaney mel scale, with normalization
32
+ self.n_fft = n_fft
33
+ self.win_size = n_fft
34
+ self.hop_size = hop_length
35
+ # (n_mels, n_fft // 2 + 1)
36
+ mel_basis = librosa_mel_fn(
37
+ sr=sample_rate, n_fft=n_fft, n_mels=n_mels, norm="slaney", htk=False, fmin=fmin, fmax=fmax
38
+ )
39
+ mel_basis = torch.from_numpy(mel_basis).float()
40
+ hann_window = torch.hann_window(n_fft)
41
+ self.register_buffer("mel_basis", mel_basis)
42
+ self.register_buffer("hann_window", hann_window)
43
+ else:
44
+ # Vocos style: center padding, HTK mel scale, without normalization
45
+ if padding not in ["center", "same"]:
46
+ raise ValueError("Padding must be 'center' or 'same'.")
47
+
48
+ self.padding = padding
49
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
50
+ sample_rate=sample_rate,
51
+ n_fft=n_fft,
52
+ hop_length=hop_length,
53
+ n_mels=n_mels,
54
+ center=padding == "center",
55
+ power=1,
56
+ fmin=fmin,
57
+ fmax=fmax,
58
+ )
59
+
60
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Returns:
63
+ mel_specgram (Tensor): Mel spectrogram of the input audio. (B, C, L)
64
+ """
65
+ if self.bigvgan_style_mel:
66
+ return self.bigvgan_mel(audio)
67
+ else:
68
+ return self.vocos_mel(audio)
69
+
70
+ def vocos_mel(self, audio: torch.Tensor) -> torch.Tensor:
71
+ if self.padding == "same":
72
+ pad = self.mel_spec.win_length - self.mel_spec.hop_length
73
+ audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
74
+
75
+ specgram = self.mel_spec.spectrogram(audio)
76
+ mel_specgram = self.mel_spec.mel_scale(specgram)
77
+
78
+ # Convert to log scale
79
+ mel_specgram = safe_log(mel_specgram)
80
+ return mel_specgram
81
+
82
+ def bigvgan_mel(self, audio: torch.Tensor) -> torch.Tensor:
83
+ # Pad so that the output length T = L // hop_length
84
+ padding = (self.n_fft - self.hop_size) // 2
85
+ audio = torch.nn.functional.pad(audio, (padding, padding), mode="reflect")
86
+ audio = audio.reshape(-1, audio.shape[-1])
87
+
88
+ spec = torch.stft(
89
+ audio,
90
+ n_fft=self.n_fft,
91
+ hop_length=self.hop_size,
92
+ win_length=self.win_size,
93
+ window=self.hann_window,
94
+ center=False,
95
+ pad_mode="reflect",
96
+ normalized=False,
97
+ onesided=True,
98
+ return_complex=True,
99
+ )
100
+ spec = spec.reshape(audio.shape[:-1] + spec.shape[-2:])
101
+
102
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
103
+ mel_spec = torch.matmul(self.mel_basis, spec)
104
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
105
+ return mel_spec
src/kanade_tokenizer/module/convnext.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/gemelo-ai/vocos/blob/main/vocos/models.py
2
+
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class ConvNeXtBlock(nn.Module):
9
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
10
+
11
+ Args:
12
+ dim (int): Number of input channels.
13
+ intermediate_dim (int): Dimensionality of the intermediate layer.
14
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
15
+ Defaults to None.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ intermediate_dim: int,
22
+ layer_scale_init_value: float,
23
+ ):
24
+ super().__init__()
25
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
26
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
27
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
28
+ self.act = nn.GELU()
29
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
30
+ self.gamma = (
31
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
32
+ if layer_scale_init_value > 0
33
+ else None
34
+ )
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ residual = x
38
+ x = self.dwconv(x)
39
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
40
+ x = self.norm(x)
41
+ x = self.pwconv1(x)
42
+ x = self.act(x)
43
+ x = self.pwconv2(x)
44
+ if self.gamma is not None:
45
+ x = self.gamma * x
46
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
47
+
48
+ x = residual + x
49
+ return x
50
+
51
+
52
+ class ConvNextBackbone(nn.Module):
53
+ """
54
+ Backbone module built with ConvNeXt blocks.
55
+
56
+ Args:
57
+ input_channels (int): Number of input features channels.
58
+ dim (int): Hidden dimension of the model.
59
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
60
+ num_layers (int): Number of ConvNeXtBlock layers.
61
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ input_channels: int,
67
+ dim: int,
68
+ intermediate_dim: int,
69
+ num_layers: int,
70
+ output_channels: int | None = None,
71
+ layer_scale_init_value: float | None = None,
72
+ skip_embed: bool = False,
73
+ ):
74
+ super().__init__()
75
+ self.input_channels = input_channels
76
+ self.output_channels = output_channels
77
+ self.dim = dim
78
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) if not skip_embed else nn.Identity()
79
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
80
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
81
+ self.convnext = nn.ModuleList(
82
+ [
83
+ ConvNeXtBlock(
84
+ dim=dim,
85
+ intermediate_dim=intermediate_dim,
86
+ layer_scale_init_value=layer_scale_init_value,
87
+ )
88
+ for _ in range(num_layers)
89
+ ]
90
+ )
91
+ self.proj_out = nn.Linear(dim, output_channels) if output_channels else nn.Identity()
92
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
93
+ self.apply(self._init_weights)
94
+
95
+ @property
96
+ def input_dim(self) -> int:
97
+ return self.input_channels
98
+
99
+ @property
100
+ def output_dim(self) -> int:
101
+ return self.output_channels if self.output_channels else self.dim
102
+
103
+ def _init_weights(self, m):
104
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
105
+ nn.init.trunc_normal_(m.weight, std=0.02)
106
+ nn.init.constant_(m.bias, 0)
107
+
108
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
109
+ """
110
+ Args:
111
+ x (Tensor): Input tensor of shape (B, L, C), where B is the batch size,
112
+ C denotes output features, and L is the sequence length.
113
+ Returns:
114
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
115
+ and H denotes the model dimension.
116
+ """
117
+ x = x.transpose(1, 2) # (B, L, C) -> (B, C, L)
118
+ x = self.embed(x)
119
+ x = self.norm(x.transpose(1, 2))
120
+ x = x.transpose(1, 2)
121
+ for conv_block in self.convnext:
122
+ x = conv_block(x)
123
+ x = self.final_layer_norm(x.transpose(1, 2))
124
+ x = self.proj_out(x)
125
+ return x
src/kanade_tokenizer/module/discriminator.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from:
2
+ # https://github.com/gemelo-ai/vocos/blob/main/vocos/discriminators.py
3
+ # https://github.com/gemelo-ai/vocos/blob/main/vocos/loss.py
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+
10
+
11
+ def get_2d_padding(kernel_size: tuple[int, int], dilation: tuple[int, int] = (1, 1)):
12
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
13
+
14
+
15
+ class SpectrogramDiscriminator(nn.Module):
16
+ def __init__(
17
+ self,
18
+ frequency_bins: int,
19
+ channels: int = 32,
20
+ kernel_size: tuple[int, int] = (3, 3),
21
+ dilation: list[int] = [1, 2, 4],
22
+ bands: tuple[tuple[float, float], ...] = ((0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)),
23
+ use_downsample: bool = True,
24
+ ):
25
+ super().__init__()
26
+ self.bands = [(int(b[0] * frequency_bins), int(b[1] * frequency_bins)) for b in bands]
27
+
28
+ self.stacks = nn.ModuleList()
29
+ for _ in self.bands:
30
+ stack = nn.ModuleList(
31
+ [weight_norm(nn.Conv2d(1, channels, kernel_size, padding=get_2d_padding(kernel_size)))]
32
+ )
33
+
34
+ for d in dilation:
35
+ # dilation on time axis
36
+ pad = get_2d_padding(kernel_size, (d, 1))
37
+ stack.append(weight_norm(nn.Conv2d(channels, channels, kernel_size, dilation=(d, 1), padding=pad)))
38
+
39
+ stack.append(weight_norm(nn.Conv2d(channels, channels, kernel_size, padding=get_2d_padding(kernel_size))))
40
+
41
+ self.stacks.append(stack)
42
+
43
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, kernel_size, padding=get_2d_padding(kernel_size)))
44
+ if use_downsample:
45
+ self.downsample = nn.AvgPool2d(4, stride=2, padding=1, count_include_pad=False)
46
+ else:
47
+ self.downsample = nn.Identity()
48
+
49
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
50
+ """
51
+ Args:
52
+ x (Tensor): Input spectrogram (B, C, F, T).
53
+ Returns:
54
+ output (Tensor): Discriminator output.
55
+ intermediates (list[Tensor]): List of intermediate feature maps.
56
+ """
57
+ if x.dim() == 3:
58
+ x = x.unsqueeze(1)
59
+ assert x.dim() == 4, f"Expected 4D input, got {x.dim()}D"
60
+
61
+ # Split into bands
62
+ x = rearrange(x, "b c f t -> b c t f")
63
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
64
+
65
+ x = []
66
+ intermediates = []
67
+ for x_band, stack in zip(x_bands, self.stacks):
68
+ for layer in stack:
69
+ x_band = layer(x_band)
70
+ x_band = torch.nn.functional.leaky_relu(x_band, 0.1)
71
+ intermediates.append(x_band)
72
+ x.append(x_band)
73
+
74
+ # Concatenate the outputs from all bands
75
+ x = torch.cat(x, dim=-1)
76
+ x = self.conv_post(x)
77
+ x = self.downsample(x)
78
+ return x, intermediates
src/kanade_tokenizer/module/fsq.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finite Scalar Quantization: https://arxiv.org/abs/2309.15505
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from ..util import get_logger
7
+
8
+ logger = get_logger()
9
+
10
+
11
+ def round_ste(z: torch.Tensor) -> torch.Tensor:
12
+ """Round with straight through gradients."""
13
+ zhat = z.round()
14
+ return z + (zhat - z).detach()
15
+
16
+
17
+ def get_entropy(prob: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
18
+ return -torch.sum(prob * torch.log(prob + eps), dim=-1)
19
+
20
+
21
+ class FSQ(nn.Module):
22
+ def __init__(self, levels: list[int]):
23
+ super().__init__()
24
+ self.levels = levels
25
+ self.dim = len(levels)
26
+
27
+ _levels = torch.tensor(levels, dtype=torch.long)
28
+ self.register_buffer("_levels", _levels, persistent=False)
29
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.long)
30
+ self.register_buffer("_basis", _basis, persistent=False)
31
+
32
+ def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
33
+ """Bound `z`, an array of shape (..., d)."""
34
+ half_l = (self._levels - 1) * (1 - eps) / 2
35
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
36
+ shift = (offset / half_l).tan()
37
+ return (z + shift).tanh() * half_l - offset
38
+
39
+ def quantize(self, z: torch.Tensor) -> torch.Tensor:
40
+ """Quantizes z, returns quantized zhat, same shape as z."""
41
+ quantized = round_ste(self.bound(z))
42
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
43
+ return quantized / half_width
44
+
45
+ def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
46
+ half_width = self._levels // 2
47
+ return (zhat_normalized * half_width) + half_width
48
+
49
+ def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
50
+ half_width = self._levels // 2
51
+ return (zhat - half_width) / half_width
52
+
53
+ def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
54
+ """Converts a `code` to an index in the codebook."""
55
+ # (B, T, C) -> (B, T)
56
+ assert zhat.shape[-1] == len(self.levels)
57
+ zhat = self._scale_and_shift(zhat)
58
+ return (zhat * self._basis.to(torch.float64)).to(torch.long).sum(dim=-1)
59
+
60
+ def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
61
+ """Inverse of `codes_to_indices`."""
62
+ # (B, T) -> (B, T, C)
63
+ indices = indices.unsqueeze(-1)
64
+ codes_non_centered = (indices // self._basis) % self._levels
65
+ return self._scale_and_shift_inverse(codes_non_centered)
66
+
67
+ def encode(self, z: torch.Tensor) -> torch.Tensor:
68
+ z_q = self.quantize(z)
69
+ indices = self.codes_to_indices(z_q) # (B, T)
70
+ return z_q, indices
71
+
72
+ def decode(self, indices: torch.Tensor) -> torch.Tensor:
73
+ z_q = self.indices_to_codes(indices) # (B, T, C)
74
+ return z_q
75
+
76
+ def forward(self, z: torch.Tensor):
77
+ z_q = self.quantize(z)
78
+ indices = self.codes_to_indices(z_q) # (B, T)
79
+ return z_q, indices
80
+
81
+
82
+ class FiniteScalarQuantizer(nn.Module):
83
+ def __init__(self, input_dim: int, output_dim: int, levels: list[int]) -> None:
84
+ super().__init__()
85
+ self.input_dim_ = input_dim
86
+ self.output_dim_ = output_dim
87
+
88
+ self.fsq = FSQ(levels)
89
+ logger.debug(
90
+ f"Finite Scalar Quantizer with levels: {levels}, input_dim: {input_dim}, output_dim: {output_dim}, codebook_size: {self.all_codebook_size}"
91
+ )
92
+
93
+ self.proj_in = nn.Linear(input_dim, len(levels)) if len(levels) != input_dim else nn.Identity()
94
+ self.proj_out = nn.Linear(len(levels), output_dim) if len(levels) != output_dim else nn.Identity()
95
+
96
+ def build_codebook(self) -> None:
97
+ pass
98
+
99
+ @property
100
+ def output_dim(self) -> int:
101
+ return self.output_dim_
102
+
103
+ @property
104
+ def all_codebook_size(self) -> int:
105
+ size = 1
106
+ for level in self.fsq.levels:
107
+ size *= level
108
+ return size
109
+
110
+ def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
111
+ latent = self.proj_in(z) # Latent projected by proj_in
112
+ quantized_latent, indices = self.fsq(latent) # Quantized latent before proj_out
113
+ z_q = self.proj_out(quantized_latent)
114
+
115
+ # Compute perplexity from used indices distribution
116
+ flat_indices = indices.view(-1)
117
+ unique_indices, counts = torch.unique(flat_indices, return_counts=True)
118
+ used_indices_probs = counts.float() / flat_indices.numel()
119
+ entropy = get_entropy(used_indices_probs)
120
+ perplexity = torch.exp(entropy)
121
+
122
+ info_dict = {
123
+ "latent": latent,
124
+ "quantized_latent": quantized_latent,
125
+ "indices": indices,
126
+ "perplexity": perplexity,
127
+ }
128
+ return z_q, info_dict
129
+
130
+ def encode(self, z: torch.Tensor, skip_proj: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
131
+ z = self.proj_in(z)
132
+ z_q, indices = self.fsq.encode(z)
133
+ if not skip_proj:
134
+ z_q = self.proj_out(z_q)
135
+ return z_q, indices
136
+
137
+ def decode(self, indices: torch.Tensor) -> torch.Tensor:
138
+ z_q = self.fsq.decode(indices)
139
+ z_q = self.proj_out(z_q)
140
+ return z_q
src/kanade_tokenizer/module/global_encoder.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .convnext import ConvNextBackbone
7
+
8
+
9
+ class AttentiveStatsPool(nn.Module):
10
+ def __init__(self, input_channels: int, output_channels: int, attention_channels: int = 128):
11
+ super().__init__()
12
+
13
+ self.attn = nn.Sequential(
14
+ nn.Conv1d(input_channels, attention_channels, kernel_size=1),
15
+ nn.Tanh(),
16
+ nn.Conv1d(attention_channels, input_channels, kernel_size=1),
17
+ nn.Softmax(dim=2),
18
+ )
19
+ self.proj = nn.Linear(input_channels * 2, output_channels)
20
+ self.norm = nn.LayerNorm(output_channels)
21
+
22
+ def forward(self, x):
23
+ alpha = self.attn(x)
24
+
25
+ mean = torch.sum(alpha * x, dim=2)
26
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
27
+ std = torch.sqrt(residuals.clamp(min=1e-4, max=1e4))
28
+
29
+ x = torch.cat([mean, std], dim=1)
30
+ return self.norm(self.proj(x))
31
+
32
+
33
+ class GlobalEncoder(nn.Module):
34
+ def __init__(
35
+ self,
36
+ input_channels: int,
37
+ output_channels: int,
38
+ dim: int,
39
+ intermediate_dim: int,
40
+ num_layers: int,
41
+ skip_embed: bool = False,
42
+ attention_channels: int = 128,
43
+ use_attn_pool: bool = True,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.backbone = ConvNextBackbone(
48
+ input_channels=input_channels,
49
+ dim=dim,
50
+ intermediate_dim=intermediate_dim,
51
+ num_layers=num_layers,
52
+ skip_embed=skip_embed,
53
+ )
54
+ if use_attn_pool:
55
+ self.pooling = AttentiveStatsPool(
56
+ input_channels=dim, output_channels=output_channels, attention_channels=attention_channels
57
+ )
58
+ else:
59
+ self.pooling = nn.Sequential(
60
+ nn.AdaptiveAvgPool1d(1),
61
+ nn.Flatten(1),
62
+ nn.Linear(dim, output_channels),
63
+ nn.LayerNorm(output_channels),
64
+ )
65
+ self.output_channels = output_channels
66
+
67
+ @property
68
+ def output_dim(self):
69
+ return self.output_channels
70
+
71
+ def forward(self, x):
72
+ features = self.backbone(x)
73
+ # (B, T, C) -> (B, C, T)
74
+ features = features.transpose(1, 2)
75
+ return self.pooling(features) # (B, C_out)
src/kanade_tokenizer/module/hift.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/yl4579/HiFTNet/blob/main/models.py
2
+ # https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/hifigan/generator.py
3
+
4
+ from typing import Dict, List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from scipy.signal import get_window
11
+ from torch.distributions.uniform import Uniform
12
+ from torch.nn import Conv1d, ConvTranspose1d
13
+ from torch.nn.utils import remove_weight_norm
14
+ from torch.nn.utils.parametrizations import weight_norm
15
+
16
+
17
+ def get_padding(kernel_size, dilation=1):
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def mel_spec_transform(
28
+ audio: torch.Tensor,
29
+ n_fft: int,
30
+ n_mels: int,
31
+ sample_rate: int,
32
+ hop_size: int,
33
+ win_size: int,
34
+ fmin: int = 0,
35
+ fmax: Optional[int] = None,
36
+ ):
37
+ from librosa.filters import mel as librosa_mel_fn
38
+
39
+ # (n_mels, n_fft // 2 + 1)
40
+ mel_basis = librosa_mel_fn(
41
+ sr=sample_rate, n_fft=n_fft, n_mels=n_mels, norm="slaney", htk=False, fmin=fmin, fmax=fmax
42
+ )
43
+ mel_basis = torch.from_numpy(mel_basis).float()
44
+ hann_window = torch.hann_window(win_size)
45
+
46
+ # Pad so that the output length T = L // hop_length
47
+ padding = (n_fft - hop_size) // 2
48
+ audio = torch.nn.functional.pad(audio, (padding, padding), mode="reflect")
49
+ audio = audio.reshape(-1, audio.shape[-1])
50
+
51
+ # (B, n_fft // 2 + 1, T=1 + (L' - n_fft) // hop_length)
52
+ # L' = L + n_fft - hop_length
53
+ # T = L // hop_length
54
+ spec = torch.stft(
55
+ audio,
56
+ n_fft=n_fft,
57
+ hop_length=hop_size,
58
+ win_length=win_size,
59
+ window=hann_window,
60
+ center=False,
61
+ pad_mode="reflect",
62
+ normalized=False,
63
+ onesided=True,
64
+ return_complex=True,
65
+ )
66
+ spec = spec.reshape(audio.shape[:-1] + spec.shape[-2:])
67
+
68
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
69
+ mel_spec = torch.matmul(mel_basis, spec)
70
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
71
+
72
+ return mel_spec
73
+
74
+
75
+ class Snake(nn.Module):
76
+ """
77
+ Implementation of a sine-based periodic activation function
78
+ Shape:
79
+ - Input: (B, C, T)
80
+ - Output: (B, C, T), same shape as the input
81
+ Parameters:
82
+ - alpha - trainable parameter
83
+ References:
84
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
85
+ https://arxiv.org/abs/2006.08195
86
+ Examples:
87
+ >>> a1 = snake(256)
88
+ >>> x = torch.randn(256)
89
+ >>> x = a1(x)
90
+ """
91
+
92
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
93
+ """
94
+ Initialization.
95
+ INPUT:
96
+ - in_features: shape of the input
97
+ - alpha: trainable parameter
98
+ alpha is initialized to 1 by default, higher values = higher-frequency.
99
+ alpha will be trained along with the rest of your model.
100
+ """
101
+ super(Snake, self).__init__()
102
+ self.in_features = in_features
103
+
104
+ # initialize alpha
105
+ self.alpha_logscale = alpha_logscale
106
+ if self.alpha_logscale: # log scale alphas initialized to zeros
107
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
108
+ else: # linear scale alphas initialized to ones
109
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
110
+
111
+ self.alpha.requires_grad = alpha_trainable
112
+
113
+ self.no_div_by_zero = 0.000000001
114
+
115
+ def forward(self, x):
116
+ """
117
+ Forward pass of the function.
118
+ Applies the function to the input elementwise.
119
+ Snake ∶= x + 1/a * sin^2 (xa)
120
+ """
121
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
122
+ if self.alpha_logscale:
123
+ alpha = torch.exp(alpha)
124
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
125
+
126
+ return x
127
+
128
+
129
+ class ResBlock(torch.nn.Module):
130
+ """Residual block module in HiFiGAN/BigVGAN."""
131
+
132
+ def __init__(
133
+ self,
134
+ channels: int = 512,
135
+ kernel_size: int = 3,
136
+ dilations: List[int] = [1, 3, 5],
137
+ ):
138
+ super(ResBlock, self).__init__()
139
+ self.convs1 = nn.ModuleList()
140
+ self.convs2 = nn.ModuleList()
141
+
142
+ for dilation in dilations:
143
+ self.convs1.append(
144
+ weight_norm(
145
+ Conv1d(
146
+ channels,
147
+ channels,
148
+ kernel_size,
149
+ 1,
150
+ dilation=dilation,
151
+ padding=get_padding(kernel_size, dilation),
152
+ )
153
+ )
154
+ )
155
+ self.convs2.append(
156
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
157
+ )
158
+ self.convs1.apply(init_weights)
159
+ self.convs2.apply(init_weights)
160
+ self.activations1 = nn.ModuleList([Snake(channels, alpha_logscale=False) for _ in range(len(self.convs1))])
161
+ self.activations2 = nn.ModuleList([Snake(channels, alpha_logscale=False) for _ in range(len(self.convs2))])
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ for idx in range(len(self.convs1)):
165
+ xt = self.activations1[idx](x)
166
+ xt = self.convs1[idx](xt)
167
+ xt = self.activations2[idx](xt)
168
+ xt = self.convs2[idx](xt)
169
+ x = xt + x
170
+ return x
171
+
172
+ def remove_weight_norm(self):
173
+ for idx in range(len(self.convs1)):
174
+ remove_weight_norm(self.convs1[idx])
175
+ remove_weight_norm(self.convs2[idx])
176
+
177
+
178
+ class ConvRNNF0Predictor(nn.Module):
179
+ def __init__(self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512):
180
+ super().__init__()
181
+
182
+ self.num_class = num_class
183
+ self.condnet = nn.Sequential(
184
+ weight_norm(nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)),
185
+ nn.ELU(),
186
+ weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
187
+ nn.ELU(),
188
+ weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
189
+ nn.ELU(),
190
+ weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
191
+ nn.ELU(),
192
+ weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
193
+ nn.ELU(),
194
+ )
195
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
196
+
197
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
198
+ x = self.condnet(x)
199
+ x = x.transpose(1, 2)
200
+ return torch.abs(self.classifier(x).squeeze(-1))
201
+
202
+
203
+ class SineGen(torch.nn.Module):
204
+ """Definition of sine generator
205
+ SineGen(samp_rate, harmonic_num = 0,
206
+ sine_amp = 0.1, noise_std = 0.003,
207
+ voiced_threshold = 0,
208
+ flag_for_pulse=False)
209
+ samp_rate: sampling rate in Hz
210
+ harmonic_num: number of harmonic overtones (default 0)
211
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
212
+ noise_std: std of Gaussian noise (default 0.003)
213
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
214
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
215
+ Note: when flag_for_pulse is True, the first time step of a voiced
216
+ segment is always sin(np.pi) or cos(0)
217
+ """
218
+
219
+ def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
220
+ super(SineGen, self).__init__()
221
+ self.sine_amp = sine_amp
222
+ self.noise_std = noise_std
223
+ self.harmonic_num = harmonic_num
224
+ self.sampling_rate = samp_rate
225
+ self.voiced_threshold = voiced_threshold
226
+
227
+ def _f02uv(self, f0):
228
+ # generate uv signal
229
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
230
+ return uv
231
+
232
+ @torch.no_grad()
233
+ def forward(self, f0):
234
+ """
235
+ :param f0: [B, 1, sample_len], Hz
236
+ :return: [B, 1, sample_len]
237
+ """
238
+
239
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
240
+ for i in range(self.harmonic_num + 1):
241
+ F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate
242
+
243
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
244
+ u_dist = Uniform(low=-np.pi, high=np.pi)
245
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
246
+ phase_vec[:, 0, :] = 0
247
+
248
+ # generate sine waveforms
249
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
250
+
251
+ # generate uv signal
252
+ uv = self._f02uv(f0)
253
+
254
+ # noise: for unvoiced should be similar to sine_amp
255
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
256
+ # . for voiced regions is self.noise_std
257
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
258
+ noise = noise_amp * torch.randn_like(sine_waves)
259
+
260
+ # first: set the unvoiced part to 0 by uv
261
+ # then: additive noise
262
+ sine_waves = sine_waves * uv + noise
263
+ return sine_waves, uv, noise
264
+
265
+
266
+ class SourceModuleHnNSF(torch.nn.Module):
267
+ """SourceModule for hn-nsf
268
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
269
+ add_noise_std=0.003, voiced_threshod=0)
270
+ sampling_rate: sampling_rate in Hz
271
+ harmonic_num: number of harmonic above F0 (default: 0)
272
+ sine_amp: amplitude of sine source signal (default: 0.1)
273
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
274
+ note that amplitude of noise in unvoiced is decided
275
+ by sine_amp
276
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
277
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
278
+ F0_sampled (batchsize, length, 1)
279
+ Sine_source (batchsize, length, 1)
280
+ noise_source (batchsize, length 1)
281
+ uv (batchsize, length, 1)
282
+ """
283
+
284
+ def __init__(
285
+ self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0
286
+ ):
287
+ super(SourceModuleHnNSF, self).__init__()
288
+
289
+ self.sine_amp = sine_amp
290
+ self.noise_std = add_noise_std
291
+
292
+ # to produce sine waveforms
293
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
294
+
295
+ # to merge source harmonics into a single excitation
296
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
297
+ self.l_tanh = torch.nn.Tanh()
298
+
299
+ def forward(self, x):
300
+ """
301
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
302
+ F0_sampled (batchsize, length, 1)
303
+ Sine_source (batchsize, length, 1)
304
+ noise_source (batchsize, length 1)
305
+ """
306
+ # source for harmonic branch
307
+ with torch.no_grad():
308
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
309
+ sine_wavs = sine_wavs.transpose(1, 2)
310
+ uv = uv.transpose(1, 2)
311
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
312
+
313
+ # source for noise branch, in the same shape as uv
314
+ noise = torch.randn_like(uv) * self.sine_amp / 3
315
+ return sine_merge, noise, uv
316
+
317
+
318
+ class SineGen2(torch.nn.Module):
319
+ """Definition of sine generator
320
+ SineGen(samp_rate, harmonic_num = 0,
321
+ sine_amp = 0.1, noise_std = 0.003,
322
+ voiced_threshold = 0,
323
+ flag_for_pulse=False)
324
+ samp_rate: sampling rate in Hz
325
+ harmonic_num: number of harmonic overtones (default 0)
326
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
327
+ noise_std: std of Gaussian noise (default 0.003)
328
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
329
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
330
+ Note: when flag_for_pulse is True, the first time step of a voiced
331
+ segment is always sin(np.pi) or cos(0)
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ samp_rate,
337
+ upsample_scale,
338
+ harmonic_num=0,
339
+ sine_amp=0.1,
340
+ noise_std=0.003,
341
+ voiced_threshold=0,
342
+ flag_for_pulse=False,
343
+ ):
344
+ super(SineGen2, self).__init__()
345
+ self.sine_amp = sine_amp
346
+ self.noise_std = noise_std
347
+ self.harmonic_num = harmonic_num
348
+ self.dim = self.harmonic_num + 1
349
+ self.sampling_rate = samp_rate
350
+ self.voiced_threshold = voiced_threshold
351
+ self.flag_for_pulse = flag_for_pulse
352
+ self.upsample_scale = upsample_scale
353
+
354
+ def _f02uv(self, f0):
355
+ # generate uv signal
356
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
357
+ return uv
358
+
359
+ def _f02sine(self, f0_values):
360
+ """f0_values: (batchsize, length, dim)
361
+ where dim indicates fundamental tone and overtones
362
+ """
363
+ # convert to F0 in rad. The interger part n can be ignored
364
+ # because 2 * np.pi * n doesn't affect phase
365
+ rad_values = (f0_values / self.sampling_rate) % 1
366
+
367
+ # initial phase noise (no noise for fundamental component)
368
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
369
+ rand_ini[:, 0] = 0
370
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
371
+
372
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
373
+ if not self.flag_for_pulse:
374
+ rad_values = torch.nn.functional.interpolate(
375
+ rad_values.transpose(1, 2), scale_factor=1 / self.upsample_scale, mode="linear"
376
+ ).transpose(1, 2)
377
+
378
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
379
+ phase = torch.nn.functional.interpolate(
380
+ phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear"
381
+ ).transpose(1, 2)
382
+ sines = torch.sin(phase)
383
+ else:
384
+ # If necessary, make sure that the first time step of every
385
+ # voiced segments is sin(pi) or cos(0)
386
+ # This is used for pulse-train generation
387
+
388
+ # identify the last time step in unvoiced segments
389
+ uv = self._f02uv(f0_values)
390
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
391
+ uv_1[:, -1, :] = 1
392
+ u_loc = (uv < 1) * (uv_1 > 0)
393
+
394
+ # get the instantanouse phase
395
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
396
+ # different batch needs to be processed differently
397
+ for idx in range(f0_values.shape[0]):
398
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
399
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
400
+ # stores the accumulation of i.phase within
401
+ # each voiced segments
402
+ tmp_cumsum[idx, :, :] = 0
403
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
404
+
405
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
406
+ # within the previous voiced segment.
407
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
408
+
409
+ # get the sines
410
+ sines = torch.cos(i_phase * 2 * np.pi)
411
+ return sines
412
+
413
+ def forward(self, f0):
414
+ """sine_tensor, uv = forward(f0)
415
+ input F0: tensor(batchsize=1, length, dim=1)
416
+ f0 for unvoiced steps should be 0
417
+ output sine_tensor: tensor(batchsize=1, length, dim)
418
+ output uv: tensor(batchsize=1, length, 1)
419
+ """
420
+ # fundamental component
421
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
422
+
423
+ # generate sine waveforms
424
+ sine_waves = self._f02sine(fn) * self.sine_amp
425
+
426
+ # generate uv signal
427
+ uv = self._f02uv(f0)
428
+
429
+ # noise: for unvoiced should be similar to sine_amp
430
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
431
+ # . for voiced regions is self.noise_std
432
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
433
+ noise = noise_amp * torch.randn_like(sine_waves)
434
+
435
+ # first: set the unvoiced part to 0 by uv
436
+ # then: additive noise
437
+ sine_waves = sine_waves * uv + noise
438
+ return sine_waves, uv, noise
439
+
440
+
441
+ class SourceModuleHnNSF2(torch.nn.Module):
442
+ """SourceModule for hn-nsf
443
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
444
+ add_noise_std=0.003, voiced_threshod=0)
445
+ sampling_rate: sampling_rate in Hz
446
+ harmonic_num: number of harmonic above F0 (default: 0)
447
+ sine_amp: amplitude of sine source signal (default: 0.1)
448
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
449
+ note that amplitude of noise in unvoiced is decided
450
+ by sine_amp
451
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
452
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
453
+ F0_sampled (batchsize, length, 1)
454
+ Sine_source (batchsize, length, 1)
455
+ noise_source (batchsize, length 1)
456
+ uv (batchsize, length, 1)
457
+ """
458
+
459
+ def __init__(
460
+ self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0
461
+ ):
462
+ super(SourceModuleHnNSF2, self).__init__()
463
+
464
+ self.sine_amp = sine_amp
465
+ self.noise_std = add_noise_std
466
+
467
+ # to produce sine waveforms
468
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
469
+
470
+ # to merge source harmonics into a single excitation
471
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
472
+ self.l_tanh = torch.nn.Tanh()
473
+
474
+ def forward(self, x):
475
+ """
476
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
477
+ F0_sampled (batchsize, length, 1)
478
+ Sine_source (batchsize, length, 1)
479
+ noise_source (batchsize, length 1)
480
+ """
481
+ # source for harmonic branch
482
+ with torch.no_grad():
483
+ sine_wavs, uv, _ = self.l_sin_gen(x)
484
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
485
+
486
+ # source for noise branch, in the same shape as uv
487
+ noise = torch.randn_like(uv) * self.sine_amp / 3
488
+ return sine_merge, noise, uv
489
+
490
+
491
+ class HiFTGenerator(nn.Module):
492
+ """
493
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
494
+ https://arxiv.org/abs/2309.09493
495
+ """
496
+
497
+ def __init__(
498
+ self,
499
+ in_channels: int = 80,
500
+ base_channels: int = 512,
501
+ nb_harmonics: int = 8,
502
+ sampling_rate: int = 24000,
503
+ nsf_alpha: float = 0.1,
504
+ nsf_sigma: float = 0.003,
505
+ nsf_voiced_threshold: float = 10,
506
+ upsample_rates: list[int] = [8, 5, 3],
507
+ upsample_kernel_sizes: list[int] = [16, 11, 7],
508
+ istft_n_fft: int = 16,
509
+ istft_hop_len: int = 4,
510
+ resblock_kernel_sizes: list[int] = [3, 7, 11],
511
+ resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
512
+ source_resblock_kernel_sizes: list[int] = [7, 7, 11],
513
+ source_resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
514
+ lrelu_slope: float = 0.1,
515
+ audio_limit: float = 0.99,
516
+ f0_predictor_channels: int = 512,
517
+ ):
518
+ super(HiFTGenerator, self).__init__()
519
+
520
+ self.out_channels = 1
521
+ self.nb_harmonics = nb_harmonics
522
+ self.sampling_rate = sampling_rate
523
+ self.istft_n_fft = istft_n_fft
524
+ self.istft_hop_len = istft_hop_len
525
+ self.lrelu_slope = lrelu_slope
526
+ self.audio_limit = audio_limit
527
+
528
+ self.num_kernels = len(resblock_kernel_sizes)
529
+ self.num_upsamples = len(upsample_rates)
530
+ self.m_source = SourceModuleHnNSF2(
531
+ sampling_rate=sampling_rate,
532
+ upsample_scale=np.prod(upsample_rates) * istft_hop_len,
533
+ harmonic_num=nb_harmonics,
534
+ sine_amp=nsf_alpha,
535
+ add_noise_std=nsf_sigma,
536
+ voiced_threshod=nsf_voiced_threshold,
537
+ )
538
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_hop_len)
539
+
540
+ self.conv_pre = weight_norm(Conv1d(in_channels, base_channels, 7, 1, padding=3))
541
+
542
+ # Up
543
+ self.ups = nn.ModuleList()
544
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
545
+ self.ups.append(
546
+ weight_norm(
547
+ ConvTranspose1d(
548
+ base_channels // (2**i), base_channels // (2 ** (i + 1)), k, u, padding=(k - u) // 2
549
+ )
550
+ )
551
+ )
552
+
553
+ # Down
554
+ self.source_downs = nn.ModuleList()
555
+ self.source_resblocks = nn.ModuleList()
556
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
557
+ downsample_cum_rates = np.cumprod(downsample_rates)
558
+ for i, (u, k, d) in enumerate(
559
+ zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)
560
+ ):
561
+ if u == 1:
562
+ self.source_downs.append(Conv1d(istft_n_fft + 2, base_channels // (2 ** (i + 1)), 1, 1))
563
+ else:
564
+ self.source_downs.append(
565
+ Conv1d(istft_n_fft + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
566
+ )
567
+
568
+ self.source_resblocks.append(ResBlock(base_channels // (2 ** (i + 1)), k, d))
569
+
570
+ self.resblocks = nn.ModuleList()
571
+ for i in range(len(self.ups)):
572
+ ch = base_channels // (2 ** (i + 1))
573
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
574
+ self.resblocks.append(ResBlock(ch, k, d))
575
+
576
+ self.conv_post = weight_norm(Conv1d(ch, istft_n_fft + 2, 7, 1, padding=3))
577
+
578
+ self.ups.apply(init_weights)
579
+ self.conv_post.apply(init_weights)
580
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
581
+ self.stft_window = torch.from_numpy(get_window("hann", istft_n_fft, fftbins=True).astype(np.float32))
582
+
583
+ self.f0_predictor = ConvRNNF0Predictor(
584
+ num_class=1, in_channels=in_channels, cond_channels=f0_predictor_channels
585
+ )
586
+
587
+ def remove_weight_norm(self):
588
+ for layer in self.ups:
589
+ remove_weight_norm(layer)
590
+ for layer in self.resblocks:
591
+ layer.remove_weight_norm()
592
+ remove_weight_norm(self.conv_pre)
593
+ remove_weight_norm(self.conv_post)
594
+ self.m_source.remove_weight_norm()
595
+ for layer in self.source_downs:
596
+ remove_weight_norm(layer)
597
+ for layer in self.source_resblocks:
598
+ layer.remove_weight_norm()
599
+
600
+ def _stft(self, x):
601
+ spec = torch.stft(
602
+ x,
603
+ self.istft_n_fft,
604
+ self.istft_hop_len,
605
+ self.istft_n_fft,
606
+ window=self.stft_window.to(x.device),
607
+ return_complex=True,
608
+ )
609
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
610
+ return spec[..., 0], spec[..., 1]
611
+
612
+ def _istft(self, magnitude, phase):
613
+ magnitude = torch.clip(magnitude, max=1e2)
614
+ real = magnitude * torch.cos(phase)
615
+ img = magnitude * torch.sin(phase)
616
+ inverse_transform = torch.istft(
617
+ torch.complex(real, img),
618
+ self.istft_n_fft,
619
+ self.istft_hop_len,
620
+ self.istft_n_fft,
621
+ window=self.stft_window.to(magnitude.device),
622
+ )
623
+ return inverse_transform
624
+
625
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
626
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
627
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
628
+
629
+ x = self.conv_pre(x)
630
+ for i in range(self.num_upsamples):
631
+ x = F.leaky_relu(x, self.lrelu_slope)
632
+ x = self.ups[i](x)
633
+
634
+ if i == self.num_upsamples - 1:
635
+ x = self.reflection_pad(x)
636
+
637
+ # fusion
638
+ si = self.source_downs[i](s_stft)
639
+ si = self.source_resblocks[i](si)
640
+ x = x + si
641
+
642
+ xs = None
643
+ for j in range(self.num_kernels):
644
+ if xs is None:
645
+ xs = self.resblocks[i * self.num_kernels + j](x)
646
+ else:
647
+ xs += self.resblocks[i * self.num_kernels + j](x)
648
+ x = xs / self.num_kernels
649
+
650
+ x = F.leaky_relu(x)
651
+ x = self.conv_post(x)
652
+ magnitude = torch.exp(x[:, : self.istft_n_fft // 2 + 1, :])
653
+ phase = torch.sin(x[:, self.istft_n_fft // 2 + 1 :, :]) # actually, sin is redundancy
654
+
655
+ x = self._istft(magnitude, phase)
656
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
657
+ return x
658
+
659
+ def forward(self, speech_feat: torch.Tensor) -> Dict[str, Optional[torch.Tensor]]:
660
+ speech_feat = speech_feat.transpose(1, 2)
661
+ # mel->f0
662
+ f0 = self.f0_predictor(speech_feat)
663
+ # f0->source
664
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
665
+ s, _, _ = self.m_source(s)
666
+ s = s.transpose(1, 2)
667
+ # mel+source->speech
668
+ generated_speech = self.decode(x=speech_feat, s=s)
669
+ return generated_speech, f0
670
+
671
+ @torch.inference_mode()
672
+ def inference(self, speech_feat: torch.Tensor) -> torch.Tensor:
673
+ # mel->f0
674
+ f0 = self.f0_predictor(speech_feat)
675
+ # f0->source
676
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
677
+ s, _, _ = self.m_source(s)
678
+ s = s.transpose(1, 2)
679
+ generated_speech = self.decode(x=speech_feat, s=s)
680
+ return generated_speech
681
+
682
+ def load_weights(self, weights_path: str):
683
+ checkpoint = torch.load(weights_path, map_location="cpu")
684
+ state_dict = {k.replace("generator.", ""): v for k, v in checkpoint.items()}
685
+ self.load_state_dict(state_dict, strict=True)
src/kanade_tokenizer/module/postnet.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/ming024/FastSpeech2
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_padding(kernel_size: int, dilation: int = 1):
8
+ return ((kernel_size - 1) * dilation) // 2
9
+
10
+
11
+ class Norm(nn.Module):
12
+ def __init__(self, channels: int):
13
+ super().__init__()
14
+ self.norm = nn.LayerNorm(channels)
15
+
16
+ def forward(self, x):
17
+ # (batch_size, channels, sequence_length)
18
+ x = x.transpose(1, 2)
19
+ x = self.norm(x)
20
+ return x.transpose(1, 2)
21
+
22
+
23
+ class PostNet(nn.Module):
24
+ def __init__(
25
+ self,
26
+ input_channels: int = 100,
27
+ channels: int = 512,
28
+ kernel_size: int = 5,
29
+ num_layers: int = 5,
30
+ dropout: float = 0.5,
31
+ use_layer_norm: bool = False,
32
+ ):
33
+ super().__init__()
34
+
35
+ padding = get_padding(kernel_size)
36
+ self.convolutions = nn.ModuleList()
37
+
38
+ self.convolutions.append(
39
+ nn.Sequential(
40
+ nn.Conv1d(input_channels, channels, kernel_size=kernel_size, padding=padding),
41
+ Norm(channels) if use_layer_norm else nn.BatchNorm1d(channels),
42
+ )
43
+ )
44
+ for i in range(1, num_layers - 1):
45
+ self.convolutions.append(
46
+ nn.Sequential(
47
+ nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding),
48
+ Norm(channels) if use_layer_norm else nn.BatchNorm1d(channels),
49
+ )
50
+ )
51
+ self.convolutions.append(
52
+ nn.Sequential(
53
+ nn.Conv1d(channels, input_channels, kernel_size=kernel_size, padding=padding),
54
+ Norm(input_channels) if use_layer_norm else nn.BatchNorm1d(input_channels),
55
+ )
56
+ )
57
+
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ def forward(self, x):
61
+ residual = x
62
+
63
+ for i in range(len(self.convolutions) - 1):
64
+ x = self.convolutions[i](x)
65
+ x = torch.tanh(x)
66
+ x = self.dropout(x)
67
+
68
+ x = self.convolutions[-1](x)
69
+ x = self.dropout(x)
70
+
71
+ return x + residual
src/kanade_tokenizer/module/ssl_extractor.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+ import torchaudio.pipelines as pipelines
5
+ from torchaudio.models.wav2vec2 import Wav2Vec2Model
6
+ from torchaudio.models.wav2vec2.components import ConvLayerBlock
7
+
8
+ from ..util import get_logger
9
+
10
+ logger = get_logger()
11
+
12
+
13
+ # Map of friendly names to torchaudio pipeline bundles
14
+ MODEL_REGISTRY = {
15
+ "wav2vec2_base": pipelines.WAV2VEC2_BASE,
16
+ "wav2vec2_large": pipelines.WAV2VEC2_LARGE,
17
+ "wav2vec2_large_lv60k": pipelines.WAV2VEC2_LARGE_LV60K,
18
+ "hubert_base": pipelines.HUBERT_BASE,
19
+ "hubert_large": pipelines.HUBERT_LARGE,
20
+ "hubert_xlarge": pipelines.HUBERT_XLARGE,
21
+ "wavlm_base": pipelines.WAVLM_BASE,
22
+ "wavlm_base_plus": pipelines.WAVLM_BASE_PLUS,
23
+ "wavlm_large": pipelines.WAVLM_LARGE,
24
+ }
25
+
26
+
27
+ class SSLFeatureExtractor(nn.Module):
28
+ def __init__(self, model_name: str = "wavlm_base_plus", output_layer: int | None = None, sample_rate: int = 16000):
29
+ """
30
+ Args:
31
+ model_name: Name of the SSL model to use
32
+ output_layer: Which layer's features to extract (None for last layer), 1-based indexing
33
+ sample_rate: Sample rate of input audio
34
+ """
35
+ super().__init__()
36
+ self.output_layer = output_layer if output_layer is not None else -1
37
+
38
+ if model_name not in MODEL_REGISTRY:
39
+ raise ValueError(f"Unknown model: {model_name}. Available models: {list(MODEL_REGISTRY.keys())}")
40
+ bundle = MODEL_REGISTRY[model_name]
41
+ self.model: Wav2Vec2Model = bundle.get_model()
42
+ self.model.eval()
43
+ self.feature_dim: int = bundle._params["encoder_embed_dim"]
44
+
45
+ self.ssl_sample_rate = bundle.sample_rate
46
+ # Create resampler if needed
47
+ if sample_rate != self.ssl_sample_rate:
48
+ logger.debug(f"Resampling from {sample_rate} to {self.ssl_sample_rate} required by {model_name}.")
49
+ self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.ssl_sample_rate)
50
+ else:
51
+ self.resampler = None
52
+
53
+ @property
54
+ def hop_size(self) -> int:
55
+ """Get the hop size of the model's convolutional layers."""
56
+ hop_size = 1
57
+ for _, stride in self.conv_config:
58
+ hop_size *= stride
59
+ return hop_size
60
+
61
+ @property
62
+ def conv_config(self) -> list[tuple[int, int]]:
63
+ """Get the configuration of the convolutional layers in the model."""
64
+ conv_layers = []
65
+ for layer in self.model.feature_extractor.conv_layers:
66
+ layer: ConvLayerBlock
67
+ conv_layers.append((layer.kernel_size, layer.stride))
68
+ return conv_layers
69
+
70
+ def get_minimum_input_length(self, desired_output_length: int) -> int:
71
+ """Calculate the minimum input length required to produce a given output length."""
72
+ length = desired_output_length
73
+ for kernel_size, stride in reversed(self.conv_config):
74
+ length = (length - 1) * stride + kernel_size
75
+ return length
76
+
77
+ @torch.no_grad()
78
+ def forward(
79
+ self,
80
+ waveform: torch.Tensor,
81
+ lengths: torch.Tensor | None = None,
82
+ num_layers: int | None = None,
83
+ return_lengths: bool = False,
84
+ ) -> list[torch.Tensor]:
85
+ """
86
+ Args:
87
+ waveform: (batch_size, num_samples)
88
+ lengths: Optional tensor of sequence lengths for each batch item (used for attention masking)
89
+
90
+ Returns:
91
+ features: List of feature tensors for each layer (batch_size, frame, dim)
92
+ lengths: Sequence lengths for each batch item
93
+ """
94
+ if waveform.dim() == 1:
95
+ waveform = waveform.unsqueeze(0)
96
+ # Resample if needed
97
+ if self.resampler is not None:
98
+ waveform = self.resampler(waveform)
99
+
100
+ features, feature_lengths = self.model.extract_features(
101
+ waveform, lengths, num_layers=num_layers or self.output_layer
102
+ )
103
+
104
+ if return_lengths:
105
+ return features, feature_lengths
106
+ return features
src/kanade_tokenizer/module/transformer.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/meta-llama/llama3/blob/main/llama/model.py
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
4
+
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from ..util import get_logger
11
+ from .adaln_zero import AdaLNZero
12
+
13
+
14
+ logger = get_logger()
15
+
16
+
17
+ try:
18
+ from flash_attn import flash_attn_func, flash_attn_with_kvcache
19
+
20
+ FLASH_ATTN_AVAILABLE = True
21
+ except ImportError:
22
+ FLASH_ATTN_AVAILABLE = False
23
+ logger.warning(
24
+ "FlashAttention is not installed. Falling back to PyTorch SDPA implementation. There is no warranty that the model will work correctly."
25
+ )
26
+
27
+
28
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
29
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
30
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
31
+ freqs = torch.outer(t, freqs)
32
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
33
+ return freqs_cis
34
+
35
+
36
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
37
+ ndim = x.ndim
38
+ assert 0 <= 1 < ndim
39
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
40
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
41
+ return freqs_cis.view(*shape)
42
+
43
+
44
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
45
+ x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
46
+ freqs_cis = reshape_for_broadcast(freqs_cis, x_)
47
+ x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
48
+ return x_out.type_as(x)
49
+
50
+
51
+ class Attention(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim: int,
55
+ n_heads: int,
56
+ dropout: float,
57
+ window_size: int | None,
58
+ qkv_bias: bool = False,
59
+ proj_bias: bool = False,
60
+ use_flash_attention: bool = False,
61
+ causal: bool = False,
62
+ ):
63
+ super().__init__()
64
+ self.n_heads = n_heads
65
+ self.head_dim = dim // n_heads
66
+
67
+ self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias)
68
+ self.wk = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias)
69
+ self.wv = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias)
70
+ self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=proj_bias)
71
+
72
+ self.scale = self.head_dim**-0.5
73
+ self.dropout = dropout
74
+
75
+ # Enable local attention if window_size is specified
76
+ self.use_local_attention = window_size is not None
77
+ if self.use_local_attention:
78
+ assert window_size % 2 == 1, "Window size must be odd for local attention."
79
+ self.window_per_side = window_size // 2
80
+
81
+ self.use_flash_attention = use_flash_attention
82
+
83
+ self.causal = causal
84
+
85
+ def create_mask(
86
+ self, bsz: int, seqlen: int, mask: torch.Tensor | None, device: torch.device
87
+ ) -> torch.Tensor | None:
88
+ """Create attention mask combining provided mask and local attention constraints"""
89
+ if not self.use_local_attention and mask is None:
90
+ return None
91
+
92
+ # Start with all positions allowed
93
+ attn_mask = torch.ones((seqlen, seqlen), dtype=torch.bool, device=device)
94
+
95
+ if self.causal:
96
+ # Causal mask: no future positions allowed
97
+ attn_mask = torch.tril(attn_mask)
98
+
99
+ # Apply local attention constraints
100
+ if self.use_local_attention:
101
+ attn_mask = torch.triu(attn_mask, diagonal=-self.window_per_side)
102
+ attn_mask = torch.tril(attn_mask, diagonal=self.window_per_side)
103
+
104
+ # Expand mask to batch size
105
+ attn_mask = attn_mask.unsqueeze(0).expand(bsz, -1, -1)
106
+
107
+ # Apply global mask if provided
108
+ if mask is not None:
109
+ assert mask.shape[-1] == seqlen and mask.shape[-2] == seqlen, (
110
+ "Mask must be square and match sequence length."
111
+ )
112
+ # Ensure mask has correct batch dimensions
113
+ if mask.dim() == 2:
114
+ mask = mask.unsqueeze(0).expand(bsz, -1, -1)
115
+ attn_mask = attn_mask & mask
116
+
117
+ # Expand to head dimension
118
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
119
+ return attn_mask
120
+
121
+ def forward(
122
+ self,
123
+ x: torch.Tensor,
124
+ freqs_cis: torch.Tensor | None,
125
+ mask: torch.Tensor | None,
126
+ return_kv: bool = False,
127
+ ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
128
+ """Forward pass for multi-head attention.
129
+ Args:
130
+ x (torch.Tensor): Input tensor of shape (bsz, seqlen, dim).
131
+ freqs_cis (torch.Tensor, optional): Precomputed rotary frequencies.
132
+ mask (torch.Tensor, optional): Attention mask.
133
+ return_kv (bool): Whether to return KV pairs for caching.
134
+ Returns:
135
+ output (torch.Tensor): Output tensor of shape (bsz, seqlen, dim).
136
+ new_kv (tuple, optional): KV pairs if return_kv is True.
137
+ """
138
+ bsz, seqlen, _ = x.shape
139
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
140
+
141
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
142
+ xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
143
+ xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
144
+
145
+ # Apply rotary embeddings if provided
146
+ if freqs_cis is not None:
147
+ xq = apply_rotary_emb(xq, freqs_cis=freqs_cis[:seqlen])
148
+ xk = apply_rotary_emb(xk, freqs_cis=freqs_cis[:seqlen])
149
+
150
+ if self.use_flash_attention and FLASH_ATTN_AVAILABLE:
151
+ assert mask is None, "Flash attention does not support arbitrary masking."
152
+
153
+ # Flash Attention
154
+ window_size = (self.window_per_side, self.window_per_side) if self.use_local_attention else (-1, -1)
155
+ output = flash_attn_func(
156
+ xq, # (bsz, seqlen, n_heads, head_dim)
157
+ xk, # (bsz, seqlen, n_heads, head_dim)
158
+ xv, # (bsz, seqlen, n_heads, head_dim)
159
+ dropout_p=(self.dropout if self.training else 0.0),
160
+ softmax_scale=self.scale,
161
+ window_size=window_size,
162
+ causal=self.causal,
163
+ ) # (bsz, seqlen, n_heads, head_dim)
164
+
165
+ else:
166
+ attn_mask = self.create_mask(bsz, seqlen, mask, x.device)
167
+
168
+ # SDPA Attention
169
+ output = F.scaled_dot_product_attention(
170
+ xq.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim)
171
+ xk.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim)
172
+ xv.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim)
173
+ attn_mask=attn_mask, # (bsz, n_heads, seqlen, seqlen) boolean mask
174
+ dropout_p=self.dropout,
175
+ scale=self.scale,
176
+ ).transpose(1, 2) # (bsz, seqlen, n_heads, head_dim)
177
+
178
+ output = output.contiguous().view(bsz, seqlen, -1)
179
+ output = self.wo(output)
180
+
181
+ if return_kv:
182
+ return output, (xk, xv)
183
+ return output
184
+
185
+ def forward_with_cache(
186
+ self,
187
+ x: torch.Tensor,
188
+ kv_cache: tuple[torch.Tensor, torch.Tensor],
189
+ freqs_cis: torch.Tensor,
190
+ start_pos: int,
191
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
192
+ """
193
+ Forward pass with KV cache for efficient inference. Only used for inference.
194
+
195
+ Args:
196
+ x (torch.Tensor): Input tensor for the current step. Shape: (bsz, 1, dim)
197
+ kv_cache: A tuple of (key_cache, value_cache) from previous steps.
198
+ start_pos (int): The starting position of the new token in the sequence.
199
+ freqs_cis (torch.Tensor): Precomputed rotary frequencies.
200
+
201
+ Returns:
202
+ output (torch.Tensor): Output tensor after attention. Shape: (bsz, 1, dim)
203
+ new_kv (tuple): Updated KV cache including the new key and value.
204
+ """
205
+ bsz, seqlen, _ = x.shape
206
+ assert seqlen == 1, "KV cache method is designed for single-token generation."
207
+
208
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
209
+
210
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
211
+ xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
212
+ xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
213
+
214
+ # Apply rotary embeddings using the correct positional slice
215
+ xq = apply_rotary_emb(xq, freqs_cis=freqs_cis[start_pos : start_pos + seqlen])
216
+ xk = apply_rotary_emb(xk, freqs_cis=freqs_cis[start_pos : start_pos + seqlen])
217
+
218
+ # Update the KV cache
219
+ k_cache, v_cache = kv_cache
220
+ new_kv = (xk, xv)
221
+ xk = torch.cat([k_cache, xk], dim=1)
222
+ xv = torch.cat([v_cache, xv], dim=1)
223
+
224
+ # For single token generation, causal mask is implicitly handled.
225
+ # We attend to all keys (prefix + previous tokens).
226
+ if self.use_flash_attention and FLASH_ATTN_AVAILABLE:
227
+ # Flash Attention
228
+ output = flash_attn_with_kvcache(
229
+ xq, # (bsz, 1, n_heads, head_dim)
230
+ xk, # (bsz, 1+kv_len, n_heads, head_dim)
231
+ xv, # (bsz, 1+kv_len, n_heads, head_dim)
232
+ softmax_scale=self.scale,
233
+ ) # (bsz, 1, n_heads, head_dim)
234
+ else:
235
+ # SDPA Attention
236
+ output = F.scaled_dot_product_attention(
237
+ xq.transpose(1, 2), # (bsz, n_heads, 1, head_dim)
238
+ xk.transpose(1, 2), # (bsz, n_heads, 1+kv_len, head_dim)
239
+ xv.transpose(1, 2), # (bsz, n_heads, 1+kv_len, head_dim)
240
+ scale=self.scale,
241
+ ).transpose(1, 2) # (bsz, 1, n_heads, head_dim)
242
+
243
+ output = output.contiguous().view(bsz, seqlen, -1)
244
+ return self.wo(output), new_kv
245
+
246
+
247
+ class FeedForward(nn.Module):
248
+ def __init__(
249
+ self,
250
+ dim: int,
251
+ hidden_dim: int,
252
+ multiple_of: int,
253
+ ffn_dim_multiplier: float | None,
254
+ ):
255
+ super().__init__()
256
+ hidden_dim = int(2 * hidden_dim / 3)
257
+ # custom dim factor multiplier
258
+ if ffn_dim_multiplier is not None:
259
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
260
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
261
+
262
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
263
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
264
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
265
+
266
+ def forward(self, x):
267
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
268
+
269
+
270
+ class TransformerBlock(nn.Module):
271
+ def __init__(
272
+ self,
273
+ dim: int,
274
+ n_heads: int,
275
+ qkv_bias: bool,
276
+ proj_bias: bool,
277
+ window_size: int | None,
278
+ multiple_of: int,
279
+ ffn_dim_multiplier: float | None,
280
+ dropout: float,
281
+ norm_eps: float,
282
+ adanorm_condition_dim: int | None = None,
283
+ use_flash_attention: bool = False,
284
+ use_adaln_zero: bool = False,
285
+ causal: bool = False,
286
+ ):
287
+ super().__init__()
288
+ self.attention = Attention(
289
+ dim=dim,
290
+ n_heads=n_heads,
291
+ dropout=dropout,
292
+ window_size=window_size,
293
+ use_flash_attention=use_flash_attention,
294
+ qkv_bias=qkv_bias,
295
+ proj_bias=proj_bias,
296
+ causal=causal,
297
+ )
298
+
299
+ self.feed_forward = FeedForward(
300
+ dim=dim,
301
+ hidden_dim=4 * dim,
302
+ multiple_of=multiple_of,
303
+ ffn_dim_multiplier=ffn_dim_multiplier,
304
+ )
305
+
306
+ # Choose between AdaLNZero and regular LayerNorm
307
+ self.use_adaln_zero = use_adaln_zero
308
+ if self.use_adaln_zero:
309
+ assert adanorm_condition_dim is not None, "condition_dim must be provided when using AdaLNZero"
310
+ self.attention_norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=True)
311
+ self.ffn_norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=True)
312
+ else:
313
+ self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
314
+ self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
315
+
316
+ def forward(
317
+ self,
318
+ x: torch.Tensor,
319
+ freqs_cis: torch.Tensor | None,
320
+ mask: torch.Tensor | None,
321
+ condition: torch.Tensor | None = None,
322
+ return_kv: bool = False,
323
+ kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None,
324
+ start_pos: int | None = None,
325
+ ) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
326
+ """
327
+ Forward pass for a single Transformer block.
328
+ Args:
329
+ x (torch.Tensor): Input tensor of shape (bsz, seqlen, dim).
330
+ freqs_cis (torch.Tensor, optional): Precomputed rotary frequencies.
331
+ mask (torch.Tensor, optional): Attention mask.
332
+ condition (torch.Tensor, optional): Conditioning tensor for AdaLNZero.
333
+ return_kv (bool): Whether to return KV pairs for caching.
334
+ kv_cache (tuple, optional): KV cache for efficient inference.
335
+ start_pos (int, optional): Starting position for KV cache.
336
+ Returns:
337
+ out (torch.Tensor): Output tensor of shape (bsz, seqlen, dim).
338
+ new_kv (tuple, optional): New KV pairs if return_kv is True or kv_cache is provided.
339
+ """
340
+ # Apply normalization
341
+ if self.use_adaln_zero:
342
+ assert condition is not None, "condition must be provided when using AdaLNZero"
343
+ attn_normed, attn_gate = self.attention_norm(x, condition=condition)
344
+ else:
345
+ attn_normed = self.attention_norm(x)
346
+
347
+ # Forward attention with KV cache if provided
348
+ new_kv = None
349
+ if kv_cache is not None and start_pos is not None:
350
+ # Use KV cache for efficient inference
351
+ attn_out, new_kv = self.attention.forward_with_cache(attn_normed, kv_cache, freqs_cis, start_pos)
352
+ elif return_kv:
353
+ # Return KV pairs for caching
354
+ attn_out, new_kv = self.attention(attn_normed, freqs_cis, mask, return_kv=True)
355
+ else:
356
+ attn_out = self.attention(attn_normed, freqs_cis, mask)
357
+
358
+ # Apply gating for attention if using AdaLNZero
359
+ if self.use_adaln_zero:
360
+ h = x + attn_gate * attn_out # residual + gate * x
361
+ else:
362
+ h = x + attn_out
363
+
364
+ # Apply normalization for feedforward
365
+ if self.use_adaln_zero:
366
+ ffn_normed, ffn_gate = self.ffn_norm(h, condition=condition)
367
+ else:
368
+ ffn_normed = self.ffn_norm(h)
369
+
370
+ ffn_out = self.feed_forward(ffn_normed)
371
+
372
+ # Apply gating for feedforward if using AdaLNZero
373
+ if self.use_adaln_zero:
374
+ out = h + ffn_gate * ffn_out # residual + gate * x
375
+ else:
376
+ out = h + ffn_out
377
+
378
+ # If using KV cache, return the new KV pairs
379
+ if new_kv is not None:
380
+ return out, new_kv
381
+ return out
382
+
383
+
384
+ class Transformer(nn.Module):
385
+ def __init__(
386
+ self,
387
+ dim: int = 4096,
388
+ n_layers: int = 32,
389
+ n_heads: int = 32,
390
+ qkv_bias: bool = False,
391
+ proj_bias: bool = False,
392
+ window_size: int | None = None,
393
+ multiple_of: int = 256,
394
+ ffn_dim_multiplier: float | None = None,
395
+ dropout: float = 0.1,
396
+ norm_eps: float = 1e-5,
397
+ use_rope: bool = True,
398
+ rope_theta: float = 500000.0,
399
+ max_seq_len: int = 2048,
400
+ input_dim: int | None = None,
401
+ output_dim: int | None = None,
402
+ adanorm_condition_dim: int | None = None,
403
+ use_flash_attention: bool = False,
404
+ use_adaln_zero: bool = False,
405
+ use_xavier_init: bool = True,
406
+ causal: bool = False,
407
+ ):
408
+ super().__init__()
409
+ self.dim = dim
410
+ self.n_heads = n_heads
411
+ self.rope_theta = rope_theta
412
+ self.use_adaln_zero = use_adaln_zero
413
+
414
+ self.layers = nn.ModuleList()
415
+ for layer_id in range(n_layers):
416
+ self.layers.append(
417
+ TransformerBlock(
418
+ dim=dim,
419
+ n_heads=n_heads,
420
+ window_size=window_size,
421
+ multiple_of=multiple_of,
422
+ ffn_dim_multiplier=ffn_dim_multiplier,
423
+ dropout=dropout,
424
+ qkv_bias=qkv_bias,
425
+ proj_bias=proj_bias,
426
+ norm_eps=norm_eps,
427
+ adanorm_condition_dim=adanorm_condition_dim,
428
+ use_flash_attention=use_flash_attention,
429
+ use_adaln_zero=use_adaln_zero,
430
+ causal=causal,
431
+ )
432
+ )
433
+
434
+ # Choose between AdaLNZero (without gate) and regular LayerNorm for final norm
435
+ if self.use_adaln_zero:
436
+ assert adanorm_condition_dim is not None, "condition_dim must be provided when using AdaLNZero"
437
+ self.norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=False)
438
+ else:
439
+ self.norm = nn.LayerNorm(dim, eps=norm_eps)
440
+ self.input_proj = nn.Linear(input_dim, dim) if input_dim is not None else nn.Identity()
441
+ self.output_proj = nn.Linear(dim, output_dim) if output_dim is not None else nn.Identity()
442
+ self.output_dim_ = output_dim if output_dim is not None else dim
443
+
444
+ if use_rope:
445
+ self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
446
+ logger.debug(
447
+ f"Using RoPE with theta={rope_theta}, max_seq_len={max_seq_len}, "
448
+ f"dim={dim}, n_heads={n_heads}, freqs_cis shape={self.freqs_cis.shape}"
449
+ )
450
+ else:
451
+ self.freqs_cis = None
452
+
453
+ if window_size is not None:
454
+ logger.debug(f"Using local attention with window size {window_size}")
455
+
456
+ if self.use_adaln_zero:
457
+ logger.debug(f"Using AdaLNZero conditioning with condition_dim={adanorm_condition_dim}")
458
+
459
+ if use_flash_attention:
460
+ logger.debug("Using Flash Attention for memory-efficient attention computation")
461
+
462
+ if use_xavier_init:
463
+ logger.debug("Using Xavier initialization for linear layers")
464
+ self.apply(self._init_weights)
465
+ self.apply(self._init_adaln_zero)
466
+
467
+ @property
468
+ def output_dim(self) -> int:
469
+ return self.output_dim_
470
+
471
+ def _init_weights(self, module: nn.Module):
472
+ if isinstance(module, nn.Linear):
473
+ nn.init.xavier_normal_(module.weight)
474
+ if module.bias is not None:
475
+ nn.init.zeros_(module.bias)
476
+
477
+ def _init_adaln_zero(self, module: nn.Module):
478
+ if isinstance(module, AdaLNZero):
479
+ # Initialize condition projection weights to zero
480
+ nn.init.zeros_(module.condition_proj[1].weight)
481
+ nn.init.zeros_(module.condition_proj[1].bias)
482
+
483
+ def forward(
484
+ self,
485
+ x: torch.Tensor,
486
+ mask: torch.Tensor | None = None,
487
+ condition: torch.Tensor | None = None,
488
+ return_kv: bool = False,
489
+ kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
490
+ start_pos: int | None = None,
491
+ ) -> torch.Tensor | tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
492
+ """
493
+ Forward pass for the Transformer model.
494
+ Args:
495
+ x (torch.Tensor): Input tensor of shape (bsz, seqlen, input_dim).
496
+ mask (torch.Tensor, optional): Attention mask.
497
+ condition (torch.Tensor, optional): Conditioning tensor for AdaLNZero.
498
+ return_kv (bool): Whether to return KV pairs for caching.
499
+ kv_cache (list, optional): List of KV caches for each layer for efficient inference.
500
+ start_pos (int, optional): Starting position for KV cache.
501
+ Returns:
502
+ output (torch.Tensor): Output tensor of shape (bsz, seqlen, output_dim).
503
+ new_kv_list (list, optional): List of new KV pairs for each layer if return_kv is True or kv_cache is provided.
504
+ """
505
+ bsz, seqlen, _dim = x.shape
506
+
507
+ if self.use_adaln_zero:
508
+ assert condition is not None, "condition must be provided when using AdaLNZero"
509
+
510
+ # Rotary embeddings
511
+ if self.freqs_cis is not None:
512
+ # Recompute freqs_cis if the sequence length or starting position exceeds the precomputed length
513
+ expected_len = (start_pos + 1) if start_pos is not None else seqlen
514
+ if expected_len > self.freqs_cis.shape[0]:
515
+ logger.warning(
516
+ f"Input sequence length {expected_len} exceeds precomputed RoPE length {self.freqs_cis.shape[0]}. Recomputing freqs_cis."
517
+ )
518
+ self.freqs_cis = precompute_freqs_cis(self.dim // self.n_heads, expected_len * 4, self.rope_theta)
519
+
520
+ self.freqs_cis = self.freqs_cis.to(x.device)
521
+ freqs_cis = self.freqs_cis
522
+ else:
523
+ freqs_cis = None
524
+
525
+ x = self.input_proj(x)
526
+ new_kv_list = []
527
+ for i, layer in enumerate(self.layers):
528
+ # Collect KV cache if provided
529
+ if kv_cache is not None and start_pos is not None:
530
+ x, new_kv = layer(x, freqs_cis, mask, condition, kv_cache=kv_cache[i], start_pos=start_pos)
531
+ new_kv_list.append(new_kv)
532
+ elif return_kv:
533
+ x, new_kv = layer(x, freqs_cis, mask, condition, return_kv=True)
534
+ new_kv_list.append(new_kv)
535
+ else:
536
+ x = layer(x, freqs_cis, mask, condition)
537
+
538
+ # Apply final normalization
539
+ if self.use_adaln_zero:
540
+ x, _ = self.norm(x, condition=condition) # Final norm doesn't use gate
541
+ else:
542
+ x = self.norm(x)
543
+
544
+ output = self.output_proj(x)
545
+
546
+ # If using KV cache, return the new KV pairs
547
+ if new_kv_list:
548
+ return output, new_kv_list
549
+ return output
src/kanade_tokenizer/pipeline.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import jsonargparse
5
+ import lightning as L
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import yaml
10
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
11
+ from torch.optim import AdamW
12
+ from torch.optim.lr_scheduler import OneCycleLR
13
+
14
+ from .data.datamodule import AudioBatch
15
+ from .model import KanadeModel, KanadeModelConfig
16
+ from .module.audio_feature import MelSpectrogramFeature
17
+ from .module.discriminator import SpectrogramDiscriminator
18
+ from .module.fsq import FiniteScalarQuantizer
19
+ from .module.global_encoder import GlobalEncoder
20
+ from .module.postnet import PostNet
21
+ from .module.ssl_extractor import SSLFeatureExtractor
22
+ from .module.transformer import Transformer
23
+ from .util import freeze_modules, get_logger, load_vocoder, vocode
24
+
25
+ logger = get_logger()
26
+
27
+
28
+ @dataclass
29
+ class KanadePipelineConfig:
30
+ # Training control
31
+ train_feature: bool = True # Whether to train the feature reconstruction branch
32
+ train_mel: bool = True # Whether to train the mel spectrogram generation branch
33
+
34
+ # Audio settings
35
+ audio_length: int = 138240 # Length of audio input in samples
36
+
37
+ # Optimization settings
38
+ lr: float = 2e-4
39
+ weight_decay: float = 1e-4
40
+ betas: tuple[float, float] = (0.9, 0.99)
41
+ gradient_clip_val: float | None = 1.0
42
+
43
+ # LR scheduling parameters
44
+ warmup_percent: float = 0.1
45
+ lr_div_factor: float = 10.0
46
+ lr_final_div_factor: float = 1.0
47
+ anneal_mode: str = "cos"
48
+
49
+ # Loss weights
50
+ feature_l1_weight: float = 30.0
51
+ feature_l2_weight: float = 0.0
52
+ mel_l1_weight: float = 30.0
53
+ mel_l2_weight: float = 0.0
54
+ adv_weight: float = 1.0
55
+ feature_matching_weight: float = 10.0
56
+
57
+ # GAN settings
58
+ use_discriminator: bool = False
59
+ adv_loss_type: Literal["hinge", "least_square"] = "hinge" # Type of adversarial loss
60
+ discriminator_lr: float | None = None # Learning rate for discriminator
61
+ discriminator_start_step: int = 0 # Step to start training discriminator
62
+ discriminator_update_prob: float = 1.0 # Probability of updating discriminator at each step
63
+
64
+ # Checkpoint loading
65
+ ckpt_path: str | None = None # Path to checkpoint to load from
66
+ skip_loading_modules: tuple[str, ...] = () # Modules to skip when loading checkpoint
67
+
68
+ # Other settings
69
+ log_mel_samples: int = 10
70
+ use_torch_compile: bool = True
71
+
72
+
73
+ class KanadePipeline(L.LightningModule):
74
+ """LightningModule wrapper for KanadeModel, handling training (including GAN)."""
75
+
76
+ def __init__(
77
+ self,
78
+ model_config: KanadeModelConfig,
79
+ pipeline_config: KanadePipelineConfig,
80
+ ssl_feature_extractor: SSLFeatureExtractor,
81
+ local_encoder: Transformer,
82
+ local_quantizer: FiniteScalarQuantizer,
83
+ feature_decoder: Transformer | None,
84
+ global_encoder: GlobalEncoder,
85
+ mel_prenet: Transformer,
86
+ mel_decoder: Transformer,
87
+ mel_postnet: PostNet,
88
+ discriminator: SpectrogramDiscriminator | None = None,
89
+ ):
90
+ super().__init__()
91
+ self.config = pipeline_config
92
+ self.save_hyperparameters("model_config", "pipeline_config")
93
+ self.strict_loading = False
94
+ self.automatic_optimization = False
95
+ self.torch_compiled = False
96
+
97
+ # Validate components required for training
98
+ assert not pipeline_config.train_feature or feature_decoder is not None, (
99
+ "Feature decoder must be provided if training feature reconstruction"
100
+ )
101
+ logger.info(
102
+ f"Training configuration: train_feature={pipeline_config.train_feature}, train_mel={pipeline_config.train_mel}"
103
+ )
104
+
105
+ # 1. Kanade model
106
+ self.model = KanadeModel(
107
+ config=model_config,
108
+ ssl_feature_extractor=ssl_feature_extractor,
109
+ local_encoder=local_encoder,
110
+ local_quantizer=local_quantizer,
111
+ feature_decoder=feature_decoder,
112
+ global_encoder=global_encoder,
113
+ mel_decoder=mel_decoder,
114
+ mel_prenet=mel_prenet,
115
+ mel_postnet=mel_postnet,
116
+ )
117
+ self._freeze_unused_modules(pipeline_config.train_feature, pipeline_config.train_mel)
118
+
119
+ # Calculate padding for expected SSL output length
120
+ self.padding = self.model._calculate_waveform_padding(pipeline_config.audio_length)
121
+ logger.info(f"Input waveform padding for SSL feature extractor: {self.padding} samples")
122
+
123
+ # Calculate target mel spectrogram length
124
+ self.target_mel_length = self.model._calculate_target_mel_length(pipeline_config.audio_length)
125
+ logger.info(f"Target mel spectrogram length: {self.target_mel_length} frames")
126
+
127
+ # 2. Discriminator
128
+ self._init_discriminator(pipeline_config, discriminator)
129
+
130
+ # 3. Mel spectrogram feature extractor for loss computation
131
+ if pipeline_config.train_mel:
132
+ self.mel_spec = MelSpectrogramFeature(
133
+ sample_rate=model_config.sample_rate,
134
+ n_fft=model_config.n_fft,
135
+ hop_length=model_config.hop_length,
136
+ n_mels=model_config.n_mels,
137
+ padding=model_config.padding,
138
+ fmin=model_config.mel_fmin,
139
+ fmax=model_config.mel_fmax,
140
+ bigvgan_style_mel=model_config.bigvgan_style_mel,
141
+ )
142
+
143
+ # Mel sample storage for logging
144
+ self.vocoder = None
145
+ self.validation_examples = []
146
+ self.log_mel_samples = pipeline_config.log_mel_samples
147
+
148
+ def _freeze_unused_modules(self, train_feature: bool, train_mel: bool):
149
+ model = self.model
150
+ if not train_feature:
151
+ # Freeze local branch components if not training feature reconstruction
152
+ freeze_modules([model.local_encoder, model.local_quantizer, model.feature_decoder])
153
+ if model.conv_downsample is not None:
154
+ freeze_modules([model.conv_downsample, model.conv_upsample])
155
+ logger.info("Feature reconstruction branch frozen: local_encoder, local_quantizer, feature_decoder")
156
+
157
+ if not train_mel:
158
+ # Freeze global branch and mel generation components if not training mel generation
159
+ freeze_modules(
160
+ [model.global_encoder, model.mel_prenet, model.mel_conv_upsample, model.mel_decoder, model.mel_postnet]
161
+ )
162
+ logger.info(
163
+ "Mel generation branch frozen: global_encoder, mel_prenet, mel_conv_upsample, mel_decoder, mel_postnet"
164
+ )
165
+
166
+ def _init_discriminator(self, config: KanadePipelineConfig, discriminator: SpectrogramDiscriminator | None):
167
+ # Setup discriminator if provided
168
+ self.discriminator = discriminator
169
+ self.use_discriminator = config.use_discriminator and discriminator is not None and config.train_mel
170
+
171
+ if config.use_discriminator and discriminator is None:
172
+ logger.error(
173
+ "Discriminator is enabled in config but no discriminator model provided. Disabling GAN training."
174
+ )
175
+ if config.use_discriminator and discriminator is not None and not config.train_mel:
176
+ logger.warning(
177
+ "Discriminator is enabled but train_mel=False. Discriminator will not be effective without mel training."
178
+ )
179
+
180
+ self.discriminator_start_step = config.discriminator_start_step
181
+ self.discriminator_update_prob = config.discriminator_update_prob
182
+ if self.use_discriminator:
183
+ logger.info("Discriminator initialized for GAN training")
184
+ logger.info(f"Discriminator start step: {self.discriminator_start_step}")
185
+ logger.info(f"Discriminator update probability: {self.discriminator_update_prob}")
186
+
187
+ def setup(self, stage: str):
188
+ # Torch compile model if enabled
189
+ if torch.__version__ >= "2.0" and self.config.use_torch_compile:
190
+ self.model = torch.compile(self.model)
191
+ if self.discriminator is not None:
192
+ self.discriminator = torch.compile(self.discriminator)
193
+ self.torch_compiled = True
194
+
195
+ # Load checkpoint if provided
196
+ if self.config.ckpt_path:
197
+ ckpt_path = self.config.ckpt_path
198
+
199
+ # Download weights from HuggingFace Hub if needed
200
+ if ckpt_path.startswith("hf:"):
201
+ from huggingface_hub import hf_hub_download
202
+
203
+ repo_id = ckpt_path[len("hf:") :]
204
+ # Separate out revision if specified
205
+ revision = None
206
+ if "@" in repo_id:
207
+ repo_id, revision = repo_id.split("@", 1)
208
+
209
+ ckpt_path = hf_hub_download(repo_id, filename="model.safetensors", revision=revision)
210
+
211
+ self._load_weights(ckpt_path)
212
+
213
+ def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
214
+ """
215
+ Returns:
216
+ ssl_real: Extracted SSL features for local branch (B, T, C)
217
+ ssl_recon: Reconstructed SSL features (B, T, C) - only if train_feature=True
218
+ mel_recon: Generated mel spectrogram (B, n_mels, T) - only if train_mel=True
219
+ loss_dict: Dictionary with auxiliary information (codes, losses, etc.)
220
+ """
221
+ loss_dict = {}
222
+
223
+ # 1. Extract SSL features
224
+ local_ssl_features, global_ssl_features = self.model.forward_ssl_features(waveform, padding=self.padding)
225
+
226
+ # 2. Content branch processing
227
+ content_embeddings, _, ssl_recon, perplexity = self.model.forward_content(local_ssl_features)
228
+ loss_dict["local/perplexity"] = perplexity
229
+
230
+ # 3. Global branch processing and mel reconstruction
231
+ mel_recon = None
232
+ if self.config.train_mel:
233
+ global_embeddings = self.model.forward_global(global_ssl_features)
234
+ mel_recon = self.model.forward_mel(content_embeddings, global_embeddings, mel_length=self.target_mel_length)
235
+
236
+ return local_ssl_features, ssl_recon, mel_recon, loss_dict
237
+
238
+ def _get_reconstruction_loss(
239
+ self, audio_real: torch.Tensor, ssl_real: torch.Tensor, ssl_recon: torch.Tensor, mel_recon: torch.Tensor
240
+ ) -> tuple[torch.Tensor, dict, torch.Tensor]:
241
+ """Compute L1 + L2 loss for SSL feature and mel spectrogram reconstruction.
242
+ Returns:
243
+ total_loss: Combined reconstruction loss
244
+ loss_dict: Dictionary with individual loss components
245
+ mel_real: Real mel spectrogram for reference
246
+ """
247
+ if audio_real.dim() == 3:
248
+ audio_real = audio_real.squeeze(1)
249
+
250
+ loss_dict = {}
251
+ feature_loss, mel_loss = 0, 0
252
+
253
+ # Compute SSL feature reconstruction losses if training features
254
+ if self.config.train_feature and self.model.feature_decoder is not None:
255
+ assert ssl_real is not None and ssl_recon is not None, (
256
+ "SSL features must be provided for training feature reconstruction"
257
+ )
258
+ ssl_l1 = F.l1_loss(ssl_recon, ssl_real)
259
+ ssl_l2 = F.mse_loss(ssl_recon, ssl_real)
260
+
261
+ feature_loss = self.config.feature_l1_weight * ssl_l1 + self.config.feature_l2_weight * ssl_l2
262
+ loss_dict.update({"ssl_l1": ssl_l1, "ssl_l2": ssl_l2, "feature_loss": feature_loss})
263
+
264
+ # Compute mel spectrogram reconstruction losses if training mel
265
+ mel_real = None
266
+ if self.config.train_mel:
267
+ assert mel_recon is not None, "Mel reconstruction must be provided for training mel generation"
268
+ # Extract reference mel spectrogram from audio
269
+ mel_real = self.mel_spec(audio_real)
270
+
271
+ mel_l1 = F.l1_loss(mel_recon, mel_real)
272
+ mel_l2 = F.mse_loss(mel_recon, mel_real)
273
+ mel_loss = self.config.mel_l1_weight * mel_l1 + self.config.mel_l2_weight * mel_l2
274
+ loss_dict.update({"mel_l1": mel_l1, "mel_l2": mel_l2, "mel_loss": mel_loss})
275
+
276
+ total_loss = feature_loss + mel_loss
277
+ return total_loss, loss_dict, mel_real
278
+
279
+ def _get_discriminator_loss(
280
+ self, real_outputs: torch.Tensor, fake_outputs: torch.Tensor
281
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
282
+ """Compute the adversarial loss for discriminator.
283
+ Returns:
284
+ disc_loss: Total discriminator loss
285
+ real_loss: Loss component from real samples
286
+ fake_loss: Loss component from fake samples
287
+ """
288
+ if self.config.adv_loss_type == "hinge":
289
+ real_loss = torch.mean(torch.clamp(1 - real_outputs, min=0))
290
+ fake_loss = torch.mean(torch.clamp(1 + fake_outputs, min=0))
291
+ elif self.config.adv_loss_type == "least_square":
292
+ real_loss = torch.mean((real_outputs - 1) ** 2)
293
+ fake_loss = torch.mean(fake_outputs**2)
294
+ else:
295
+ raise ValueError(f"Unknown adversarial loss type: {self.config.adv_loss_type}")
296
+
297
+ disc_loss = real_loss + fake_loss
298
+ return disc_loss, real_loss, fake_loss
299
+
300
+ def _get_generator_loss(self, fake_outputs: torch.Tensor) -> torch.Tensor:
301
+ """Compute the adversarial loss for generator."""
302
+ if self.config.adv_loss_type == "hinge":
303
+ return torch.mean(torch.clamp(1 - fake_outputs, min=0))
304
+ elif self.config.adv_loss_type == "least_square":
305
+ return torch.mean((fake_outputs - 1) ** 2)
306
+ else:
307
+ raise ValueError(f"Unknown adversarial loss type: {self.config.adv_loss_type}")
308
+
309
+ def _get_feature_matching_loss(
310
+ self, real_intermediates: list[torch.Tensor], fake_intermediates: list[torch.Tensor]
311
+ ) -> torch.Tensor:
312
+ losses = []
313
+ for real_feat, fake_feat in zip(real_intermediates, fake_intermediates):
314
+ losses.append(torch.mean(torch.abs(real_feat.detach() - fake_feat)))
315
+ fm_loss = torch.mean(torch.stack(losses))
316
+ return fm_loss
317
+
318
+ def _discriminator_step(
319
+ self, batch: AudioBatch, optimizer_disc: torch.optim.Optimizer
320
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor], list[torch.Tensor]]:
321
+ """
322
+ Returns:
323
+ ssl_real: Real SSL features
324
+ ssl_recon: Reconstructed SSL features from generator
325
+ mel_recon: Generated mel spectrogram
326
+ loss_dict: Dictionary with auxiliary information
327
+ real_intermediates: Intermediate feature maps from discriminator for real mel
328
+ """
329
+ assert self.use_discriminator, "Discriminator step called but discriminator is not enabled"
330
+
331
+ ssl_real, ssl_recon, mel_recon, loss_dict = self(batch.waveform)
332
+ assert mel_recon is not None, "Mel reconstruction must be available for discriminator step"
333
+
334
+ # Get true mel spectrogram (always use original waveform)
335
+ mel_real = self.mel_spec(batch.waveform)
336
+
337
+ # Get discriminator outputs and intermediates for real mel
338
+ real_outputs, real_intermediates = self.discriminator(mel_real)
339
+ fake_outputs, _ = self.discriminator(mel_recon.detach())
340
+
341
+ # Compute discriminator loss
342
+ disc_loss, real_loss, fake_loss = self._get_discriminator_loss(real_outputs, fake_outputs)
343
+
344
+ # Log discriminator losses
345
+ batch_size = batch.waveform.size(0)
346
+ self.log("train/disc/real", real_loss, batch_size=batch_size)
347
+ self.log("train/disc/fake", fake_loss, batch_size=batch_size)
348
+ self.log("train/disc/loss", disc_loss, batch_size=batch_size, prog_bar=True)
349
+ for name, value in loss_dict.items():
350
+ self.log(f"train/{name}", value, batch_size=batch_size)
351
+
352
+ # Optimize discriminator
353
+ optimizer_disc.zero_grad()
354
+ self.manual_backward(disc_loss)
355
+
356
+ # Log gradient norm
357
+ grad_norm = torch.nn.utils.clip_grad_norm_(
358
+ self.discriminator.parameters(), max_norm=self.config.gradient_clip_val or torch.inf
359
+ )
360
+ self.log("train/disc/grad_norm", grad_norm, batch_size=batch_size)
361
+
362
+ optimizer_disc.step()
363
+
364
+ return ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates
365
+
366
+ def _generator_step(
367
+ self,
368
+ batch: AudioBatch,
369
+ optimizer_gen: torch.optim.Optimizer,
370
+ ssl_real: torch.Tensor | None = None,
371
+ ssl_recon: torch.Tensor | None = None,
372
+ mel_recon: torch.Tensor | None = None,
373
+ loss_dict: dict | None = None,
374
+ real_intermediates: list[torch.Tensor] | None = None,
375
+ training_disc: bool = False,
376
+ ) -> torch.Tensor:
377
+ """
378
+ Args:
379
+ batch: Audio batch with waveform and augmented_waveform
380
+ optimizer_gen: Generator optimizer
381
+ ssl_real: Real SSL features (optional)
382
+ ssl_recon: Reconstructed SSL features (optional)
383
+ mel_recon: Generated mel spectrogram (optional)
384
+ loss_dict: Dictionary with auxiliary information (optional)
385
+ real_intermediates: Intermediate feature maps from discriminator for real mel (optional)
386
+ training_disc: Whether discriminator is being trained in this step
387
+
388
+ Returns:
389
+ gen_loss: Total generator loss
390
+ """
391
+ # Forward pass through the model if not already done in discriminator step
392
+ if loss_dict is None:
393
+ ssl_real, ssl_recon, mel_recon, loss_dict = self(batch.waveform)
394
+
395
+ # Compute reconstruction loss (always use original waveform for mel target)
396
+ recon_loss, recon_dict, mel_real = self._get_reconstruction_loss(batch.waveform, ssl_real, ssl_recon, mel_recon)
397
+ gen_loss = recon_loss
398
+
399
+ # Compute adversarial and feature matching losses if using discriminator
400
+ batch_size = batch.waveform.size(0)
401
+ if training_disc:
402
+ assert mel_real is not None and mel_recon is not None, "Mel spectrograms must be provided for GAN training"
403
+
404
+ if real_intermediates is None:
405
+ _, real_intermediates = self.discriminator(mel_real)
406
+
407
+ fake_outputs, fake_intermediates = self.discriminator(mel_recon)
408
+
409
+ # Compute adversarial loss
410
+ adv_loss = self._get_generator_loss(fake_outputs)
411
+ gen_loss += self.config.adv_weight * adv_loss
412
+ self.log("train/gen/adv_loss", adv_loss, batch_size=batch_size)
413
+
414
+ # Compute feature matching loss
415
+ feature_matching_loss = self._get_feature_matching_loss(real_intermediates, fake_intermediates)
416
+ gen_loss += self.config.feature_matching_weight * feature_matching_loss
417
+ self.log("train/gen/feature_matching_loss", feature_matching_loss, batch_size=batch_size)
418
+
419
+ # Log reconstruction losses
420
+ for name, value in loss_dict.items():
421
+ self.log(f"train/{name}", value, batch_size=batch_size)
422
+ for name, value in recon_dict.items():
423
+ self.log(f"train/gen/{name}", value, batch_size=batch_size)
424
+
425
+ self.log("train/loss", gen_loss, batch_size=batch_size, prog_bar=True)
426
+
427
+ # Optimize generator
428
+ optimizer_gen.zero_grad()
429
+ self.manual_backward(gen_loss)
430
+
431
+ # Log gradient norm
432
+ grad_norm = torch.nn.utils.clip_grad_norm_(
433
+ self.model.parameters(), max_norm=self.config.gradient_clip_val or torch.inf
434
+ )
435
+ self.log("train/gen/grad_norm", grad_norm, batch_size=batch_size)
436
+
437
+ optimizer_gen.step()
438
+
439
+ return gen_loss
440
+
441
+ def training_step(self, batch: AudioBatch, batch_idx: int):
442
+ if self.use_discriminator:
443
+ optimizer_disc, optimizer_gen = self.optimizers()
444
+ scheduler_disc, scheduler_gen = self.lr_schedulers()
445
+ else:
446
+ optimizer_gen = self.optimizers()
447
+ scheduler_gen = self.lr_schedulers()
448
+
449
+ # Determine if discriminator should be trained in this step
450
+ training_disc = (
451
+ self.use_discriminator
452
+ and self.global_step >= self.discriminator_start_step
453
+ and torch.rand(1).item() < self.discriminator_update_prob
454
+ )
455
+ if self.global_step == self.discriminator_start_step and self.use_discriminator:
456
+ logger.info(f"Discriminator training starts at step {self.global_step}")
457
+
458
+ ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates = None, None, None, None, None
459
+
460
+ # Train discriminator if conditions are met
461
+ if training_disc:
462
+ ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates = self._discriminator_step(
463
+ batch, optimizer_disc
464
+ )
465
+ scheduler_disc.step()
466
+ elif self.use_discriminator:
467
+ # Step the discriminator scheduler even when not training discriminator
468
+ scheduler_disc.step()
469
+
470
+ # Train generator
471
+ self._generator_step(
472
+ batch, optimizer_gen, ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates, training_disc
473
+ )
474
+ scheduler_gen.step()
475
+
476
+ def validation_step(self, batch: AudioBatch, batch_idx: int):
477
+ audio_real = batch.waveform
478
+ ssl_real, ssl_recon, mel_recon, loss_dict = self(audio_real)
479
+
480
+ # Convert to waveform using vocoder for logging
481
+ batch_size = audio_real.size(0)
482
+
483
+ # Compute reconstruction loss
484
+ recon_loss, recon_dict, mel_real = self._get_reconstruction_loss(audio_real, ssl_real, ssl_recon, mel_recon)
485
+ gen_loss = recon_loss
486
+
487
+ # Log reconstruction losses
488
+ for name, value in loss_dict.items():
489
+ self.log(f"val/{name}", value, batch_size=batch_size)
490
+ for name, value in recon_dict.items():
491
+ self.log(f"val/gen/{name}", value, batch_size=batch_size)
492
+ self.log("val/loss", gen_loss, batch_size=batch_size)
493
+
494
+ # Save first few samples for visualization at end of epoch if training mel generation
495
+ if self.config.train_mel and len(self.validation_examples) < self.log_mel_samples:
496
+ assert mel_real is not None and mel_recon is not None, (
497
+ "Mel spectrograms must be provided for validation logging"
498
+ )
499
+ audio_real = audio_real[0].cpu()
500
+ audio_gen = None
501
+ if self.vocoder is not None:
502
+ audio_gen = self.vocode(mel_recon[0:1])[0].cpu()
503
+
504
+ self.validation_examples.append((mel_real[0].cpu(), mel_recon[0].detach().cpu(), audio_real, audio_gen))
505
+
506
+ def predict_step(self, batch: AudioBatch, batch_idx: int):
507
+ audio_real = batch.waveform
508
+ _, _, mel_gen, _ = self(audio_real)
509
+
510
+ audio_gen = self.vocode(mel_gen)
511
+
512
+ if audio_gen.dim() == 2:
513
+ audio_gen = audio_gen.unsqueeze(1)
514
+ return {"audio_ids": batch.audio_ids, "audio_real": audio_real, "audio_gen": audio_gen}
515
+
516
+ def configure_optimizers(self):
517
+ # Generator optimizer
518
+ optimizer_gen = AdamW(
519
+ self.model.parameters(), lr=self.config.lr, betas=self.config.betas, weight_decay=self.config.weight_decay
520
+ )
521
+
522
+ # Generator LR scheduler
523
+ scheduler_gen = OneCycleLR(
524
+ optimizer_gen,
525
+ max_lr=self.config.lr,
526
+ div_factor=self.config.lr_div_factor,
527
+ final_div_factor=self.config.lr_final_div_factor,
528
+ pct_start=self.config.warmup_percent,
529
+ anneal_strategy=self.config.anneal_mode,
530
+ total_steps=self.trainer.estimated_stepping_batches,
531
+ )
532
+
533
+ if not self.use_discriminator:
534
+ return ([optimizer_gen], [{"scheduler": scheduler_gen, "interval": "step"}])
535
+
536
+ # If using discriminator, also configure discriminator optimizer and scheduler
537
+ optimizer_disc = AdamW(
538
+ self.discriminator.parameters(),
539
+ lr=self.config.discriminator_lr or self.config.lr,
540
+ betas=self.config.betas,
541
+ weight_decay=self.config.weight_decay,
542
+ )
543
+
544
+ # Discriminator LR scheduler
545
+ scheduler_disc = OneCycleLR(
546
+ optimizer_disc,
547
+ max_lr=self.config.discriminator_lr or self.config.lr,
548
+ div_factor=self.config.lr_div_factor,
549
+ final_div_factor=self.config.lr_final_div_factor,
550
+ pct_start=self.config.warmup_percent,
551
+ anneal_strategy=self.config.anneal_mode,
552
+ total_steps=self.trainer.estimated_stepping_batches,
553
+ )
554
+
555
+ # Load optimizer state
556
+ if self.config.ckpt_path:
557
+ if self.config.ckpt_path.endswith(".ckpt"):
558
+ checkpoint = torch.load(self.config.ckpt_path)
559
+ optimizer_states = checkpoint["optimizer_states"]
560
+ if len(optimizer_states) > 1 and self.use_discriminator:
561
+ optimizer_disc.load_state_dict(optimizer_states[0])
562
+ optimizer_gen.load_state_dict(optimizer_states[1])
563
+ logger.info("Loaded discriminator and generator's optimizer states from checkpoint")
564
+ elif len(optimizer_states) == 1 and not self.use_discriminator:
565
+ # Load generator optimizer state only
566
+ optimizer_gen.load_state_dict(optimizer_states[0])
567
+ logger.info("Loaded generator's optimizer state from checkpoint")
568
+ else:
569
+ logger.info("No optimizer state loaded since checkpoint is not a .ckpt file")
570
+
571
+ return (
572
+ [optimizer_disc, optimizer_gen],
573
+ [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
574
+ )
575
+
576
+ def _setup_vocoder(self):
577
+ try:
578
+ return load_vocoder(name=self.model.config.vocoder_name)
579
+ except ImportError:
580
+ logger.error("Vocoder could not be loaded. Please install the required dependencies.")
581
+ return None
582
+
583
+ def vocode(self, mel: torch.Tensor) -> torch.Tensor:
584
+ self.vocoder = self.vocoder.to(mel.device)
585
+ waveform = vocode(self.vocoder, mel)
586
+ return waveform.cpu().float()
587
+
588
+ def on_validation_start(self):
589
+ self.vocoder = self._setup_vocoder()
590
+
591
+ def on_predict_start(self):
592
+ self.vocoder = self._setup_vocoder()
593
+
594
+ def on_validation_end(self):
595
+ if len(self.validation_examples) > 0:
596
+ for i, (mel_real, mel_recon, audio_real, audio_gen) in enumerate(self.validation_examples):
597
+ # Log spectrograms
598
+ fig_real = self._get_spectrogram_plot(mel_real)
599
+ fig_gen = self._get_spectrogram_plot(mel_recon)
600
+ self._log_figure(f"val/{i}_mel_real", fig_real)
601
+ self._log_figure(f"val/{i}_mel_gen", fig_gen)
602
+
603
+ # Log audio samples
604
+ if audio_gen is not None:
605
+ audio_real = audio_real.cpu().numpy()
606
+ audio_gen = audio_gen.cpu().numpy()
607
+ self._log_audio(f"val/{i}_audio_real", audio_real)
608
+ self._log_audio(f"val/{i}_audio_gen", audio_gen)
609
+
610
+ self.validation_examples = []
611
+
612
+ # Clear vocoder to free memory
613
+ self.vocoder = None
614
+
615
+ def _log_figure(self, tag: str, fig):
616
+ """Log a matplotlib figure to the logger."""
617
+ if isinstance(self.logger, TensorBoardLogger):
618
+ self.logger.experiment.add_figure(tag, fig, self.global_step)
619
+ elif isinstance(self.logger, WandbLogger):
620
+ import PIL.Image as Image
621
+
622
+ fig.canvas.draw()
623
+ image = Image.frombytes("RGBa", fig.canvas.get_width_height(), fig.canvas.buffer_rgba())
624
+ image = image.convert("RGB")
625
+ self.logger.log_image(tag, [image], step=self.global_step)
626
+
627
+ def _log_audio(self, tag: str, audio: np.ndarray):
628
+ """Log an audio sample to the logger."""
629
+ if isinstance(self.logger, TensorBoardLogger):
630
+ self.logger.experiment.add_audio(tag, audio, self.global_step, sample_rate=self.model.config.sample_rate)
631
+ elif isinstance(self.logger, WandbLogger):
632
+ self.logger.log_audio(
633
+ tag, [audio.flatten()], sample_rate=[self.model.config.sample_rate], step=self.global_step
634
+ )
635
+
636
+ def _get_spectrogram_plot(self, mel: torch.Tensor):
637
+ from matplotlib import pyplot as plt
638
+
639
+ mel = mel.detach().cpu().numpy()
640
+ fig, ax = plt.subplots(figsize=(10, 4))
641
+ im = ax.imshow(mel, aspect="auto", origin="lower", cmap="magma", vmin=-8.0, vmax=5.0)
642
+ fig.colorbar(im, ax=ax)
643
+ ax.set_ylabel("Mel bins")
644
+ ax.set_xlabel("Time steps")
645
+ fig.tight_layout()
646
+ return fig
647
+
648
+ def _load_weights(self, ckpt_path: str | None, model_state_dict: dict[str, torch.Tensor] | None = None):
649
+ """Load model and discriminator weights from checkpoint. Supports .ckpt (Lightning), .safetensors, .pt/.pth formats.
650
+ If model_state_dict is provided, load weights from it instead of ckpt_path."""
651
+
652
+ def select_keys(state_dict: dict, prefix: str) -> dict:
653
+ """Select keys from state_dict that start with the given prefix. Remove the prefix from keys."""
654
+ return {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
655
+
656
+ def remove_prefix(state_dict: dict, prefix: str) -> dict:
657
+ """Remove a prefix from keys that start with that prefix."""
658
+ return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
659
+
660
+ def add_prefix(state_dict: dict, prefix: str) -> dict:
661
+ """Add a prefix to keys that do not start with that prefix."""
662
+ return {f"{prefix}{k}" if not k.startswith(prefix) else k: v for k, v in state_dict.items()}
663
+
664
+ # Load state dict
665
+ if model_state_dict is not None:
666
+ # Load from provided state dict
667
+ disc_state_dict = {}
668
+ elif ckpt_path.endswith(".ckpt"):
669
+ # Lightning checkpoint
670
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
671
+ model_state_dict = select_keys(checkpoint["state_dict"], "model.")
672
+ disc_state_dict = select_keys(checkpoint["state_dict"], "discriminator.")
673
+ elif ckpt_path.endswith(".safetensors"):
674
+ # Safetensors checkpoint
675
+ from safetensors.torch import load_file
676
+
677
+ checkpoint = load_file(ckpt_path, device="cpu")
678
+ model_state_dict = checkpoint
679
+ disc_state_dict = {}
680
+ elif ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
681
+ # Standard PyTorch checkpoint
682
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
683
+ model_state_dict = checkpoint
684
+ disc_state_dict = {}
685
+ else:
686
+ raise ValueError(f"Unsupported checkpoint format: {ckpt_path}")
687
+
688
+ # Load model weights
689
+ model_state_dict = remove_prefix(model_state_dict, "_orig_mod.")
690
+ model_state_dict = {
691
+ k: v
692
+ for k, v in model_state_dict.items()
693
+ if not any(k.startswith(module) for module in self.config.skip_loading_modules)
694
+ }
695
+ if self.torch_compiled:
696
+ model_state_dict = add_prefix(model_state_dict, "_orig_mod.")
697
+
698
+ if len(model_state_dict) > 0:
699
+ result = self.model.load_state_dict(model_state_dict, strict=False)
700
+ logger.info(f"Loaded model weights from {ckpt_path or 'provided state_dict'}.")
701
+ if result.missing_keys:
702
+ logger.debug(f"Missing keys in model state_dict: {result.missing_keys}")
703
+ if result.unexpected_keys:
704
+ logger.debug(f"Unexpected keys in model state_dict: {result.unexpected_keys}")
705
+
706
+ # Load discriminator weights if available
707
+ if self.use_discriminator:
708
+ disc_state_dict = remove_prefix(disc_state_dict, "_orig_mod.")
709
+ if self.torch_compiled:
710
+ disc_state_dict = add_prefix(disc_state_dict, "_orig_mod.")
711
+
712
+ if len(disc_state_dict) > 0:
713
+ result = self.discriminator.load_state_dict(disc_state_dict, strict=False)
714
+ logger.info(f"Loaded discriminator weights from {ckpt_path}.")
715
+ if result.missing_keys:
716
+ logger.debug(f"Missing keys in discriminator state_dict: {result.missing_keys}")
717
+ if result.unexpected_keys:
718
+ logger.debug(f"Unexpected keys in discriminator state_dict: {result.unexpected_keys}")
719
+
720
+ @classmethod
721
+ def from_hparams(cls, config_path: str) -> "KanadePipeline":
722
+ """Instantiate KanadePipeline from config file.
723
+ Args:
724
+ config_path (str): Path to model configuration file (.yaml).
725
+ Returns:
726
+ KanadePipeline: Instantiated KanadePipeline.
727
+ """
728
+ # Load config
729
+ with open(config_path, "r") as f:
730
+ config = yaml.safe_load(f)
731
+
732
+ # Remove related fields to prevent loading actual weights here
733
+ new_config = {"model": config["model"]}
734
+ pipeline_config = new_config["model"]["init_args"]["pipeline_config"]
735
+ if "ckpt_path" in pipeline_config:
736
+ del pipeline_config["ckpt_path"]
737
+ if "skip_loading_modules" in pipeline_config:
738
+ del pipeline_config["skip_loading_modules"]
739
+
740
+ # Instantiate model using jsonargparse
741
+ parser = jsonargparse.ArgumentParser(exit_on_error=False)
742
+ parser.add_argument("--model", type=KanadePipeline)
743
+ cfg = parser.parse_object(new_config)
744
+ cfg = parser.instantiate_classes(cfg)
745
+ return cfg.model
746
+
747
+ @staticmethod
748
+ def from_pretrained(config_path: str, ckpt_path: str) -> "KanadePipeline":
749
+ """Load KanadePipeline from training configuration and checkpoint files.
750
+ Args:
751
+ config_path: Path to pipeline configuration file (YAML).
752
+ ckpt_path: Path to checkpoint file (.ckpt) or model weights (.safetensors).
753
+ Returns:
754
+ KanadePipeline: Instantied KanadePipeline with loaded weights.
755
+ """
756
+ # Load pipeline from config
757
+ model = KanadePipeline.from_hparams(config_path)
758
+ # Load the weights
759
+ model._load_weights(ckpt_path)
760
+ return model
src/kanade_tokenizer/util.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Literal
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ # Configure logger
8
+ logger = logging.getLogger("kanade_tokenizer")
9
+ logger.setLevel(logging.INFO)
10
+ handler = logging.StreamHandler()
11
+ handler.setLevel(logging.INFO)
12
+ handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(name)s: %(message)s"))
13
+ logger.addHandler(handler)
14
+
15
+
16
+ def get_logger() -> logging.Logger:
17
+ return logger
18
+
19
+
20
+ def freeze_modules(modules: list[nn.Module] | None):
21
+ for module in modules:
22
+ if module is not None:
23
+ for param in module.parameters():
24
+ param.requires_grad = False
25
+
26
+
27
+ def _load_audio_internal(
28
+ path: str, frame_offset: int | None = None, num_frames: int | None = None
29
+ ) -> tuple[torch.Tensor, int]:
30
+ # TorchAudio >= 2.9.0 removed decoding and encoding capabilities to TorchCodec.
31
+ # See: https://github.com/pytorch/audio/issues/3902
32
+ # waveform, sample_rate = torchaudio.load(path, frame_offset=frame_offset or 0, num_frames=num_frames or -1)
33
+
34
+ import soundfile as sf
35
+
36
+ with sf.SoundFile(path) as f:
37
+ if frame_offset is not None:
38
+ f.seek(frame_offset)
39
+ frames = f.read(frames=num_frames or -1, dtype="float32", always_2d=True)
40
+ waveform = torch.from_numpy(frames.T)
41
+ sample_rate = f.samplerate
42
+ return waveform, sample_rate
43
+
44
+
45
+ def load_audio(audio_path: str, sample_rate: int = 24000) -> torch.Tensor:
46
+ import torchaudio
47
+
48
+ """Load and preprocess audio file."""
49
+ waveform, sr = _load_audio_internal(audio_path)
50
+
51
+ # Convert to mono if stereo
52
+ if waveform.shape[0] > 1:
53
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
54
+
55
+ # Resample if necessary
56
+ if sr != sample_rate:
57
+ resampler = torchaudio.transforms.Resample(sr, sample_rate)
58
+ waveform = resampler(waveform)
59
+
60
+ # Normalize waveform
61
+ max_val = torch.max(torch.abs(waveform)) + 1e-8
62
+ waveform = waveform / max_val # Normalize to [-1, 1]
63
+
64
+ return waveform.squeeze(0) # Remove channel dimension
65
+
66
+
67
+ def load_vocoder(name: Literal["vocos", "hift"] = "vocos") -> torch.nn.Module:
68
+ if name == "vocos":
69
+ from vocos import Vocos
70
+
71
+ model = Vocos.from_pretrained("charactr/vocos-mel-24khz")
72
+ model = model.eval()
73
+ return model
74
+ elif name == "hift":
75
+ from huggingface_hub import hf_hub_download
76
+ from .module.hift import HiFTGenerator
77
+
78
+ # Download hte HiFT model from FunAudioLLM/CosyVoice2-0.5B
79
+ model_path = hf_hub_download(repo_id="FunAudioLLM/CosyVoice2-0.5B", filename="hift.pt")
80
+ model = HiFTGenerator()
81
+ model.load_weights(model_path)
82
+ model = model.eval()
83
+ return model
84
+ else:
85
+ raise ValueError(f"Unsupported vocoder name: {name}")
86
+
87
+
88
+ def vocode(vocoder, mel_spectrogram: torch.Tensor) -> torch.Tensor:
89
+ """Convert mel spectrogram to waveform using Vocos vocoder.
90
+ Args:
91
+ vocoder: Pretrained vocoder model.
92
+ mel_spectrogram (torch.Tensor): Input mel spectrogram tensor (..., n_mels, frame).
93
+ Returns:
94
+ torch.Tensor: Generated audio waveform tensor (..., samples).
95
+ """
96
+ mel_spectrogram = mel_spectrogram.to(torch.float32) # Ensure mel spectrogram is in float32
97
+
98
+ vocoder_class_name = vocoder.__class__.__name__
99
+ if "Vocos" in vocoder_class_name:
100
+ generated_waveform = vocoder.decode(mel_spectrogram)
101
+ elif "HiFT" in vocoder_class_name:
102
+ generated_waveform = vocoder.inference(mel_spectrogram)
103
+ else:
104
+ raise ValueError(f"Unsupported vocoder class: {vocoder_class_name}")
105
+
106
+ return generated_waveform