Spaces:
Runtime error
Runtime error
Upload 25 files
Browse files- kugelaudio_open/__init__.py +73 -0
- kugelaudio_open/cli.py +121 -0
- kugelaudio_open/configs/__init__.py +22 -0
- kugelaudio_open/configs/kugelaudio_1.5b.json +68 -0
- kugelaudio_open/configs/kugelaudio_7b.json +68 -0
- kugelaudio_open/configs/model_config.py +290 -0
- kugelaudio_open/models/__init__.py +47 -0
- kugelaudio_open/models/conv_layers.py +289 -0
- kugelaudio_open/models/diffusion_head.py +288 -0
- kugelaudio_open/models/kugelaudio_inference.py +800 -0
- kugelaudio_open/models/kugelaudio_model.py +721 -0
- kugelaudio_open/models/tokenizer.py +1197 -0
- kugelaudio_open/processors/__init__.py +10 -0
- kugelaudio_open/processors/audio_processor.py +268 -0
- kugelaudio_open/processors/kugelaudio_processor.py +366 -0
- kugelaudio_open/processors/text_tokenizer.py +93 -0
- kugelaudio_open/schedule/__init__.py +5 -0
- kugelaudio_open/schedule/dpm_solver.py +1084 -0
- kugelaudio_open/ui/__init__.py +5 -0
- kugelaudio_open/ui/__main__.py +41 -0
- kugelaudio_open/ui/app.py +506 -0
- kugelaudio_open/utils/__init__.py +5 -0
- kugelaudio_open/utils/generation.py +118 -0
- kugelaudio_open/watermark/__init__.py +5 -0
- kugelaudio_open/watermark/watermark.py +390 -0
kugelaudio_open/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KugelAudio - Open Source Text-to-Speech Model
|
| 2 |
+
|
| 3 |
+
KugelAudio is a state-of-the-art neural text-to-speech model that generates
|
| 4 |
+
natural, expressive speech from text with voice cloning capabilities.
|
| 5 |
+
|
| 6 |
+
Example:
|
| 7 |
+
>>> from kugelaudio import KugelAudioForConditionalGenerationInference
|
| 8 |
+
>>> from transformers import AutoModel
|
| 9 |
+
>>>
|
| 10 |
+
>>> # Load the model
|
| 11 |
+
>>> model = AutoModel.from_pretrained("kugelaudio/kugelaudio-0-open")
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
__version__ = "0.1.0"
|
| 15 |
+
|
| 16 |
+
from .configs import (
|
| 17 |
+
KugelAudioAcousticTokenizerConfig,
|
| 18 |
+
KugelAudioConfig,
|
| 19 |
+
KugelAudioDiffusionHeadConfig,
|
| 20 |
+
KugelAudioSemanticTokenizerConfig,
|
| 21 |
+
)
|
| 22 |
+
from .models import (
|
| 23 |
+
KugelAudioAcousticTokenizerModel,
|
| 24 |
+
KugelAudioDiffusionHead,
|
| 25 |
+
KugelAudioForConditionalGeneration,
|
| 26 |
+
KugelAudioForConditionalGenerationInference,
|
| 27 |
+
KugelAudioModel,
|
| 28 |
+
KugelAudioPreTrainedModel,
|
| 29 |
+
KugelAudioSemanticTokenizerModel,
|
| 30 |
+
)
|
| 31 |
+
from .processors import KugelAudioProcessor
|
| 32 |
+
from .schedule import DPMSolverMultistepScheduler
|
| 33 |
+
from .watermark import AudioWatermark
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Lazy imports for optional components
|
| 37 |
+
def launch_ui(*args, **kwargs):
|
| 38 |
+
"""Launch the Gradio web interface."""
|
| 39 |
+
try:
|
| 40 |
+
from .ui import launch_ui as _launch_ui
|
| 41 |
+
|
| 42 |
+
return _launch_ui(*args, **kwargs)
|
| 43 |
+
except ImportError:
|
| 44 |
+
raise ImportError(
|
| 45 |
+
"Gradio is required for the web interface. " "Install it with: pip install gradio"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
# Version
|
| 51 |
+
"__version__",
|
| 52 |
+
# Configs
|
| 53 |
+
"KugelAudioConfig",
|
| 54 |
+
"KugelAudioAcousticTokenizerConfig",
|
| 55 |
+
"KugelAudioSemanticTokenizerConfig",
|
| 56 |
+
"KugelAudioDiffusionHeadConfig",
|
| 57 |
+
# Models
|
| 58 |
+
"KugelAudioModel",
|
| 59 |
+
"KugelAudioPreTrainedModel",
|
| 60 |
+
"KugelAudioForConditionalGeneration",
|
| 61 |
+
"KugelAudioForConditionalGenerationInference",
|
| 62 |
+
"KugelAudioAcousticTokenizerModel",
|
| 63 |
+
"KugelAudioSemanticTokenizerModel",
|
| 64 |
+
"KugelAudioDiffusionHead",
|
| 65 |
+
# Scheduler
|
| 66 |
+
"DPMSolverMultistepScheduler",
|
| 67 |
+
# Processors
|
| 68 |
+
"KugelAudioProcessor",
|
| 69 |
+
# Watermark
|
| 70 |
+
"AudioWatermark",
|
| 71 |
+
# UI
|
| 72 |
+
"launch_ui",
|
| 73 |
+
]
|
kugelaudio_open/cli.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Command-line interface for KugelAudio."""
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser(
|
| 10 |
+
description="KugelAudio - Open-source text-to-speech",
|
| 11 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 12 |
+
epilog="""
|
| 13 |
+
Examples:
|
| 14 |
+
# Launch web interface
|
| 15 |
+
kugelaudio ui
|
| 16 |
+
|
| 17 |
+
# Launch with public share link
|
| 18 |
+
kugelaudio ui --share
|
| 19 |
+
|
| 20 |
+
# Generate speech from command line
|
| 21 |
+
kugelaudio generate "Hello world!" -o output.wav
|
| 22 |
+
|
| 23 |
+
# Check watermark in audio file
|
| 24 |
+
kugelaudio verify audio.wav
|
| 25 |
+
""",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 29 |
+
|
| 30 |
+
# UI command
|
| 31 |
+
ui_parser = subparsers.add_parser("ui", help="Launch Gradio web interface")
|
| 32 |
+
ui_parser.add_argument("--share", action="store_true", help="Create public share link")
|
| 33 |
+
ui_parser.add_argument("--host", default="127.0.0.1", help="Server hostname")
|
| 34 |
+
ui_parser.add_argument("--port", type=int, default=7860, help="Server port")
|
| 35 |
+
|
| 36 |
+
# Generate command
|
| 37 |
+
gen_parser = subparsers.add_parser("generate", help="Generate speech from text")
|
| 38 |
+
gen_parser.add_argument("text", help="Text to synthesize")
|
| 39 |
+
gen_parser.add_argument("-o", "--output", default="output.wav", help="Output file path")
|
| 40 |
+
gen_parser.add_argument("-r", "--reference", help="Reference audio for voice cloning")
|
| 41 |
+
gen_parser.add_argument("--model", default="kugelaudio/kugelaudio-0-open", help="Model ID")
|
| 42 |
+
gen_parser.add_argument("--cfg-scale", type=float, default=3.0, help="Guidance scale")
|
| 43 |
+
|
| 44 |
+
# Verify command
|
| 45 |
+
verify_parser = subparsers.add_parser("verify", help="Check watermark in audio")
|
| 46 |
+
verify_parser.add_argument("audio", help="Audio file to check")
|
| 47 |
+
|
| 48 |
+
args = parser.parse_args()
|
| 49 |
+
|
| 50 |
+
if args.command == "ui":
|
| 51 |
+
from kugelaudio_open.ui import launch_app
|
| 52 |
+
|
| 53 |
+
launch_app(
|
| 54 |
+
share=args.share,
|
| 55 |
+
server_name=args.host,
|
| 56 |
+
server_port=args.port,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
elif args.command == "generate":
|
| 60 |
+
import torch
|
| 61 |
+
from kugelaudio_open.models import KugelAudioForConditionalGenerationInference
|
| 62 |
+
from kugelaudio_open.processors import KugelAudioProcessor
|
| 63 |
+
|
| 64 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 65 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 66 |
+
|
| 67 |
+
print(f"Loading model {args.model}...")
|
| 68 |
+
model = KugelAudioForConditionalGenerationInference.from_pretrained(
|
| 69 |
+
args.model, torch_dtype=dtype
|
| 70 |
+
).to(device)
|
| 71 |
+
model.eval()
|
| 72 |
+
|
| 73 |
+
processor = KugelAudioProcessor.from_pretrained(args.model)
|
| 74 |
+
|
| 75 |
+
# Process inputs (voice_prompt passed to processor for proper handling)
|
| 76 |
+
inputs = processor(
|
| 77 |
+
text=args.text,
|
| 78 |
+
voice_prompt=args.reference, # Pass reference audio path directly
|
| 79 |
+
return_tensors="pt"
|
| 80 |
+
)
|
| 81 |
+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
| 82 |
+
|
| 83 |
+
print("Generating speech...")
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
outputs = model.generate(
|
| 86 |
+
**inputs,
|
| 87 |
+
cfg_scale=args.cfg_scale,
|
| 88 |
+
max_new_tokens=4096,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Audio is already watermarked by the model's generate method
|
| 92 |
+
audio = outputs.speech_outputs[0]
|
| 93 |
+
|
| 94 |
+
# Save
|
| 95 |
+
processor.save_audio(audio, args.output)
|
| 96 |
+
print(f"Audio saved to {args.output}")
|
| 97 |
+
|
| 98 |
+
elif args.command == "verify":
|
| 99 |
+
import numpy as np
|
| 100 |
+
import soundfile as sf
|
| 101 |
+
from kugelaudio_open.watermark import AudioWatermark
|
| 102 |
+
|
| 103 |
+
audio, sr = sf.read(args.audio)
|
| 104 |
+
|
| 105 |
+
watermark = AudioWatermark()
|
| 106 |
+
result = watermark.detect(audio, sample_rate=sr)
|
| 107 |
+
|
| 108 |
+
if result.detected:
|
| 109 |
+
print(f"β
Watermark DETECTED (confidence: {result.confidence:.1%})")
|
| 110 |
+
print("This audio was generated by KugelAudio.")
|
| 111 |
+
else:
|
| 112 |
+
print(f"β No watermark detected (confidence: {result.confidence:.1%})")
|
| 113 |
+
print("This audio does not appear to be generated by KugelAudio.")
|
| 114 |
+
|
| 115 |
+
else:
|
| 116 |
+
parser.print_help()
|
| 117 |
+
sys.exit(1)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|
kugelaudio_open/configs/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KugelAudio configuration classes."""
|
| 2 |
+
|
| 3 |
+
from .model_config import (
|
| 4 |
+
KugelAudioConfig,
|
| 5 |
+
KugelAudioAcousticTokenizerConfig,
|
| 6 |
+
KugelAudioSemanticTokenizerConfig,
|
| 7 |
+
KugelAudioDiffusionHeadConfig,
|
| 8 |
+
# Aliases
|
| 9 |
+
AcousticTokenizerConfig,
|
| 10 |
+
SemanticTokenizerConfig,
|
| 11 |
+
DiffusionHeadConfig,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"KugelAudioConfig",
|
| 16 |
+
"KugelAudioAcousticTokenizerConfig",
|
| 17 |
+
"KugelAudioSemanticTokenizerConfig",
|
| 18 |
+
"KugelAudioDiffusionHeadConfig",
|
| 19 |
+
"AcousticTokenizerConfig",
|
| 20 |
+
"SemanticTokenizerConfig",
|
| 21 |
+
"DiffusionHeadConfig",
|
| 22 |
+
]
|
kugelaudio_open/configs/kugelaudio_1.5b.json
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "kugelaudio",
|
| 3 |
+
"_attn_implementation_autoset": true,
|
| 4 |
+
"acoustic_vae_dim": 64,
|
| 5 |
+
"tts_backbone_num_hidden_layers": 20,
|
| 6 |
+
"acoustic_tokenizer_config": {
|
| 7 |
+
"model_type": "kugelaudio_acoustic_tokenizer",
|
| 8 |
+
"causal": true,
|
| 9 |
+
"channels": 1,
|
| 10 |
+
"conv_bias": true,
|
| 11 |
+
"conv_norm": "none",
|
| 12 |
+
"decoder_depths": null,
|
| 13 |
+
"decoder_n_filters": 32,
|
| 14 |
+
"decoder_ratios": [8, 5, 5, 4, 2, 2],
|
| 15 |
+
"disable_last_norm": true,
|
| 16 |
+
"encoder_depths": "3-3-3-3-3-3-8",
|
| 17 |
+
"encoder_n_filters": 32,
|
| 18 |
+
"encoder_ratios": [8, 5, 5, 4, 2, 2],
|
| 19 |
+
"fix_std": 0.5,
|
| 20 |
+
"layer_scale_init_value": 1e-06,
|
| 21 |
+
"layernorm": "RMSNorm",
|
| 22 |
+
"layernorm_elementwise_affine": true,
|
| 23 |
+
"layernorm_eps": 1e-05,
|
| 24 |
+
"mixer_layer": "depthwise_conv",
|
| 25 |
+
"pad_mode": "constant",
|
| 26 |
+
"std_dist_type": "gaussian",
|
| 27 |
+
"vae_dim": 64,
|
| 28 |
+
"weight_init_value": 0.01
|
| 29 |
+
},
|
| 30 |
+
"decoder_config": {
|
| 31 |
+
"model_type": "qwen2",
|
| 32 |
+
"attention_dropout": 0.0,
|
| 33 |
+
"hidden_act": "silu",
|
| 34 |
+
"hidden_size": 1536,
|
| 35 |
+
"initializer_range": 0.02,
|
| 36 |
+
"intermediate_size": 8960,
|
| 37 |
+
"max_position_embeddings": 65536,
|
| 38 |
+
"max_window_layers": 28,
|
| 39 |
+
"num_attention_heads": 12,
|
| 40 |
+
"num_hidden_layers": 28,
|
| 41 |
+
"num_key_value_heads": 2,
|
| 42 |
+
"rms_norm_eps": 1e-06,
|
| 43 |
+
"rope_scaling": null,
|
| 44 |
+
"rope_theta": 1000000.0,
|
| 45 |
+
"sliding_window": null,
|
| 46 |
+
"tie_word_embeddings": true,
|
| 47 |
+
"torch_dtype": "bfloat16",
|
| 48 |
+
"use_cache": true,
|
| 49 |
+
"use_sliding_window": false,
|
| 50 |
+
"vocab_size": 151936
|
| 51 |
+
},
|
| 52 |
+
"diffusion_head_config": {
|
| 53 |
+
"model_type": "kugelaudio_diffusion_head",
|
| 54 |
+
"ddpm_batch_mul": 4,
|
| 55 |
+
"ddpm_beta_schedule": "cosine",
|
| 56 |
+
"ddpm_num_inference_steps": 20,
|
| 57 |
+
"ddpm_num_steps": 1000,
|
| 58 |
+
"diffusion_type": "ddpm",
|
| 59 |
+
"head_ffn_ratio": 3.0,
|
| 60 |
+
"head_layers": 4,
|
| 61 |
+
"hidden_size": 1536,
|
| 62 |
+
"latent_size": 64,
|
| 63 |
+
"prediction_type": "v_prediction",
|
| 64 |
+
"rms_norm_eps": 1e-05,
|
| 65 |
+
"speech_vae_dim": 64
|
| 66 |
+
},
|
| 67 |
+
"torch_dtype": "bfloat16"
|
| 68 |
+
}
|
kugelaudio_open/configs/kugelaudio_7b.json
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "kugelaudio",
|
| 3 |
+
"_attn_implementation_autoset": true,
|
| 4 |
+
"acoustic_vae_dim": 64,
|
| 5 |
+
"tts_backbone_num_hidden_layers": 20,
|
| 6 |
+
"acoustic_tokenizer_config": {
|
| 7 |
+
"model_type": "kugelaudio_acoustic_tokenizer",
|
| 8 |
+
"causal": true,
|
| 9 |
+
"channels": 1,
|
| 10 |
+
"conv_bias": true,
|
| 11 |
+
"conv_norm": "none",
|
| 12 |
+
"decoder_depths": null,
|
| 13 |
+
"decoder_n_filters": 32,
|
| 14 |
+
"decoder_ratios": [8, 5, 5, 4, 2, 2],
|
| 15 |
+
"disable_last_norm": true,
|
| 16 |
+
"encoder_depths": "3-3-3-3-3-3-8",
|
| 17 |
+
"encoder_n_filters": 32,
|
| 18 |
+
"encoder_ratios": [8, 5, 5, 4, 2, 2],
|
| 19 |
+
"fix_std": 0.5,
|
| 20 |
+
"layer_scale_init_value": 1e-06,
|
| 21 |
+
"layernorm": "RMSNorm",
|
| 22 |
+
"layernorm_elementwise_affine": true,
|
| 23 |
+
"layernorm_eps": 1e-05,
|
| 24 |
+
"mixer_layer": "depthwise_conv",
|
| 25 |
+
"pad_mode": "constant",
|
| 26 |
+
"std_dist_type": "gaussian",
|
| 27 |
+
"vae_dim": 64,
|
| 28 |
+
"weight_init_value": 0.01
|
| 29 |
+
},
|
| 30 |
+
"decoder_config": {
|
| 31 |
+
"model_type": "qwen2",
|
| 32 |
+
"attention_dropout": 0.0,
|
| 33 |
+
"hidden_act": "silu",
|
| 34 |
+
"hidden_size": 3584,
|
| 35 |
+
"initializer_range": 0.02,
|
| 36 |
+
"intermediate_size": 18944,
|
| 37 |
+
"max_position_embeddings": 32768,
|
| 38 |
+
"max_window_layers": 28,
|
| 39 |
+
"num_attention_heads": 28,
|
| 40 |
+
"num_hidden_layers": 28,
|
| 41 |
+
"num_key_value_heads": 4,
|
| 42 |
+
"rms_norm_eps": 1e-06,
|
| 43 |
+
"rope_scaling": null,
|
| 44 |
+
"rope_theta": 1000000.0,
|
| 45 |
+
"sliding_window": null,
|
| 46 |
+
"tie_word_embeddings": false,
|
| 47 |
+
"torch_dtype": "bfloat16",
|
| 48 |
+
"use_cache": true,
|
| 49 |
+
"use_sliding_window": false,
|
| 50 |
+
"vocab_size": 152064
|
| 51 |
+
},
|
| 52 |
+
"diffusion_head_config": {
|
| 53 |
+
"model_type": "kugelaudio_diffusion_head",
|
| 54 |
+
"ddpm_batch_mul": 4,
|
| 55 |
+
"ddpm_beta_schedule": "cosine",
|
| 56 |
+
"ddpm_num_inference_steps": 20,
|
| 57 |
+
"ddpm_num_steps": 1000,
|
| 58 |
+
"diffusion_type": "ddpm",
|
| 59 |
+
"head_ffn_ratio": 3.0,
|
| 60 |
+
"head_layers": 4,
|
| 61 |
+
"hidden_size": 3584,
|
| 62 |
+
"latent_size": 64,
|
| 63 |
+
"prediction_type": "v_prediction",
|
| 64 |
+
"rms_norm_eps": 1e-05,
|
| 65 |
+
"speech_vae_dim": 64
|
| 66 |
+
},
|
| 67 |
+
"torch_dtype": "bfloat16"
|
| 68 |
+
}
|
kugelaudio_open/configs/model_config.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration classes for KugelAudio models."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, List, Union
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
| 6 |
+
from transformers.utils import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class KugelAudioAcousticTokenizerConfig(PretrainedConfig):
|
| 12 |
+
"""Configuration for the acoustic tokenizer.
|
| 13 |
+
|
| 14 |
+
The acoustic tokenizer converts continuous speech latents back to audio waveforms.
|
| 15 |
+
It uses a hierarchical convolutional architecture with multiple upsampling stages.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
model_type = "kugelaudio_acoustic_tokenizer"
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
channels: int = 1,
|
| 23 |
+
corpus_normalize: float = 0.0,
|
| 24 |
+
causal: bool = True,
|
| 25 |
+
vae_dim: int = 64,
|
| 26 |
+
fix_std: float = 0.5,
|
| 27 |
+
std_dist_type: str = "gaussian",
|
| 28 |
+
# Common settings
|
| 29 |
+
mixer_layer: str = "depthwise_conv",
|
| 30 |
+
conv_norm: str = "none",
|
| 31 |
+
pad_mode: str = "constant",
|
| 32 |
+
disable_last_norm: bool = True,
|
| 33 |
+
layernorm: str = "RMSNorm",
|
| 34 |
+
layernorm_eps: float = 1e-5,
|
| 35 |
+
layernorm_elementwise_affine: bool = True,
|
| 36 |
+
conv_bias: bool = True,
|
| 37 |
+
layer_scale_init_value: float = 1e-6,
|
| 38 |
+
weight_init_value: float = 1e-2,
|
| 39 |
+
# Encoder specific
|
| 40 |
+
encoder_n_filters: int = 32,
|
| 41 |
+
encoder_ratios: Optional[List[int]] = None,
|
| 42 |
+
encoder_depths: str = "3-3-3-3-3-3-8",
|
| 43 |
+
# Decoder specific
|
| 44 |
+
decoder_n_filters: int = 32,
|
| 45 |
+
decoder_ratios: Optional[List[int]] = None,
|
| 46 |
+
decoder_depths: Optional[str] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
super().__init__(**kwargs)
|
| 50 |
+
self.channels = channels
|
| 51 |
+
self.corpus_normalize = corpus_normalize
|
| 52 |
+
self.causal = causal
|
| 53 |
+
self.vae_dim = vae_dim
|
| 54 |
+
self.fix_std = fix_std
|
| 55 |
+
self.std_dist_type = std_dist_type
|
| 56 |
+
|
| 57 |
+
# Common parameters
|
| 58 |
+
self.conv_norm = conv_norm
|
| 59 |
+
self.pad_mode = pad_mode
|
| 60 |
+
self.layernorm_eps = layernorm_eps
|
| 61 |
+
self.disable_last_norm = disable_last_norm
|
| 62 |
+
self.layernorm = layernorm
|
| 63 |
+
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
| 64 |
+
self.conv_bias = conv_bias
|
| 65 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 66 |
+
self.weight_init_value = weight_init_value
|
| 67 |
+
self.mixer_layer = mixer_layer
|
| 68 |
+
|
| 69 |
+
# Encoder specific parameters
|
| 70 |
+
self.encoder_n_filters = encoder_n_filters
|
| 71 |
+
self.encoder_ratios = encoder_ratios if encoder_ratios is not None else [8, 5, 5, 4, 2, 2]
|
| 72 |
+
self.encoder_depths = encoder_depths
|
| 73 |
+
|
| 74 |
+
# Decoder specific parameters
|
| 75 |
+
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else self.encoder_ratios
|
| 76 |
+
self.decoder_n_filters = decoder_n_filters
|
| 77 |
+
self.decoder_depths = decoder_depths
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class KugelAudioSemanticTokenizerConfig(PretrainedConfig):
|
| 81 |
+
"""Configuration for the semantic tokenizer.
|
| 82 |
+
|
| 83 |
+
The semantic tokenizer extracts semantic features from audio for conditioning.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
model_type = "kugelaudio_semantic_tokenizer"
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
channels: int = 1,
|
| 91 |
+
corpus_normalize: float = 0.0,
|
| 92 |
+
causal: bool = True,
|
| 93 |
+
vae_dim: int = 64,
|
| 94 |
+
fix_std: float = 0,
|
| 95 |
+
std_dist_type: str = "none",
|
| 96 |
+
# Common settings
|
| 97 |
+
mixer_layer: str = "depthwise_conv",
|
| 98 |
+
conv_norm: str = "none",
|
| 99 |
+
pad_mode: str = "constant",
|
| 100 |
+
disable_last_norm: bool = True,
|
| 101 |
+
layernorm: str = "RMSNorm",
|
| 102 |
+
layernorm_eps: float = 1e-5,
|
| 103 |
+
layernorm_elementwise_affine: bool = True,
|
| 104 |
+
conv_bias: bool = True,
|
| 105 |
+
layer_scale_init_value: float = 1e-6,
|
| 106 |
+
weight_init_value: float = 1e-2,
|
| 107 |
+
# Encoder specific
|
| 108 |
+
encoder_n_filters: int = 32,
|
| 109 |
+
encoder_ratios: Optional[List[int]] = None,
|
| 110 |
+
encoder_depths: str = "3-3-3-3-3-3-8",
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
super().__init__(**kwargs)
|
| 114 |
+
self.channels = channels
|
| 115 |
+
self.corpus_normalize = corpus_normalize
|
| 116 |
+
self.causal = causal
|
| 117 |
+
self.vae_dim = vae_dim
|
| 118 |
+
self.fix_std = fix_std
|
| 119 |
+
self.std_dist_type = std_dist_type
|
| 120 |
+
|
| 121 |
+
# Common parameters
|
| 122 |
+
self.conv_norm = conv_norm
|
| 123 |
+
self.pad_mode = pad_mode
|
| 124 |
+
self.layernorm_eps = layernorm_eps
|
| 125 |
+
self.disable_last_norm = disable_last_norm
|
| 126 |
+
self.layernorm = layernorm
|
| 127 |
+
self.layernorm_elementwise_affine = layernorm_elementwise_affine
|
| 128 |
+
self.conv_bias = conv_bias
|
| 129 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 130 |
+
self.weight_init_value = weight_init_value
|
| 131 |
+
self.mixer_layer = mixer_layer
|
| 132 |
+
|
| 133 |
+
# Encoder specific parameters
|
| 134 |
+
self.encoder_n_filters = encoder_n_filters
|
| 135 |
+
self.encoder_ratios = encoder_ratios if encoder_ratios is not None else [8, 5, 5, 4, 2, 2]
|
| 136 |
+
self.encoder_depths = encoder_depths
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class KugelAudioDiffusionHeadConfig(PretrainedConfig):
|
| 140 |
+
"""Configuration for the diffusion prediction head.
|
| 141 |
+
|
| 142 |
+
The diffusion head predicts speech latents from text-conditioned hidden states
|
| 143 |
+
using a denoising diffusion process.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
model_type = "kugelaudio_diffusion_head"
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
hidden_size: int = 768,
|
| 151 |
+
head_layers: int = 4,
|
| 152 |
+
head_ffn_ratio: float = 3.0,
|
| 153 |
+
rms_norm_eps: float = 1e-5,
|
| 154 |
+
latent_size: int = 64,
|
| 155 |
+
speech_vae_dim: Optional[int] = None,
|
| 156 |
+
prediction_type: str = "v_prediction",
|
| 157 |
+
diffusion_type: str = "ddpm",
|
| 158 |
+
ddpm_num_steps: int = 1000,
|
| 159 |
+
ddpm_num_inference_steps: int = 20,
|
| 160 |
+
ddpm_beta_schedule: str = "cosine",
|
| 161 |
+
ddpm_algorithm_type: str = "sde-dpmsolver++",
|
| 162 |
+
ddpm_batch_mul: int = 4,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
self.hidden_size = hidden_size
|
| 166 |
+
self.head_layers = head_layers
|
| 167 |
+
self.head_ffn_ratio = head_ffn_ratio
|
| 168 |
+
self.rms_norm_eps = rms_norm_eps
|
| 169 |
+
self.latent_size = latent_size
|
| 170 |
+
self.speech_vae_dim = speech_vae_dim
|
| 171 |
+
self.prediction_type = prediction_type
|
| 172 |
+
self.diffusion_type = diffusion_type
|
| 173 |
+
self.ddpm_num_steps = ddpm_num_steps
|
| 174 |
+
self.ddpm_num_inference_steps = ddpm_num_inference_steps
|
| 175 |
+
self.ddpm_beta_schedule = ddpm_beta_schedule
|
| 176 |
+
self.ddpm_algorithm_type = ddpm_algorithm_type
|
| 177 |
+
self.ddpm_batch_mul = ddpm_batch_mul
|
| 178 |
+
|
| 179 |
+
super().__init__(**kwargs)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class KugelAudioConfig(PretrainedConfig):
|
| 183 |
+
"""Main configuration for KugelAudio TTS model.
|
| 184 |
+
|
| 185 |
+
This configuration combines:
|
| 186 |
+
- A language model backbone (Qwen2) for text understanding
|
| 187 |
+
- An acoustic tokenizer for audio encoding/decoding
|
| 188 |
+
- A semantic tokenizer for semantic feature extraction
|
| 189 |
+
- A diffusion head for speech latent prediction
|
| 190 |
+
|
| 191 |
+
Example:
|
| 192 |
+
>>> from kugelaudio import KugelAudioConfig
|
| 193 |
+
>>> config = KugelAudioConfig.from_pretrained("kugelaudio/kugelaudio-0-open")
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
model_type = "kugelaudio"
|
| 197 |
+
is_composition = True
|
| 198 |
+
|
| 199 |
+
sub_configs = {
|
| 200 |
+
"acoustic_tokenizer_config": KugelAudioAcousticTokenizerConfig,
|
| 201 |
+
"semantic_tokenizer_config": KugelAudioSemanticTokenizerConfig,
|
| 202 |
+
"decoder_config": Qwen2Config,
|
| 203 |
+
"diffusion_head_config": KugelAudioDiffusionHeadConfig,
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# Tensor parallel plan for distributed inference
|
| 207 |
+
base_model_tp_plan = {
|
| 208 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 209 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 210 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 211 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 212 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 213 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 214 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
acoustic_tokenizer_config=None,
|
| 220 |
+
semantic_tokenizer_config=None,
|
| 221 |
+
decoder_config=None,
|
| 222 |
+
diffusion_head_config=None,
|
| 223 |
+
**kwargs,
|
| 224 |
+
):
|
| 225 |
+
# Disable auto attention implementation selection
|
| 226 |
+
kwargs["_attn_implementation_autoset"] = False
|
| 227 |
+
|
| 228 |
+
# Initialize acoustic tokenizer config
|
| 229 |
+
if acoustic_tokenizer_config is None:
|
| 230 |
+
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
| 231 |
+
elif isinstance(acoustic_tokenizer_config, dict):
|
| 232 |
+
acoustic_tokenizer_config["model_type"] = "kugelaudio_acoustic_tokenizer"
|
| 233 |
+
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
| 234 |
+
elif isinstance(acoustic_tokenizer_config, KugelAudioAcousticTokenizerConfig):
|
| 235 |
+
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
| 236 |
+
|
| 237 |
+
# Initialize semantic tokenizer config
|
| 238 |
+
if semantic_tokenizer_config is None:
|
| 239 |
+
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
| 240 |
+
elif isinstance(semantic_tokenizer_config, dict):
|
| 241 |
+
semantic_tokenizer_config["model_type"] = "kugelaudio_semantic_tokenizer"
|
| 242 |
+
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
| 243 |
+
elif isinstance(semantic_tokenizer_config, KugelAudioSemanticTokenizerConfig):
|
| 244 |
+
self.semantic_tokenizer_config = semantic_tokenizer_config
|
| 245 |
+
|
| 246 |
+
# Initialize decoder (language model) config
|
| 247 |
+
if decoder_config is None:
|
| 248 |
+
self.decoder_config = self.sub_configs["decoder_config"]()
|
| 249 |
+
elif isinstance(decoder_config, dict):
|
| 250 |
+
if decoder_config.get("model_type", "") == "qwen2":
|
| 251 |
+
self.decoder_config = Qwen2Config(**decoder_config)
|
| 252 |
+
else:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
f"Unsupported decoder model type: {decoder_config.get('model_type', '')}"
|
| 255 |
+
)
|
| 256 |
+
elif isinstance(decoder_config, Qwen2Config):
|
| 257 |
+
self.decoder_config = decoder_config
|
| 258 |
+
|
| 259 |
+
# Initialize diffusion head config
|
| 260 |
+
if diffusion_head_config is None:
|
| 261 |
+
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
|
| 262 |
+
elif isinstance(diffusion_head_config, dict):
|
| 263 |
+
diffusion_head_config["model_type"] = "kugelaudio_diffusion_head"
|
| 264 |
+
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
|
| 265 |
+
elif isinstance(diffusion_head_config, KugelAudioDiffusionHeadConfig):
|
| 266 |
+
self.diffusion_head_config = diffusion_head_config
|
| 267 |
+
|
| 268 |
+
# Derived parameters
|
| 269 |
+
self.acoustic_vae_dim = self.acoustic_tokenizer_config.vae_dim
|
| 270 |
+
self.semantic_vae_dim = self.semantic_tokenizer_config.vae_dim
|
| 271 |
+
|
| 272 |
+
super().__init__(**kwargs)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Aliases for backwards compatibility
|
| 276 |
+
AcousticTokenizerConfig = KugelAudioAcousticTokenizerConfig
|
| 277 |
+
SemanticTokenizerConfig = KugelAudioSemanticTokenizerConfig
|
| 278 |
+
DiffusionHeadConfig = KugelAudioDiffusionHeadConfig
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
__all__ = [
|
| 282 |
+
"KugelAudioAcousticTokenizerConfig",
|
| 283 |
+
"KugelAudioSemanticTokenizerConfig",
|
| 284 |
+
"KugelAudioDiffusionHeadConfig",
|
| 285 |
+
"KugelAudioConfig",
|
| 286 |
+
# Aliases
|
| 287 |
+
"AcousticTokenizerConfig",
|
| 288 |
+
"SemanticTokenizerConfig",
|
| 289 |
+
"DiffusionHeadConfig",
|
| 290 |
+
]
|
kugelaudio_open/models/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KugelAudio model components."""
|
| 2 |
+
|
| 3 |
+
from .kugelaudio_model import (
|
| 4 |
+
KugelAudioModel,
|
| 5 |
+
KugelAudioPreTrainedModel,
|
| 6 |
+
KugelAudioForConditionalGeneration,
|
| 7 |
+
)
|
| 8 |
+
from .kugelaudio_inference import (
|
| 9 |
+
KugelAudioForConditionalGenerationInference,
|
| 10 |
+
KugelAudioCausalLMOutputWithPast,
|
| 11 |
+
KugelAudioGenerationOutput,
|
| 12 |
+
)
|
| 13 |
+
from .tokenizer import (
|
| 14 |
+
KugelAudioAcousticTokenizerModel,
|
| 15 |
+
KugelAudioSemanticTokenizerModel,
|
| 16 |
+
KugelAudioTokenizerEncoderOutput,
|
| 17 |
+
)
|
| 18 |
+
from .diffusion_head import KugelAudioDiffusionHead
|
| 19 |
+
from .conv_layers import (
|
| 20 |
+
RMSNorm,
|
| 21 |
+
ConvRMSNorm,
|
| 22 |
+
ConvLayerNorm,
|
| 23 |
+
SConv1d,
|
| 24 |
+
SConvTranspose1d,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
# Main models
|
| 29 |
+
"KugelAudioModel",
|
| 30 |
+
"KugelAudioPreTrainedModel",
|
| 31 |
+
"KugelAudioForConditionalGeneration",
|
| 32 |
+
"KugelAudioForConditionalGenerationInference",
|
| 33 |
+
# Outputs
|
| 34 |
+
"KugelAudioCausalLMOutputWithPast",
|
| 35 |
+
"KugelAudioGenerationOutput",
|
| 36 |
+
# Tokenizers
|
| 37 |
+
"KugelAudioAcousticTokenizerModel",
|
| 38 |
+
"KugelAudioSemanticTokenizerModel",
|
| 39 |
+
"KugelAudioTokenizerEncoderOutput",
|
| 40 |
+
# Components
|
| 41 |
+
"KugelAudioDiffusionHead",
|
| 42 |
+
"RMSNorm",
|
| 43 |
+
"ConvRMSNorm",
|
| 44 |
+
"ConvLayerNorm",
|
| 45 |
+
"SConv1d",
|
| 46 |
+
"SConvTranspose1d",
|
| 47 |
+
]
|
kugelaudio_open/models/conv_layers.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convolutional layers for KugelAudio tokenizers.
|
| 2 |
+
|
| 3 |
+
This module provides the building blocks for the acoustic and semantic tokenizers,
|
| 4 |
+
including streaming-capable convolutions and normalization layers.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import typing as tp
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from transformers.utils import logging
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Normalization modules
|
| 21 |
+
class ConvLayerNorm(nn.LayerNorm):
|
| 22 |
+
"""
|
| 23 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
| 24 |
+
before running the normalization and moves them back to original position right after.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
| 27 |
+
super().__init__(normalized_shape, **kwargs)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x = x.transpose(1, 2) # b ... t -> b t ...
|
| 31 |
+
x = nn.functional.layer_norm(
|
| 32 |
+
x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
|
| 33 |
+
).type_as(x)
|
| 34 |
+
x = x.transpose(1, 2) # b t ... -> b ... t
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class RMSNorm(nn.Module):
|
| 39 |
+
"""Root Mean Square Layer Normalization."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.dim = dim
|
| 44 |
+
self.eps = eps
|
| 45 |
+
self.elementwise_affine = elementwise_affine
|
| 46 |
+
if self.elementwise_affine:
|
| 47 |
+
weight_shape = (dim,) if weight_shape is None else weight_shape
|
| 48 |
+
self.weight = nn.Parameter(torch.ones(weight_shape))
|
| 49 |
+
else:
|
| 50 |
+
self.register_parameter('weight', None)
|
| 51 |
+
|
| 52 |
+
def _norm(self, x):
|
| 53 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
output = self._norm(x.float()).type_as(x)
|
| 57 |
+
if self.weight is not None:
|
| 58 |
+
output = output * self.weight
|
| 59 |
+
return output
|
| 60 |
+
|
| 61 |
+
def extra_repr(self) -> str:
|
| 62 |
+
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ConvRMSNorm(RMSNorm):
|
| 66 |
+
"""Convolution-friendly RMSNorm."""
|
| 67 |
+
|
| 68 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| 69 |
+
super().__init__(dim, eps, elementwise_affine, weight_shape)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = x.transpose(1, 2) # b ... t -> b t ...
|
| 73 |
+
output = self._norm(x.float()).type_as(x)
|
| 74 |
+
if self.weight is not None:
|
| 75 |
+
output = output * self.weight
|
| 76 |
+
output = output.transpose(1, 2) # b t ... -> b ... t
|
| 77 |
+
return output
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Convolutional layers and utilities
|
| 81 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 82 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
| 86 |
+
assert norm in CONV_NORMALIZATIONS
|
| 87 |
+
if norm == 'weight_norm':
|
| 88 |
+
return nn.utils.weight_norm(module)
|
| 89 |
+
elif norm == 'spectral_norm':
|
| 90 |
+
return nn.utils.spectral_norm(module)
|
| 91 |
+
else:
|
| 92 |
+
return module
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
| 96 |
+
"""Return the proper normalization module."""
|
| 97 |
+
assert norm in CONV_NORMALIZATIONS
|
| 98 |
+
if norm == 'layer_norm':
|
| 99 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 100 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
| 101 |
+
elif norm == 'time_group_norm':
|
| 102 |
+
if causal:
|
| 103 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 104 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 105 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 106 |
+
else:
|
| 107 |
+
return nn.Identity()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| 111 |
+
padding_total: int = 0) -> int:
|
| 112 |
+
"""Calculate extra padding needed for convolution to have the same output length."""
|
| 113 |
+
length = x.shape[-1]
|
| 114 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 115 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 116 |
+
return ideal_length - length
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
| 120 |
+
"""Pad 1D input with handling for small inputs in reflect mode."""
|
| 121 |
+
length = x.shape[-1]
|
| 122 |
+
padding_left, padding_right = paddings
|
| 123 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 124 |
+
if mode == 'reflect':
|
| 125 |
+
max_pad = max(padding_left, padding_right)
|
| 126 |
+
extra_pad = 0
|
| 127 |
+
if length <= max_pad:
|
| 128 |
+
extra_pad = max_pad - length + 1
|
| 129 |
+
x = F.pad(x, (0, extra_pad))
|
| 130 |
+
padded = F.pad(x, paddings, mode, value)
|
| 131 |
+
end = padded.shape[-1] - extra_pad
|
| 132 |
+
return padded[..., :end]
|
| 133 |
+
else:
|
| 134 |
+
return F.pad(x, paddings, mode, value)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 138 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 139 |
+
padding_left, padding_right = paddings
|
| 140 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 141 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 142 |
+
end = x.shape[-1] - padding_right
|
| 143 |
+
return x[..., padding_left: end]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class NormConv1d(nn.Module):
|
| 147 |
+
"""Wrapper around Conv1d and normalization applied to this conv."""
|
| 148 |
+
|
| 149 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 150 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 153 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 154 |
+
self.norm_type = norm
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
x = self.conv(x)
|
| 158 |
+
x = self.norm(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class NormConvTranspose1d(nn.Module):
|
| 163 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv."""
|
| 164 |
+
|
| 165 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 166 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| 169 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 170 |
+
self.norm_type = norm
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
x = self.convtr(x)
|
| 174 |
+
x = self.norm(x)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class SConv1d(nn.Module):
|
| 179 |
+
"""Conv1d with built-in handling of asymmetric or causal padding and normalization."""
|
| 180 |
+
|
| 181 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 182 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
| 183 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
| 184 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 185 |
+
pad_mode: str = 'reflect'):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| 188 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| 189 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
| 190 |
+
self.causal = causal
|
| 191 |
+
self.pad_mode = pad_mode
|
| 192 |
+
|
| 193 |
+
# Store configuration
|
| 194 |
+
self.kernel_size = kernel_size
|
| 195 |
+
self.dilation = dilation
|
| 196 |
+
self.stride = stride
|
| 197 |
+
self.in_channels = in_channels
|
| 198 |
+
self.out_channels = out_channels
|
| 199 |
+
|
| 200 |
+
# For non-streaming mode, calculate padding
|
| 201 |
+
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
"""Forward pass (non-streaming)."""
|
| 205 |
+
B, C, T = x.shape
|
| 206 |
+
kernel_size = self.kernel_size
|
| 207 |
+
stride = self.stride
|
| 208 |
+
dilation = self.dilation
|
| 209 |
+
padding_total = self.padding_total
|
| 210 |
+
|
| 211 |
+
# Compute extra padding for stride alignment
|
| 212 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 213 |
+
|
| 214 |
+
if self.causal:
|
| 215 |
+
# Left padding for causal
|
| 216 |
+
if self.pad_mode == 'constant':
|
| 217 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
|
| 218 |
+
else:
|
| 219 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 220 |
+
else:
|
| 221 |
+
# Symmetric padding for non-causal
|
| 222 |
+
padding_right = padding_total // 2
|
| 223 |
+
padding_left = padding_total - padding_right
|
| 224 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
| 225 |
+
|
| 226 |
+
output = self.conv(x)
|
| 227 |
+
return output
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class SConvTranspose1d(nn.Module):
|
| 231 |
+
"""ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
|
| 232 |
+
|
| 233 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 234 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
| 235 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
| 236 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| 239 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
|
| 240 |
+
self.causal = causal
|
| 241 |
+
self.trim_right_ratio = trim_right_ratio
|
| 242 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
| 243 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 244 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
| 245 |
+
|
| 246 |
+
# Store configuration
|
| 247 |
+
self.kernel_size = kernel_size
|
| 248 |
+
self.stride = stride
|
| 249 |
+
self.in_channels = in_channels
|
| 250 |
+
self.out_channels = out_channels
|
| 251 |
+
|
| 252 |
+
# For transposed convolution, padding calculation is different
|
| 253 |
+
self.padding_total = kernel_size - stride
|
| 254 |
+
|
| 255 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 256 |
+
"""Forward pass (non-streaming)."""
|
| 257 |
+
kernel_size = self.kernel_size
|
| 258 |
+
stride = self.stride
|
| 259 |
+
padding_total = self.padding_total
|
| 260 |
+
|
| 261 |
+
y = self.convtr(x)
|
| 262 |
+
|
| 263 |
+
# Remove the padding from output
|
| 264 |
+
if self.causal:
|
| 265 |
+
# Trim right side for causal
|
| 266 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 267 |
+
padding_left = padding_total - padding_right
|
| 268 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 269 |
+
else:
|
| 270 |
+
# Symmetric unpadding for non-causal
|
| 271 |
+
padding_right = padding_total // 2
|
| 272 |
+
padding_left = padding_total - padding_right
|
| 273 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 274 |
+
|
| 275 |
+
return y
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
__all__ = [
|
| 279 |
+
"ConvLayerNorm",
|
| 280 |
+
"RMSNorm",
|
| 281 |
+
"ConvRMSNorm",
|
| 282 |
+
"NormConv1d",
|
| 283 |
+
"NormConvTranspose1d",
|
| 284 |
+
"SConv1d",
|
| 285 |
+
"SConvTranspose1d",
|
| 286 |
+
"pad1d",
|
| 287 |
+
"unpad1d",
|
| 288 |
+
"get_extra_padding_for_conv1d",
|
| 289 |
+
]
|
kugelaudio_open/models/diffusion_head.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from transformers.models.auto import AutoModel
|
| 9 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 10 |
+
# from transformers.modeling_layers import GradientCheckpointingLayer
|
| 11 |
+
from transformers.activations import ACT2FN
|
| 12 |
+
from transformers.utils import logging
|
| 13 |
+
|
| 14 |
+
from ..configs import KugelAudioDiffusionHeadConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RMSNorm(nn.Module):
|
| 21 |
+
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.dim = dim
|
| 24 |
+
self.eps = eps
|
| 25 |
+
self.elementwise_affine = elementwise_affine
|
| 26 |
+
if self.elementwise_affine:
|
| 27 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 28 |
+
else:
|
| 29 |
+
self.register_parameter('weight', None)
|
| 30 |
+
|
| 31 |
+
def _norm(self, x):
|
| 32 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
output = self._norm(x.float()).type_as(x)
|
| 36 |
+
if self.weight is not None:
|
| 37 |
+
output = output * self.weight
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
def extra_repr(self) -> str:
|
| 41 |
+
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
| 42 |
+
|
| 43 |
+
def modulate(x, shift, scale):
|
| 44 |
+
"""Apply modulation to input tensor."""
|
| 45 |
+
return x * (1 + scale) + shift
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TimestepEmbedder(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
Embeds scalar timesteps into vector representations.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
hidden_size (`int`): Size of the output embedding
|
| 54 |
+
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
| 55 |
+
"""
|
| 56 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.mlp = nn.Sequential(
|
| 59 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
| 60 |
+
# nn.SiLU(),
|
| 61 |
+
ACT2FN['silu'],
|
| 62 |
+
nn.Linear(hidden_size, hidden_size, bias=False),
|
| 63 |
+
)
|
| 64 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 68 |
+
"""
|
| 69 |
+
Create sinusoidal timestep embeddings.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
| 73 |
+
These may be fractional.
|
| 74 |
+
dim (`int`): The dimension of the output.
|
| 75 |
+
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
| 79 |
+
"""
|
| 80 |
+
half = dim // 2
|
| 81 |
+
# Create freqs directly on the target device to avoid transfers during CUDA graph capture
|
| 82 |
+
freqs = torch.exp(
|
| 83 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
| 84 |
+
)
|
| 85 |
+
args = t[:, None].float() * freqs[None]
|
| 86 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 87 |
+
if dim % 2:
|
| 88 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 89 |
+
return embedding.to(t.dtype)
|
| 90 |
+
|
| 91 |
+
def forward(self, t):
|
| 92 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 93 |
+
t_emb = self.mlp(t_freq)
|
| 94 |
+
return t_emb
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class FeedForwardNetwork(nn.Module):
|
| 98 |
+
"""
|
| 99 |
+
Standard feed-forward network with SwiGLU activation.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
embed_dim (`int`): Input dimension
|
| 103 |
+
ffn_dim (`int`): Hidden dimension
|
| 104 |
+
"""
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
embed_dim,
|
| 108 |
+
ffn_dim,
|
| 109 |
+
):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.embed_dim = embed_dim
|
| 112 |
+
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
| 113 |
+
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
| 114 |
+
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
| 115 |
+
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
gate = self.gate_proj(x)
|
| 119 |
+
up = self.up_proj(x)
|
| 120 |
+
|
| 121 |
+
# SwiGLU activation
|
| 122 |
+
# gate = F.silu(gate)
|
| 123 |
+
gate = self.act_fn(gate)
|
| 124 |
+
return self.down_proj(gate * up)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class HeadLayer(nn.Module):
|
| 128 |
+
"""
|
| 129 |
+
A layer in the diffusion head.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
embed_dim (`int`): Input dimension
|
| 133 |
+
ffn_dim (`int`): Hidden dimension
|
| 134 |
+
cond_dim (`int`): Condition embedding dimension
|
| 135 |
+
norm_eps (`float`, optional): Epsilon for normalization
|
| 136 |
+
"""
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
embed_dim,
|
| 140 |
+
ffn_dim,
|
| 141 |
+
cond_dim,
|
| 142 |
+
norm_eps=1e-5,
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.embed_dim = embed_dim
|
| 146 |
+
self.cond_dim = cond_dim
|
| 147 |
+
self.ffn_dim = ffn_dim
|
| 148 |
+
self.ffn = FeedForwardNetwork(
|
| 149 |
+
self.embed_dim,
|
| 150 |
+
self.ffn_dim,
|
| 151 |
+
)
|
| 152 |
+
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
| 153 |
+
self.adaLN_modulation = nn.Sequential(
|
| 154 |
+
# nn.SiLU(),
|
| 155 |
+
ACT2FN['silu'],
|
| 156 |
+
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x, c):
|
| 160 |
+
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
| 161 |
+
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class FinalLayer(nn.Module):
|
| 166 |
+
"""
|
| 167 |
+
Final layer in the diffusion head.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
hidden_size (`int`): Input dimension
|
| 171 |
+
output_size (`int`): Output dimension
|
| 172 |
+
cond_size (`int`): Condition embedding dimension
|
| 173 |
+
norm_eps (`float`, optional): Epsilon for normalization
|
| 174 |
+
"""
|
| 175 |
+
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
| 178 |
+
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
| 179 |
+
self.adaLN_modulation = nn.Sequential(
|
| 180 |
+
# nn.SiLU(),
|
| 181 |
+
ACT2FN['silu'],
|
| 182 |
+
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def forward(self, x, c):
|
| 186 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 187 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 188 |
+
x = self.linear(x)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class KugelAudioDiffusionHead(PreTrainedModel):
|
| 193 |
+
"""
|
| 194 |
+
Diffusion head model for kugelaudio.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
config (`KugelAudioDiffusionHeadConfig`): Model configuration
|
| 198 |
+
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
| 199 |
+
"""
|
| 200 |
+
config_class = KugelAudioDiffusionHeadConfig
|
| 201 |
+
supports_gradient_checkpointing = True
|
| 202 |
+
_supports_flash_attn_2 = True
|
| 203 |
+
_supports_sdpa = True
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
config,
|
| 208 |
+
):
|
| 209 |
+
super().__init__(config)
|
| 210 |
+
self.config = config
|
| 211 |
+
self.cond_dim = config.hidden_size
|
| 212 |
+
latent_size = config.latent_size
|
| 213 |
+
|
| 214 |
+
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
| 215 |
+
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
| 216 |
+
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
| 217 |
+
|
| 218 |
+
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
| 219 |
+
|
| 220 |
+
# Create the intermediate layers
|
| 221 |
+
self.layers = nn.ModuleList([
|
| 222 |
+
HeadLayer(
|
| 223 |
+
embed_dim=config.hidden_size,
|
| 224 |
+
ffn_dim=ffn_dim,
|
| 225 |
+
cond_dim=self.cond_dim,
|
| 226 |
+
norm_eps=config.rms_norm_eps
|
| 227 |
+
)
|
| 228 |
+
for _ in range(config.head_layers)
|
| 229 |
+
])
|
| 230 |
+
|
| 231 |
+
# Final layer for output
|
| 232 |
+
self.final_layer = FinalLayer(
|
| 233 |
+
hidden_size=config.hidden_size,
|
| 234 |
+
output_size=latent_size,
|
| 235 |
+
cond_size=self.cond_dim,
|
| 236 |
+
norm_eps=config.rms_norm_eps
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.initialize_weights()
|
| 240 |
+
|
| 241 |
+
def initialize_weights(self):
|
| 242 |
+
"""Initialize the weights of the model."""
|
| 243 |
+
# Initialize timestep embedder
|
| 244 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 245 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 246 |
+
|
| 247 |
+
# Zero-out adaLN modulation layers
|
| 248 |
+
for layer in self.layers:
|
| 249 |
+
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
| 250 |
+
|
| 251 |
+
# Zero-out output layers
|
| 252 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 253 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
noisy_images,
|
| 258 |
+
timesteps,
|
| 259 |
+
condition,
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
Forward pass of the prediction head.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
| 266 |
+
timesteps (`torch.Tensor`): Timesteps for diffusion
|
| 267 |
+
condition (`torch.Tensor`): Conditioning information
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
`torch.Tensor`: The predicted noise/velocity
|
| 271 |
+
"""
|
| 272 |
+
x = self.noisy_images_proj(noisy_images)
|
| 273 |
+
t = self.t_embedder(timesteps)
|
| 274 |
+
condition = self.cond_proj(condition)
|
| 275 |
+
c = condition + t
|
| 276 |
+
|
| 277 |
+
for layer in self.layers:
|
| 278 |
+
x = layer(x, c)
|
| 279 |
+
|
| 280 |
+
x = self.final_layer(x, c)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
AutoModel.register(KugelAudioDiffusionHeadConfig, KugelAudioDiffusionHead)
|
| 285 |
+
|
| 286 |
+
__all__ = [
|
| 287 |
+
"KugelAudioDiffusionHead",
|
| 288 |
+
]
|
kugelaudio_open/models/kugelaudio_inference.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KugelAudio inference model for speech generation.
|
| 2 |
+
|
| 3 |
+
This is the open-source inference implementation without optimizations.
|
| 4 |
+
Based on the original VibeVoice model architecture.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from transformers import modeling_utils
|
| 14 |
+
from transformers.cache_utils import DynamicCache
|
| 15 |
+
from transformers.generation import (
|
| 16 |
+
GenerationConfig,
|
| 17 |
+
GenerationMixin,
|
| 18 |
+
LogitsProcessor,
|
| 19 |
+
LogitsProcessorList,
|
| 20 |
+
StoppingCriteriaList,
|
| 21 |
+
)
|
| 22 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 23 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 25 |
+
from transformers.utils import logging
|
| 26 |
+
|
| 27 |
+
from ..configs import KugelAudioConfig
|
| 28 |
+
from ..schedule.dpm_solver import DPMSolverMultistepScheduler
|
| 29 |
+
from .diffusion_head import KugelAudioDiffusionHead
|
| 30 |
+
from .kugelaudio_model import KugelAudioModel, KugelAudioPreTrainedModel
|
| 31 |
+
from .tokenizer import (
|
| 32 |
+
KugelAudioTokenizerEncoderOutput,
|
| 33 |
+
KugelAudioTokenizerStreamingCache,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
| 39 |
+
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _get_cache_tensors(cache) -> Tuple[List, List]:
|
| 43 |
+
"""Get key and value cache tensors from a cache object."""
|
| 44 |
+
if hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
|
| 45 |
+
return cache.key_cache, cache.value_cache
|
| 46 |
+
raise AttributeError(f"Cannot get cache tensors from {type(cache).__name__}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class KugelAudioCausalLMOutputWithPast(BaseModelOutputWithPast):
|
| 51 |
+
logits: Optional[torch.FloatTensor] = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class KugelAudioGenerationOutput(ModelOutput):
|
| 56 |
+
"""Output type for KugelAudio generation."""
|
| 57 |
+
|
| 58 |
+
sequences: torch.LongTensor = None
|
| 59 |
+
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class KugelAudioTokenConstraintProcessor(LogitsProcessor):
|
| 63 |
+
"""Constrains token generation to only valid tokens during speech generation."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, valid_token_ids: List[int], device: torch.device = None):
|
| 66 |
+
self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
|
| 67 |
+
|
| 68 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 69 |
+
mask = torch.full_like(scores, float("-inf"))
|
| 70 |
+
mask[:, self.valid_token_ids] = 0
|
| 71 |
+
scores = scores + mask
|
| 72 |
+
return scores
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class KugelAudioForConditionalGenerationInference(KugelAudioPreTrainedModel, GenerationMixin):
|
| 76 |
+
"""KugelAudio model for inference with speech generation capabilities."""
|
| 77 |
+
|
| 78 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 79 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 80 |
+
|
| 81 |
+
def __init__(self, config):
|
| 82 |
+
super().__init__(config)
|
| 83 |
+
self.model = KugelAudioModel(config)
|
| 84 |
+
self.lm_head = nn.Linear(
|
| 85 |
+
config.decoder_config.hidden_size,
|
| 86 |
+
config.decoder_config.vocab_size,
|
| 87 |
+
bias=False,
|
| 88 |
+
)
|
| 89 |
+
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
|
| 90 |
+
self.post_init()
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def noise_scheduler(self):
|
| 94 |
+
return self.model.noise_scheduler
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def prediction_head(self):
|
| 98 |
+
return self.model.prediction_head
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def speech_scaling_factor(self):
|
| 102 |
+
return self.model.speech_scaling_factor
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def speech_bias_factor(self):
|
| 106 |
+
return self.model.speech_bias_factor
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def acoustic_tokenizer(self):
|
| 110 |
+
return self.model.acoustic_tokenizer
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def semantic_tokenizer(self):
|
| 114 |
+
return self.model.semantic_tokenizer
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def acoustic_connector(self):
|
| 118 |
+
return self.model.acoustic_connector
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def semantic_connector(self):
|
| 122 |
+
return self.model.semantic_connector
|
| 123 |
+
|
| 124 |
+
def get_input_embeddings(self):
|
| 125 |
+
return self.model.get_input_embeddings()
|
| 126 |
+
|
| 127 |
+
def set_input_embeddings(self, value):
|
| 128 |
+
self.model.set_input_embeddings(value)
|
| 129 |
+
|
| 130 |
+
def get_output_embeddings(self):
|
| 131 |
+
return self.lm_head
|
| 132 |
+
|
| 133 |
+
def set_output_embeddings(self, new_embeddings):
|
| 134 |
+
self.lm_head = new_embeddings
|
| 135 |
+
|
| 136 |
+
def set_ddpm_inference_steps(self, num_steps=None):
|
| 137 |
+
self.ddpm_inference_steps = (
|
| 138 |
+
num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def _process_speech_inputs(
|
| 142 |
+
self,
|
| 143 |
+
speech_tensors: Optional[torch.Tensor],
|
| 144 |
+
speech_masks: Optional[torch.Tensor],
|
| 145 |
+
voice_cache: Optional[dict] = None,
|
| 146 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 147 |
+
"""Process speech inputs through acoustic and semantic encoders.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Tuple of (acoustic_features, speech_embeds) where speech_embeds has shape
|
| 151 |
+
[num_valid_frames, hidden] - already indexed by speech_masks for direct
|
| 152 |
+
assignment to inputs_embeds[speech_input_mask].
|
| 153 |
+
"""
|
| 154 |
+
device = next(self.parameters()).device
|
| 155 |
+
dtype = next(self.parameters()).dtype
|
| 156 |
+
|
| 157 |
+
if voice_cache is not None:
|
| 158 |
+
# Use pre-encoded voice features
|
| 159 |
+
acoustic_mean = voice_cache["acoustic_mean"].to(device=device, dtype=dtype)
|
| 160 |
+
semantic_mean = voice_cache["semantic_mean"].to(device=device, dtype=dtype)
|
| 161 |
+
|
| 162 |
+
# Sample from acoustic distribution
|
| 163 |
+
fix_std = voice_cache.get("acoustic_std", self.acoustic_tokenizer.fix_std)
|
| 164 |
+
acoustic_features = acoustic_mean + fix_std * torch.randn_like(acoustic_mean)
|
| 165 |
+
semantic_features = semantic_mean
|
| 166 |
+
|
| 167 |
+
# Create speech_masks from cache dimensions (all frames valid)
|
| 168 |
+
batch_size = acoustic_features.shape[0]
|
| 169 |
+
seq_len = acoustic_features.shape[1]
|
| 170 |
+
speech_masks = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
| 171 |
+
|
| 172 |
+
elif speech_tensors is not None:
|
| 173 |
+
# Encode speech through tokenizers
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
# Acoustic encoding
|
| 176 |
+
if speech_tensors.dim() == 2:
|
| 177 |
+
speech_tensors = speech_tensors.unsqueeze(1)
|
| 178 |
+
|
| 179 |
+
acoustic_output = self.acoustic_tokenizer.encode(speech_tensors)
|
| 180 |
+
acoustic_features, _ = self.acoustic_tokenizer.sampling(acoustic_output)
|
| 181 |
+
|
| 182 |
+
# Semantic encoding
|
| 183 |
+
semantic_output = self.semantic_tokenizer.encode(speech_tensors)
|
| 184 |
+
semantic_features = semantic_output.mean
|
| 185 |
+
|
| 186 |
+
# Create speech_masks if not provided (all frames valid)
|
| 187 |
+
if speech_masks is None:
|
| 188 |
+
batch_size = acoustic_features.shape[0]
|
| 189 |
+
seq_len = acoustic_features.shape[1]
|
| 190 |
+
speech_masks = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
| 191 |
+
else:
|
| 192 |
+
# Return dummy features
|
| 193 |
+
vae_dim = self.config.acoustic_vae_dim
|
| 194 |
+
acoustic_features = torch.zeros(1, 1, vae_dim, device=device, dtype=dtype)
|
| 195 |
+
semantic_features = torch.zeros(
|
| 196 |
+
1, 1, self.config.semantic_vae_dim, device=device, dtype=dtype
|
| 197 |
+
)
|
| 198 |
+
speech_masks = torch.ones(1, 1, dtype=torch.bool, device=device)
|
| 199 |
+
|
| 200 |
+
# Ensure acoustic and semantic have matching time dimensions
|
| 201 |
+
acoustic_len = acoustic_features.shape[1]
|
| 202 |
+
semantic_len = semantic_features.shape[1]
|
| 203 |
+
if semantic_len < acoustic_len:
|
| 204 |
+
pad_size = acoustic_len - semantic_len
|
| 205 |
+
semantic_features = torch.nn.functional.pad(
|
| 206 |
+
semantic_features, (0, 0, 0, pad_size), mode="constant", value=0
|
| 207 |
+
)
|
| 208 |
+
elif semantic_len > acoustic_len:
|
| 209 |
+
semantic_features = semantic_features[:, :acoustic_len, :]
|
| 210 |
+
|
| 211 |
+
# Apply scaling to acoustic features
|
| 212 |
+
if not torch.isnan(self.speech_scaling_factor):
|
| 213 |
+
acoustic_features = (
|
| 214 |
+
acoustic_features + self.speech_bias_factor
|
| 215 |
+
) * self.speech_scaling_factor
|
| 216 |
+
|
| 217 |
+
# Get embeddings through connectors
|
| 218 |
+
acoustic_embed = self.acoustic_connector(acoustic_features)
|
| 219 |
+
semantic_embed = self.semantic_connector(semantic_features)
|
| 220 |
+
|
| 221 |
+
# Combine embeddings and index by speech_masks
|
| 222 |
+
combined_embed = acoustic_embed + semantic_embed
|
| 223 |
+
|
| 224 |
+
# Move speech_masks to CPU for indexing (matches working implementation)
|
| 225 |
+
speech_embeds = combined_embed[speech_masks.cpu()]
|
| 226 |
+
|
| 227 |
+
return acoustic_features, speech_embeds
|
| 228 |
+
|
| 229 |
+
def forward(
|
| 230 |
+
self,
|
| 231 |
+
input_ids: torch.LongTensor = None,
|
| 232 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 233 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 234 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 235 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 236 |
+
labels: Optional[torch.LongTensor] = None,
|
| 237 |
+
use_cache: Optional[bool] = None,
|
| 238 |
+
output_attentions: Optional[bool] = None,
|
| 239 |
+
output_hidden_states: Optional[bool] = None,
|
| 240 |
+
return_dict: Optional[bool] = None,
|
| 241 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 242 |
+
speech_tensors: Optional[torch.FloatTensor] = None,
|
| 243 |
+
speech_masks: Optional[torch.BoolTensor] = None,
|
| 244 |
+
speech_input_mask: Optional[torch.BoolTensor] = None,
|
| 245 |
+
voice_cache: Optional[dict] = None,
|
| 246 |
+
logits_to_keep: Union[int, slice] = 0,
|
| 247 |
+
**kwargs,
|
| 248 |
+
) -> Union[Tuple, KugelAudioCausalLMOutputWithPast]:
|
| 249 |
+
"""Forward pass for the model."""
|
| 250 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 251 |
+
|
| 252 |
+
if inputs_embeds is None:
|
| 253 |
+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
| 254 |
+
|
| 255 |
+
# Process speech inputs if provided
|
| 256 |
+
if voice_cache is not None or (speech_tensors is not None and speech_masks is not None):
|
| 257 |
+
_, speech_embeds = self._process_speech_inputs(
|
| 258 |
+
speech_tensors.to(self.dtype) if speech_tensors is not None else None,
|
| 259 |
+
speech_masks,
|
| 260 |
+
voice_cache=voice_cache,
|
| 261 |
+
)
|
| 262 |
+
if speech_input_mask is not None:
|
| 263 |
+
inputs_embeds[speech_input_mask] = speech_embeds
|
| 264 |
+
|
| 265 |
+
outputs = self.model(
|
| 266 |
+
inputs_embeds=inputs_embeds,
|
| 267 |
+
attention_mask=attention_mask,
|
| 268 |
+
position_ids=position_ids,
|
| 269 |
+
past_key_values=past_key_values,
|
| 270 |
+
use_cache=use_cache,
|
| 271 |
+
output_attentions=output_attentions,
|
| 272 |
+
output_hidden_states=output_hidden_states,
|
| 273 |
+
return_dict=return_dict,
|
| 274 |
+
cache_position=cache_position,
|
| 275 |
+
**kwargs,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
| 279 |
+
slice_indices = (
|
| 280 |
+
slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 281 |
+
)
|
| 282 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 283 |
+
|
| 284 |
+
return KugelAudioCausalLMOutputWithPast(
|
| 285 |
+
logits=logits,
|
| 286 |
+
past_key_values=outputs.past_key_values,
|
| 287 |
+
last_hidden_state=hidden_states,
|
| 288 |
+
attentions=outputs.attentions,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
@torch.no_grad()
|
| 292 |
+
def sample_speech_tokens(
|
| 293 |
+
self, condition: torch.Tensor, neg_condition: torch.Tensor, cfg_scale: float = 3.0
|
| 294 |
+
) -> torch.Tensor:
|
| 295 |
+
"""Sample speech latents using diffusion with classifier-free guidance."""
|
| 296 |
+
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
|
| 297 |
+
|
| 298 |
+
if cfg_scale == 1.0:
|
| 299 |
+
# No CFG - single forward pass
|
| 300 |
+
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
|
| 301 |
+
for t in self.model.noise_scheduler.timesteps:
|
| 302 |
+
eps = self.model.prediction_head(
|
| 303 |
+
speech, t.repeat(speech.shape[0]).to(speech), condition=condition
|
| 304 |
+
)
|
| 305 |
+
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
|
| 306 |
+
return speech
|
| 307 |
+
|
| 308 |
+
# With CFG - batched forward pass
|
| 309 |
+
combined_condition = torch.cat([condition, neg_condition], dim=0).to(
|
| 310 |
+
self.model.prediction_head.device
|
| 311 |
+
)
|
| 312 |
+
speech = torch.randn(combined_condition.shape[0], self.config.acoustic_vae_dim).to(
|
| 313 |
+
combined_condition
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
for t in self.model.noise_scheduler.timesteps:
|
| 317 |
+
half = speech[: len(speech) // 2]
|
| 318 |
+
combined = torch.cat([half, half], dim=0)
|
| 319 |
+
eps = self.model.prediction_head(
|
| 320 |
+
combined, t.repeat(combined.shape[0]).to(combined), condition=combined_condition
|
| 321 |
+
)
|
| 322 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 323 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 324 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 325 |
+
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
|
| 326 |
+
|
| 327 |
+
return speech[: len(speech) // 2]
|
| 328 |
+
|
| 329 |
+
@torch.no_grad()
|
| 330 |
+
def encode_voice_prompt(
|
| 331 |
+
self,
|
| 332 |
+
voice_audio: torch.Tensor,
|
| 333 |
+
sample_rate: int = 24000,
|
| 334 |
+
) -> dict:
|
| 335 |
+
"""Pre-encode a voice prompt for caching."""
|
| 336 |
+
device = next(self.parameters()).device
|
| 337 |
+
dtype = next(self.parameters()).dtype
|
| 338 |
+
|
| 339 |
+
if voice_audio.dim() == 1:
|
| 340 |
+
voice_audio = voice_audio.unsqueeze(0).unsqueeze(0)
|
| 341 |
+
elif voice_audio.dim() == 2:
|
| 342 |
+
voice_audio = voice_audio.unsqueeze(1)
|
| 343 |
+
|
| 344 |
+
voice_audio = voice_audio.to(device=device, dtype=dtype)
|
| 345 |
+
|
| 346 |
+
with torch.no_grad():
|
| 347 |
+
acoustic_output = self.model.acoustic_tokenizer.encode(voice_audio)
|
| 348 |
+
semantic_output = self.model.semantic_tokenizer.encode(voice_audio)
|
| 349 |
+
|
| 350 |
+
return {
|
| 351 |
+
"acoustic_mean": acoustic_output.mean.cpu(),
|
| 352 |
+
"acoustic_std": getattr(acoustic_output, "std", self.model.acoustic_tokenizer.fix_std),
|
| 353 |
+
"semantic_mean": semantic_output.mean.cpu(),
|
| 354 |
+
"audio_length": voice_audio.shape[-1],
|
| 355 |
+
"sample_rate": sample_rate,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
@torch.no_grad()
|
| 359 |
+
def generate(
|
| 360 |
+
self,
|
| 361 |
+
text_ids: Optional[torch.Tensor] = None,
|
| 362 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 363 |
+
voice_prompt: Optional[torch.Tensor] = None,
|
| 364 |
+
voice_cache: Optional[dict] = None,
|
| 365 |
+
speech_tensors: Optional[torch.Tensor] = None,
|
| 366 |
+
speech_masks: Optional[torch.Tensor] = None,
|
| 367 |
+
speech_input_mask: Optional[torch.Tensor] = None,
|
| 368 |
+
cfg_scale: float = 3.0,
|
| 369 |
+
max_new_tokens: int = 2048,
|
| 370 |
+
do_sample: bool = False,
|
| 371 |
+
temperature: float = 1.0,
|
| 372 |
+
show_progress: bool = True,
|
| 373 |
+
**kwargs,
|
| 374 |
+
) -> KugelAudioGenerationOutput:
|
| 375 |
+
"""Generate speech from text.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
text_ids: Tokenized text input (from processor)
|
| 379 |
+
input_ids: Alternative name for text_ids
|
| 380 |
+
voice_prompt: Voice audio tensor for cloning (legacy, use speech_tensors instead)
|
| 381 |
+
voice_cache: Pre-encoded voice features (from encode_voice_prompt)
|
| 382 |
+
speech_tensors: Voice audio tensor from processor for cloning
|
| 383 |
+
speech_masks: Mask indicating valid voice frames
|
| 384 |
+
speech_input_mask: Boolean mask indicating where to insert voice embeddings
|
| 385 |
+
cfg_scale: Classifier-free guidance scale (higher = more faithful to text)
|
| 386 |
+
max_new_tokens: Maximum tokens to generate
|
| 387 |
+
do_sample: Whether to sample or use greedy decoding
|
| 388 |
+
temperature: Sampling temperature
|
| 389 |
+
show_progress: Whether to show progress bar
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
KugelAudioGenerationOutput with sequences and speech_outputs
|
| 393 |
+
"""
|
| 394 |
+
device = next(self.parameters()).device
|
| 395 |
+
dtype = next(self.parameters()).dtype
|
| 396 |
+
|
| 397 |
+
# Handle input_ids vs text_ids
|
| 398 |
+
if text_ids is None and input_ids is not None:
|
| 399 |
+
text_ids = input_ids
|
| 400 |
+
if text_ids is None:
|
| 401 |
+
raise ValueError("text_ids or input_ids is required")
|
| 402 |
+
|
| 403 |
+
text_ids = text_ids.to(device)
|
| 404 |
+
batch_size = text_ids.shape[0]
|
| 405 |
+
|
| 406 |
+
# Handle legacy voice_prompt parameter
|
| 407 |
+
if voice_prompt is not None and speech_tensors is None:
|
| 408 |
+
speech_tensors = voice_prompt
|
| 409 |
+
# Create default speech_masks if not provided
|
| 410 |
+
if speech_masks is None:
|
| 411 |
+
# Estimate number of frames from audio length
|
| 412 |
+
audio_len = voice_prompt.shape[-1]
|
| 413 |
+
num_frames = (audio_len + 3199) // 3200 # compression ratio
|
| 414 |
+
speech_masks = torch.ones(batch_size, num_frames, dtype=torch.bool, device=device)
|
| 415 |
+
|
| 416 |
+
# Get special token IDs
|
| 417 |
+
speech_start_id = getattr(self.config, "speech_start_id", None) or 151652
|
| 418 |
+
speech_end_id = getattr(self.config, "speech_end_id", None) or 151653
|
| 419 |
+
speech_diffusion_id = getattr(self.config, "speech_diffusion_id", None) or 151654
|
| 420 |
+
eos_token_id = getattr(self.config.decoder_config, "eos_token_id", None) or 151643
|
| 421 |
+
|
| 422 |
+
# Initialize streaming caches for tokenizers
|
| 423 |
+
acoustic_cache = KugelAudioTokenizerStreamingCache()
|
| 424 |
+
semantic_cache = KugelAudioTokenizerStreamingCache()
|
| 425 |
+
|
| 426 |
+
# Initialize sequences and attention masks
|
| 427 |
+
current_ids = text_ids
|
| 428 |
+
attention_mask = torch.ones_like(current_ids)
|
| 429 |
+
|
| 430 |
+
# For CFG, create negative prompt (just speech_start token)
|
| 431 |
+
negative_ids = torch.full((batch_size, 1), speech_start_id, dtype=torch.long, device=device)
|
| 432 |
+
negative_attention_mask = torch.ones_like(negative_ids)
|
| 433 |
+
|
| 434 |
+
# Storage for generated audio and tracking
|
| 435 |
+
audio_chunks = [[] for _ in range(batch_size)]
|
| 436 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 437 |
+
correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 438 |
+
|
| 439 |
+
# Get initial embeddings
|
| 440 |
+
inputs_embeds = self.model.get_input_embeddings()(current_ids)
|
| 441 |
+
|
| 442 |
+
# Process voice/speech input if provided
|
| 443 |
+
if speech_tensors is not None or voice_cache is not None:
|
| 444 |
+
# Get speech embeddings
|
| 445 |
+
if voice_cache is not None:
|
| 446 |
+
_, speech_embeds = self._process_speech_inputs(
|
| 447 |
+
speech_tensors=None,
|
| 448 |
+
speech_masks=None,
|
| 449 |
+
voice_cache=voice_cache,
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
# Encode speech_tensors directly
|
| 453 |
+
speech_tensors = speech_tensors.to(device=device, dtype=dtype)
|
| 454 |
+
if speech_masks is not None:
|
| 455 |
+
speech_masks = speech_masks.to(device)
|
| 456 |
+
_, speech_embeds = self._process_speech_inputs(
|
| 457 |
+
speech_tensors=speech_tensors,
|
| 458 |
+
speech_masks=speech_masks,
|
| 459 |
+
voice_cache=None,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Insert speech embeddings at positions marked by speech_input_mask
|
| 463 |
+
# speech_embeds is already flattened to [num_valid_frames, hidden] by _process_speech_inputs
|
| 464 |
+
if speech_input_mask is not None:
|
| 465 |
+
speech_input_mask = speech_input_mask.to(device)
|
| 466 |
+
# Directly assign - shapes should match
|
| 467 |
+
inputs_embeds[speech_input_mask] = speech_embeds
|
| 468 |
+
|
| 469 |
+
negative_inputs_embeds = self.model.get_input_embeddings()(negative_ids)
|
| 470 |
+
|
| 471 |
+
# Setup logits processor to constrain to valid tokens
|
| 472 |
+
valid_tokens = [speech_start_id, speech_end_id, speech_diffusion_id, eos_token_id]
|
| 473 |
+
token_constraint = KugelAudioTokenConstraintProcessor(valid_tokens, device=device)
|
| 474 |
+
|
| 475 |
+
# Initialize KV caches
|
| 476 |
+
past_key_values = None
|
| 477 |
+
negative_past_key_values = None
|
| 478 |
+
|
| 479 |
+
# Progress bar
|
| 480 |
+
progress_iter = (
|
| 481 |
+
tqdm(range(max_new_tokens), desc="Generating", leave=False)
|
| 482 |
+
if show_progress
|
| 483 |
+
else range(max_new_tokens)
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
for step in progress_iter:
|
| 487 |
+
if finished.all():
|
| 488 |
+
break
|
| 489 |
+
|
| 490 |
+
# Forward pass for positive (main) model
|
| 491 |
+
if past_key_values is None:
|
| 492 |
+
outputs = self(
|
| 493 |
+
inputs_embeds=inputs_embeds,
|
| 494 |
+
attention_mask=attention_mask,
|
| 495 |
+
use_cache=True,
|
| 496 |
+
return_dict=True,
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
outputs = self(
|
| 500 |
+
inputs_embeds=inputs_embeds[:, -1:],
|
| 501 |
+
attention_mask=attention_mask,
|
| 502 |
+
past_key_values=past_key_values,
|
| 503 |
+
use_cache=True,
|
| 504 |
+
return_dict=True,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
past_key_values = outputs.past_key_values
|
| 508 |
+
logits = outputs.logits[:, -1, :]
|
| 509 |
+
|
| 510 |
+
# Apply token constraint
|
| 511 |
+
logits = token_constraint(current_ids, logits)
|
| 512 |
+
|
| 513 |
+
# Sample or greedy decode
|
| 514 |
+
if do_sample and temperature > 0:
|
| 515 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 516 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 517 |
+
else:
|
| 518 |
+
next_tokens = torch.argmax(logits, dim=-1)
|
| 519 |
+
|
| 520 |
+
# Force finished samples to output EOS
|
| 521 |
+
next_tokens = torch.where(
|
| 522 |
+
finished, torch.tensor(eos_token_id, device=device), next_tokens
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Update sequences
|
| 526 |
+
current_ids = torch.cat([current_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
| 527 |
+
attention_mask = torch.cat(
|
| 528 |
+
[
|
| 529 |
+
attention_mask,
|
| 530 |
+
torch.ones((batch_size, 1), device=device, dtype=attention_mask.dtype),
|
| 531 |
+
],
|
| 532 |
+
dim=-1,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Check for EOS tokens
|
| 536 |
+
eos_mask = (next_tokens == eos_token_id) & ~finished
|
| 537 |
+
if eos_mask.any():
|
| 538 |
+
finished = finished | eos_mask
|
| 539 |
+
|
| 540 |
+
# Check for speech_end tokens - mark as finished and clear caches
|
| 541 |
+
speech_end_mask = (next_tokens == speech_end_id) & ~finished
|
| 542 |
+
if speech_end_mask.any():
|
| 543 |
+
finished = finished | speech_end_mask
|
| 544 |
+
speech_end_indices = speech_end_mask.nonzero(as_tuple=False).squeeze(-1)
|
| 545 |
+
acoustic_cache.set_to_zero(speech_end_indices)
|
| 546 |
+
semantic_cache.set_to_zero(speech_end_indices)
|
| 547 |
+
|
| 548 |
+
# Handle speech_start tokens - refresh negative model KV cache
|
| 549 |
+
speech_start_mask = (next_tokens == speech_start_id) & ~finished
|
| 550 |
+
if (
|
| 551 |
+
speech_start_mask.any()
|
| 552 |
+
and cfg_scale != 1.0
|
| 553 |
+
and negative_past_key_values is not None
|
| 554 |
+
):
|
| 555 |
+
speech_start_indices = speech_start_mask.nonzero(as_tuple=False).squeeze(-1)
|
| 556 |
+
if speech_start_indices.dim() == 0:
|
| 557 |
+
speech_start_indices = speech_start_indices.unsqueeze(0)
|
| 558 |
+
|
| 559 |
+
for sample_idx in speech_start_indices.tolist():
|
| 560 |
+
negative_attention_mask[sample_idx, :] = 0
|
| 561 |
+
negative_attention_mask[sample_idx, -1] = 1
|
| 562 |
+
|
| 563 |
+
key_caches, value_caches = _get_cache_tensors(negative_past_key_values)
|
| 564 |
+
for k_cache, v_cache in zip(key_caches, value_caches):
|
| 565 |
+
k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
|
| 566 |
+
v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
|
| 567 |
+
|
| 568 |
+
negative_ids[sample_idx, -1] = speech_start_id
|
| 569 |
+
|
| 570 |
+
# Prepare next input embeddings
|
| 571 |
+
next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1)
|
| 572 |
+
|
| 573 |
+
# Handle diffusion tokens - generate speech
|
| 574 |
+
diffusion_mask = (next_tokens == speech_diffusion_id) & ~finished
|
| 575 |
+
if diffusion_mask.any():
|
| 576 |
+
diffusion_indices = diffusion_mask.nonzero(as_tuple=False).squeeze(-1)
|
| 577 |
+
if diffusion_indices.dim() == 0:
|
| 578 |
+
diffusion_indices = diffusion_indices.unsqueeze(0)
|
| 579 |
+
|
| 580 |
+
# Run negative forward pass for CFG
|
| 581 |
+
if cfg_scale != 1.0:
|
| 582 |
+
if negative_past_key_values is None:
|
| 583 |
+
neg_outputs = self(
|
| 584 |
+
inputs_embeds=negative_inputs_embeds,
|
| 585 |
+
attention_mask=negative_attention_mask,
|
| 586 |
+
use_cache=True,
|
| 587 |
+
return_dict=True,
|
| 588 |
+
)
|
| 589 |
+
else:
|
| 590 |
+
neg_outputs = self(
|
| 591 |
+
inputs_embeds=negative_inputs_embeds[:, -1:],
|
| 592 |
+
attention_mask=negative_attention_mask,
|
| 593 |
+
past_key_values=negative_past_key_values,
|
| 594 |
+
use_cache=True,
|
| 595 |
+
return_dict=True,
|
| 596 |
+
)
|
| 597 |
+
negative_past_key_values = neg_outputs.past_key_values
|
| 598 |
+
|
| 599 |
+
# Handle non-diffusion samples KV cache correction
|
| 600 |
+
non_diffusion_mask = ~diffusion_mask & ~finished
|
| 601 |
+
if non_diffusion_mask.any():
|
| 602 |
+
non_diffusion_indices = non_diffusion_mask.nonzero(as_tuple=False).squeeze(
|
| 603 |
+
-1
|
| 604 |
+
)
|
| 605 |
+
if non_diffusion_indices.dim() == 0:
|
| 606 |
+
non_diffusion_indices = non_diffusion_indices.unsqueeze(0)
|
| 607 |
+
|
| 608 |
+
key_caches, value_caches = _get_cache_tensors(negative_past_key_values)
|
| 609 |
+
for sample_idx in non_diffusion_indices.tolist():
|
| 610 |
+
start_idx = correct_cnt[sample_idx].item()
|
| 611 |
+
seq_len = negative_attention_mask.shape[1]
|
| 612 |
+
|
| 613 |
+
if start_idx + 1 < seq_len - 1:
|
| 614 |
+
negative_attention_mask[sample_idx, start_idx + 1 :] = (
|
| 615 |
+
negative_attention_mask[sample_idx, start_idx:-1].clone()
|
| 616 |
+
)
|
| 617 |
+
negative_attention_mask[sample_idx, start_idx] = 0
|
| 618 |
+
|
| 619 |
+
for k_cache, v_cache in zip(key_caches, value_caches):
|
| 620 |
+
if start_idx + 1 < k_cache.shape[2] - 1:
|
| 621 |
+
k_cache[sample_idx, :, start_idx + 1 :, :] = k_cache[
|
| 622 |
+
sample_idx, :, start_idx:-1, :
|
| 623 |
+
].clone()
|
| 624 |
+
v_cache[sample_idx, :, start_idx + 1 :, :] = v_cache[
|
| 625 |
+
sample_idx, :, start_idx:-1, :
|
| 626 |
+
].clone()
|
| 627 |
+
|
| 628 |
+
if start_idx + 1 < negative_ids.shape[1] - 1:
|
| 629 |
+
negative_ids[sample_idx, start_idx + 1 :] = negative_ids[
|
| 630 |
+
sample_idx, start_idx:-1
|
| 631 |
+
].clone()
|
| 632 |
+
|
| 633 |
+
correct_cnt[non_diffusion_indices] += 1
|
| 634 |
+
|
| 635 |
+
neg_condition = neg_outputs.last_hidden_state[diffusion_indices, -1, :]
|
| 636 |
+
else:
|
| 637 |
+
neg_condition = torch.zeros(
|
| 638 |
+
diffusion_indices.shape[0],
|
| 639 |
+
self.config.decoder_config.hidden_size,
|
| 640 |
+
device=device,
|
| 641 |
+
dtype=dtype,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# Get conditioning from last hidden state
|
| 645 |
+
condition = outputs.last_hidden_state[diffusion_indices, -1, :]
|
| 646 |
+
|
| 647 |
+
# Sample speech latents using diffusion
|
| 648 |
+
speech_latents = self.sample_speech_tokens(condition, neg_condition, cfg_scale)
|
| 649 |
+
|
| 650 |
+
# Unscale latents
|
| 651 |
+
scaled_latent = (
|
| 652 |
+
speech_latents / self.speech_scaling_factor - self.speech_bias_factor
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Decode through acoustic tokenizer with streaming cache
|
| 656 |
+
audio = self.acoustic_tokenizer.decode(
|
| 657 |
+
scaled_latent.unsqueeze(1).permute(0, 2, 1),
|
| 658 |
+
cache=acoustic_cache,
|
| 659 |
+
sample_indices=diffusion_indices,
|
| 660 |
+
use_cache=True,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Store audio chunks
|
| 664 |
+
for i, idx in enumerate(diffusion_indices.tolist()):
|
| 665 |
+
if not finished[idx]:
|
| 666 |
+
audio_chunks[idx].append(audio[i].cpu())
|
| 667 |
+
|
| 668 |
+
# Encode audio to semantic features with streaming cache
|
| 669 |
+
semantic_output = self.semantic_tokenizer.encode(
|
| 670 |
+
audio,
|
| 671 |
+
cache=semantic_cache,
|
| 672 |
+
sample_indices=diffusion_indices,
|
| 673 |
+
use_cache=True,
|
| 674 |
+
)
|
| 675 |
+
semantic_features = semantic_output.mean
|
| 676 |
+
|
| 677 |
+
# Compute embeddings for next step
|
| 678 |
+
acoustic_embed = self.acoustic_connector(speech_latents.unsqueeze(1))
|
| 679 |
+
semantic_embed = self.semantic_connector(semantic_features)
|
| 680 |
+
diffusion_embeds = (acoustic_embed + semantic_embed).squeeze(1)
|
| 681 |
+
|
| 682 |
+
# Update embeddings for diffusion samples
|
| 683 |
+
next_inputs_embeds[diffusion_indices] = diffusion_embeds.unsqueeze(1)
|
| 684 |
+
|
| 685 |
+
# Update embeddings for next iteration
|
| 686 |
+
inputs_embeds = torch.cat([inputs_embeds, next_inputs_embeds], dim=1)
|
| 687 |
+
|
| 688 |
+
# Update negative model
|
| 689 |
+
negative_inputs_embeds = torch.cat([negative_inputs_embeds, next_inputs_embeds], dim=1)
|
| 690 |
+
negative_attention_mask = torch.cat(
|
| 691 |
+
[
|
| 692 |
+
negative_attention_mask,
|
| 693 |
+
torch.ones((batch_size, 1), device=device, dtype=negative_attention_mask.dtype),
|
| 694 |
+
],
|
| 695 |
+
dim=-1,
|
| 696 |
+
)
|
| 697 |
+
negative_ids = torch.cat([negative_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
| 698 |
+
|
| 699 |
+
# Concatenate audio chunks with normalization
|
| 700 |
+
speech_outputs = []
|
| 701 |
+
for chunks in audio_chunks:
|
| 702 |
+
if chunks:
|
| 703 |
+
concatenated = torch.cat(chunks, dim=-1).squeeze()
|
| 704 |
+
# Normalize audio to prevent clipping
|
| 705 |
+
max_val = concatenated.abs().max()
|
| 706 |
+
if max_val > 1.0:
|
| 707 |
+
concatenated = concatenated * (0.95 / max_val)
|
| 708 |
+
# Apply watermark to all generated audio
|
| 709 |
+
concatenated = self._apply_watermark(concatenated, sample_rate=24000)
|
| 710 |
+
speech_outputs.append(concatenated)
|
| 711 |
+
else:
|
| 712 |
+
speech_outputs.append(None)
|
| 713 |
+
|
| 714 |
+
return KugelAudioGenerationOutput(
|
| 715 |
+
sequences=current_ids,
|
| 716 |
+
speech_outputs=speech_outputs,
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
def _apply_watermark(self, audio: torch.Tensor, sample_rate: int = 24000) -> torch.Tensor:
|
| 720 |
+
"""Apply imperceptible watermark to generated audio.
|
| 721 |
+
|
| 722 |
+
This watermark identifies audio as generated by KugelAudio and is designed
|
| 723 |
+
to be robust against various audio transformations while remaining inaudible.
|
| 724 |
+
"""
|
| 725 |
+
try:
|
| 726 |
+
import torchaudio.functional as F
|
| 727 |
+
from audioseal import AudioSeal
|
| 728 |
+
except ImportError:
|
| 729 |
+
return audio # Graceful fallback if audioseal not available
|
| 730 |
+
|
| 731 |
+
device = audio.device
|
| 732 |
+
dtype = audio.dtype
|
| 733 |
+
original_shape = audio.shape
|
| 734 |
+
|
| 735 |
+
# Prepare audio for watermarking (AudioSeal expects [batch, channels, samples] at 16kHz)
|
| 736 |
+
if audio.dim() == 1:
|
| 737 |
+
audio_for_wm = audio.unsqueeze(0).unsqueeze(0)
|
| 738 |
+
elif audio.dim() == 2:
|
| 739 |
+
audio_for_wm = audio.unsqueeze(0)
|
| 740 |
+
else:
|
| 741 |
+
audio_for_wm = audio
|
| 742 |
+
|
| 743 |
+
audio_for_wm = audio_for_wm.float()
|
| 744 |
+
|
| 745 |
+
# Resample to 16kHz for AudioSeal
|
| 746 |
+
if sample_rate != 16000:
|
| 747 |
+
audio_16k = F.resample(audio_for_wm, sample_rate, 16000)
|
| 748 |
+
else:
|
| 749 |
+
audio_16k = audio_for_wm
|
| 750 |
+
|
| 751 |
+
# Load watermark generator (cached after first use)
|
| 752 |
+
if not hasattr(self, "_wm_generator"):
|
| 753 |
+
self._wm_generator = AudioSeal.load_generator("audioseal_wm_16bits").to(device)
|
| 754 |
+
self._wm_generator.eval()
|
| 755 |
+
|
| 756 |
+
# Generate and apply watermark
|
| 757 |
+
with torch.no_grad():
|
| 758 |
+
watermark_16k = self._wm_generator.get_watermark(audio_16k.to(device), 16000)
|
| 759 |
+
|
| 760 |
+
# Resample watermark back to original sample rate
|
| 761 |
+
if sample_rate != 16000:
|
| 762 |
+
watermark = F.resample(watermark_16k, 16000, sample_rate)
|
| 763 |
+
# Ensure same length
|
| 764 |
+
if watermark.shape[-1] != audio_for_wm.shape[-1]:
|
| 765 |
+
if watermark.shape[-1] > audio_for_wm.shape[-1]:
|
| 766 |
+
watermark = watermark[..., : audio_for_wm.shape[-1]]
|
| 767 |
+
else:
|
| 768 |
+
watermark = torch.nn.functional.pad(
|
| 769 |
+
watermark, (0, audio_for_wm.shape[-1] - watermark.shape[-1])
|
| 770 |
+
)
|
| 771 |
+
else:
|
| 772 |
+
watermark = watermark_16k
|
| 773 |
+
|
| 774 |
+
# Add watermark to audio
|
| 775 |
+
watermarked = audio_for_wm + watermark.to(audio_for_wm.device)
|
| 776 |
+
|
| 777 |
+
# Normalize to prevent clipping
|
| 778 |
+
max_val = watermarked.abs().max()
|
| 779 |
+
if max_val > 1.0:
|
| 780 |
+
watermarked = watermarked * (0.95 / max_val)
|
| 781 |
+
|
| 782 |
+
# Restore original shape
|
| 783 |
+
if len(original_shape) == 1:
|
| 784 |
+
watermarked = watermarked.squeeze(0).squeeze(0)
|
| 785 |
+
elif len(original_shape) == 2:
|
| 786 |
+
watermarked = watermarked.squeeze(0)
|
| 787 |
+
|
| 788 |
+
return watermarked.to(dtype=dtype)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
# Register with AutoModel
|
| 792 |
+
AutoModel.register(KugelAudioConfig, KugelAudioModel)
|
| 793 |
+
AutoModelForCausalLM.register(KugelAudioConfig, KugelAudioForConditionalGenerationInference)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
__all__ = [
|
| 797 |
+
"KugelAudioForConditionalGenerationInference",
|
| 798 |
+
"KugelAudioCausalLMOutputWithPast",
|
| 799 |
+
"KugelAudioGenerationOutput",
|
| 800 |
+
]
|
kugelaudio_open/models/kugelaudio_model.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union, Callable
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
|
| 9 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
| 10 |
+
|
| 11 |
+
from transformers.activations import ACT2FN
|
| 12 |
+
from transformers.modeling_outputs import (
|
| 13 |
+
CausalLMOutput,
|
| 14 |
+
BaseModelOutputWithPast,
|
| 15 |
+
ModelOutput,
|
| 16 |
+
)
|
| 17 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 18 |
+
from transformers import modeling_utils
|
| 19 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 20 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 21 |
+
from transformers.utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from .tokenizer import (
|
| 25 |
+
KugelAudioAcousticTokenizerModel,
|
| 26 |
+
KugelAudioSemanticTokenizerModel,
|
| 27 |
+
)
|
| 28 |
+
from .diffusion_head import KugelAudioDiffusionHead
|
| 29 |
+
from ..schedule.dpm_solver import DPMSolverMultistepScheduler
|
| 30 |
+
|
| 31 |
+
from ..configs import KugelAudioConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
if (
|
| 37 |
+
not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
|
| 38 |
+
or modeling_utils.ALL_PARALLEL_STYLES is None
|
| 39 |
+
):
|
| 40 |
+
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class KugelAudioCausalLMOutputWithPast(ModelOutput):
|
| 45 |
+
loss: Optional[torch.FloatTensor] = None
|
| 46 |
+
diffusion_loss: Optional[torch.FloatTensor] = None
|
| 47 |
+
speech_token_num: Optional[torch.LongTensor] = None
|
| 48 |
+
logits: torch.FloatTensor = None
|
| 49 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 50 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 51 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class KugelAudioGenerationOutput(ModelOutput):
|
| 56 |
+
"""
|
| 57 |
+
Output type for KugelAudio generation.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 61 |
+
The generated sequences.
|
| 62 |
+
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
| 63 |
+
List of generated speech waveforms or latents for each speech segment.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
sequences: torch.LongTensor = None
|
| 67 |
+
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class SpeechConnector(nn.Module):
|
| 71 |
+
def __init__(self, input_dim, output_dim):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.fc1 = nn.Linear(input_dim, output_dim)
|
| 74 |
+
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
| 75 |
+
self.fc2 = nn.Linear(output_dim, output_dim)
|
| 76 |
+
|
| 77 |
+
def forward(self, features, **kwargs):
|
| 78 |
+
x = self.fc1(features)
|
| 79 |
+
x = self.norm(x)
|
| 80 |
+
x = self.fc2(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# @auto_docstring
|
| 85 |
+
class KugelAudioPreTrainedModel(PreTrainedModel):
|
| 86 |
+
config_class = KugelAudioConfig
|
| 87 |
+
base_model_prefix = "model"
|
| 88 |
+
supports_gradient_checkpointing = True
|
| 89 |
+
_skip_keys_device_placement = "past_key_values"
|
| 90 |
+
_supports_cache_class = True
|
| 91 |
+
_supports_flash_attn_2 = True
|
| 92 |
+
_supports_sdpa = True
|
| 93 |
+
_supports_quantized_cache = True
|
| 94 |
+
_supports_static_cache = True
|
| 95 |
+
_supports_attention_backend = True
|
| 96 |
+
|
| 97 |
+
def _init_weights(self, module):
|
| 98 |
+
if isinstance(module, KugelAudioDiffusionHead):
|
| 99 |
+
module.initialize_weights()
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
# Use the language model's initializer_range if available
|
| 103 |
+
if hasattr(self.config, "language_model_config") and hasattr(
|
| 104 |
+
self.config.language_model_config, "initializer_range"
|
| 105 |
+
):
|
| 106 |
+
std = self.config.language_model_config.initializer_range
|
| 107 |
+
elif hasattr(self.config, "decoder_config") and hasattr(
|
| 108 |
+
self.config.decoder_config, "initializer_range"
|
| 109 |
+
):
|
| 110 |
+
std = self.config.decoder_config.initializer_range
|
| 111 |
+
else:
|
| 112 |
+
std = 0.02 # Default value
|
| 113 |
+
|
| 114 |
+
if isinstance(module, nn.Linear):
|
| 115 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 116 |
+
if module.bias is not None:
|
| 117 |
+
module.bias.data.zero_()
|
| 118 |
+
elif isinstance(module, nn.LayerNorm):
|
| 119 |
+
module.weight.data.fill_(1.0)
|
| 120 |
+
module.bias.data.zero_()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# @auto_docstring
|
| 124 |
+
class KugelAudioModel(KugelAudioPreTrainedModel):
|
| 125 |
+
def __init__(self, config):
|
| 126 |
+
super().__init__(config)
|
| 127 |
+
|
| 128 |
+
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
|
| 129 |
+
if isinstance(config.torch_dtype, str):
|
| 130 |
+
dtype = getattr(torch, config.torch_dtype)
|
| 131 |
+
else:
|
| 132 |
+
dtype = config.torch_dtype
|
| 133 |
+
else:
|
| 134 |
+
dtype = torch.float32
|
| 135 |
+
|
| 136 |
+
# Initialize Qwen2 model for language modeling
|
| 137 |
+
lm_config = config.decoder_config
|
| 138 |
+
self.language_model = AutoModel.from_config(lm_config)
|
| 139 |
+
|
| 140 |
+
# Initialize speech components if needed
|
| 141 |
+
self.acoustic_tokenizer = AutoModel.from_config(
|
| 142 |
+
config.acoustic_tokenizer_config
|
| 143 |
+
).to(dtype)
|
| 144 |
+
self.semantic_tokenizer = AutoModel.from_config(
|
| 145 |
+
config.semantic_tokenizer_config
|
| 146 |
+
).to(dtype)
|
| 147 |
+
|
| 148 |
+
self.acoustic_connector = SpeechConnector(
|
| 149 |
+
config.acoustic_vae_dim, lm_config.hidden_size
|
| 150 |
+
).to(dtype)
|
| 151 |
+
self.semantic_connector = SpeechConnector(
|
| 152 |
+
config.semantic_vae_dim, lm_config.hidden_size
|
| 153 |
+
).to(dtype)
|
| 154 |
+
|
| 155 |
+
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
| 156 |
+
self.register_buffer("speech_scaling_factor", torch.tensor(float("nan")))
|
| 157 |
+
self.register_buffer("speech_bias_factor", torch.tensor(float("nan")))
|
| 158 |
+
|
| 159 |
+
# Initialize prediction head for speech generation
|
| 160 |
+
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(
|
| 161 |
+
dtype
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Initialize noise scheduler with SDE-DPM-Solver++ for better quality
|
| 165 |
+
algorithm_type = getattr(
|
| 166 |
+
config.diffusion_head_config, "ddpm_algorithm_type", "sde-dpmsolver++"
|
| 167 |
+
)
|
| 168 |
+
self.noise_scheduler = DPMSolverMultistepScheduler(
|
| 169 |
+
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
| 170 |
+
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
| 171 |
+
prediction_type=config.diffusion_head_config.prediction_type,
|
| 172 |
+
algorithm_type=algorithm_type,
|
| 173 |
+
solver_order=2,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def get_input_embeddings(self):
|
| 177 |
+
if hasattr(self.language_model, "embed_tokens"):
|
| 178 |
+
# If the language model has an embed_tokens attribute, return it
|
| 179 |
+
return self.language_model.embed_tokens
|
| 180 |
+
|
| 181 |
+
for (
|
| 182 |
+
name,
|
| 183 |
+
attr,
|
| 184 |
+
) in (
|
| 185 |
+
self.language_model.fullmap.items()
|
| 186 |
+
): # parallel by nnscaler, the name is changed
|
| 187 |
+
if attr.orig_name == "embed_tokens.weight":
|
| 188 |
+
return getattr(self.language_model, name)
|
| 189 |
+
assert False, "should not arrive here"
|
| 190 |
+
|
| 191 |
+
def set_input_embeddings(self, value):
|
| 192 |
+
self.language_model.embed_tokens = value
|
| 193 |
+
|
| 194 |
+
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
| 195 |
+
"""Set the speech tokenizers used for encoding and decoding speech."""
|
| 196 |
+
self.acoustic_tokenizer = acoustic_tokenizer
|
| 197 |
+
self.semantic_tokenizer = semantic_tokenizer
|
| 198 |
+
|
| 199 |
+
# Reset the encoder to evaluation mode
|
| 200 |
+
if self.acoustic_tokenizer is not None:
|
| 201 |
+
self.acoustic_tokenizer.eval()
|
| 202 |
+
|
| 203 |
+
if self.semantic_tokenizer is not None:
|
| 204 |
+
self.semantic_tokenizer.eval()
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 208 |
+
attention_mask: torch.Tensor,
|
| 209 |
+
sequence_length: int,
|
| 210 |
+
target_length: int,
|
| 211 |
+
dtype: torch.dtype,
|
| 212 |
+
device: torch.device = None,
|
| 213 |
+
cache_position: torch.Tensor = None,
|
| 214 |
+
batch_size: int = None,
|
| 215 |
+
config=None,
|
| 216 |
+
past_key_values=None,
|
| 217 |
+
**kwargs,
|
| 218 |
+
) -> torch.Tensor:
|
| 219 |
+
"""
|
| 220 |
+
Creates a 4D causal attention mask for use with static cache.
|
| 221 |
+
|
| 222 |
+
This enables torch.compile to work efficiently without recompilation
|
| 223 |
+
by providing a consistent mask shape during autoregressive generation.
|
| 224 |
+
|
| 225 |
+
Based on the standard HuggingFace implementation without sliding window
|
| 226 |
+
(KugelAudio doesn't use sliding window attention).
|
| 227 |
+
|
| 228 |
+
Compatible with both old and new transformers API.
|
| 229 |
+
"""
|
| 230 |
+
# Handle case where attention_mask is already 4D
|
| 231 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 232 |
+
return attention_mask
|
| 233 |
+
|
| 234 |
+
# Get device from attention_mask or cache_position if not provided
|
| 235 |
+
if device is None:
|
| 236 |
+
if attention_mask is not None:
|
| 237 |
+
device = attention_mask.device
|
| 238 |
+
elif cache_position is not None:
|
| 239 |
+
device = cache_position.device
|
| 240 |
+
else:
|
| 241 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 242 |
+
|
| 243 |
+
min_dtype = torch.finfo(dtype).min
|
| 244 |
+
|
| 245 |
+
# Create causal mask: (sequence_length, target_length)
|
| 246 |
+
causal_mask = torch.full(
|
| 247 |
+
(sequence_length, target_length),
|
| 248 |
+
fill_value=min_dtype,
|
| 249 |
+
dtype=dtype,
|
| 250 |
+
device=device,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if sequence_length != 1:
|
| 254 |
+
# Apply upper triangular mask (can't attend to future tokens)
|
| 255 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 256 |
+
|
| 257 |
+
# Mask positions beyond current cache position
|
| 258 |
+
if cache_position is not None:
|
| 259 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 260 |
+
|
| 261 |
+
# Expand to 4D: (batch_size, 1, sequence_length, target_length)
|
| 262 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 263 |
+
|
| 264 |
+
# Combine with input attention mask if provided
|
| 265 |
+
if attention_mask is not None:
|
| 266 |
+
causal_mask = causal_mask.clone()
|
| 267 |
+
mask_length = attention_mask.shape[-1]
|
| 268 |
+
# Create padding mask from attention_mask
|
| 269 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(dtype) * min_dtype
|
| 270 |
+
padding_mask = padding_mask == 0
|
| 271 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 272 |
+
padding_mask, min_dtype
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return causal_mask
|
| 276 |
+
|
| 277 |
+
def forward(
|
| 278 |
+
self,
|
| 279 |
+
input_ids: torch.LongTensor = None,
|
| 280 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 281 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 282 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 283 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 284 |
+
use_cache: Optional[bool] = None,
|
| 285 |
+
output_attentions: Optional[bool] = None,
|
| 286 |
+
output_hidden_states: Optional[bool] = None,
|
| 287 |
+
return_dict: Optional[bool] = None,
|
| 288 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 289 |
+
**kwargs,
|
| 290 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 291 |
+
|
| 292 |
+
return_dict = (
|
| 293 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Forward through language model
|
| 297 |
+
outputs = self.language_model(
|
| 298 |
+
input_ids=input_ids,
|
| 299 |
+
attention_mask=attention_mask,
|
| 300 |
+
position_ids=position_ids,
|
| 301 |
+
past_key_values=past_key_values,
|
| 302 |
+
inputs_embeds=inputs_embeds,
|
| 303 |
+
use_cache=use_cache,
|
| 304 |
+
output_attentions=output_attentions,
|
| 305 |
+
output_hidden_states=output_hidden_states,
|
| 306 |
+
return_dict=return_dict,
|
| 307 |
+
cache_position=cache_position,
|
| 308 |
+
**kwargs,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if not return_dict:
|
| 312 |
+
return outputs
|
| 313 |
+
|
| 314 |
+
return BaseModelOutputWithPast(
|
| 315 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 316 |
+
past_key_values=outputs.past_key_values,
|
| 317 |
+
hidden_states=outputs.hidden_states,
|
| 318 |
+
attentions=outputs.attentions,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class KugelAudioForConditionalGeneration(KugelAudioPreTrainedModel):
|
| 323 |
+
"""
|
| 324 |
+
Unified model for both training and inference.
|
| 325 |
+
|
| 326 |
+
Supports:
|
| 327 |
+
- Training via forward() with loss computation
|
| 328 |
+
- Inference via generate() for audio generation
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 332 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 333 |
+
|
| 334 |
+
def __init__(self, config):
|
| 335 |
+
super().__init__(config)
|
| 336 |
+
self.model = KugelAudioModel(config)
|
| 337 |
+
self.vocab_size = config.decoder_config.vocab_size
|
| 338 |
+
self.lm_head = nn.Linear(
|
| 339 |
+
config.decoder_config.hidden_size, self.vocab_size, bias=False
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Inference configuration (for generate() method)
|
| 343 |
+
self.ddpm_inference_steps = (
|
| 344 |
+
config.diffusion_head_config.ddpm_num_inference_steps
|
| 345 |
+
if hasattr(config, "diffusion_head_config")
|
| 346 |
+
else 5
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
self.post_init()
|
| 350 |
+
|
| 351 |
+
# Properties for easier access (used by generate())
|
| 352 |
+
@property
|
| 353 |
+
def noise_scheduler(self):
|
| 354 |
+
return self.model.noise_scheduler
|
| 355 |
+
|
| 356 |
+
@property
|
| 357 |
+
def prediction_head(self):
|
| 358 |
+
return self.model.prediction_head
|
| 359 |
+
|
| 360 |
+
def get_input_embeddings(self):
|
| 361 |
+
return self.model.get_input_embeddings()
|
| 362 |
+
|
| 363 |
+
def set_input_embeddings(self, value):
|
| 364 |
+
self.model.set_input_embeddings(value)
|
| 365 |
+
|
| 366 |
+
def get_output_embeddings(self):
|
| 367 |
+
return self.lm_head
|
| 368 |
+
|
| 369 |
+
def set_decoder(self, decoder):
|
| 370 |
+
self.model.language_model = decoder
|
| 371 |
+
|
| 372 |
+
def get_decoder(self):
|
| 373 |
+
return self.model.language_model
|
| 374 |
+
|
| 375 |
+
def tie_weights(self):
|
| 376 |
+
"""
|
| 377 |
+
Tie the weights between the input embeddings and the output embeddings.
|
| 378 |
+
"""
|
| 379 |
+
if getattr(self.config.decoder_config, "tie_word_embeddings", False):
|
| 380 |
+
# The standard PreTrainedModel method will handle the tying.
|
| 381 |
+
# It typically does a simple parameter object assignment, which is
|
| 382 |
+
# CORRECT to do BEFORE FSDP wraps the model.
|
| 383 |
+
output_embeddings = self.get_output_embeddings()
|
| 384 |
+
input_embeddings = self.get_input_embeddings()
|
| 385 |
+
if hasattr(input_embeddings, "weight"):
|
| 386 |
+
output_embeddings.weight = input_embeddings.weight
|
| 387 |
+
else:
|
| 388 |
+
# maybe returned input_embeddings a tensor directly
|
| 389 |
+
output_embeddings.weight = input_embeddings
|
| 390 |
+
|
| 391 |
+
if getattr(output_embeddings, "bias", None) is not None:
|
| 392 |
+
output_embeddings.bias.data = nn.functional.pad(
|
| 393 |
+
output_embeddings.bias.data,
|
| 394 |
+
(
|
| 395 |
+
0,
|
| 396 |
+
output_embeddings.weight.shape[0]
|
| 397 |
+
- output_embeddings.bias.shape[0],
|
| 398 |
+
),
|
| 399 |
+
"constant",
|
| 400 |
+
0,
|
| 401 |
+
)
|
| 402 |
+
print("β
Tied input and output embeddings using standard assignment.")
|
| 403 |
+
else:
|
| 404 |
+
print("βΉοΈ tie_word_embeddings is False, not tying weights.")
|
| 405 |
+
|
| 406 |
+
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
| 407 |
+
# The key is to avoid calling it after accelerator.prepare().
|
| 408 |
+
def set_output_embeddings(self, new_embeddings):
|
| 409 |
+
# Your current implementation using data.copy_ is good practice,
|
| 410 |
+
# but the best way is to not call this after prepare().
|
| 411 |
+
self.lm_head = new_embeddings
|
| 412 |
+
|
| 413 |
+
def forward_speech_features(
|
| 414 |
+
self,
|
| 415 |
+
speech_tensors=None,
|
| 416 |
+
speech_masks=None,
|
| 417 |
+
speech_type="audio",
|
| 418 |
+
return_unmask=False,
|
| 419 |
+
):
|
| 420 |
+
if speech_tensors is None:
|
| 421 |
+
# Use config to get vae_dim instead of non-existent self.args
|
| 422 |
+
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
| 423 |
+
audio_features = torch.zeros(1, 1, vae_dim).to(
|
| 424 |
+
self.get_input_embeddings().weight
|
| 425 |
+
)
|
| 426 |
+
connect_features = self.model.acoustic_connector(audio_features)
|
| 427 |
+
return audio_features, connect_features
|
| 428 |
+
else:
|
| 429 |
+
with torch.no_grad():
|
| 430 |
+
if speech_type == "audio":
|
| 431 |
+
with torch.no_grad():
|
| 432 |
+
frames_out = self.model.acoustic_tokenizer.encode(
|
| 433 |
+
speech_tensors.unsqueeze(1)
|
| 434 |
+
)
|
| 435 |
+
if isinstance(frames_out, (list, tuple)):
|
| 436 |
+
frames = frames_out[0][0]
|
| 437 |
+
else:
|
| 438 |
+
frames = frames_out
|
| 439 |
+
audio_tokens = frames.sample(
|
| 440 |
+
self.model.acoustic_tokenizer.std_dist_type
|
| 441 |
+
)[0]
|
| 442 |
+
|
| 443 |
+
elif speech_type == "vae":
|
| 444 |
+
# Use config to get vae_dim instead of non-existent self.args
|
| 445 |
+
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
| 446 |
+
speech_mode = speech_tensors.reshape(
|
| 447 |
+
speech_tensors.size(0), -1, vae_dim
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# gaussian sample from the speech_mode
|
| 451 |
+
batch_size = speech_mode.size(0)
|
| 452 |
+
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
| 453 |
+
std = (
|
| 454 |
+
torch.randn(
|
| 455 |
+
batch_size,
|
| 456 |
+
dtype=speech_mode.dtype,
|
| 457 |
+
device=speech_mode.device,
|
| 458 |
+
)
|
| 459 |
+
* value
|
| 460 |
+
)
|
| 461 |
+
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
| 462 |
+
audio_tokens = speech_mode + std * torch.randn(
|
| 463 |
+
speech_mode.shape
|
| 464 |
+
).to(speech_mode)
|
| 465 |
+
else:
|
| 466 |
+
raise NotImplementedError(
|
| 467 |
+
f"Speech type {speech_type} not implemented"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(
|
| 471 |
+
self.model.speech_bias_factor
|
| 472 |
+
):
|
| 473 |
+
scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std()
|
| 474 |
+
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
| 475 |
+
|
| 476 |
+
# Only use distributed operations if the process group is initialized
|
| 477 |
+
if dist.is_available() and dist.is_initialized():
|
| 478 |
+
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
| 479 |
+
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
| 480 |
+
world_size = dist.get_world_size()
|
| 481 |
+
self.model.speech_scaling_factor.copy_(
|
| 482 |
+
scaling_factor / world_size
|
| 483 |
+
)
|
| 484 |
+
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
| 485 |
+
print(
|
| 486 |
+
f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}",
|
| 487 |
+
flush=True,
|
| 488 |
+
)
|
| 489 |
+
else:
|
| 490 |
+
# Single process case
|
| 491 |
+
self.model.speech_scaling_factor.copy_(scaling_factor)
|
| 492 |
+
self.model.speech_bias_factor.copy_(bias_factor)
|
| 493 |
+
print(
|
| 494 |
+
f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}",
|
| 495 |
+
flush=True,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
audio_features = (
|
| 499 |
+
audio_tokens + self.model.speech_bias_factor
|
| 500 |
+
) * self.model.speech_scaling_factor
|
| 501 |
+
|
| 502 |
+
connect_features = self.model.acoustic_connector(audio_features)
|
| 503 |
+
if return_unmask:
|
| 504 |
+
return audio_features, connect_features
|
| 505 |
+
return audio_features[speech_masks], connect_features[speech_masks]
|
| 506 |
+
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
input_ids: torch.LongTensor = None,
|
| 510 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 511 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 512 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 513 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 514 |
+
labels: Optional[torch.LongTensor] = None,
|
| 515 |
+
use_cache: Optional[bool] = False,
|
| 516 |
+
output_attentions: Optional[bool] = None,
|
| 517 |
+
output_hidden_states: Optional[bool] = None,
|
| 518 |
+
return_dict: Optional[bool] = None,
|
| 519 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 520 |
+
# New arguments for speech processing and loss calculation
|
| 521 |
+
speech_tensors: Optional[torch.FloatTensor] = None,
|
| 522 |
+
speech_masks: Optional[torch.BoolTensor] = None,
|
| 523 |
+
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
| 524 |
+
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
| 525 |
+
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
| 526 |
+
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
| 527 |
+
ddpm_batch_mul: int = 1,
|
| 528 |
+
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
| 529 |
+
) -> Union[Tuple, KugelAudioCausalLMOutputWithPast]:
|
| 530 |
+
|
| 531 |
+
return_dict = (
|
| 532 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
x = self.get_input_embeddings()(input_ids)
|
| 536 |
+
|
| 537 |
+
semantic_speech_all_connect_features = self.model.semantic_connector(
|
| 538 |
+
speech_semantic_tensors
|
| 539 |
+
)
|
| 540 |
+
if speeches_loss_input is not None:
|
| 541 |
+
# only part audio need diffuse
|
| 542 |
+
speech_all_features, speech_all_connect_features = (
|
| 543 |
+
self.forward_speech_features(
|
| 544 |
+
speech_tensors=(
|
| 545 |
+
speech_tensors.type_as(x)
|
| 546 |
+
if speech_tensors is not None
|
| 547 |
+
else None
|
| 548 |
+
),
|
| 549 |
+
speech_masks=speech_masks,
|
| 550 |
+
speech_type=kwargs.get("speech_type", "audio"),
|
| 551 |
+
return_unmask=True,
|
| 552 |
+
)
|
| 553 |
+
)
|
| 554 |
+
if speech_tensors is not None:
|
| 555 |
+
if semantic_speech_all_connect_features is not None:
|
| 556 |
+
x[acoustic_input_mask] = (
|
| 557 |
+
speech_all_connect_features[speech_masks]
|
| 558 |
+
+ semantic_speech_all_connect_features[speech_masks]
|
| 559 |
+
)
|
| 560 |
+
else:
|
| 561 |
+
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
| 562 |
+
speech_features = speech_all_features[
|
| 563 |
+
speeches_loss_input & speech_masks
|
| 564 |
+
] # only part audio need diffuse
|
| 565 |
+
speech_connect_features = speech_all_connect_features[
|
| 566 |
+
speeches_loss_input & speech_masks
|
| 567 |
+
]
|
| 568 |
+
# Forward-time consistency check: selected latent count should match number of acoustic placeholders
|
| 569 |
+
try:
|
| 570 |
+
if acoustic_input_mask is not None:
|
| 571 |
+
assert speech_connect_features.shape[0] == int(
|
| 572 |
+
acoustic_input_mask.sum().item()
|
| 573 |
+
), f"Mismatch between selected speech connectors ({speech_connect_features.shape[0]}) and acoustic_input_mask sum ({int(acoustic_input_mask.sum().item())})"
|
| 574 |
+
except Exception:
|
| 575 |
+
pass
|
| 576 |
+
else:
|
| 577 |
+
speech_features, speech_connect_features = self.forward_speech_features(
|
| 578 |
+
speech_tensors=(
|
| 579 |
+
speech_tensors.type_as(x) if speech_tensors is not None else None
|
| 580 |
+
),
|
| 581 |
+
speech_masks=speech_masks,
|
| 582 |
+
speech_type=kwargs.get("speech_type", "audio"),
|
| 583 |
+
)
|
| 584 |
+
if speech_tensors is not None:
|
| 585 |
+
x[acoustic_input_mask] = speech_connect_features
|
| 586 |
+
|
| 587 |
+
outputs = self.model(
|
| 588 |
+
input_ids=None,
|
| 589 |
+
attention_mask=attention_mask,
|
| 590 |
+
position_ids=position_ids,
|
| 591 |
+
past_key_values=past_key_values,
|
| 592 |
+
inputs_embeds=x,
|
| 593 |
+
use_cache=use_cache,
|
| 594 |
+
output_attentions=output_attentions,
|
| 595 |
+
output_hidden_states=False,
|
| 596 |
+
return_dict=return_dict,
|
| 597 |
+
cache_position=cache_position,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
hidden_states = outputs.last_hidden_state
|
| 601 |
+
logits = self.lm_head(hidden_states)
|
| 602 |
+
# logits = logits.float()
|
| 603 |
+
|
| 604 |
+
loss = None
|
| 605 |
+
if labels is not None:
|
| 606 |
+
# The custom CE loss with masking is calculated in the training script.
|
| 607 |
+
# We leave the standard loss calculation here as None.
|
| 608 |
+
pass
|
| 609 |
+
|
| 610 |
+
# --- Diffusion Loss Calculation ---
|
| 611 |
+
diffusion_loss = None
|
| 612 |
+
# This block is executed only if we are in a context that involves speech.
|
| 613 |
+
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
| 614 |
+
# Build conditioning mask from positions whose NEXT token is a speech latent (shift left by 1)
|
| 615 |
+
cond_mask = torch.zeros_like(acoustic_loss_mask, dtype=torch.bool)
|
| 616 |
+
cond_mask[:, :-1] = acoustic_loss_mask[:, 1:]
|
| 617 |
+
cond_mask[:, 0] = False
|
| 618 |
+
condition_features = hidden_states[cond_mask]
|
| 619 |
+
|
| 620 |
+
speech_len, latent_size = speech_features.shape
|
| 621 |
+
# Sanity check: ensure 1:1 alignment between selected conditions and latents
|
| 622 |
+
try:
|
| 623 |
+
assert (
|
| 624 |
+
condition_features.shape[0] == speech_len
|
| 625 |
+
), f"Mismatch: condition_features={condition_features.shape[0]} vs speech_features={speech_len}"
|
| 626 |
+
except Exception:
|
| 627 |
+
pass
|
| 628 |
+
|
| 629 |
+
noise = torch.randn(
|
| 630 |
+
(speech_len * ddpm_batch_mul, latent_size),
|
| 631 |
+
device=hidden_states.device,
|
| 632 |
+
dtype=hidden_states.dtype,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
timesteps = torch.multinomial(
|
| 636 |
+
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
| 637 |
+
speech_len * ddpm_batch_mul,
|
| 638 |
+
replacement=True,
|
| 639 |
+
).to(hidden_states.device)
|
| 640 |
+
|
| 641 |
+
speech_features_repeated = speech_features.repeat_interleave(
|
| 642 |
+
ddpm_batch_mul, dim=0
|
| 643 |
+
)
|
| 644 |
+
condition_features_repeated = condition_features.repeat_interleave(
|
| 645 |
+
ddpm_batch_mul, dim=0
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
| 649 |
+
speech_features_repeated, noise, timesteps
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
model_output = self.model.prediction_head(
|
| 653 |
+
noisy_speech_features, timesteps.type_as(x), condition_features_repeated
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
prediction_type = self.config.diffusion_head_config.prediction_type
|
| 657 |
+
if prediction_type == "epsilon":
|
| 658 |
+
target_for_loss = noise
|
| 659 |
+
elif prediction_type == "v_prediction":
|
| 660 |
+
target_for_loss = self.model.noise_scheduler.get_velocity(
|
| 661 |
+
speech_features_repeated, noise, timesteps
|
| 662 |
+
)
|
| 663 |
+
else:
|
| 664 |
+
raise NotImplementedError(
|
| 665 |
+
f"Prediction type {prediction_type} not implemented"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
diffusion_loss = F.mse_loss(
|
| 669 |
+
model_output.float(), target_for_loss.float(), reduction="sum"
|
| 670 |
+
)
|
| 671 |
+
if latent_size > 0 and ddpm_batch_mul > 0:
|
| 672 |
+
# Normalize by latent dim, number of sampled diffusion steps per latent, and number of speech tokens
|
| 673 |
+
diffusion_loss = (
|
| 674 |
+
diffusion_loss / latent_size / ddpm_batch_mul / max(speech_len, 1)
|
| 675 |
+
)
|
| 676 |
+
else:
|
| 677 |
+
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
| 678 |
+
|
| 679 |
+
else:
|
| 680 |
+
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
| 681 |
+
# but we are in a speech context.
|
| 682 |
+
diffusion_loss = (
|
| 683 |
+
sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
| 684 |
+
)
|
| 685 |
+
diffusion_loss += (
|
| 686 |
+
sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
| 687 |
+
)
|
| 688 |
+
diffusion_loss += (
|
| 689 |
+
sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
| 690 |
+
)
|
| 691 |
+
# --- End Diffusion Loss Calculation ---
|
| 692 |
+
|
| 693 |
+
if not return_dict:
|
| 694 |
+
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
| 695 |
+
return (loss, diffusion_loss) + output
|
| 696 |
+
|
| 697 |
+
return KugelAudioCausalLMOutputWithPast(
|
| 698 |
+
loss=loss,
|
| 699 |
+
diffusion_loss=diffusion_loss,
|
| 700 |
+
speech_token_num=torch.tensor(
|
| 701 |
+
speech_len if speech_tensors is not None else 0,
|
| 702 |
+
device=logits.device,
|
| 703 |
+
dtype=torch.long,
|
| 704 |
+
),
|
| 705 |
+
logits=logits,
|
| 706 |
+
past_key_values=outputs.past_key_values,
|
| 707 |
+
hidden_states=outputs.hidden_states,
|
| 708 |
+
attentions=outputs.attentions,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
AutoModel.register(KugelAudioConfig, KugelAudioModel)
|
| 713 |
+
AutoModelForCausalLM.register(KugelAudioConfig, KugelAudioForConditionalGeneration)
|
| 714 |
+
|
| 715 |
+
__all__ = [
|
| 716 |
+
"KugelAudioModel",
|
| 717 |
+
"KugelAudioPreTrainedModel",
|
| 718 |
+
"KugelAudioForConditionalGeneration",
|
| 719 |
+
"KugelAudioCausalLMOutputWithPast",
|
| 720 |
+
"KugelAudioGenerationOutput",
|
| 721 |
+
]
|
kugelaudio_open/models/tokenizer.py
ADDED
|
@@ -0,0 +1,1197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import typing as tp
|
| 3 |
+
from functools import partial
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torch._dynamo
|
| 13 |
+
|
| 14 |
+
from transformers.models.auto import AutoModel
|
| 15 |
+
|
| 16 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 17 |
+
from transformers.utils import logging
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
+
from transformers.activations import ACT2FN
|
| 20 |
+
|
| 21 |
+
from ..configs import KugelAudioAcousticTokenizerConfig, KugelAudioSemanticTokenizerConfig
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
# APEX is not used in the open-source version
|
| 26 |
+
APEX_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
# Normalization modules
|
| 29 |
+
class ConvLayerNorm(nn.LayerNorm):
|
| 30 |
+
"""
|
| 31 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
| 32 |
+
before running the normalization and moves them back to original position right after.
|
| 33 |
+
"""
|
| 34 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
| 35 |
+
super().__init__(normalized_shape, **kwargs)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
x = x.transpose(1, 2) # b ... t -> b t ...
|
| 39 |
+
x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
|
| 40 |
+
x = x.transpose(1, 2) # b t ... -> b ... t
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
class RMSNorm(nn.Module):
|
| 44 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.dim = dim
|
| 47 |
+
self.eps = eps
|
| 48 |
+
self.elementwise_affine = elementwise_affine
|
| 49 |
+
if self.elementwise_affine:
|
| 50 |
+
weight_shape = (dim,) if weight_shape is None else weight_shape
|
| 51 |
+
self.weight = nn.Parameter(torch.ones(weight_shape))
|
| 52 |
+
else:
|
| 53 |
+
self.register_parameter('weight', None)
|
| 54 |
+
|
| 55 |
+
def _norm(self, x):
|
| 56 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
output = self._norm(x.float()).type_as(x)
|
| 60 |
+
if self.weight is not None:
|
| 61 |
+
output = output * self.weight
|
| 62 |
+
return output
|
| 63 |
+
|
| 64 |
+
def extra_repr(self) -> str:
|
| 65 |
+
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
| 66 |
+
|
| 67 |
+
class ConvRMSNorm(RMSNorm):
|
| 68 |
+
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| 69 |
+
super().__init__(dim, eps, elementwise_affine, weight_shape)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = x.transpose(1, 2) # b ... t -> b t ...
|
| 73 |
+
if (not APEX_AVAILABLE) or (not self.elementwise_affine):
|
| 74 |
+
# Fallback to native implementation
|
| 75 |
+
output = self._norm(x.float()).type_as(x)
|
| 76 |
+
if self.weight is not None:
|
| 77 |
+
output = output * self.weight
|
| 78 |
+
else:
|
| 79 |
+
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
|
| 80 |
+
output = output.transpose(1, 2) # b t ... -> b ... t
|
| 81 |
+
return output
|
| 82 |
+
|
| 83 |
+
# Convolutional layers and utilities
|
| 84 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 85 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
| 89 |
+
assert norm in CONV_NORMALIZATIONS
|
| 90 |
+
if norm == 'weight_norm':
|
| 91 |
+
return nn.utils.weight_norm(module)
|
| 92 |
+
elif norm == 'spectral_norm':
|
| 93 |
+
return nn.utils.spectral_norm(module)
|
| 94 |
+
else:
|
| 95 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 96 |
+
# doesn't need reparametrization.
|
| 97 |
+
return module
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
| 101 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
| 102 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| 103 |
+
"""
|
| 104 |
+
assert norm in CONV_NORMALIZATIONS
|
| 105 |
+
if norm == 'layer_norm':
|
| 106 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 107 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
| 108 |
+
elif norm == 'time_group_norm':
|
| 109 |
+
if causal:
|
| 110 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 111 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 112 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 113 |
+
else:
|
| 114 |
+
return nn.Identity()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| 118 |
+
padding_total: int = 0) -> int:
|
| 119 |
+
"""Calculate extra padding needed for convolution to have the same output length"""
|
| 120 |
+
length = x.shape[-1]
|
| 121 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 122 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 123 |
+
return ideal_length - length
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
| 127 |
+
"""Pad 1D input with handling for small inputs in reflect mode"""
|
| 128 |
+
length = x.shape[-1]
|
| 129 |
+
padding_left, padding_right = paddings
|
| 130 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 131 |
+
if mode == 'reflect':
|
| 132 |
+
max_pad = max(padding_left, padding_right)
|
| 133 |
+
extra_pad = 0
|
| 134 |
+
if length <= max_pad:
|
| 135 |
+
extra_pad = max_pad - length + 1
|
| 136 |
+
x = F.pad(x, (0, extra_pad))
|
| 137 |
+
padded = F.pad(x, paddings, mode, value)
|
| 138 |
+
end = padded.shape[-1] - extra_pad
|
| 139 |
+
return padded[..., :end]
|
| 140 |
+
else:
|
| 141 |
+
return F.pad(x, paddings, mode, value)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 145 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 146 |
+
padding_left, padding_right = paddings
|
| 147 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 148 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 149 |
+
end = x.shape[-1] - padding_right
|
| 150 |
+
return x[..., padding_left: end]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class NormConv1d(nn.Module):
|
| 154 |
+
"""Wrapper around Conv1d and normalization applied to this conv"""
|
| 155 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 156 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 159 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 160 |
+
self.norm_type = norm
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
x = self.conv(x)
|
| 164 |
+
x = self.norm(x)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class NormConvTranspose1d(nn.Module):
|
| 169 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv"""
|
| 170 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 171 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| 174 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 175 |
+
self.norm_type = norm
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
x = self.convtr(x)
|
| 179 |
+
x = self.norm(x)
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class KugelAudioTokenizerStreamingCache:
|
| 184 |
+
"""Cache for streaming convolution, similar to KV cache in attention"""
|
| 185 |
+
def __init__(self):
|
| 186 |
+
self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
|
| 187 |
+
|
| 188 |
+
def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
|
| 189 |
+
"""Get cached states for given layer and sample indices"""
|
| 190 |
+
states = []
|
| 191 |
+
max_length = 0
|
| 192 |
+
|
| 193 |
+
# First pass: collect states and find max length
|
| 194 |
+
for idx in sample_indices.tolist():
|
| 195 |
+
key = (layer_id, idx)
|
| 196 |
+
if key not in self.cache:
|
| 197 |
+
return None # If any sample is missing, return None
|
| 198 |
+
state = self.cache[key]
|
| 199 |
+
states.append(state)
|
| 200 |
+
max_length = max(max_length, state.shape[-1])
|
| 201 |
+
|
| 202 |
+
# Second pass: pad states to max length if needed
|
| 203 |
+
if len(states) > 0 and states[0].dim() >= 2:
|
| 204 |
+
padded_states = []
|
| 205 |
+
for state in states:
|
| 206 |
+
if state.shape[-1] < max_length:
|
| 207 |
+
# Pad on the time dimension (last dimension)
|
| 208 |
+
pad_size = max_length - state.shape[-1]
|
| 209 |
+
# Pad with zeros on the LEFT to align the most recent samples
|
| 210 |
+
padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
|
| 211 |
+
padded_states.append(padded_state)
|
| 212 |
+
else:
|
| 213 |
+
padded_states.append(state)
|
| 214 |
+
return torch.stack(padded_states, dim=0)
|
| 215 |
+
else:
|
| 216 |
+
return torch.stack(states, dim=0)
|
| 217 |
+
|
| 218 |
+
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
|
| 219 |
+
"""Set cached states for given layer and sample indices"""
|
| 220 |
+
for i, idx in enumerate(sample_indices.tolist()):
|
| 221 |
+
key = (layer_id, idx)
|
| 222 |
+
self.cache[key] = states[i].detach()
|
| 223 |
+
|
| 224 |
+
def set_to_zero(self, sample_indices: torch.Tensor):
|
| 225 |
+
"""Set all cached states to zero for given sample indices"""
|
| 226 |
+
for key in list(self.cache.keys()):
|
| 227 |
+
layer_id, sample_idx = key
|
| 228 |
+
if sample_idx in sample_indices.tolist():
|
| 229 |
+
# Create zero tensor with same shape and dtype as cached tensor
|
| 230 |
+
cached_tensor = self.cache[key]
|
| 231 |
+
self.cache[key] = torch.zeros_like(cached_tensor)
|
| 232 |
+
|
| 233 |
+
def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
|
| 234 |
+
"""Clear cache for specific layer/samples or everything"""
|
| 235 |
+
if layer_id is None and sample_indices is None:
|
| 236 |
+
self.cache.clear()
|
| 237 |
+
elif layer_id is not None and sample_indices is None:
|
| 238 |
+
# Clear all samples for a specific layer
|
| 239 |
+
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
|
| 240 |
+
for k in keys_to_remove:
|
| 241 |
+
del self.cache[k]
|
| 242 |
+
elif layer_id is not None and sample_indices is not None:
|
| 243 |
+
# Clear specific samples for a specific layer
|
| 244 |
+
for idx in sample_indices.tolist():
|
| 245 |
+
key = (layer_id, idx)
|
| 246 |
+
self.cache.pop(key, None)
|
| 247 |
+
|
| 248 |
+
class SConv1d(nn.Module):
|
| 249 |
+
"""Conv1d with built-in handling of asymmetric or causal padding and normalization."""
|
| 250 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 251 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
| 252 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
| 253 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 254 |
+
pad_mode: str = 'reflect'):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| 257 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| 258 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
| 259 |
+
self.causal = causal
|
| 260 |
+
self.pad_mode = pad_mode
|
| 261 |
+
|
| 262 |
+
# Store configuration
|
| 263 |
+
self.kernel_size = kernel_size
|
| 264 |
+
self.dilation = dilation
|
| 265 |
+
self.stride = stride
|
| 266 |
+
self.in_channels = in_channels
|
| 267 |
+
self.out_channels = out_channels
|
| 268 |
+
|
| 269 |
+
# For causal convolution, we need to maintain kernel_size - 1 samples as context
|
| 270 |
+
# need to check use which context_size is more suitable
|
| 271 |
+
# self.context_size = (kernel_size - 1) * dilation
|
| 272 |
+
self.context_size = (kernel_size - 1) * dilation - (stride - 1)
|
| 273 |
+
|
| 274 |
+
# For non-streaming mode, calculate padding
|
| 275 |
+
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
| 276 |
+
|
| 277 |
+
# Create a unique layer ID for cache management
|
| 278 |
+
self._layer_id = None
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def layer_id(self):
|
| 282 |
+
if self._layer_id is None:
|
| 283 |
+
self._layer_id = f"sconv1d_{id(self)}"
|
| 284 |
+
return self._layer_id
|
| 285 |
+
|
| 286 |
+
def forward(self, x: torch.Tensor,
|
| 287 |
+
cache: Optional[KugelAudioTokenizerStreamingCache] = None,
|
| 288 |
+
sample_indices: Optional[torch.Tensor] = None,
|
| 289 |
+
use_cache: bool = False,
|
| 290 |
+
debug: bool = False) -> torch.Tensor:
|
| 291 |
+
"""
|
| 292 |
+
Forward pass with optional streaming support via cache.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
x: Input tensor [batch_size, channels, time]
|
| 296 |
+
cache: KugelAudioTokenizerStreamingCache object for maintaining states
|
| 297 |
+
sample_indices: Indices identifying each sample for cache management
|
| 298 |
+
use_cache: Whether to use cached states for streaming
|
| 299 |
+
debug: Whether to print debug information
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Output tensor
|
| 303 |
+
"""
|
| 304 |
+
B, C, T = x.shape
|
| 305 |
+
|
| 306 |
+
# Non-streaming mode
|
| 307 |
+
if not use_cache or cache is None:
|
| 308 |
+
return self._forward_non_streaming(x, debug=debug)
|
| 309 |
+
|
| 310 |
+
# Streaming mode
|
| 311 |
+
assert self.causal, "Streaming mode is only supported for causal convolutions"
|
| 312 |
+
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
| 313 |
+
assert len(sample_indices) == B, "sample_indices must match batch size"
|
| 314 |
+
|
| 315 |
+
return self._forward_streaming(x, cache, sample_indices, debug)
|
| 316 |
+
|
| 317 |
+
@torch._dynamo.disable() # Disable compilation for streaming path - dynamic cache ops cause recompilations
|
| 318 |
+
def _forward_streaming(self, x: torch.Tensor,
|
| 319 |
+
cache: KugelAudioTokenizerStreamingCache,
|
| 320 |
+
sample_indices: torch.Tensor,
|
| 321 |
+
debug: bool = False) -> torch.Tensor:
|
| 322 |
+
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
| 323 |
+
B, C, T = x.shape
|
| 324 |
+
|
| 325 |
+
# Cache operations (not compiled)
|
| 326 |
+
cached_states = cache.get(self.layer_id, sample_indices)
|
| 327 |
+
|
| 328 |
+
if cached_states is None:
|
| 329 |
+
# First chunk - initialize with zeros for context
|
| 330 |
+
if self.context_size > 0:
|
| 331 |
+
cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
|
| 332 |
+
if debug:
|
| 333 |
+
print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
|
| 334 |
+
else:
|
| 335 |
+
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
| 336 |
+
if debug:
|
| 337 |
+
print(f"[DEBUG] No context needed (kernel_size=stride)")
|
| 338 |
+
|
| 339 |
+
# Concatenate cached states with input
|
| 340 |
+
if cached_states.shape[2] > 0:
|
| 341 |
+
input_with_context = torch.cat([cached_states, x], dim=2)
|
| 342 |
+
else:
|
| 343 |
+
input_with_context = x
|
| 344 |
+
|
| 345 |
+
if debug:
|
| 346 |
+
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
| 347 |
+
|
| 348 |
+
# Apply convolution directly - no extra padding in streaming mode
|
| 349 |
+
# The conv layer will handle its own padding internally
|
| 350 |
+
output = self.conv(input_with_context)
|
| 351 |
+
|
| 352 |
+
if debug:
|
| 353 |
+
print(f"[DEBUG] Output shape: {output.shape}")
|
| 354 |
+
|
| 355 |
+
# Update cache for next chunk
|
| 356 |
+
if self.context_size > 0:
|
| 357 |
+
# Calculate how many samples to keep
|
| 358 |
+
total_input_length = input_with_context.shape[2]
|
| 359 |
+
|
| 360 |
+
# Keep the last context_size samples
|
| 361 |
+
if total_input_length >= self.context_size:
|
| 362 |
+
new_cache_start = total_input_length - self.context_size
|
| 363 |
+
new_cache = input_with_context[:, :, new_cache_start:]
|
| 364 |
+
else:
|
| 365 |
+
# If we have less than context_size samples, keep everything
|
| 366 |
+
new_cache = input_with_context
|
| 367 |
+
|
| 368 |
+
if debug:
|
| 369 |
+
print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
| 370 |
+
|
| 371 |
+
cache.set(self.layer_id, sample_indices, new_cache)
|
| 372 |
+
|
| 373 |
+
return output
|
| 374 |
+
|
| 375 |
+
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
| 376 |
+
"""Standard forward pass without streaming"""
|
| 377 |
+
B, C, T = x.shape
|
| 378 |
+
kernel_size = self.kernel_size
|
| 379 |
+
stride = self.stride
|
| 380 |
+
dilation = self.dilation
|
| 381 |
+
padding_total = self.padding_total
|
| 382 |
+
|
| 383 |
+
# Ensure weight is on the same device as input
|
| 384 |
+
if hasattr(self, "conv") and hasattr(self.conv, "conv") and hasattr(self.conv.conv, "weight"):
|
| 385 |
+
if self.conv.conv.weight.device != x.device:
|
| 386 |
+
self.conv.conv.to(x.device)
|
| 387 |
+
|
| 388 |
+
# Compute extra padding for stride alignment
|
| 389 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 390 |
+
|
| 391 |
+
if debug:
|
| 392 |
+
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
|
| 393 |
+
|
| 394 |
+
if self.causal:
|
| 395 |
+
# Left padding for causal
|
| 396 |
+
if self.pad_mode == 'constant':
|
| 397 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
|
| 398 |
+
else:
|
| 399 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 400 |
+
else:
|
| 401 |
+
# Symmetric padding for non-causal
|
| 402 |
+
padding_right = padding_total // 2
|
| 403 |
+
padding_left = padding_total - padding_right
|
| 404 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
| 405 |
+
|
| 406 |
+
if debug:
|
| 407 |
+
print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
|
| 408 |
+
|
| 409 |
+
output = self.conv(x)
|
| 410 |
+
|
| 411 |
+
if debug:
|
| 412 |
+
print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
|
| 413 |
+
|
| 414 |
+
return output
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class SConvTranspose1d(nn.Module):
|
| 418 |
+
"""ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
|
| 419 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 420 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
| 421 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
| 422 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| 425 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
|
| 426 |
+
self.causal = causal
|
| 427 |
+
self.trim_right_ratio = trim_right_ratio
|
| 428 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
| 429 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 430 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
| 431 |
+
|
| 432 |
+
# Store configuration
|
| 433 |
+
self.kernel_size = kernel_size
|
| 434 |
+
self.stride = stride
|
| 435 |
+
self.in_channels = in_channels
|
| 436 |
+
self.out_channels = out_channels
|
| 437 |
+
|
| 438 |
+
# For transposed convolution, padding calculation is different
|
| 439 |
+
self.padding_total = kernel_size - stride
|
| 440 |
+
|
| 441 |
+
# For streaming, we need to keep track of input history
|
| 442 |
+
# Transposed conv needs to see multiple input samples to produce correct output
|
| 443 |
+
self.context_size = kernel_size - 1
|
| 444 |
+
|
| 445 |
+
# Create a unique layer ID for cache management
|
| 446 |
+
self._layer_id = None
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def layer_id(self):
|
| 450 |
+
if self._layer_id is None:
|
| 451 |
+
self._layer_id = f"sconvtr1d_{id(self)}"
|
| 452 |
+
return self._layer_id
|
| 453 |
+
|
| 454 |
+
def forward(self, x: torch.Tensor,
|
| 455 |
+
cache: Optional[KugelAudioTokenizerStreamingCache] = None,
|
| 456 |
+
sample_indices: Optional[torch.Tensor] = None,
|
| 457 |
+
use_cache: bool = False,
|
| 458 |
+
debug: bool = False) -> torch.Tensor:
|
| 459 |
+
"""
|
| 460 |
+
Forward pass with optional streaming support via cache.
|
| 461 |
+
"""
|
| 462 |
+
B, C, T = x.shape
|
| 463 |
+
|
| 464 |
+
# Non-streaming mode
|
| 465 |
+
if not use_cache or cache is None:
|
| 466 |
+
return self._forward_non_streaming(x, debug=debug)
|
| 467 |
+
|
| 468 |
+
# Streaming mode
|
| 469 |
+
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
| 470 |
+
assert len(sample_indices) == B, "sample_indices must match batch size"
|
| 471 |
+
|
| 472 |
+
return self._forward_streaming(x, cache, sample_indices, debug)
|
| 473 |
+
|
| 474 |
+
@torch._dynamo.disable() # Disable compilation for streaming path - dynamic cache ops cause recompilations
|
| 475 |
+
def _forward_streaming(self, x: torch.Tensor,
|
| 476 |
+
cache: KugelAudioTokenizerStreamingCache,
|
| 477 |
+
sample_indices: torch.Tensor,
|
| 478 |
+
debug: bool = False) -> torch.Tensor:
|
| 479 |
+
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
| 480 |
+
B, C, T = x.shape
|
| 481 |
+
|
| 482 |
+
# Cache operations (not compiled)
|
| 483 |
+
cached_input = cache.get(self.layer_id, sample_indices)
|
| 484 |
+
|
| 485 |
+
if cached_input is None:
|
| 486 |
+
# First chunk - no history yet
|
| 487 |
+
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
| 488 |
+
if debug:
|
| 489 |
+
print(f"[DEBUG] Initialized empty cache for transposed conv")
|
| 490 |
+
|
| 491 |
+
# Concatenate cached input with new input
|
| 492 |
+
full_input = torch.cat([cached_input, x], dim=2)
|
| 493 |
+
|
| 494 |
+
if debug:
|
| 495 |
+
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
|
| 496 |
+
|
| 497 |
+
# First chunk or debug mode - use uncompiled version
|
| 498 |
+
full_output = self.convtr(full_input)
|
| 499 |
+
|
| 500 |
+
if debug:
|
| 501 |
+
print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
|
| 502 |
+
|
| 503 |
+
# Calculate padding to remove
|
| 504 |
+
if self.causal:
|
| 505 |
+
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
| 506 |
+
padding_left = self.padding_total - padding_right
|
| 507 |
+
else:
|
| 508 |
+
padding_right = self.padding_total // 2
|
| 509 |
+
padding_left = self.padding_total - padding_right
|
| 510 |
+
|
| 511 |
+
# Remove padding
|
| 512 |
+
if padding_left + padding_right > 0:
|
| 513 |
+
full_output = unpad1d(full_output, (padding_left, padding_right))
|
| 514 |
+
|
| 515 |
+
if debug:
|
| 516 |
+
print(f"[DEBUG] After unpadding: {full_output.shape}")
|
| 517 |
+
|
| 518 |
+
# Determine which part of the output corresponds to the new input
|
| 519 |
+
if cached_input.shape[2] == 0:
|
| 520 |
+
# First chunk - return all output
|
| 521 |
+
output = full_output
|
| 522 |
+
else:
|
| 523 |
+
# Subsequent chunks - return only the new output
|
| 524 |
+
expected_new_output = T * self.stride
|
| 525 |
+
|
| 526 |
+
# Take the last expected_new_output samples
|
| 527 |
+
if full_output.shape[2] >= expected_new_output:
|
| 528 |
+
output = full_output[:, :, -expected_new_output:]
|
| 529 |
+
else:
|
| 530 |
+
output = full_output
|
| 531 |
+
|
| 532 |
+
if debug:
|
| 533 |
+
print(f"[DEBUG] Final streaming output shape: {output.shape}")
|
| 534 |
+
|
| 535 |
+
# Update cache
|
| 536 |
+
if full_input.shape[2] > self.context_size:
|
| 537 |
+
new_cache = full_input[:, :, -self.context_size:]
|
| 538 |
+
else:
|
| 539 |
+
new_cache = full_input
|
| 540 |
+
|
| 541 |
+
if debug:
|
| 542 |
+
print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
| 543 |
+
|
| 544 |
+
cache.set(self.layer_id, sample_indices, new_cache)
|
| 545 |
+
|
| 546 |
+
return output
|
| 547 |
+
|
| 548 |
+
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
| 549 |
+
"""Standard forward pass without streaming"""
|
| 550 |
+
# Ensure weight is on the same device as input
|
| 551 |
+
if hasattr(self, "convtr") and hasattr(self.convtr, "convtr") and hasattr(self.convtr.convtr, "weight"):
|
| 552 |
+
if self.convtr.convtr.weight.device != x.device:
|
| 553 |
+
self.convtr.convtr.to(x.device)
|
| 554 |
+
|
| 555 |
+
if debug:
|
| 556 |
+
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
|
| 557 |
+
|
| 558 |
+
# Apply transposed convolution
|
| 559 |
+
y = self.convtr(x)
|
| 560 |
+
|
| 561 |
+
if debug:
|
| 562 |
+
print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
|
| 563 |
+
|
| 564 |
+
# Calculate and remove padding
|
| 565 |
+
if self.causal:
|
| 566 |
+
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
| 567 |
+
padding_left = self.padding_total - padding_right
|
| 568 |
+
else:
|
| 569 |
+
padding_right = self.padding_total // 2
|
| 570 |
+
padding_left = self.padding_total - padding_right
|
| 571 |
+
|
| 572 |
+
if padding_left + padding_right > 0:
|
| 573 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 574 |
+
|
| 575 |
+
if debug:
|
| 576 |
+
print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
|
| 577 |
+
|
| 578 |
+
return y
|
| 579 |
+
|
| 580 |
+
# FFN
|
| 581 |
+
class FFN(nn.Module):
|
| 582 |
+
def __init__(
|
| 583 |
+
self,
|
| 584 |
+
embed_dim,
|
| 585 |
+
ffn_dim,
|
| 586 |
+
bias=False,
|
| 587 |
+
):
|
| 588 |
+
super().__init__()
|
| 589 |
+
self.embed_dim = embed_dim
|
| 590 |
+
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
|
| 591 |
+
self.gelu = ACT2FN["gelu"]
|
| 592 |
+
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
|
| 593 |
+
|
| 594 |
+
def forward(self, x):
|
| 595 |
+
x = self.linear1(x)
|
| 596 |
+
x = self.gelu(x)
|
| 597 |
+
x = self.linear2(x)
|
| 598 |
+
return x
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class Convlayer(nn.Module):
|
| 602 |
+
def __init__(
|
| 603 |
+
self,
|
| 604 |
+
in_channels,
|
| 605 |
+
out_channels,
|
| 606 |
+
kernel_size,
|
| 607 |
+
stride=1,
|
| 608 |
+
dilation=1,
|
| 609 |
+
groups=1,
|
| 610 |
+
bias=True,
|
| 611 |
+
pad_mode='zeros',
|
| 612 |
+
norm='weight_norm',
|
| 613 |
+
causal=True,
|
| 614 |
+
):
|
| 615 |
+
super().__init__()
|
| 616 |
+
self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
|
| 617 |
+
groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
|
| 618 |
+
|
| 619 |
+
def forward(self, x):
|
| 620 |
+
return self.conv(x)
|
| 621 |
+
|
| 622 |
+
class Block1D(nn.Module):
|
| 623 |
+
def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
|
| 624 |
+
layer_scale_init_value=1e-6, **kwargs):
|
| 625 |
+
super().__init__()
|
| 626 |
+
|
| 627 |
+
if kwargs.get('layernorm', 'LN') == 'LN':
|
| 628 |
+
self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 629 |
+
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 630 |
+
elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
|
| 631 |
+
self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 632 |
+
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| 633 |
+
|
| 634 |
+
if mixer_layer == 'conv':
|
| 635 |
+
self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
|
| 636 |
+
kernel_size=kernel_size,
|
| 637 |
+
pad_mode=kwargs.get('pad_mode', 'reflect'),
|
| 638 |
+
norm=kwargs.get('norm', 'none'),
|
| 639 |
+
causal=kwargs.get('causal', True),
|
| 640 |
+
bias=kwargs.get('bias', True),
|
| 641 |
+
)
|
| 642 |
+
elif mixer_layer == 'depthwise_conv':
|
| 643 |
+
self.mixer = Convlayer(dim, dim, groups=dim,
|
| 644 |
+
kernel_size=kernel_size,
|
| 645 |
+
pad_mode=kwargs.get('pad_mode', 'reflect'),
|
| 646 |
+
norm=kwargs.get('norm', 'none'),
|
| 647 |
+
causal=kwargs.get('causal', True),
|
| 648 |
+
bias=kwargs.get('bias', True),
|
| 649 |
+
)
|
| 650 |
+
else:
|
| 651 |
+
raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
|
| 652 |
+
|
| 653 |
+
self.ffn = FFN(
|
| 654 |
+
dim,
|
| 655 |
+
kwargs.get('ffn_expansion', 4) * dim,
|
| 656 |
+
bias=kwargs.get('bias', False),
|
| 657 |
+
)
|
| 658 |
+
self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
|
| 659 |
+
|
| 660 |
+
if layer_scale_init_value > 0:
|
| 661 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 662 |
+
self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 663 |
+
else:
|
| 664 |
+
self.gamma = None
|
| 665 |
+
self.ffn_gamma = None
|
| 666 |
+
|
| 667 |
+
def forward(self, x):
|
| 668 |
+
# mixer
|
| 669 |
+
residual = x
|
| 670 |
+
x = self.norm(x)
|
| 671 |
+
x = self.mixer(x)
|
| 672 |
+
if self.gamma is not None:
|
| 673 |
+
x = x * self.gamma.unsqueeze(-1)
|
| 674 |
+
x = residual + self.drop_path(x)
|
| 675 |
+
|
| 676 |
+
# ffn
|
| 677 |
+
residual = x
|
| 678 |
+
x = self.ffn_norm(x)
|
| 679 |
+
x = x.permute(0, 2, 1)
|
| 680 |
+
x = self.ffn(x)
|
| 681 |
+
x = x.permute(0, 2, 1)
|
| 682 |
+
if self.ffn_gamma is not None:
|
| 683 |
+
x = x * self.ffn_gamma.unsqueeze(-1)
|
| 684 |
+
x = residual + self.drop_path(x)
|
| 685 |
+
|
| 686 |
+
return x
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
class TokenizerEncoder(nn.Module):
|
| 690 |
+
"""
|
| 691 |
+
Encoder component for the KugelAudio tokenizer that converts audio to latent representations.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
config: Configuration object with model parameters
|
| 695 |
+
"""
|
| 696 |
+
def __init__(self, config):
|
| 697 |
+
super().__init__()
|
| 698 |
+
|
| 699 |
+
# Extract parameters from config
|
| 700 |
+
self.channels = config.channels
|
| 701 |
+
self.dimension = config.dimension
|
| 702 |
+
self.n_filters = config.n_filters
|
| 703 |
+
self.ratios = list(reversed(config.ratios))
|
| 704 |
+
self.depths = config.depths
|
| 705 |
+
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
| 706 |
+
self.hop_length = np.prod(self.ratios)
|
| 707 |
+
self.causal = config.causal
|
| 708 |
+
|
| 709 |
+
# Additional config parameters with defaults
|
| 710 |
+
kernel_size = getattr(config, "kernel_size", 7)
|
| 711 |
+
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
| 712 |
+
norm = getattr(config, "norm", "none")
|
| 713 |
+
norm_params = getattr(config, "norm_params", {})
|
| 714 |
+
pad_mode = getattr(config, "pad_mode", "reflect")
|
| 715 |
+
bias = getattr(config, "bias", True)
|
| 716 |
+
layernorm = getattr(config, "layernorm", "LN")
|
| 717 |
+
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
| 718 |
+
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
| 719 |
+
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
| 720 |
+
mixer_layer = getattr(config, "mixer_layer", "conv")
|
| 721 |
+
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
| 722 |
+
disable_last_norm = getattr(config, "disable_last_norm", False)
|
| 723 |
+
|
| 724 |
+
# determine the norm type based on layernorm
|
| 725 |
+
if layernorm == 'LN':
|
| 726 |
+
norm_type = ConvLayerNorm
|
| 727 |
+
elif layernorm == 'RMSNorm':
|
| 728 |
+
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
| 729 |
+
else:
|
| 730 |
+
raise ValueError(f"Unsupported norm type: {layernorm}")
|
| 731 |
+
|
| 732 |
+
# stem and intermediate downsampling conv layers
|
| 733 |
+
stem = nn.Sequential(
|
| 734 |
+
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
self.downsample_layers = nn.ModuleList()
|
| 738 |
+
self.downsample_layers.append(stem)
|
| 739 |
+
for i in range(len(self.ratios)):
|
| 740 |
+
in_ch = self.n_filters * (2 ** i)
|
| 741 |
+
out_ch = self.n_filters * (2 ** (i + 1))
|
| 742 |
+
downsample_layer = nn.Sequential(
|
| 743 |
+
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
| 744 |
+
)
|
| 745 |
+
self.downsample_layers.append(downsample_layer)
|
| 746 |
+
|
| 747 |
+
# configure the transformer blocks
|
| 748 |
+
layer_type = partial(
|
| 749 |
+
Block1D,
|
| 750 |
+
mixer_layer=mixer_layer,
|
| 751 |
+
layernorm=layernorm,
|
| 752 |
+
eps=layernorm_eps,
|
| 753 |
+
causal=self.causal,
|
| 754 |
+
pad_mode=pad_mode,
|
| 755 |
+
norm=norm,
|
| 756 |
+
bias=bias,
|
| 757 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
self.stages = nn.ModuleList()
|
| 761 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| 762 |
+
cur = 0
|
| 763 |
+
|
| 764 |
+
for i in range(len(self.depths)):
|
| 765 |
+
in_ch = self.n_filters * (2 ** i)
|
| 766 |
+
stage = nn.Sequential(
|
| 767 |
+
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
| 768 |
+
)
|
| 769 |
+
self.stages.append(stage)
|
| 770 |
+
cur += self.depths[i]
|
| 771 |
+
|
| 772 |
+
if not disable_last_norm:
|
| 773 |
+
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
| 774 |
+
else:
|
| 775 |
+
self.norm = nn.Identity()
|
| 776 |
+
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
| 777 |
+
|
| 778 |
+
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 779 |
+
for i in range(len(self.depths)):
|
| 780 |
+
# Apply downsampling
|
| 781 |
+
for layer in self.downsample_layers[i]:
|
| 782 |
+
if isinstance(layer, SConv1d):
|
| 783 |
+
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 784 |
+
else:
|
| 785 |
+
x = layer(x)
|
| 786 |
+
|
| 787 |
+
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
| 788 |
+
for block in self.stages[i]:
|
| 789 |
+
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
| 790 |
+
# Block1D forward with cache support
|
| 791 |
+
residual = x
|
| 792 |
+
x = block.norm(x)
|
| 793 |
+
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 794 |
+
if block.gamma is not None:
|
| 795 |
+
x = x * block.gamma.unsqueeze(-1)
|
| 796 |
+
x = residual + x
|
| 797 |
+
|
| 798 |
+
# FFN part
|
| 799 |
+
residual = x
|
| 800 |
+
x = block.ffn_norm(x)
|
| 801 |
+
x = x.permute(0, 2, 1)
|
| 802 |
+
x = block.ffn(x)
|
| 803 |
+
x = x.permute(0, 2, 1)
|
| 804 |
+
if block.ffn_gamma is not None:
|
| 805 |
+
x = x * block.ffn_gamma.unsqueeze(-1)
|
| 806 |
+
x = residual + x
|
| 807 |
+
else:
|
| 808 |
+
x = block(x)
|
| 809 |
+
|
| 810 |
+
return self.norm(x)
|
| 811 |
+
|
| 812 |
+
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 813 |
+
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 814 |
+
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 815 |
+
return x
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
class TokenizerDecoder(nn.Module):
|
| 819 |
+
"""
|
| 820 |
+
Decoder component for the KugelAudio tokenizer that converts latent representations back to audio.
|
| 821 |
+
|
| 822 |
+
Args:
|
| 823 |
+
config: Configuration object with model parameters
|
| 824 |
+
"""
|
| 825 |
+
def __init__(self, config):
|
| 826 |
+
super().__init__()
|
| 827 |
+
|
| 828 |
+
# Extract parameters from config
|
| 829 |
+
self.dimension = config.dimension
|
| 830 |
+
self.channels = config.channels
|
| 831 |
+
self.n_filters = config.n_filters
|
| 832 |
+
self.ratios = config.ratios
|
| 833 |
+
|
| 834 |
+
# IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in KugelAudioAcousticTokenizerModel
|
| 835 |
+
self.depths = config.depths # Changed from list(reversed(config.depths))
|
| 836 |
+
|
| 837 |
+
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
| 838 |
+
self.hop_length = np.prod(self.ratios)
|
| 839 |
+
self.causal = config.causal
|
| 840 |
+
|
| 841 |
+
# Additional config parameters with defaults
|
| 842 |
+
kernel_size = getattr(config, "kernel_size", 7)
|
| 843 |
+
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
| 844 |
+
norm = getattr(config, "norm", "none")
|
| 845 |
+
norm_params = getattr(config, "norm_params", {})
|
| 846 |
+
pad_mode = getattr(config, "pad_mode", "reflect")
|
| 847 |
+
bias = getattr(config, "bias", True)
|
| 848 |
+
layernorm = getattr(config, "layernorm", "LN")
|
| 849 |
+
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
| 850 |
+
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
|
| 851 |
+
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
| 852 |
+
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
| 853 |
+
mixer_layer = getattr(config, "mixer_layer", "conv")
|
| 854 |
+
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
| 855 |
+
disable_last_norm = getattr(config, "disable_last_norm", False)
|
| 856 |
+
|
| 857 |
+
# determine the norm type based on layernorm
|
| 858 |
+
if layernorm == 'LN':
|
| 859 |
+
norm_type = ConvLayerNorm
|
| 860 |
+
elif layernorm == 'RMSNorm':
|
| 861 |
+
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
| 862 |
+
else:
|
| 863 |
+
raise ValueError(f"Unsupported norm type: {layernorm}")
|
| 864 |
+
|
| 865 |
+
# stem and upsampling layers
|
| 866 |
+
stem = nn.Sequential(
|
| 867 |
+
SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
|
| 868 |
+
norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
self.upsample_layers = nn.ModuleList()
|
| 872 |
+
self.upsample_layers.append(stem)
|
| 873 |
+
for i in range(len(self.ratios)):
|
| 874 |
+
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
| 875 |
+
out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
|
| 876 |
+
upsample_layer = nn.Sequential(
|
| 877 |
+
SConvTranspose1d(in_ch, out_ch,
|
| 878 |
+
kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
|
| 879 |
+
norm=norm, norm_kwargs=norm_params, bias=bias,
|
| 880 |
+
causal=self.causal, trim_right_ratio=trim_right_ratio),
|
| 881 |
+
)
|
| 882 |
+
self.upsample_layers.append(upsample_layer)
|
| 883 |
+
|
| 884 |
+
# configure transformer blocks
|
| 885 |
+
layer_type = partial(
|
| 886 |
+
Block1D,
|
| 887 |
+
mixer_layer=mixer_layer,
|
| 888 |
+
layernorm=layernorm,
|
| 889 |
+
eps=layernorm_eps,
|
| 890 |
+
causal=self.causal,
|
| 891 |
+
pad_mode=pad_mode,
|
| 892 |
+
norm=norm,
|
| 893 |
+
bias=bias,
|
| 894 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
self.stages = nn.ModuleList()
|
| 898 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| 899 |
+
cur = 0
|
| 900 |
+
|
| 901 |
+
# Create stages in the same order as the original model
|
| 902 |
+
for i in range(len(self.depths)):
|
| 903 |
+
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
| 904 |
+
stage = nn.Sequential(
|
| 905 |
+
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
| 906 |
+
)
|
| 907 |
+
self.stages.append(stage)
|
| 908 |
+
cur += self.depths[i]
|
| 909 |
+
|
| 910 |
+
if not disable_last_norm:
|
| 911 |
+
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
| 912 |
+
else:
|
| 913 |
+
self.norm = nn.Identity()
|
| 914 |
+
self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
| 915 |
+
|
| 916 |
+
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 917 |
+
for i in range(len(self.depths)):
|
| 918 |
+
# Apply upsampling
|
| 919 |
+
for layer in self.upsample_layers[i]:
|
| 920 |
+
if isinstance(layer, (SConv1d, SConvTranspose1d)):
|
| 921 |
+
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 922 |
+
else:
|
| 923 |
+
x = layer(x)
|
| 924 |
+
|
| 925 |
+
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
| 926 |
+
for block in self.stages[i]:
|
| 927 |
+
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
| 928 |
+
# Block1D forward with cache support
|
| 929 |
+
residual = x
|
| 930 |
+
x = block.norm(x)
|
| 931 |
+
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 932 |
+
if block.gamma is not None:
|
| 933 |
+
x = x * block.gamma.unsqueeze(-1)
|
| 934 |
+
x = residual + x
|
| 935 |
+
|
| 936 |
+
# FFN part
|
| 937 |
+
residual = x
|
| 938 |
+
x = block.ffn_norm(x)
|
| 939 |
+
x = x.permute(0, 2, 1)
|
| 940 |
+
x = block.ffn(x)
|
| 941 |
+
x = x.permute(0, 2, 1)
|
| 942 |
+
if block.ffn_gamma is not None:
|
| 943 |
+
x = x * block.ffn_gamma.unsqueeze(-1)
|
| 944 |
+
x = residual + x
|
| 945 |
+
else:
|
| 946 |
+
x = block(x)
|
| 947 |
+
|
| 948 |
+
return self.norm(x)
|
| 949 |
+
|
| 950 |
+
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 951 |
+
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 952 |
+
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 953 |
+
return x
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
@dataclass
|
| 957 |
+
class KugelAudioTokenizerEncoderOutput:
|
| 958 |
+
"""
|
| 959 |
+
Output of KugelAudio tokenizer encoder, representing a Gaussian distribution with fixed variance.
|
| 960 |
+
|
| 961 |
+
Args:
|
| 962 |
+
mean (`torch.FloatTensor`): The mean parameters of the distribution.
|
| 963 |
+
std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
|
| 964 |
+
"""
|
| 965 |
+
mean: torch.Tensor
|
| 966 |
+
std: Optional[Union[float, torch.Tensor]] = None
|
| 967 |
+
|
| 968 |
+
def sample(self, dist_type='fix'):
|
| 969 |
+
"""
|
| 970 |
+
Sample from the distribution.
|
| 971 |
+
|
| 972 |
+
Args:
|
| 973 |
+
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
|
| 974 |
+
|
| 975 |
+
Returns:
|
| 976 |
+
`torch.FloatTensor`: Sampled values.
|
| 977 |
+
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
|
| 978 |
+
"""
|
| 979 |
+
if dist_type == 'fix':
|
| 980 |
+
x = self.mean + self.std * torch.randn_like(self.mean)
|
| 981 |
+
return x, self.std
|
| 982 |
+
elif dist_type == 'gaussian':
|
| 983 |
+
batch_size = self.mean.size(0)
|
| 984 |
+
value = self.std / 0.8
|
| 985 |
+
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
|
| 986 |
+
|
| 987 |
+
while std.dim() < self.mean.dim():
|
| 988 |
+
std = std.unsqueeze(-1)
|
| 989 |
+
|
| 990 |
+
x = self.mean + std * torch.randn_like(self.mean)
|
| 991 |
+
return x, std
|
| 992 |
+
else:
|
| 993 |
+
return self.mean, self.std
|
| 994 |
+
|
| 995 |
+
def kl(self):
|
| 996 |
+
"""Compute KL divergence between this distribution and a standard normal."""
|
| 997 |
+
target = torch.zeros_like(self.mean)
|
| 998 |
+
return F.mse_loss(self.mean, target, reduction='none')
|
| 999 |
+
|
| 1000 |
+
def mode(self):
|
| 1001 |
+
"""Return the distribution mode (which is the mean for Gaussian)."""
|
| 1002 |
+
return self.mean
|
| 1003 |
+
|
| 1004 |
+
class KugelAudioAcousticTokenizerModel(PreTrainedModel):
|
| 1005 |
+
"""KugelAudio speech tokenizer model combining encoder and decoder for acoustic tokens"""
|
| 1006 |
+
|
| 1007 |
+
config_class = KugelAudioAcousticTokenizerConfig
|
| 1008 |
+
base_model_prefix = "kugelaudio_acoustic_tokenizer"
|
| 1009 |
+
_supports_flash_attn_2 = True
|
| 1010 |
+
_supports_sdpa = True
|
| 1011 |
+
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
|
| 1012 |
+
|
| 1013 |
+
def __init__(self, config):
|
| 1014 |
+
super().__init__(config)
|
| 1015 |
+
|
| 1016 |
+
self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
|
| 1017 |
+
self.std_dist_type = getattr(config, "std_dist_type", "fix")
|
| 1018 |
+
|
| 1019 |
+
# Parse encoder depths
|
| 1020 |
+
if isinstance(config.encoder_depths, str):
|
| 1021 |
+
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
| 1022 |
+
else:
|
| 1023 |
+
encoder_depths = config.encoder_depths
|
| 1024 |
+
|
| 1025 |
+
# Parse decoder depths if provided
|
| 1026 |
+
if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
|
| 1027 |
+
decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
|
| 1028 |
+
else:
|
| 1029 |
+
# Default: use reversed encoder depths if decoder_depths is None
|
| 1030 |
+
decoder_depths = list(reversed(encoder_depths))
|
| 1031 |
+
|
| 1032 |
+
# Create encoder config
|
| 1033 |
+
encoder_config = copy.deepcopy(config)
|
| 1034 |
+
encoder_config.dimension = config.vae_dim
|
| 1035 |
+
encoder_config.n_filters = config.encoder_n_filters
|
| 1036 |
+
encoder_config.ratios = config.encoder_ratios
|
| 1037 |
+
encoder_config.depths = encoder_depths
|
| 1038 |
+
encoder_config.norm = config.conv_norm
|
| 1039 |
+
encoder_config.pad_mode = config.pad_mode
|
| 1040 |
+
encoder_config.bias = config.conv_bias
|
| 1041 |
+
encoder_config.layernorm_eps = config.layernorm_eps
|
| 1042 |
+
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| 1043 |
+
encoder_config.mixer_layer = config.mixer_layer
|
| 1044 |
+
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| 1045 |
+
encoder_config.disable_last_norm = config.disable_last_norm
|
| 1046 |
+
|
| 1047 |
+
# Create decoder config
|
| 1048 |
+
decoder_config = copy.deepcopy(config)
|
| 1049 |
+
decoder_config.dimension = config.vae_dim
|
| 1050 |
+
decoder_config.n_filters = config.decoder_n_filters
|
| 1051 |
+
decoder_config.ratios = config.decoder_ratios
|
| 1052 |
+
decoder_config.depths = decoder_depths
|
| 1053 |
+
decoder_config.norm = config.conv_norm
|
| 1054 |
+
decoder_config.pad_mode = config.pad_mode
|
| 1055 |
+
decoder_config.bias = config.conv_bias
|
| 1056 |
+
decoder_config.layernorm_eps = config.layernorm_eps
|
| 1057 |
+
decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| 1058 |
+
decoder_config.mixer_layer = config.mixer_layer
|
| 1059 |
+
decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| 1060 |
+
decoder_config.disable_last_norm = config.disable_last_norm
|
| 1061 |
+
|
| 1062 |
+
# Initialize encoder and decoder
|
| 1063 |
+
self.encoder = TokenizerEncoder(encoder_config)
|
| 1064 |
+
self.decoder = TokenizerDecoder(decoder_config)
|
| 1065 |
+
|
| 1066 |
+
# Initialize weights
|
| 1067 |
+
self.apply(self._init_weights)
|
| 1068 |
+
|
| 1069 |
+
def _init_weights(self, module):
|
| 1070 |
+
"""Initialize weights for the model"""
|
| 1071 |
+
if isinstance(module, nn.Linear):
|
| 1072 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1073 |
+
if module.bias is not None:
|
| 1074 |
+
nn.init.zeros_(module.bias)
|
| 1075 |
+
elif isinstance(module, nn.LayerNorm):
|
| 1076 |
+
nn.init.ones_(module.weight)
|
| 1077 |
+
nn.init.zeros_(module.bias)
|
| 1078 |
+
elif isinstance(module, nn.Conv1d):
|
| 1079 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1080 |
+
if module.bias is not None:
|
| 1081 |
+
nn.init.zeros_(module.bias)
|
| 1082 |
+
|
| 1083 |
+
@torch.no_grad()
|
| 1084 |
+
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1085 |
+
"""Convert audio to latent representations"""
|
| 1086 |
+
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1087 |
+
return KugelAudioTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
|
| 1088 |
+
|
| 1089 |
+
@torch.no_grad()
|
| 1090 |
+
def sampling(self, encoder_output, dist_type=None):
|
| 1091 |
+
"""Sample from the encoder output distribution"""
|
| 1092 |
+
dist_type = dist_type or self.std_dist_type
|
| 1093 |
+
|
| 1094 |
+
if dist_type == 'fix':
|
| 1095 |
+
return encoder_output.sample(dist_type='fix')
|
| 1096 |
+
elif dist_type == 'gaussian':
|
| 1097 |
+
return encoder_output.sample(dist_type='gaussian')
|
| 1098 |
+
else:
|
| 1099 |
+
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
|
| 1100 |
+
|
| 1101 |
+
@torch.no_grad()
|
| 1102 |
+
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1103 |
+
"""Convert latent representations back to audio"""
|
| 1104 |
+
if latents.shape[1] == self.config.vae_dim:
|
| 1105 |
+
pass
|
| 1106 |
+
else:
|
| 1107 |
+
latents = latents.permute(0, 2, 1)
|
| 1108 |
+
|
| 1109 |
+
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1110 |
+
return audio
|
| 1111 |
+
|
| 1112 |
+
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1113 |
+
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
| 1114 |
+
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1115 |
+
sampled_latents, _ = self.sampling(encoder_output)
|
| 1116 |
+
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1117 |
+
return reconstructed, sampled_latents
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
class KugelAudioSemanticTokenizerModel(PreTrainedModel):
|
| 1121 |
+
"""KugelAudio speech tokenizer model with only encoder for semantic tokens"""
|
| 1122 |
+
|
| 1123 |
+
config_class = KugelAudioSemanticTokenizerConfig
|
| 1124 |
+
base_model_prefix = "kugelaudio_semantic_tokenizer"
|
| 1125 |
+
_supports_flash_attn_2 = True
|
| 1126 |
+
_supports_sdpa = True
|
| 1127 |
+
_no_split_modules = ["TokenizerEncoder"]
|
| 1128 |
+
|
| 1129 |
+
def __init__(self, config):
|
| 1130 |
+
super().__init__(config)
|
| 1131 |
+
|
| 1132 |
+
# Parse encoder depths
|
| 1133 |
+
if isinstance(config.encoder_depths, str):
|
| 1134 |
+
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
| 1135 |
+
else:
|
| 1136 |
+
encoder_depths = config.encoder_depths
|
| 1137 |
+
|
| 1138 |
+
# Create encoder config
|
| 1139 |
+
encoder_config = copy.deepcopy(config)
|
| 1140 |
+
encoder_config.dimension = config.vae_dim
|
| 1141 |
+
encoder_config.n_filters = config.encoder_n_filters
|
| 1142 |
+
encoder_config.ratios = config.encoder_ratios
|
| 1143 |
+
encoder_config.depths = encoder_depths
|
| 1144 |
+
encoder_config.norm = config.conv_norm
|
| 1145 |
+
encoder_config.pad_mode = config.pad_mode
|
| 1146 |
+
encoder_config.bias = config.conv_bias
|
| 1147 |
+
encoder_config.layernorm_eps = config.layernorm_eps
|
| 1148 |
+
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| 1149 |
+
encoder_config.mixer_layer = config.mixer_layer
|
| 1150 |
+
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| 1151 |
+
encoder_config.disable_last_norm = config.disable_last_norm
|
| 1152 |
+
|
| 1153 |
+
# Initialize encoder and decoder
|
| 1154 |
+
self.encoder = TokenizerEncoder(encoder_config)
|
| 1155 |
+
|
| 1156 |
+
# Initialize weights
|
| 1157 |
+
self.apply(self._init_weights)
|
| 1158 |
+
|
| 1159 |
+
def _init_weights(self, module):
|
| 1160 |
+
"""Initialize weights for the model"""
|
| 1161 |
+
if isinstance(module, nn.Linear):
|
| 1162 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1163 |
+
if module.bias is not None:
|
| 1164 |
+
nn.init.zeros_(module.bias)
|
| 1165 |
+
elif isinstance(module, nn.LayerNorm):
|
| 1166 |
+
nn.init.ones_(module.weight)
|
| 1167 |
+
nn.init.zeros_(module.bias)
|
| 1168 |
+
elif isinstance(module, nn.Conv1d):
|
| 1169 |
+
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| 1170 |
+
if module.bias is not None:
|
| 1171 |
+
nn.init.zeros_(module.bias)
|
| 1172 |
+
|
| 1173 |
+
@torch.no_grad()
|
| 1174 |
+
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1175 |
+
"""Convert audio to latent representations"""
|
| 1176 |
+
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1177 |
+
return KugelAudioTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
|
| 1178 |
+
|
| 1179 |
+
@torch.no_grad()
|
| 1180 |
+
def sampling(self, encoder_output, dist_type=None):
|
| 1181 |
+
"""Sample from the encoder output distribution"""
|
| 1182 |
+
return encoder_output.sample(dist_type='none')
|
| 1183 |
+
|
| 1184 |
+
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| 1185 |
+
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
| 1186 |
+
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| 1187 |
+
sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
|
| 1188 |
+
return None, sampled_latents
|
| 1189 |
+
|
| 1190 |
+
AutoModel.register(KugelAudioAcousticTokenizerConfig, KugelAudioAcousticTokenizerModel)
|
| 1191 |
+
AutoModel.register(KugelAudioSemanticTokenizerConfig, KugelAudioSemanticTokenizerModel)
|
| 1192 |
+
|
| 1193 |
+
__all__ = [
|
| 1194 |
+
"KugelAudioTokenizerStreamingCache",
|
| 1195 |
+
"KugelAudioAcousticTokenizerModel",
|
| 1196 |
+
"KugelAudioSemanticTokenizerModel",
|
| 1197 |
+
]
|
kugelaudio_open/processors/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Processors for KugelAudio text and audio handling."""
|
| 2 |
+
|
| 3 |
+
from kugelaudio_open.processors.audio_processor import AudioProcessor, AudioNormalizer
|
| 4 |
+
from kugelaudio_open.processors.kugelaudio_processor import KugelAudioProcessor
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"AudioProcessor",
|
| 8 |
+
"AudioNormalizer",
|
| 9 |
+
"KugelAudioProcessor",
|
| 10 |
+
]
|
kugelaudio_open/processors/audio_processor.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio processing utilities for KugelAudio."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional, Union, List, Dict, Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.get_logger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AudioNormalizer:
|
| 16 |
+
"""Normalize audio to target dB FS level.
|
| 17 |
+
|
| 18 |
+
This ensures consistent input levels for the model while
|
| 19 |
+
maintaining audio quality and avoiding clipping.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
| 23 |
+
self.target_dB_FS = target_dB_FS
|
| 24 |
+
self.eps = eps
|
| 25 |
+
|
| 26 |
+
def normalize_db(self, audio: np.ndarray) -> tuple:
|
| 27 |
+
"""Adjust audio to target dB FS level."""
|
| 28 |
+
rms = np.sqrt(np.mean(audio**2))
|
| 29 |
+
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
| 30 |
+
return audio * scalar, rms, scalar
|
| 31 |
+
|
| 32 |
+
def avoid_clipping(self, audio: np.ndarray) -> tuple:
|
| 33 |
+
"""Scale down if necessary to avoid clipping."""
|
| 34 |
+
max_val = np.max(np.abs(audio))
|
| 35 |
+
if max_val > 1.0:
|
| 36 |
+
scalar = max_val + self.eps
|
| 37 |
+
return audio / scalar, scalar
|
| 38 |
+
return audio, 1.0
|
| 39 |
+
|
| 40 |
+
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
| 41 |
+
"""Normalize audio: adjust dB FS then avoid clipping."""
|
| 42 |
+
audio, _, _ = self.normalize_db(audio)
|
| 43 |
+
audio, _ = self.avoid_clipping(audio)
|
| 44 |
+
return audio
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AudioProcessor(FeatureExtractionMixin):
|
| 48 |
+
"""Processor for audio preprocessing and postprocessing.
|
| 49 |
+
|
| 50 |
+
Handles:
|
| 51 |
+
- Audio format conversion (stereo to mono)
|
| 52 |
+
- Normalization
|
| 53 |
+
- Loading from various file formats
|
| 54 |
+
- Saving to WAV files
|
| 55 |
+
|
| 56 |
+
Example:
|
| 57 |
+
>>> processor = AudioProcessor(sampling_rate=24000)
|
| 58 |
+
>>> audio = processor("path/to/audio.wav")
|
| 59 |
+
>>> processor.save_audio(generated_audio, "output.wav")
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
model_input_names = ["input_features"]
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
sampling_rate: int = 24000,
|
| 67 |
+
normalize_audio: bool = True,
|
| 68 |
+
target_dB_FS: float = -25,
|
| 69 |
+
eps: float = 1e-6,
|
| 70 |
+
**kwargs,
|
| 71 |
+
):
|
| 72 |
+
super().__init__(**kwargs)
|
| 73 |
+
|
| 74 |
+
self.sampling_rate = sampling_rate
|
| 75 |
+
self.normalize_audio = normalize_audio
|
| 76 |
+
self.normalizer = AudioNormalizer(target_dB_FS, eps) if normalize_audio else None
|
| 77 |
+
|
| 78 |
+
self.feature_extractor_dict = {
|
| 79 |
+
"sampling_rate": sampling_rate,
|
| 80 |
+
"normalize_audio": normalize_audio,
|
| 81 |
+
"target_dB_FS": target_dB_FS,
|
| 82 |
+
"eps": eps,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
|
| 86 |
+
"""Convert stereo to mono if needed."""
|
| 87 |
+
if len(audio.shape) == 1:
|
| 88 |
+
return audio
|
| 89 |
+
elif len(audio.shape) == 2:
|
| 90 |
+
if audio.shape[0] == 2:
|
| 91 |
+
return np.mean(audio, axis=0)
|
| 92 |
+
elif audio.shape[1] == 2:
|
| 93 |
+
return np.mean(audio, axis=1)
|
| 94 |
+
elif audio.shape[0] == 1:
|
| 95 |
+
return audio.squeeze(0)
|
| 96 |
+
elif audio.shape[1] == 1:
|
| 97 |
+
return audio.squeeze(1)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Unexpected audio shape: {audio.shape}")
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
|
| 102 |
+
|
| 103 |
+
def _process_single(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
|
| 104 |
+
"""Process a single audio array."""
|
| 105 |
+
if not isinstance(audio, np.ndarray):
|
| 106 |
+
audio = np.array(audio, dtype=np.float32)
|
| 107 |
+
else:
|
| 108 |
+
audio = audio.astype(np.float32)
|
| 109 |
+
|
| 110 |
+
audio = self._ensure_mono(audio)
|
| 111 |
+
|
| 112 |
+
if self.normalize_audio and self.normalizer:
|
| 113 |
+
audio = self.normalizer(audio)
|
| 114 |
+
|
| 115 |
+
return audio
|
| 116 |
+
|
| 117 |
+
def _load_from_path(self, audio_path: str) -> np.ndarray:
|
| 118 |
+
"""Load audio from file path."""
|
| 119 |
+
ext = os.path.splitext(audio_path)[1].lower()
|
| 120 |
+
|
| 121 |
+
if ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]:
|
| 122 |
+
import librosa
|
| 123 |
+
audio, _ = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
|
| 124 |
+
return audio
|
| 125 |
+
elif ext == ".pt":
|
| 126 |
+
tensor = torch.load(audio_path, map_location="cpu", weights_only=True).squeeze()
|
| 127 |
+
return tensor.numpy().astype(np.float32)
|
| 128 |
+
elif ext == ".npy":
|
| 129 |
+
return np.load(audio_path).astype(np.float32)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Unsupported format: {ext}")
|
| 132 |
+
|
| 133 |
+
def __call__(
|
| 134 |
+
self,
|
| 135 |
+
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[str]] = None,
|
| 136 |
+
sampling_rate: Optional[int] = None,
|
| 137 |
+
return_tensors: Optional[str] = None,
|
| 138 |
+
**kwargs,
|
| 139 |
+
) -> Dict[str, Any]:
|
| 140 |
+
"""Process audio input(s).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
audio: Audio input - path, array, or list of either
|
| 144 |
+
sampling_rate: Input sampling rate (for validation)
|
| 145 |
+
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy)
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Dictionary with processed audio
|
| 149 |
+
"""
|
| 150 |
+
if audio is None:
|
| 151 |
+
raise ValueError("Audio input is required")
|
| 152 |
+
|
| 153 |
+
if sampling_rate is not None and sampling_rate != self.sampling_rate:
|
| 154 |
+
logger.warning(
|
| 155 |
+
f"Input sampling rate ({sampling_rate}) differs from expected ({self.sampling_rate}). "
|
| 156 |
+
"Please resample your audio."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Handle different input types
|
| 160 |
+
if isinstance(audio, str):
|
| 161 |
+
audio = self._load_from_path(audio)
|
| 162 |
+
is_batched = False
|
| 163 |
+
elif isinstance(audio, list):
|
| 164 |
+
if all(isinstance(item, str) for item in audio):
|
| 165 |
+
audio = [self._load_from_path(p) for p in audio]
|
| 166 |
+
is_batched = True
|
| 167 |
+
else:
|
| 168 |
+
is_batched = isinstance(audio[0], (np.ndarray, list))
|
| 169 |
+
else:
|
| 170 |
+
is_batched = False
|
| 171 |
+
|
| 172 |
+
# Process
|
| 173 |
+
if is_batched:
|
| 174 |
+
processed = [self._process_single(a) for a in audio]
|
| 175 |
+
else:
|
| 176 |
+
processed = [self._process_single(audio)]
|
| 177 |
+
|
| 178 |
+
# Convert to tensors
|
| 179 |
+
if return_tensors == "pt":
|
| 180 |
+
if len(processed) == 1:
|
| 181 |
+
features = torch.from_numpy(processed[0]).unsqueeze(0).unsqueeze(1)
|
| 182 |
+
else:
|
| 183 |
+
features = torch.stack([torch.from_numpy(a) for a in processed]).unsqueeze(1)
|
| 184 |
+
elif return_tensors == "np":
|
| 185 |
+
if len(processed) == 1:
|
| 186 |
+
features = processed[0][np.newaxis, np.newaxis, :]
|
| 187 |
+
else:
|
| 188 |
+
features = np.stack(processed)[:, np.newaxis, :]
|
| 189 |
+
else:
|
| 190 |
+
features = processed[0] if len(processed) == 1 else processed
|
| 191 |
+
|
| 192 |
+
return {"audio": features}
|
| 193 |
+
|
| 194 |
+
def save_audio(
|
| 195 |
+
self,
|
| 196 |
+
audio: Union[torch.Tensor, np.ndarray, List],
|
| 197 |
+
output_path: str = "output.wav",
|
| 198 |
+
sampling_rate: Optional[int] = None,
|
| 199 |
+
normalize: bool = False,
|
| 200 |
+
batch_prefix: str = "audio_",
|
| 201 |
+
) -> List[str]:
|
| 202 |
+
"""Save audio to WAV file(s).
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
audio: Audio data to save
|
| 206 |
+
output_path: Output path (directory for batched audio)
|
| 207 |
+
sampling_rate: Sampling rate (defaults to processor's rate)
|
| 208 |
+
normalize: Whether to normalize before saving
|
| 209 |
+
batch_prefix: Prefix for batch files
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
List of saved file paths
|
| 213 |
+
"""
|
| 214 |
+
import soundfile as sf
|
| 215 |
+
|
| 216 |
+
if sampling_rate is None:
|
| 217 |
+
sampling_rate = self.sampling_rate
|
| 218 |
+
|
| 219 |
+
# Convert to numpy
|
| 220 |
+
if isinstance(audio, torch.Tensor):
|
| 221 |
+
audio_np = audio.float().detach().cpu().numpy()
|
| 222 |
+
elif isinstance(audio, list):
|
| 223 |
+
if all(isinstance(a, torch.Tensor) for a in audio):
|
| 224 |
+
audio_np = [a.float().detach().cpu().numpy() for a in audio]
|
| 225 |
+
else:
|
| 226 |
+
audio_np = audio
|
| 227 |
+
else:
|
| 228 |
+
audio_np = audio
|
| 229 |
+
|
| 230 |
+
saved_paths = []
|
| 231 |
+
|
| 232 |
+
if isinstance(audio_np, list):
|
| 233 |
+
os.makedirs(output_path, exist_ok=True)
|
| 234 |
+
for i, item in enumerate(audio_np):
|
| 235 |
+
item = self._prepare_for_save(item, normalize)
|
| 236 |
+
path = os.path.join(output_path, f"{batch_prefix}{i}.wav")
|
| 237 |
+
sf.write(path, item, sampling_rate)
|
| 238 |
+
saved_paths.append(path)
|
| 239 |
+
elif len(audio_np.shape) >= 3 and audio_np.shape[0] > 1:
|
| 240 |
+
os.makedirs(output_path, exist_ok=True)
|
| 241 |
+
for i in range(audio_np.shape[0]):
|
| 242 |
+
item = audio_np[i].squeeze()
|
| 243 |
+
item = self._prepare_for_save(item, normalize)
|
| 244 |
+
path = os.path.join(output_path, f"{batch_prefix}{i}.wav")
|
| 245 |
+
sf.write(path, item, sampling_rate)
|
| 246 |
+
saved_paths.append(path)
|
| 247 |
+
else:
|
| 248 |
+
item = self._prepare_for_save(audio_np.squeeze(), normalize)
|
| 249 |
+
sf.write(output_path, item, sampling_rate)
|
| 250 |
+
saved_paths.append(output_path)
|
| 251 |
+
|
| 252 |
+
return saved_paths
|
| 253 |
+
|
| 254 |
+
def _prepare_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
|
| 255 |
+
"""Prepare audio for saving."""
|
| 256 |
+
if len(audio.shape) > 1 and audio.shape[0] == 1:
|
| 257 |
+
audio = audio.squeeze(0)
|
| 258 |
+
|
| 259 |
+
if normalize:
|
| 260 |
+
max_val = np.abs(audio).max()
|
| 261 |
+
if max_val > 0:
|
| 262 |
+
audio = audio / max_val
|
| 263 |
+
|
| 264 |
+
return audio
|
| 265 |
+
|
| 266 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 267 |
+
"""Convert to dictionary for serialization."""
|
| 268 |
+
return self.feature_extractor_dict
|
kugelaudio_open/processors/kugelaudio_processor.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main processor for KugelAudio combining text and audio processing."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any, Dict, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from kugelaudio_open.processors.audio_processor import AudioNormalizer, AudioProcessor
|
| 11 |
+
from transformers.tokenization_utils_base import (
|
| 12 |
+
BatchEncoding,
|
| 13 |
+
PaddingStrategy,
|
| 14 |
+
TruncationStrategy,
|
| 15 |
+
)
|
| 16 |
+
from transformers.utils import TensorType, cached_file, logging
|
| 17 |
+
|
| 18 |
+
logger = logging.get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class KugelAudioProcessor:
|
| 22 |
+
"""Combined processor for KugelAudio text and audio.
|
| 23 |
+
|
| 24 |
+
Wraps a text tokenizer and audio processor into a single interface
|
| 25 |
+
for preparing inputs for KugelAudio models.
|
| 26 |
+
|
| 27 |
+
Example:
|
| 28 |
+
>>> processor = KugelAudioProcessor.from_pretrained("kugelaudio/kugelaudio-0-open")
|
| 29 |
+
>>> inputs = processor(text="Hello world", voice_prompt=voice_audio)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
tokenizer=None,
|
| 35 |
+
audio_processor: Optional[AudioProcessor] = None,
|
| 36 |
+
speech_compression_ratio: int = 3200,
|
| 37 |
+
db_normalize: bool = True,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.audio_processor = audio_processor or AudioProcessor()
|
| 42 |
+
self.speech_compression_ratio = speech_compression_ratio
|
| 43 |
+
self.db_normalize = db_normalize
|
| 44 |
+
self.audio_normalizer = AudioNormalizer() if db_normalize else None
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 48 |
+
"""Load processor from pretrained model.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
pretrained_model_name_or_path: Model ID or local path
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
KugelAudioProcessor instance
|
| 55 |
+
"""
|
| 56 |
+
from kugelaudio_open.processors.text_tokenizer import KugelAudioTextTokenizer
|
| 57 |
+
|
| 58 |
+
# Try to load config
|
| 59 |
+
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
| 60 |
+
config = None
|
| 61 |
+
|
| 62 |
+
if os.path.exists(config_path):
|
| 63 |
+
with open(config_path, "r") as f:
|
| 64 |
+
config = json.load(f)
|
| 65 |
+
else:
|
| 66 |
+
try:
|
| 67 |
+
config_file = cached_file(
|
| 68 |
+
pretrained_model_name_or_path, "preprocessor_config.json", **kwargs
|
| 69 |
+
)
|
| 70 |
+
with open(config_file, "r") as f:
|
| 71 |
+
config = json.load(f)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.warning(f"Could not load config: {e}. Using defaults.")
|
| 74 |
+
config = {
|
| 75 |
+
"speech_compression_ratio": 3200,
|
| 76 |
+
"db_normalize": True,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
# Extract parameters
|
| 80 |
+
speech_compression_ratio = config.get("speech_compression_ratio", 3200)
|
| 81 |
+
db_normalize = config.get("db_normalize", True)
|
| 82 |
+
|
| 83 |
+
# Load tokenizer
|
| 84 |
+
lm_name = config.get("language_model_pretrained_name") or kwargs.pop(
|
| 85 |
+
"language_model_pretrained_name", "Qwen/Qwen2.5-1.5B"
|
| 86 |
+
)
|
| 87 |
+
logger.info(f"Loading tokenizer from {lm_name}")
|
| 88 |
+
tokenizer = KugelAudioTextTokenizer.from_pretrained(lm_name, **kwargs)
|
| 89 |
+
|
| 90 |
+
# Load audio processor
|
| 91 |
+
if "audio_processor" in config:
|
| 92 |
+
audio_config = config["audio_processor"]
|
| 93 |
+
audio_processor = AudioProcessor(
|
| 94 |
+
sampling_rate=audio_config.get("sampling_rate", 24000),
|
| 95 |
+
normalize_audio=audio_config.get("normalize_audio", True),
|
| 96 |
+
target_dB_FS=audio_config.get("target_dB_FS", -25),
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
audio_processor = AudioProcessor()
|
| 100 |
+
|
| 101 |
+
return cls(
|
| 102 |
+
tokenizer=tokenizer,
|
| 103 |
+
audio_processor=audio_processor,
|
| 104 |
+
speech_compression_ratio=speech_compression_ratio,
|
| 105 |
+
db_normalize=db_normalize,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 109 |
+
"""Save processor to directory."""
|
| 110 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
config = {
|
| 113 |
+
"processor_class": "KugelAudioProcessor",
|
| 114 |
+
"speech_compression_ratio": self.speech_compression_ratio,
|
| 115 |
+
"db_normalize": self.db_normalize,
|
| 116 |
+
"audio_processor": {
|
| 117 |
+
"feature_extractor_type": "AudioProcessor",
|
| 118 |
+
"sampling_rate": getattr(self.audio_processor, "sampling_rate", 24000),
|
| 119 |
+
"normalize_audio": getattr(self.audio_processor, "normalize_audio", True),
|
| 120 |
+
"target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25),
|
| 121 |
+
},
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
| 125 |
+
with open(config_path, "w") as f:
|
| 126 |
+
json.dump(config, f, indent=2)
|
| 127 |
+
|
| 128 |
+
logger.info(f"Processor saved to {config_path}")
|
| 129 |
+
|
| 130 |
+
def __call__(
|
| 131 |
+
self,
|
| 132 |
+
text: Optional[str] = None,
|
| 133 |
+
voice_prompt: Optional[Union[np.ndarray, torch.Tensor, str]] = None,
|
| 134 |
+
padding: Union[bool, str, PaddingStrategy] = True,
|
| 135 |
+
truncation: Union[bool, str, TruncationStrategy] = False,
|
| 136 |
+
max_length: Optional[int] = None,
|
| 137 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 138 |
+
**kwargs,
|
| 139 |
+
) -> BatchEncoding:
|
| 140 |
+
"""Process text and optional voice prompt.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
text: Input text to synthesize
|
| 144 |
+
voice_prompt: Voice prompt audio for speaker identity (raw audio tensor or path)
|
| 145 |
+
padding: Padding strategy
|
| 146 |
+
truncation: Truncation strategy
|
| 147 |
+
max_length: Maximum sequence length
|
| 148 |
+
return_tensors: Return format
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
BatchEncoding with processed inputs including speech_input_mask for voice cloning
|
| 152 |
+
"""
|
| 153 |
+
if text is None:
|
| 154 |
+
raise ValueError("Text input is required")
|
| 155 |
+
|
| 156 |
+
# Special token IDs
|
| 157 |
+
speech_start_id = 151652 # <|vision_start|> repurposed for speech
|
| 158 |
+
speech_diffusion_id = 151654 # VAE token used as placeholder
|
| 159 |
+
|
| 160 |
+
# Format text with proper template
|
| 161 |
+
# Add speaker prefix if not present (use Speaker 0 to match training format)
|
| 162 |
+
formatted_text = text.strip()
|
| 163 |
+
if not formatted_text.startswith("Speaker"):
|
| 164 |
+
formatted_text = f"Speaker 0: {formatted_text}"
|
| 165 |
+
|
| 166 |
+
# Build the full prompt template matching the training format
|
| 167 |
+
system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
|
| 168 |
+
|
| 169 |
+
# Start building tokens and speech_input_mask
|
| 170 |
+
full_tokens = []
|
| 171 |
+
speech_input_mask = []
|
| 172 |
+
voice_audio = None
|
| 173 |
+
|
| 174 |
+
# System prompt tokens
|
| 175 |
+
system_tokens = self.tokenizer.encode(system_prompt, add_special_tokens=False)
|
| 176 |
+
full_tokens.extend(system_tokens)
|
| 177 |
+
speech_input_mask.extend([False] * len(system_tokens))
|
| 178 |
+
|
| 179 |
+
# Process voice prompt if provided
|
| 180 |
+
if voice_prompt is not None:
|
| 181 |
+
# Load audio if it's a path
|
| 182 |
+
if isinstance(voice_prompt, str):
|
| 183 |
+
voice_audio = self.audio_processor._load_from_path(voice_prompt)
|
| 184 |
+
if self.db_normalize and self.audio_normalizer:
|
| 185 |
+
voice_audio = self.audio_normalizer(voice_audio)
|
| 186 |
+
elif isinstance(voice_prompt, np.ndarray):
|
| 187 |
+
voice_audio = voice_prompt.astype(np.float32)
|
| 188 |
+
elif isinstance(voice_prompt, torch.Tensor):
|
| 189 |
+
voice_audio = voice_prompt.cpu().numpy()
|
| 190 |
+
if voice_audio.ndim > 1:
|
| 191 |
+
voice_audio = voice_audio.squeeze()
|
| 192 |
+
voice_audio = voice_audio.astype(np.float32)
|
| 193 |
+
|
| 194 |
+
# Voice input section with placeholder tokens
|
| 195 |
+
voice_input_tokens = self.tokenizer.encode(" Voice input:\n", add_special_tokens=False)
|
| 196 |
+
full_tokens.extend(voice_input_tokens)
|
| 197 |
+
speech_input_mask.extend([False] * len(voice_input_tokens))
|
| 198 |
+
|
| 199 |
+
# Speaker prefix for voice
|
| 200 |
+
speaker_prefix = self.tokenizer.encode(" Speaker 0:", add_special_tokens=False)
|
| 201 |
+
full_tokens.extend(speaker_prefix)
|
| 202 |
+
speech_input_mask.extend([False] * len(speaker_prefix))
|
| 203 |
+
|
| 204 |
+
# Calculate number of VAE tokens needed based on audio length
|
| 205 |
+
# compression ratio is typically 3200 samples per token at 24kHz
|
| 206 |
+
num_voice_tokens = math.ceil(len(voice_audio) / self.speech_compression_ratio)
|
| 207 |
+
|
| 208 |
+
# Add placeholder VAE tokens that will be replaced with speech embeddings
|
| 209 |
+
full_tokens.extend([speech_diffusion_id] * num_voice_tokens)
|
| 210 |
+
speech_input_mask.extend([True] * num_voice_tokens) # These positions get speech embeddings
|
| 211 |
+
|
| 212 |
+
# Newline after voice
|
| 213 |
+
newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
|
| 214 |
+
full_tokens.extend(newline_tokens)
|
| 215 |
+
speech_input_mask.extend([False] * len(newline_tokens))
|
| 216 |
+
|
| 217 |
+
# Text input section
|
| 218 |
+
text_input_tokens = self.tokenizer.encode(" Text input:\n", add_special_tokens=False)
|
| 219 |
+
full_tokens.extend(text_input_tokens)
|
| 220 |
+
speech_input_mask.extend([False] * len(text_input_tokens))
|
| 221 |
+
|
| 222 |
+
# Speaker text
|
| 223 |
+
speaker_text_tokens = self.tokenizer.encode(f" {formatted_text}\n", add_special_tokens=False)
|
| 224 |
+
full_tokens.extend(speaker_text_tokens)
|
| 225 |
+
speech_input_mask.extend([False] * len(speaker_text_tokens))
|
| 226 |
+
|
| 227 |
+
# Speech output section
|
| 228 |
+
speech_output_tokens = self.tokenizer.encode(" Speech output:\n", add_special_tokens=False)
|
| 229 |
+
full_tokens.extend(speech_output_tokens)
|
| 230 |
+
speech_input_mask.extend([False] * len(speech_output_tokens))
|
| 231 |
+
|
| 232 |
+
# Add speech_start token
|
| 233 |
+
full_tokens.append(speech_start_id)
|
| 234 |
+
speech_input_mask.append(False)
|
| 235 |
+
|
| 236 |
+
result = BatchEncoding()
|
| 237 |
+
result["text_ids"] = full_tokens
|
| 238 |
+
result["speech_input_mask"] = speech_input_mask
|
| 239 |
+
|
| 240 |
+
if return_tensors == "pt":
|
| 241 |
+
result["text_ids"] = torch.tensor([full_tokens], dtype=torch.long)
|
| 242 |
+
result["speech_input_mask"] = torch.tensor([speech_input_mask], dtype=torch.bool)
|
| 243 |
+
|
| 244 |
+
# Include processed voice audio for the model to encode
|
| 245 |
+
if voice_audio is not None:
|
| 246 |
+
if return_tensors == "pt":
|
| 247 |
+
result["speech_tensors"] = torch.tensor(voice_audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
| 248 |
+
# Create speech_masks (all True for the voice frames)
|
| 249 |
+
num_frames = math.ceil(len(voice_audio) / self.speech_compression_ratio)
|
| 250 |
+
result["speech_masks"] = torch.ones(1, num_frames, dtype=torch.bool)
|
| 251 |
+
else:
|
| 252 |
+
result["speech_tensors"] = voice_audio
|
| 253 |
+
num_frames = math.ceil(len(voice_audio) / self.speech_compression_ratio)
|
| 254 |
+
result["speech_masks"] = [True] * num_frames
|
| 255 |
+
|
| 256 |
+
return result
|
| 257 |
+
|
| 258 |
+
def process_with_cached_prompt(
|
| 259 |
+
self,
|
| 260 |
+
text: str,
|
| 261 |
+
cached_prompt: Dict[str, Any],
|
| 262 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
| 263 |
+
**kwargs,
|
| 264 |
+
) -> BatchEncoding:
|
| 265 |
+
"""Process text with pre-computed voice prompt cache.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
text: Input text to synthesize
|
| 269 |
+
cached_prompt: Pre-computed KV cache from voice prompt
|
| 270 |
+
return_tensors: Return format
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
BatchEncoding ready for generation
|
| 274 |
+
"""
|
| 275 |
+
script_tokens = self.tokenizer.encode(text.strip() + "\n", add_special_tokens=False)
|
| 276 |
+
|
| 277 |
+
lm_length = cached_prompt["lm"]["last_hidden_state"].size(1)
|
| 278 |
+
tts_lm_length = cached_prompt["tts_lm"]["last_hidden_state"].size(1)
|
| 279 |
+
|
| 280 |
+
# Create pseudo input IDs
|
| 281 |
+
input_ids = [self.tokenizer.pad_id] * lm_length
|
| 282 |
+
tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_length
|
| 283 |
+
speech_input_mask = [False] * tts_lm_length
|
| 284 |
+
|
| 285 |
+
result = BatchEncoding()
|
| 286 |
+
|
| 287 |
+
if return_tensors == "pt":
|
| 288 |
+
result["input_ids"] = torch.tensor([input_ids], dtype=torch.long)
|
| 289 |
+
result["tts_lm_input_ids"] = torch.tensor([tts_lm_input_ids], dtype=torch.long)
|
| 290 |
+
result["tts_text_ids"] = torch.tensor([script_tokens], dtype=torch.long)
|
| 291 |
+
result["attention_mask"] = torch.ones(1, lm_length, dtype=torch.long)
|
| 292 |
+
result["tts_lm_attention_mask"] = torch.ones(1, tts_lm_length, dtype=torch.long)
|
| 293 |
+
result["speech_input_mask"] = torch.tensor([speech_input_mask], dtype=torch.bool)
|
| 294 |
+
else:
|
| 295 |
+
result["input_ids"] = [input_ids]
|
| 296 |
+
result["tts_lm_input_ids"] = [tts_lm_input_ids]
|
| 297 |
+
result["tts_text_ids"] = [script_tokens]
|
| 298 |
+
result["attention_mask"] = [[1] * lm_length]
|
| 299 |
+
result["tts_lm_attention_mask"] = [[1] * tts_lm_length]
|
| 300 |
+
result["speech_input_mask"] = [speech_input_mask]
|
| 301 |
+
|
| 302 |
+
return result
|
| 303 |
+
|
| 304 |
+
def prepare_speech_inputs(
|
| 305 |
+
self,
|
| 306 |
+
speech_inputs: List[np.ndarray],
|
| 307 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 308 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 309 |
+
dtype: Optional[torch.dtype] = None,
|
| 310 |
+
) -> Dict[str, Any]:
|
| 311 |
+
"""Prepare speech inputs for model.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
speech_inputs: List of speech arrays
|
| 315 |
+
return_tensors: Return format
|
| 316 |
+
device: Device to place tensors
|
| 317 |
+
dtype: Data type for tensors
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Dictionary with padded speeches and masks
|
| 321 |
+
"""
|
| 322 |
+
if not speech_inputs:
|
| 323 |
+
return {"padded_speeches": None, "speech_masks": None}
|
| 324 |
+
|
| 325 |
+
# Calculate sequence lengths
|
| 326 |
+
seq_lens = [math.ceil(s.shape[0] / self.speech_compression_ratio) for s in speech_inputs]
|
| 327 |
+
max_speech_len = max(s.shape[0] for s in speech_inputs)
|
| 328 |
+
|
| 329 |
+
# Pad speeches
|
| 330 |
+
padded = np.zeros((len(speech_inputs), max_speech_len), dtype=np.float32)
|
| 331 |
+
masks = np.zeros((len(speech_inputs), max(seq_lens)), dtype=np.bool_)
|
| 332 |
+
|
| 333 |
+
for i, (speech, seq_len) in enumerate(zip(speech_inputs, seq_lens)):
|
| 334 |
+
padded[i, : len(speech)] = speech
|
| 335 |
+
masks[i, :seq_len] = True
|
| 336 |
+
|
| 337 |
+
result = {"padded_speeches": padded, "speech_masks": masks}
|
| 338 |
+
|
| 339 |
+
if return_tensors == "pt":
|
| 340 |
+
result["padded_speeches"] = torch.tensor(
|
| 341 |
+
padded, device=device, dtype=dtype or torch.float32
|
| 342 |
+
)
|
| 343 |
+
result["speech_masks"] = torch.tensor(masks, device=device, dtype=torch.bool)
|
| 344 |
+
|
| 345 |
+
return result
|
| 346 |
+
|
| 347 |
+
def batch_decode(self, *args, **kwargs):
|
| 348 |
+
"""Decode token IDs to text."""
|
| 349 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 350 |
+
|
| 351 |
+
def decode(self, *args, **kwargs):
|
| 352 |
+
"""Decode token IDs to text."""
|
| 353 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 354 |
+
|
| 355 |
+
def save_audio(self, audio, output_path: str = "output.wav", **kwargs) -> List[str]:
|
| 356 |
+
"""Save generated audio to file."""
|
| 357 |
+
return self.audio_processor.save_audio(audio, output_path, **kwargs)
|
| 358 |
+
|
| 359 |
+
@property
|
| 360 |
+
def model_input_names(self) -> List[str]:
|
| 361 |
+
"""Return list of model input names."""
|
| 362 |
+
tokenizer_names = getattr(self.tokenizer, "model_input_names", [])
|
| 363 |
+
audio_names = getattr(self.audio_processor, "model_input_names", [])
|
| 364 |
+
return list(
|
| 365 |
+
dict.fromkeys(tokenizer_names + audio_names + ["speech_inputs", "speech_input_mask"])
|
| 366 |
+
)
|
kugelaudio_open/processors/text_tokenizer.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text tokenizer for KugelAudio based on Qwen2."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
from transformers.utils import logging
|
| 6 |
+
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
| 7 |
+
|
| 8 |
+
logger = logging.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class KugelAudioTextTokenizer(Qwen2TokenizerFast):
|
| 12 |
+
"""Text tokenizer for KugelAudio with speech special tokens.
|
| 13 |
+
|
| 14 |
+
Based on Qwen2 tokenizer with additional tokens for speech synthesis:
|
| 15 |
+
- speech_start: Marks the beginning of speech generation
|
| 16 |
+
- speech_end: Marks the end of speech generation
|
| 17 |
+
- speech_diffusion: Placeholder for diffusion tokens
|
| 18 |
+
|
| 19 |
+
Example:
|
| 20 |
+
>>> tokenizer = KugelAudioTextTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")
|
| 21 |
+
>>> tokens = tokenizer.encode("Hello world")
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
vocab_file=None,
|
| 29 |
+
merges_file=None,
|
| 30 |
+
tokenizer_file=None,
|
| 31 |
+
unk_token="<|endoftext|>",
|
| 32 |
+
bos_token=None,
|
| 33 |
+
eos_token="<|endoftext|>",
|
| 34 |
+
pad_token="<|endoftext|>",
|
| 35 |
+
add_prefix_space=False,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(
|
| 39 |
+
vocab_file=vocab_file,
|
| 40 |
+
merges_file=merges_file,
|
| 41 |
+
tokenizer_file=tokenizer_file,
|
| 42 |
+
unk_token=unk_token,
|
| 43 |
+
bos_token=bos_token,
|
| 44 |
+
eos_token=eos_token,
|
| 45 |
+
pad_token=pad_token,
|
| 46 |
+
add_prefix_space=add_prefix_space,
|
| 47 |
+
**kwargs,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self._add_speech_special_tokens()
|
| 51 |
+
|
| 52 |
+
def _add_speech_special_tokens(self):
|
| 53 |
+
"""Add KugelAudio-specific special tokens for speech."""
|
| 54 |
+
special_tokens = {
|
| 55 |
+
"additional_special_tokens": [
|
| 56 |
+
"<|vision_start|>", # Speech start (reusing vision tokens for compatibility)
|
| 57 |
+
"<|vision_end|>", # Speech end
|
| 58 |
+
"<|vision_pad|>", # Speech diffusion pad
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
self.add_special_tokens(special_tokens)
|
| 62 |
+
|
| 63 |
+
# Cache special token IDs
|
| 64 |
+
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
|
| 65 |
+
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
|
| 66 |
+
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
|
| 67 |
+
self._eos_id = self.eos_token_id
|
| 68 |
+
self._pad_id = self.convert_tokens_to_ids("<|image_pad|>")
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def eos_id(self) -> int:
|
| 72 |
+
"""End of sequence token ID."""
|
| 73 |
+
return self._eos_id
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def speech_start_id(self) -> int:
|
| 77 |
+
"""Speech start token ID."""
|
| 78 |
+
return self._speech_start_id
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def speech_end_id(self) -> int:
|
| 82 |
+
"""Speech end token ID."""
|
| 83 |
+
return self._speech_end_id
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def speech_diffusion_id(self) -> int:
|
| 87 |
+
"""Speech diffusion placeholder token ID."""
|
| 88 |
+
return self._speech_diffusion_id
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def pad_id(self) -> int:
|
| 92 |
+
"""Padding token ID (returns -100 for loss masking)."""
|
| 93 |
+
return self._pad_id
|
kugelaudio_open/schedule/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KugelAudio scheduling components."""
|
| 2 |
+
|
| 3 |
+
from .dpm_solver import DPMSolverMultistepScheduler
|
| 4 |
+
|
| 5 |
+
__all__ = ["DPMSolverMultistepScheduler"]
|
kugelaudio_open/schedule/dpm_solver.py
ADDED
|
@@ -0,0 +1,1084 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.utils import deprecate
|
| 25 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 26 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
| 27 |
+
|
| 28 |
+
def betas_for_alpha_bar(
|
| 29 |
+
num_diffusion_timesteps,
|
| 30 |
+
max_beta=0.999,
|
| 31 |
+
alpha_transform_type="cosine",
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 35 |
+
(1-beta) over time from t = [0,1].
|
| 36 |
+
|
| 37 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 38 |
+
to that part of the diffusion process.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 43 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 44 |
+
prevent singularities.
|
| 45 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 46 |
+
Choose from `cosine` or `exp`
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 50 |
+
"""
|
| 51 |
+
if alpha_transform_type == "cosine":
|
| 52 |
+
|
| 53 |
+
def alpha_bar_fn(t):
|
| 54 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 55 |
+
# return math.cos(t * math.pi / 2 * 0.95) ** 2
|
| 56 |
+
|
| 57 |
+
elif alpha_transform_type == "exp":
|
| 58 |
+
|
| 59 |
+
def alpha_bar_fn(t):
|
| 60 |
+
return math.exp(t * -12.0)
|
| 61 |
+
|
| 62 |
+
elif alpha_transform_type == "cauchy":
|
| 63 |
+
# Β΅ + Ξ³ tan (Ο (0.5 - x)) Ξ³ = 1, Β΅ = 3
|
| 64 |
+
# alpha^2 = 1-1/(exp(Ξ»)+1)
|
| 65 |
+
def alpha_bar_fn(t, gamma=1, mu=3):
|
| 66 |
+
snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
|
| 67 |
+
return 1 - 1 / (math.exp(snr) + 1.1)
|
| 68 |
+
|
| 69 |
+
elif alpha_transform_type == "laplace":
|
| 70 |
+
# Β΅ β bsgn(0.5 β t) log(1 β 2|t β 0.5|) Β΅ = 0, b = 1
|
| 71 |
+
def alpha_bar_fn(t, mu=0, b=1):
|
| 72 |
+
snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
|
| 73 |
+
return 1 - 1 / (math.exp(snr) + 1.02)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 77 |
+
|
| 78 |
+
betas = []
|
| 79 |
+
for i in range(num_diffusion_timesteps):
|
| 80 |
+
t1 = i / num_diffusion_timesteps
|
| 81 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 82 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 83 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
| 87 |
+
def rescale_zero_terminal_snr(betas):
|
| 88 |
+
"""
|
| 89 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
betas (`torch.Tensor`):
|
| 94 |
+
the betas that the scheduler is being initialized with.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 98 |
+
"""
|
| 99 |
+
# Convert betas to alphas_bar_sqrt
|
| 100 |
+
alphas = 1.0 - betas
|
| 101 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 102 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 103 |
+
|
| 104 |
+
# Store old values.
|
| 105 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 106 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 107 |
+
|
| 108 |
+
# Shift so the last timestep is zero.
|
| 109 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 110 |
+
|
| 111 |
+
# Scale so the first timestep is back to the old value.
|
| 112 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 113 |
+
|
| 114 |
+
# Convert alphas_bar_sqrt to betas
|
| 115 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 116 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 117 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 118 |
+
betas = 1 - alphas
|
| 119 |
+
|
| 120 |
+
return betas
|
| 121 |
+
|
| 122 |
+
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 123 |
+
"""
|
| 124 |
+
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
| 125 |
+
|
| 126 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 127 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 131 |
+
The number of diffusion steps to train the model.
|
| 132 |
+
beta_start (`float`, defaults to 0.0001):
|
| 133 |
+
The starting `beta` value of inference.
|
| 134 |
+
beta_end (`float`, defaults to 0.02):
|
| 135 |
+
The final `beta` value.
|
| 136 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 137 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 138 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
| 139 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 140 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 141 |
+
solver_order (`int`, defaults to 2):
|
| 142 |
+
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
| 143 |
+
sampling, and `solver_order=3` for unconditional sampling.
|
| 144 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 145 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 146 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 147 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 148 |
+
thresholding (`bool`, defaults to `False`):
|
| 149 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 150 |
+
as Stable Diffusion.
|
| 151 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 152 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 153 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 154 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
| 155 |
+
`algorithm_type="dpmsolver++"`.
|
| 156 |
+
algorithm_type (`str`, defaults to `dpmsolver++`):
|
| 157 |
+
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
| 158 |
+
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
| 159 |
+
paper, and the `dpmsolver++` type implements the algorithms in the
|
| 160 |
+
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
| 161 |
+
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
| 162 |
+
solver_type (`str`, defaults to `midpoint`):
|
| 163 |
+
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
| 164 |
+
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
| 165 |
+
lower_order_final (`bool`, defaults to `True`):
|
| 166 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 167 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 168 |
+
euler_at_final (`bool`, defaults to `False`):
|
| 169 |
+
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
| 170 |
+
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
| 171 |
+
steps, but sometimes may result in blurring.
|
| 172 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 173 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 174 |
+
the sigmas are determined according to a sequence of noise levels {Οi}.
|
| 175 |
+
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
|
| 176 |
+
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
| 177 |
+
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
| 178 |
+
`lambda(t)`.
|
| 179 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 180 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 181 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 182 |
+
lambda_min_clipped (`float`, defaults to `-inf`):
|
| 183 |
+
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
| 184 |
+
cosine (`squaredcos_cap_v2`) noise schedule.
|
| 185 |
+
variance_type (`str`, *optional*):
|
| 186 |
+
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
| 187 |
+
contains the predicted Gaussian variance.
|
| 188 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 189 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 190 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 191 |
+
steps_offset (`int`, defaults to 0):
|
| 192 |
+
An offset added to the inference steps, as required by some model families.
|
| 193 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 194 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 195 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 196 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 200 |
+
order = 1
|
| 201 |
+
|
| 202 |
+
@register_to_config
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
num_train_timesteps: int = 1000,
|
| 206 |
+
beta_start: float = 0.0001,
|
| 207 |
+
beta_end: float = 0.02,
|
| 208 |
+
beta_schedule: str = "linear",
|
| 209 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 210 |
+
solver_order: int = 2,
|
| 211 |
+
prediction_type: str = "epsilon",
|
| 212 |
+
thresholding: bool = False,
|
| 213 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 214 |
+
sample_max_value: float = 1.0,
|
| 215 |
+
algorithm_type: str = "dpmsolver++",
|
| 216 |
+
solver_type: str = "midpoint",
|
| 217 |
+
lower_order_final: bool = True,
|
| 218 |
+
euler_at_final: bool = False,
|
| 219 |
+
use_karras_sigmas: Optional[bool] = False,
|
| 220 |
+
use_lu_lambdas: Optional[bool] = False,
|
| 221 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 222 |
+
lambda_min_clipped: float = -float("inf"),
|
| 223 |
+
variance_type: Optional[str] = None,
|
| 224 |
+
timestep_spacing: str = "linspace",
|
| 225 |
+
steps_offset: int = 0,
|
| 226 |
+
rescale_betas_zero_snr: bool = False,
|
| 227 |
+
):
|
| 228 |
+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 229 |
+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
| 230 |
+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
| 231 |
+
|
| 232 |
+
if trained_betas is not None:
|
| 233 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 234 |
+
elif beta_schedule == "linear":
|
| 235 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 236 |
+
elif beta_schedule == "scaled_linear":
|
| 237 |
+
# this schedule is very specific to the latent diffusion model.
|
| 238 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 239 |
+
elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
|
| 240 |
+
# Glide cosine schedule
|
| 241 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
|
| 242 |
+
elif beta_schedule == "cauchy":
|
| 243 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
|
| 244 |
+
elif beta_schedule == "laplace":
|
| 245 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
|
| 246 |
+
else:
|
| 247 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 248 |
+
|
| 249 |
+
if rescale_betas_zero_snr:
|
| 250 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 251 |
+
|
| 252 |
+
self.alphas = 1.0 - self.betas
|
| 253 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 254 |
+
|
| 255 |
+
if rescale_betas_zero_snr:
|
| 256 |
+
# Close to 0 without being 0 so first sigma is not inf
|
| 257 |
+
# FP16 smallest positive subnormal works well here
|
| 258 |
+
self.alphas_cumprod[-1] = 2**-24
|
| 259 |
+
|
| 260 |
+
# Currently we only support VP-type noise schedule
|
| 261 |
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
| 262 |
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
| 263 |
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
| 264 |
+
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
|
| 265 |
+
|
| 266 |
+
# standard deviation of the initial noise distribution
|
| 267 |
+
self.init_noise_sigma = 1.0
|
| 268 |
+
|
| 269 |
+
# settings for DPM-Solver
|
| 270 |
+
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
| 271 |
+
if algorithm_type == "deis":
|
| 272 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
| 273 |
+
else:
|
| 274 |
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
| 275 |
+
|
| 276 |
+
if solver_type not in ["midpoint", "heun"]:
|
| 277 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
| 278 |
+
self.register_to_config(solver_type="midpoint")
|
| 279 |
+
else:
|
| 280 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 281 |
+
|
| 282 |
+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# setable values
|
| 288 |
+
self.num_inference_steps = None
|
| 289 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 290 |
+
self.timesteps = torch.from_numpy(timesteps)
|
| 291 |
+
self.model_outputs = [None] * solver_order
|
| 292 |
+
self.lower_order_nums = 0
|
| 293 |
+
self._step_index = None
|
| 294 |
+
self._begin_index = None
|
| 295 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 296 |
+
|
| 297 |
+
@property
|
| 298 |
+
def step_index(self):
|
| 299 |
+
"""
|
| 300 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 301 |
+
"""
|
| 302 |
+
return self._step_index
|
| 303 |
+
|
| 304 |
+
@property
|
| 305 |
+
def begin_index(self):
|
| 306 |
+
"""
|
| 307 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 308 |
+
"""
|
| 309 |
+
return self._begin_index
|
| 310 |
+
|
| 311 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 312 |
+
"""
|
| 313 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
begin_index (`int`):
|
| 317 |
+
The begin index for the scheduler.
|
| 318 |
+
"""
|
| 319 |
+
self._begin_index = begin_index
|
| 320 |
+
|
| 321 |
+
def set_timesteps(
|
| 322 |
+
self,
|
| 323 |
+
num_inference_steps: int = None,
|
| 324 |
+
device: Union[str, torch.device] = None,
|
| 325 |
+
timesteps: Optional[List[int]] = None,
|
| 326 |
+
):
|
| 327 |
+
"""
|
| 328 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
num_inference_steps (`int`):
|
| 332 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 333 |
+
device (`str` or `torch.device`, *optional*):
|
| 334 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 335 |
+
timesteps (`List[int]`, *optional*):
|
| 336 |
+
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
| 337 |
+
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
| 338 |
+
must be `None`, and `timestep_spacing` attribute will be ignored.
|
| 339 |
+
"""
|
| 340 |
+
if num_inference_steps is None and timesteps is None:
|
| 341 |
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
| 342 |
+
if num_inference_steps is not None and timesteps is not None:
|
| 343 |
+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
| 344 |
+
if timesteps is not None and self.config.use_karras_sigmas:
|
| 345 |
+
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
| 346 |
+
if timesteps is not None and self.config.use_lu_lambdas:
|
| 347 |
+
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
|
| 348 |
+
|
| 349 |
+
if timesteps is not None:
|
| 350 |
+
timesteps = np.array(timesteps).astype(np.int64)
|
| 351 |
+
else:
|
| 352 |
+
# Clipping the minimum of all lambda(t) for numerical stability.
|
| 353 |
+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
| 354 |
+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
| 355 |
+
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
|
| 356 |
+
|
| 357 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
| 358 |
+
if self.config.timestep_spacing == "linspace":
|
| 359 |
+
timesteps = (
|
| 360 |
+
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
|
| 361 |
+
.round()[::-1][:-1]
|
| 362 |
+
.copy()
|
| 363 |
+
.astype(np.int64)
|
| 364 |
+
)
|
| 365 |
+
elif self.config.timestep_spacing == "leading":
|
| 366 |
+
step_ratio = last_timestep // (num_inference_steps + 1)
|
| 367 |
+
# creates integer timesteps by multiplying by ratio
|
| 368 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 369 |
+
timesteps = (
|
| 370 |
+
(np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
| 371 |
+
)
|
| 372 |
+
timesteps += self.config.steps_offset
|
| 373 |
+
elif self.config.timestep_spacing == "trailing":
|
| 374 |
+
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
| 375 |
+
# creates integer timesteps by multiplying by ratio
|
| 376 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 377 |
+
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
| 378 |
+
timesteps -= 1
|
| 379 |
+
else:
|
| 380 |
+
raise ValueError(
|
| 381 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 385 |
+
log_sigmas = np.log(sigmas)
|
| 386 |
+
|
| 387 |
+
if self.config.use_karras_sigmas:
|
| 388 |
+
sigmas = np.flip(sigmas).copy()
|
| 389 |
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 390 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
| 391 |
+
elif self.config.use_lu_lambdas:
|
| 392 |
+
lambdas = np.flip(log_sigmas.copy())
|
| 393 |
+
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
|
| 394 |
+
sigmas = np.exp(lambdas)
|
| 395 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
| 396 |
+
else:
|
| 397 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
| 398 |
+
|
| 399 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 400 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
| 401 |
+
elif self.config.final_sigmas_type == "zero":
|
| 402 |
+
sigma_last = 0
|
| 403 |
+
else:
|
| 404 |
+
raise ValueError(
|
| 405 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
| 409 |
+
|
| 410 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 411 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
| 412 |
+
|
| 413 |
+
self.num_inference_steps = len(timesteps)
|
| 414 |
+
|
| 415 |
+
self.model_outputs = [
|
| 416 |
+
None,
|
| 417 |
+
] * self.config.solver_order
|
| 418 |
+
self.lower_order_nums = 0
|
| 419 |
+
|
| 420 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 421 |
+
self._step_index = None
|
| 422 |
+
self._begin_index = None
|
| 423 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 424 |
+
|
| 425 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 426 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 427 |
+
"""
|
| 428 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 429 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 430 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 431 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 432 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 433 |
+
|
| 434 |
+
https://arxiv.org/abs/2205.11487
|
| 435 |
+
"""
|
| 436 |
+
dtype = sample.dtype
|
| 437 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 438 |
+
|
| 439 |
+
if dtype not in (torch.float32, torch.float64):
|
| 440 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 441 |
+
|
| 442 |
+
# Flatten sample for doing quantile calculation along each image
|
| 443 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 444 |
+
|
| 445 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 446 |
+
|
| 447 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 448 |
+
s = torch.clamp(
|
| 449 |
+
s, min=1, max=self.config.sample_max_value
|
| 450 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 451 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 452 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 453 |
+
|
| 454 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 455 |
+
sample = sample.to(dtype)
|
| 456 |
+
|
| 457 |
+
return sample
|
| 458 |
+
|
| 459 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
| 460 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
| 461 |
+
# get log sigma
|
| 462 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
| 463 |
+
|
| 464 |
+
# get distribution
|
| 465 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
| 466 |
+
|
| 467 |
+
# get sigmas range
|
| 468 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
| 469 |
+
high_idx = low_idx + 1
|
| 470 |
+
|
| 471 |
+
low = log_sigmas[low_idx]
|
| 472 |
+
high = log_sigmas[high_idx]
|
| 473 |
+
|
| 474 |
+
# interpolate sigmas
|
| 475 |
+
w = (low - log_sigma) / (low - high)
|
| 476 |
+
w = np.clip(w, 0, 1)
|
| 477 |
+
|
| 478 |
+
# transform interpolation to time range
|
| 479 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 480 |
+
t = t.reshape(sigma.shape)
|
| 481 |
+
return t
|
| 482 |
+
|
| 483 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 484 |
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
| 485 |
+
sigma_t = sigma * alpha_t
|
| 486 |
+
|
| 487 |
+
return alpha_t, sigma_t
|
| 488 |
+
|
| 489 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
| 490 |
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 491 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 492 |
+
|
| 493 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 494 |
+
# TODO: Add this logic to the other schedulers
|
| 495 |
+
if hasattr(self.config, "sigma_min"):
|
| 496 |
+
sigma_min = self.config.sigma_min
|
| 497 |
+
else:
|
| 498 |
+
sigma_min = None
|
| 499 |
+
|
| 500 |
+
if hasattr(self.config, "sigma_max"):
|
| 501 |
+
sigma_max = self.config.sigma_max
|
| 502 |
+
else:
|
| 503 |
+
sigma_max = None
|
| 504 |
+
|
| 505 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 506 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 507 |
+
|
| 508 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 509 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
| 510 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 511 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 512 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 513 |
+
return sigmas
|
| 514 |
+
|
| 515 |
+
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 516 |
+
"""Constructs the noise schedule of Lu et al. (2022)."""
|
| 517 |
+
|
| 518 |
+
lambda_min: float = in_lambdas[-1].item()
|
| 519 |
+
lambda_max: float = in_lambdas[0].item()
|
| 520 |
+
|
| 521 |
+
rho = 1.0 # 1.0 is the value used in the paper
|
| 522 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
| 523 |
+
min_inv_rho = lambda_min ** (1 / rho)
|
| 524 |
+
max_inv_rho = lambda_max ** (1 / rho)
|
| 525 |
+
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 526 |
+
return lambdas
|
| 527 |
+
|
| 528 |
+
def convert_model_output(
|
| 529 |
+
self,
|
| 530 |
+
model_output: torch.Tensor,
|
| 531 |
+
*args,
|
| 532 |
+
sample: torch.Tensor = None,
|
| 533 |
+
**kwargs,
|
| 534 |
+
) -> torch.Tensor:
|
| 535 |
+
"""
|
| 536 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
| 537 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
| 538 |
+
integral of the data prediction model.
|
| 539 |
+
|
| 540 |
+
<Tip>
|
| 541 |
+
|
| 542 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
| 543 |
+
prediction and data prediction models.
|
| 544 |
+
|
| 545 |
+
</Tip>
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
model_output (`torch.Tensor`):
|
| 549 |
+
The direct output from the learned diffusion model.
|
| 550 |
+
sample (`torch.Tensor`):
|
| 551 |
+
A current instance of a sample created by the diffusion process.
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
`torch.Tensor`:
|
| 555 |
+
The converted model output.
|
| 556 |
+
"""
|
| 557 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 558 |
+
if sample is None:
|
| 559 |
+
if len(args) > 1:
|
| 560 |
+
sample = args[1]
|
| 561 |
+
else:
|
| 562 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
| 563 |
+
if timestep is not None:
|
| 564 |
+
deprecate(
|
| 565 |
+
"timesteps",
|
| 566 |
+
"1.0.0",
|
| 567 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Guard against out-of-bounds access (can occur in concurrent scenarios)
|
| 571 |
+
safe_step_index = min(self.step_index, len(self.sigmas) - 1)
|
| 572 |
+
|
| 573 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
| 574 |
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
| 575 |
+
if self.config.prediction_type == "epsilon":
|
| 576 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
| 577 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
| 578 |
+
model_output = model_output[:, :3]
|
| 579 |
+
sigma = self.sigmas[safe_step_index]
|
| 580 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 581 |
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
| 582 |
+
elif self.config.prediction_type == "sample":
|
| 583 |
+
x0_pred = model_output
|
| 584 |
+
elif self.config.prediction_type == "v_prediction":
|
| 585 |
+
sigma = self.sigmas[safe_step_index]
|
| 586 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 587 |
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
| 588 |
+
else:
|
| 589 |
+
raise ValueError(
|
| 590 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 591 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if self.config.thresholding:
|
| 595 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 596 |
+
|
| 597 |
+
return x0_pred
|
| 598 |
+
|
| 599 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
| 600 |
+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 601 |
+
if self.config.prediction_type == "epsilon":
|
| 602 |
+
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
| 603 |
+
if self.config.variance_type in ["learned", "learned_range"]:
|
| 604 |
+
epsilon = model_output[:, :3]
|
| 605 |
+
else:
|
| 606 |
+
epsilon = model_output
|
| 607 |
+
elif self.config.prediction_type == "sample":
|
| 608 |
+
sigma = self.sigmas[safe_step_index]
|
| 609 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 610 |
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
| 611 |
+
elif self.config.prediction_type == "v_prediction":
|
| 612 |
+
sigma = self.sigmas[safe_step_index]
|
| 613 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 614 |
+
epsilon = alpha_t * model_output + sigma_t * sample
|
| 615 |
+
else:
|
| 616 |
+
raise ValueError(
|
| 617 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 618 |
+
" `v_prediction` for the DPMSolverMultistepScheduler."
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if self.config.thresholding:
|
| 622 |
+
sigma = self.sigmas[safe_step_index]
|
| 623 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 624 |
+
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
| 625 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 626 |
+
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
| 627 |
+
|
| 628 |
+
return epsilon
|
| 629 |
+
|
| 630 |
+
def dpm_solver_first_order_update(
|
| 631 |
+
self,
|
| 632 |
+
model_output: torch.Tensor,
|
| 633 |
+
*args,
|
| 634 |
+
sample: torch.Tensor = None,
|
| 635 |
+
noise: Optional[torch.Tensor] = None,
|
| 636 |
+
**kwargs,
|
| 637 |
+
) -> torch.Tensor:
|
| 638 |
+
"""
|
| 639 |
+
One step for the first-order DPMSolver (equivalent to DDIM).
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
model_output (`torch.Tensor`):
|
| 643 |
+
The direct output from the learned diffusion model.
|
| 644 |
+
sample (`torch.Tensor`):
|
| 645 |
+
A current instance of a sample created by the diffusion process.
|
| 646 |
+
|
| 647 |
+
Returns:
|
| 648 |
+
`torch.Tensor`:
|
| 649 |
+
The sample tensor at the previous timestep.
|
| 650 |
+
"""
|
| 651 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 652 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 653 |
+
if sample is None:
|
| 654 |
+
if len(args) > 2:
|
| 655 |
+
sample = args[2]
|
| 656 |
+
else:
|
| 657 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
| 658 |
+
if timestep is not None:
|
| 659 |
+
deprecate(
|
| 660 |
+
"timesteps",
|
| 661 |
+
"1.0.0",
|
| 662 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
if prev_timestep is not None:
|
| 666 |
+
deprecate(
|
| 667 |
+
"prev_timestep",
|
| 668 |
+
"1.0.0",
|
| 669 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Guard against out-of-bounds access (can occur in concurrent scenarios)
|
| 673 |
+
current_index = min(self.step_index, len(self.sigmas) - 1)
|
| 674 |
+
next_index = min(self.step_index + 1, len(self.sigmas) - 1)
|
| 675 |
+
sigma_t, sigma_s = self.sigmas[next_index], self.sigmas[current_index]
|
| 676 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 677 |
+
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
| 678 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 679 |
+
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
| 680 |
+
|
| 681 |
+
h = lambda_t - lambda_s
|
| 682 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 683 |
+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
| 684 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 685 |
+
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
| 686 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 687 |
+
assert noise is not None
|
| 688 |
+
x_t = (
|
| 689 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
| 690 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
| 691 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 692 |
+
)
|
| 693 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 694 |
+
assert noise is not None
|
| 695 |
+
x_t = (
|
| 696 |
+
(alpha_t / alpha_s) * sample
|
| 697 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
| 698 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 699 |
+
)
|
| 700 |
+
return x_t
|
| 701 |
+
|
| 702 |
+
def multistep_dpm_solver_second_order_update(
|
| 703 |
+
self,
|
| 704 |
+
model_output_list: List[torch.Tensor],
|
| 705 |
+
*args,
|
| 706 |
+
sample: torch.Tensor = None,
|
| 707 |
+
noise: Optional[torch.Tensor] = None,
|
| 708 |
+
**kwargs,
|
| 709 |
+
) -> torch.Tensor:
|
| 710 |
+
"""
|
| 711 |
+
One step for the second-order multistep DPMSolver.
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
model_output_list (`List[torch.Tensor]`):
|
| 715 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 716 |
+
sample (`torch.Tensor`):
|
| 717 |
+
A current instance of a sample created by the diffusion process.
|
| 718 |
+
|
| 719 |
+
Returns:
|
| 720 |
+
`torch.Tensor`:
|
| 721 |
+
The sample tensor at the previous timestep.
|
| 722 |
+
"""
|
| 723 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
| 724 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 725 |
+
if sample is None:
|
| 726 |
+
if len(args) > 2:
|
| 727 |
+
sample = args[2]
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
| 730 |
+
if timestep_list is not None:
|
| 731 |
+
deprecate(
|
| 732 |
+
"timestep_list",
|
| 733 |
+
"1.0.0",
|
| 734 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
if prev_timestep is not None:
|
| 738 |
+
deprecate(
|
| 739 |
+
"prev_timestep",
|
| 740 |
+
"1.0.0",
|
| 741 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
# Guard against out-of-bounds access (can occur in concurrent scenarios)
|
| 745 |
+
current_index = min(self.step_index, len(self.sigmas) - 1)
|
| 746 |
+
next_index = min(self.step_index + 1, len(self.sigmas) - 1)
|
| 747 |
+
prev_index = max(self.step_index - 1, 0)
|
| 748 |
+
sigma_t, sigma_s0, sigma_s1 = (
|
| 749 |
+
self.sigmas[next_index],
|
| 750 |
+
self.sigmas[current_index],
|
| 751 |
+
self.sigmas[prev_index],
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 755 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 756 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 757 |
+
|
| 758 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 759 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 760 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 761 |
+
|
| 762 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
| 763 |
+
|
| 764 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
| 765 |
+
r0 = h_0 / h
|
| 766 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
| 767 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 768 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
| 769 |
+
if self.config.solver_type == "midpoint":
|
| 770 |
+
x_t = (
|
| 771 |
+
(sigma_t / sigma_s0) * sample
|
| 772 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 773 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
| 774 |
+
)
|
| 775 |
+
elif self.config.solver_type == "heun":
|
| 776 |
+
x_t = (
|
| 777 |
+
(sigma_t / sigma_s0) * sample
|
| 778 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 779 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 780 |
+
)
|
| 781 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 782 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 783 |
+
if self.config.solver_type == "midpoint":
|
| 784 |
+
x_t = (
|
| 785 |
+
(alpha_t / alpha_s0) * sample
|
| 786 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 787 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
|
| 788 |
+
)
|
| 789 |
+
elif self.config.solver_type == "heun":
|
| 790 |
+
x_t = (
|
| 791 |
+
(alpha_t / alpha_s0) * sample
|
| 792 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 793 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 794 |
+
)
|
| 795 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 796 |
+
assert noise is not None
|
| 797 |
+
if self.config.solver_type == "midpoint":
|
| 798 |
+
x_t = (
|
| 799 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 800 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 801 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
| 802 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 803 |
+
)
|
| 804 |
+
elif self.config.solver_type == "heun":
|
| 805 |
+
x_t = (
|
| 806 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 807 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 808 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
| 809 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 810 |
+
)
|
| 811 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 812 |
+
assert noise is not None
|
| 813 |
+
if self.config.solver_type == "midpoint":
|
| 814 |
+
x_t = (
|
| 815 |
+
(alpha_t / alpha_s0) * sample
|
| 816 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 817 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D1
|
| 818 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 819 |
+
)
|
| 820 |
+
elif self.config.solver_type == "heun":
|
| 821 |
+
x_t = (
|
| 822 |
+
(alpha_t / alpha_s0) * sample
|
| 823 |
+
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 824 |
+
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 825 |
+
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
|
| 826 |
+
)
|
| 827 |
+
return x_t
|
| 828 |
+
|
| 829 |
+
def multistep_dpm_solver_third_order_update(
|
| 830 |
+
self,
|
| 831 |
+
model_output_list: List[torch.Tensor],
|
| 832 |
+
*args,
|
| 833 |
+
sample: torch.Tensor = None,
|
| 834 |
+
**kwargs,
|
| 835 |
+
) -> torch.Tensor:
|
| 836 |
+
"""
|
| 837 |
+
One step for the third-order multistep DPMSolver.
|
| 838 |
+
|
| 839 |
+
Args:
|
| 840 |
+
model_output_list (`List[torch.Tensor]`):
|
| 841 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 842 |
+
sample (`torch.Tensor`):
|
| 843 |
+
A current instance of a sample created by diffusion process.
|
| 844 |
+
|
| 845 |
+
Returns:
|
| 846 |
+
`torch.Tensor`:
|
| 847 |
+
The sample tensor at the previous timestep.
|
| 848 |
+
"""
|
| 849 |
+
|
| 850 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
| 851 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
|
| 852 |
+
if sample is None:
|
| 853 |
+
if len(args) > 2:
|
| 854 |
+
sample = args[2]
|
| 855 |
+
else:
|
| 856 |
+
raise ValueError(" missing`sample` as a required keyward argument")
|
| 857 |
+
if timestep_list is not None:
|
| 858 |
+
deprecate(
|
| 859 |
+
"timestep_list",
|
| 860 |
+
"1.0.0",
|
| 861 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
if prev_timestep is not None:
|
| 865 |
+
deprecate(
|
| 866 |
+
"prev_timestep",
|
| 867 |
+
"1.0.0",
|
| 868 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
# Guard against out-of-bounds access (can occur in concurrent scenarios)
|
| 872 |
+
current_index = min(self.step_index, len(self.sigmas) - 1)
|
| 873 |
+
next_index = min(self.step_index + 1, len(self.sigmas) - 1)
|
| 874 |
+
prev_index_1 = max(self.step_index - 1, 0)
|
| 875 |
+
prev_index_2 = max(self.step_index - 2, 0)
|
| 876 |
+
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
| 877 |
+
self.sigmas[next_index],
|
| 878 |
+
self.sigmas[current_index],
|
| 879 |
+
self.sigmas[prev_index_1],
|
| 880 |
+
self.sigmas[prev_index_2],
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 884 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 885 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 886 |
+
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
| 887 |
+
|
| 888 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 889 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 890 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 891 |
+
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
| 892 |
+
|
| 893 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
| 894 |
+
|
| 895 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
| 896 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 897 |
+
D0 = m0
|
| 898 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
| 899 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 900 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 901 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 902 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 903 |
+
x_t = (
|
| 904 |
+
(sigma_t / sigma_s0) * sample
|
| 905 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 906 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 907 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
| 908 |
+
)
|
| 909 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 910 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 911 |
+
x_t = (
|
| 912 |
+
(alpha_t / alpha_s0) * sample
|
| 913 |
+
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
| 914 |
+
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
| 915 |
+
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
|
| 916 |
+
)
|
| 917 |
+
return x_t
|
| 918 |
+
|
| 919 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 920 |
+
if schedule_timesteps is None:
|
| 921 |
+
schedule_timesteps = self.timesteps
|
| 922 |
+
|
| 923 |
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
| 924 |
+
|
| 925 |
+
if len(index_candidates) == 0:
|
| 926 |
+
step_index = len(self.timesteps) - 1
|
| 927 |
+
# The sigma index that is taken for the **very** first `step`
|
| 928 |
+
# is always the second index (or the last index if there is only 1)
|
| 929 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 930 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 931 |
+
elif len(index_candidates) > 1:
|
| 932 |
+
step_index = index_candidates[1].item()
|
| 933 |
+
else:
|
| 934 |
+
step_index = index_candidates[0].item()
|
| 935 |
+
|
| 936 |
+
return step_index
|
| 937 |
+
|
| 938 |
+
def _init_step_index(self, timestep):
|
| 939 |
+
"""
|
| 940 |
+
Initialize the step_index counter for the scheduler.
|
| 941 |
+
"""
|
| 942 |
+
|
| 943 |
+
if self.begin_index is None:
|
| 944 |
+
if isinstance(timestep, torch.Tensor):
|
| 945 |
+
timestep = timestep.to(self.timesteps.device)
|
| 946 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 947 |
+
else:
|
| 948 |
+
self._step_index = self._begin_index
|
| 949 |
+
|
| 950 |
+
def step(
|
| 951 |
+
self,
|
| 952 |
+
model_output: torch.Tensor,
|
| 953 |
+
timestep: int,
|
| 954 |
+
sample: torch.Tensor,
|
| 955 |
+
generator=None,
|
| 956 |
+
variance_noise: Optional[torch.Tensor] = None,
|
| 957 |
+
return_dict: bool = True,
|
| 958 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 959 |
+
"""
|
| 960 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 961 |
+
the multistep DPMSolver.
|
| 962 |
+
|
| 963 |
+
Args:
|
| 964 |
+
model_output (`torch.Tensor`):
|
| 965 |
+
The direct output from learned diffusion model.
|
| 966 |
+
timestep (`int`):
|
| 967 |
+
The current discrete timestep in the diffusion chain.
|
| 968 |
+
sample (`torch.Tensor`):
|
| 969 |
+
A current instance of a sample created by the diffusion process.
|
| 970 |
+
generator (`torch.Generator`, *optional*):
|
| 971 |
+
A random number generator.
|
| 972 |
+
variance_noise (`torch.Tensor`):
|
| 973 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 974 |
+
itself. Useful for methods such as [`LEdits++`].
|
| 975 |
+
return_dict (`bool`):
|
| 976 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 977 |
+
|
| 978 |
+
Returns:
|
| 979 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 980 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 981 |
+
tuple is returned where the first element is the sample tensor.
|
| 982 |
+
|
| 983 |
+
"""
|
| 984 |
+
if self.num_inference_steps is None:
|
| 985 |
+
raise ValueError(
|
| 986 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
if self.step_index is None:
|
| 990 |
+
self._init_step_index(timestep)
|
| 991 |
+
|
| 992 |
+
# Improve numerical stability for small number of steps
|
| 993 |
+
# Also guard against out-of-bounds access: if step_index >= len(timesteps) - 1,
|
| 994 |
+
# we must use first-order to avoid accessing sigmas[step_index + 1] out of bounds
|
| 995 |
+
is_last_or_past = self.step_index >= len(self.timesteps) - 1
|
| 996 |
+
lower_order_final = is_last_or_past and (
|
| 997 |
+
self.config.euler_at_final
|
| 998 |
+
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
| 999 |
+
or self.config.final_sigmas_type == "zero"
|
| 1000 |
+
or self.step_index >= len(self.sigmas) - 1 # Safety: prevent OOB access
|
| 1001 |
+
)
|
| 1002 |
+
lower_order_second = (
|
| 1003 |
+
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
model_output = self.convert_model_output(model_output, sample=sample)
|
| 1007 |
+
for i in range(self.config.solver_order - 1):
|
| 1008 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 1009 |
+
self.model_outputs[-1] = model_output
|
| 1010 |
+
|
| 1011 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 1012 |
+
sample = sample.to(torch.float32)
|
| 1013 |
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
| 1014 |
+
noise = randn_tensor(
|
| 1015 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
| 1016 |
+
)
|
| 1017 |
+
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
| 1018 |
+
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
| 1019 |
+
else:
|
| 1020 |
+
noise = None
|
| 1021 |
+
|
| 1022 |
+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
| 1023 |
+
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
| 1024 |
+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
| 1025 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
| 1026 |
+
else:
|
| 1027 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
| 1028 |
+
|
| 1029 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 1030 |
+
self.lower_order_nums += 1
|
| 1031 |
+
|
| 1032 |
+
# Cast sample back to expected dtype
|
| 1033 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 1034 |
+
|
| 1035 |
+
# upon completion increase step index by one
|
| 1036 |
+
self._step_index += 1
|
| 1037 |
+
|
| 1038 |
+
if not return_dict:
|
| 1039 |
+
return (prev_sample,)
|
| 1040 |
+
|
| 1041 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 1042 |
+
|
| 1043 |
+
def add_noise(
|
| 1044 |
+
self,
|
| 1045 |
+
original_samples: torch.Tensor,
|
| 1046 |
+
noise: torch.Tensor,
|
| 1047 |
+
timesteps: torch.IntTensor,
|
| 1048 |
+
) -> torch.Tensor:
|
| 1049 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 1050 |
+
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1051 |
+
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1052 |
+
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
|
| 1053 |
+
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
|
| 1054 |
+
timesteps = timesteps.to(original_samples.device)
|
| 1055 |
+
alpha_t = alpha_t[timesteps].flatten()
|
| 1056 |
+
while len(alpha_t.shape) < len(original_samples.shape):
|
| 1057 |
+
alpha_t = alpha_t.unsqueeze(-1)
|
| 1058 |
+
|
| 1059 |
+
sigma_t = sigma_t[timesteps].flatten()
|
| 1060 |
+
while len(sigma_t.shape) < len(original_samples.shape):
|
| 1061 |
+
sigma_t = sigma_t.unsqueeze(-1)
|
| 1062 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 1063 |
+
return noisy_samples
|
| 1064 |
+
|
| 1065 |
+
def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
| 1066 |
+
# alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1067 |
+
# sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 1068 |
+
alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
|
| 1069 |
+
sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
|
| 1070 |
+
|
| 1071 |
+
timesteps = timesteps.to(original_samples.device)
|
| 1072 |
+
alpha_t = alpha_t[timesteps].flatten()
|
| 1073 |
+
while len(alpha_t.shape) < len(original_samples.shape):
|
| 1074 |
+
alpha_t = alpha_t.unsqueeze(-1)
|
| 1075 |
+
|
| 1076 |
+
sigma_t = sigma_t[timesteps].flatten()
|
| 1077 |
+
while len(sigma_t.shape) < len(original_samples.shape):
|
| 1078 |
+
sigma_t = sigma_t.unsqueeze(-1)
|
| 1079 |
+
|
| 1080 |
+
velocity = alpha_t * noise - sigma_t * original_samples
|
| 1081 |
+
return velocity
|
| 1082 |
+
|
| 1083 |
+
def __len__(self):
|
| 1084 |
+
return self.config.num_train_timesteps
|
kugelaudio_open/ui/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio web interface for KugelAudio."""
|
| 2 |
+
|
| 3 |
+
from kugelaudio_open.ui.app import create_app, launch_app
|
| 4 |
+
|
| 5 |
+
__all__ = ["create_app", "launch_app"]
|
kugelaudio_open/ui/__main__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI entry point for KugelAudio UI."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from kugelaudio_open.ui import launch_app
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
parser = argparse.ArgumentParser(description="Launch KugelAudio Gradio UI")
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--share",
|
| 12 |
+
action="store_true",
|
| 13 |
+
help="Create a public Gradio share link",
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--host",
|
| 17 |
+
default="127.0.0.1",
|
| 18 |
+
help="Server hostname (default: 127.0.0.1, use 0.0.0.0 for network access)",
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--port",
|
| 22 |
+
type=int,
|
| 23 |
+
default=7860,
|
| 24 |
+
help="Server port (default: 7860)",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
print(f"ποΈ Starting KugelAudio UI on {args.host}:{args.port}")
|
| 30 |
+
if args.share:
|
| 31 |
+
print("π‘ Creating public share link...")
|
| 32 |
+
|
| 33 |
+
launch_app(
|
| 34 |
+
share=args.share,
|
| 35 |
+
server_name=args.host,
|
| 36 |
+
server_port=args.port,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
kugelaudio_open/ui/app.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio web interface for KugelAudio text-to-speech."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
GRADIO_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
GRADIO_AVAILABLE = False
|
| 17 |
+
warnings.warn("Gradio not installed. Install with: pip install gradio")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Global model instances (lazy loaded)
|
| 21 |
+
_model = None
|
| 22 |
+
_processor = None
|
| 23 |
+
_watermark = None
|
| 24 |
+
_current_model_id = None # Track which model is loaded
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_device():
|
| 28 |
+
"""Get the best available device."""
|
| 29 |
+
if torch.cuda.is_available():
|
| 30 |
+
return "cuda"
|
| 31 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 32 |
+
return "mps"
|
| 33 |
+
return "cpu"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _warmup_model(model, processor=None):
|
| 37 |
+
"""Warmup model components to eliminate CUDA kernel compilation overhead on first generation.
|
| 38 |
+
|
| 39 |
+
This runs dummy data through all model components (acoustic decoder, semantic encoder,
|
| 40 |
+
diffusion head, language model) to trigger JIT compilation before actual inference.
|
| 41 |
+
"""
|
| 42 |
+
device = next(model.parameters()).device
|
| 43 |
+
dtype = next(model.parameters()).dtype
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
# 1. Warmup acoustic decoder (biggest impact - saves ~190ms on first call)
|
| 47 |
+
latent_dim = model.config.acoustic_vae_dim
|
| 48 |
+
dummy_latent = torch.randn(1, latent_dim, 1, device=device, dtype=dtype)
|
| 49 |
+
_ = model.acoustic_tokenizer.decode(dummy_latent)
|
| 50 |
+
|
| 51 |
+
# 2. Warmup semantic encoder
|
| 52 |
+
dummy_audio = torch.randn(1, 1, 3200, device=device, dtype=dtype)
|
| 53 |
+
_ = model.semantic_tokenizer.encode(dummy_audio)
|
| 54 |
+
|
| 55 |
+
# 3. Warmup diffusion/prediction head
|
| 56 |
+
hidden_size = model.config.decoder_config.hidden_size
|
| 57 |
+
model.noise_scheduler.set_timesteps(model.ddpm_inference_steps)
|
| 58 |
+
|
| 59 |
+
dummy_condition = torch.randn(2, hidden_size, device=device, dtype=dtype)
|
| 60 |
+
dummy_speech = torch.randn(2, latent_dim, device=device, dtype=dtype)
|
| 61 |
+
|
| 62 |
+
for t in model.noise_scheduler.timesteps:
|
| 63 |
+
half = dummy_speech[:1]
|
| 64 |
+
combined = torch.cat([half, half], dim=0)
|
| 65 |
+
_ = model.prediction_head(
|
| 66 |
+
combined,
|
| 67 |
+
t.repeat(combined.shape[0]).to(combined),
|
| 68 |
+
condition=dummy_condition,
|
| 69 |
+
)
|
| 70 |
+
dummy_eps = torch.randn_like(dummy_speech)
|
| 71 |
+
dummy_speech = model.noise_scheduler.step(dummy_eps, t, dummy_speech).prev_sample
|
| 72 |
+
|
| 73 |
+
# 4. Warmup language model with KV cache path
|
| 74 |
+
dummy_ids = torch.randint(0, 32000, (1, 64), device=device)
|
| 75 |
+
dummy_mask = torch.ones_like(dummy_ids)
|
| 76 |
+
_ = model.model.language_model(input_ids=dummy_ids, attention_mask=dummy_mask, use_cache=True)
|
| 77 |
+
|
| 78 |
+
# 5. Warmup acoustic encoder (for voice prompts)
|
| 79 |
+
dummy_voice = torch.randn(1, 1, 24000, device=device, dtype=dtype)
|
| 80 |
+
_ = model.acoustic_tokenizer.encode(dummy_voice)
|
| 81 |
+
|
| 82 |
+
# 6. Run a minimal generation to warmup the full generation path
|
| 83 |
+
if processor is not None:
|
| 84 |
+
dummy_inputs = processor(text="Hi.", return_tensors="pt")
|
| 85 |
+
dummy_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs.items()}
|
| 86 |
+
_ = model.generate(**dummy_inputs, cfg_scale=3.0, max_new_tokens=10, show_progress=False)
|
| 87 |
+
|
| 88 |
+
# Clear memory
|
| 89 |
+
if device.type == "cuda":
|
| 90 |
+
torch.cuda.empty_cache()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_models(model_id: str = "kugelaudio/kugelaudio-0-open"):
|
| 94 |
+
"""Load model and processor. Switches model if a different model_id is requested."""
|
| 95 |
+
global _model, _processor, _watermark, _current_model_id
|
| 96 |
+
|
| 97 |
+
from kugelaudio_open.models import KugelAudioForConditionalGenerationInference
|
| 98 |
+
from kugelaudio_open.processors import KugelAudioProcessor
|
| 99 |
+
from kugelaudio_open.watermark import AudioWatermark
|
| 100 |
+
|
| 101 |
+
device = get_device()
|
| 102 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 103 |
+
|
| 104 |
+
# Check if we need to load a different model
|
| 105 |
+
if _model is None or _current_model_id != model_id:
|
| 106 |
+
# Clean up old model if switching
|
| 107 |
+
if _model is not None and _current_model_id != model_id:
|
| 108 |
+
print(f"Switching model from {_current_model_id} to {model_id}...")
|
| 109 |
+
del _model
|
| 110 |
+
del _processor
|
| 111 |
+
_model = None
|
| 112 |
+
_processor = None
|
| 113 |
+
# Clear CUDA cache to free memory
|
| 114 |
+
if device == "cuda":
|
| 115 |
+
torch.cuda.empty_cache()
|
| 116 |
+
|
| 117 |
+
print(f"Loading model {model_id} on {device}...")
|
| 118 |
+
try:
|
| 119 |
+
_model = KugelAudioForConditionalGenerationInference.from_pretrained(
|
| 120 |
+
model_id,
|
| 121 |
+
torch_dtype=dtype,
|
| 122 |
+
attn_implementation="flash_attention_2" if device == "cuda" else "sdpa",
|
| 123 |
+
).to(device)
|
| 124 |
+
except Exception:
|
| 125 |
+
_model = KugelAudioForConditionalGenerationInference.from_pretrained(
|
| 126 |
+
model_id,
|
| 127 |
+
torch_dtype=dtype,
|
| 128 |
+
).to(device)
|
| 129 |
+
_model.eval()
|
| 130 |
+
_current_model_id = model_id
|
| 131 |
+
print(f"Model {model_id} loaded!")
|
| 132 |
+
|
| 133 |
+
if _processor is None:
|
| 134 |
+
_processor = KugelAudioProcessor.from_pretrained(model_id)
|
| 135 |
+
|
| 136 |
+
# Warmup to eliminate first-generation slowness from CUDA kernel compilation
|
| 137 |
+
# Do this after processor is loaded so we can run a mini-generation
|
| 138 |
+
if device == "cuda" and _model is not None:
|
| 139 |
+
# Check if we need to warmup (only on first load)
|
| 140 |
+
if not getattr(_model, "_warmed_up", False):
|
| 141 |
+
print("Warming up model (this may take a moment)...")
|
| 142 |
+
_warmup_model(_model, _processor)
|
| 143 |
+
_model._warmed_up = True
|
| 144 |
+
print("Warmup complete!")
|
| 145 |
+
|
| 146 |
+
if _watermark is None:
|
| 147 |
+
_watermark = AudioWatermark(device=device)
|
| 148 |
+
|
| 149 |
+
return _model, _processor, _watermark
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def generate_speech(
|
| 153 |
+
text: str,
|
| 154 |
+
reference_audio: Optional[Tuple[int, np.ndarray]] = None,
|
| 155 |
+
model_choice: str = "kugelaudio-0-open",
|
| 156 |
+
cfg_scale: float = 3.0,
|
| 157 |
+
max_tokens: int = 2048,
|
| 158 |
+
) -> Tuple[int, np.ndarray]:
|
| 159 |
+
"""Generate speech from text.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
text: Text to synthesize
|
| 163 |
+
reference_audio: Optional (sample_rate, audio_array) for voice cloning
|
| 164 |
+
model_choice: Model variant to use
|
| 165 |
+
cfg_scale: Classifier-free guidance scale
|
| 166 |
+
max_tokens: Maximum generation tokens
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tuple of (sample_rate, audio_array)
|
| 170 |
+
|
| 171 |
+
Note:
|
| 172 |
+
All generated audio is automatically watermarked for identification.
|
| 173 |
+
"""
|
| 174 |
+
if not text.strip():
|
| 175 |
+
raise gr.Error("Please enter some text to synthesize.")
|
| 176 |
+
|
| 177 |
+
model_id = f"kugelaudio/{model_choice}"
|
| 178 |
+
model, processor, watermark = load_models(model_id)
|
| 179 |
+
device = next(model.parameters()).device
|
| 180 |
+
|
| 181 |
+
# Process reference audio if provided
|
| 182 |
+
voice_audio = None
|
| 183 |
+
if reference_audio is not None:
|
| 184 |
+
ref_sr, ref_audio = reference_audio
|
| 185 |
+
print(f"[Voice Cloning] Input audio: sr={ref_sr}, shape={ref_audio.shape}, dtype={ref_audio.dtype}")
|
| 186 |
+
|
| 187 |
+
# Convert to float32 and normalize based on dtype
|
| 188 |
+
if ref_audio.dtype == np.int16:
|
| 189 |
+
ref_audio = ref_audio.astype(np.float32) / 32768.0
|
| 190 |
+
elif ref_audio.dtype == np.int32:
|
| 191 |
+
ref_audio = ref_audio.astype(np.float32) / 2147483648.0
|
| 192 |
+
elif ref_audio.dtype == np.float64:
|
| 193 |
+
ref_audio = ref_audio.astype(np.float32)
|
| 194 |
+
elif ref_audio.dtype != np.float32:
|
| 195 |
+
ref_audio = ref_audio.astype(np.float32)
|
| 196 |
+
|
| 197 |
+
# Ensure mono BEFORE resampling (important for stereo files)
|
| 198 |
+
if ref_audio.ndim > 1:
|
| 199 |
+
if ref_audio.shape[0] == 2: # [2, samples] format (channels first)
|
| 200 |
+
ref_audio = ref_audio.mean(axis=0)
|
| 201 |
+
elif ref_audio.shape[-1] == 2: # [samples, 2] format (channels last)
|
| 202 |
+
ref_audio = ref_audio.mean(axis=-1)
|
| 203 |
+
elif ref_audio.shape[0] < ref_audio.shape[-1]: # Likely [channels, samples]
|
| 204 |
+
ref_audio = ref_audio.mean(axis=0)
|
| 205 |
+
else: # Likely [samples, channels]
|
| 206 |
+
ref_audio = ref_audio.mean(axis=-1)
|
| 207 |
+
|
| 208 |
+
# Ensure 1D
|
| 209 |
+
ref_audio = ref_audio.squeeze()
|
| 210 |
+
|
| 211 |
+
print(f"[Voice Cloning] After mono conversion: shape={ref_audio.shape}, dtype={ref_audio.dtype}")
|
| 212 |
+
|
| 213 |
+
# Resample to 24kHz if needed - this is CRITICAL for voice cloning
|
| 214 |
+
if ref_sr != 24000:
|
| 215 |
+
import librosa
|
| 216 |
+
print(f"[Voice Cloning] Resampling from {ref_sr}Hz to 24000Hz (ratio: {ref_sr/24000:.4f})")
|
| 217 |
+
ref_audio = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=24000)
|
| 218 |
+
print(f"[Voice Cloning] After resampling: shape={ref_audio.shape}, duration={len(ref_audio)/24000:.2f}s")
|
| 219 |
+
else:
|
| 220 |
+
print(f"[Voice Cloning] No resampling needed, already at 24kHz")
|
| 221 |
+
|
| 222 |
+
# Normalize audio to reasonable range
|
| 223 |
+
max_val = np.abs(ref_audio).max()
|
| 224 |
+
if max_val > 0:
|
| 225 |
+
ref_audio = ref_audio / max_val * 0.95
|
| 226 |
+
|
| 227 |
+
voice_audio = ref_audio
|
| 228 |
+
print(f"[Voice Cloning] Final voice audio: shape={voice_audio.shape}, min={voice_audio.min():.4f}, max={voice_audio.max():.4f}, std={voice_audio.std():.4f}")
|
| 229 |
+
|
| 230 |
+
# Process text input with optional voice prompt
|
| 231 |
+
inputs = processor(text=text.strip(), voice_prompt=voice_audio, return_tensors="pt")
|
| 232 |
+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
| 233 |
+
|
| 234 |
+
print(f"[Generation] Using model: {model_id}, cfg_scale={cfg_scale}, max_tokens={max_tokens}")
|
| 235 |
+
|
| 236 |
+
# Generate
|
| 237 |
+
with torch.no_grad():
|
| 238 |
+
outputs = model.generate(
|
| 239 |
+
**inputs,
|
| 240 |
+
cfg_scale=cfg_scale,
|
| 241 |
+
max_new_tokens=max_tokens,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
|
| 245 |
+
raise gr.Error("Generation failed. Please try again with different settings.")
|
| 246 |
+
|
| 247 |
+
# Audio is already watermarked by the model's generate method
|
| 248 |
+
audio = outputs.speech_outputs[0]
|
| 249 |
+
print(f"[Generation] Raw output: shape={audio.shape}, dtype={audio.dtype}")
|
| 250 |
+
|
| 251 |
+
# Convert to numpy (convert to float32 first since numpy doesn't support bfloat16)
|
| 252 |
+
if isinstance(audio, torch.Tensor):
|
| 253 |
+
audio = audio.cpu().float().numpy()
|
| 254 |
+
|
| 255 |
+
# Ensure correct shape (1D array)
|
| 256 |
+
audio = audio.squeeze()
|
| 257 |
+
|
| 258 |
+
# Normalize to prevent clipping (important for Gradio playback)
|
| 259 |
+
max_val = np.abs(audio).max()
|
| 260 |
+
if max_val > 1.0:
|
| 261 |
+
audio = audio / max_val * 0.95
|
| 262 |
+
|
| 263 |
+
print(f"[Generation] Final output: shape={audio.shape}, dtype={audio.dtype}, duration={len(audio)/24000:.2f}s")
|
| 264 |
+
print(f"[Generation] Audio stats: min={audio.min():.4f}, max={audio.max():.4f}, std={audio.std():.4f}")
|
| 265 |
+
|
| 266 |
+
# Return with explicit sample rate - Gradio expects (sample_rate, audio_array)
|
| 267 |
+
return (24000, audio)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def check_watermark(audio: Tuple[int, np.ndarray]) -> str:
|
| 271 |
+
"""Check if audio contains KugelAudio watermark."""
|
| 272 |
+
if audio is None:
|
| 273 |
+
return "No audio provided."
|
| 274 |
+
|
| 275 |
+
from kugelaudio_open.watermark import AudioWatermark
|
| 276 |
+
|
| 277 |
+
sr, audio_data = audio
|
| 278 |
+
|
| 279 |
+
# Convert to float32 if needed
|
| 280 |
+
if audio_data.dtype == np.int16:
|
| 281 |
+
audio_data = audio_data.astype(np.float32) / 32768.0
|
| 282 |
+
elif audio_data.dtype == np.int32:
|
| 283 |
+
audio_data = audio_data.astype(np.float32) / 2147483648.0
|
| 284 |
+
|
| 285 |
+
watermark = AudioWatermark()
|
| 286 |
+
result = watermark.detect(audio_data, sample_rate=sr)
|
| 287 |
+
|
| 288 |
+
if result.detected:
|
| 289 |
+
return f"β
**Watermark Detected**\n\nConfidence: {result.confidence:.1%}\n\nThis audio was generated by KugelAudio."
|
| 290 |
+
else:
|
| 291 |
+
return f"β **No Watermark Detected**\n\nConfidence: {result.confidence:.1%}\n\nThis audio does not appear to be generated by KugelAudio."
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def create_app() -> "gr.Blocks":
|
| 295 |
+
"""Create the Gradio application."""
|
| 296 |
+
if not GRADIO_AVAILABLE:
|
| 297 |
+
raise ImportError("Gradio not installed. Install with: pip install gradio")
|
| 298 |
+
|
| 299 |
+
# Logo URLs
|
| 300 |
+
kugelaudio_logo = "https://www.kugelaudio.com/logos/Logo%20Short.svg"
|
| 301 |
+
kisz_logo = "https://docs.sc.hpi.de/attachments/aisc/aisc-logo.png"
|
| 302 |
+
bmftr_logo = (
|
| 303 |
+
"https://hpi.de/fileadmin/_processed_/a/3/csm_BMFTR_de_Web_RGB_gef_durch_cd1f5345bd.jpg"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
with gr.Blocks(title="KugelAudio - Text to Speech") as app:
|
| 307 |
+
gr.HTML(
|
| 308 |
+
f"""
|
| 309 |
+
<div style="text-align: center; margin-bottom: 1.5rem;">
|
| 310 |
+
<h1 style="margin-bottom: 0.5rem;">ποΈ KugelAudio</h1>
|
| 311 |
+
<p style="color: #666; margin-bottom: 1rem;">Open-source text-to-speech with voice cloning capabilities</p>
|
| 312 |
+
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; flex-wrap: wrap;">
|
| 313 |
+
<a href="https://kugelaudio.com" target="_blank">
|
| 314 |
+
<img src="{kugelaudio_logo}" alt="KugelAudio" style="height: 50px; width: auto;">
|
| 315 |
+
</a>
|
| 316 |
+
<a href="https://hpi.de/ki-servicezentrum/" target="_blank">
|
| 317 |
+
<img src="{kisz_logo}" alt="KI-Servicezentrum Berlin-Brandenburg" style="height: 50px; width: auto;">
|
| 318 |
+
</a>
|
| 319 |
+
<a href="https://www.bmftr.bund.de" target="_blank">
|
| 320 |
+
<img src="{bmftr_logo}" alt="GefΓΆrdert durch BMFTR" style="height: 70px; width: auto;">
|
| 321 |
+
</a>
|
| 322 |
+
</div>
|
| 323 |
+
</div>
|
| 324 |
+
"""
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
with gr.Tabs():
|
| 328 |
+
# Tab 1: Text to Speech
|
| 329 |
+
with gr.TabItem("π£οΈ Generate Speech"):
|
| 330 |
+
with gr.Row():
|
| 331 |
+
with gr.Column(scale=1):
|
| 332 |
+
text_input = gr.Textbox(
|
| 333 |
+
label="Text to Synthesize",
|
| 334 |
+
placeholder="Enter the text you want to convert to speech...",
|
| 335 |
+
lines=5,
|
| 336 |
+
max_lines=20,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
reference_audio = gr.Audio(
|
| 340 |
+
label="Reference Audio (Optional)",
|
| 341 |
+
type="numpy",
|
| 342 |
+
sources=["upload", "microphone"],
|
| 343 |
+
)
|
| 344 |
+
gr.Markdown("*Upload a voice sample to clone the speaker's voice*")
|
| 345 |
+
|
| 346 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 347 |
+
model_choice = gr.Dropdown(
|
| 348 |
+
choices=["kugelaudio-0-open"],
|
| 349 |
+
value="kugelaudio-0-open",
|
| 350 |
+
label="Model",
|
| 351 |
+
)
|
| 352 |
+
cfg_scale = gr.Slider(
|
| 353 |
+
minimum=1.0,
|
| 354 |
+
maximum=10.0,
|
| 355 |
+
value=3.0,
|
| 356 |
+
step=0.5,
|
| 357 |
+
label="Guidance Scale",
|
| 358 |
+
info="Higher values = more adherence to text",
|
| 359 |
+
)
|
| 360 |
+
max_tokens = gr.Slider(
|
| 361 |
+
minimum=512,
|
| 362 |
+
maximum=8192,
|
| 363 |
+
value=2048,
|
| 364 |
+
step=256,
|
| 365 |
+
label="Max Tokens",
|
| 366 |
+
info="Maximum generation length",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
generate_btn = gr.Button("π€ Generate Speech", variant="primary", size="lg")
|
| 370 |
+
|
| 371 |
+
with gr.Column(scale=1):
|
| 372 |
+
output_audio = gr.Audio(
|
| 373 |
+
label="Generated Speech",
|
| 374 |
+
type="numpy",
|
| 375 |
+
interactive=False,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
gr.Markdown(
|
| 379 |
+
"""
|
| 380 |
+
### Tips
|
| 381 |
+
- For best results, use clear and well-punctuated text
|
| 382 |
+
- Reference audio should be 5-30 seconds of clear speech
|
| 383 |
+
- The 7B model produces higher quality but is slower
|
| 384 |
+
"""
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
generate_btn.click(
|
| 388 |
+
fn=generate_speech,
|
| 389 |
+
inputs=[text_input, reference_audio, model_choice, cfg_scale, max_tokens],
|
| 390 |
+
outputs=[output_audio],
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Tab 2: Watermark Detection
|
| 394 |
+
with gr.TabItem("π Verify Watermark"):
|
| 395 |
+
gr.Markdown(
|
| 396 |
+
"""
|
| 397 |
+
### Watermark Verification
|
| 398 |
+
Check if an audio file was generated by KugelAudio. All audio generated
|
| 399 |
+
by KugelAudio contains an imperceptible watermark for identification.
|
| 400 |
+
"""
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
with gr.Row():
|
| 404 |
+
with gr.Column():
|
| 405 |
+
verify_audio = gr.Audio(
|
| 406 |
+
label="Audio to Verify",
|
| 407 |
+
type="numpy",
|
| 408 |
+
sources=["upload"],
|
| 409 |
+
)
|
| 410 |
+
verify_btn = gr.Button("π Check Watermark", variant="secondary")
|
| 411 |
+
|
| 412 |
+
with gr.Column():
|
| 413 |
+
verify_result = gr.Markdown(
|
| 414 |
+
label="Result",
|
| 415 |
+
value="Upload an audio file to check for watermark.",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
verify_btn.click(
|
| 419 |
+
fn=check_watermark,
|
| 420 |
+
inputs=[verify_audio],
|
| 421 |
+
outputs=[verify_result],
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Tab 3: About
|
| 425 |
+
with gr.TabItem("βΉοΈ About"):
|
| 426 |
+
gr.Markdown(
|
| 427 |
+
"""
|
| 428 |
+
## About KugelAudio
|
| 429 |
+
|
| 430 |
+
KugelAudio is an open-source text-to-speech system that combines:
|
| 431 |
+
|
| 432 |
+
- **AR + Diffusion Architecture**: Uses autoregressive language modeling
|
| 433 |
+
with diffusion-based speech synthesis for high-quality output
|
| 434 |
+
- **Voice Cloning**: Clone any voice with just a few seconds of reference audio
|
| 435 |
+
- **Audio Watermarking**: All generated audio contains an imperceptible watermark
|
| 436 |
+
using [Facebook's AudioSeal](https://huggingface.co/facebook/audioseal) technology
|
| 437 |
+
|
| 438 |
+
### Models
|
| 439 |
+
|
| 440 |
+
| Model | Parameters | Quality | Speed |
|
| 441 |
+
|-------|------------|---------|-------|
|
| 442 |
+
| kugelaudio-0-open | 7B | Best | Standard |
|
| 443 |
+
|
| 444 |
+
### Responsible Use
|
| 445 |
+
|
| 446 |
+
This technology is intended for legitimate purposes such as:
|
| 447 |
+
- Accessibility (text-to-speech for visually impaired)
|
| 448 |
+
- Content creation (podcasts, videos, audiobooks)
|
| 449 |
+
- Voice assistants and chatbots
|
| 450 |
+
|
| 451 |
+
**Please do not use this technology for:**
|
| 452 |
+
- Creating deepfakes or misleading content
|
| 453 |
+
- Impersonating individuals without consent
|
| 454 |
+
- Any illegal or harmful purposes
|
| 455 |
+
|
| 456 |
+
All generated audio is watermarked to enable detection.
|
| 457 |
+
"""
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
gr.HTML(
|
| 461 |
+
"""
|
| 462 |
+
<div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #eee;">
|
| 463 |
+
<p style="color: #888; margin-bottom: 0.5rem;">
|
| 464 |
+
<strong>KugelAudio</strong> β’ Open Source TTS with Voice Cloning
|
| 465 |
+
</p>
|
| 466 |
+
<p style="color: #aaa; font-size: 0.9rem;">
|
| 467 |
+
Created by <a href="mailto:kajo@kugelaudio.com" style="color: #667eea;">Kajo Kratzenstein</a> β’
|
| 468 |
+
<a href="https://kugelaudio.com" style="color: #667eea;">kugelaudio.com</a> β’
|
| 469 |
+
<a href="https://github.com/kugelaudio/kugelaudio" style="color: #667eea;">GitHub</a>
|
| 470 |
+
</p>
|
| 471 |
+
</div>
|
| 472 |
+
"""
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
return app
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def launch_app(
|
| 479 |
+
share: bool = False,
|
| 480 |
+
server_name: str = "127.0.0.1",
|
| 481 |
+
server_port: int = 7860,
|
| 482 |
+
**kwargs,
|
| 483 |
+
):
|
| 484 |
+
"""Launch the Gradio web interface.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
share: Create a public share link
|
| 488 |
+
server_name: Server hostname (use "0.0.0.0" for network access)
|
| 489 |
+
server_port: Server port
|
| 490 |
+
**kwargs: Additional arguments passed to gr.Blocks.launch()
|
| 491 |
+
"""
|
| 492 |
+
app = create_app()
|
| 493 |
+
app.launch(
|
| 494 |
+
share=share,
|
| 495 |
+
server_name=server_name,
|
| 496 |
+
server_port=server_port,
|
| 497 |
+
theme=gr.themes.Soft(
|
| 498 |
+
primary_hue="indigo",
|
| 499 |
+
secondary_hue="slate",
|
| 500 |
+
),
|
| 501 |
+
**kwargs,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
launch_app()
|
kugelaudio_open/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for KugelAudio."""
|
| 2 |
+
|
| 3 |
+
from kugelaudio_open.utils.generation import generate_speech, load_model_and_processor
|
| 4 |
+
|
| 5 |
+
__all__ = ["generate_speech", "load_model_and_processor"]
|
kugelaudio_open/utils/generation.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""High-level generation utilities for KugelAudio."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Union, Tuple
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_model_and_processor(
|
| 8 |
+
model_name_or_path: str = "kugelaudio/kugelaudio-0-open",
|
| 9 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 10 |
+
torch_dtype: Optional[torch.dtype] = None,
|
| 11 |
+
use_flash_attention: bool = True,
|
| 12 |
+
):
|
| 13 |
+
"""Load KugelAudio model and processor.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
model_name_or_path: HuggingFace model ID or local path
|
| 17 |
+
device: Device to load model on (auto-detected if None)
|
| 18 |
+
torch_dtype: Data type for model weights
|
| 19 |
+
use_flash_attention: Whether to use flash attention if available
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Tuple of (model, processor)
|
| 23 |
+
|
| 24 |
+
Example:
|
| 25 |
+
>>> model, processor = load_model_and_processor("kugelaudio/kugelaudio-0-open")
|
| 26 |
+
"""
|
| 27 |
+
from kugelaudio_open.models import KugelAudioForConditionalGenerationInference
|
| 28 |
+
from kugelaudio_open.processors import KugelAudioProcessor
|
| 29 |
+
|
| 30 |
+
# Auto-detect device
|
| 31 |
+
if device is None:
|
| 32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
+
|
| 34 |
+
# Auto-detect dtype
|
| 35 |
+
if torch_dtype is None:
|
| 36 |
+
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 37 |
+
|
| 38 |
+
# Load model
|
| 39 |
+
attn_impl = "flash_attention_2" if use_flash_attention else "sdpa"
|
| 40 |
+
try:
|
| 41 |
+
model = KugelAudioForConditionalGenerationInference.from_pretrained(
|
| 42 |
+
model_name_or_path,
|
| 43 |
+
torch_dtype=torch_dtype,
|
| 44 |
+
attn_implementation=attn_impl,
|
| 45 |
+
).to(device)
|
| 46 |
+
except Exception:
|
| 47 |
+
# Fallback without flash attention
|
| 48 |
+
model = KugelAudioForConditionalGenerationInference.from_pretrained(
|
| 49 |
+
model_name_or_path,
|
| 50 |
+
torch_dtype=torch_dtype,
|
| 51 |
+
).to(device)
|
| 52 |
+
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
# Load processor
|
| 56 |
+
processor = KugelAudioProcessor.from_pretrained(model_name_or_path)
|
| 57 |
+
|
| 58 |
+
return model, processor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def generate_speech(
|
| 62 |
+
model,
|
| 63 |
+
processor,
|
| 64 |
+
text: str,
|
| 65 |
+
voice_prompt: Optional[torch.Tensor] = None,
|
| 66 |
+
voice_prompt_path: Optional[str] = None,
|
| 67 |
+
cfg_scale: float = 3.0,
|
| 68 |
+
max_new_tokens: int = 4096,
|
| 69 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 70 |
+
) -> torch.Tensor:
|
| 71 |
+
"""Generate speech from text.
|
| 72 |
+
|
| 73 |
+
All generated audio is automatically watermarked for identification.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
model: KugelAudio model
|
| 77 |
+
processor: KugelAudio processor
|
| 78 |
+
text: Text to synthesize
|
| 79 |
+
voice_prompt: Voice prompt tensor for speaker identity
|
| 80 |
+
voice_prompt_path: Path to voice prompt audio file
|
| 81 |
+
cfg_scale: Classifier-free guidance scale
|
| 82 |
+
max_new_tokens: Maximum number of tokens to generate
|
| 83 |
+
device: Device for generation
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Generated audio tensor (watermarked)
|
| 87 |
+
|
| 88 |
+
Example:
|
| 89 |
+
>>> audio = generate_speech(model, processor, "Hello world!")
|
| 90 |
+
>>> processor.save_audio(audio, "output.wav")
|
| 91 |
+
"""
|
| 92 |
+
if device is None:
|
| 93 |
+
device = next(model.parameters()).device
|
| 94 |
+
|
| 95 |
+
# Load voice prompt if path provided
|
| 96 |
+
if voice_prompt is None and voice_prompt_path is not None:
|
| 97 |
+
voice_data = processor.audio_processor(voice_prompt_path, return_tensors="pt")
|
| 98 |
+
voice_prompt = voice_data["audio"].to(device)
|
| 99 |
+
|
| 100 |
+
# Process inputs
|
| 101 |
+
inputs = processor(text=text, return_tensors="pt")
|
| 102 |
+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
| 103 |
+
|
| 104 |
+
# Generate (watermark is automatically applied by the model)
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
outputs = model.generate(
|
| 107 |
+
**inputs,
|
| 108 |
+
voice_prompt=voice_prompt,
|
| 109 |
+
cfg_scale=cfg_scale,
|
| 110 |
+
max_new_tokens=max_new_tokens,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
audio = outputs.speech_outputs[0] if outputs.speech_outputs else None
|
| 114 |
+
|
| 115 |
+
if audio is None:
|
| 116 |
+
raise RuntimeError("Generation failed - no audio output")
|
| 117 |
+
|
| 118 |
+
return audio
|
kugelaudio_open/watermark/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio watermarking for KugelAudio generated speech."""
|
| 2 |
+
|
| 3 |
+
from kugelaudio_open.watermark.watermark import AudioWatermark
|
| 4 |
+
|
| 5 |
+
__all__ = ["AudioWatermark"]
|
kugelaudio_open/watermark/watermark.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio watermarking for KugelAudio using Facebook's AudioSeal.
|
| 2 |
+
|
| 3 |
+
AudioSeal provides state-of-the-art speech localized watermarking with:
|
| 4 |
+
- High robustness to audio editing and compression
|
| 5 |
+
- Fast single-pass detection (real-time capable)
|
| 6 |
+
- Sample-level detection (1/16k second resolution)
|
| 7 |
+
- Optional 16-bit message embedding
|
| 8 |
+
|
| 9 |
+
Reference: https://huggingface.co/facebook/audioseal
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Union, Tuple
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
# Try to import AudioSeal
|
| 20 |
+
try:
|
| 21 |
+
from audioseal import AudioSeal
|
| 22 |
+
AUDIOSEAL_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
AUDIOSEAL_AVAILABLE = False
|
| 25 |
+
warnings.warn(
|
| 26 |
+
"AudioSeal not installed. Install with: pip install audioseal\n"
|
| 27 |
+
"Watermarking will use fallback implementation."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class WatermarkResult:
|
| 33 |
+
"""Result of watermark detection."""
|
| 34 |
+
detected: bool
|
| 35 |
+
confidence: float
|
| 36 |
+
message: Optional[torch.Tensor] = None
|
| 37 |
+
frame_probabilities: Optional[torch.Tensor] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AudioWatermark:
|
| 41 |
+
"""Professional audio watermarking using Facebook's AudioSeal.
|
| 42 |
+
|
| 43 |
+
AudioSeal is a state-of-the-art watermarking system that embeds
|
| 44 |
+
imperceptible watermarks in audio that are robust to various
|
| 45 |
+
audio transformations.
|
| 46 |
+
|
| 47 |
+
Features:
|
| 48 |
+
- Imperceptible watermarks with minimal quality degradation
|
| 49 |
+
- Robust to compression, resampling, and editing
|
| 50 |
+
- Fast detection suitable for real-time applications
|
| 51 |
+
- Optional 16-bit message embedding for tracking
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
>>> watermark = AudioWatermark()
|
| 55 |
+
>>> watermarked_audio = watermark.embed(audio)
|
| 56 |
+
>>> result = watermark.detect(watermarked_audio)
|
| 57 |
+
>>> print(f"Detected: {result.detected}, Confidence: {result.confidence:.2%}")
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
model_name: AudioSeal model variant ("audioseal_wm_16bits")
|
| 61 |
+
device: Device for inference ("cuda" or "cpu")
|
| 62 |
+
message: Optional 16-bit message to embed (for tracking)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
# Default message identifying KugelAudio-generated content
|
| 66 |
+
KUGELAUDIO_MESSAGE = torch.tensor([[1, 0, 1, 0, 1, 0, 1, 0,
|
| 67 |
+
0, 1, 0, 1, 0, 1, 0, 1]]) # Alternating pattern
|
| 68 |
+
|
| 69 |
+
# AudioSeal expects 16kHz audio
|
| 70 |
+
AUDIOSEAL_SAMPLE_RATE = 16000
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
model_name: str = "audioseal_wm_16bits",
|
| 75 |
+
detector_name: str = "audioseal_detector_16bits",
|
| 76 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 77 |
+
message: Optional[torch.Tensor] = None,
|
| 78 |
+
):
|
| 79 |
+
if device is None:
|
| 80 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 81 |
+
self.device = torch.device(device)
|
| 82 |
+
|
| 83 |
+
self._generator = None
|
| 84 |
+
self._detector = None
|
| 85 |
+
self._model_name = model_name
|
| 86 |
+
self._detector_name = detector_name
|
| 87 |
+
|
| 88 |
+
# Use KugelAudio identifier message by default
|
| 89 |
+
self.message = message if message is not None else self.KUGELAUDIO_MESSAGE.clone()
|
| 90 |
+
|
| 91 |
+
if not AUDIOSEAL_AVAILABLE:
|
| 92 |
+
warnings.warn(
|
| 93 |
+
"AudioSeal not available. Watermarking disabled. "
|
| 94 |
+
"Install with: pip install audioseal"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def generator(self):
|
| 99 |
+
"""Lazy load the generator model."""
|
| 100 |
+
if self._generator is None and AUDIOSEAL_AVAILABLE:
|
| 101 |
+
self._generator = AudioSeal.load_generator(self._model_name)
|
| 102 |
+
self._generator = self._generator.to(self.device)
|
| 103 |
+
self._generator.eval()
|
| 104 |
+
return self._generator
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def detector(self):
|
| 108 |
+
"""Lazy load the detector model."""
|
| 109 |
+
if self._detector is None and AUDIOSEAL_AVAILABLE:
|
| 110 |
+
self._detector = AudioSeal.load_detector(self._detector_name)
|
| 111 |
+
self._detector = self._detector.to(self.device)
|
| 112 |
+
self._detector.eval()
|
| 113 |
+
return self._detector
|
| 114 |
+
|
| 115 |
+
def _resample(
|
| 116 |
+
self,
|
| 117 |
+
audio: torch.Tensor,
|
| 118 |
+
orig_sr: int,
|
| 119 |
+
target_sr: int
|
| 120 |
+
) -> torch.Tensor:
|
| 121 |
+
"""Resample audio to target sample rate."""
|
| 122 |
+
if orig_sr == target_sr:
|
| 123 |
+
return audio
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
import torchaudio.functional as F
|
| 127 |
+
return F.resample(audio, orig_sr, target_sr)
|
| 128 |
+
except ImportError:
|
| 129 |
+
# Fallback using scipy
|
| 130 |
+
from scipy import signal
|
| 131 |
+
audio_np = audio.cpu().numpy()
|
| 132 |
+
num_samples = int(len(audio_np.flatten()) * target_sr / orig_sr)
|
| 133 |
+
resampled = signal.resample(audio_np.flatten(), num_samples)
|
| 134 |
+
return torch.from_numpy(resampled).reshape(audio.shape[0], audio.shape[1], -1).to(audio.device)
|
| 135 |
+
|
| 136 |
+
def embed(
|
| 137 |
+
self,
|
| 138 |
+
audio: Union[np.ndarray, torch.Tensor],
|
| 139 |
+
sample_rate: int = 24000,
|
| 140 |
+
message: Optional[torch.Tensor] = None,
|
| 141 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
| 142 |
+
"""Embed watermark into audio.
|
| 143 |
+
|
| 144 |
+
The watermark is imperceptible and robust to various audio
|
| 145 |
+
transformations including compression and resampling.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
audio: Input audio of shape (samples,), (channels, samples),
|
| 149 |
+
or (batch, channels, samples)
|
| 150 |
+
sample_rate: Sample rate of input audio (default: 24000 for KugelAudio)
|
| 151 |
+
message: Optional 16-bit message to embed
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Watermarked audio with same shape and type as input
|
| 155 |
+
"""
|
| 156 |
+
if not AUDIOSEAL_AVAILABLE:
|
| 157 |
+
# Return unchanged if AudioSeal not available
|
| 158 |
+
return audio
|
| 159 |
+
|
| 160 |
+
# Track input type
|
| 161 |
+
is_numpy = isinstance(audio, np.ndarray)
|
| 162 |
+
if is_numpy:
|
| 163 |
+
audio = torch.from_numpy(audio)
|
| 164 |
+
|
| 165 |
+
original_device = audio.device
|
| 166 |
+
original_dtype = audio.dtype
|
| 167 |
+
|
| 168 |
+
# Ensure float32 for processing
|
| 169 |
+
audio = audio.float()
|
| 170 |
+
|
| 171 |
+
# Handle different input shapes
|
| 172 |
+
original_shape = audio.shape
|
| 173 |
+
if audio.ndim == 1:
|
| 174 |
+
# (samples,) -> (1, 1, samples)
|
| 175 |
+
audio = audio.unsqueeze(0).unsqueeze(0)
|
| 176 |
+
elif audio.ndim == 2:
|
| 177 |
+
# (channels, samples) -> (1, channels, samples)
|
| 178 |
+
audio = audio.unsqueeze(0)
|
| 179 |
+
|
| 180 |
+
# Move to device
|
| 181 |
+
audio = audio.to(self.device)
|
| 182 |
+
|
| 183 |
+
# Resample to 16kHz for AudioSeal
|
| 184 |
+
if sample_rate != self.AUDIOSEAL_SAMPLE_RATE:
|
| 185 |
+
audio_16k = self._resample(audio, sample_rate, self.AUDIOSEAL_SAMPLE_RATE)
|
| 186 |
+
else:
|
| 187 |
+
audio_16k = audio
|
| 188 |
+
|
| 189 |
+
# Prepare message
|
| 190 |
+
msg = message if message is not None else self.message
|
| 191 |
+
msg = msg.to(self.device)
|
| 192 |
+
if msg.shape[0] != audio_16k.shape[0]:
|
| 193 |
+
msg = msg.expand(audio_16k.shape[0], -1)
|
| 194 |
+
|
| 195 |
+
# Generate watermark at 16kHz
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
watermark_16k = self.generator.get_watermark(audio_16k, self.AUDIOSEAL_SAMPLE_RATE, message=msg)
|
| 198 |
+
|
| 199 |
+
# Resample watermark back to original sample rate
|
| 200 |
+
if sample_rate != self.AUDIOSEAL_SAMPLE_RATE:
|
| 201 |
+
watermark = self._resample(watermark_16k, self.AUDIOSEAL_SAMPLE_RATE, sample_rate)
|
| 202 |
+
# Ensure same length as original
|
| 203 |
+
if watermark.shape[-1] != audio.shape[-1]:
|
| 204 |
+
if watermark.shape[-1] > audio.shape[-1]:
|
| 205 |
+
watermark = watermark[..., :audio.shape[-1]]
|
| 206 |
+
else:
|
| 207 |
+
watermark = torch.nn.functional.pad(
|
| 208 |
+
watermark, (0, audio.shape[-1] - watermark.shape[-1])
|
| 209 |
+
)
|
| 210 |
+
# Re-fetch original audio at original sample rate
|
| 211 |
+
audio = self._resample(audio_16k, self.AUDIOSEAL_SAMPLE_RATE, sample_rate)
|
| 212 |
+
if audio.shape[-1] != original_shape[-1] if len(original_shape) > 0 else True:
|
| 213 |
+
# Adjust to match original length
|
| 214 |
+
target_len = original_shape[-1] if original_shape else watermark.shape[-1]
|
| 215 |
+
if audio.shape[-1] > target_len:
|
| 216 |
+
audio = audio[..., :target_len]
|
| 217 |
+
watermark = watermark[..., :target_len]
|
| 218 |
+
else:
|
| 219 |
+
watermark = watermark_16k
|
| 220 |
+
|
| 221 |
+
# Add watermark to audio
|
| 222 |
+
watermarked = audio + watermark
|
| 223 |
+
|
| 224 |
+
# Prevent clipping
|
| 225 |
+
max_val = watermarked.abs().max()
|
| 226 |
+
if max_val > 1.0:
|
| 227 |
+
watermarked = watermarked / max_val
|
| 228 |
+
|
| 229 |
+
# Restore original shape
|
| 230 |
+
if len(original_shape) == 1:
|
| 231 |
+
watermarked = watermarked.squeeze(0).squeeze(0)
|
| 232 |
+
elif len(original_shape) == 2:
|
| 233 |
+
watermarked = watermarked.squeeze(0)
|
| 234 |
+
|
| 235 |
+
# Restore device and dtype
|
| 236 |
+
watermarked = watermarked.to(device=original_device, dtype=original_dtype)
|
| 237 |
+
|
| 238 |
+
# Convert back to numpy if input was numpy
|
| 239 |
+
if is_numpy:
|
| 240 |
+
watermarked = watermarked.numpy()
|
| 241 |
+
|
| 242 |
+
return watermarked
|
| 243 |
+
|
| 244 |
+
def detect(
|
| 245 |
+
self,
|
| 246 |
+
audio: Union[np.ndarray, torch.Tensor],
|
| 247 |
+
sample_rate: int = 24000,
|
| 248 |
+
threshold: float = 0.5,
|
| 249 |
+
) -> WatermarkResult:
|
| 250 |
+
"""Detect watermark in audio.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
audio: Input audio to check for watermark
|
| 254 |
+
sample_rate: Sample rate of input audio
|
| 255 |
+
threshold: Detection threshold (0.0-1.0)
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
WatermarkResult with detection status, confidence, and decoded message
|
| 259 |
+
"""
|
| 260 |
+
if not AUDIOSEAL_AVAILABLE:
|
| 261 |
+
return WatermarkResult(
|
| 262 |
+
detected=False,
|
| 263 |
+
confidence=0.0,
|
| 264 |
+
message=None,
|
| 265 |
+
frame_probabilities=None,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Convert to tensor if needed
|
| 269 |
+
if isinstance(audio, np.ndarray):
|
| 270 |
+
audio = torch.from_numpy(audio)
|
| 271 |
+
|
| 272 |
+
audio = audio.float()
|
| 273 |
+
|
| 274 |
+
# Handle different input shapes
|
| 275 |
+
if audio.ndim == 1:
|
| 276 |
+
audio = audio.unsqueeze(0).unsqueeze(0)
|
| 277 |
+
elif audio.ndim == 2:
|
| 278 |
+
audio = audio.unsqueeze(0)
|
| 279 |
+
|
| 280 |
+
audio = audio.to(self.device)
|
| 281 |
+
|
| 282 |
+
# Resample to 16kHz
|
| 283 |
+
if sample_rate != self.AUDIOSEAL_SAMPLE_RATE:
|
| 284 |
+
audio = self._resample(audio, sample_rate, self.AUDIOSEAL_SAMPLE_RATE)
|
| 285 |
+
|
| 286 |
+
# Detect watermark
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
result, message = self.detector(audio, self.AUDIOSEAL_SAMPLE_RATE)
|
| 289 |
+
|
| 290 |
+
# result shape: (batch, 2, frames) - probabilities for [no_watermark, watermark]
|
| 291 |
+
# Get positive (watermark present) probabilities
|
| 292 |
+
watermark_probs = result[:, 1, :] # (batch, frames)
|
| 293 |
+
|
| 294 |
+
# Calculate overall confidence as mean of frame probabilities
|
| 295 |
+
confidence = watermark_probs.mean().item()
|
| 296 |
+
|
| 297 |
+
# Detection based on threshold
|
| 298 |
+
detected = confidence > threshold
|
| 299 |
+
|
| 300 |
+
return WatermarkResult(
|
| 301 |
+
detected=detected,
|
| 302 |
+
confidence=confidence,
|
| 303 |
+
message=message.cpu() if message is not None else None,
|
| 304 |
+
frame_probabilities=watermark_probs.cpu(),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def verify(self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int = 24000) -> bool:
|
| 308 |
+
"""Quick verification that audio contains KugelAudio watermark.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
audio: Audio to verify
|
| 312 |
+
sample_rate: Sample rate of audio
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
True if watermark detected with high confidence
|
| 316 |
+
"""
|
| 317 |
+
result = self.detect(audio, sample_rate)
|
| 318 |
+
return result.detected and result.confidence > 0.6
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class WatermarkPostProcessor:
|
| 322 |
+
"""Post-processor that automatically adds watermarks to generated audio.
|
| 323 |
+
|
| 324 |
+
Designed to be integrated into the generation pipeline to ensure
|
| 325 |
+
all generated audio is watermarked transparently.
|
| 326 |
+
|
| 327 |
+
Example:
|
| 328 |
+
>>> post_processor = WatermarkPostProcessor()
|
| 329 |
+
>>> # In generation pipeline:
|
| 330 |
+
>>> audio = model.generate(...)
|
| 331 |
+
>>> audio = post_processor(audio) # Watermark added automatically
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def __init__(
|
| 335 |
+
self,
|
| 336 |
+
enabled: bool = True,
|
| 337 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 338 |
+
sample_rate: int = 24000,
|
| 339 |
+
):
|
| 340 |
+
self.enabled = enabled
|
| 341 |
+
self.sample_rate = sample_rate
|
| 342 |
+
self._watermark = None
|
| 343 |
+
self._device = device
|
| 344 |
+
|
| 345 |
+
@property
|
| 346 |
+
def watermark(self) -> AudioWatermark:
|
| 347 |
+
"""Lazy initialization of watermark model."""
|
| 348 |
+
if self._watermark is None:
|
| 349 |
+
self._watermark = AudioWatermark(device=self._device)
|
| 350 |
+
return self._watermark
|
| 351 |
+
|
| 352 |
+
def __call__(
|
| 353 |
+
self,
|
| 354 |
+
audio: Union[np.ndarray, torch.Tensor],
|
| 355 |
+
sample_rate: Optional[int] = None,
|
| 356 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
| 357 |
+
"""Add watermark to audio if enabled."""
|
| 358 |
+
if not self.enabled:
|
| 359 |
+
return audio
|
| 360 |
+
|
| 361 |
+
sr = sample_rate or self.sample_rate
|
| 362 |
+
return self.watermark.embed(audio, sample_rate=sr)
|
| 363 |
+
|
| 364 |
+
def disable(self):
|
| 365 |
+
"""Disable watermarking."""
|
| 366 |
+
self.enabled = False
|
| 367 |
+
|
| 368 |
+
def enable(self):
|
| 369 |
+
"""Enable watermarking."""
|
| 370 |
+
self.enabled = True
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def is_watermarked(
|
| 374 |
+
audio: Union[np.ndarray, torch.Tensor],
|
| 375 |
+
sample_rate: int = 24000,
|
| 376 |
+
threshold: float = 0.5,
|
| 377 |
+
) -> bool:
|
| 378 |
+
"""Convenience function to check if audio is watermarked.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
audio: Audio to check
|
| 382 |
+
sample_rate: Sample rate of audio
|
| 383 |
+
threshold: Detection threshold
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
True if watermark detected
|
| 387 |
+
"""
|
| 388 |
+
watermark = AudioWatermark()
|
| 389 |
+
result = watermark.detect(audio, sample_rate, threshold)
|
| 390 |
+
return result.detected
|