Spaces:
Running
Running
Dalzymodderever commited on
Commit ·
2cba492
1
Parent(s): a680a6c
Intial Commit
Browse files- app.py +119 -0
- requirements.txt +10 -0
- src/kanade_tokenizer/__init__.py +11 -0
- src/kanade_tokenizer/data/datamodule.py +146 -0
- src/kanade_tokenizer/data/dataset.py +201 -0
- src/kanade_tokenizer/model.py +500 -0
- src/kanade_tokenizer/module/adaln_zero.py +68 -0
- src/kanade_tokenizer/module/audio_feature.py +105 -0
- src/kanade_tokenizer/module/convnext.py +125 -0
- src/kanade_tokenizer/module/discriminator.py +78 -0
- src/kanade_tokenizer/module/fsq.py +140 -0
- src/kanade_tokenizer/module/global_encoder.py +75 -0
- src/kanade_tokenizer/module/hift.py +685 -0
- src/kanade_tokenizer/module/postnet.py +71 -0
- src/kanade_tokenizer/module/ssl_extractor.py +106 -0
- src/kanade_tokenizer/module/transformer.py +549 -0
- src/kanade_tokenizer/pipeline.py +760 -0
- src/kanade_tokenizer/util.py +106 -0
app.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
# --- 1. PATH SETUP ---
|
| 8 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
+
src_path = os.path.join(current_dir, "src")
|
| 10 |
+
if src_path not in sys.path:
|
| 11 |
+
sys.path.append(src_path)
|
| 12 |
+
|
| 13 |
+
# --- 2. Imports ---
|
| 14 |
+
try:
|
| 15 |
+
from kanade_tokenizer.model import KanadeModel
|
| 16 |
+
from kanade_tokenizer.util import load_vocoder, vocode, load_audio
|
| 17 |
+
except ImportError as e:
|
| 18 |
+
print(f"❌ IMPORT ERROR: {e}")
|
| 19 |
+
raise e
|
| 20 |
+
|
| 21 |
+
# --- Configuration ---
|
| 22 |
+
KANADE_REPO = "frothywater/kanade-25hz-clean"
|
| 23 |
+
KANADE_VOCODER = "hift"
|
| 24 |
+
DEVICE = "cpu"
|
| 25 |
+
SAMPLE_RATE = 24000
|
| 26 |
+
MAX_AUDIO_SECONDS = 30 # Limit audio to 30 seconds
|
| 27 |
+
|
| 28 |
+
print(f"🚀 Initializing on {DEVICE}...")
|
| 29 |
+
|
| 30 |
+
# --- 3. Load Models ---
|
| 31 |
+
print(f"📥 Loading Kanade...")
|
| 32 |
+
kanade_model = KanadeModel.from_pretrained(repo_id=KANADE_REPO).to(DEVICE).eval()
|
| 33 |
+
|
| 34 |
+
print(f"🔊 Loading HiFT Vocoder...")
|
| 35 |
+
kanade_vocoder = load_vocoder(name=KANADE_VOCODER).to(DEVICE).eval()
|
| 36 |
+
|
| 37 |
+
print("✅ Models Loaded.")
|
| 38 |
+
|
| 39 |
+
# --- Core Inference ---
|
| 40 |
+
def run_inference(source_wav, ref_wav):
|
| 41 |
+
"""Run voice conversion inference on CPU"""
|
| 42 |
+
with torch.inference_mode():
|
| 43 |
+
mel_output = kanade_model.voice_conversion(source_wav, ref_wav)
|
| 44 |
+
generated_wav = vocode(kanade_vocoder, mel_output.unsqueeze(0))
|
| 45 |
+
return generated_wav
|
| 46 |
+
|
| 47 |
+
# --- Main Handler ---
|
| 48 |
+
def voice_conversion(source_path, reference_path):
|
| 49 |
+
if not source_path or not reference_path:
|
| 50 |
+
return None, "⚠️ Please provide both source and reference audio."
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
# Load audio
|
| 54 |
+
source_wav = load_audio(source_path, sample_rate=SAMPLE_RATE).to(DEVICE)
|
| 55 |
+
ref_wav = load_audio(reference_path, sample_rate=SAMPLE_RATE).to(DEVICE)
|
| 56 |
+
|
| 57 |
+
# Check duration (30 second limit)
|
| 58 |
+
max_samples = MAX_AUDIO_SECONDS * SAMPLE_RATE
|
| 59 |
+
|
| 60 |
+
if source_wav.shape[-1] > max_samples:
|
| 61 |
+
source_wav = source_wav[..., :max_samples]
|
| 62 |
+
|
| 63 |
+
if ref_wav.shape[-1] > max_samples:
|
| 64 |
+
ref_wav = ref_wav[..., :max_samples]
|
| 65 |
+
|
| 66 |
+
# Run inference
|
| 67 |
+
start = time.time()
|
| 68 |
+
final_wav = run_inference(source_wav, ref_wav)
|
| 69 |
+
proc_time = time.time() - start
|
| 70 |
+
|
| 71 |
+
output_np = final_wav.squeeze().cpu().float().numpy()
|
| 72 |
+
output_duration = len(output_np) / SAMPLE_RATE
|
| 73 |
+
|
| 74 |
+
# RTF = processing time / audio duration (lower is better, <1 means faster than real-time)
|
| 75 |
+
rtf = proc_time / output_duration if output_duration > 0 else 0
|
| 76 |
+
|
| 77 |
+
return (SAMPLE_RATE, output_np), f"✅ {proc_time:.2f}s to convert {output_duration:.1f}s of audio | RTF: {rtf:.2f}x"
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
import traceback
|
| 81 |
+
traceback.print_exc()
|
| 82 |
+
return None, f"❌ Error: {str(e)}"
|
| 83 |
+
|
| 84 |
+
# --- Gradio Interface ---
|
| 85 |
+
with gr.Blocks(title="Kanade Voice Cloning") as demo:
|
| 86 |
+
gr.Markdown("""
|
| 87 |
+
# 🗣️ Kanade Voice Cloning
|
| 88 |
+
**Model:** `frothywater/kanade-25hz-clean`
|
| 89 |
+
|
| 90 |
+
Convert any audio into a target voice. Upload a source audio (what to say) and a reference audio (whose voice to use).
|
| 91 |
+
|
| 92 |
+
⏱️ **Limit:** Audio is trimmed to 30 seconds max.
|
| 93 |
+
""")
|
| 94 |
+
|
| 95 |
+
with gr.Row():
|
| 96 |
+
with gr.Column():
|
| 97 |
+
source_audio = gr.Audio(label="Source Audio (Content - what to say)", type="filepath")
|
| 98 |
+
reference_audio = gr.Audio(label="Reference Audio (Target Voice - whose voice)", type="filepath")
|
| 99 |
+
convert_btn = gr.Button("🎤 Convert Voice", variant="primary")
|
| 100 |
+
|
| 101 |
+
with gr.Column():
|
| 102 |
+
output_audio = gr.Audio(label="Result")
|
| 103 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
| 104 |
+
|
| 105 |
+
convert_btn.click(
|
| 106 |
+
voice_conversion,
|
| 107 |
+
inputs=[source_audio, reference_audio],
|
| 108 |
+
outputs=[output_audio, status_text]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
gr.Markdown("""
|
| 112 |
+
---
|
| 113 |
+
**Tips:**
|
| 114 |
+
- For best results, use clean reference audio (3-10 seconds of clear speech)
|
| 115 |
+
- Source and reference should ideally be similar in speaking pace
|
| 116 |
+
""")
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub
|
| 2 |
+
jsonargparse[signatures]
|
| 3 |
+
numpy
|
| 4 |
+
safetensors
|
| 5 |
+
soundfile
|
| 6 |
+
torch
|
| 7 |
+
torchaudio
|
| 8 |
+
tqdm
|
| 9 |
+
vocos
|
| 10 |
+
gradio
|
src/kanade_tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import KanadeFeatures, KanadeModel, KanadeModelConfig
|
| 2 |
+
from .util import load_audio, load_vocoder, vocode
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"KanadeModel",
|
| 6 |
+
"KanadeModelConfig",
|
| 7 |
+
"KanadeFeatures",
|
| 8 |
+
"load_audio",
|
| 9 |
+
"load_vocoder",
|
| 10 |
+
"vocode",
|
| 11 |
+
]
|
src/kanade_tokenizer/data/datamodule.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import lightning as L
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader, Dataset
|
| 7 |
+
|
| 8 |
+
from ..util import get_logger
|
| 9 |
+
from .dataset import AudioItem, ChunkedAudioDataset, pad_audio
|
| 10 |
+
|
| 11 |
+
logger = get_logger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class AudioBatch:
|
| 16 |
+
waveform: torch.Tensor # [batch, channels, samples]
|
| 17 |
+
audio_ids: list[str]
|
| 18 |
+
paths: list[Path]
|
| 19 |
+
sample_rates: list[int]
|
| 20 |
+
frame_offsets: list[int] | None # For chunked audio
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class AudioDataConfig:
|
| 25 |
+
csv_path: str
|
| 26 |
+
audio_root: str
|
| 27 |
+
|
| 28 |
+
# Audio processing
|
| 29 |
+
sample_rate: int | None = 16000
|
| 30 |
+
mono: bool = True
|
| 31 |
+
normalize: bool = True
|
| 32 |
+
|
| 33 |
+
# Chunking options
|
| 34 |
+
chunk_size: int | None = None
|
| 35 |
+
chunk_hop_size: int | None = None
|
| 36 |
+
|
| 37 |
+
# DataLoader options
|
| 38 |
+
batch_size: int = 32
|
| 39 |
+
num_workers: int = 4
|
| 40 |
+
pin_memory: bool = False
|
| 41 |
+
persistent_workers: bool = False
|
| 42 |
+
shuffle: bool = False
|
| 43 |
+
drop_last: bool = False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def audio_collate_fn(batch: list[AudioItem]) -> AudioBatch:
|
| 47 |
+
waveforms = [item.waveform for item in batch]
|
| 48 |
+
|
| 49 |
+
# Pad all waveforms to max length
|
| 50 |
+
max_length = max(wave.shape[1] for wave in waveforms)
|
| 51 |
+
if any(wave.shape[1] != max_length for wave in waveforms):
|
| 52 |
+
waveforms = [pad_audio(wave, max_length) for wave in waveforms]
|
| 53 |
+
|
| 54 |
+
return AudioBatch(
|
| 55 |
+
waveform=torch.stack(waveforms),
|
| 56 |
+
audio_ids=[item.audio_id for item in batch],
|
| 57 |
+
paths=[item.path for item in batch],
|
| 58 |
+
sample_rates=[item.sample_rate for item in batch],
|
| 59 |
+
frame_offsets=[item.frame_offset for item in batch],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AudioDataModule(L.LightningDataModule):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
train_config: AudioDataConfig,
|
| 67 |
+
val_config: AudioDataConfig | None = None,
|
| 68 |
+
test_config: AudioDataConfig | None = None,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.train_config = train_config
|
| 72 |
+
self.val_config = val_config or train_config
|
| 73 |
+
self.test_config = test_config or self.val_config
|
| 74 |
+
|
| 75 |
+
# Set to be initialized in setup()
|
| 76 |
+
self.train_dataset: Dataset | None = None
|
| 77 |
+
self.val_dataset: Dataset | None = None
|
| 78 |
+
self.test_dataset: Dataset | None = None
|
| 79 |
+
|
| 80 |
+
def _create_dataset(self, config: AudioDataConfig) -> Dataset:
|
| 81 |
+
return ChunkedAudioDataset(
|
| 82 |
+
csv_path=config.csv_path,
|
| 83 |
+
audio_root=config.audio_root,
|
| 84 |
+
chunk_size=config.chunk_size,
|
| 85 |
+
hop_size=config.chunk_hop_size,
|
| 86 |
+
mono=config.mono,
|
| 87 |
+
normalize=config.normalize,
|
| 88 |
+
target_sample_rate=config.sample_rate,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def setup(self, stage: str | None = None):
|
| 92 |
+
if stage == "fit" or stage is None:
|
| 93 |
+
self.train_dataset = self._create_dataset(self.train_config)
|
| 94 |
+
self.val_dataset = self._create_dataset(self.val_config)
|
| 95 |
+
elif stage == "validate":
|
| 96 |
+
self.val_dataset = self._create_dataset(self.val_config)
|
| 97 |
+
elif stage == "test" or stage == "predict":
|
| 98 |
+
self.test_dataset = self._create_dataset(self.test_config)
|
| 99 |
+
|
| 100 |
+
def train_dataloader(self) -> DataLoader:
|
| 101 |
+
return DataLoader(
|
| 102 |
+
self.train_dataset,
|
| 103 |
+
batch_size=self.train_config.batch_size,
|
| 104 |
+
num_workers=self.train_config.num_workers,
|
| 105 |
+
pin_memory=self.train_config.pin_memory,
|
| 106 |
+
persistent_workers=self.train_config.persistent_workers if self.train_config.num_workers > 0 else False,
|
| 107 |
+
shuffle=self.train_config.shuffle,
|
| 108 |
+
drop_last=self.train_config.drop_last,
|
| 109 |
+
collate_fn=audio_collate_fn,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def val_dataloader(self) -> DataLoader:
|
| 113 |
+
return DataLoader(
|
| 114 |
+
self.val_dataset,
|
| 115 |
+
batch_size=self.val_config.batch_size,
|
| 116 |
+
num_workers=self.val_config.num_workers,
|
| 117 |
+
pin_memory=self.val_config.pin_memory,
|
| 118 |
+
persistent_workers=self.val_config.persistent_workers if self.val_config.num_workers > 0 else False,
|
| 119 |
+
shuffle=False,
|
| 120 |
+
drop_last=False,
|
| 121 |
+
collate_fn=audio_collate_fn,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def test_dataloader(self) -> DataLoader:
|
| 125 |
+
return DataLoader(
|
| 126 |
+
self.test_dataset,
|
| 127 |
+
batch_size=self.test_config.batch_size,
|
| 128 |
+
num_workers=self.test_config.num_workers,
|
| 129 |
+
pin_memory=self.test_config.pin_memory,
|
| 130 |
+
persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
|
| 131 |
+
shuffle=False,
|
| 132 |
+
drop_last=False,
|
| 133 |
+
collate_fn=audio_collate_fn,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def predict_dataloader(self) -> DataLoader:
|
| 137 |
+
return DataLoader(
|
| 138 |
+
self.test_dataset,
|
| 139 |
+
batch_size=self.test_config.batch_size,
|
| 140 |
+
num_workers=self.test_config.num_workers,
|
| 141 |
+
pin_memory=self.test_config.pin_memory,
|
| 142 |
+
persistent_workers=self.test_config.persistent_workers if self.test_config.num_workers > 0 else False,
|
| 143 |
+
shuffle=False,
|
| 144 |
+
drop_last=False,
|
| 145 |
+
collate_fn=audio_collate_fn,
|
| 146 |
+
)
|
src/kanade_tokenizer/data/dataset.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
from ..util import _load_audio_internal, get_logger
|
| 10 |
+
|
| 11 |
+
logger = get_logger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class AudioItem:
|
| 16 |
+
waveform: torch.Tensor
|
| 17 |
+
audio_id: str
|
| 18 |
+
path: Path
|
| 19 |
+
sample_rate: int
|
| 20 |
+
frame_offset: int | None = None # For chunked audio
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def convert_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
# (1, samples)
|
| 25 |
+
if waveform.shape[0] > 1:
|
| 26 |
+
return torch.mean(waveform, dim=0, keepdim=True)
|
| 27 |
+
return waveform
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def resample_audio(waveform: torch.Tensor, orig_freq: int, new_freq: int) -> torch.Tensor:
|
| 31 |
+
if orig_freq != new_freq:
|
| 32 |
+
resampler = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq)
|
| 33 |
+
return resampler(waveform)
|
| 34 |
+
return waveform
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def normalize_audio(waveform: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
max_val = torch.max(torch.abs(waveform)) + 1e-8
|
| 39 |
+
return waveform / max_val
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def preprocess_audio(
|
| 43 |
+
waveform: torch.Tensor, sample_rate: int, mono: bool, normalize: bool, target_sample_rate: int | None = None
|
| 44 |
+
) -> tuple[torch.Tensor, int]:
|
| 45 |
+
# Convert to mono if needed
|
| 46 |
+
if mono:
|
| 47 |
+
waveform = convert_to_mono(waveform)
|
| 48 |
+
|
| 49 |
+
# Resample if needed
|
| 50 |
+
if target_sample_rate is not None and sample_rate != target_sample_rate:
|
| 51 |
+
waveform = resample_audio(waveform, sample_rate, target_sample_rate)
|
| 52 |
+
sample_rate = target_sample_rate
|
| 53 |
+
|
| 54 |
+
# Normalize if needed
|
| 55 |
+
if normalize:
|
| 56 |
+
waveform = normalize_audio(waveform)
|
| 57 |
+
|
| 58 |
+
return waveform, sample_rate
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def pad_audio(waveform: torch.Tensor, target_length: int) -> torch.Tensor:
|
| 62 |
+
current_length = waveform.shape[1]
|
| 63 |
+
if current_length >= target_length:
|
| 64 |
+
return waveform
|
| 65 |
+
|
| 66 |
+
# Calculate padding needed
|
| 67 |
+
pad_length = target_length - current_length
|
| 68 |
+
# Pad with zeros at the end
|
| 69 |
+
padding = torch.zeros((waveform.shape[0], pad_length), dtype=waveform.dtype, device=waveform.device)
|
| 70 |
+
padded_waveform = torch.cat([waveform, padding], dim=1)
|
| 71 |
+
return padded_waveform
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class ChunkInfo:
|
| 76 |
+
audio_id: str
|
| 77 |
+
frame_offset: int # In target sample rate
|
| 78 |
+
num_frames: int # In target sample rate
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ChunkedAudioDataset(Dataset):
|
| 82 |
+
"""
|
| 83 |
+
Dataset that loads audio from CSV with optional chunking.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
csv_path: Path to the CSV file with columns: audio_id, path, length, sample_rate
|
| 87 |
+
audio_root: Root directory for audio files (prepended to paths in CSV)
|
| 88 |
+
chunk_size: Size of each chunk in frames (None = no chunking)
|
| 89 |
+
hop_size: Hop size between chunks in frames (None = use chunk_size)
|
| 90 |
+
mono: Convert to mono if True
|
| 91 |
+
normalize: Normalize audio if True
|
| 92 |
+
target_sample_rate: Resample to this sample rate if provided
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
csv_path: str,
|
| 98 |
+
audio_root: str,
|
| 99 |
+
chunk_size: int | None = None,
|
| 100 |
+
hop_size: int | None = None,
|
| 101 |
+
mono: bool = True,
|
| 102 |
+
normalize: bool = True,
|
| 103 |
+
target_sample_rate: int | None = None,
|
| 104 |
+
):
|
| 105 |
+
self.csv_path = csv_path
|
| 106 |
+
self.audio_root = audio_root
|
| 107 |
+
self.chunk_size = chunk_size
|
| 108 |
+
self.hop_size = hop_size if hop_size is not None else chunk_size
|
| 109 |
+
self.mono = mono
|
| 110 |
+
self.normalize = normalize
|
| 111 |
+
self.target_sample_rate = target_sample_rate
|
| 112 |
+
|
| 113 |
+
# Load CSV and compute chunks
|
| 114 |
+
self.file_entries = self._load_csv()
|
| 115 |
+
self.chunks = self._compute_chunks()
|
| 116 |
+
|
| 117 |
+
logger.info(f"Loaded dataset from {csv_path}: {len(self.file_entries)} files, {len(self.chunks)} chunks")
|
| 118 |
+
|
| 119 |
+
def _load_csv(self) -> dict[str, dict]:
|
| 120 |
+
"""Load audio metadata from CSV."""
|
| 121 |
+
entries = {}
|
| 122 |
+
with open(self.csv_path, "r", encoding="utf-8") as f:
|
| 123 |
+
reader = csv.DictReader(f)
|
| 124 |
+
for row in reader:
|
| 125 |
+
entries[row["audio_id"]] = {
|
| 126 |
+
"path": row["path"],
|
| 127 |
+
"length": int(row["length"]),
|
| 128 |
+
"sample_rate": int(row["sample_rate"]),
|
| 129 |
+
}
|
| 130 |
+
return entries
|
| 131 |
+
|
| 132 |
+
def _compute_chunks(self) -> list[ChunkInfo]:
|
| 133 |
+
"""Compute all chunks from the file entries."""
|
| 134 |
+
chunks = []
|
| 135 |
+
for audio_id, entry in self.file_entries.items():
|
| 136 |
+
length = entry["length"]
|
| 137 |
+
sample_rate = entry["sample_rate"]
|
| 138 |
+
|
| 139 |
+
# Adjust length if resampling to target sample rate
|
| 140 |
+
if self.target_sample_rate is not None and sample_rate != self.target_sample_rate:
|
| 141 |
+
length = int(length * self.target_sample_rate / sample_rate)
|
| 142 |
+
sample_rate = self.target_sample_rate
|
| 143 |
+
|
| 144 |
+
if self.chunk_size is None or length <= self.chunk_size:
|
| 145 |
+
# No chunking, or file is shorter than chunk size: use entire file
|
| 146 |
+
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=0, num_frames=length))
|
| 147 |
+
else:
|
| 148 |
+
# Chunking: compute all chunks with last chunk aligned to end
|
| 149 |
+
frame_offset = 0
|
| 150 |
+
while frame_offset + self.chunk_size <= length:
|
| 151 |
+
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=frame_offset, num_frames=self.chunk_size))
|
| 152 |
+
frame_offset += self.hop_size
|
| 153 |
+
|
| 154 |
+
# Add the last chunk aligned to the end
|
| 155 |
+
last_start = length - self.chunk_size
|
| 156 |
+
if last_start > frame_offset - self.hop_size:
|
| 157 |
+
chunks.append(ChunkInfo(audio_id=audio_id, frame_offset=last_start, num_frames=self.chunk_size))
|
| 158 |
+
|
| 159 |
+
return chunks
|
| 160 |
+
|
| 161 |
+
def __len__(self) -> int:
|
| 162 |
+
return len(self.chunks)
|
| 163 |
+
|
| 164 |
+
def __getitem__(self, idx: int) -> AudioItem:
|
| 165 |
+
"""Load and return a single audio chunk."""
|
| 166 |
+
chunk = self.chunks[idx]
|
| 167 |
+
entry = self.file_entries[chunk.audio_id]
|
| 168 |
+
orig_sample_rate = entry["sample_rate"]
|
| 169 |
+
full_path = Path(self.audio_root) / entry["path"]
|
| 170 |
+
|
| 171 |
+
# Calculate start frame and num frames in original sample rate
|
| 172 |
+
if self.target_sample_rate is not None and orig_sample_rate != self.target_sample_rate:
|
| 173 |
+
orig_frame_offset = int(chunk.frame_offset * orig_sample_rate / self.target_sample_rate)
|
| 174 |
+
orig_num_frames = int(chunk.num_frames * orig_sample_rate / self.target_sample_rate)
|
| 175 |
+
else:
|
| 176 |
+
orig_frame_offset = chunk.frame_offset
|
| 177 |
+
orig_num_frames = chunk.num_frames
|
| 178 |
+
|
| 179 |
+
waveform, sample_rate = _load_audio_internal(
|
| 180 |
+
full_path, frame_offset=orig_frame_offset, num_frames=orig_num_frames
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
waveform, sample_rate = preprocess_audio(
|
| 184 |
+
waveform=waveform,
|
| 185 |
+
sample_rate=sample_rate,
|
| 186 |
+
mono=self.mono,
|
| 187 |
+
normalize=self.normalize,
|
| 188 |
+
target_sample_rate=self.target_sample_rate,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Pad if necessary (in case file is shorter than expected)
|
| 192 |
+
if self.chunk_size is not None and waveform.shape[1] < self.chunk_size:
|
| 193 |
+
waveform = pad_audio(waveform, self.chunk_size)
|
| 194 |
+
|
| 195 |
+
return AudioItem(
|
| 196 |
+
waveform=waveform,
|
| 197 |
+
audio_id=chunk.audio_id,
|
| 198 |
+
path=full_path,
|
| 199 |
+
sample_rate=sample_rate,
|
| 200 |
+
frame_offset=chunk.frame_offset,
|
| 201 |
+
)
|
src/kanade_tokenizer/model.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
+
import jsonargparse
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .module.fsq import FiniteScalarQuantizer
|
| 11 |
+
from .module.global_encoder import GlobalEncoder
|
| 12 |
+
from .module.postnet import PostNet
|
| 13 |
+
from .module.ssl_extractor import SSLFeatureExtractor
|
| 14 |
+
from .module.transformer import Transformer
|
| 15 |
+
from .util import freeze_modules, get_logger
|
| 16 |
+
|
| 17 |
+
logger = get_logger()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class KanadeModelConfig:
|
| 22 |
+
# SSL Feature settings
|
| 23 |
+
local_ssl_layers: tuple[int, ...] = (6, 9) # Indices of SSL layers for local branch
|
| 24 |
+
global_ssl_layers: tuple[int, ...] = (1, 2) # Indices of SSL layers for global branch
|
| 25 |
+
normalize_ssl_features: bool = True # Whether to normalize local SSL features before encoding
|
| 26 |
+
|
| 27 |
+
# Down/up-sampling settings
|
| 28 |
+
downsample_factor: int = 2 # Temporal downsampling factor for local features
|
| 29 |
+
mel_upsample_factor: int = 4 # Conv1DTranspose upsampling factor for mel features before interpolation
|
| 30 |
+
use_conv_downsample: bool = True # Whether to use Conv1D for downsampling instead average pooling
|
| 31 |
+
local_interpolation_mode: str = "linear" # Interpolation mode for local upsampling ("linear", "nearest")
|
| 32 |
+
mel_interpolation_mode: str = "linear" # Interpolation mode for mel upsampling ("linear", "nearest")
|
| 33 |
+
|
| 34 |
+
# Mel spectrogram settings
|
| 35 |
+
sample_rate: int = 24000
|
| 36 |
+
n_fft: int = 1024
|
| 37 |
+
hop_length: int = 256
|
| 38 |
+
n_mels: int = 100
|
| 39 |
+
padding: str = "center"
|
| 40 |
+
mel_fmin: int = 0 # Minimum frequency for mel spectrograms
|
| 41 |
+
mel_fmax: int | None = None # Maximum frequency for mel spectrograms
|
| 42 |
+
bigvgan_style_mel: bool = False # Whether to use BigVGAN-style mel spectrograms
|
| 43 |
+
|
| 44 |
+
# Vocoder settings
|
| 45 |
+
vocoder_name: Literal["vocos", "hift"] = "vocos" # Vocoder to use for waveform synthesis
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class KanadeFeatures:
|
| 50 |
+
content_embedding: torch.Tensor | None = None # (seq_len, dim)
|
| 51 |
+
content_token_indices: torch.Tensor | None = None # (seq_len,)
|
| 52 |
+
global_embedding: torch.Tensor | None = None # (dim,)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class KanadeModel(nn.Module):
|
| 56 |
+
"""Model architecture and forward pass logic for Kanade tokenizer."""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
config: KanadeModelConfig,
|
| 61 |
+
ssl_feature_extractor: SSLFeatureExtractor,
|
| 62 |
+
local_encoder: Transformer,
|
| 63 |
+
local_quantizer: FiniteScalarQuantizer,
|
| 64 |
+
global_encoder: GlobalEncoder,
|
| 65 |
+
mel_prenet: Transformer,
|
| 66 |
+
mel_decoder: Transformer,
|
| 67 |
+
mel_postnet: PostNet,
|
| 68 |
+
feature_decoder: Transformer | None = None,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.config = config
|
| 72 |
+
self._init_ssl_extractor(config, ssl_feature_extractor)
|
| 73 |
+
self._init_local_branch(config, local_encoder, local_quantizer, feature_decoder)
|
| 74 |
+
self._init_global_branch(global_encoder)
|
| 75 |
+
self._init_mel_decoder(config, mel_prenet, mel_decoder, mel_postnet)
|
| 76 |
+
|
| 77 |
+
def _init_ssl_extractor(self, config: KanadeModelConfig, ssl_feature_extractor: SSLFeatureExtractor):
|
| 78 |
+
"""Initialize and configure SSL feature extractor."""
|
| 79 |
+
self.ssl_feature_extractor = ssl_feature_extractor
|
| 80 |
+
freeze_modules([self.ssl_feature_extractor])
|
| 81 |
+
logger.debug(
|
| 82 |
+
f"SSL feature extractor initialized and frozen, feature dim: {self.ssl_feature_extractor.feature_dim}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Configure local SSL layers
|
| 86 |
+
self.local_ssl_layers = list(config.local_ssl_layers)
|
| 87 |
+
if len(self.local_ssl_layers) > 1:
|
| 88 |
+
logger.debug(
|
| 89 |
+
f"Using average of {len(self.local_ssl_layers)} SSL layers for local branch: {self.local_ssl_layers}"
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
logger.debug(f"Using single SSL layer {self.local_ssl_layers[0]} for local branch")
|
| 93 |
+
|
| 94 |
+
if config.normalize_ssl_features:
|
| 95 |
+
logger.debug("Normalizing local SSL features before encoding")
|
| 96 |
+
|
| 97 |
+
# Configure global SSL layers
|
| 98 |
+
self.global_ssl_layers = list(config.global_ssl_layers)
|
| 99 |
+
if len(self.global_ssl_layers) > 1:
|
| 100 |
+
logger.debug(
|
| 101 |
+
f"Using average of {len(self.global_ssl_layers)} SSL layers for global branch: {self.global_ssl_layers}"
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
logger.debug(f"Using single SSL layer {self.global_ssl_layers[0]} for global branch")
|
| 105 |
+
|
| 106 |
+
def _init_local_branch(
|
| 107 |
+
self,
|
| 108 |
+
config: KanadeModelConfig,
|
| 109 |
+
local_encoder: Transformer,
|
| 110 |
+
local_quantizer: FiniteScalarQuantizer,
|
| 111 |
+
feature_decoder: Transformer | None,
|
| 112 |
+
):
|
| 113 |
+
"""Initialize local branch components (encoder, downsampling, quantizer, decoder)."""
|
| 114 |
+
self.local_encoder = local_encoder
|
| 115 |
+
self.local_quantizer = local_quantizer
|
| 116 |
+
self.feature_decoder = feature_decoder
|
| 117 |
+
|
| 118 |
+
# Configure downsampling
|
| 119 |
+
self.downsample_factor = config.downsample_factor
|
| 120 |
+
if self.downsample_factor > 1:
|
| 121 |
+
logger.debug(f"Using temporal downsampling with factor {self.downsample_factor}")
|
| 122 |
+
if config.use_conv_downsample:
|
| 123 |
+
# Create Conv1d layers for downsampling and upsampling local embeddings
|
| 124 |
+
feature_dim = local_encoder.output_dim
|
| 125 |
+
self.conv_downsample = nn.Conv1d(
|
| 126 |
+
feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
|
| 127 |
+
)
|
| 128 |
+
self.conv_upsample = nn.ConvTranspose1d(
|
| 129 |
+
feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
|
| 130 |
+
) # won't be used unless training feature reconstruction
|
| 131 |
+
logger.debug(f"Using Conv1d downsampling/upsampling with kernel size {config.downsample_factor}")
|
| 132 |
+
else:
|
| 133 |
+
self.conv_downsample = None
|
| 134 |
+
self.conv_upsample = None
|
| 135 |
+
logger.debug("Using average pooling and linear interpolation for downsampling/upsampling")
|
| 136 |
+
else:
|
| 137 |
+
self.conv_downsample = None
|
| 138 |
+
self.conv_upsample = None
|
| 139 |
+
|
| 140 |
+
def _init_global_branch(self, global_encoder: GlobalEncoder):
|
| 141 |
+
"""Initialize global branch components."""
|
| 142 |
+
self.global_encoder = global_encoder
|
| 143 |
+
|
| 144 |
+
def _init_mel_decoder(
|
| 145 |
+
self, config: KanadeModelConfig, mel_prenet: Transformer, mel_decoder: Transformer, mel_postnet: PostNet
|
| 146 |
+
):
|
| 147 |
+
"""Initialize mel decoder components (prenet, upsampling, decoder, postnet)."""
|
| 148 |
+
self.mel_prenet = mel_prenet
|
| 149 |
+
self.mel_decoder = mel_decoder
|
| 150 |
+
self.mel_postnet = mel_postnet
|
| 151 |
+
|
| 152 |
+
# Configure mel upsampling
|
| 153 |
+
self.mel_conv_upsample = None
|
| 154 |
+
if config.mel_upsample_factor > 1:
|
| 155 |
+
# Create Conv1DTranspose layer for mel upsampling
|
| 156 |
+
input_dim = mel_prenet.output_dim
|
| 157 |
+
self.mel_conv_upsample = nn.ConvTranspose1d(
|
| 158 |
+
input_dim, input_dim, kernel_size=config.mel_upsample_factor, stride=config.mel_upsample_factor
|
| 159 |
+
)
|
| 160 |
+
logger.debug(f"Using Conv1DTranspose for mel upsampling with factor {config.mel_upsample_factor}")
|
| 161 |
+
|
| 162 |
+
def _calculate_waveform_padding(self, audio_length: int, ensure_recon_length: bool = False) -> int:
|
| 163 |
+
"""Calculate required padding for input waveform to ensure consistent SSL feature lengths."""
|
| 164 |
+
extractor = self.ssl_feature_extractor
|
| 165 |
+
sample_rate = self.config.sample_rate
|
| 166 |
+
# SSL may resample the input to its own sample rate, so calculate the number of samples after resampling
|
| 167 |
+
num_samples_after_resampling = audio_length / sample_rate * extractor.ssl_sample_rate
|
| 168 |
+
# We expect the SSL feature extractor to be consistent with its hop size
|
| 169 |
+
expected_ssl_output_length = math.ceil(num_samples_after_resampling / extractor.hop_size)
|
| 170 |
+
# If ensure_recon_length is True, we want to make sure the output length is exactly divisible by downsample factor
|
| 171 |
+
if ensure_recon_length and (remainder := expected_ssl_output_length % self.downsample_factor) != 0:
|
| 172 |
+
expected_ssl_output_length += self.downsample_factor - remainder
|
| 173 |
+
# But it may require more input samples to produce that output length, so calculate the required input length
|
| 174 |
+
num_samples_required_after_resampling = extractor.get_minimum_input_length(expected_ssl_output_length)
|
| 175 |
+
# That number of samples is at the SSL sample rate, so convert back to our original sample rate
|
| 176 |
+
num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
|
| 177 |
+
# Calculate padding needed on each side
|
| 178 |
+
padding = math.ceil((num_samples_required - audio_length) / 2)
|
| 179 |
+
return padding
|
| 180 |
+
|
| 181 |
+
def _calculate_original_audio_length(self, token_length: int) -> int:
|
| 182 |
+
"""Calculate the original audio length based on token length."""
|
| 183 |
+
extractor = self.ssl_feature_extractor
|
| 184 |
+
sample_rate = self.config.sample_rate
|
| 185 |
+
# Calculate the feature length before downsampling
|
| 186 |
+
feature_length = token_length * self.downsample_factor
|
| 187 |
+
num_samples_required_after_resampling = extractor.get_minimum_input_length(feature_length)
|
| 188 |
+
num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
|
| 189 |
+
return math.ceil(num_samples_required)
|
| 190 |
+
|
| 191 |
+
def _calculate_target_mel_length(self, audio_length: int) -> int:
|
| 192 |
+
"""Calculate the target mel spectrogram length based on audio length."""
|
| 193 |
+
if self.config.padding == "center":
|
| 194 |
+
return audio_length // self.config.hop_length + 1
|
| 195 |
+
elif self.config.padding == "same":
|
| 196 |
+
return audio_length // self.config.hop_length
|
| 197 |
+
else:
|
| 198 |
+
return (audio_length - self.config.n_fft) // self.config.hop_length + 1
|
| 199 |
+
|
| 200 |
+
def _process_ssl_features(self, features: list[torch.Tensor], layers: list[int]) -> torch.Tensor:
|
| 201 |
+
if len(layers) > 1:
|
| 202 |
+
# Get features from multiple layers and average them
|
| 203 |
+
selected_features = [features[i - 1] for i in layers]
|
| 204 |
+
mixed_features = torch.stack(selected_features, dim=0).mean(dim=0)
|
| 205 |
+
else:
|
| 206 |
+
# Just take the single specified layer
|
| 207 |
+
mixed_features = features[layers[0] - 1]
|
| 208 |
+
return mixed_features
|
| 209 |
+
|
| 210 |
+
def _normalize_ssl_features(self, features: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
| 211 |
+
if not self.config.normalize_ssl_features:
|
| 212 |
+
return features
|
| 213 |
+
|
| 214 |
+
# Compute mean and std across time steps for each sample and feature dimension
|
| 215 |
+
mean = torch.mean(features, dim=1, keepdim=True) # (B, 1, C)
|
| 216 |
+
std = torch.std(features, dim=1, keepdim=True) # (B, 1, C)
|
| 217 |
+
return (features - mean) / (std + eps)
|
| 218 |
+
|
| 219 |
+
def forward_ssl_features(
|
| 220 |
+
self, waveform: torch.Tensor, padding: int | None = None
|
| 221 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 222 |
+
"""Forward pass to extract SSL features. (B, T, C)
|
| 223 |
+
Args:
|
| 224 |
+
waveform: Input waveform tensor of shape (B, channels, samples)
|
| 225 |
+
padding: Optional padding to apply on both sides of the waveform. This is useful to ensure
|
| 226 |
+
that the SSL feature extractor produces consistent output lengths.
|
| 227 |
+
Returns:
|
| 228 |
+
local_ssl_features: Local SSL features for local branch. (B, T, C)
|
| 229 |
+
global_ssl_features: Global SSL features for global branch. (B, T, C)
|
| 230 |
+
"""
|
| 231 |
+
# Prepare input waveform
|
| 232 |
+
if waveform.dim() == 3:
|
| 233 |
+
waveform = waveform.squeeze(1)
|
| 234 |
+
|
| 235 |
+
# 1. Extract SSL features
|
| 236 |
+
if padding > 0:
|
| 237 |
+
waveform = F.pad(waveform, (padding, padding), mode="constant")
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
ssl_features = self.ssl_feature_extractor(waveform)
|
| 241 |
+
|
| 242 |
+
local_ssl_features = self._process_ssl_features(ssl_features, self.local_ssl_layers)
|
| 243 |
+
local_ssl_features = self._normalize_ssl_features(local_ssl_features)
|
| 244 |
+
|
| 245 |
+
global_ssl_features = self._process_ssl_features(ssl_features, self.global_ssl_layers)
|
| 246 |
+
|
| 247 |
+
return local_ssl_features, global_ssl_features
|
| 248 |
+
|
| 249 |
+
def forward_content(
|
| 250 |
+
self, local_ssl_features: torch.Tensor
|
| 251 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
| 252 |
+
"""Forward pass to extract content embeddings from the local branch.
|
| 253 |
+
Args:
|
| 254 |
+
local_ssl_features: Local SSL features tensor of shape (B, T, C)
|
| 255 |
+
Returns:
|
| 256 |
+
local_quantized: Quantized local embeddings. (B, T/factor, C)
|
| 257 |
+
indices: Content token indices. (B, T/factor)
|
| 258 |
+
ssl_recon: Reconstructed SSL features (if feature decoder is present). (B, T, C)
|
| 259 |
+
perplexity: Quantizer perplexity (if feature decoder is present). Scalar tensor.
|
| 260 |
+
"""
|
| 261 |
+
local_encoded = self.local_encoder(local_ssl_features)
|
| 262 |
+
|
| 263 |
+
# Downsample temporally if needed: (B, T, C) -> (B, T/factor, C)
|
| 264 |
+
if self.downsample_factor > 1:
|
| 265 |
+
if self.config.use_conv_downsample:
|
| 266 |
+
local_encoded = self.conv_downsample(local_encoded.transpose(1, 2)).transpose(1, 2)
|
| 267 |
+
else:
|
| 268 |
+
local_encoded = F.avg_pool1d(
|
| 269 |
+
local_encoded.transpose(1, 2), kernel_size=self.downsample_factor, stride=self.downsample_factor
|
| 270 |
+
).transpose(1, 2)
|
| 271 |
+
|
| 272 |
+
# If training feature reconstruction, decode local embeddings
|
| 273 |
+
ssl_recon = None
|
| 274 |
+
perplexity = torch.tensor(0.0)
|
| 275 |
+
if self.feature_decoder is not None:
|
| 276 |
+
local_quantized, local_quantize_info = self.local_quantizer(local_encoded)
|
| 277 |
+
indices = local_quantize_info["indices"]
|
| 278 |
+
perplexity = torch.mean(local_quantize_info["perplexity"])
|
| 279 |
+
|
| 280 |
+
local_latent_for_ssl = local_quantized
|
| 281 |
+
# Upsample if needed
|
| 282 |
+
if self.downsample_factor > 1:
|
| 283 |
+
if self.config.use_conv_downsample:
|
| 284 |
+
# Use conv transpose for upsampling: (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T) -> (B, T, C)
|
| 285 |
+
local_latent_for_ssl = self.conv_upsample(local_latent_for_ssl.transpose(1, 2)).transpose(1, 2)
|
| 286 |
+
else:
|
| 287 |
+
# (B, T/factor, C) -> (B, T, C)
|
| 288 |
+
local_latent_for_ssl = F.interpolate(
|
| 289 |
+
local_latent_for_ssl.transpose(1, 2),
|
| 290 |
+
size=local_ssl_features.shape[1],
|
| 291 |
+
mode=self.config.local_interpolation_mode,
|
| 292 |
+
).transpose(1, 2)
|
| 293 |
+
|
| 294 |
+
ssl_recon = self.feature_decoder(local_latent_for_ssl)
|
| 295 |
+
else:
|
| 296 |
+
# If not training feature reconstruction, just get quantized local embeddings
|
| 297 |
+
local_quantized, indices = self.local_quantizer.encode(local_encoded)
|
| 298 |
+
|
| 299 |
+
return local_quantized, indices, ssl_recon, perplexity
|
| 300 |
+
|
| 301 |
+
def forward_global(self, global_ssl_features: torch.Tensor) -> torch.Tensor:
|
| 302 |
+
"""Forward pass to extract global embeddings from the global branch.
|
| 303 |
+
Args:
|
| 304 |
+
global_ssl_features: Global SSL features tensor of shape (B, T, C)
|
| 305 |
+
Returns:
|
| 306 |
+
global_encoded: Global embeddings. (B, C)
|
| 307 |
+
"""
|
| 308 |
+
global_encoded = self.global_encoder(global_ssl_features)
|
| 309 |
+
return global_encoded
|
| 310 |
+
|
| 311 |
+
def forward_mel(
|
| 312 |
+
self, content_embeddings: torch.Tensor, global_embeddings: torch.Tensor, mel_length: int
|
| 313 |
+
) -> torch.Tensor:
|
| 314 |
+
"""Forward pass to generate mel spectrogram from content and global embeddings.
|
| 315 |
+
Args:
|
| 316 |
+
content_embeddings: Content embeddings tensor of shape (B, T, C)
|
| 317 |
+
global_embeddings: Global embeddings tensor of shape (B, C)
|
| 318 |
+
mel_length: Target mel spectrogram length (T_mel)
|
| 319 |
+
Returns:
|
| 320 |
+
mel_recon: Reconstructed mel spectrogram tensor of shape (B, n_mels, T_mel)
|
| 321 |
+
"""
|
| 322 |
+
local_latent = self.mel_prenet(content_embeddings)
|
| 323 |
+
|
| 324 |
+
# Upsample local latent to match mel spectrogram length
|
| 325 |
+
# First use Conv1DTranspose if configured
|
| 326 |
+
if self.mel_conv_upsample is not None:
|
| 327 |
+
# (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T*upsample_factor) -> (B, T*upsample_factor, C)
|
| 328 |
+
local_latent = self.mel_conv_upsample(local_latent.transpose(1, 2)).transpose(1, 2)
|
| 329 |
+
local_latent = F.interpolate(
|
| 330 |
+
local_latent.transpose(1, 2), size=mel_length, mode=self.config.mel_interpolation_mode
|
| 331 |
+
).transpose(1, 2) # (B, T_current, C) -> (B, T_mel, C)
|
| 332 |
+
|
| 333 |
+
# Generate mel spectrogram, conditioned on global embeddings
|
| 334 |
+
mel_recon = self.mel_decoder(local_latent, condition=global_embeddings.unsqueeze(1))
|
| 335 |
+
mel_recon = mel_recon.transpose(1, 2) # (B, n_mels, T)
|
| 336 |
+
|
| 337 |
+
mel_recon = self.mel_postnet(mel_recon)
|
| 338 |
+
return mel_recon
|
| 339 |
+
|
| 340 |
+
# ======== Inference methods ========
|
| 341 |
+
|
| 342 |
+
def weights_to_save(self, *, include_modules: list[str]) -> dict[str, torch.Tensor]:
|
| 343 |
+
"""Get model weights for saving. Excludes certain modules not needed for inference."""
|
| 344 |
+
excluded_modules = [
|
| 345 |
+
m for m in ["ssl_feature_extractor", "feature_decoder", "conv_upsample"] if m not in include_modules
|
| 346 |
+
]
|
| 347 |
+
state_dict = {
|
| 348 |
+
name: param
|
| 349 |
+
for name, param in self.named_parameters()
|
| 350 |
+
if not any(name.startswith(excl) for excl in excluded_modules)
|
| 351 |
+
}
|
| 352 |
+
return state_dict
|
| 353 |
+
|
| 354 |
+
@classmethod
|
| 355 |
+
def from_hparams(cls, config_path: str) -> "KanadeModel":
|
| 356 |
+
"""Instantiate KanadeModel from config file.
|
| 357 |
+
Args:
|
| 358 |
+
config_path (str): Path to model configuration file (.yaml).
|
| 359 |
+
Returns:
|
| 360 |
+
KanadeModel: Instantiated KanadeModel.
|
| 361 |
+
"""
|
| 362 |
+
parser = jsonargparse.ArgumentParser(exit_on_error=False)
|
| 363 |
+
parser.add_argument("--model", type=KanadeModel)
|
| 364 |
+
cfg = parser.parse_path(config_path)
|
| 365 |
+
cfg = parser.instantiate_classes(cfg)
|
| 366 |
+
return cfg.model
|
| 367 |
+
|
| 368 |
+
@classmethod
|
| 369 |
+
def from_pretrained(
|
| 370 |
+
cls,
|
| 371 |
+
repo_id: str | None = None,
|
| 372 |
+
revision: str | None = None,
|
| 373 |
+
config_path: str | None = None,
|
| 374 |
+
weights_path: str | None = None,
|
| 375 |
+
) -> "KanadeModel":
|
| 376 |
+
"""Load KanadeModel either from HuggingFace Hub or local config and weights files.
|
| 377 |
+
Args:
|
| 378 |
+
repo_id (str, optional): HuggingFace Hub repository ID. If provided, loads config and weights from the hub.
|
| 379 |
+
revision (str, optional): Revision (branch, tag, commit) for the HuggingFace Hub repo.
|
| 380 |
+
config_path (str, optional): Path to model configuration file (.yaml). Required if repo_id is not provided.
|
| 381 |
+
weights_path (str, optional): Path to model weights file (.safetensors). Required if repo_id is not provided.
|
| 382 |
+
Returns:
|
| 383 |
+
KanadeModel: Loaded KanadeModel instance.
|
| 384 |
+
"""
|
| 385 |
+
if repo_id is not None:
|
| 386 |
+
# Load from HuggingFace Hub
|
| 387 |
+
from huggingface_hub import hf_hub_download
|
| 388 |
+
|
| 389 |
+
config_path = hf_hub_download(repo_id, "config.yaml", revision=revision)
|
| 390 |
+
weights_path = hf_hub_download(repo_id, "model.safetensors", revision=revision)
|
| 391 |
+
else:
|
| 392 |
+
# Check local paths
|
| 393 |
+
if config_path is None or weights_path is None:
|
| 394 |
+
raise ValueError(
|
| 395 |
+
"Please provide either HuggingFace Hub repo_id or both config_path and weights_path for model loading."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Load model from config
|
| 399 |
+
model = cls.from_hparams(config_path)
|
| 400 |
+
|
| 401 |
+
# Load weights
|
| 402 |
+
from safetensors.torch import load_file
|
| 403 |
+
|
| 404 |
+
state_dict = load_file(weights_path, device="cpu")
|
| 405 |
+
model.load_state_dict(state_dict, strict=False)
|
| 406 |
+
logger.info(f"Loaded weights from safetensors file: {weights_path}")
|
| 407 |
+
|
| 408 |
+
return model
|
| 409 |
+
|
| 410 |
+
@torch.inference_mode()
|
| 411 |
+
def encode(self, waveform: torch.Tensor, return_content: bool = True, return_global: bool = True) -> KanadeFeatures:
|
| 412 |
+
"""Extract content and/or global features from audio using Kanade model.
|
| 413 |
+
Args:
|
| 414 |
+
waveform (torch.Tensor): Input audio waveform tensor (samples,). The sample rate should match model config.
|
| 415 |
+
return_content (bool): Whether to extract content features.
|
| 416 |
+
return_global (bool): Whether to extract global features.
|
| 417 |
+
Returns:
|
| 418 |
+
dict[str, torch.Tensor]: Extracted features.
|
| 419 |
+
"""
|
| 420 |
+
audio_length = waveform.size(0)
|
| 421 |
+
padding = self._calculate_waveform_padding(audio_length)
|
| 422 |
+
local_ssl_features, global_ssl_features = self.forward_ssl_features(waveform.unsqueeze(0), padding=padding)
|
| 423 |
+
|
| 424 |
+
result = KanadeFeatures()
|
| 425 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 426 |
+
if return_content:
|
| 427 |
+
content_embedding, token_indices, _, _ = self.forward_content(local_ssl_features)
|
| 428 |
+
result.content_embedding = content_embedding.squeeze(0) # (seq_len, dim)
|
| 429 |
+
result.content_token_indices = token_indices.squeeze(0) # (seq_len,)
|
| 430 |
+
|
| 431 |
+
if return_global:
|
| 432 |
+
global_embedding = self.forward_global(global_ssl_features)
|
| 433 |
+
result.global_embedding = global_embedding.squeeze(0) # (dim,)
|
| 434 |
+
|
| 435 |
+
return result
|
| 436 |
+
|
| 437 |
+
def decode_token_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
| 438 |
+
"""Get content embeddings from content token indices. (..., seq_len) -> (..., seq_len, dim)"""
|
| 439 |
+
content_embedding = self.local_quantizer.decode(indices)
|
| 440 |
+
return content_embedding
|
| 441 |
+
|
| 442 |
+
@torch.inference_mode()
|
| 443 |
+
def decode(
|
| 444 |
+
self,
|
| 445 |
+
global_embedding: torch.Tensor,
|
| 446 |
+
content_token_indices: torch.Tensor | None = None,
|
| 447 |
+
content_embedding: torch.Tensor | None = None,
|
| 448 |
+
target_audio_length: int | None = None,
|
| 449 |
+
) -> torch.Tensor:
|
| 450 |
+
"""Synthesize audio from content and global features using Kanade model and Vocos.
|
| 451 |
+
Args:
|
| 452 |
+
global_embedding (torch.Tensor): Global embedding tensor (dim,).
|
| 453 |
+
content_token_indices (torch.Tensor, optional): Optional content token indices tensor (seq_len).
|
| 454 |
+
content_embedding (torch.Tensor, optional): Optional content embedding tensor (seq_len, dim).
|
| 455 |
+
If both content_token_indices and content_embedding are provided, content_embedding takes precedence.
|
| 456 |
+
target_audio_length (int, optional): Target length of the output audio in samples.
|
| 457 |
+
If None, uses the original audio length estimated from the sequence length of content tokens.
|
| 458 |
+
Returns:
|
| 459 |
+
torch.Tensor: Generated mel spectrogram tensor (n_mels, T).
|
| 460 |
+
"""
|
| 461 |
+
# Obtain content embedding if not provided
|
| 462 |
+
if content_embedding is None:
|
| 463 |
+
if content_token_indices is None:
|
| 464 |
+
raise ValueError("Either content_token_indices or content_embedding must be provided.")
|
| 465 |
+
content_embedding = self.decode_token_indices(content_token_indices)
|
| 466 |
+
|
| 467 |
+
if target_audio_length is None:
|
| 468 |
+
# Estimate original audio length from content token sequence length
|
| 469 |
+
seq_len = content_embedding.size(0)
|
| 470 |
+
target_audio_length = self._calculate_original_audio_length(seq_len)
|
| 471 |
+
|
| 472 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 473 |
+
mel_length = self._calculate_target_mel_length(target_audio_length)
|
| 474 |
+
content_embedding = content_embedding.unsqueeze(0) # (1, seq_len, dim)
|
| 475 |
+
global_embedding = global_embedding.unsqueeze(0) # (1, dim)
|
| 476 |
+
mel_spectrogram = self.forward_mel(content_embedding, global_embedding, mel_length=mel_length)
|
| 477 |
+
|
| 478 |
+
return mel_spectrogram.squeeze(0) # (n_mels, T)
|
| 479 |
+
|
| 480 |
+
@torch.inference_mode()
|
| 481 |
+
def voice_conversion(self, source_waveform: torch.Tensor, reference_waveform: torch.Tensor) -> torch.Tensor:
|
| 482 |
+
"""Convert voice using Kanade model and Vocos, keeping content from source and global characteristics from reference.
|
| 483 |
+
Only supports single audio input. Just a convenient wrapper around encode and decode methods.
|
| 484 |
+
Args:
|
| 485 |
+
source_waveform (torch.Tensor): Source audio waveform tensor (samples,).
|
| 486 |
+
reference_waveform (torch.Tensor): Reference audio waveform tensor (samples_ref,).
|
| 487 |
+
Returns:
|
| 488 |
+
torch.Tensor: Converted mel spectrogram tensor (n_mels, T).
|
| 489 |
+
"""
|
| 490 |
+
# Extract source content features and reference global features
|
| 491 |
+
source_features = self.encode(source_waveform, return_content=True, return_global=False)
|
| 492 |
+
reference_features = self.encode(reference_waveform, return_content=False, return_global=True)
|
| 493 |
+
|
| 494 |
+
# Synthesize mel spectrogram using source content and reference global features
|
| 495 |
+
mel_spectrogram = self.decode(
|
| 496 |
+
content_embedding=source_features.content_embedding,
|
| 497 |
+
global_embedding=reference_features.global_embedding,
|
| 498 |
+
target_audio_length=source_waveform.size(0),
|
| 499 |
+
)
|
| 500 |
+
return mel_spectrogram
|
src/kanade_tokenizer/module/adaln_zero.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/facebookresearch/DiT
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AdaLNZero(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Adaptive Layer Normalization Zero (AdaLNZero) module.
|
| 11 |
+
|
| 12 |
+
Combines LayerNorm with adaptive conditioning to produce shift, scale, and gate values.
|
| 13 |
+
The gate is used to scale features before residual connection.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
dim: Feature dimension
|
| 17 |
+
condition_dim: Conditioning dimension
|
| 18 |
+
eps: LayerNorm epsilon
|
| 19 |
+
return_gate: If True, returns gate value for scaling.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
condition_dim: int,
|
| 26 |
+
eps: float = 1e-5,
|
| 27 |
+
return_gate: bool = True,
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.dim = dim
|
| 31 |
+
self.condition_dim = condition_dim
|
| 32 |
+
self.return_gate = return_gate
|
| 33 |
+
|
| 34 |
+
# LayerNorm without learnable parameters
|
| 35 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 36 |
+
|
| 37 |
+
# Conditioning network: condition -> shift, scale, gate
|
| 38 |
+
output_dim = 3 * dim if return_gate else 2 * dim
|
| 39 |
+
self.condition_proj = nn.Sequential(
|
| 40 |
+
nn.SiLU(),
|
| 41 |
+
nn.Linear(condition_dim, output_dim),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Initialize to zero for stable training
|
| 45 |
+
nn.init.zeros_(self.condition_proj[1].weight)
|
| 46 |
+
nn.init.zeros_(self.condition_proj[1].bias)
|
| 47 |
+
|
| 48 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] | None:
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
x: Input tensor of shape (B, L, dim)
|
| 52 |
+
condition: Conditioning tensor of shape (B, L, condition_dim) or (B, 1, condition_dim)
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
modulated_x: Normalized and modulated features
|
| 56 |
+
gate: Gate values for scaling (None if return_gate=False)
|
| 57 |
+
"""
|
| 58 |
+
x_norm = self.norm(x)
|
| 59 |
+
condition_params = self.condition_proj(condition)
|
| 60 |
+
|
| 61 |
+
if self.return_gate:
|
| 62 |
+
shift, scale, gate = condition_params.chunk(3, dim=-1)
|
| 63 |
+
else:
|
| 64 |
+
shift, scale = condition_params.chunk(2, dim=-1)
|
| 65 |
+
gate = None
|
| 66 |
+
|
| 67 |
+
modulated_x = x_norm * (1 + scale) + shift
|
| 68 |
+
return modulated_x, gate
|
src/kanade_tokenizer/module/audio_feature.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from:
|
| 2 |
+
# Vocos: https://github.com/gemelo-ai/vocos/blob/main/vocos/feature_extractors.py
|
| 3 |
+
# BigVGAN: https://github.com/NVIDIA/BigVGAN/blob/main/meldataset.py (Also used by HiFT)
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
| 12 |
+
return torch.log(torch.clip(x, min=clip_val))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MelSpectrogramFeature(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
sample_rate: int = 24000,
|
| 19 |
+
n_fft: int = 1024,
|
| 20 |
+
hop_length: int = 256,
|
| 21 |
+
n_mels: int = 100,
|
| 22 |
+
padding: str = "center",
|
| 23 |
+
fmin: int = 0,
|
| 24 |
+
fmax: int | None = None,
|
| 25 |
+
bigvgan_style_mel: bool = False,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.bigvgan_style_mel = bigvgan_style_mel
|
| 30 |
+
if bigvgan_style_mel:
|
| 31 |
+
# BigVGAN style: same padding, Slaney mel scale, with normalization
|
| 32 |
+
self.n_fft = n_fft
|
| 33 |
+
self.win_size = n_fft
|
| 34 |
+
self.hop_size = hop_length
|
| 35 |
+
# (n_mels, n_fft // 2 + 1)
|
| 36 |
+
mel_basis = librosa_mel_fn(
|
| 37 |
+
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, norm="slaney", htk=False, fmin=fmin, fmax=fmax
|
| 38 |
+
)
|
| 39 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 40 |
+
hann_window = torch.hann_window(n_fft)
|
| 41 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 42 |
+
self.register_buffer("hann_window", hann_window)
|
| 43 |
+
else:
|
| 44 |
+
# Vocos style: center padding, HTK mel scale, without normalization
|
| 45 |
+
if padding not in ["center", "same"]:
|
| 46 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 47 |
+
|
| 48 |
+
self.padding = padding
|
| 49 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
| 50 |
+
sample_rate=sample_rate,
|
| 51 |
+
n_fft=n_fft,
|
| 52 |
+
hop_length=hop_length,
|
| 53 |
+
n_mels=n_mels,
|
| 54 |
+
center=padding == "center",
|
| 55 |
+
power=1,
|
| 56 |
+
fmin=fmin,
|
| 57 |
+
fmax=fmax,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Returns:
|
| 63 |
+
mel_specgram (Tensor): Mel spectrogram of the input audio. (B, C, L)
|
| 64 |
+
"""
|
| 65 |
+
if self.bigvgan_style_mel:
|
| 66 |
+
return self.bigvgan_mel(audio)
|
| 67 |
+
else:
|
| 68 |
+
return self.vocos_mel(audio)
|
| 69 |
+
|
| 70 |
+
def vocos_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
if self.padding == "same":
|
| 72 |
+
pad = self.mel_spec.win_length - self.mel_spec.hop_length
|
| 73 |
+
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
|
| 74 |
+
|
| 75 |
+
specgram = self.mel_spec.spectrogram(audio)
|
| 76 |
+
mel_specgram = self.mel_spec.mel_scale(specgram)
|
| 77 |
+
|
| 78 |
+
# Convert to log scale
|
| 79 |
+
mel_specgram = safe_log(mel_specgram)
|
| 80 |
+
return mel_specgram
|
| 81 |
+
|
| 82 |
+
def bigvgan_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
# Pad so that the output length T = L // hop_length
|
| 84 |
+
padding = (self.n_fft - self.hop_size) // 2
|
| 85 |
+
audio = torch.nn.functional.pad(audio, (padding, padding), mode="reflect")
|
| 86 |
+
audio = audio.reshape(-1, audio.shape[-1])
|
| 87 |
+
|
| 88 |
+
spec = torch.stft(
|
| 89 |
+
audio,
|
| 90 |
+
n_fft=self.n_fft,
|
| 91 |
+
hop_length=self.hop_size,
|
| 92 |
+
win_length=self.win_size,
|
| 93 |
+
window=self.hann_window,
|
| 94 |
+
center=False,
|
| 95 |
+
pad_mode="reflect",
|
| 96 |
+
normalized=False,
|
| 97 |
+
onesided=True,
|
| 98 |
+
return_complex=True,
|
| 99 |
+
)
|
| 100 |
+
spec = spec.reshape(audio.shape[:-1] + spec.shape[-2:])
|
| 101 |
+
|
| 102 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 103 |
+
mel_spec = torch.matmul(self.mel_basis, spec)
|
| 104 |
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 105 |
+
return mel_spec
|
src/kanade_tokenizer/module/convnext.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/gemelo-ai/vocos/blob/main/vocos/models.py
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvNeXtBlock(nn.Module):
|
| 9 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
dim (int): Number of input channels.
|
| 13 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 14 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 15 |
+
Defaults to None.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim: int,
|
| 21 |
+
intermediate_dim: int,
|
| 22 |
+
layer_scale_init_value: float,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 26 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 27 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 28 |
+
self.act = nn.GELU()
|
| 29 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 30 |
+
self.gamma = (
|
| 31 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 32 |
+
if layer_scale_init_value > 0
|
| 33 |
+
else None
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
residual = x
|
| 38 |
+
x = self.dwconv(x)
|
| 39 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 40 |
+
x = self.norm(x)
|
| 41 |
+
x = self.pwconv1(x)
|
| 42 |
+
x = self.act(x)
|
| 43 |
+
x = self.pwconv2(x)
|
| 44 |
+
if self.gamma is not None:
|
| 45 |
+
x = self.gamma * x
|
| 46 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 47 |
+
|
| 48 |
+
x = residual + x
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ConvNextBackbone(nn.Module):
|
| 53 |
+
"""
|
| 54 |
+
Backbone module built with ConvNeXt blocks.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
input_channels (int): Number of input features channels.
|
| 58 |
+
dim (int): Hidden dimension of the model.
|
| 59 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
| 60 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
| 61 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
input_channels: int,
|
| 67 |
+
dim: int,
|
| 68 |
+
intermediate_dim: int,
|
| 69 |
+
num_layers: int,
|
| 70 |
+
output_channels: int | None = None,
|
| 71 |
+
layer_scale_init_value: float | None = None,
|
| 72 |
+
skip_embed: bool = False,
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.input_channels = input_channels
|
| 76 |
+
self.output_channels = output_channels
|
| 77 |
+
self.dim = dim
|
| 78 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) if not skip_embed else nn.Identity()
|
| 79 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 80 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
| 81 |
+
self.convnext = nn.ModuleList(
|
| 82 |
+
[
|
| 83 |
+
ConvNeXtBlock(
|
| 84 |
+
dim=dim,
|
| 85 |
+
intermediate_dim=intermediate_dim,
|
| 86 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 87 |
+
)
|
| 88 |
+
for _ in range(num_layers)
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
self.proj_out = nn.Linear(dim, output_channels) if output_channels else nn.Identity()
|
| 92 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 93 |
+
self.apply(self._init_weights)
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def input_dim(self) -> int:
|
| 97 |
+
return self.input_channels
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def output_dim(self) -> int:
|
| 101 |
+
return self.output_channels if self.output_channels else self.dim
|
| 102 |
+
|
| 103 |
+
def _init_weights(self, m):
|
| 104 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 105 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 106 |
+
nn.init.constant_(m.bias, 0)
|
| 107 |
+
|
| 108 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Args:
|
| 111 |
+
x (Tensor): Input tensor of shape (B, L, C), where B is the batch size,
|
| 112 |
+
C denotes output features, and L is the sequence length.
|
| 113 |
+
Returns:
|
| 114 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
| 115 |
+
and H denotes the model dimension.
|
| 116 |
+
"""
|
| 117 |
+
x = x.transpose(1, 2) # (B, L, C) -> (B, C, L)
|
| 118 |
+
x = self.embed(x)
|
| 119 |
+
x = self.norm(x.transpose(1, 2))
|
| 120 |
+
x = x.transpose(1, 2)
|
| 121 |
+
for conv_block in self.convnext:
|
| 122 |
+
x = conv_block(x)
|
| 123 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 124 |
+
x = self.proj_out(x)
|
| 125 |
+
return x
|
src/kanade_tokenizer/module/discriminator.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from:
|
| 2 |
+
# https://github.com/gemelo-ai/vocos/blob/main/vocos/discriminators.py
|
| 3 |
+
# https://github.com/gemelo-ai/vocos/blob/main/vocos/loss.py
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_2d_padding(kernel_size: tuple[int, int], dilation: tuple[int, int] = (1, 1)):
|
| 12 |
+
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SpectrogramDiscriminator(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
frequency_bins: int,
|
| 19 |
+
channels: int = 32,
|
| 20 |
+
kernel_size: tuple[int, int] = (3, 3),
|
| 21 |
+
dilation: list[int] = [1, 2, 4],
|
| 22 |
+
bands: tuple[tuple[float, float], ...] = ((0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)),
|
| 23 |
+
use_downsample: bool = True,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.bands = [(int(b[0] * frequency_bins), int(b[1] * frequency_bins)) for b in bands]
|
| 27 |
+
|
| 28 |
+
self.stacks = nn.ModuleList()
|
| 29 |
+
for _ in self.bands:
|
| 30 |
+
stack = nn.ModuleList(
|
| 31 |
+
[weight_norm(nn.Conv2d(1, channels, kernel_size, padding=get_2d_padding(kernel_size)))]
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
for d in dilation:
|
| 35 |
+
# dilation on time axis
|
| 36 |
+
pad = get_2d_padding(kernel_size, (d, 1))
|
| 37 |
+
stack.append(weight_norm(nn.Conv2d(channels, channels, kernel_size, dilation=(d, 1), padding=pad)))
|
| 38 |
+
|
| 39 |
+
stack.append(weight_norm(nn.Conv2d(channels, channels, kernel_size, padding=get_2d_padding(kernel_size))))
|
| 40 |
+
|
| 41 |
+
self.stacks.append(stack)
|
| 42 |
+
|
| 43 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, kernel_size, padding=get_2d_padding(kernel_size)))
|
| 44 |
+
if use_downsample:
|
| 45 |
+
self.downsample = nn.AvgPool2d(4, stride=2, padding=1, count_include_pad=False)
|
| 46 |
+
else:
|
| 47 |
+
self.downsample = nn.Identity()
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
x (Tensor): Input spectrogram (B, C, F, T).
|
| 53 |
+
Returns:
|
| 54 |
+
output (Tensor): Discriminator output.
|
| 55 |
+
intermediates (list[Tensor]): List of intermediate feature maps.
|
| 56 |
+
"""
|
| 57 |
+
if x.dim() == 3:
|
| 58 |
+
x = x.unsqueeze(1)
|
| 59 |
+
assert x.dim() == 4, f"Expected 4D input, got {x.dim()}D"
|
| 60 |
+
|
| 61 |
+
# Split into bands
|
| 62 |
+
x = rearrange(x, "b c f t -> b c t f")
|
| 63 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
| 64 |
+
|
| 65 |
+
x = []
|
| 66 |
+
intermediates = []
|
| 67 |
+
for x_band, stack in zip(x_bands, self.stacks):
|
| 68 |
+
for layer in stack:
|
| 69 |
+
x_band = layer(x_band)
|
| 70 |
+
x_band = torch.nn.functional.leaky_relu(x_band, 0.1)
|
| 71 |
+
intermediates.append(x_band)
|
| 72 |
+
x.append(x_band)
|
| 73 |
+
|
| 74 |
+
# Concatenate the outputs from all bands
|
| 75 |
+
x = torch.cat(x, dim=-1)
|
| 76 |
+
x = self.conv_post(x)
|
| 77 |
+
x = self.downsample(x)
|
| 78 |
+
return x, intermediates
|
src/kanade_tokenizer/module/fsq.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finite Scalar Quantization: https://arxiv.org/abs/2309.15505
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from ..util import get_logger
|
| 7 |
+
|
| 8 |
+
logger = get_logger()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def round_ste(z: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
"""Round with straight through gradients."""
|
| 13 |
+
zhat = z.round()
|
| 14 |
+
return z + (zhat - z).detach()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_entropy(prob: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
|
| 18 |
+
return -torch.sum(prob * torch.log(prob + eps), dim=-1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class FSQ(nn.Module):
|
| 22 |
+
def __init__(self, levels: list[int]):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.levels = levels
|
| 25 |
+
self.dim = len(levels)
|
| 26 |
+
|
| 27 |
+
_levels = torch.tensor(levels, dtype=torch.long)
|
| 28 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
| 29 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.long)
|
| 30 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
| 31 |
+
|
| 32 |
+
def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
|
| 33 |
+
"""Bound `z`, an array of shape (..., d)."""
|
| 34 |
+
half_l = (self._levels - 1) * (1 - eps) / 2
|
| 35 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
| 36 |
+
shift = (offset / half_l).tan()
|
| 37 |
+
return (z + shift).tanh() * half_l - offset
|
| 38 |
+
|
| 39 |
+
def quantize(self, z: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
| 41 |
+
quantized = round_ste(self.bound(z))
|
| 42 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
| 43 |
+
return quantized / half_width
|
| 44 |
+
|
| 45 |
+
def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
half_width = self._levels // 2
|
| 47 |
+
return (zhat_normalized * half_width) + half_width
|
| 48 |
+
|
| 49 |
+
def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
half_width = self._levels // 2
|
| 51 |
+
return (zhat - half_width) / half_width
|
| 52 |
+
|
| 53 |
+
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""Converts a `code` to an index in the codebook."""
|
| 55 |
+
# (B, T, C) -> (B, T)
|
| 56 |
+
assert zhat.shape[-1] == len(self.levels)
|
| 57 |
+
zhat = self._scale_and_shift(zhat)
|
| 58 |
+
return (zhat * self._basis.to(torch.float64)).to(torch.long).sum(dim=-1)
|
| 59 |
+
|
| 60 |
+
def indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""Inverse of `codes_to_indices`."""
|
| 62 |
+
# (B, T) -> (B, T, C)
|
| 63 |
+
indices = indices.unsqueeze(-1)
|
| 64 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
| 65 |
+
return self._scale_and_shift_inverse(codes_non_centered)
|
| 66 |
+
|
| 67 |
+
def encode(self, z: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
z_q = self.quantize(z)
|
| 69 |
+
indices = self.codes_to_indices(z_q) # (B, T)
|
| 70 |
+
return z_q, indices
|
| 71 |
+
|
| 72 |
+
def decode(self, indices: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
z_q = self.indices_to_codes(indices) # (B, T, C)
|
| 74 |
+
return z_q
|
| 75 |
+
|
| 76 |
+
def forward(self, z: torch.Tensor):
|
| 77 |
+
z_q = self.quantize(z)
|
| 78 |
+
indices = self.codes_to_indices(z_q) # (B, T)
|
| 79 |
+
return z_q, indices
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FiniteScalarQuantizer(nn.Module):
|
| 83 |
+
def __init__(self, input_dim: int, output_dim: int, levels: list[int]) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.input_dim_ = input_dim
|
| 86 |
+
self.output_dim_ = output_dim
|
| 87 |
+
|
| 88 |
+
self.fsq = FSQ(levels)
|
| 89 |
+
logger.debug(
|
| 90 |
+
f"Finite Scalar Quantizer with levels: {levels}, input_dim: {input_dim}, output_dim: {output_dim}, codebook_size: {self.all_codebook_size}"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.proj_in = nn.Linear(input_dim, len(levels)) if len(levels) != input_dim else nn.Identity()
|
| 94 |
+
self.proj_out = nn.Linear(len(levels), output_dim) if len(levels) != output_dim else nn.Identity()
|
| 95 |
+
|
| 96 |
+
def build_codebook(self) -> None:
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def output_dim(self) -> int:
|
| 101 |
+
return self.output_dim_
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def all_codebook_size(self) -> int:
|
| 105 |
+
size = 1
|
| 106 |
+
for level in self.fsq.levels:
|
| 107 |
+
size *= level
|
| 108 |
+
return size
|
| 109 |
+
|
| 110 |
+
def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, dict]:
|
| 111 |
+
latent = self.proj_in(z) # Latent projected by proj_in
|
| 112 |
+
quantized_latent, indices = self.fsq(latent) # Quantized latent before proj_out
|
| 113 |
+
z_q = self.proj_out(quantized_latent)
|
| 114 |
+
|
| 115 |
+
# Compute perplexity from used indices distribution
|
| 116 |
+
flat_indices = indices.view(-1)
|
| 117 |
+
unique_indices, counts = torch.unique(flat_indices, return_counts=True)
|
| 118 |
+
used_indices_probs = counts.float() / flat_indices.numel()
|
| 119 |
+
entropy = get_entropy(used_indices_probs)
|
| 120 |
+
perplexity = torch.exp(entropy)
|
| 121 |
+
|
| 122 |
+
info_dict = {
|
| 123 |
+
"latent": latent,
|
| 124 |
+
"quantized_latent": quantized_latent,
|
| 125 |
+
"indices": indices,
|
| 126 |
+
"perplexity": perplexity,
|
| 127 |
+
}
|
| 128 |
+
return z_q, info_dict
|
| 129 |
+
|
| 130 |
+
def encode(self, z: torch.Tensor, skip_proj: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
|
| 131 |
+
z = self.proj_in(z)
|
| 132 |
+
z_q, indices = self.fsq.encode(z)
|
| 133 |
+
if not skip_proj:
|
| 134 |
+
z_q = self.proj_out(z_q)
|
| 135 |
+
return z_q, indices
|
| 136 |
+
|
| 137 |
+
def decode(self, indices: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
z_q = self.fsq.decode(indices)
|
| 139 |
+
z_q = self.proj_out(z_q)
|
| 140 |
+
return z_q
|
src/kanade_tokenizer/module/global_encoder.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/ecapa_tdnn.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .convnext import ConvNextBackbone
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AttentiveStatsPool(nn.Module):
|
| 10 |
+
def __init__(self, input_channels: int, output_channels: int, attention_channels: int = 128):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.attn = nn.Sequential(
|
| 14 |
+
nn.Conv1d(input_channels, attention_channels, kernel_size=1),
|
| 15 |
+
nn.Tanh(),
|
| 16 |
+
nn.Conv1d(attention_channels, input_channels, kernel_size=1),
|
| 17 |
+
nn.Softmax(dim=2),
|
| 18 |
+
)
|
| 19 |
+
self.proj = nn.Linear(input_channels * 2, output_channels)
|
| 20 |
+
self.norm = nn.LayerNorm(output_channels)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
alpha = self.attn(x)
|
| 24 |
+
|
| 25 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 26 |
+
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 27 |
+
std = torch.sqrt(residuals.clamp(min=1e-4, max=1e4))
|
| 28 |
+
|
| 29 |
+
x = torch.cat([mean, std], dim=1)
|
| 30 |
+
return self.norm(self.proj(x))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GlobalEncoder(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
input_channels: int,
|
| 37 |
+
output_channels: int,
|
| 38 |
+
dim: int,
|
| 39 |
+
intermediate_dim: int,
|
| 40 |
+
num_layers: int,
|
| 41 |
+
skip_embed: bool = False,
|
| 42 |
+
attention_channels: int = 128,
|
| 43 |
+
use_attn_pool: bool = True,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.backbone = ConvNextBackbone(
|
| 48 |
+
input_channels=input_channels,
|
| 49 |
+
dim=dim,
|
| 50 |
+
intermediate_dim=intermediate_dim,
|
| 51 |
+
num_layers=num_layers,
|
| 52 |
+
skip_embed=skip_embed,
|
| 53 |
+
)
|
| 54 |
+
if use_attn_pool:
|
| 55 |
+
self.pooling = AttentiveStatsPool(
|
| 56 |
+
input_channels=dim, output_channels=output_channels, attention_channels=attention_channels
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.pooling = nn.Sequential(
|
| 60 |
+
nn.AdaptiveAvgPool1d(1),
|
| 61 |
+
nn.Flatten(1),
|
| 62 |
+
nn.Linear(dim, output_channels),
|
| 63 |
+
nn.LayerNorm(output_channels),
|
| 64 |
+
)
|
| 65 |
+
self.output_channels = output_channels
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def output_dim(self):
|
| 69 |
+
return self.output_channels
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
features = self.backbone(x)
|
| 73 |
+
# (B, T, C) -> (B, C, T)
|
| 74 |
+
features = features.transpose(1, 2)
|
| 75 |
+
return self.pooling(features) # (B, C_out)
|
src/kanade_tokenizer/module/hift.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/yl4579/HiFTNet/blob/main/models.py
|
| 2 |
+
# https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/hifigan/generator.py
|
| 3 |
+
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from scipy.signal import get_window
|
| 11 |
+
from torch.distributions.uniform import Uniform
|
| 12 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
| 13 |
+
from torch.nn.utils import remove_weight_norm
|
| 14 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_padding(kernel_size, dilation=1):
|
| 18 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 22 |
+
classname = m.__class__.__name__
|
| 23 |
+
if classname.find("Conv") != -1:
|
| 24 |
+
m.weight.data.normal_(mean, std)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def mel_spec_transform(
|
| 28 |
+
audio: torch.Tensor,
|
| 29 |
+
n_fft: int,
|
| 30 |
+
n_mels: int,
|
| 31 |
+
sample_rate: int,
|
| 32 |
+
hop_size: int,
|
| 33 |
+
win_size: int,
|
| 34 |
+
fmin: int = 0,
|
| 35 |
+
fmax: Optional[int] = None,
|
| 36 |
+
):
|
| 37 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 38 |
+
|
| 39 |
+
# (n_mels, n_fft // 2 + 1)
|
| 40 |
+
mel_basis = librosa_mel_fn(
|
| 41 |
+
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, norm="slaney", htk=False, fmin=fmin, fmax=fmax
|
| 42 |
+
)
|
| 43 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 44 |
+
hann_window = torch.hann_window(win_size)
|
| 45 |
+
|
| 46 |
+
# Pad so that the output length T = L // hop_length
|
| 47 |
+
padding = (n_fft - hop_size) // 2
|
| 48 |
+
audio = torch.nn.functional.pad(audio, (padding, padding), mode="reflect")
|
| 49 |
+
audio = audio.reshape(-1, audio.shape[-1])
|
| 50 |
+
|
| 51 |
+
# (B, n_fft // 2 + 1, T=1 + (L' - n_fft) // hop_length)
|
| 52 |
+
# L' = L + n_fft - hop_length
|
| 53 |
+
# T = L // hop_length
|
| 54 |
+
spec = torch.stft(
|
| 55 |
+
audio,
|
| 56 |
+
n_fft=n_fft,
|
| 57 |
+
hop_length=hop_size,
|
| 58 |
+
win_length=win_size,
|
| 59 |
+
window=hann_window,
|
| 60 |
+
center=False,
|
| 61 |
+
pad_mode="reflect",
|
| 62 |
+
normalized=False,
|
| 63 |
+
onesided=True,
|
| 64 |
+
return_complex=True,
|
| 65 |
+
)
|
| 66 |
+
spec = spec.reshape(audio.shape[:-1] + spec.shape[-2:])
|
| 67 |
+
|
| 68 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 69 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
| 70 |
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 71 |
+
|
| 72 |
+
return mel_spec
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Snake(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Implementation of a sine-based periodic activation function
|
| 78 |
+
Shape:
|
| 79 |
+
- Input: (B, C, T)
|
| 80 |
+
- Output: (B, C, T), same shape as the input
|
| 81 |
+
Parameters:
|
| 82 |
+
- alpha - trainable parameter
|
| 83 |
+
References:
|
| 84 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 85 |
+
https://arxiv.org/abs/2006.08195
|
| 86 |
+
Examples:
|
| 87 |
+
>>> a1 = snake(256)
|
| 88 |
+
>>> x = torch.randn(256)
|
| 89 |
+
>>> x = a1(x)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 93 |
+
"""
|
| 94 |
+
Initialization.
|
| 95 |
+
INPUT:
|
| 96 |
+
- in_features: shape of the input
|
| 97 |
+
- alpha: trainable parameter
|
| 98 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 99 |
+
alpha will be trained along with the rest of your model.
|
| 100 |
+
"""
|
| 101 |
+
super(Snake, self).__init__()
|
| 102 |
+
self.in_features = in_features
|
| 103 |
+
|
| 104 |
+
# initialize alpha
|
| 105 |
+
self.alpha_logscale = alpha_logscale
|
| 106 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 107 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 108 |
+
else: # linear scale alphas initialized to ones
|
| 109 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
| 110 |
+
|
| 111 |
+
self.alpha.requires_grad = alpha_trainable
|
| 112 |
+
|
| 113 |
+
self.no_div_by_zero = 0.000000001
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
"""
|
| 117 |
+
Forward pass of the function.
|
| 118 |
+
Applies the function to the input elementwise.
|
| 119 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 120 |
+
"""
|
| 121 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 122 |
+
if self.alpha_logscale:
|
| 123 |
+
alpha = torch.exp(alpha)
|
| 124 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
| 125 |
+
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ResBlock(torch.nn.Module):
|
| 130 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
channels: int = 512,
|
| 135 |
+
kernel_size: int = 3,
|
| 136 |
+
dilations: List[int] = [1, 3, 5],
|
| 137 |
+
):
|
| 138 |
+
super(ResBlock, self).__init__()
|
| 139 |
+
self.convs1 = nn.ModuleList()
|
| 140 |
+
self.convs2 = nn.ModuleList()
|
| 141 |
+
|
| 142 |
+
for dilation in dilations:
|
| 143 |
+
self.convs1.append(
|
| 144 |
+
weight_norm(
|
| 145 |
+
Conv1d(
|
| 146 |
+
channels,
|
| 147 |
+
channels,
|
| 148 |
+
kernel_size,
|
| 149 |
+
1,
|
| 150 |
+
dilation=dilation,
|
| 151 |
+
padding=get_padding(kernel_size, dilation),
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
self.convs2.append(
|
| 156 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
|
| 157 |
+
)
|
| 158 |
+
self.convs1.apply(init_weights)
|
| 159 |
+
self.convs2.apply(init_weights)
|
| 160 |
+
self.activations1 = nn.ModuleList([Snake(channels, alpha_logscale=False) for _ in range(len(self.convs1))])
|
| 161 |
+
self.activations2 = nn.ModuleList([Snake(channels, alpha_logscale=False) for _ in range(len(self.convs2))])
|
| 162 |
+
|
| 163 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
for idx in range(len(self.convs1)):
|
| 165 |
+
xt = self.activations1[idx](x)
|
| 166 |
+
xt = self.convs1[idx](xt)
|
| 167 |
+
xt = self.activations2[idx](xt)
|
| 168 |
+
xt = self.convs2[idx](xt)
|
| 169 |
+
x = xt + x
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
def remove_weight_norm(self):
|
| 173 |
+
for idx in range(len(self.convs1)):
|
| 174 |
+
remove_weight_norm(self.convs1[idx])
|
| 175 |
+
remove_weight_norm(self.convs2[idx])
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class ConvRNNF0Predictor(nn.Module):
|
| 179 |
+
def __init__(self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512):
|
| 180 |
+
super().__init__()
|
| 181 |
+
|
| 182 |
+
self.num_class = num_class
|
| 183 |
+
self.condnet = nn.Sequential(
|
| 184 |
+
weight_norm(nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)),
|
| 185 |
+
nn.ELU(),
|
| 186 |
+
weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
|
| 187 |
+
nn.ELU(),
|
| 188 |
+
weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
|
| 189 |
+
nn.ELU(),
|
| 190 |
+
weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
|
| 191 |
+
nn.ELU(),
|
| 192 |
+
weight_norm(nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)),
|
| 193 |
+
nn.ELU(),
|
| 194 |
+
)
|
| 195 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 196 |
+
|
| 197 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
x = self.condnet(x)
|
| 199 |
+
x = x.transpose(1, 2)
|
| 200 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class SineGen(torch.nn.Module):
|
| 204 |
+
"""Definition of sine generator
|
| 205 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 206 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 207 |
+
voiced_threshold = 0,
|
| 208 |
+
flag_for_pulse=False)
|
| 209 |
+
samp_rate: sampling rate in Hz
|
| 210 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 211 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 212 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 213 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 214 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 215 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 216 |
+
segment is always sin(np.pi) or cos(0)
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
| 220 |
+
super(SineGen, self).__init__()
|
| 221 |
+
self.sine_amp = sine_amp
|
| 222 |
+
self.noise_std = noise_std
|
| 223 |
+
self.harmonic_num = harmonic_num
|
| 224 |
+
self.sampling_rate = samp_rate
|
| 225 |
+
self.voiced_threshold = voiced_threshold
|
| 226 |
+
|
| 227 |
+
def _f02uv(self, f0):
|
| 228 |
+
# generate uv signal
|
| 229 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 230 |
+
return uv
|
| 231 |
+
|
| 232 |
+
@torch.no_grad()
|
| 233 |
+
def forward(self, f0):
|
| 234 |
+
"""
|
| 235 |
+
:param f0: [B, 1, sample_len], Hz
|
| 236 |
+
:return: [B, 1, sample_len]
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
| 240 |
+
for i in range(self.harmonic_num + 1):
|
| 241 |
+
F_mat[:, i : i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
| 242 |
+
|
| 243 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 244 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 245 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
| 246 |
+
phase_vec[:, 0, :] = 0
|
| 247 |
+
|
| 248 |
+
# generate sine waveforms
|
| 249 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
| 250 |
+
|
| 251 |
+
# generate uv signal
|
| 252 |
+
uv = self._f02uv(f0)
|
| 253 |
+
|
| 254 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 255 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 256 |
+
# . for voiced regions is self.noise_std
|
| 257 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 258 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 259 |
+
|
| 260 |
+
# first: set the unvoiced part to 0 by uv
|
| 261 |
+
# then: additive noise
|
| 262 |
+
sine_waves = sine_waves * uv + noise
|
| 263 |
+
return sine_waves, uv, noise
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 267 |
+
"""SourceModule for hn-nsf
|
| 268 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 269 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 270 |
+
sampling_rate: sampling_rate in Hz
|
| 271 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 272 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 273 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 274 |
+
note that amplitude of noise in unvoiced is decided
|
| 275 |
+
by sine_amp
|
| 276 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 277 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 278 |
+
F0_sampled (batchsize, length, 1)
|
| 279 |
+
Sine_source (batchsize, length, 1)
|
| 280 |
+
noise_source (batchsize, length 1)
|
| 281 |
+
uv (batchsize, length, 1)
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0
|
| 286 |
+
):
|
| 287 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 288 |
+
|
| 289 |
+
self.sine_amp = sine_amp
|
| 290 |
+
self.noise_std = add_noise_std
|
| 291 |
+
|
| 292 |
+
# to produce sine waveforms
|
| 293 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
| 294 |
+
|
| 295 |
+
# to merge source harmonics into a single excitation
|
| 296 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 297 |
+
self.l_tanh = torch.nn.Tanh()
|
| 298 |
+
|
| 299 |
+
def forward(self, x):
|
| 300 |
+
"""
|
| 301 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 302 |
+
F0_sampled (batchsize, length, 1)
|
| 303 |
+
Sine_source (batchsize, length, 1)
|
| 304 |
+
noise_source (batchsize, length 1)
|
| 305 |
+
"""
|
| 306 |
+
# source for harmonic branch
|
| 307 |
+
with torch.no_grad():
|
| 308 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
| 309 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
| 310 |
+
uv = uv.transpose(1, 2)
|
| 311 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 312 |
+
|
| 313 |
+
# source for noise branch, in the same shape as uv
|
| 314 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 315 |
+
return sine_merge, noise, uv
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class SineGen2(torch.nn.Module):
|
| 319 |
+
"""Definition of sine generator
|
| 320 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 321 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 322 |
+
voiced_threshold = 0,
|
| 323 |
+
flag_for_pulse=False)
|
| 324 |
+
samp_rate: sampling rate in Hz
|
| 325 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 326 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 327 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 328 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 329 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 330 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 331 |
+
segment is always sin(np.pi) or cos(0)
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def __init__(
|
| 335 |
+
self,
|
| 336 |
+
samp_rate,
|
| 337 |
+
upsample_scale,
|
| 338 |
+
harmonic_num=0,
|
| 339 |
+
sine_amp=0.1,
|
| 340 |
+
noise_std=0.003,
|
| 341 |
+
voiced_threshold=0,
|
| 342 |
+
flag_for_pulse=False,
|
| 343 |
+
):
|
| 344 |
+
super(SineGen2, self).__init__()
|
| 345 |
+
self.sine_amp = sine_amp
|
| 346 |
+
self.noise_std = noise_std
|
| 347 |
+
self.harmonic_num = harmonic_num
|
| 348 |
+
self.dim = self.harmonic_num + 1
|
| 349 |
+
self.sampling_rate = samp_rate
|
| 350 |
+
self.voiced_threshold = voiced_threshold
|
| 351 |
+
self.flag_for_pulse = flag_for_pulse
|
| 352 |
+
self.upsample_scale = upsample_scale
|
| 353 |
+
|
| 354 |
+
def _f02uv(self, f0):
|
| 355 |
+
# generate uv signal
|
| 356 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 357 |
+
return uv
|
| 358 |
+
|
| 359 |
+
def _f02sine(self, f0_values):
|
| 360 |
+
"""f0_values: (batchsize, length, dim)
|
| 361 |
+
where dim indicates fundamental tone and overtones
|
| 362 |
+
"""
|
| 363 |
+
# convert to F0 in rad. The interger part n can be ignored
|
| 364 |
+
# because 2 * np.pi * n doesn't affect phase
|
| 365 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
| 366 |
+
|
| 367 |
+
# initial phase noise (no noise for fundamental component)
|
| 368 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
| 369 |
+
rand_ini[:, 0] = 0
|
| 370 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
| 371 |
+
|
| 372 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
| 373 |
+
if not self.flag_for_pulse:
|
| 374 |
+
rad_values = torch.nn.functional.interpolate(
|
| 375 |
+
rad_values.transpose(1, 2), scale_factor=1 / self.upsample_scale, mode="linear"
|
| 376 |
+
).transpose(1, 2)
|
| 377 |
+
|
| 378 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
| 379 |
+
phase = torch.nn.functional.interpolate(
|
| 380 |
+
phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear"
|
| 381 |
+
).transpose(1, 2)
|
| 382 |
+
sines = torch.sin(phase)
|
| 383 |
+
else:
|
| 384 |
+
# If necessary, make sure that the first time step of every
|
| 385 |
+
# voiced segments is sin(pi) or cos(0)
|
| 386 |
+
# This is used for pulse-train generation
|
| 387 |
+
|
| 388 |
+
# identify the last time step in unvoiced segments
|
| 389 |
+
uv = self._f02uv(f0_values)
|
| 390 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
| 391 |
+
uv_1[:, -1, :] = 1
|
| 392 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
| 393 |
+
|
| 394 |
+
# get the instantanouse phase
|
| 395 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
| 396 |
+
# different batch needs to be processed differently
|
| 397 |
+
for idx in range(f0_values.shape[0]):
|
| 398 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
| 399 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
| 400 |
+
# stores the accumulation of i.phase within
|
| 401 |
+
# each voiced segments
|
| 402 |
+
tmp_cumsum[idx, :, :] = 0
|
| 403 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
| 404 |
+
|
| 405 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
| 406 |
+
# within the previous voiced segment.
|
| 407 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
| 408 |
+
|
| 409 |
+
# get the sines
|
| 410 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
| 411 |
+
return sines
|
| 412 |
+
|
| 413 |
+
def forward(self, f0):
|
| 414 |
+
"""sine_tensor, uv = forward(f0)
|
| 415 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
| 416 |
+
f0 for unvoiced steps should be 0
|
| 417 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
| 418 |
+
output uv: tensor(batchsize=1, length, 1)
|
| 419 |
+
"""
|
| 420 |
+
# fundamental component
|
| 421 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
| 422 |
+
|
| 423 |
+
# generate sine waveforms
|
| 424 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
| 425 |
+
|
| 426 |
+
# generate uv signal
|
| 427 |
+
uv = self._f02uv(f0)
|
| 428 |
+
|
| 429 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 430 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 431 |
+
# . for voiced regions is self.noise_std
|
| 432 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 433 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 434 |
+
|
| 435 |
+
# first: set the unvoiced part to 0 by uv
|
| 436 |
+
# then: additive noise
|
| 437 |
+
sine_waves = sine_waves * uv + noise
|
| 438 |
+
return sine_waves, uv, noise
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class SourceModuleHnNSF2(torch.nn.Module):
|
| 442 |
+
"""SourceModule for hn-nsf
|
| 443 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 444 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 445 |
+
sampling_rate: sampling_rate in Hz
|
| 446 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 447 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 448 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 449 |
+
note that amplitude of noise in unvoiced is decided
|
| 450 |
+
by sine_amp
|
| 451 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 452 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 453 |
+
F0_sampled (batchsize, length, 1)
|
| 454 |
+
Sine_source (batchsize, length, 1)
|
| 455 |
+
noise_source (batchsize, length 1)
|
| 456 |
+
uv (batchsize, length, 1)
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
def __init__(
|
| 460 |
+
self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0
|
| 461 |
+
):
|
| 462 |
+
super(SourceModuleHnNSF2, self).__init__()
|
| 463 |
+
|
| 464 |
+
self.sine_amp = sine_amp
|
| 465 |
+
self.noise_std = add_noise_std
|
| 466 |
+
|
| 467 |
+
# to produce sine waveforms
|
| 468 |
+
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
| 469 |
+
|
| 470 |
+
# to merge source harmonics into a single excitation
|
| 471 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 472 |
+
self.l_tanh = torch.nn.Tanh()
|
| 473 |
+
|
| 474 |
+
def forward(self, x):
|
| 475 |
+
"""
|
| 476 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 477 |
+
F0_sampled (batchsize, length, 1)
|
| 478 |
+
Sine_source (batchsize, length, 1)
|
| 479 |
+
noise_source (batchsize, length 1)
|
| 480 |
+
"""
|
| 481 |
+
# source for harmonic branch
|
| 482 |
+
with torch.no_grad():
|
| 483 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
| 484 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 485 |
+
|
| 486 |
+
# source for noise branch, in the same shape as uv
|
| 487 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 488 |
+
return sine_merge, noise, uv
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class HiFTGenerator(nn.Module):
|
| 492 |
+
"""
|
| 493 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 494 |
+
https://arxiv.org/abs/2309.09493
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
in_channels: int = 80,
|
| 500 |
+
base_channels: int = 512,
|
| 501 |
+
nb_harmonics: int = 8,
|
| 502 |
+
sampling_rate: int = 24000,
|
| 503 |
+
nsf_alpha: float = 0.1,
|
| 504 |
+
nsf_sigma: float = 0.003,
|
| 505 |
+
nsf_voiced_threshold: float = 10,
|
| 506 |
+
upsample_rates: list[int] = [8, 5, 3],
|
| 507 |
+
upsample_kernel_sizes: list[int] = [16, 11, 7],
|
| 508 |
+
istft_n_fft: int = 16,
|
| 509 |
+
istft_hop_len: int = 4,
|
| 510 |
+
resblock_kernel_sizes: list[int] = [3, 7, 11],
|
| 511 |
+
resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 512 |
+
source_resblock_kernel_sizes: list[int] = [7, 7, 11],
|
| 513 |
+
source_resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 514 |
+
lrelu_slope: float = 0.1,
|
| 515 |
+
audio_limit: float = 0.99,
|
| 516 |
+
f0_predictor_channels: int = 512,
|
| 517 |
+
):
|
| 518 |
+
super(HiFTGenerator, self).__init__()
|
| 519 |
+
|
| 520 |
+
self.out_channels = 1
|
| 521 |
+
self.nb_harmonics = nb_harmonics
|
| 522 |
+
self.sampling_rate = sampling_rate
|
| 523 |
+
self.istft_n_fft = istft_n_fft
|
| 524 |
+
self.istft_hop_len = istft_hop_len
|
| 525 |
+
self.lrelu_slope = lrelu_slope
|
| 526 |
+
self.audio_limit = audio_limit
|
| 527 |
+
|
| 528 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 529 |
+
self.num_upsamples = len(upsample_rates)
|
| 530 |
+
self.m_source = SourceModuleHnNSF2(
|
| 531 |
+
sampling_rate=sampling_rate,
|
| 532 |
+
upsample_scale=np.prod(upsample_rates) * istft_hop_len,
|
| 533 |
+
harmonic_num=nb_harmonics,
|
| 534 |
+
sine_amp=nsf_alpha,
|
| 535 |
+
add_noise_std=nsf_sigma,
|
| 536 |
+
voiced_threshod=nsf_voiced_threshold,
|
| 537 |
+
)
|
| 538 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_hop_len)
|
| 539 |
+
|
| 540 |
+
self.conv_pre = weight_norm(Conv1d(in_channels, base_channels, 7, 1, padding=3))
|
| 541 |
+
|
| 542 |
+
# Up
|
| 543 |
+
self.ups = nn.ModuleList()
|
| 544 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 545 |
+
self.ups.append(
|
| 546 |
+
weight_norm(
|
| 547 |
+
ConvTranspose1d(
|
| 548 |
+
base_channels // (2**i), base_channels // (2 ** (i + 1)), k, u, padding=(k - u) // 2
|
| 549 |
+
)
|
| 550 |
+
)
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# Down
|
| 554 |
+
self.source_downs = nn.ModuleList()
|
| 555 |
+
self.source_resblocks = nn.ModuleList()
|
| 556 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 557 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 558 |
+
for i, (u, k, d) in enumerate(
|
| 559 |
+
zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)
|
| 560 |
+
):
|
| 561 |
+
if u == 1:
|
| 562 |
+
self.source_downs.append(Conv1d(istft_n_fft + 2, base_channels // (2 ** (i + 1)), 1, 1))
|
| 563 |
+
else:
|
| 564 |
+
self.source_downs.append(
|
| 565 |
+
Conv1d(istft_n_fft + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
self.source_resblocks.append(ResBlock(base_channels // (2 ** (i + 1)), k, d))
|
| 569 |
+
|
| 570 |
+
self.resblocks = nn.ModuleList()
|
| 571 |
+
for i in range(len(self.ups)):
|
| 572 |
+
ch = base_channels // (2 ** (i + 1))
|
| 573 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 574 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
| 575 |
+
|
| 576 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_n_fft + 2, 7, 1, padding=3))
|
| 577 |
+
|
| 578 |
+
self.ups.apply(init_weights)
|
| 579 |
+
self.conv_post.apply(init_weights)
|
| 580 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 581 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_n_fft, fftbins=True).astype(np.float32))
|
| 582 |
+
|
| 583 |
+
self.f0_predictor = ConvRNNF0Predictor(
|
| 584 |
+
num_class=1, in_channels=in_channels, cond_channels=f0_predictor_channels
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
def remove_weight_norm(self):
|
| 588 |
+
for layer in self.ups:
|
| 589 |
+
remove_weight_norm(layer)
|
| 590 |
+
for layer in self.resblocks:
|
| 591 |
+
layer.remove_weight_norm()
|
| 592 |
+
remove_weight_norm(self.conv_pre)
|
| 593 |
+
remove_weight_norm(self.conv_post)
|
| 594 |
+
self.m_source.remove_weight_norm()
|
| 595 |
+
for layer in self.source_downs:
|
| 596 |
+
remove_weight_norm(layer)
|
| 597 |
+
for layer in self.source_resblocks:
|
| 598 |
+
layer.remove_weight_norm()
|
| 599 |
+
|
| 600 |
+
def _stft(self, x):
|
| 601 |
+
spec = torch.stft(
|
| 602 |
+
x,
|
| 603 |
+
self.istft_n_fft,
|
| 604 |
+
self.istft_hop_len,
|
| 605 |
+
self.istft_n_fft,
|
| 606 |
+
window=self.stft_window.to(x.device),
|
| 607 |
+
return_complex=True,
|
| 608 |
+
)
|
| 609 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 610 |
+
return spec[..., 0], spec[..., 1]
|
| 611 |
+
|
| 612 |
+
def _istft(self, magnitude, phase):
|
| 613 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 614 |
+
real = magnitude * torch.cos(phase)
|
| 615 |
+
img = magnitude * torch.sin(phase)
|
| 616 |
+
inverse_transform = torch.istft(
|
| 617 |
+
torch.complex(real, img),
|
| 618 |
+
self.istft_n_fft,
|
| 619 |
+
self.istft_hop_len,
|
| 620 |
+
self.istft_n_fft,
|
| 621 |
+
window=self.stft_window.to(magnitude.device),
|
| 622 |
+
)
|
| 623 |
+
return inverse_transform
|
| 624 |
+
|
| 625 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 626 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 627 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 628 |
+
|
| 629 |
+
x = self.conv_pre(x)
|
| 630 |
+
for i in range(self.num_upsamples):
|
| 631 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 632 |
+
x = self.ups[i](x)
|
| 633 |
+
|
| 634 |
+
if i == self.num_upsamples - 1:
|
| 635 |
+
x = self.reflection_pad(x)
|
| 636 |
+
|
| 637 |
+
# fusion
|
| 638 |
+
si = self.source_downs[i](s_stft)
|
| 639 |
+
si = self.source_resblocks[i](si)
|
| 640 |
+
x = x + si
|
| 641 |
+
|
| 642 |
+
xs = None
|
| 643 |
+
for j in range(self.num_kernels):
|
| 644 |
+
if xs is None:
|
| 645 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 646 |
+
else:
|
| 647 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 648 |
+
x = xs / self.num_kernels
|
| 649 |
+
|
| 650 |
+
x = F.leaky_relu(x)
|
| 651 |
+
x = self.conv_post(x)
|
| 652 |
+
magnitude = torch.exp(x[:, : self.istft_n_fft // 2 + 1, :])
|
| 653 |
+
phase = torch.sin(x[:, self.istft_n_fft // 2 + 1 :, :]) # actually, sin is redundancy
|
| 654 |
+
|
| 655 |
+
x = self._istft(magnitude, phase)
|
| 656 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
| 657 |
+
return x
|
| 658 |
+
|
| 659 |
+
def forward(self, speech_feat: torch.Tensor) -> Dict[str, Optional[torch.Tensor]]:
|
| 660 |
+
speech_feat = speech_feat.transpose(1, 2)
|
| 661 |
+
# mel->f0
|
| 662 |
+
f0 = self.f0_predictor(speech_feat)
|
| 663 |
+
# f0->source
|
| 664 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 665 |
+
s, _, _ = self.m_source(s)
|
| 666 |
+
s = s.transpose(1, 2)
|
| 667 |
+
# mel+source->speech
|
| 668 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 669 |
+
return generated_speech, f0
|
| 670 |
+
|
| 671 |
+
@torch.inference_mode()
|
| 672 |
+
def inference(self, speech_feat: torch.Tensor) -> torch.Tensor:
|
| 673 |
+
# mel->f0
|
| 674 |
+
f0 = self.f0_predictor(speech_feat)
|
| 675 |
+
# f0->source
|
| 676 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 677 |
+
s, _, _ = self.m_source(s)
|
| 678 |
+
s = s.transpose(1, 2)
|
| 679 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 680 |
+
return generated_speech
|
| 681 |
+
|
| 682 |
+
def load_weights(self, weights_path: str):
|
| 683 |
+
checkpoint = torch.load(weights_path, map_location="cpu")
|
| 684 |
+
state_dict = {k.replace("generator.", ""): v for k, v in checkpoint.items()}
|
| 685 |
+
self.load_state_dict(state_dict, strict=True)
|
src/kanade_tokenizer/module/postnet.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/ming024/FastSpeech2
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_padding(kernel_size: int, dilation: int = 1):
|
| 8 |
+
return ((kernel_size - 1) * dilation) // 2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Norm(nn.Module):
|
| 12 |
+
def __init__(self, channels: int):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.norm = nn.LayerNorm(channels)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
# (batch_size, channels, sequence_length)
|
| 18 |
+
x = x.transpose(1, 2)
|
| 19 |
+
x = self.norm(x)
|
| 20 |
+
return x.transpose(1, 2)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PostNet(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
input_channels: int = 100,
|
| 27 |
+
channels: int = 512,
|
| 28 |
+
kernel_size: int = 5,
|
| 29 |
+
num_layers: int = 5,
|
| 30 |
+
dropout: float = 0.5,
|
| 31 |
+
use_layer_norm: bool = False,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
padding = get_padding(kernel_size)
|
| 36 |
+
self.convolutions = nn.ModuleList()
|
| 37 |
+
|
| 38 |
+
self.convolutions.append(
|
| 39 |
+
nn.Sequential(
|
| 40 |
+
nn.Conv1d(input_channels, channels, kernel_size=kernel_size, padding=padding),
|
| 41 |
+
Norm(channels) if use_layer_norm else nn.BatchNorm1d(channels),
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
for i in range(1, num_layers - 1):
|
| 45 |
+
self.convolutions.append(
|
| 46 |
+
nn.Sequential(
|
| 47 |
+
nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding),
|
| 48 |
+
Norm(channels) if use_layer_norm else nn.BatchNorm1d(channels),
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
self.convolutions.append(
|
| 52 |
+
nn.Sequential(
|
| 53 |
+
nn.Conv1d(channels, input_channels, kernel_size=kernel_size, padding=padding),
|
| 54 |
+
Norm(input_channels) if use_layer_norm else nn.BatchNorm1d(input_channels),
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.dropout = nn.Dropout(dropout)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
residual = x
|
| 62 |
+
|
| 63 |
+
for i in range(len(self.convolutions) - 1):
|
| 64 |
+
x = self.convolutions[i](x)
|
| 65 |
+
x = torch.tanh(x)
|
| 66 |
+
x = self.dropout(x)
|
| 67 |
+
|
| 68 |
+
x = self.convolutions[-1](x)
|
| 69 |
+
x = self.dropout(x)
|
| 70 |
+
|
| 71 |
+
return x + residual
|
src/kanade_tokenizer/module/ssl_extractor.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchaudio
|
| 4 |
+
import torchaudio.pipelines as pipelines
|
| 5 |
+
from torchaudio.models.wav2vec2 import Wav2Vec2Model
|
| 6 |
+
from torchaudio.models.wav2vec2.components import ConvLayerBlock
|
| 7 |
+
|
| 8 |
+
from ..util import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Map of friendly names to torchaudio pipeline bundles
|
| 14 |
+
MODEL_REGISTRY = {
|
| 15 |
+
"wav2vec2_base": pipelines.WAV2VEC2_BASE,
|
| 16 |
+
"wav2vec2_large": pipelines.WAV2VEC2_LARGE,
|
| 17 |
+
"wav2vec2_large_lv60k": pipelines.WAV2VEC2_LARGE_LV60K,
|
| 18 |
+
"hubert_base": pipelines.HUBERT_BASE,
|
| 19 |
+
"hubert_large": pipelines.HUBERT_LARGE,
|
| 20 |
+
"hubert_xlarge": pipelines.HUBERT_XLARGE,
|
| 21 |
+
"wavlm_base": pipelines.WAVLM_BASE,
|
| 22 |
+
"wavlm_base_plus": pipelines.WAVLM_BASE_PLUS,
|
| 23 |
+
"wavlm_large": pipelines.WAVLM_LARGE,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SSLFeatureExtractor(nn.Module):
|
| 28 |
+
def __init__(self, model_name: str = "wavlm_base_plus", output_layer: int | None = None, sample_rate: int = 16000):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
model_name: Name of the SSL model to use
|
| 32 |
+
output_layer: Which layer's features to extract (None for last layer), 1-based indexing
|
| 33 |
+
sample_rate: Sample rate of input audio
|
| 34 |
+
"""
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.output_layer = output_layer if output_layer is not None else -1
|
| 37 |
+
|
| 38 |
+
if model_name not in MODEL_REGISTRY:
|
| 39 |
+
raise ValueError(f"Unknown model: {model_name}. Available models: {list(MODEL_REGISTRY.keys())}")
|
| 40 |
+
bundle = MODEL_REGISTRY[model_name]
|
| 41 |
+
self.model: Wav2Vec2Model = bundle.get_model()
|
| 42 |
+
self.model.eval()
|
| 43 |
+
self.feature_dim: int = bundle._params["encoder_embed_dim"]
|
| 44 |
+
|
| 45 |
+
self.ssl_sample_rate = bundle.sample_rate
|
| 46 |
+
# Create resampler if needed
|
| 47 |
+
if sample_rate != self.ssl_sample_rate:
|
| 48 |
+
logger.debug(f"Resampling from {sample_rate} to {self.ssl_sample_rate} required by {model_name}.")
|
| 49 |
+
self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.ssl_sample_rate)
|
| 50 |
+
else:
|
| 51 |
+
self.resampler = None
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def hop_size(self) -> int:
|
| 55 |
+
"""Get the hop size of the model's convolutional layers."""
|
| 56 |
+
hop_size = 1
|
| 57 |
+
for _, stride in self.conv_config:
|
| 58 |
+
hop_size *= stride
|
| 59 |
+
return hop_size
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def conv_config(self) -> list[tuple[int, int]]:
|
| 63 |
+
"""Get the configuration of the convolutional layers in the model."""
|
| 64 |
+
conv_layers = []
|
| 65 |
+
for layer in self.model.feature_extractor.conv_layers:
|
| 66 |
+
layer: ConvLayerBlock
|
| 67 |
+
conv_layers.append((layer.kernel_size, layer.stride))
|
| 68 |
+
return conv_layers
|
| 69 |
+
|
| 70 |
+
def get_minimum_input_length(self, desired_output_length: int) -> int:
|
| 71 |
+
"""Calculate the minimum input length required to produce a given output length."""
|
| 72 |
+
length = desired_output_length
|
| 73 |
+
for kernel_size, stride in reversed(self.conv_config):
|
| 74 |
+
length = (length - 1) * stride + kernel_size
|
| 75 |
+
return length
|
| 76 |
+
|
| 77 |
+
@torch.no_grad()
|
| 78 |
+
def forward(
|
| 79 |
+
self,
|
| 80 |
+
waveform: torch.Tensor,
|
| 81 |
+
lengths: torch.Tensor | None = None,
|
| 82 |
+
num_layers: int | None = None,
|
| 83 |
+
return_lengths: bool = False,
|
| 84 |
+
) -> list[torch.Tensor]:
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
waveform: (batch_size, num_samples)
|
| 88 |
+
lengths: Optional tensor of sequence lengths for each batch item (used for attention masking)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
features: List of feature tensors for each layer (batch_size, frame, dim)
|
| 92 |
+
lengths: Sequence lengths for each batch item
|
| 93 |
+
"""
|
| 94 |
+
if waveform.dim() == 1:
|
| 95 |
+
waveform = waveform.unsqueeze(0)
|
| 96 |
+
# Resample if needed
|
| 97 |
+
if self.resampler is not None:
|
| 98 |
+
waveform = self.resampler(waveform)
|
| 99 |
+
|
| 100 |
+
features, feature_lengths = self.model.extract_features(
|
| 101 |
+
waveform, lengths, num_layers=num_layers or self.output_layer
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if return_lengths:
|
| 105 |
+
return features, feature_lengths
|
| 106 |
+
return features
|
src/kanade_tokenizer/module/transformer.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/meta-llama/llama3/blob/main/llama/model.py
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from ..util import get_logger
|
| 11 |
+
from .adaln_zero import AdaLNZero
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = get_logger()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
| 19 |
+
|
| 20 |
+
FLASH_ATTN_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
FLASH_ATTN_AVAILABLE = False
|
| 23 |
+
logger.warning(
|
| 24 |
+
"FlashAttention is not installed. Falling back to PyTorch SDPA implementation. There is no warranty that the model will work correctly."
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 29 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 30 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
| 31 |
+
freqs = torch.outer(t, freqs)
|
| 32 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 33 |
+
return freqs_cis
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 37 |
+
ndim = x.ndim
|
| 38 |
+
assert 0 <= 1 < ndim
|
| 39 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 40 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 41 |
+
return freqs_cis.view(*shape)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 46 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, x_)
|
| 47 |
+
x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
|
| 48 |
+
return x_out.type_as(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Attention(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
dim: int,
|
| 55 |
+
n_heads: int,
|
| 56 |
+
dropout: float,
|
| 57 |
+
window_size: int | None,
|
| 58 |
+
qkv_bias: bool = False,
|
| 59 |
+
proj_bias: bool = False,
|
| 60 |
+
use_flash_attention: bool = False,
|
| 61 |
+
causal: bool = False,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.n_heads = n_heads
|
| 65 |
+
self.head_dim = dim // n_heads
|
| 66 |
+
|
| 67 |
+
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias)
|
| 68 |
+
self.wk = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias)
|
| 69 |
+
self.wv = nn.Linear(dim, n_heads * self.head_dim, bias=qkv_bias)
|
| 70 |
+
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=proj_bias)
|
| 71 |
+
|
| 72 |
+
self.scale = self.head_dim**-0.5
|
| 73 |
+
self.dropout = dropout
|
| 74 |
+
|
| 75 |
+
# Enable local attention if window_size is specified
|
| 76 |
+
self.use_local_attention = window_size is not None
|
| 77 |
+
if self.use_local_attention:
|
| 78 |
+
assert window_size % 2 == 1, "Window size must be odd for local attention."
|
| 79 |
+
self.window_per_side = window_size // 2
|
| 80 |
+
|
| 81 |
+
self.use_flash_attention = use_flash_attention
|
| 82 |
+
|
| 83 |
+
self.causal = causal
|
| 84 |
+
|
| 85 |
+
def create_mask(
|
| 86 |
+
self, bsz: int, seqlen: int, mask: torch.Tensor | None, device: torch.device
|
| 87 |
+
) -> torch.Tensor | None:
|
| 88 |
+
"""Create attention mask combining provided mask and local attention constraints"""
|
| 89 |
+
if not self.use_local_attention and mask is None:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
# Start with all positions allowed
|
| 93 |
+
attn_mask = torch.ones((seqlen, seqlen), dtype=torch.bool, device=device)
|
| 94 |
+
|
| 95 |
+
if self.causal:
|
| 96 |
+
# Causal mask: no future positions allowed
|
| 97 |
+
attn_mask = torch.tril(attn_mask)
|
| 98 |
+
|
| 99 |
+
# Apply local attention constraints
|
| 100 |
+
if self.use_local_attention:
|
| 101 |
+
attn_mask = torch.triu(attn_mask, diagonal=-self.window_per_side)
|
| 102 |
+
attn_mask = torch.tril(attn_mask, diagonal=self.window_per_side)
|
| 103 |
+
|
| 104 |
+
# Expand mask to batch size
|
| 105 |
+
attn_mask = attn_mask.unsqueeze(0).expand(bsz, -1, -1)
|
| 106 |
+
|
| 107 |
+
# Apply global mask if provided
|
| 108 |
+
if mask is not None:
|
| 109 |
+
assert mask.shape[-1] == seqlen and mask.shape[-2] == seqlen, (
|
| 110 |
+
"Mask must be square and match sequence length."
|
| 111 |
+
)
|
| 112 |
+
# Ensure mask has correct batch dimensions
|
| 113 |
+
if mask.dim() == 2:
|
| 114 |
+
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
|
| 115 |
+
attn_mask = attn_mask & mask
|
| 116 |
+
|
| 117 |
+
# Expand to head dimension
|
| 118 |
+
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
|
| 119 |
+
return attn_mask
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
x: torch.Tensor,
|
| 124 |
+
freqs_cis: torch.Tensor | None,
|
| 125 |
+
mask: torch.Tensor | None,
|
| 126 |
+
return_kv: bool = False,
|
| 127 |
+
) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 128 |
+
"""Forward pass for multi-head attention.
|
| 129 |
+
Args:
|
| 130 |
+
x (torch.Tensor): Input tensor of shape (bsz, seqlen, dim).
|
| 131 |
+
freqs_cis (torch.Tensor, optional): Precomputed rotary frequencies.
|
| 132 |
+
mask (torch.Tensor, optional): Attention mask.
|
| 133 |
+
return_kv (bool): Whether to return KV pairs for caching.
|
| 134 |
+
Returns:
|
| 135 |
+
output (torch.Tensor): Output tensor of shape (bsz, seqlen, dim).
|
| 136 |
+
new_kv (tuple, optional): KV pairs if return_kv is True.
|
| 137 |
+
"""
|
| 138 |
+
bsz, seqlen, _ = x.shape
|
| 139 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 140 |
+
|
| 141 |
+
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 142 |
+
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 143 |
+
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 144 |
+
|
| 145 |
+
# Apply rotary embeddings if provided
|
| 146 |
+
if freqs_cis is not None:
|
| 147 |
+
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis[:seqlen])
|
| 148 |
+
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis[:seqlen])
|
| 149 |
+
|
| 150 |
+
if self.use_flash_attention and FLASH_ATTN_AVAILABLE:
|
| 151 |
+
assert mask is None, "Flash attention does not support arbitrary masking."
|
| 152 |
+
|
| 153 |
+
# Flash Attention
|
| 154 |
+
window_size = (self.window_per_side, self.window_per_side) if self.use_local_attention else (-1, -1)
|
| 155 |
+
output = flash_attn_func(
|
| 156 |
+
xq, # (bsz, seqlen, n_heads, head_dim)
|
| 157 |
+
xk, # (bsz, seqlen, n_heads, head_dim)
|
| 158 |
+
xv, # (bsz, seqlen, n_heads, head_dim)
|
| 159 |
+
dropout_p=(self.dropout if self.training else 0.0),
|
| 160 |
+
softmax_scale=self.scale,
|
| 161 |
+
window_size=window_size,
|
| 162 |
+
causal=self.causal,
|
| 163 |
+
) # (bsz, seqlen, n_heads, head_dim)
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
attn_mask = self.create_mask(bsz, seqlen, mask, x.device)
|
| 167 |
+
|
| 168 |
+
# SDPA Attention
|
| 169 |
+
output = F.scaled_dot_product_attention(
|
| 170 |
+
xq.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim)
|
| 171 |
+
xk.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim)
|
| 172 |
+
xv.transpose(1, 2), # (bsz, n_heads, seqlen, head_dim)
|
| 173 |
+
attn_mask=attn_mask, # (bsz, n_heads, seqlen, seqlen) boolean mask
|
| 174 |
+
dropout_p=self.dropout,
|
| 175 |
+
scale=self.scale,
|
| 176 |
+
).transpose(1, 2) # (bsz, seqlen, n_heads, head_dim)
|
| 177 |
+
|
| 178 |
+
output = output.contiguous().view(bsz, seqlen, -1)
|
| 179 |
+
output = self.wo(output)
|
| 180 |
+
|
| 181 |
+
if return_kv:
|
| 182 |
+
return output, (xk, xv)
|
| 183 |
+
return output
|
| 184 |
+
|
| 185 |
+
def forward_with_cache(
|
| 186 |
+
self,
|
| 187 |
+
x: torch.Tensor,
|
| 188 |
+
kv_cache: tuple[torch.Tensor, torch.Tensor],
|
| 189 |
+
freqs_cis: torch.Tensor,
|
| 190 |
+
start_pos: int,
|
| 191 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 192 |
+
"""
|
| 193 |
+
Forward pass with KV cache for efficient inference. Only used for inference.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
x (torch.Tensor): Input tensor for the current step. Shape: (bsz, 1, dim)
|
| 197 |
+
kv_cache: A tuple of (key_cache, value_cache) from previous steps.
|
| 198 |
+
start_pos (int): The starting position of the new token in the sequence.
|
| 199 |
+
freqs_cis (torch.Tensor): Precomputed rotary frequencies.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
output (torch.Tensor): Output tensor after attention. Shape: (bsz, 1, dim)
|
| 203 |
+
new_kv (tuple): Updated KV cache including the new key and value.
|
| 204 |
+
"""
|
| 205 |
+
bsz, seqlen, _ = x.shape
|
| 206 |
+
assert seqlen == 1, "KV cache method is designed for single-token generation."
|
| 207 |
+
|
| 208 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 209 |
+
|
| 210 |
+
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 211 |
+
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 212 |
+
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
| 213 |
+
|
| 214 |
+
# Apply rotary embeddings using the correct positional slice
|
| 215 |
+
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis[start_pos : start_pos + seqlen])
|
| 216 |
+
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis[start_pos : start_pos + seqlen])
|
| 217 |
+
|
| 218 |
+
# Update the KV cache
|
| 219 |
+
k_cache, v_cache = kv_cache
|
| 220 |
+
new_kv = (xk, xv)
|
| 221 |
+
xk = torch.cat([k_cache, xk], dim=1)
|
| 222 |
+
xv = torch.cat([v_cache, xv], dim=1)
|
| 223 |
+
|
| 224 |
+
# For single token generation, causal mask is implicitly handled.
|
| 225 |
+
# We attend to all keys (prefix + previous tokens).
|
| 226 |
+
if self.use_flash_attention and FLASH_ATTN_AVAILABLE:
|
| 227 |
+
# Flash Attention
|
| 228 |
+
output = flash_attn_with_kvcache(
|
| 229 |
+
xq, # (bsz, 1, n_heads, head_dim)
|
| 230 |
+
xk, # (bsz, 1+kv_len, n_heads, head_dim)
|
| 231 |
+
xv, # (bsz, 1+kv_len, n_heads, head_dim)
|
| 232 |
+
softmax_scale=self.scale,
|
| 233 |
+
) # (bsz, 1, n_heads, head_dim)
|
| 234 |
+
else:
|
| 235 |
+
# SDPA Attention
|
| 236 |
+
output = F.scaled_dot_product_attention(
|
| 237 |
+
xq.transpose(1, 2), # (bsz, n_heads, 1, head_dim)
|
| 238 |
+
xk.transpose(1, 2), # (bsz, n_heads, 1+kv_len, head_dim)
|
| 239 |
+
xv.transpose(1, 2), # (bsz, n_heads, 1+kv_len, head_dim)
|
| 240 |
+
scale=self.scale,
|
| 241 |
+
).transpose(1, 2) # (bsz, 1, n_heads, head_dim)
|
| 242 |
+
|
| 243 |
+
output = output.contiguous().view(bsz, seqlen, -1)
|
| 244 |
+
return self.wo(output), new_kv
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class FeedForward(nn.Module):
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
dim: int,
|
| 251 |
+
hidden_dim: int,
|
| 252 |
+
multiple_of: int,
|
| 253 |
+
ffn_dim_multiplier: float | None,
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 257 |
+
# custom dim factor multiplier
|
| 258 |
+
if ffn_dim_multiplier is not None:
|
| 259 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 260 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 261 |
+
|
| 262 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 263 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 264 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class TransformerBlock(nn.Module):
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
dim: int,
|
| 274 |
+
n_heads: int,
|
| 275 |
+
qkv_bias: bool,
|
| 276 |
+
proj_bias: bool,
|
| 277 |
+
window_size: int | None,
|
| 278 |
+
multiple_of: int,
|
| 279 |
+
ffn_dim_multiplier: float | None,
|
| 280 |
+
dropout: float,
|
| 281 |
+
norm_eps: float,
|
| 282 |
+
adanorm_condition_dim: int | None = None,
|
| 283 |
+
use_flash_attention: bool = False,
|
| 284 |
+
use_adaln_zero: bool = False,
|
| 285 |
+
causal: bool = False,
|
| 286 |
+
):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.attention = Attention(
|
| 289 |
+
dim=dim,
|
| 290 |
+
n_heads=n_heads,
|
| 291 |
+
dropout=dropout,
|
| 292 |
+
window_size=window_size,
|
| 293 |
+
use_flash_attention=use_flash_attention,
|
| 294 |
+
qkv_bias=qkv_bias,
|
| 295 |
+
proj_bias=proj_bias,
|
| 296 |
+
causal=causal,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
self.feed_forward = FeedForward(
|
| 300 |
+
dim=dim,
|
| 301 |
+
hidden_dim=4 * dim,
|
| 302 |
+
multiple_of=multiple_of,
|
| 303 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Choose between AdaLNZero and regular LayerNorm
|
| 307 |
+
self.use_adaln_zero = use_adaln_zero
|
| 308 |
+
if self.use_adaln_zero:
|
| 309 |
+
assert adanorm_condition_dim is not None, "condition_dim must be provided when using AdaLNZero"
|
| 310 |
+
self.attention_norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=True)
|
| 311 |
+
self.ffn_norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=True)
|
| 312 |
+
else:
|
| 313 |
+
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
|
| 314 |
+
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
|
| 315 |
+
|
| 316 |
+
def forward(
|
| 317 |
+
self,
|
| 318 |
+
x: torch.Tensor,
|
| 319 |
+
freqs_cis: torch.Tensor | None,
|
| 320 |
+
mask: torch.Tensor | None,
|
| 321 |
+
condition: torch.Tensor | None = None,
|
| 322 |
+
return_kv: bool = False,
|
| 323 |
+
kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 324 |
+
start_pos: int | None = None,
|
| 325 |
+
) -> torch.Tensor | tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 326 |
+
"""
|
| 327 |
+
Forward pass for a single Transformer block.
|
| 328 |
+
Args:
|
| 329 |
+
x (torch.Tensor): Input tensor of shape (bsz, seqlen, dim).
|
| 330 |
+
freqs_cis (torch.Tensor, optional): Precomputed rotary frequencies.
|
| 331 |
+
mask (torch.Tensor, optional): Attention mask.
|
| 332 |
+
condition (torch.Tensor, optional): Conditioning tensor for AdaLNZero.
|
| 333 |
+
return_kv (bool): Whether to return KV pairs for caching.
|
| 334 |
+
kv_cache (tuple, optional): KV cache for efficient inference.
|
| 335 |
+
start_pos (int, optional): Starting position for KV cache.
|
| 336 |
+
Returns:
|
| 337 |
+
out (torch.Tensor): Output tensor of shape (bsz, seqlen, dim).
|
| 338 |
+
new_kv (tuple, optional): New KV pairs if return_kv is True or kv_cache is provided.
|
| 339 |
+
"""
|
| 340 |
+
# Apply normalization
|
| 341 |
+
if self.use_adaln_zero:
|
| 342 |
+
assert condition is not None, "condition must be provided when using AdaLNZero"
|
| 343 |
+
attn_normed, attn_gate = self.attention_norm(x, condition=condition)
|
| 344 |
+
else:
|
| 345 |
+
attn_normed = self.attention_norm(x)
|
| 346 |
+
|
| 347 |
+
# Forward attention with KV cache if provided
|
| 348 |
+
new_kv = None
|
| 349 |
+
if kv_cache is not None and start_pos is not None:
|
| 350 |
+
# Use KV cache for efficient inference
|
| 351 |
+
attn_out, new_kv = self.attention.forward_with_cache(attn_normed, kv_cache, freqs_cis, start_pos)
|
| 352 |
+
elif return_kv:
|
| 353 |
+
# Return KV pairs for caching
|
| 354 |
+
attn_out, new_kv = self.attention(attn_normed, freqs_cis, mask, return_kv=True)
|
| 355 |
+
else:
|
| 356 |
+
attn_out = self.attention(attn_normed, freqs_cis, mask)
|
| 357 |
+
|
| 358 |
+
# Apply gating for attention if using AdaLNZero
|
| 359 |
+
if self.use_adaln_zero:
|
| 360 |
+
h = x + attn_gate * attn_out # residual + gate * x
|
| 361 |
+
else:
|
| 362 |
+
h = x + attn_out
|
| 363 |
+
|
| 364 |
+
# Apply normalization for feedforward
|
| 365 |
+
if self.use_adaln_zero:
|
| 366 |
+
ffn_normed, ffn_gate = self.ffn_norm(h, condition=condition)
|
| 367 |
+
else:
|
| 368 |
+
ffn_normed = self.ffn_norm(h)
|
| 369 |
+
|
| 370 |
+
ffn_out = self.feed_forward(ffn_normed)
|
| 371 |
+
|
| 372 |
+
# Apply gating for feedforward if using AdaLNZero
|
| 373 |
+
if self.use_adaln_zero:
|
| 374 |
+
out = h + ffn_gate * ffn_out # residual + gate * x
|
| 375 |
+
else:
|
| 376 |
+
out = h + ffn_out
|
| 377 |
+
|
| 378 |
+
# If using KV cache, return the new KV pairs
|
| 379 |
+
if new_kv is not None:
|
| 380 |
+
return out, new_kv
|
| 381 |
+
return out
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class Transformer(nn.Module):
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
dim: int = 4096,
|
| 388 |
+
n_layers: int = 32,
|
| 389 |
+
n_heads: int = 32,
|
| 390 |
+
qkv_bias: bool = False,
|
| 391 |
+
proj_bias: bool = False,
|
| 392 |
+
window_size: int | None = None,
|
| 393 |
+
multiple_of: int = 256,
|
| 394 |
+
ffn_dim_multiplier: float | None = None,
|
| 395 |
+
dropout: float = 0.1,
|
| 396 |
+
norm_eps: float = 1e-5,
|
| 397 |
+
use_rope: bool = True,
|
| 398 |
+
rope_theta: float = 500000.0,
|
| 399 |
+
max_seq_len: int = 2048,
|
| 400 |
+
input_dim: int | None = None,
|
| 401 |
+
output_dim: int | None = None,
|
| 402 |
+
adanorm_condition_dim: int | None = None,
|
| 403 |
+
use_flash_attention: bool = False,
|
| 404 |
+
use_adaln_zero: bool = False,
|
| 405 |
+
use_xavier_init: bool = True,
|
| 406 |
+
causal: bool = False,
|
| 407 |
+
):
|
| 408 |
+
super().__init__()
|
| 409 |
+
self.dim = dim
|
| 410 |
+
self.n_heads = n_heads
|
| 411 |
+
self.rope_theta = rope_theta
|
| 412 |
+
self.use_adaln_zero = use_adaln_zero
|
| 413 |
+
|
| 414 |
+
self.layers = nn.ModuleList()
|
| 415 |
+
for layer_id in range(n_layers):
|
| 416 |
+
self.layers.append(
|
| 417 |
+
TransformerBlock(
|
| 418 |
+
dim=dim,
|
| 419 |
+
n_heads=n_heads,
|
| 420 |
+
window_size=window_size,
|
| 421 |
+
multiple_of=multiple_of,
|
| 422 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 423 |
+
dropout=dropout,
|
| 424 |
+
qkv_bias=qkv_bias,
|
| 425 |
+
proj_bias=proj_bias,
|
| 426 |
+
norm_eps=norm_eps,
|
| 427 |
+
adanorm_condition_dim=adanorm_condition_dim,
|
| 428 |
+
use_flash_attention=use_flash_attention,
|
| 429 |
+
use_adaln_zero=use_adaln_zero,
|
| 430 |
+
causal=causal,
|
| 431 |
+
)
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Choose between AdaLNZero (without gate) and regular LayerNorm for final norm
|
| 435 |
+
if self.use_adaln_zero:
|
| 436 |
+
assert adanorm_condition_dim is not None, "condition_dim must be provided when using AdaLNZero"
|
| 437 |
+
self.norm = AdaLNZero(dim, adanorm_condition_dim, eps=norm_eps, return_gate=False)
|
| 438 |
+
else:
|
| 439 |
+
self.norm = nn.LayerNorm(dim, eps=norm_eps)
|
| 440 |
+
self.input_proj = nn.Linear(input_dim, dim) if input_dim is not None else nn.Identity()
|
| 441 |
+
self.output_proj = nn.Linear(dim, output_dim) if output_dim is not None else nn.Identity()
|
| 442 |
+
self.output_dim_ = output_dim if output_dim is not None else dim
|
| 443 |
+
|
| 444 |
+
if use_rope:
|
| 445 |
+
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_seq_len * 2, rope_theta)
|
| 446 |
+
logger.debug(
|
| 447 |
+
f"Using RoPE with theta={rope_theta}, max_seq_len={max_seq_len}, "
|
| 448 |
+
f"dim={dim}, n_heads={n_heads}, freqs_cis shape={self.freqs_cis.shape}"
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
self.freqs_cis = None
|
| 452 |
+
|
| 453 |
+
if window_size is not None:
|
| 454 |
+
logger.debug(f"Using local attention with window size {window_size}")
|
| 455 |
+
|
| 456 |
+
if self.use_adaln_zero:
|
| 457 |
+
logger.debug(f"Using AdaLNZero conditioning with condition_dim={adanorm_condition_dim}")
|
| 458 |
+
|
| 459 |
+
if use_flash_attention:
|
| 460 |
+
logger.debug("Using Flash Attention for memory-efficient attention computation")
|
| 461 |
+
|
| 462 |
+
if use_xavier_init:
|
| 463 |
+
logger.debug("Using Xavier initialization for linear layers")
|
| 464 |
+
self.apply(self._init_weights)
|
| 465 |
+
self.apply(self._init_adaln_zero)
|
| 466 |
+
|
| 467 |
+
@property
|
| 468 |
+
def output_dim(self) -> int:
|
| 469 |
+
return self.output_dim_
|
| 470 |
+
|
| 471 |
+
def _init_weights(self, module: nn.Module):
|
| 472 |
+
if isinstance(module, nn.Linear):
|
| 473 |
+
nn.init.xavier_normal_(module.weight)
|
| 474 |
+
if module.bias is not None:
|
| 475 |
+
nn.init.zeros_(module.bias)
|
| 476 |
+
|
| 477 |
+
def _init_adaln_zero(self, module: nn.Module):
|
| 478 |
+
if isinstance(module, AdaLNZero):
|
| 479 |
+
# Initialize condition projection weights to zero
|
| 480 |
+
nn.init.zeros_(module.condition_proj[1].weight)
|
| 481 |
+
nn.init.zeros_(module.condition_proj[1].bias)
|
| 482 |
+
|
| 483 |
+
def forward(
|
| 484 |
+
self,
|
| 485 |
+
x: torch.Tensor,
|
| 486 |
+
mask: torch.Tensor | None = None,
|
| 487 |
+
condition: torch.Tensor | None = None,
|
| 488 |
+
return_kv: bool = False,
|
| 489 |
+
kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
|
| 490 |
+
start_pos: int | None = None,
|
| 491 |
+
) -> torch.Tensor | tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
|
| 492 |
+
"""
|
| 493 |
+
Forward pass for the Transformer model.
|
| 494 |
+
Args:
|
| 495 |
+
x (torch.Tensor): Input tensor of shape (bsz, seqlen, input_dim).
|
| 496 |
+
mask (torch.Tensor, optional): Attention mask.
|
| 497 |
+
condition (torch.Tensor, optional): Conditioning tensor for AdaLNZero.
|
| 498 |
+
return_kv (bool): Whether to return KV pairs for caching.
|
| 499 |
+
kv_cache (list, optional): List of KV caches for each layer for efficient inference.
|
| 500 |
+
start_pos (int, optional): Starting position for KV cache.
|
| 501 |
+
Returns:
|
| 502 |
+
output (torch.Tensor): Output tensor of shape (bsz, seqlen, output_dim).
|
| 503 |
+
new_kv_list (list, optional): List of new KV pairs for each layer if return_kv is True or kv_cache is provided.
|
| 504 |
+
"""
|
| 505 |
+
bsz, seqlen, _dim = x.shape
|
| 506 |
+
|
| 507 |
+
if self.use_adaln_zero:
|
| 508 |
+
assert condition is not None, "condition must be provided when using AdaLNZero"
|
| 509 |
+
|
| 510 |
+
# Rotary embeddings
|
| 511 |
+
if self.freqs_cis is not None:
|
| 512 |
+
# Recompute freqs_cis if the sequence length or starting position exceeds the precomputed length
|
| 513 |
+
expected_len = (start_pos + 1) if start_pos is not None else seqlen
|
| 514 |
+
if expected_len > self.freqs_cis.shape[0]:
|
| 515 |
+
logger.warning(
|
| 516 |
+
f"Input sequence length {expected_len} exceeds precomputed RoPE length {self.freqs_cis.shape[0]}. Recomputing freqs_cis."
|
| 517 |
+
)
|
| 518 |
+
self.freqs_cis = precompute_freqs_cis(self.dim // self.n_heads, expected_len * 4, self.rope_theta)
|
| 519 |
+
|
| 520 |
+
self.freqs_cis = self.freqs_cis.to(x.device)
|
| 521 |
+
freqs_cis = self.freqs_cis
|
| 522 |
+
else:
|
| 523 |
+
freqs_cis = None
|
| 524 |
+
|
| 525 |
+
x = self.input_proj(x)
|
| 526 |
+
new_kv_list = []
|
| 527 |
+
for i, layer in enumerate(self.layers):
|
| 528 |
+
# Collect KV cache if provided
|
| 529 |
+
if kv_cache is not None and start_pos is not None:
|
| 530 |
+
x, new_kv = layer(x, freqs_cis, mask, condition, kv_cache=kv_cache[i], start_pos=start_pos)
|
| 531 |
+
new_kv_list.append(new_kv)
|
| 532 |
+
elif return_kv:
|
| 533 |
+
x, new_kv = layer(x, freqs_cis, mask, condition, return_kv=True)
|
| 534 |
+
new_kv_list.append(new_kv)
|
| 535 |
+
else:
|
| 536 |
+
x = layer(x, freqs_cis, mask, condition)
|
| 537 |
+
|
| 538 |
+
# Apply final normalization
|
| 539 |
+
if self.use_adaln_zero:
|
| 540 |
+
x, _ = self.norm(x, condition=condition) # Final norm doesn't use gate
|
| 541 |
+
else:
|
| 542 |
+
x = self.norm(x)
|
| 543 |
+
|
| 544 |
+
output = self.output_proj(x)
|
| 545 |
+
|
| 546 |
+
# If using KV cache, return the new KV pairs
|
| 547 |
+
if new_kv_list:
|
| 548 |
+
return output, new_kv_list
|
| 549 |
+
return output
|
src/kanade_tokenizer/pipeline.py
ADDED
|
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
import jsonargparse
|
| 5 |
+
import lightning as L
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import yaml
|
| 10 |
+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
| 11 |
+
from torch.optim import AdamW
|
| 12 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 13 |
+
|
| 14 |
+
from .data.datamodule import AudioBatch
|
| 15 |
+
from .model import KanadeModel, KanadeModelConfig
|
| 16 |
+
from .module.audio_feature import MelSpectrogramFeature
|
| 17 |
+
from .module.discriminator import SpectrogramDiscriminator
|
| 18 |
+
from .module.fsq import FiniteScalarQuantizer
|
| 19 |
+
from .module.global_encoder import GlobalEncoder
|
| 20 |
+
from .module.postnet import PostNet
|
| 21 |
+
from .module.ssl_extractor import SSLFeatureExtractor
|
| 22 |
+
from .module.transformer import Transformer
|
| 23 |
+
from .util import freeze_modules, get_logger, load_vocoder, vocode
|
| 24 |
+
|
| 25 |
+
logger = get_logger()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class KanadePipelineConfig:
|
| 30 |
+
# Training control
|
| 31 |
+
train_feature: bool = True # Whether to train the feature reconstruction branch
|
| 32 |
+
train_mel: bool = True # Whether to train the mel spectrogram generation branch
|
| 33 |
+
|
| 34 |
+
# Audio settings
|
| 35 |
+
audio_length: int = 138240 # Length of audio input in samples
|
| 36 |
+
|
| 37 |
+
# Optimization settings
|
| 38 |
+
lr: float = 2e-4
|
| 39 |
+
weight_decay: float = 1e-4
|
| 40 |
+
betas: tuple[float, float] = (0.9, 0.99)
|
| 41 |
+
gradient_clip_val: float | None = 1.0
|
| 42 |
+
|
| 43 |
+
# LR scheduling parameters
|
| 44 |
+
warmup_percent: float = 0.1
|
| 45 |
+
lr_div_factor: float = 10.0
|
| 46 |
+
lr_final_div_factor: float = 1.0
|
| 47 |
+
anneal_mode: str = "cos"
|
| 48 |
+
|
| 49 |
+
# Loss weights
|
| 50 |
+
feature_l1_weight: float = 30.0
|
| 51 |
+
feature_l2_weight: float = 0.0
|
| 52 |
+
mel_l1_weight: float = 30.0
|
| 53 |
+
mel_l2_weight: float = 0.0
|
| 54 |
+
adv_weight: float = 1.0
|
| 55 |
+
feature_matching_weight: float = 10.0
|
| 56 |
+
|
| 57 |
+
# GAN settings
|
| 58 |
+
use_discriminator: bool = False
|
| 59 |
+
adv_loss_type: Literal["hinge", "least_square"] = "hinge" # Type of adversarial loss
|
| 60 |
+
discriminator_lr: float | None = None # Learning rate for discriminator
|
| 61 |
+
discriminator_start_step: int = 0 # Step to start training discriminator
|
| 62 |
+
discriminator_update_prob: float = 1.0 # Probability of updating discriminator at each step
|
| 63 |
+
|
| 64 |
+
# Checkpoint loading
|
| 65 |
+
ckpt_path: str | None = None # Path to checkpoint to load from
|
| 66 |
+
skip_loading_modules: tuple[str, ...] = () # Modules to skip when loading checkpoint
|
| 67 |
+
|
| 68 |
+
# Other settings
|
| 69 |
+
log_mel_samples: int = 10
|
| 70 |
+
use_torch_compile: bool = True
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class KanadePipeline(L.LightningModule):
|
| 74 |
+
"""LightningModule wrapper for KanadeModel, handling training (including GAN)."""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
model_config: KanadeModelConfig,
|
| 79 |
+
pipeline_config: KanadePipelineConfig,
|
| 80 |
+
ssl_feature_extractor: SSLFeatureExtractor,
|
| 81 |
+
local_encoder: Transformer,
|
| 82 |
+
local_quantizer: FiniteScalarQuantizer,
|
| 83 |
+
feature_decoder: Transformer | None,
|
| 84 |
+
global_encoder: GlobalEncoder,
|
| 85 |
+
mel_prenet: Transformer,
|
| 86 |
+
mel_decoder: Transformer,
|
| 87 |
+
mel_postnet: PostNet,
|
| 88 |
+
discriminator: SpectrogramDiscriminator | None = None,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.config = pipeline_config
|
| 92 |
+
self.save_hyperparameters("model_config", "pipeline_config")
|
| 93 |
+
self.strict_loading = False
|
| 94 |
+
self.automatic_optimization = False
|
| 95 |
+
self.torch_compiled = False
|
| 96 |
+
|
| 97 |
+
# Validate components required for training
|
| 98 |
+
assert not pipeline_config.train_feature or feature_decoder is not None, (
|
| 99 |
+
"Feature decoder must be provided if training feature reconstruction"
|
| 100 |
+
)
|
| 101 |
+
logger.info(
|
| 102 |
+
f"Training configuration: train_feature={pipeline_config.train_feature}, train_mel={pipeline_config.train_mel}"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# 1. Kanade model
|
| 106 |
+
self.model = KanadeModel(
|
| 107 |
+
config=model_config,
|
| 108 |
+
ssl_feature_extractor=ssl_feature_extractor,
|
| 109 |
+
local_encoder=local_encoder,
|
| 110 |
+
local_quantizer=local_quantizer,
|
| 111 |
+
feature_decoder=feature_decoder,
|
| 112 |
+
global_encoder=global_encoder,
|
| 113 |
+
mel_decoder=mel_decoder,
|
| 114 |
+
mel_prenet=mel_prenet,
|
| 115 |
+
mel_postnet=mel_postnet,
|
| 116 |
+
)
|
| 117 |
+
self._freeze_unused_modules(pipeline_config.train_feature, pipeline_config.train_mel)
|
| 118 |
+
|
| 119 |
+
# Calculate padding for expected SSL output length
|
| 120 |
+
self.padding = self.model._calculate_waveform_padding(pipeline_config.audio_length)
|
| 121 |
+
logger.info(f"Input waveform padding for SSL feature extractor: {self.padding} samples")
|
| 122 |
+
|
| 123 |
+
# Calculate target mel spectrogram length
|
| 124 |
+
self.target_mel_length = self.model._calculate_target_mel_length(pipeline_config.audio_length)
|
| 125 |
+
logger.info(f"Target mel spectrogram length: {self.target_mel_length} frames")
|
| 126 |
+
|
| 127 |
+
# 2. Discriminator
|
| 128 |
+
self._init_discriminator(pipeline_config, discriminator)
|
| 129 |
+
|
| 130 |
+
# 3. Mel spectrogram feature extractor for loss computation
|
| 131 |
+
if pipeline_config.train_mel:
|
| 132 |
+
self.mel_spec = MelSpectrogramFeature(
|
| 133 |
+
sample_rate=model_config.sample_rate,
|
| 134 |
+
n_fft=model_config.n_fft,
|
| 135 |
+
hop_length=model_config.hop_length,
|
| 136 |
+
n_mels=model_config.n_mels,
|
| 137 |
+
padding=model_config.padding,
|
| 138 |
+
fmin=model_config.mel_fmin,
|
| 139 |
+
fmax=model_config.mel_fmax,
|
| 140 |
+
bigvgan_style_mel=model_config.bigvgan_style_mel,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Mel sample storage for logging
|
| 144 |
+
self.vocoder = None
|
| 145 |
+
self.validation_examples = []
|
| 146 |
+
self.log_mel_samples = pipeline_config.log_mel_samples
|
| 147 |
+
|
| 148 |
+
def _freeze_unused_modules(self, train_feature: bool, train_mel: bool):
|
| 149 |
+
model = self.model
|
| 150 |
+
if not train_feature:
|
| 151 |
+
# Freeze local branch components if not training feature reconstruction
|
| 152 |
+
freeze_modules([model.local_encoder, model.local_quantizer, model.feature_decoder])
|
| 153 |
+
if model.conv_downsample is not None:
|
| 154 |
+
freeze_modules([model.conv_downsample, model.conv_upsample])
|
| 155 |
+
logger.info("Feature reconstruction branch frozen: local_encoder, local_quantizer, feature_decoder")
|
| 156 |
+
|
| 157 |
+
if not train_mel:
|
| 158 |
+
# Freeze global branch and mel generation components if not training mel generation
|
| 159 |
+
freeze_modules(
|
| 160 |
+
[model.global_encoder, model.mel_prenet, model.mel_conv_upsample, model.mel_decoder, model.mel_postnet]
|
| 161 |
+
)
|
| 162 |
+
logger.info(
|
| 163 |
+
"Mel generation branch frozen: global_encoder, mel_prenet, mel_conv_upsample, mel_decoder, mel_postnet"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def _init_discriminator(self, config: KanadePipelineConfig, discriminator: SpectrogramDiscriminator | None):
|
| 167 |
+
# Setup discriminator if provided
|
| 168 |
+
self.discriminator = discriminator
|
| 169 |
+
self.use_discriminator = config.use_discriminator and discriminator is not None and config.train_mel
|
| 170 |
+
|
| 171 |
+
if config.use_discriminator and discriminator is None:
|
| 172 |
+
logger.error(
|
| 173 |
+
"Discriminator is enabled in config but no discriminator model provided. Disabling GAN training."
|
| 174 |
+
)
|
| 175 |
+
if config.use_discriminator and discriminator is not None and not config.train_mel:
|
| 176 |
+
logger.warning(
|
| 177 |
+
"Discriminator is enabled but train_mel=False. Discriminator will not be effective without mel training."
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.discriminator_start_step = config.discriminator_start_step
|
| 181 |
+
self.discriminator_update_prob = config.discriminator_update_prob
|
| 182 |
+
if self.use_discriminator:
|
| 183 |
+
logger.info("Discriminator initialized for GAN training")
|
| 184 |
+
logger.info(f"Discriminator start step: {self.discriminator_start_step}")
|
| 185 |
+
logger.info(f"Discriminator update probability: {self.discriminator_update_prob}")
|
| 186 |
+
|
| 187 |
+
def setup(self, stage: str):
|
| 188 |
+
# Torch compile model if enabled
|
| 189 |
+
if torch.__version__ >= "2.0" and self.config.use_torch_compile:
|
| 190 |
+
self.model = torch.compile(self.model)
|
| 191 |
+
if self.discriminator is not None:
|
| 192 |
+
self.discriminator = torch.compile(self.discriminator)
|
| 193 |
+
self.torch_compiled = True
|
| 194 |
+
|
| 195 |
+
# Load checkpoint if provided
|
| 196 |
+
if self.config.ckpt_path:
|
| 197 |
+
ckpt_path = self.config.ckpt_path
|
| 198 |
+
|
| 199 |
+
# Download weights from HuggingFace Hub if needed
|
| 200 |
+
if ckpt_path.startswith("hf:"):
|
| 201 |
+
from huggingface_hub import hf_hub_download
|
| 202 |
+
|
| 203 |
+
repo_id = ckpt_path[len("hf:") :]
|
| 204 |
+
# Separate out revision if specified
|
| 205 |
+
revision = None
|
| 206 |
+
if "@" in repo_id:
|
| 207 |
+
repo_id, revision = repo_id.split("@", 1)
|
| 208 |
+
|
| 209 |
+
ckpt_path = hf_hub_download(repo_id, filename="model.safetensors", revision=revision)
|
| 210 |
+
|
| 211 |
+
self._load_weights(ckpt_path)
|
| 212 |
+
|
| 213 |
+
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
|
| 214 |
+
"""
|
| 215 |
+
Returns:
|
| 216 |
+
ssl_real: Extracted SSL features for local branch (B, T, C)
|
| 217 |
+
ssl_recon: Reconstructed SSL features (B, T, C) - only if train_feature=True
|
| 218 |
+
mel_recon: Generated mel spectrogram (B, n_mels, T) - only if train_mel=True
|
| 219 |
+
loss_dict: Dictionary with auxiliary information (codes, losses, etc.)
|
| 220 |
+
"""
|
| 221 |
+
loss_dict = {}
|
| 222 |
+
|
| 223 |
+
# 1. Extract SSL features
|
| 224 |
+
local_ssl_features, global_ssl_features = self.model.forward_ssl_features(waveform, padding=self.padding)
|
| 225 |
+
|
| 226 |
+
# 2. Content branch processing
|
| 227 |
+
content_embeddings, _, ssl_recon, perplexity = self.model.forward_content(local_ssl_features)
|
| 228 |
+
loss_dict["local/perplexity"] = perplexity
|
| 229 |
+
|
| 230 |
+
# 3. Global branch processing and mel reconstruction
|
| 231 |
+
mel_recon = None
|
| 232 |
+
if self.config.train_mel:
|
| 233 |
+
global_embeddings = self.model.forward_global(global_ssl_features)
|
| 234 |
+
mel_recon = self.model.forward_mel(content_embeddings, global_embeddings, mel_length=self.target_mel_length)
|
| 235 |
+
|
| 236 |
+
return local_ssl_features, ssl_recon, mel_recon, loss_dict
|
| 237 |
+
|
| 238 |
+
def _get_reconstruction_loss(
|
| 239 |
+
self, audio_real: torch.Tensor, ssl_real: torch.Tensor, ssl_recon: torch.Tensor, mel_recon: torch.Tensor
|
| 240 |
+
) -> tuple[torch.Tensor, dict, torch.Tensor]:
|
| 241 |
+
"""Compute L1 + L2 loss for SSL feature and mel spectrogram reconstruction.
|
| 242 |
+
Returns:
|
| 243 |
+
total_loss: Combined reconstruction loss
|
| 244 |
+
loss_dict: Dictionary with individual loss components
|
| 245 |
+
mel_real: Real mel spectrogram for reference
|
| 246 |
+
"""
|
| 247 |
+
if audio_real.dim() == 3:
|
| 248 |
+
audio_real = audio_real.squeeze(1)
|
| 249 |
+
|
| 250 |
+
loss_dict = {}
|
| 251 |
+
feature_loss, mel_loss = 0, 0
|
| 252 |
+
|
| 253 |
+
# Compute SSL feature reconstruction losses if training features
|
| 254 |
+
if self.config.train_feature and self.model.feature_decoder is not None:
|
| 255 |
+
assert ssl_real is not None and ssl_recon is not None, (
|
| 256 |
+
"SSL features must be provided for training feature reconstruction"
|
| 257 |
+
)
|
| 258 |
+
ssl_l1 = F.l1_loss(ssl_recon, ssl_real)
|
| 259 |
+
ssl_l2 = F.mse_loss(ssl_recon, ssl_real)
|
| 260 |
+
|
| 261 |
+
feature_loss = self.config.feature_l1_weight * ssl_l1 + self.config.feature_l2_weight * ssl_l2
|
| 262 |
+
loss_dict.update({"ssl_l1": ssl_l1, "ssl_l2": ssl_l2, "feature_loss": feature_loss})
|
| 263 |
+
|
| 264 |
+
# Compute mel spectrogram reconstruction losses if training mel
|
| 265 |
+
mel_real = None
|
| 266 |
+
if self.config.train_mel:
|
| 267 |
+
assert mel_recon is not None, "Mel reconstruction must be provided for training mel generation"
|
| 268 |
+
# Extract reference mel spectrogram from audio
|
| 269 |
+
mel_real = self.mel_spec(audio_real)
|
| 270 |
+
|
| 271 |
+
mel_l1 = F.l1_loss(mel_recon, mel_real)
|
| 272 |
+
mel_l2 = F.mse_loss(mel_recon, mel_real)
|
| 273 |
+
mel_loss = self.config.mel_l1_weight * mel_l1 + self.config.mel_l2_weight * mel_l2
|
| 274 |
+
loss_dict.update({"mel_l1": mel_l1, "mel_l2": mel_l2, "mel_loss": mel_loss})
|
| 275 |
+
|
| 276 |
+
total_loss = feature_loss + mel_loss
|
| 277 |
+
return total_loss, loss_dict, mel_real
|
| 278 |
+
|
| 279 |
+
def _get_discriminator_loss(
|
| 280 |
+
self, real_outputs: torch.Tensor, fake_outputs: torch.Tensor
|
| 281 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 282 |
+
"""Compute the adversarial loss for discriminator.
|
| 283 |
+
Returns:
|
| 284 |
+
disc_loss: Total discriminator loss
|
| 285 |
+
real_loss: Loss component from real samples
|
| 286 |
+
fake_loss: Loss component from fake samples
|
| 287 |
+
"""
|
| 288 |
+
if self.config.adv_loss_type == "hinge":
|
| 289 |
+
real_loss = torch.mean(torch.clamp(1 - real_outputs, min=0))
|
| 290 |
+
fake_loss = torch.mean(torch.clamp(1 + fake_outputs, min=0))
|
| 291 |
+
elif self.config.adv_loss_type == "least_square":
|
| 292 |
+
real_loss = torch.mean((real_outputs - 1) ** 2)
|
| 293 |
+
fake_loss = torch.mean(fake_outputs**2)
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError(f"Unknown adversarial loss type: {self.config.adv_loss_type}")
|
| 296 |
+
|
| 297 |
+
disc_loss = real_loss + fake_loss
|
| 298 |
+
return disc_loss, real_loss, fake_loss
|
| 299 |
+
|
| 300 |
+
def _get_generator_loss(self, fake_outputs: torch.Tensor) -> torch.Tensor:
|
| 301 |
+
"""Compute the adversarial loss for generator."""
|
| 302 |
+
if self.config.adv_loss_type == "hinge":
|
| 303 |
+
return torch.mean(torch.clamp(1 - fake_outputs, min=0))
|
| 304 |
+
elif self.config.adv_loss_type == "least_square":
|
| 305 |
+
return torch.mean((fake_outputs - 1) ** 2)
|
| 306 |
+
else:
|
| 307 |
+
raise ValueError(f"Unknown adversarial loss type: {self.config.adv_loss_type}")
|
| 308 |
+
|
| 309 |
+
def _get_feature_matching_loss(
|
| 310 |
+
self, real_intermediates: list[torch.Tensor], fake_intermediates: list[torch.Tensor]
|
| 311 |
+
) -> torch.Tensor:
|
| 312 |
+
losses = []
|
| 313 |
+
for real_feat, fake_feat in zip(real_intermediates, fake_intermediates):
|
| 314 |
+
losses.append(torch.mean(torch.abs(real_feat.detach() - fake_feat)))
|
| 315 |
+
fm_loss = torch.mean(torch.stack(losses))
|
| 316 |
+
return fm_loss
|
| 317 |
+
|
| 318 |
+
def _discriminator_step(
|
| 319 |
+
self, batch: AudioBatch, optimizer_disc: torch.optim.Optimizer
|
| 320 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor], list[torch.Tensor]]:
|
| 321 |
+
"""
|
| 322 |
+
Returns:
|
| 323 |
+
ssl_real: Real SSL features
|
| 324 |
+
ssl_recon: Reconstructed SSL features from generator
|
| 325 |
+
mel_recon: Generated mel spectrogram
|
| 326 |
+
loss_dict: Dictionary with auxiliary information
|
| 327 |
+
real_intermediates: Intermediate feature maps from discriminator for real mel
|
| 328 |
+
"""
|
| 329 |
+
assert self.use_discriminator, "Discriminator step called but discriminator is not enabled"
|
| 330 |
+
|
| 331 |
+
ssl_real, ssl_recon, mel_recon, loss_dict = self(batch.waveform)
|
| 332 |
+
assert mel_recon is not None, "Mel reconstruction must be available for discriminator step"
|
| 333 |
+
|
| 334 |
+
# Get true mel spectrogram (always use original waveform)
|
| 335 |
+
mel_real = self.mel_spec(batch.waveform)
|
| 336 |
+
|
| 337 |
+
# Get discriminator outputs and intermediates for real mel
|
| 338 |
+
real_outputs, real_intermediates = self.discriminator(mel_real)
|
| 339 |
+
fake_outputs, _ = self.discriminator(mel_recon.detach())
|
| 340 |
+
|
| 341 |
+
# Compute discriminator loss
|
| 342 |
+
disc_loss, real_loss, fake_loss = self._get_discriminator_loss(real_outputs, fake_outputs)
|
| 343 |
+
|
| 344 |
+
# Log discriminator losses
|
| 345 |
+
batch_size = batch.waveform.size(0)
|
| 346 |
+
self.log("train/disc/real", real_loss, batch_size=batch_size)
|
| 347 |
+
self.log("train/disc/fake", fake_loss, batch_size=batch_size)
|
| 348 |
+
self.log("train/disc/loss", disc_loss, batch_size=batch_size, prog_bar=True)
|
| 349 |
+
for name, value in loss_dict.items():
|
| 350 |
+
self.log(f"train/{name}", value, batch_size=batch_size)
|
| 351 |
+
|
| 352 |
+
# Optimize discriminator
|
| 353 |
+
optimizer_disc.zero_grad()
|
| 354 |
+
self.manual_backward(disc_loss)
|
| 355 |
+
|
| 356 |
+
# Log gradient norm
|
| 357 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 358 |
+
self.discriminator.parameters(), max_norm=self.config.gradient_clip_val or torch.inf
|
| 359 |
+
)
|
| 360 |
+
self.log("train/disc/grad_norm", grad_norm, batch_size=batch_size)
|
| 361 |
+
|
| 362 |
+
optimizer_disc.step()
|
| 363 |
+
|
| 364 |
+
return ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates
|
| 365 |
+
|
| 366 |
+
def _generator_step(
|
| 367 |
+
self,
|
| 368 |
+
batch: AudioBatch,
|
| 369 |
+
optimizer_gen: torch.optim.Optimizer,
|
| 370 |
+
ssl_real: torch.Tensor | None = None,
|
| 371 |
+
ssl_recon: torch.Tensor | None = None,
|
| 372 |
+
mel_recon: torch.Tensor | None = None,
|
| 373 |
+
loss_dict: dict | None = None,
|
| 374 |
+
real_intermediates: list[torch.Tensor] | None = None,
|
| 375 |
+
training_disc: bool = False,
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
"""
|
| 378 |
+
Args:
|
| 379 |
+
batch: Audio batch with waveform and augmented_waveform
|
| 380 |
+
optimizer_gen: Generator optimizer
|
| 381 |
+
ssl_real: Real SSL features (optional)
|
| 382 |
+
ssl_recon: Reconstructed SSL features (optional)
|
| 383 |
+
mel_recon: Generated mel spectrogram (optional)
|
| 384 |
+
loss_dict: Dictionary with auxiliary information (optional)
|
| 385 |
+
real_intermediates: Intermediate feature maps from discriminator for real mel (optional)
|
| 386 |
+
training_disc: Whether discriminator is being trained in this step
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
gen_loss: Total generator loss
|
| 390 |
+
"""
|
| 391 |
+
# Forward pass through the model if not already done in discriminator step
|
| 392 |
+
if loss_dict is None:
|
| 393 |
+
ssl_real, ssl_recon, mel_recon, loss_dict = self(batch.waveform)
|
| 394 |
+
|
| 395 |
+
# Compute reconstruction loss (always use original waveform for mel target)
|
| 396 |
+
recon_loss, recon_dict, mel_real = self._get_reconstruction_loss(batch.waveform, ssl_real, ssl_recon, mel_recon)
|
| 397 |
+
gen_loss = recon_loss
|
| 398 |
+
|
| 399 |
+
# Compute adversarial and feature matching losses if using discriminator
|
| 400 |
+
batch_size = batch.waveform.size(0)
|
| 401 |
+
if training_disc:
|
| 402 |
+
assert mel_real is not None and mel_recon is not None, "Mel spectrograms must be provided for GAN training"
|
| 403 |
+
|
| 404 |
+
if real_intermediates is None:
|
| 405 |
+
_, real_intermediates = self.discriminator(mel_real)
|
| 406 |
+
|
| 407 |
+
fake_outputs, fake_intermediates = self.discriminator(mel_recon)
|
| 408 |
+
|
| 409 |
+
# Compute adversarial loss
|
| 410 |
+
adv_loss = self._get_generator_loss(fake_outputs)
|
| 411 |
+
gen_loss += self.config.adv_weight * adv_loss
|
| 412 |
+
self.log("train/gen/adv_loss", adv_loss, batch_size=batch_size)
|
| 413 |
+
|
| 414 |
+
# Compute feature matching loss
|
| 415 |
+
feature_matching_loss = self._get_feature_matching_loss(real_intermediates, fake_intermediates)
|
| 416 |
+
gen_loss += self.config.feature_matching_weight * feature_matching_loss
|
| 417 |
+
self.log("train/gen/feature_matching_loss", feature_matching_loss, batch_size=batch_size)
|
| 418 |
+
|
| 419 |
+
# Log reconstruction losses
|
| 420 |
+
for name, value in loss_dict.items():
|
| 421 |
+
self.log(f"train/{name}", value, batch_size=batch_size)
|
| 422 |
+
for name, value in recon_dict.items():
|
| 423 |
+
self.log(f"train/gen/{name}", value, batch_size=batch_size)
|
| 424 |
+
|
| 425 |
+
self.log("train/loss", gen_loss, batch_size=batch_size, prog_bar=True)
|
| 426 |
+
|
| 427 |
+
# Optimize generator
|
| 428 |
+
optimizer_gen.zero_grad()
|
| 429 |
+
self.manual_backward(gen_loss)
|
| 430 |
+
|
| 431 |
+
# Log gradient norm
|
| 432 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 433 |
+
self.model.parameters(), max_norm=self.config.gradient_clip_val or torch.inf
|
| 434 |
+
)
|
| 435 |
+
self.log("train/gen/grad_norm", grad_norm, batch_size=batch_size)
|
| 436 |
+
|
| 437 |
+
optimizer_gen.step()
|
| 438 |
+
|
| 439 |
+
return gen_loss
|
| 440 |
+
|
| 441 |
+
def training_step(self, batch: AudioBatch, batch_idx: int):
|
| 442 |
+
if self.use_discriminator:
|
| 443 |
+
optimizer_disc, optimizer_gen = self.optimizers()
|
| 444 |
+
scheduler_disc, scheduler_gen = self.lr_schedulers()
|
| 445 |
+
else:
|
| 446 |
+
optimizer_gen = self.optimizers()
|
| 447 |
+
scheduler_gen = self.lr_schedulers()
|
| 448 |
+
|
| 449 |
+
# Determine if discriminator should be trained in this step
|
| 450 |
+
training_disc = (
|
| 451 |
+
self.use_discriminator
|
| 452 |
+
and self.global_step >= self.discriminator_start_step
|
| 453 |
+
and torch.rand(1).item() < self.discriminator_update_prob
|
| 454 |
+
)
|
| 455 |
+
if self.global_step == self.discriminator_start_step and self.use_discriminator:
|
| 456 |
+
logger.info(f"Discriminator training starts at step {self.global_step}")
|
| 457 |
+
|
| 458 |
+
ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates = None, None, None, None, None
|
| 459 |
+
|
| 460 |
+
# Train discriminator if conditions are met
|
| 461 |
+
if training_disc:
|
| 462 |
+
ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates = self._discriminator_step(
|
| 463 |
+
batch, optimizer_disc
|
| 464 |
+
)
|
| 465 |
+
scheduler_disc.step()
|
| 466 |
+
elif self.use_discriminator:
|
| 467 |
+
# Step the discriminator scheduler even when not training discriminator
|
| 468 |
+
scheduler_disc.step()
|
| 469 |
+
|
| 470 |
+
# Train generator
|
| 471 |
+
self._generator_step(
|
| 472 |
+
batch, optimizer_gen, ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates, training_disc
|
| 473 |
+
)
|
| 474 |
+
scheduler_gen.step()
|
| 475 |
+
|
| 476 |
+
def validation_step(self, batch: AudioBatch, batch_idx: int):
|
| 477 |
+
audio_real = batch.waveform
|
| 478 |
+
ssl_real, ssl_recon, mel_recon, loss_dict = self(audio_real)
|
| 479 |
+
|
| 480 |
+
# Convert to waveform using vocoder for logging
|
| 481 |
+
batch_size = audio_real.size(0)
|
| 482 |
+
|
| 483 |
+
# Compute reconstruction loss
|
| 484 |
+
recon_loss, recon_dict, mel_real = self._get_reconstruction_loss(audio_real, ssl_real, ssl_recon, mel_recon)
|
| 485 |
+
gen_loss = recon_loss
|
| 486 |
+
|
| 487 |
+
# Log reconstruction losses
|
| 488 |
+
for name, value in loss_dict.items():
|
| 489 |
+
self.log(f"val/{name}", value, batch_size=batch_size)
|
| 490 |
+
for name, value in recon_dict.items():
|
| 491 |
+
self.log(f"val/gen/{name}", value, batch_size=batch_size)
|
| 492 |
+
self.log("val/loss", gen_loss, batch_size=batch_size)
|
| 493 |
+
|
| 494 |
+
# Save first few samples for visualization at end of epoch if training mel generation
|
| 495 |
+
if self.config.train_mel and len(self.validation_examples) < self.log_mel_samples:
|
| 496 |
+
assert mel_real is not None and mel_recon is not None, (
|
| 497 |
+
"Mel spectrograms must be provided for validation logging"
|
| 498 |
+
)
|
| 499 |
+
audio_real = audio_real[0].cpu()
|
| 500 |
+
audio_gen = None
|
| 501 |
+
if self.vocoder is not None:
|
| 502 |
+
audio_gen = self.vocode(mel_recon[0:1])[0].cpu()
|
| 503 |
+
|
| 504 |
+
self.validation_examples.append((mel_real[0].cpu(), mel_recon[0].detach().cpu(), audio_real, audio_gen))
|
| 505 |
+
|
| 506 |
+
def predict_step(self, batch: AudioBatch, batch_idx: int):
|
| 507 |
+
audio_real = batch.waveform
|
| 508 |
+
_, _, mel_gen, _ = self(audio_real)
|
| 509 |
+
|
| 510 |
+
audio_gen = self.vocode(mel_gen)
|
| 511 |
+
|
| 512 |
+
if audio_gen.dim() == 2:
|
| 513 |
+
audio_gen = audio_gen.unsqueeze(1)
|
| 514 |
+
return {"audio_ids": batch.audio_ids, "audio_real": audio_real, "audio_gen": audio_gen}
|
| 515 |
+
|
| 516 |
+
def configure_optimizers(self):
|
| 517 |
+
# Generator optimizer
|
| 518 |
+
optimizer_gen = AdamW(
|
| 519 |
+
self.model.parameters(), lr=self.config.lr, betas=self.config.betas, weight_decay=self.config.weight_decay
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Generator LR scheduler
|
| 523 |
+
scheduler_gen = OneCycleLR(
|
| 524 |
+
optimizer_gen,
|
| 525 |
+
max_lr=self.config.lr,
|
| 526 |
+
div_factor=self.config.lr_div_factor,
|
| 527 |
+
final_div_factor=self.config.lr_final_div_factor,
|
| 528 |
+
pct_start=self.config.warmup_percent,
|
| 529 |
+
anneal_strategy=self.config.anneal_mode,
|
| 530 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
if not self.use_discriminator:
|
| 534 |
+
return ([optimizer_gen], [{"scheduler": scheduler_gen, "interval": "step"}])
|
| 535 |
+
|
| 536 |
+
# If using discriminator, also configure discriminator optimizer and scheduler
|
| 537 |
+
optimizer_disc = AdamW(
|
| 538 |
+
self.discriminator.parameters(),
|
| 539 |
+
lr=self.config.discriminator_lr or self.config.lr,
|
| 540 |
+
betas=self.config.betas,
|
| 541 |
+
weight_decay=self.config.weight_decay,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# Discriminator LR scheduler
|
| 545 |
+
scheduler_disc = OneCycleLR(
|
| 546 |
+
optimizer_disc,
|
| 547 |
+
max_lr=self.config.discriminator_lr or self.config.lr,
|
| 548 |
+
div_factor=self.config.lr_div_factor,
|
| 549 |
+
final_div_factor=self.config.lr_final_div_factor,
|
| 550 |
+
pct_start=self.config.warmup_percent,
|
| 551 |
+
anneal_strategy=self.config.anneal_mode,
|
| 552 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Load optimizer state
|
| 556 |
+
if self.config.ckpt_path:
|
| 557 |
+
if self.config.ckpt_path.endswith(".ckpt"):
|
| 558 |
+
checkpoint = torch.load(self.config.ckpt_path)
|
| 559 |
+
optimizer_states = checkpoint["optimizer_states"]
|
| 560 |
+
if len(optimizer_states) > 1 and self.use_discriminator:
|
| 561 |
+
optimizer_disc.load_state_dict(optimizer_states[0])
|
| 562 |
+
optimizer_gen.load_state_dict(optimizer_states[1])
|
| 563 |
+
logger.info("Loaded discriminator and generator's optimizer states from checkpoint")
|
| 564 |
+
elif len(optimizer_states) == 1 and not self.use_discriminator:
|
| 565 |
+
# Load generator optimizer state only
|
| 566 |
+
optimizer_gen.load_state_dict(optimizer_states[0])
|
| 567 |
+
logger.info("Loaded generator's optimizer state from checkpoint")
|
| 568 |
+
else:
|
| 569 |
+
logger.info("No optimizer state loaded since checkpoint is not a .ckpt file")
|
| 570 |
+
|
| 571 |
+
return (
|
| 572 |
+
[optimizer_disc, optimizer_gen],
|
| 573 |
+
[{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
def _setup_vocoder(self):
|
| 577 |
+
try:
|
| 578 |
+
return load_vocoder(name=self.model.config.vocoder_name)
|
| 579 |
+
except ImportError:
|
| 580 |
+
logger.error("Vocoder could not be loaded. Please install the required dependencies.")
|
| 581 |
+
return None
|
| 582 |
+
|
| 583 |
+
def vocode(self, mel: torch.Tensor) -> torch.Tensor:
|
| 584 |
+
self.vocoder = self.vocoder.to(mel.device)
|
| 585 |
+
waveform = vocode(self.vocoder, mel)
|
| 586 |
+
return waveform.cpu().float()
|
| 587 |
+
|
| 588 |
+
def on_validation_start(self):
|
| 589 |
+
self.vocoder = self._setup_vocoder()
|
| 590 |
+
|
| 591 |
+
def on_predict_start(self):
|
| 592 |
+
self.vocoder = self._setup_vocoder()
|
| 593 |
+
|
| 594 |
+
def on_validation_end(self):
|
| 595 |
+
if len(self.validation_examples) > 0:
|
| 596 |
+
for i, (mel_real, mel_recon, audio_real, audio_gen) in enumerate(self.validation_examples):
|
| 597 |
+
# Log spectrograms
|
| 598 |
+
fig_real = self._get_spectrogram_plot(mel_real)
|
| 599 |
+
fig_gen = self._get_spectrogram_plot(mel_recon)
|
| 600 |
+
self._log_figure(f"val/{i}_mel_real", fig_real)
|
| 601 |
+
self._log_figure(f"val/{i}_mel_gen", fig_gen)
|
| 602 |
+
|
| 603 |
+
# Log audio samples
|
| 604 |
+
if audio_gen is not None:
|
| 605 |
+
audio_real = audio_real.cpu().numpy()
|
| 606 |
+
audio_gen = audio_gen.cpu().numpy()
|
| 607 |
+
self._log_audio(f"val/{i}_audio_real", audio_real)
|
| 608 |
+
self._log_audio(f"val/{i}_audio_gen", audio_gen)
|
| 609 |
+
|
| 610 |
+
self.validation_examples = []
|
| 611 |
+
|
| 612 |
+
# Clear vocoder to free memory
|
| 613 |
+
self.vocoder = None
|
| 614 |
+
|
| 615 |
+
def _log_figure(self, tag: str, fig):
|
| 616 |
+
"""Log a matplotlib figure to the logger."""
|
| 617 |
+
if isinstance(self.logger, TensorBoardLogger):
|
| 618 |
+
self.logger.experiment.add_figure(tag, fig, self.global_step)
|
| 619 |
+
elif isinstance(self.logger, WandbLogger):
|
| 620 |
+
import PIL.Image as Image
|
| 621 |
+
|
| 622 |
+
fig.canvas.draw()
|
| 623 |
+
image = Image.frombytes("RGBa", fig.canvas.get_width_height(), fig.canvas.buffer_rgba())
|
| 624 |
+
image = image.convert("RGB")
|
| 625 |
+
self.logger.log_image(tag, [image], step=self.global_step)
|
| 626 |
+
|
| 627 |
+
def _log_audio(self, tag: str, audio: np.ndarray):
|
| 628 |
+
"""Log an audio sample to the logger."""
|
| 629 |
+
if isinstance(self.logger, TensorBoardLogger):
|
| 630 |
+
self.logger.experiment.add_audio(tag, audio, self.global_step, sample_rate=self.model.config.sample_rate)
|
| 631 |
+
elif isinstance(self.logger, WandbLogger):
|
| 632 |
+
self.logger.log_audio(
|
| 633 |
+
tag, [audio.flatten()], sample_rate=[self.model.config.sample_rate], step=self.global_step
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
def _get_spectrogram_plot(self, mel: torch.Tensor):
|
| 637 |
+
from matplotlib import pyplot as plt
|
| 638 |
+
|
| 639 |
+
mel = mel.detach().cpu().numpy()
|
| 640 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 641 |
+
im = ax.imshow(mel, aspect="auto", origin="lower", cmap="magma", vmin=-8.0, vmax=5.0)
|
| 642 |
+
fig.colorbar(im, ax=ax)
|
| 643 |
+
ax.set_ylabel("Mel bins")
|
| 644 |
+
ax.set_xlabel("Time steps")
|
| 645 |
+
fig.tight_layout()
|
| 646 |
+
return fig
|
| 647 |
+
|
| 648 |
+
def _load_weights(self, ckpt_path: str | None, model_state_dict: dict[str, torch.Tensor] | None = None):
|
| 649 |
+
"""Load model and discriminator weights from checkpoint. Supports .ckpt (Lightning), .safetensors, .pt/.pth formats.
|
| 650 |
+
If model_state_dict is provided, load weights from it instead of ckpt_path."""
|
| 651 |
+
|
| 652 |
+
def select_keys(state_dict: dict, prefix: str) -> dict:
|
| 653 |
+
"""Select keys from state_dict that start with the given prefix. Remove the prefix from keys."""
|
| 654 |
+
return {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
| 655 |
+
|
| 656 |
+
def remove_prefix(state_dict: dict, prefix: str) -> dict:
|
| 657 |
+
"""Remove a prefix from keys that start with that prefix."""
|
| 658 |
+
return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
|
| 659 |
+
|
| 660 |
+
def add_prefix(state_dict: dict, prefix: str) -> dict:
|
| 661 |
+
"""Add a prefix to keys that do not start with that prefix."""
|
| 662 |
+
return {f"{prefix}{k}" if not k.startswith(prefix) else k: v for k, v in state_dict.items()}
|
| 663 |
+
|
| 664 |
+
# Load state dict
|
| 665 |
+
if model_state_dict is not None:
|
| 666 |
+
# Load from provided state dict
|
| 667 |
+
disc_state_dict = {}
|
| 668 |
+
elif ckpt_path.endswith(".ckpt"):
|
| 669 |
+
# Lightning checkpoint
|
| 670 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 671 |
+
model_state_dict = select_keys(checkpoint["state_dict"], "model.")
|
| 672 |
+
disc_state_dict = select_keys(checkpoint["state_dict"], "discriminator.")
|
| 673 |
+
elif ckpt_path.endswith(".safetensors"):
|
| 674 |
+
# Safetensors checkpoint
|
| 675 |
+
from safetensors.torch import load_file
|
| 676 |
+
|
| 677 |
+
checkpoint = load_file(ckpt_path, device="cpu")
|
| 678 |
+
model_state_dict = checkpoint
|
| 679 |
+
disc_state_dict = {}
|
| 680 |
+
elif ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
| 681 |
+
# Standard PyTorch checkpoint
|
| 682 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 683 |
+
model_state_dict = checkpoint
|
| 684 |
+
disc_state_dict = {}
|
| 685 |
+
else:
|
| 686 |
+
raise ValueError(f"Unsupported checkpoint format: {ckpt_path}")
|
| 687 |
+
|
| 688 |
+
# Load model weights
|
| 689 |
+
model_state_dict = remove_prefix(model_state_dict, "_orig_mod.")
|
| 690 |
+
model_state_dict = {
|
| 691 |
+
k: v
|
| 692 |
+
for k, v in model_state_dict.items()
|
| 693 |
+
if not any(k.startswith(module) for module in self.config.skip_loading_modules)
|
| 694 |
+
}
|
| 695 |
+
if self.torch_compiled:
|
| 696 |
+
model_state_dict = add_prefix(model_state_dict, "_orig_mod.")
|
| 697 |
+
|
| 698 |
+
if len(model_state_dict) > 0:
|
| 699 |
+
result = self.model.load_state_dict(model_state_dict, strict=False)
|
| 700 |
+
logger.info(f"Loaded model weights from {ckpt_path or 'provided state_dict'}.")
|
| 701 |
+
if result.missing_keys:
|
| 702 |
+
logger.debug(f"Missing keys in model state_dict: {result.missing_keys}")
|
| 703 |
+
if result.unexpected_keys:
|
| 704 |
+
logger.debug(f"Unexpected keys in model state_dict: {result.unexpected_keys}")
|
| 705 |
+
|
| 706 |
+
# Load discriminator weights if available
|
| 707 |
+
if self.use_discriminator:
|
| 708 |
+
disc_state_dict = remove_prefix(disc_state_dict, "_orig_mod.")
|
| 709 |
+
if self.torch_compiled:
|
| 710 |
+
disc_state_dict = add_prefix(disc_state_dict, "_orig_mod.")
|
| 711 |
+
|
| 712 |
+
if len(disc_state_dict) > 0:
|
| 713 |
+
result = self.discriminator.load_state_dict(disc_state_dict, strict=False)
|
| 714 |
+
logger.info(f"Loaded discriminator weights from {ckpt_path}.")
|
| 715 |
+
if result.missing_keys:
|
| 716 |
+
logger.debug(f"Missing keys in discriminator state_dict: {result.missing_keys}")
|
| 717 |
+
if result.unexpected_keys:
|
| 718 |
+
logger.debug(f"Unexpected keys in discriminator state_dict: {result.unexpected_keys}")
|
| 719 |
+
|
| 720 |
+
@classmethod
|
| 721 |
+
def from_hparams(cls, config_path: str) -> "KanadePipeline":
|
| 722 |
+
"""Instantiate KanadePipeline from config file.
|
| 723 |
+
Args:
|
| 724 |
+
config_path (str): Path to model configuration file (.yaml).
|
| 725 |
+
Returns:
|
| 726 |
+
KanadePipeline: Instantiated KanadePipeline.
|
| 727 |
+
"""
|
| 728 |
+
# Load config
|
| 729 |
+
with open(config_path, "r") as f:
|
| 730 |
+
config = yaml.safe_load(f)
|
| 731 |
+
|
| 732 |
+
# Remove related fields to prevent loading actual weights here
|
| 733 |
+
new_config = {"model": config["model"]}
|
| 734 |
+
pipeline_config = new_config["model"]["init_args"]["pipeline_config"]
|
| 735 |
+
if "ckpt_path" in pipeline_config:
|
| 736 |
+
del pipeline_config["ckpt_path"]
|
| 737 |
+
if "skip_loading_modules" in pipeline_config:
|
| 738 |
+
del pipeline_config["skip_loading_modules"]
|
| 739 |
+
|
| 740 |
+
# Instantiate model using jsonargparse
|
| 741 |
+
parser = jsonargparse.ArgumentParser(exit_on_error=False)
|
| 742 |
+
parser.add_argument("--model", type=KanadePipeline)
|
| 743 |
+
cfg = parser.parse_object(new_config)
|
| 744 |
+
cfg = parser.instantiate_classes(cfg)
|
| 745 |
+
return cfg.model
|
| 746 |
+
|
| 747 |
+
@staticmethod
|
| 748 |
+
def from_pretrained(config_path: str, ckpt_path: str) -> "KanadePipeline":
|
| 749 |
+
"""Load KanadePipeline from training configuration and checkpoint files.
|
| 750 |
+
Args:
|
| 751 |
+
config_path: Path to pipeline configuration file (YAML).
|
| 752 |
+
ckpt_path: Path to checkpoint file (.ckpt) or model weights (.safetensors).
|
| 753 |
+
Returns:
|
| 754 |
+
KanadePipeline: Instantied KanadePipeline with loaded weights.
|
| 755 |
+
"""
|
| 756 |
+
# Load pipeline from config
|
| 757 |
+
model = KanadePipeline.from_hparams(config_path)
|
| 758 |
+
# Load the weights
|
| 759 |
+
model._load_weights(ckpt_path)
|
| 760 |
+
return model
|
src/kanade_tokenizer/util.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
# Configure logger
|
| 8 |
+
logger = logging.getLogger("kanade_tokenizer")
|
| 9 |
+
logger.setLevel(logging.INFO)
|
| 10 |
+
handler = logging.StreamHandler()
|
| 11 |
+
handler.setLevel(logging.INFO)
|
| 12 |
+
handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(name)s: %(message)s"))
|
| 13 |
+
logger.addHandler(handler)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_logger() -> logging.Logger:
|
| 17 |
+
return logger
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def freeze_modules(modules: list[nn.Module] | None):
|
| 21 |
+
for module in modules:
|
| 22 |
+
if module is not None:
|
| 23 |
+
for param in module.parameters():
|
| 24 |
+
param.requires_grad = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load_audio_internal(
|
| 28 |
+
path: str, frame_offset: int | None = None, num_frames: int | None = None
|
| 29 |
+
) -> tuple[torch.Tensor, int]:
|
| 30 |
+
# TorchAudio >= 2.9.0 removed decoding and encoding capabilities to TorchCodec.
|
| 31 |
+
# See: https://github.com/pytorch/audio/issues/3902
|
| 32 |
+
# waveform, sample_rate = torchaudio.load(path, frame_offset=frame_offset or 0, num_frames=num_frames or -1)
|
| 33 |
+
|
| 34 |
+
import soundfile as sf
|
| 35 |
+
|
| 36 |
+
with sf.SoundFile(path) as f:
|
| 37 |
+
if frame_offset is not None:
|
| 38 |
+
f.seek(frame_offset)
|
| 39 |
+
frames = f.read(frames=num_frames or -1, dtype="float32", always_2d=True)
|
| 40 |
+
waveform = torch.from_numpy(frames.T)
|
| 41 |
+
sample_rate = f.samplerate
|
| 42 |
+
return waveform, sample_rate
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_audio(audio_path: str, sample_rate: int = 24000) -> torch.Tensor:
|
| 46 |
+
import torchaudio
|
| 47 |
+
|
| 48 |
+
"""Load and preprocess audio file."""
|
| 49 |
+
waveform, sr = _load_audio_internal(audio_path)
|
| 50 |
+
|
| 51 |
+
# Convert to mono if stereo
|
| 52 |
+
if waveform.shape[0] > 1:
|
| 53 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 54 |
+
|
| 55 |
+
# Resample if necessary
|
| 56 |
+
if sr != sample_rate:
|
| 57 |
+
resampler = torchaudio.transforms.Resample(sr, sample_rate)
|
| 58 |
+
waveform = resampler(waveform)
|
| 59 |
+
|
| 60 |
+
# Normalize waveform
|
| 61 |
+
max_val = torch.max(torch.abs(waveform)) + 1e-8
|
| 62 |
+
waveform = waveform / max_val # Normalize to [-1, 1]
|
| 63 |
+
|
| 64 |
+
return waveform.squeeze(0) # Remove channel dimension
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_vocoder(name: Literal["vocos", "hift"] = "vocos") -> torch.nn.Module:
|
| 68 |
+
if name == "vocos":
|
| 69 |
+
from vocos import Vocos
|
| 70 |
+
|
| 71 |
+
model = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
| 72 |
+
model = model.eval()
|
| 73 |
+
return model
|
| 74 |
+
elif name == "hift":
|
| 75 |
+
from huggingface_hub import hf_hub_download
|
| 76 |
+
from .module.hift import HiFTGenerator
|
| 77 |
+
|
| 78 |
+
# Download hte HiFT model from FunAudioLLM/CosyVoice2-0.5B
|
| 79 |
+
model_path = hf_hub_download(repo_id="FunAudioLLM/CosyVoice2-0.5B", filename="hift.pt")
|
| 80 |
+
model = HiFTGenerator()
|
| 81 |
+
model.load_weights(model_path)
|
| 82 |
+
model = model.eval()
|
| 83 |
+
return model
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unsupported vocoder name: {name}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def vocode(vocoder, mel_spectrogram: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
"""Convert mel spectrogram to waveform using Vocos vocoder.
|
| 90 |
+
Args:
|
| 91 |
+
vocoder: Pretrained vocoder model.
|
| 92 |
+
mel_spectrogram (torch.Tensor): Input mel spectrogram tensor (..., n_mels, frame).
|
| 93 |
+
Returns:
|
| 94 |
+
torch.Tensor: Generated audio waveform tensor (..., samples).
|
| 95 |
+
"""
|
| 96 |
+
mel_spectrogram = mel_spectrogram.to(torch.float32) # Ensure mel spectrogram is in float32
|
| 97 |
+
|
| 98 |
+
vocoder_class_name = vocoder.__class__.__name__
|
| 99 |
+
if "Vocos" in vocoder_class_name:
|
| 100 |
+
generated_waveform = vocoder.decode(mel_spectrogram)
|
| 101 |
+
elif "HiFT" in vocoder_class_name:
|
| 102 |
+
generated_waveform = vocoder.inference(mel_spectrogram)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Unsupported vocoder class: {vocoder_class_name}")
|
| 105 |
+
|
| 106 |
+
return generated_waveform
|