multimodalart HF Staff commited on
Commit
bbb0e68
Β·
verified Β·
1 Parent(s): 2160ec3

Upload 25 files

Browse files
kugelaudio_open/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KugelAudio - Open Source Text-to-Speech Model
2
+
3
+ KugelAudio is a state-of-the-art neural text-to-speech model that generates
4
+ natural, expressive speech from text with voice cloning capabilities.
5
+
6
+ Example:
7
+ >>> from kugelaudio import KugelAudioForConditionalGenerationInference
8
+ >>> from transformers import AutoModel
9
+ >>>
10
+ >>> # Load the model
11
+ >>> model = AutoModel.from_pretrained("kugelaudio/kugelaudio-0-open")
12
+ """
13
+
14
+ __version__ = "0.1.0"
15
+
16
+ from .configs import (
17
+ KugelAudioAcousticTokenizerConfig,
18
+ KugelAudioConfig,
19
+ KugelAudioDiffusionHeadConfig,
20
+ KugelAudioSemanticTokenizerConfig,
21
+ )
22
+ from .models import (
23
+ KugelAudioAcousticTokenizerModel,
24
+ KugelAudioDiffusionHead,
25
+ KugelAudioForConditionalGeneration,
26
+ KugelAudioForConditionalGenerationInference,
27
+ KugelAudioModel,
28
+ KugelAudioPreTrainedModel,
29
+ KugelAudioSemanticTokenizerModel,
30
+ )
31
+ from .processors import KugelAudioProcessor
32
+ from .schedule import DPMSolverMultistepScheduler
33
+ from .watermark import AudioWatermark
34
+
35
+
36
+ # Lazy imports for optional components
37
+ def launch_ui(*args, **kwargs):
38
+ """Launch the Gradio web interface."""
39
+ try:
40
+ from .ui import launch_ui as _launch_ui
41
+
42
+ return _launch_ui(*args, **kwargs)
43
+ except ImportError:
44
+ raise ImportError(
45
+ "Gradio is required for the web interface. " "Install it with: pip install gradio"
46
+ )
47
+
48
+
49
+ __all__ = [
50
+ # Version
51
+ "__version__",
52
+ # Configs
53
+ "KugelAudioConfig",
54
+ "KugelAudioAcousticTokenizerConfig",
55
+ "KugelAudioSemanticTokenizerConfig",
56
+ "KugelAudioDiffusionHeadConfig",
57
+ # Models
58
+ "KugelAudioModel",
59
+ "KugelAudioPreTrainedModel",
60
+ "KugelAudioForConditionalGeneration",
61
+ "KugelAudioForConditionalGenerationInference",
62
+ "KugelAudioAcousticTokenizerModel",
63
+ "KugelAudioSemanticTokenizerModel",
64
+ "KugelAudioDiffusionHead",
65
+ # Scheduler
66
+ "DPMSolverMultistepScheduler",
67
+ # Processors
68
+ "KugelAudioProcessor",
69
+ # Watermark
70
+ "AudioWatermark",
71
+ # UI
72
+ "launch_ui",
73
+ ]
kugelaudio_open/cli.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Command-line interface for KugelAudio."""
3
+
4
+ import argparse
5
+ import sys
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(
10
+ description="KugelAudio - Open-source text-to-speech",
11
+ formatter_class=argparse.RawDescriptionHelpFormatter,
12
+ epilog="""
13
+ Examples:
14
+ # Launch web interface
15
+ kugelaudio ui
16
+
17
+ # Launch with public share link
18
+ kugelaudio ui --share
19
+
20
+ # Generate speech from command line
21
+ kugelaudio generate "Hello world!" -o output.wav
22
+
23
+ # Check watermark in audio file
24
+ kugelaudio verify audio.wav
25
+ """,
26
+ )
27
+
28
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
29
+
30
+ # UI command
31
+ ui_parser = subparsers.add_parser("ui", help="Launch Gradio web interface")
32
+ ui_parser.add_argument("--share", action="store_true", help="Create public share link")
33
+ ui_parser.add_argument("--host", default="127.0.0.1", help="Server hostname")
34
+ ui_parser.add_argument("--port", type=int, default=7860, help="Server port")
35
+
36
+ # Generate command
37
+ gen_parser = subparsers.add_parser("generate", help="Generate speech from text")
38
+ gen_parser.add_argument("text", help="Text to synthesize")
39
+ gen_parser.add_argument("-o", "--output", default="output.wav", help="Output file path")
40
+ gen_parser.add_argument("-r", "--reference", help="Reference audio for voice cloning")
41
+ gen_parser.add_argument("--model", default="kugelaudio/kugelaudio-0-open", help="Model ID")
42
+ gen_parser.add_argument("--cfg-scale", type=float, default=3.0, help="Guidance scale")
43
+
44
+ # Verify command
45
+ verify_parser = subparsers.add_parser("verify", help="Check watermark in audio")
46
+ verify_parser.add_argument("audio", help="Audio file to check")
47
+
48
+ args = parser.parse_args()
49
+
50
+ if args.command == "ui":
51
+ from kugelaudio_open.ui import launch_app
52
+
53
+ launch_app(
54
+ share=args.share,
55
+ server_name=args.host,
56
+ server_port=args.port,
57
+ )
58
+
59
+ elif args.command == "generate":
60
+ import torch
61
+ from kugelaudio_open.models import KugelAudioForConditionalGenerationInference
62
+ from kugelaudio_open.processors import KugelAudioProcessor
63
+
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
66
+
67
+ print(f"Loading model {args.model}...")
68
+ model = KugelAudioForConditionalGenerationInference.from_pretrained(
69
+ args.model, torch_dtype=dtype
70
+ ).to(device)
71
+ model.eval()
72
+
73
+ processor = KugelAudioProcessor.from_pretrained(args.model)
74
+
75
+ # Process inputs (voice_prompt passed to processor for proper handling)
76
+ inputs = processor(
77
+ text=args.text,
78
+ voice_prompt=args.reference, # Pass reference audio path directly
79
+ return_tensors="pt"
80
+ )
81
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
82
+
83
+ print("Generating speech...")
84
+ with torch.no_grad():
85
+ outputs = model.generate(
86
+ **inputs,
87
+ cfg_scale=args.cfg_scale,
88
+ max_new_tokens=4096,
89
+ )
90
+
91
+ # Audio is already watermarked by the model's generate method
92
+ audio = outputs.speech_outputs[0]
93
+
94
+ # Save
95
+ processor.save_audio(audio, args.output)
96
+ print(f"Audio saved to {args.output}")
97
+
98
+ elif args.command == "verify":
99
+ import numpy as np
100
+ import soundfile as sf
101
+ from kugelaudio_open.watermark import AudioWatermark
102
+
103
+ audio, sr = sf.read(args.audio)
104
+
105
+ watermark = AudioWatermark()
106
+ result = watermark.detect(audio, sample_rate=sr)
107
+
108
+ if result.detected:
109
+ print(f"βœ… Watermark DETECTED (confidence: {result.confidence:.1%})")
110
+ print("This audio was generated by KugelAudio.")
111
+ else:
112
+ print(f"❌ No watermark detected (confidence: {result.confidence:.1%})")
113
+ print("This audio does not appear to be generated by KugelAudio.")
114
+
115
+ else:
116
+ parser.print_help()
117
+ sys.exit(1)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
kugelaudio_open/configs/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KugelAudio configuration classes."""
2
+
3
+ from .model_config import (
4
+ KugelAudioConfig,
5
+ KugelAudioAcousticTokenizerConfig,
6
+ KugelAudioSemanticTokenizerConfig,
7
+ KugelAudioDiffusionHeadConfig,
8
+ # Aliases
9
+ AcousticTokenizerConfig,
10
+ SemanticTokenizerConfig,
11
+ DiffusionHeadConfig,
12
+ )
13
+
14
+ __all__ = [
15
+ "KugelAudioConfig",
16
+ "KugelAudioAcousticTokenizerConfig",
17
+ "KugelAudioSemanticTokenizerConfig",
18
+ "KugelAudioDiffusionHeadConfig",
19
+ "AcousticTokenizerConfig",
20
+ "SemanticTokenizerConfig",
21
+ "DiffusionHeadConfig",
22
+ ]
kugelaudio_open/configs/kugelaudio_1.5b.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "kugelaudio",
3
+ "_attn_implementation_autoset": true,
4
+ "acoustic_vae_dim": 64,
5
+ "tts_backbone_num_hidden_layers": 20,
6
+ "acoustic_tokenizer_config": {
7
+ "model_type": "kugelaudio_acoustic_tokenizer",
8
+ "causal": true,
9
+ "channels": 1,
10
+ "conv_bias": true,
11
+ "conv_norm": "none",
12
+ "decoder_depths": null,
13
+ "decoder_n_filters": 32,
14
+ "decoder_ratios": [8, 5, 5, 4, 2, 2],
15
+ "disable_last_norm": true,
16
+ "encoder_depths": "3-3-3-3-3-3-8",
17
+ "encoder_n_filters": 32,
18
+ "encoder_ratios": [8, 5, 5, 4, 2, 2],
19
+ "fix_std": 0.5,
20
+ "layer_scale_init_value": 1e-06,
21
+ "layernorm": "RMSNorm",
22
+ "layernorm_elementwise_affine": true,
23
+ "layernorm_eps": 1e-05,
24
+ "mixer_layer": "depthwise_conv",
25
+ "pad_mode": "constant",
26
+ "std_dist_type": "gaussian",
27
+ "vae_dim": 64,
28
+ "weight_init_value": 0.01
29
+ },
30
+ "decoder_config": {
31
+ "model_type": "qwen2",
32
+ "attention_dropout": 0.0,
33
+ "hidden_act": "silu",
34
+ "hidden_size": 1536,
35
+ "initializer_range": 0.02,
36
+ "intermediate_size": 8960,
37
+ "max_position_embeddings": 65536,
38
+ "max_window_layers": 28,
39
+ "num_attention_heads": 12,
40
+ "num_hidden_layers": 28,
41
+ "num_key_value_heads": 2,
42
+ "rms_norm_eps": 1e-06,
43
+ "rope_scaling": null,
44
+ "rope_theta": 1000000.0,
45
+ "sliding_window": null,
46
+ "tie_word_embeddings": true,
47
+ "torch_dtype": "bfloat16",
48
+ "use_cache": true,
49
+ "use_sliding_window": false,
50
+ "vocab_size": 151936
51
+ },
52
+ "diffusion_head_config": {
53
+ "model_type": "kugelaudio_diffusion_head",
54
+ "ddpm_batch_mul": 4,
55
+ "ddpm_beta_schedule": "cosine",
56
+ "ddpm_num_inference_steps": 20,
57
+ "ddpm_num_steps": 1000,
58
+ "diffusion_type": "ddpm",
59
+ "head_ffn_ratio": 3.0,
60
+ "head_layers": 4,
61
+ "hidden_size": 1536,
62
+ "latent_size": 64,
63
+ "prediction_type": "v_prediction",
64
+ "rms_norm_eps": 1e-05,
65
+ "speech_vae_dim": 64
66
+ },
67
+ "torch_dtype": "bfloat16"
68
+ }
kugelaudio_open/configs/kugelaudio_7b.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "kugelaudio",
3
+ "_attn_implementation_autoset": true,
4
+ "acoustic_vae_dim": 64,
5
+ "tts_backbone_num_hidden_layers": 20,
6
+ "acoustic_tokenizer_config": {
7
+ "model_type": "kugelaudio_acoustic_tokenizer",
8
+ "causal": true,
9
+ "channels": 1,
10
+ "conv_bias": true,
11
+ "conv_norm": "none",
12
+ "decoder_depths": null,
13
+ "decoder_n_filters": 32,
14
+ "decoder_ratios": [8, 5, 5, 4, 2, 2],
15
+ "disable_last_norm": true,
16
+ "encoder_depths": "3-3-3-3-3-3-8",
17
+ "encoder_n_filters": 32,
18
+ "encoder_ratios": [8, 5, 5, 4, 2, 2],
19
+ "fix_std": 0.5,
20
+ "layer_scale_init_value": 1e-06,
21
+ "layernorm": "RMSNorm",
22
+ "layernorm_elementwise_affine": true,
23
+ "layernorm_eps": 1e-05,
24
+ "mixer_layer": "depthwise_conv",
25
+ "pad_mode": "constant",
26
+ "std_dist_type": "gaussian",
27
+ "vae_dim": 64,
28
+ "weight_init_value": 0.01
29
+ },
30
+ "decoder_config": {
31
+ "model_type": "qwen2",
32
+ "attention_dropout": 0.0,
33
+ "hidden_act": "silu",
34
+ "hidden_size": 3584,
35
+ "initializer_range": 0.02,
36
+ "intermediate_size": 18944,
37
+ "max_position_embeddings": 32768,
38
+ "max_window_layers": 28,
39
+ "num_attention_heads": 28,
40
+ "num_hidden_layers": 28,
41
+ "num_key_value_heads": 4,
42
+ "rms_norm_eps": 1e-06,
43
+ "rope_scaling": null,
44
+ "rope_theta": 1000000.0,
45
+ "sliding_window": null,
46
+ "tie_word_embeddings": false,
47
+ "torch_dtype": "bfloat16",
48
+ "use_cache": true,
49
+ "use_sliding_window": false,
50
+ "vocab_size": 152064
51
+ },
52
+ "diffusion_head_config": {
53
+ "model_type": "kugelaudio_diffusion_head",
54
+ "ddpm_batch_mul": 4,
55
+ "ddpm_beta_schedule": "cosine",
56
+ "ddpm_num_inference_steps": 20,
57
+ "ddpm_num_steps": 1000,
58
+ "diffusion_type": "ddpm",
59
+ "head_ffn_ratio": 3.0,
60
+ "head_layers": 4,
61
+ "hidden_size": 3584,
62
+ "latent_size": 64,
63
+ "prediction_type": "v_prediction",
64
+ "rms_norm_eps": 1e-05,
65
+ "speech_vae_dim": 64
66
+ },
67
+ "torch_dtype": "bfloat16"
68
+ }
kugelaudio_open/configs/model_config.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration classes for KugelAudio models."""
2
+
3
+ from typing import Optional, List, Union
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class KugelAudioAcousticTokenizerConfig(PretrainedConfig):
12
+ """Configuration for the acoustic tokenizer.
13
+
14
+ The acoustic tokenizer converts continuous speech latents back to audio waveforms.
15
+ It uses a hierarchical convolutional architecture with multiple upsampling stages.
16
+ """
17
+
18
+ model_type = "kugelaudio_acoustic_tokenizer"
19
+
20
+ def __init__(
21
+ self,
22
+ channels: int = 1,
23
+ corpus_normalize: float = 0.0,
24
+ causal: bool = True,
25
+ vae_dim: int = 64,
26
+ fix_std: float = 0.5,
27
+ std_dist_type: str = "gaussian",
28
+ # Common settings
29
+ mixer_layer: str = "depthwise_conv",
30
+ conv_norm: str = "none",
31
+ pad_mode: str = "constant",
32
+ disable_last_norm: bool = True,
33
+ layernorm: str = "RMSNorm",
34
+ layernorm_eps: float = 1e-5,
35
+ layernorm_elementwise_affine: bool = True,
36
+ conv_bias: bool = True,
37
+ layer_scale_init_value: float = 1e-6,
38
+ weight_init_value: float = 1e-2,
39
+ # Encoder specific
40
+ encoder_n_filters: int = 32,
41
+ encoder_ratios: Optional[List[int]] = None,
42
+ encoder_depths: str = "3-3-3-3-3-3-8",
43
+ # Decoder specific
44
+ decoder_n_filters: int = 32,
45
+ decoder_ratios: Optional[List[int]] = None,
46
+ decoder_depths: Optional[str] = None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.channels = channels
51
+ self.corpus_normalize = corpus_normalize
52
+ self.causal = causal
53
+ self.vae_dim = vae_dim
54
+ self.fix_std = fix_std
55
+ self.std_dist_type = std_dist_type
56
+
57
+ # Common parameters
58
+ self.conv_norm = conv_norm
59
+ self.pad_mode = pad_mode
60
+ self.layernorm_eps = layernorm_eps
61
+ self.disable_last_norm = disable_last_norm
62
+ self.layernorm = layernorm
63
+ self.layernorm_elementwise_affine = layernorm_elementwise_affine
64
+ self.conv_bias = conv_bias
65
+ self.layer_scale_init_value = layer_scale_init_value
66
+ self.weight_init_value = weight_init_value
67
+ self.mixer_layer = mixer_layer
68
+
69
+ # Encoder specific parameters
70
+ self.encoder_n_filters = encoder_n_filters
71
+ self.encoder_ratios = encoder_ratios if encoder_ratios is not None else [8, 5, 5, 4, 2, 2]
72
+ self.encoder_depths = encoder_depths
73
+
74
+ # Decoder specific parameters
75
+ self.decoder_ratios = decoder_ratios if decoder_ratios is not None else self.encoder_ratios
76
+ self.decoder_n_filters = decoder_n_filters
77
+ self.decoder_depths = decoder_depths
78
+
79
+
80
+ class KugelAudioSemanticTokenizerConfig(PretrainedConfig):
81
+ """Configuration for the semantic tokenizer.
82
+
83
+ The semantic tokenizer extracts semantic features from audio for conditioning.
84
+ """
85
+
86
+ model_type = "kugelaudio_semantic_tokenizer"
87
+
88
+ def __init__(
89
+ self,
90
+ channels: int = 1,
91
+ corpus_normalize: float = 0.0,
92
+ causal: bool = True,
93
+ vae_dim: int = 64,
94
+ fix_std: float = 0,
95
+ std_dist_type: str = "none",
96
+ # Common settings
97
+ mixer_layer: str = "depthwise_conv",
98
+ conv_norm: str = "none",
99
+ pad_mode: str = "constant",
100
+ disable_last_norm: bool = True,
101
+ layernorm: str = "RMSNorm",
102
+ layernorm_eps: float = 1e-5,
103
+ layernorm_elementwise_affine: bool = True,
104
+ conv_bias: bool = True,
105
+ layer_scale_init_value: float = 1e-6,
106
+ weight_init_value: float = 1e-2,
107
+ # Encoder specific
108
+ encoder_n_filters: int = 32,
109
+ encoder_ratios: Optional[List[int]] = None,
110
+ encoder_depths: str = "3-3-3-3-3-3-8",
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+ self.channels = channels
115
+ self.corpus_normalize = corpus_normalize
116
+ self.causal = causal
117
+ self.vae_dim = vae_dim
118
+ self.fix_std = fix_std
119
+ self.std_dist_type = std_dist_type
120
+
121
+ # Common parameters
122
+ self.conv_norm = conv_norm
123
+ self.pad_mode = pad_mode
124
+ self.layernorm_eps = layernorm_eps
125
+ self.disable_last_norm = disable_last_norm
126
+ self.layernorm = layernorm
127
+ self.layernorm_elementwise_affine = layernorm_elementwise_affine
128
+ self.conv_bias = conv_bias
129
+ self.layer_scale_init_value = layer_scale_init_value
130
+ self.weight_init_value = weight_init_value
131
+ self.mixer_layer = mixer_layer
132
+
133
+ # Encoder specific parameters
134
+ self.encoder_n_filters = encoder_n_filters
135
+ self.encoder_ratios = encoder_ratios if encoder_ratios is not None else [8, 5, 5, 4, 2, 2]
136
+ self.encoder_depths = encoder_depths
137
+
138
+
139
+ class KugelAudioDiffusionHeadConfig(PretrainedConfig):
140
+ """Configuration for the diffusion prediction head.
141
+
142
+ The diffusion head predicts speech latents from text-conditioned hidden states
143
+ using a denoising diffusion process.
144
+ """
145
+
146
+ model_type = "kugelaudio_diffusion_head"
147
+
148
+ def __init__(
149
+ self,
150
+ hidden_size: int = 768,
151
+ head_layers: int = 4,
152
+ head_ffn_ratio: float = 3.0,
153
+ rms_norm_eps: float = 1e-5,
154
+ latent_size: int = 64,
155
+ speech_vae_dim: Optional[int] = None,
156
+ prediction_type: str = "v_prediction",
157
+ diffusion_type: str = "ddpm",
158
+ ddpm_num_steps: int = 1000,
159
+ ddpm_num_inference_steps: int = 20,
160
+ ddpm_beta_schedule: str = "cosine",
161
+ ddpm_algorithm_type: str = "sde-dpmsolver++",
162
+ ddpm_batch_mul: int = 4,
163
+ **kwargs,
164
+ ):
165
+ self.hidden_size = hidden_size
166
+ self.head_layers = head_layers
167
+ self.head_ffn_ratio = head_ffn_ratio
168
+ self.rms_norm_eps = rms_norm_eps
169
+ self.latent_size = latent_size
170
+ self.speech_vae_dim = speech_vae_dim
171
+ self.prediction_type = prediction_type
172
+ self.diffusion_type = diffusion_type
173
+ self.ddpm_num_steps = ddpm_num_steps
174
+ self.ddpm_num_inference_steps = ddpm_num_inference_steps
175
+ self.ddpm_beta_schedule = ddpm_beta_schedule
176
+ self.ddpm_algorithm_type = ddpm_algorithm_type
177
+ self.ddpm_batch_mul = ddpm_batch_mul
178
+
179
+ super().__init__(**kwargs)
180
+
181
+
182
+ class KugelAudioConfig(PretrainedConfig):
183
+ """Main configuration for KugelAudio TTS model.
184
+
185
+ This configuration combines:
186
+ - A language model backbone (Qwen2) for text understanding
187
+ - An acoustic tokenizer for audio encoding/decoding
188
+ - A semantic tokenizer for semantic feature extraction
189
+ - A diffusion head for speech latent prediction
190
+
191
+ Example:
192
+ >>> from kugelaudio import KugelAudioConfig
193
+ >>> config = KugelAudioConfig.from_pretrained("kugelaudio/kugelaudio-0-open")
194
+ """
195
+
196
+ model_type = "kugelaudio"
197
+ is_composition = True
198
+
199
+ sub_configs = {
200
+ "acoustic_tokenizer_config": KugelAudioAcousticTokenizerConfig,
201
+ "semantic_tokenizer_config": KugelAudioSemanticTokenizerConfig,
202
+ "decoder_config": Qwen2Config,
203
+ "diffusion_head_config": KugelAudioDiffusionHeadConfig,
204
+ }
205
+
206
+ # Tensor parallel plan for distributed inference
207
+ base_model_tp_plan = {
208
+ "layers.*.self_attn.q_proj": "colwise",
209
+ "layers.*.self_attn.k_proj": "colwise",
210
+ "layers.*.self_attn.v_proj": "colwise",
211
+ "layers.*.self_attn.o_proj": "rowwise",
212
+ "layers.*.mlp.gate_proj": "colwise",
213
+ "layers.*.mlp.up_proj": "colwise",
214
+ "layers.*.mlp.down_proj": "rowwise",
215
+ }
216
+
217
+ def __init__(
218
+ self,
219
+ acoustic_tokenizer_config=None,
220
+ semantic_tokenizer_config=None,
221
+ decoder_config=None,
222
+ diffusion_head_config=None,
223
+ **kwargs,
224
+ ):
225
+ # Disable auto attention implementation selection
226
+ kwargs["_attn_implementation_autoset"] = False
227
+
228
+ # Initialize acoustic tokenizer config
229
+ if acoustic_tokenizer_config is None:
230
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
231
+ elif isinstance(acoustic_tokenizer_config, dict):
232
+ acoustic_tokenizer_config["model_type"] = "kugelaudio_acoustic_tokenizer"
233
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
234
+ elif isinstance(acoustic_tokenizer_config, KugelAudioAcousticTokenizerConfig):
235
+ self.acoustic_tokenizer_config = acoustic_tokenizer_config
236
+
237
+ # Initialize semantic tokenizer config
238
+ if semantic_tokenizer_config is None:
239
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
240
+ elif isinstance(semantic_tokenizer_config, dict):
241
+ semantic_tokenizer_config["model_type"] = "kugelaudio_semantic_tokenizer"
242
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
243
+ elif isinstance(semantic_tokenizer_config, KugelAudioSemanticTokenizerConfig):
244
+ self.semantic_tokenizer_config = semantic_tokenizer_config
245
+
246
+ # Initialize decoder (language model) config
247
+ if decoder_config is None:
248
+ self.decoder_config = self.sub_configs["decoder_config"]()
249
+ elif isinstance(decoder_config, dict):
250
+ if decoder_config.get("model_type", "") == "qwen2":
251
+ self.decoder_config = Qwen2Config(**decoder_config)
252
+ else:
253
+ raise ValueError(
254
+ f"Unsupported decoder model type: {decoder_config.get('model_type', '')}"
255
+ )
256
+ elif isinstance(decoder_config, Qwen2Config):
257
+ self.decoder_config = decoder_config
258
+
259
+ # Initialize diffusion head config
260
+ if diffusion_head_config is None:
261
+ self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
262
+ elif isinstance(diffusion_head_config, dict):
263
+ diffusion_head_config["model_type"] = "kugelaudio_diffusion_head"
264
+ self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
265
+ elif isinstance(diffusion_head_config, KugelAudioDiffusionHeadConfig):
266
+ self.diffusion_head_config = diffusion_head_config
267
+
268
+ # Derived parameters
269
+ self.acoustic_vae_dim = self.acoustic_tokenizer_config.vae_dim
270
+ self.semantic_vae_dim = self.semantic_tokenizer_config.vae_dim
271
+
272
+ super().__init__(**kwargs)
273
+
274
+
275
+ # Aliases for backwards compatibility
276
+ AcousticTokenizerConfig = KugelAudioAcousticTokenizerConfig
277
+ SemanticTokenizerConfig = KugelAudioSemanticTokenizerConfig
278
+ DiffusionHeadConfig = KugelAudioDiffusionHeadConfig
279
+
280
+
281
+ __all__ = [
282
+ "KugelAudioAcousticTokenizerConfig",
283
+ "KugelAudioSemanticTokenizerConfig",
284
+ "KugelAudioDiffusionHeadConfig",
285
+ "KugelAudioConfig",
286
+ # Aliases
287
+ "AcousticTokenizerConfig",
288
+ "SemanticTokenizerConfig",
289
+ "DiffusionHeadConfig",
290
+ ]
kugelaudio_open/models/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KugelAudio model components."""
2
+
3
+ from .kugelaudio_model import (
4
+ KugelAudioModel,
5
+ KugelAudioPreTrainedModel,
6
+ KugelAudioForConditionalGeneration,
7
+ )
8
+ from .kugelaudio_inference import (
9
+ KugelAudioForConditionalGenerationInference,
10
+ KugelAudioCausalLMOutputWithPast,
11
+ KugelAudioGenerationOutput,
12
+ )
13
+ from .tokenizer import (
14
+ KugelAudioAcousticTokenizerModel,
15
+ KugelAudioSemanticTokenizerModel,
16
+ KugelAudioTokenizerEncoderOutput,
17
+ )
18
+ from .diffusion_head import KugelAudioDiffusionHead
19
+ from .conv_layers import (
20
+ RMSNorm,
21
+ ConvRMSNorm,
22
+ ConvLayerNorm,
23
+ SConv1d,
24
+ SConvTranspose1d,
25
+ )
26
+
27
+ __all__ = [
28
+ # Main models
29
+ "KugelAudioModel",
30
+ "KugelAudioPreTrainedModel",
31
+ "KugelAudioForConditionalGeneration",
32
+ "KugelAudioForConditionalGenerationInference",
33
+ # Outputs
34
+ "KugelAudioCausalLMOutputWithPast",
35
+ "KugelAudioGenerationOutput",
36
+ # Tokenizers
37
+ "KugelAudioAcousticTokenizerModel",
38
+ "KugelAudioSemanticTokenizerModel",
39
+ "KugelAudioTokenizerEncoderOutput",
40
+ # Components
41
+ "KugelAudioDiffusionHead",
42
+ "RMSNorm",
43
+ "ConvRMSNorm",
44
+ "ConvLayerNorm",
45
+ "SConv1d",
46
+ "SConvTranspose1d",
47
+ ]
kugelaudio_open/models/conv_layers.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convolutional layers for KugelAudio tokenizers.
2
+
3
+ This module provides the building blocks for the acoustic and semantic tokenizers,
4
+ including streaming-capable convolutions and normalization layers.
5
+ """
6
+
7
+ import math
8
+ import typing as tp
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from transformers.utils import logging
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ # Normalization modules
21
+ class ConvLayerNorm(nn.LayerNorm):
22
+ """
23
+ Convolution-friendly LayerNorm that moves channels to last dimensions
24
+ before running the normalization and moves them back to original position right after.
25
+ """
26
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
27
+ super().__init__(normalized_shape, **kwargs)
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, 2) # b ... t -> b t ...
31
+ x = nn.functional.layer_norm(
32
+ x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
33
+ ).type_as(x)
34
+ x = x.transpose(1, 2) # b t ... -> b ... t
35
+ return x
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ """Root Mean Square Layer Normalization."""
40
+
41
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
42
+ super().__init__()
43
+ self.dim = dim
44
+ self.eps = eps
45
+ self.elementwise_affine = elementwise_affine
46
+ if self.elementwise_affine:
47
+ weight_shape = (dim,) if weight_shape is None else weight_shape
48
+ self.weight = nn.Parameter(torch.ones(weight_shape))
49
+ else:
50
+ self.register_parameter('weight', None)
51
+
52
+ def _norm(self, x):
53
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
54
+
55
+ def forward(self, x):
56
+ output = self._norm(x.float()).type_as(x)
57
+ if self.weight is not None:
58
+ output = output * self.weight
59
+ return output
60
+
61
+ def extra_repr(self) -> str:
62
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
63
+
64
+
65
+ class ConvRMSNorm(RMSNorm):
66
+ """Convolution-friendly RMSNorm."""
67
+
68
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
69
+ super().__init__(dim, eps, elementwise_affine, weight_shape)
70
+
71
+ def forward(self, x):
72
+ x = x.transpose(1, 2) # b ... t -> b t ...
73
+ output = self._norm(x.float()).type_as(x)
74
+ if self.weight is not None:
75
+ output = output * self.weight
76
+ output = output.transpose(1, 2) # b t ... -> b ... t
77
+ return output
78
+
79
+
80
+ # Convolutional layers and utilities
81
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
82
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
83
+
84
+
85
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
86
+ assert norm in CONV_NORMALIZATIONS
87
+ if norm == 'weight_norm':
88
+ return nn.utils.weight_norm(module)
89
+ elif norm == 'spectral_norm':
90
+ return nn.utils.spectral_norm(module)
91
+ else:
92
+ return module
93
+
94
+
95
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
96
+ """Return the proper normalization module."""
97
+ assert norm in CONV_NORMALIZATIONS
98
+ if norm == 'layer_norm':
99
+ assert isinstance(module, nn.modules.conv._ConvNd)
100
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
101
+ elif norm == 'time_group_norm':
102
+ if causal:
103
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
104
+ assert isinstance(module, nn.modules.conv._ConvNd)
105
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
106
+ else:
107
+ return nn.Identity()
108
+
109
+
110
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
111
+ padding_total: int = 0) -> int:
112
+ """Calculate extra padding needed for convolution to have the same output length."""
113
+ length = x.shape[-1]
114
+ n_frames = (length - kernel_size + padding_total) / stride + 1
115
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
116
+ return ideal_length - length
117
+
118
+
119
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
120
+ """Pad 1D input with handling for small inputs in reflect mode."""
121
+ length = x.shape[-1]
122
+ padding_left, padding_right = paddings
123
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
124
+ if mode == 'reflect':
125
+ max_pad = max(padding_left, padding_right)
126
+ extra_pad = 0
127
+ if length <= max_pad:
128
+ extra_pad = max_pad - length + 1
129
+ x = F.pad(x, (0, extra_pad))
130
+ padded = F.pad(x, paddings, mode, value)
131
+ end = padded.shape[-1] - extra_pad
132
+ return padded[..., :end]
133
+ else:
134
+ return F.pad(x, paddings, mode, value)
135
+
136
+
137
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
138
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
139
+ padding_left, padding_right = paddings
140
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
141
+ assert (padding_left + padding_right) <= x.shape[-1]
142
+ end = x.shape[-1] - padding_right
143
+ return x[..., padding_left: end]
144
+
145
+
146
+ class NormConv1d(nn.Module):
147
+ """Wrapper around Conv1d and normalization applied to this conv."""
148
+
149
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
150
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
151
+ super().__init__()
152
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
153
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
154
+ self.norm_type = norm
155
+
156
+ def forward(self, x):
157
+ x = self.conv(x)
158
+ x = self.norm(x)
159
+ return x
160
+
161
+
162
+ class NormConvTranspose1d(nn.Module):
163
+ """Wrapper around ConvTranspose1d and normalization applied to this conv."""
164
+
165
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
166
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
167
+ super().__init__()
168
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
169
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
170
+ self.norm_type = norm
171
+
172
+ def forward(self, x):
173
+ x = self.convtr(x)
174
+ x = self.norm(x)
175
+ return x
176
+
177
+
178
+ class SConv1d(nn.Module):
179
+ """Conv1d with built-in handling of asymmetric or causal padding and normalization."""
180
+
181
+ def __init__(self, in_channels: int, out_channels: int,
182
+ kernel_size: int, stride: int = 1, dilation: int = 1,
183
+ groups: int = 1, bias: bool = True, causal: bool = False,
184
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
185
+ pad_mode: str = 'reflect'):
186
+ super().__init__()
187
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
188
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
189
+ norm=norm, norm_kwargs=norm_kwargs)
190
+ self.causal = causal
191
+ self.pad_mode = pad_mode
192
+
193
+ # Store configuration
194
+ self.kernel_size = kernel_size
195
+ self.dilation = dilation
196
+ self.stride = stride
197
+ self.in_channels = in_channels
198
+ self.out_channels = out_channels
199
+
200
+ # For non-streaming mode, calculate padding
201
+ self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ """Forward pass (non-streaming)."""
205
+ B, C, T = x.shape
206
+ kernel_size = self.kernel_size
207
+ stride = self.stride
208
+ dilation = self.dilation
209
+ padding_total = self.padding_total
210
+
211
+ # Compute extra padding for stride alignment
212
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
213
+
214
+ if self.causal:
215
+ # Left padding for causal
216
+ if self.pad_mode == 'constant':
217
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
218
+ else:
219
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
220
+ else:
221
+ # Symmetric padding for non-causal
222
+ padding_right = padding_total // 2
223
+ padding_left = padding_total - padding_right
224
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
225
+
226
+ output = self.conv(x)
227
+ return output
228
+
229
+
230
+ class SConvTranspose1d(nn.Module):
231
+ """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
232
+
233
+ def __init__(self, in_channels: int, out_channels: int,
234
+ kernel_size: int, stride: int = 1, causal: bool = False,
235
+ norm: str = 'none', trim_right_ratio: float = 1.,
236
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
237
+ super().__init__()
238
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
239
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
240
+ self.causal = causal
241
+ self.trim_right_ratio = trim_right_ratio
242
+ assert self.causal or self.trim_right_ratio == 1., \
243
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
244
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
245
+
246
+ # Store configuration
247
+ self.kernel_size = kernel_size
248
+ self.stride = stride
249
+ self.in_channels = in_channels
250
+ self.out_channels = out_channels
251
+
252
+ # For transposed convolution, padding calculation is different
253
+ self.padding_total = kernel_size - stride
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ """Forward pass (non-streaming)."""
257
+ kernel_size = self.kernel_size
258
+ stride = self.stride
259
+ padding_total = self.padding_total
260
+
261
+ y = self.convtr(x)
262
+
263
+ # Remove the padding from output
264
+ if self.causal:
265
+ # Trim right side for causal
266
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
267
+ padding_left = padding_total - padding_right
268
+ y = unpad1d(y, (padding_left, padding_right))
269
+ else:
270
+ # Symmetric unpadding for non-causal
271
+ padding_right = padding_total // 2
272
+ padding_left = padding_total - padding_right
273
+ y = unpad1d(y, (padding_left, padding_right))
274
+
275
+ return y
276
+
277
+
278
+ __all__ = [
279
+ "ConvLayerNorm",
280
+ "RMSNorm",
281
+ "ConvRMSNorm",
282
+ "NormConv1d",
283
+ "NormConvTranspose1d",
284
+ "SConv1d",
285
+ "SConvTranspose1d",
286
+ "pad1d",
287
+ "unpad1d",
288
+ "get_extra_padding_for_conv1d",
289
+ ]
kugelaudio_open/models/diffusion_head.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers.models.auto import AutoModel
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ # from transformers.modeling_layers import GradientCheckpointingLayer
11
+ from transformers.activations import ACT2FN
12
+ from transformers.utils import logging
13
+
14
+ from ..configs import KugelAudioDiffusionHeadConfig
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
22
+ super().__init__()
23
+ self.dim = dim
24
+ self.eps = eps
25
+ self.elementwise_affine = elementwise_affine
26
+ if self.elementwise_affine:
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+ else:
29
+ self.register_parameter('weight', None)
30
+
31
+ def _norm(self, x):
32
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
33
+
34
+ def forward(self, x):
35
+ output = self._norm(x.float()).type_as(x)
36
+ if self.weight is not None:
37
+ output = output * self.weight
38
+ return output
39
+
40
+ def extra_repr(self) -> str:
41
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
42
+
43
+ def modulate(x, shift, scale):
44
+ """Apply modulation to input tensor."""
45
+ return x * (1 + scale) + shift
46
+
47
+
48
+ class TimestepEmbedder(nn.Module):
49
+ """
50
+ Embeds scalar timesteps into vector representations.
51
+
52
+ Args:
53
+ hidden_size (`int`): Size of the output embedding
54
+ frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
55
+ """
56
+ def __init__(self, hidden_size, frequency_embedding_size=256):
57
+ super().__init__()
58
+ self.mlp = nn.Sequential(
59
+ nn.Linear(frequency_embedding_size, hidden_size, bias=False),
60
+ # nn.SiLU(),
61
+ ACT2FN['silu'],
62
+ nn.Linear(hidden_size, hidden_size, bias=False),
63
+ )
64
+ self.frequency_embedding_size = frequency_embedding_size
65
+
66
+ @staticmethod
67
+ def timestep_embedding(t, dim, max_period=10000):
68
+ """
69
+ Create sinusoidal timestep embeddings.
70
+
71
+ Args:
72
+ t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
73
+ These may be fractional.
74
+ dim (`int`): The dimension of the output.
75
+ max_period (`int`, optional): Controls the minimum frequency of the embeddings.
76
+
77
+ Returns:
78
+ `torch.Tensor`: An [N, D] Tensor of positional embeddings.
79
+ """
80
+ half = dim // 2
81
+ # Create freqs directly on the target device to avoid transfers during CUDA graph capture
82
+ freqs = torch.exp(
83
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
84
+ )
85
+ args = t[:, None].float() * freqs[None]
86
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
87
+ if dim % 2:
88
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
89
+ return embedding.to(t.dtype)
90
+
91
+ def forward(self, t):
92
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
93
+ t_emb = self.mlp(t_freq)
94
+ return t_emb
95
+
96
+
97
+ class FeedForwardNetwork(nn.Module):
98
+ """
99
+ Standard feed-forward network with SwiGLU activation.
100
+
101
+ Args:
102
+ embed_dim (`int`): Input dimension
103
+ ffn_dim (`int`): Hidden dimension
104
+ """
105
+ def __init__(
106
+ self,
107
+ embed_dim,
108
+ ffn_dim,
109
+ ):
110
+ super().__init__()
111
+ self.embed_dim = embed_dim
112
+ self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
113
+ self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
114
+ self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
115
+ self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
116
+
117
+ def forward(self, x):
118
+ gate = self.gate_proj(x)
119
+ up = self.up_proj(x)
120
+
121
+ # SwiGLU activation
122
+ # gate = F.silu(gate)
123
+ gate = self.act_fn(gate)
124
+ return self.down_proj(gate * up)
125
+
126
+
127
+ class HeadLayer(nn.Module):
128
+ """
129
+ A layer in the diffusion head.
130
+
131
+ Args:
132
+ embed_dim (`int`): Input dimension
133
+ ffn_dim (`int`): Hidden dimension
134
+ cond_dim (`int`): Condition embedding dimension
135
+ norm_eps (`float`, optional): Epsilon for normalization
136
+ """
137
+ def __init__(
138
+ self,
139
+ embed_dim,
140
+ ffn_dim,
141
+ cond_dim,
142
+ norm_eps=1e-5,
143
+ ):
144
+ super().__init__()
145
+ self.embed_dim = embed_dim
146
+ self.cond_dim = cond_dim
147
+ self.ffn_dim = ffn_dim
148
+ self.ffn = FeedForwardNetwork(
149
+ self.embed_dim,
150
+ self.ffn_dim,
151
+ )
152
+ self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
153
+ self.adaLN_modulation = nn.Sequential(
154
+ # nn.SiLU(),
155
+ ACT2FN['silu'],
156
+ nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
157
+ )
158
+
159
+ def forward(self, x, c):
160
+ shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
161
+ x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
162
+ return x
163
+
164
+
165
+ class FinalLayer(nn.Module):
166
+ """
167
+ Final layer in the diffusion head.
168
+
169
+ Args:
170
+ hidden_size (`int`): Input dimension
171
+ output_size (`int`): Output dimension
172
+ cond_size (`int`): Condition embedding dimension
173
+ norm_eps (`float`, optional): Epsilon for normalization
174
+ """
175
+ def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
176
+ super().__init__()
177
+ self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
178
+ self.linear = nn.Linear(hidden_size, output_size, bias=False)
179
+ self.adaLN_modulation = nn.Sequential(
180
+ # nn.SiLU(),
181
+ ACT2FN['silu'],
182
+ nn.Linear(cond_size, 2 * hidden_size, bias=False)
183
+ )
184
+
185
+ def forward(self, x, c):
186
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
187
+ x = modulate(self.norm_final(x), shift, scale)
188
+ x = self.linear(x)
189
+ return x
190
+
191
+
192
+ class KugelAudioDiffusionHead(PreTrainedModel):
193
+ """
194
+ Diffusion head model for kugelaudio.
195
+
196
+ Args:
197
+ config (`KugelAudioDiffusionHeadConfig`): Model configuration
198
+ latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
199
+ """
200
+ config_class = KugelAudioDiffusionHeadConfig
201
+ supports_gradient_checkpointing = True
202
+ _supports_flash_attn_2 = True
203
+ _supports_sdpa = True
204
+
205
+ def __init__(
206
+ self,
207
+ config,
208
+ ):
209
+ super().__init__(config)
210
+ self.config = config
211
+ self.cond_dim = config.hidden_size
212
+ latent_size = config.latent_size
213
+
214
+ self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
215
+ self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
216
+ self.t_embedder = TimestepEmbedder(self.cond_dim)
217
+
218
+ ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
219
+
220
+ # Create the intermediate layers
221
+ self.layers = nn.ModuleList([
222
+ HeadLayer(
223
+ embed_dim=config.hidden_size,
224
+ ffn_dim=ffn_dim,
225
+ cond_dim=self.cond_dim,
226
+ norm_eps=config.rms_norm_eps
227
+ )
228
+ for _ in range(config.head_layers)
229
+ ])
230
+
231
+ # Final layer for output
232
+ self.final_layer = FinalLayer(
233
+ hidden_size=config.hidden_size,
234
+ output_size=latent_size,
235
+ cond_size=self.cond_dim,
236
+ norm_eps=config.rms_norm_eps
237
+ )
238
+
239
+ self.initialize_weights()
240
+
241
+ def initialize_weights(self):
242
+ """Initialize the weights of the model."""
243
+ # Initialize timestep embedder
244
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
245
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
246
+
247
+ # Zero-out adaLN modulation layers
248
+ for layer in self.layers:
249
+ nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
250
+
251
+ # Zero-out output layers
252
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
253
+ nn.init.constant_(self.final_layer.linear.weight, 0)
254
+
255
+ def forward(
256
+ self,
257
+ noisy_images,
258
+ timesteps,
259
+ condition,
260
+ ):
261
+ """
262
+ Forward pass of the prediction head.
263
+
264
+ Args:
265
+ noisy_images (`torch.Tensor`): Noisy images/latents to denoise
266
+ timesteps (`torch.Tensor`): Timesteps for diffusion
267
+ condition (`torch.Tensor`): Conditioning information
268
+
269
+ Returns:
270
+ `torch.Tensor`: The predicted noise/velocity
271
+ """
272
+ x = self.noisy_images_proj(noisy_images)
273
+ t = self.t_embedder(timesteps)
274
+ condition = self.cond_proj(condition)
275
+ c = condition + t
276
+
277
+ for layer in self.layers:
278
+ x = layer(x, c)
279
+
280
+ x = self.final_layer(x, c)
281
+ return x
282
+
283
+
284
+ AutoModel.register(KugelAudioDiffusionHeadConfig, KugelAudioDiffusionHead)
285
+
286
+ __all__ = [
287
+ "KugelAudioDiffusionHead",
288
+ ]
kugelaudio_open/models/kugelaudio_inference.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KugelAudio inference model for speech generation.
2
+
3
+ This is the open-source inference implementation without optimizations.
4
+ Based on the original VibeVoice model architecture.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from tqdm import tqdm
13
+ from transformers import modeling_utils
14
+ from transformers.cache_utils import DynamicCache
15
+ from transformers.generation import (
16
+ GenerationConfig,
17
+ GenerationMixin,
18
+ LogitsProcessor,
19
+ LogitsProcessorList,
20
+ StoppingCriteriaList,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
25
+ from transformers.utils import logging
26
+
27
+ from ..configs import KugelAudioConfig
28
+ from ..schedule.dpm_solver import DPMSolverMultistepScheduler
29
+ from .diffusion_head import KugelAudioDiffusionHead
30
+ from .kugelaudio_model import KugelAudioModel, KugelAudioPreTrainedModel
31
+ from .tokenizer import (
32
+ KugelAudioTokenizerEncoderOutput,
33
+ KugelAudioTokenizerStreamingCache,
34
+ )
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
39
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
40
+
41
+
42
+ def _get_cache_tensors(cache) -> Tuple[List, List]:
43
+ """Get key and value cache tensors from a cache object."""
44
+ if hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
45
+ return cache.key_cache, cache.value_cache
46
+ raise AttributeError(f"Cannot get cache tensors from {type(cache).__name__}")
47
+
48
+
49
+ @dataclass
50
+ class KugelAudioCausalLMOutputWithPast(BaseModelOutputWithPast):
51
+ logits: Optional[torch.FloatTensor] = None
52
+
53
+
54
+ @dataclass
55
+ class KugelAudioGenerationOutput(ModelOutput):
56
+ """Output type for KugelAudio generation."""
57
+
58
+ sequences: torch.LongTensor = None
59
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
60
+
61
+
62
+ class KugelAudioTokenConstraintProcessor(LogitsProcessor):
63
+ """Constrains token generation to only valid tokens during speech generation."""
64
+
65
+ def __init__(self, valid_token_ids: List[int], device: torch.device = None):
66
+ self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device)
67
+
68
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
69
+ mask = torch.full_like(scores, float("-inf"))
70
+ mask[:, self.valid_token_ids] = 0
71
+ scores = scores + mask
72
+ return scores
73
+
74
+
75
+ class KugelAudioForConditionalGenerationInference(KugelAudioPreTrainedModel, GenerationMixin):
76
+ """KugelAudio model for inference with speech generation capabilities."""
77
+
78
+ _tied_weights_keys = ["lm_head.weight"]
79
+ _tp_plan = {"lm_head": "colwise_rep"}
80
+
81
+ def __init__(self, config):
82
+ super().__init__(config)
83
+ self.model = KugelAudioModel(config)
84
+ self.lm_head = nn.Linear(
85
+ config.decoder_config.hidden_size,
86
+ config.decoder_config.vocab_size,
87
+ bias=False,
88
+ )
89
+ self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
90
+ self.post_init()
91
+
92
+ @property
93
+ def noise_scheduler(self):
94
+ return self.model.noise_scheduler
95
+
96
+ @property
97
+ def prediction_head(self):
98
+ return self.model.prediction_head
99
+
100
+ @property
101
+ def speech_scaling_factor(self):
102
+ return self.model.speech_scaling_factor
103
+
104
+ @property
105
+ def speech_bias_factor(self):
106
+ return self.model.speech_bias_factor
107
+
108
+ @property
109
+ def acoustic_tokenizer(self):
110
+ return self.model.acoustic_tokenizer
111
+
112
+ @property
113
+ def semantic_tokenizer(self):
114
+ return self.model.semantic_tokenizer
115
+
116
+ @property
117
+ def acoustic_connector(self):
118
+ return self.model.acoustic_connector
119
+
120
+ @property
121
+ def semantic_connector(self):
122
+ return self.model.semantic_connector
123
+
124
+ def get_input_embeddings(self):
125
+ return self.model.get_input_embeddings()
126
+
127
+ def set_input_embeddings(self, value):
128
+ self.model.set_input_embeddings(value)
129
+
130
+ def get_output_embeddings(self):
131
+ return self.lm_head
132
+
133
+ def set_output_embeddings(self, new_embeddings):
134
+ self.lm_head = new_embeddings
135
+
136
+ def set_ddpm_inference_steps(self, num_steps=None):
137
+ self.ddpm_inference_steps = (
138
+ num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
139
+ )
140
+
141
+ def _process_speech_inputs(
142
+ self,
143
+ speech_tensors: Optional[torch.Tensor],
144
+ speech_masks: Optional[torch.Tensor],
145
+ voice_cache: Optional[dict] = None,
146
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ """Process speech inputs through acoustic and semantic encoders.
148
+
149
+ Returns:
150
+ Tuple of (acoustic_features, speech_embeds) where speech_embeds has shape
151
+ [num_valid_frames, hidden] - already indexed by speech_masks for direct
152
+ assignment to inputs_embeds[speech_input_mask].
153
+ """
154
+ device = next(self.parameters()).device
155
+ dtype = next(self.parameters()).dtype
156
+
157
+ if voice_cache is not None:
158
+ # Use pre-encoded voice features
159
+ acoustic_mean = voice_cache["acoustic_mean"].to(device=device, dtype=dtype)
160
+ semantic_mean = voice_cache["semantic_mean"].to(device=device, dtype=dtype)
161
+
162
+ # Sample from acoustic distribution
163
+ fix_std = voice_cache.get("acoustic_std", self.acoustic_tokenizer.fix_std)
164
+ acoustic_features = acoustic_mean + fix_std * torch.randn_like(acoustic_mean)
165
+ semantic_features = semantic_mean
166
+
167
+ # Create speech_masks from cache dimensions (all frames valid)
168
+ batch_size = acoustic_features.shape[0]
169
+ seq_len = acoustic_features.shape[1]
170
+ speech_masks = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
171
+
172
+ elif speech_tensors is not None:
173
+ # Encode speech through tokenizers
174
+ with torch.no_grad():
175
+ # Acoustic encoding
176
+ if speech_tensors.dim() == 2:
177
+ speech_tensors = speech_tensors.unsqueeze(1)
178
+
179
+ acoustic_output = self.acoustic_tokenizer.encode(speech_tensors)
180
+ acoustic_features, _ = self.acoustic_tokenizer.sampling(acoustic_output)
181
+
182
+ # Semantic encoding
183
+ semantic_output = self.semantic_tokenizer.encode(speech_tensors)
184
+ semantic_features = semantic_output.mean
185
+
186
+ # Create speech_masks if not provided (all frames valid)
187
+ if speech_masks is None:
188
+ batch_size = acoustic_features.shape[0]
189
+ seq_len = acoustic_features.shape[1]
190
+ speech_masks = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
191
+ else:
192
+ # Return dummy features
193
+ vae_dim = self.config.acoustic_vae_dim
194
+ acoustic_features = torch.zeros(1, 1, vae_dim, device=device, dtype=dtype)
195
+ semantic_features = torch.zeros(
196
+ 1, 1, self.config.semantic_vae_dim, device=device, dtype=dtype
197
+ )
198
+ speech_masks = torch.ones(1, 1, dtype=torch.bool, device=device)
199
+
200
+ # Ensure acoustic and semantic have matching time dimensions
201
+ acoustic_len = acoustic_features.shape[1]
202
+ semantic_len = semantic_features.shape[1]
203
+ if semantic_len < acoustic_len:
204
+ pad_size = acoustic_len - semantic_len
205
+ semantic_features = torch.nn.functional.pad(
206
+ semantic_features, (0, 0, 0, pad_size), mode="constant", value=0
207
+ )
208
+ elif semantic_len > acoustic_len:
209
+ semantic_features = semantic_features[:, :acoustic_len, :]
210
+
211
+ # Apply scaling to acoustic features
212
+ if not torch.isnan(self.speech_scaling_factor):
213
+ acoustic_features = (
214
+ acoustic_features + self.speech_bias_factor
215
+ ) * self.speech_scaling_factor
216
+
217
+ # Get embeddings through connectors
218
+ acoustic_embed = self.acoustic_connector(acoustic_features)
219
+ semantic_embed = self.semantic_connector(semantic_features)
220
+
221
+ # Combine embeddings and index by speech_masks
222
+ combined_embed = acoustic_embed + semantic_embed
223
+
224
+ # Move speech_masks to CPU for indexing (matches working implementation)
225
+ speech_embeds = combined_embed[speech_masks.cpu()]
226
+
227
+ return acoustic_features, speech_embeds
228
+
229
+ def forward(
230
+ self,
231
+ input_ids: torch.LongTensor = None,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ position_ids: Optional[torch.LongTensor] = None,
234
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
235
+ inputs_embeds: Optional[torch.FloatTensor] = None,
236
+ labels: Optional[torch.LongTensor] = None,
237
+ use_cache: Optional[bool] = None,
238
+ output_attentions: Optional[bool] = None,
239
+ output_hidden_states: Optional[bool] = None,
240
+ return_dict: Optional[bool] = None,
241
+ cache_position: Optional[torch.LongTensor] = None,
242
+ speech_tensors: Optional[torch.FloatTensor] = None,
243
+ speech_masks: Optional[torch.BoolTensor] = None,
244
+ speech_input_mask: Optional[torch.BoolTensor] = None,
245
+ voice_cache: Optional[dict] = None,
246
+ logits_to_keep: Union[int, slice] = 0,
247
+ **kwargs,
248
+ ) -> Union[Tuple, KugelAudioCausalLMOutputWithPast]:
249
+ """Forward pass for the model."""
250
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
251
+
252
+ if inputs_embeds is None:
253
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
254
+
255
+ # Process speech inputs if provided
256
+ if voice_cache is not None or (speech_tensors is not None and speech_masks is not None):
257
+ _, speech_embeds = self._process_speech_inputs(
258
+ speech_tensors.to(self.dtype) if speech_tensors is not None else None,
259
+ speech_masks,
260
+ voice_cache=voice_cache,
261
+ )
262
+ if speech_input_mask is not None:
263
+ inputs_embeds[speech_input_mask] = speech_embeds
264
+
265
+ outputs = self.model(
266
+ inputs_embeds=inputs_embeds,
267
+ attention_mask=attention_mask,
268
+ position_ids=position_ids,
269
+ past_key_values=past_key_values,
270
+ use_cache=use_cache,
271
+ output_attentions=output_attentions,
272
+ output_hidden_states=output_hidden_states,
273
+ return_dict=return_dict,
274
+ cache_position=cache_position,
275
+ **kwargs,
276
+ )
277
+
278
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
279
+ slice_indices = (
280
+ slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
281
+ )
282
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
283
+
284
+ return KugelAudioCausalLMOutputWithPast(
285
+ logits=logits,
286
+ past_key_values=outputs.past_key_values,
287
+ last_hidden_state=hidden_states,
288
+ attentions=outputs.attentions,
289
+ )
290
+
291
+ @torch.no_grad()
292
+ def sample_speech_tokens(
293
+ self, condition: torch.Tensor, neg_condition: torch.Tensor, cfg_scale: float = 3.0
294
+ ) -> torch.Tensor:
295
+ """Sample speech latents using diffusion with classifier-free guidance."""
296
+ self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
297
+
298
+ if cfg_scale == 1.0:
299
+ # No CFG - single forward pass
300
+ speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
301
+ for t in self.model.noise_scheduler.timesteps:
302
+ eps = self.model.prediction_head(
303
+ speech, t.repeat(speech.shape[0]).to(speech), condition=condition
304
+ )
305
+ speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
306
+ return speech
307
+
308
+ # With CFG - batched forward pass
309
+ combined_condition = torch.cat([condition, neg_condition], dim=0).to(
310
+ self.model.prediction_head.device
311
+ )
312
+ speech = torch.randn(combined_condition.shape[0], self.config.acoustic_vae_dim).to(
313
+ combined_condition
314
+ )
315
+
316
+ for t in self.model.noise_scheduler.timesteps:
317
+ half = speech[: len(speech) // 2]
318
+ combined = torch.cat([half, half], dim=0)
319
+ eps = self.model.prediction_head(
320
+ combined, t.repeat(combined.shape[0]).to(combined), condition=combined_condition
321
+ )
322
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
323
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
324
+ eps = torch.cat([half_eps, half_eps], dim=0)
325
+ speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
326
+
327
+ return speech[: len(speech) // 2]
328
+
329
+ @torch.no_grad()
330
+ def encode_voice_prompt(
331
+ self,
332
+ voice_audio: torch.Tensor,
333
+ sample_rate: int = 24000,
334
+ ) -> dict:
335
+ """Pre-encode a voice prompt for caching."""
336
+ device = next(self.parameters()).device
337
+ dtype = next(self.parameters()).dtype
338
+
339
+ if voice_audio.dim() == 1:
340
+ voice_audio = voice_audio.unsqueeze(0).unsqueeze(0)
341
+ elif voice_audio.dim() == 2:
342
+ voice_audio = voice_audio.unsqueeze(1)
343
+
344
+ voice_audio = voice_audio.to(device=device, dtype=dtype)
345
+
346
+ with torch.no_grad():
347
+ acoustic_output = self.model.acoustic_tokenizer.encode(voice_audio)
348
+ semantic_output = self.model.semantic_tokenizer.encode(voice_audio)
349
+
350
+ return {
351
+ "acoustic_mean": acoustic_output.mean.cpu(),
352
+ "acoustic_std": getattr(acoustic_output, "std", self.model.acoustic_tokenizer.fix_std),
353
+ "semantic_mean": semantic_output.mean.cpu(),
354
+ "audio_length": voice_audio.shape[-1],
355
+ "sample_rate": sample_rate,
356
+ }
357
+
358
+ @torch.no_grad()
359
+ def generate(
360
+ self,
361
+ text_ids: Optional[torch.Tensor] = None,
362
+ input_ids: Optional[torch.Tensor] = None,
363
+ voice_prompt: Optional[torch.Tensor] = None,
364
+ voice_cache: Optional[dict] = None,
365
+ speech_tensors: Optional[torch.Tensor] = None,
366
+ speech_masks: Optional[torch.Tensor] = None,
367
+ speech_input_mask: Optional[torch.Tensor] = None,
368
+ cfg_scale: float = 3.0,
369
+ max_new_tokens: int = 2048,
370
+ do_sample: bool = False,
371
+ temperature: float = 1.0,
372
+ show_progress: bool = True,
373
+ **kwargs,
374
+ ) -> KugelAudioGenerationOutput:
375
+ """Generate speech from text.
376
+
377
+ Args:
378
+ text_ids: Tokenized text input (from processor)
379
+ input_ids: Alternative name for text_ids
380
+ voice_prompt: Voice audio tensor for cloning (legacy, use speech_tensors instead)
381
+ voice_cache: Pre-encoded voice features (from encode_voice_prompt)
382
+ speech_tensors: Voice audio tensor from processor for cloning
383
+ speech_masks: Mask indicating valid voice frames
384
+ speech_input_mask: Boolean mask indicating where to insert voice embeddings
385
+ cfg_scale: Classifier-free guidance scale (higher = more faithful to text)
386
+ max_new_tokens: Maximum tokens to generate
387
+ do_sample: Whether to sample or use greedy decoding
388
+ temperature: Sampling temperature
389
+ show_progress: Whether to show progress bar
390
+
391
+ Returns:
392
+ KugelAudioGenerationOutput with sequences and speech_outputs
393
+ """
394
+ device = next(self.parameters()).device
395
+ dtype = next(self.parameters()).dtype
396
+
397
+ # Handle input_ids vs text_ids
398
+ if text_ids is None and input_ids is not None:
399
+ text_ids = input_ids
400
+ if text_ids is None:
401
+ raise ValueError("text_ids or input_ids is required")
402
+
403
+ text_ids = text_ids.to(device)
404
+ batch_size = text_ids.shape[0]
405
+
406
+ # Handle legacy voice_prompt parameter
407
+ if voice_prompt is not None and speech_tensors is None:
408
+ speech_tensors = voice_prompt
409
+ # Create default speech_masks if not provided
410
+ if speech_masks is None:
411
+ # Estimate number of frames from audio length
412
+ audio_len = voice_prompt.shape[-1]
413
+ num_frames = (audio_len + 3199) // 3200 # compression ratio
414
+ speech_masks = torch.ones(batch_size, num_frames, dtype=torch.bool, device=device)
415
+
416
+ # Get special token IDs
417
+ speech_start_id = getattr(self.config, "speech_start_id", None) or 151652
418
+ speech_end_id = getattr(self.config, "speech_end_id", None) or 151653
419
+ speech_diffusion_id = getattr(self.config, "speech_diffusion_id", None) or 151654
420
+ eos_token_id = getattr(self.config.decoder_config, "eos_token_id", None) or 151643
421
+
422
+ # Initialize streaming caches for tokenizers
423
+ acoustic_cache = KugelAudioTokenizerStreamingCache()
424
+ semantic_cache = KugelAudioTokenizerStreamingCache()
425
+
426
+ # Initialize sequences and attention masks
427
+ current_ids = text_ids
428
+ attention_mask = torch.ones_like(current_ids)
429
+
430
+ # For CFG, create negative prompt (just speech_start token)
431
+ negative_ids = torch.full((batch_size, 1), speech_start_id, dtype=torch.long, device=device)
432
+ negative_attention_mask = torch.ones_like(negative_ids)
433
+
434
+ # Storage for generated audio and tracking
435
+ audio_chunks = [[] for _ in range(batch_size)]
436
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
437
+ correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device)
438
+
439
+ # Get initial embeddings
440
+ inputs_embeds = self.model.get_input_embeddings()(current_ids)
441
+
442
+ # Process voice/speech input if provided
443
+ if speech_tensors is not None or voice_cache is not None:
444
+ # Get speech embeddings
445
+ if voice_cache is not None:
446
+ _, speech_embeds = self._process_speech_inputs(
447
+ speech_tensors=None,
448
+ speech_masks=None,
449
+ voice_cache=voice_cache,
450
+ )
451
+ else:
452
+ # Encode speech_tensors directly
453
+ speech_tensors = speech_tensors.to(device=device, dtype=dtype)
454
+ if speech_masks is not None:
455
+ speech_masks = speech_masks.to(device)
456
+ _, speech_embeds = self._process_speech_inputs(
457
+ speech_tensors=speech_tensors,
458
+ speech_masks=speech_masks,
459
+ voice_cache=None,
460
+ )
461
+
462
+ # Insert speech embeddings at positions marked by speech_input_mask
463
+ # speech_embeds is already flattened to [num_valid_frames, hidden] by _process_speech_inputs
464
+ if speech_input_mask is not None:
465
+ speech_input_mask = speech_input_mask.to(device)
466
+ # Directly assign - shapes should match
467
+ inputs_embeds[speech_input_mask] = speech_embeds
468
+
469
+ negative_inputs_embeds = self.model.get_input_embeddings()(negative_ids)
470
+
471
+ # Setup logits processor to constrain to valid tokens
472
+ valid_tokens = [speech_start_id, speech_end_id, speech_diffusion_id, eos_token_id]
473
+ token_constraint = KugelAudioTokenConstraintProcessor(valid_tokens, device=device)
474
+
475
+ # Initialize KV caches
476
+ past_key_values = None
477
+ negative_past_key_values = None
478
+
479
+ # Progress bar
480
+ progress_iter = (
481
+ tqdm(range(max_new_tokens), desc="Generating", leave=False)
482
+ if show_progress
483
+ else range(max_new_tokens)
484
+ )
485
+
486
+ for step in progress_iter:
487
+ if finished.all():
488
+ break
489
+
490
+ # Forward pass for positive (main) model
491
+ if past_key_values is None:
492
+ outputs = self(
493
+ inputs_embeds=inputs_embeds,
494
+ attention_mask=attention_mask,
495
+ use_cache=True,
496
+ return_dict=True,
497
+ )
498
+ else:
499
+ outputs = self(
500
+ inputs_embeds=inputs_embeds[:, -1:],
501
+ attention_mask=attention_mask,
502
+ past_key_values=past_key_values,
503
+ use_cache=True,
504
+ return_dict=True,
505
+ )
506
+
507
+ past_key_values = outputs.past_key_values
508
+ logits = outputs.logits[:, -1, :]
509
+
510
+ # Apply token constraint
511
+ logits = token_constraint(current_ids, logits)
512
+
513
+ # Sample or greedy decode
514
+ if do_sample and temperature > 0:
515
+ probs = torch.softmax(logits / temperature, dim=-1)
516
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
517
+ else:
518
+ next_tokens = torch.argmax(logits, dim=-1)
519
+
520
+ # Force finished samples to output EOS
521
+ next_tokens = torch.where(
522
+ finished, torch.tensor(eos_token_id, device=device), next_tokens
523
+ )
524
+
525
+ # Update sequences
526
+ current_ids = torch.cat([current_ids, next_tokens.unsqueeze(-1)], dim=-1)
527
+ attention_mask = torch.cat(
528
+ [
529
+ attention_mask,
530
+ torch.ones((batch_size, 1), device=device, dtype=attention_mask.dtype),
531
+ ],
532
+ dim=-1,
533
+ )
534
+
535
+ # Check for EOS tokens
536
+ eos_mask = (next_tokens == eos_token_id) & ~finished
537
+ if eos_mask.any():
538
+ finished = finished | eos_mask
539
+
540
+ # Check for speech_end tokens - mark as finished and clear caches
541
+ speech_end_mask = (next_tokens == speech_end_id) & ~finished
542
+ if speech_end_mask.any():
543
+ finished = finished | speech_end_mask
544
+ speech_end_indices = speech_end_mask.nonzero(as_tuple=False).squeeze(-1)
545
+ acoustic_cache.set_to_zero(speech_end_indices)
546
+ semantic_cache.set_to_zero(speech_end_indices)
547
+
548
+ # Handle speech_start tokens - refresh negative model KV cache
549
+ speech_start_mask = (next_tokens == speech_start_id) & ~finished
550
+ if (
551
+ speech_start_mask.any()
552
+ and cfg_scale != 1.0
553
+ and negative_past_key_values is not None
554
+ ):
555
+ speech_start_indices = speech_start_mask.nonzero(as_tuple=False).squeeze(-1)
556
+ if speech_start_indices.dim() == 0:
557
+ speech_start_indices = speech_start_indices.unsqueeze(0)
558
+
559
+ for sample_idx in speech_start_indices.tolist():
560
+ negative_attention_mask[sample_idx, :] = 0
561
+ negative_attention_mask[sample_idx, -1] = 1
562
+
563
+ key_caches, value_caches = _get_cache_tensors(negative_past_key_values)
564
+ for k_cache, v_cache in zip(key_caches, value_caches):
565
+ k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone()
566
+ v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone()
567
+
568
+ negative_ids[sample_idx, -1] = speech_start_id
569
+
570
+ # Prepare next input embeddings
571
+ next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1)
572
+
573
+ # Handle diffusion tokens - generate speech
574
+ diffusion_mask = (next_tokens == speech_diffusion_id) & ~finished
575
+ if diffusion_mask.any():
576
+ diffusion_indices = diffusion_mask.nonzero(as_tuple=False).squeeze(-1)
577
+ if diffusion_indices.dim() == 0:
578
+ diffusion_indices = diffusion_indices.unsqueeze(0)
579
+
580
+ # Run negative forward pass for CFG
581
+ if cfg_scale != 1.0:
582
+ if negative_past_key_values is None:
583
+ neg_outputs = self(
584
+ inputs_embeds=negative_inputs_embeds,
585
+ attention_mask=negative_attention_mask,
586
+ use_cache=True,
587
+ return_dict=True,
588
+ )
589
+ else:
590
+ neg_outputs = self(
591
+ inputs_embeds=negative_inputs_embeds[:, -1:],
592
+ attention_mask=negative_attention_mask,
593
+ past_key_values=negative_past_key_values,
594
+ use_cache=True,
595
+ return_dict=True,
596
+ )
597
+ negative_past_key_values = neg_outputs.past_key_values
598
+
599
+ # Handle non-diffusion samples KV cache correction
600
+ non_diffusion_mask = ~diffusion_mask & ~finished
601
+ if non_diffusion_mask.any():
602
+ non_diffusion_indices = non_diffusion_mask.nonzero(as_tuple=False).squeeze(
603
+ -1
604
+ )
605
+ if non_diffusion_indices.dim() == 0:
606
+ non_diffusion_indices = non_diffusion_indices.unsqueeze(0)
607
+
608
+ key_caches, value_caches = _get_cache_tensors(negative_past_key_values)
609
+ for sample_idx in non_diffusion_indices.tolist():
610
+ start_idx = correct_cnt[sample_idx].item()
611
+ seq_len = negative_attention_mask.shape[1]
612
+
613
+ if start_idx + 1 < seq_len - 1:
614
+ negative_attention_mask[sample_idx, start_idx + 1 :] = (
615
+ negative_attention_mask[sample_idx, start_idx:-1].clone()
616
+ )
617
+ negative_attention_mask[sample_idx, start_idx] = 0
618
+
619
+ for k_cache, v_cache in zip(key_caches, value_caches):
620
+ if start_idx + 1 < k_cache.shape[2] - 1:
621
+ k_cache[sample_idx, :, start_idx + 1 :, :] = k_cache[
622
+ sample_idx, :, start_idx:-1, :
623
+ ].clone()
624
+ v_cache[sample_idx, :, start_idx + 1 :, :] = v_cache[
625
+ sample_idx, :, start_idx:-1, :
626
+ ].clone()
627
+
628
+ if start_idx + 1 < negative_ids.shape[1] - 1:
629
+ negative_ids[sample_idx, start_idx + 1 :] = negative_ids[
630
+ sample_idx, start_idx:-1
631
+ ].clone()
632
+
633
+ correct_cnt[non_diffusion_indices] += 1
634
+
635
+ neg_condition = neg_outputs.last_hidden_state[diffusion_indices, -1, :]
636
+ else:
637
+ neg_condition = torch.zeros(
638
+ diffusion_indices.shape[0],
639
+ self.config.decoder_config.hidden_size,
640
+ device=device,
641
+ dtype=dtype,
642
+ )
643
+
644
+ # Get conditioning from last hidden state
645
+ condition = outputs.last_hidden_state[diffusion_indices, -1, :]
646
+
647
+ # Sample speech latents using diffusion
648
+ speech_latents = self.sample_speech_tokens(condition, neg_condition, cfg_scale)
649
+
650
+ # Unscale latents
651
+ scaled_latent = (
652
+ speech_latents / self.speech_scaling_factor - self.speech_bias_factor
653
+ )
654
+
655
+ # Decode through acoustic tokenizer with streaming cache
656
+ audio = self.acoustic_tokenizer.decode(
657
+ scaled_latent.unsqueeze(1).permute(0, 2, 1),
658
+ cache=acoustic_cache,
659
+ sample_indices=diffusion_indices,
660
+ use_cache=True,
661
+ )
662
+
663
+ # Store audio chunks
664
+ for i, idx in enumerate(diffusion_indices.tolist()):
665
+ if not finished[idx]:
666
+ audio_chunks[idx].append(audio[i].cpu())
667
+
668
+ # Encode audio to semantic features with streaming cache
669
+ semantic_output = self.semantic_tokenizer.encode(
670
+ audio,
671
+ cache=semantic_cache,
672
+ sample_indices=diffusion_indices,
673
+ use_cache=True,
674
+ )
675
+ semantic_features = semantic_output.mean
676
+
677
+ # Compute embeddings for next step
678
+ acoustic_embed = self.acoustic_connector(speech_latents.unsqueeze(1))
679
+ semantic_embed = self.semantic_connector(semantic_features)
680
+ diffusion_embeds = (acoustic_embed + semantic_embed).squeeze(1)
681
+
682
+ # Update embeddings for diffusion samples
683
+ next_inputs_embeds[diffusion_indices] = diffusion_embeds.unsqueeze(1)
684
+
685
+ # Update embeddings for next iteration
686
+ inputs_embeds = torch.cat([inputs_embeds, next_inputs_embeds], dim=1)
687
+
688
+ # Update negative model
689
+ negative_inputs_embeds = torch.cat([negative_inputs_embeds, next_inputs_embeds], dim=1)
690
+ negative_attention_mask = torch.cat(
691
+ [
692
+ negative_attention_mask,
693
+ torch.ones((batch_size, 1), device=device, dtype=negative_attention_mask.dtype),
694
+ ],
695
+ dim=-1,
696
+ )
697
+ negative_ids = torch.cat([negative_ids, next_tokens.unsqueeze(-1)], dim=-1)
698
+
699
+ # Concatenate audio chunks with normalization
700
+ speech_outputs = []
701
+ for chunks in audio_chunks:
702
+ if chunks:
703
+ concatenated = torch.cat(chunks, dim=-1).squeeze()
704
+ # Normalize audio to prevent clipping
705
+ max_val = concatenated.abs().max()
706
+ if max_val > 1.0:
707
+ concatenated = concatenated * (0.95 / max_val)
708
+ # Apply watermark to all generated audio
709
+ concatenated = self._apply_watermark(concatenated, sample_rate=24000)
710
+ speech_outputs.append(concatenated)
711
+ else:
712
+ speech_outputs.append(None)
713
+
714
+ return KugelAudioGenerationOutput(
715
+ sequences=current_ids,
716
+ speech_outputs=speech_outputs,
717
+ )
718
+
719
+ def _apply_watermark(self, audio: torch.Tensor, sample_rate: int = 24000) -> torch.Tensor:
720
+ """Apply imperceptible watermark to generated audio.
721
+
722
+ This watermark identifies audio as generated by KugelAudio and is designed
723
+ to be robust against various audio transformations while remaining inaudible.
724
+ """
725
+ try:
726
+ import torchaudio.functional as F
727
+ from audioseal import AudioSeal
728
+ except ImportError:
729
+ return audio # Graceful fallback if audioseal not available
730
+
731
+ device = audio.device
732
+ dtype = audio.dtype
733
+ original_shape = audio.shape
734
+
735
+ # Prepare audio for watermarking (AudioSeal expects [batch, channels, samples] at 16kHz)
736
+ if audio.dim() == 1:
737
+ audio_for_wm = audio.unsqueeze(0).unsqueeze(0)
738
+ elif audio.dim() == 2:
739
+ audio_for_wm = audio.unsqueeze(0)
740
+ else:
741
+ audio_for_wm = audio
742
+
743
+ audio_for_wm = audio_for_wm.float()
744
+
745
+ # Resample to 16kHz for AudioSeal
746
+ if sample_rate != 16000:
747
+ audio_16k = F.resample(audio_for_wm, sample_rate, 16000)
748
+ else:
749
+ audio_16k = audio_for_wm
750
+
751
+ # Load watermark generator (cached after first use)
752
+ if not hasattr(self, "_wm_generator"):
753
+ self._wm_generator = AudioSeal.load_generator("audioseal_wm_16bits").to(device)
754
+ self._wm_generator.eval()
755
+
756
+ # Generate and apply watermark
757
+ with torch.no_grad():
758
+ watermark_16k = self._wm_generator.get_watermark(audio_16k.to(device), 16000)
759
+
760
+ # Resample watermark back to original sample rate
761
+ if sample_rate != 16000:
762
+ watermark = F.resample(watermark_16k, 16000, sample_rate)
763
+ # Ensure same length
764
+ if watermark.shape[-1] != audio_for_wm.shape[-1]:
765
+ if watermark.shape[-1] > audio_for_wm.shape[-1]:
766
+ watermark = watermark[..., : audio_for_wm.shape[-1]]
767
+ else:
768
+ watermark = torch.nn.functional.pad(
769
+ watermark, (0, audio_for_wm.shape[-1] - watermark.shape[-1])
770
+ )
771
+ else:
772
+ watermark = watermark_16k
773
+
774
+ # Add watermark to audio
775
+ watermarked = audio_for_wm + watermark.to(audio_for_wm.device)
776
+
777
+ # Normalize to prevent clipping
778
+ max_val = watermarked.abs().max()
779
+ if max_val > 1.0:
780
+ watermarked = watermarked * (0.95 / max_val)
781
+
782
+ # Restore original shape
783
+ if len(original_shape) == 1:
784
+ watermarked = watermarked.squeeze(0).squeeze(0)
785
+ elif len(original_shape) == 2:
786
+ watermarked = watermarked.squeeze(0)
787
+
788
+ return watermarked.to(dtype=dtype)
789
+
790
+
791
+ # Register with AutoModel
792
+ AutoModel.register(KugelAudioConfig, KugelAudioModel)
793
+ AutoModelForCausalLM.register(KugelAudioConfig, KugelAudioForConditionalGenerationInference)
794
+
795
+
796
+ __all__ = [
797
+ "KugelAudioForConditionalGenerationInference",
798
+ "KugelAudioCausalLMOutputWithPast",
799
+ "KugelAudioGenerationOutput",
800
+ ]
kugelaudio_open/models/kugelaudio_model.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Tuple, Union, Callable
3
+ from tqdm import tqdm
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.distributed as dist
8
+
9
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
10
+
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import (
13
+ CausalLMOutput,
14
+ BaseModelOutputWithPast,
15
+ ModelOutput,
16
+ )
17
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
18
+ from transformers import modeling_utils
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
21
+ from transformers.utils import logging
22
+
23
+
24
+ from .tokenizer import (
25
+ KugelAudioAcousticTokenizerModel,
26
+ KugelAudioSemanticTokenizerModel,
27
+ )
28
+ from .diffusion_head import KugelAudioDiffusionHead
29
+ from ..schedule.dpm_solver import DPMSolverMultistepScheduler
30
+
31
+ from ..configs import KugelAudioConfig
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ if (
37
+ not hasattr(modeling_utils, "ALL_PARALLEL_STYLES")
38
+ or modeling_utils.ALL_PARALLEL_STYLES is None
39
+ ):
40
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
41
+
42
+
43
+ @dataclass
44
+ class KugelAudioCausalLMOutputWithPast(ModelOutput):
45
+ loss: Optional[torch.FloatTensor] = None
46
+ diffusion_loss: Optional[torch.FloatTensor] = None
47
+ speech_token_num: Optional[torch.LongTensor] = None
48
+ logits: torch.FloatTensor = None
49
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
50
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
51
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
52
+
53
+
54
+ @dataclass
55
+ class KugelAudioGenerationOutput(ModelOutput):
56
+ """
57
+ Output type for KugelAudio generation.
58
+
59
+ Args:
60
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
61
+ The generated sequences.
62
+ speech_outputs (`List[torch.FloatTensor]`, *optional*):
63
+ List of generated speech waveforms or latents for each speech segment.
64
+ """
65
+
66
+ sequences: torch.LongTensor = None
67
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
68
+
69
+
70
+ class SpeechConnector(nn.Module):
71
+ def __init__(self, input_dim, output_dim):
72
+ super().__init__()
73
+ self.fc1 = nn.Linear(input_dim, output_dim)
74
+ self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
75
+ self.fc2 = nn.Linear(output_dim, output_dim)
76
+
77
+ def forward(self, features, **kwargs):
78
+ x = self.fc1(features)
79
+ x = self.norm(x)
80
+ x = self.fc2(x)
81
+ return x
82
+
83
+
84
+ # @auto_docstring
85
+ class KugelAudioPreTrainedModel(PreTrainedModel):
86
+ config_class = KugelAudioConfig
87
+ base_model_prefix = "model"
88
+ supports_gradient_checkpointing = True
89
+ _skip_keys_device_placement = "past_key_values"
90
+ _supports_cache_class = True
91
+ _supports_flash_attn_2 = True
92
+ _supports_sdpa = True
93
+ _supports_quantized_cache = True
94
+ _supports_static_cache = True
95
+ _supports_attention_backend = True
96
+
97
+ def _init_weights(self, module):
98
+ if isinstance(module, KugelAudioDiffusionHead):
99
+ module.initialize_weights()
100
+ return
101
+
102
+ # Use the language model's initializer_range if available
103
+ if hasattr(self.config, "language_model_config") and hasattr(
104
+ self.config.language_model_config, "initializer_range"
105
+ ):
106
+ std = self.config.language_model_config.initializer_range
107
+ elif hasattr(self.config, "decoder_config") and hasattr(
108
+ self.config.decoder_config, "initializer_range"
109
+ ):
110
+ std = self.config.decoder_config.initializer_range
111
+ else:
112
+ std = 0.02 # Default value
113
+
114
+ if isinstance(module, nn.Linear):
115
+ module.weight.data.normal_(mean=0.0, std=std)
116
+ if module.bias is not None:
117
+ module.bias.data.zero_()
118
+ elif isinstance(module, nn.LayerNorm):
119
+ module.weight.data.fill_(1.0)
120
+ module.bias.data.zero_()
121
+
122
+
123
+ # @auto_docstring
124
+ class KugelAudioModel(KugelAudioPreTrainedModel):
125
+ def __init__(self, config):
126
+ super().__init__(config)
127
+
128
+ if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
129
+ if isinstance(config.torch_dtype, str):
130
+ dtype = getattr(torch, config.torch_dtype)
131
+ else:
132
+ dtype = config.torch_dtype
133
+ else:
134
+ dtype = torch.float32
135
+
136
+ # Initialize Qwen2 model for language modeling
137
+ lm_config = config.decoder_config
138
+ self.language_model = AutoModel.from_config(lm_config)
139
+
140
+ # Initialize speech components if needed
141
+ self.acoustic_tokenizer = AutoModel.from_config(
142
+ config.acoustic_tokenizer_config
143
+ ).to(dtype)
144
+ self.semantic_tokenizer = AutoModel.from_config(
145
+ config.semantic_tokenizer_config
146
+ ).to(dtype)
147
+
148
+ self.acoustic_connector = SpeechConnector(
149
+ config.acoustic_vae_dim, lm_config.hidden_size
150
+ ).to(dtype)
151
+ self.semantic_connector = SpeechConnector(
152
+ config.semantic_vae_dim, lm_config.hidden_size
153
+ ).to(dtype)
154
+
155
+ # Register scaling factors as buffers - use 1D tensors for FSDP compatibility
156
+ self.register_buffer("speech_scaling_factor", torch.tensor(float("nan")))
157
+ self.register_buffer("speech_bias_factor", torch.tensor(float("nan")))
158
+
159
+ # Initialize prediction head for speech generation
160
+ self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(
161
+ dtype
162
+ )
163
+
164
+ # Initialize noise scheduler with SDE-DPM-Solver++ for better quality
165
+ algorithm_type = getattr(
166
+ config.diffusion_head_config, "ddpm_algorithm_type", "sde-dpmsolver++"
167
+ )
168
+ self.noise_scheduler = DPMSolverMultistepScheduler(
169
+ num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
170
+ beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
171
+ prediction_type=config.diffusion_head_config.prediction_type,
172
+ algorithm_type=algorithm_type,
173
+ solver_order=2,
174
+ )
175
+
176
+ def get_input_embeddings(self):
177
+ if hasattr(self.language_model, "embed_tokens"):
178
+ # If the language model has an embed_tokens attribute, return it
179
+ return self.language_model.embed_tokens
180
+
181
+ for (
182
+ name,
183
+ attr,
184
+ ) in (
185
+ self.language_model.fullmap.items()
186
+ ): # parallel by nnscaler, the name is changed
187
+ if attr.orig_name == "embed_tokens.weight":
188
+ return getattr(self.language_model, name)
189
+ assert False, "should not arrive here"
190
+
191
+ def set_input_embeddings(self, value):
192
+ self.language_model.embed_tokens = value
193
+
194
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
195
+ """Set the speech tokenizers used for encoding and decoding speech."""
196
+ self.acoustic_tokenizer = acoustic_tokenizer
197
+ self.semantic_tokenizer = semantic_tokenizer
198
+
199
+ # Reset the encoder to evaluation mode
200
+ if self.acoustic_tokenizer is not None:
201
+ self.acoustic_tokenizer.eval()
202
+
203
+ if self.semantic_tokenizer is not None:
204
+ self.semantic_tokenizer.eval()
205
+
206
+ @staticmethod
207
+ def _prepare_4d_causal_attention_mask_with_cache_position(
208
+ attention_mask: torch.Tensor,
209
+ sequence_length: int,
210
+ target_length: int,
211
+ dtype: torch.dtype,
212
+ device: torch.device = None,
213
+ cache_position: torch.Tensor = None,
214
+ batch_size: int = None,
215
+ config=None,
216
+ past_key_values=None,
217
+ **kwargs,
218
+ ) -> torch.Tensor:
219
+ """
220
+ Creates a 4D causal attention mask for use with static cache.
221
+
222
+ This enables torch.compile to work efficiently without recompilation
223
+ by providing a consistent mask shape during autoregressive generation.
224
+
225
+ Based on the standard HuggingFace implementation without sliding window
226
+ (KugelAudio doesn't use sliding window attention).
227
+
228
+ Compatible with both old and new transformers API.
229
+ """
230
+ # Handle case where attention_mask is already 4D
231
+ if attention_mask is not None and attention_mask.dim() == 4:
232
+ return attention_mask
233
+
234
+ # Get device from attention_mask or cache_position if not provided
235
+ if device is None:
236
+ if attention_mask is not None:
237
+ device = attention_mask.device
238
+ elif cache_position is not None:
239
+ device = cache_position.device
240
+ else:
241
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
242
+
243
+ min_dtype = torch.finfo(dtype).min
244
+
245
+ # Create causal mask: (sequence_length, target_length)
246
+ causal_mask = torch.full(
247
+ (sequence_length, target_length),
248
+ fill_value=min_dtype,
249
+ dtype=dtype,
250
+ device=device,
251
+ )
252
+
253
+ if sequence_length != 1:
254
+ # Apply upper triangular mask (can't attend to future tokens)
255
+ causal_mask = torch.triu(causal_mask, diagonal=1)
256
+
257
+ # Mask positions beyond current cache position
258
+ if cache_position is not None:
259
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
260
+
261
+ # Expand to 4D: (batch_size, 1, sequence_length, target_length)
262
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
263
+
264
+ # Combine with input attention mask if provided
265
+ if attention_mask is not None:
266
+ causal_mask = causal_mask.clone()
267
+ mask_length = attention_mask.shape[-1]
268
+ # Create padding mask from attention_mask
269
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(dtype) * min_dtype
270
+ padding_mask = padding_mask == 0
271
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
272
+ padding_mask, min_dtype
273
+ )
274
+
275
+ return causal_mask
276
+
277
+ def forward(
278
+ self,
279
+ input_ids: torch.LongTensor = None,
280
+ attention_mask: Optional[torch.Tensor] = None,
281
+ position_ids: Optional[torch.LongTensor] = None,
282
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
283
+ inputs_embeds: Optional[torch.FloatTensor] = None,
284
+ use_cache: Optional[bool] = None,
285
+ output_attentions: Optional[bool] = None,
286
+ output_hidden_states: Optional[bool] = None,
287
+ return_dict: Optional[bool] = None,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ **kwargs,
290
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
291
+
292
+ return_dict = (
293
+ return_dict if return_dict is not None else self.config.use_return_dict
294
+ )
295
+
296
+ # Forward through language model
297
+ outputs = self.language_model(
298
+ input_ids=input_ids,
299
+ attention_mask=attention_mask,
300
+ position_ids=position_ids,
301
+ past_key_values=past_key_values,
302
+ inputs_embeds=inputs_embeds,
303
+ use_cache=use_cache,
304
+ output_attentions=output_attentions,
305
+ output_hidden_states=output_hidden_states,
306
+ return_dict=return_dict,
307
+ cache_position=cache_position,
308
+ **kwargs,
309
+ )
310
+
311
+ if not return_dict:
312
+ return outputs
313
+
314
+ return BaseModelOutputWithPast(
315
+ last_hidden_state=outputs.last_hidden_state,
316
+ past_key_values=outputs.past_key_values,
317
+ hidden_states=outputs.hidden_states,
318
+ attentions=outputs.attentions,
319
+ )
320
+
321
+
322
+ class KugelAudioForConditionalGeneration(KugelAudioPreTrainedModel):
323
+ """
324
+ Unified model for both training and inference.
325
+
326
+ Supports:
327
+ - Training via forward() with loss computation
328
+ - Inference via generate() for audio generation
329
+ """
330
+
331
+ _tied_weights_keys = ["lm_head.weight"]
332
+ _tp_plan = {"lm_head": "colwise_rep"}
333
+
334
+ def __init__(self, config):
335
+ super().__init__(config)
336
+ self.model = KugelAudioModel(config)
337
+ self.vocab_size = config.decoder_config.vocab_size
338
+ self.lm_head = nn.Linear(
339
+ config.decoder_config.hidden_size, self.vocab_size, bias=False
340
+ )
341
+
342
+ # Inference configuration (for generate() method)
343
+ self.ddpm_inference_steps = (
344
+ config.diffusion_head_config.ddpm_num_inference_steps
345
+ if hasattr(config, "diffusion_head_config")
346
+ else 5
347
+ )
348
+
349
+ self.post_init()
350
+
351
+ # Properties for easier access (used by generate())
352
+ @property
353
+ def noise_scheduler(self):
354
+ return self.model.noise_scheduler
355
+
356
+ @property
357
+ def prediction_head(self):
358
+ return self.model.prediction_head
359
+
360
+ def get_input_embeddings(self):
361
+ return self.model.get_input_embeddings()
362
+
363
+ def set_input_embeddings(self, value):
364
+ self.model.set_input_embeddings(value)
365
+
366
+ def get_output_embeddings(self):
367
+ return self.lm_head
368
+
369
+ def set_decoder(self, decoder):
370
+ self.model.language_model = decoder
371
+
372
+ def get_decoder(self):
373
+ return self.model.language_model
374
+
375
+ def tie_weights(self):
376
+ """
377
+ Tie the weights between the input embeddings and the output embeddings.
378
+ """
379
+ if getattr(self.config.decoder_config, "tie_word_embeddings", False):
380
+ # The standard PreTrainedModel method will handle the tying.
381
+ # It typically does a simple parameter object assignment, which is
382
+ # CORRECT to do BEFORE FSDP wraps the model.
383
+ output_embeddings = self.get_output_embeddings()
384
+ input_embeddings = self.get_input_embeddings()
385
+ if hasattr(input_embeddings, "weight"):
386
+ output_embeddings.weight = input_embeddings.weight
387
+ else:
388
+ # maybe returned input_embeddings a tensor directly
389
+ output_embeddings.weight = input_embeddings
390
+
391
+ if getattr(output_embeddings, "bias", None) is not None:
392
+ output_embeddings.bias.data = nn.functional.pad(
393
+ output_embeddings.bias.data,
394
+ (
395
+ 0,
396
+ output_embeddings.weight.shape[0]
397
+ - output_embeddings.bias.shape[0],
398
+ ),
399
+ "constant",
400
+ 0,
401
+ )
402
+ print("βœ… Tied input and output embeddings using standard assignment.")
403
+ else:
404
+ print("ℹ️ tie_word_embeddings is False, not tying weights.")
405
+
406
+ # Also, ensure set_output_embeddings is safe, though your implementation looks okay.
407
+ # The key is to avoid calling it after accelerator.prepare().
408
+ def set_output_embeddings(self, new_embeddings):
409
+ # Your current implementation using data.copy_ is good practice,
410
+ # but the best way is to not call this after prepare().
411
+ self.lm_head = new_embeddings
412
+
413
+ def forward_speech_features(
414
+ self,
415
+ speech_tensors=None,
416
+ speech_masks=None,
417
+ speech_type="audio",
418
+ return_unmask=False,
419
+ ):
420
+ if speech_tensors is None:
421
+ # Use config to get vae_dim instead of non-existent self.args
422
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
423
+ audio_features = torch.zeros(1, 1, vae_dim).to(
424
+ self.get_input_embeddings().weight
425
+ )
426
+ connect_features = self.model.acoustic_connector(audio_features)
427
+ return audio_features, connect_features
428
+ else:
429
+ with torch.no_grad():
430
+ if speech_type == "audio":
431
+ with torch.no_grad():
432
+ frames_out = self.model.acoustic_tokenizer.encode(
433
+ speech_tensors.unsqueeze(1)
434
+ )
435
+ if isinstance(frames_out, (list, tuple)):
436
+ frames = frames_out[0][0]
437
+ else:
438
+ frames = frames_out
439
+ audio_tokens = frames.sample(
440
+ self.model.acoustic_tokenizer.std_dist_type
441
+ )[0]
442
+
443
+ elif speech_type == "vae":
444
+ # Use config to get vae_dim instead of non-existent self.args
445
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
446
+ speech_mode = speech_tensors.reshape(
447
+ speech_tensors.size(0), -1, vae_dim
448
+ )
449
+
450
+ # gaussian sample from the speech_mode
451
+ batch_size = speech_mode.size(0)
452
+ value = self.model.acoustic_tokenizer.fix_std / 0.8
453
+ std = (
454
+ torch.randn(
455
+ batch_size,
456
+ dtype=speech_mode.dtype,
457
+ device=speech_mode.device,
458
+ )
459
+ * value
460
+ )
461
+ std = std.view(-1, *[1] * (speech_mode.dim() - 1))
462
+ audio_tokens = speech_mode + std * torch.randn(
463
+ speech_mode.shape
464
+ ).to(speech_mode)
465
+ else:
466
+ raise NotImplementedError(
467
+ f"Speech type {speech_type} not implemented"
468
+ )
469
+
470
+ if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(
471
+ self.model.speech_bias_factor
472
+ ):
473
+ scaling_factor = 1.0 / audio_tokens[speech_masks].flatten().std()
474
+ bias_factor = -audio_tokens[speech_masks].flatten().mean()
475
+
476
+ # Only use distributed operations if the process group is initialized
477
+ if dist.is_available() and dist.is_initialized():
478
+ dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
479
+ dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
480
+ world_size = dist.get_world_size()
481
+ self.model.speech_scaling_factor.copy_(
482
+ scaling_factor / world_size
483
+ )
484
+ self.model.speech_bias_factor.copy_(bias_factor / world_size)
485
+ print(
486
+ f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}",
487
+ flush=True,
488
+ )
489
+ else:
490
+ # Single process case
491
+ self.model.speech_scaling_factor.copy_(scaling_factor)
492
+ self.model.speech_bias_factor.copy_(bias_factor)
493
+ print(
494
+ f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}",
495
+ flush=True,
496
+ )
497
+
498
+ audio_features = (
499
+ audio_tokens + self.model.speech_bias_factor
500
+ ) * self.model.speech_scaling_factor
501
+
502
+ connect_features = self.model.acoustic_connector(audio_features)
503
+ if return_unmask:
504
+ return audio_features, connect_features
505
+ return audio_features[speech_masks], connect_features[speech_masks]
506
+
507
+ def forward(
508
+ self,
509
+ input_ids: torch.LongTensor = None,
510
+ attention_mask: Optional[torch.Tensor] = None,
511
+ position_ids: Optional[torch.LongTensor] = None,
512
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
513
+ inputs_embeds: Optional[torch.FloatTensor] = None,
514
+ labels: Optional[torch.LongTensor] = None,
515
+ use_cache: Optional[bool] = False,
516
+ output_attentions: Optional[bool] = None,
517
+ output_hidden_states: Optional[bool] = None,
518
+ return_dict: Optional[bool] = None,
519
+ cache_position: Optional[torch.LongTensor] = None,
520
+ # New arguments for speech processing and loss calculation
521
+ speech_tensors: Optional[torch.FloatTensor] = None,
522
+ speech_masks: Optional[torch.BoolTensor] = None,
523
+ speeches_loss_input: Optional[torch.FloatTensor] = None,
524
+ speech_semantic_tensors: Optional[torch.FloatTensor] = None,
525
+ acoustic_input_mask: Optional[torch.BoolTensor] = None,
526
+ acoustic_loss_mask: Optional[torch.BoolTensor] = None,
527
+ ddpm_batch_mul: int = 1,
528
+ **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
529
+ ) -> Union[Tuple, KugelAudioCausalLMOutputWithPast]:
530
+
531
+ return_dict = (
532
+ return_dict if return_dict is not None else self.config.use_return_dict
533
+ )
534
+
535
+ x = self.get_input_embeddings()(input_ids)
536
+
537
+ semantic_speech_all_connect_features = self.model.semantic_connector(
538
+ speech_semantic_tensors
539
+ )
540
+ if speeches_loss_input is not None:
541
+ # only part audio need diffuse
542
+ speech_all_features, speech_all_connect_features = (
543
+ self.forward_speech_features(
544
+ speech_tensors=(
545
+ speech_tensors.type_as(x)
546
+ if speech_tensors is not None
547
+ else None
548
+ ),
549
+ speech_masks=speech_masks,
550
+ speech_type=kwargs.get("speech_type", "audio"),
551
+ return_unmask=True,
552
+ )
553
+ )
554
+ if speech_tensors is not None:
555
+ if semantic_speech_all_connect_features is not None:
556
+ x[acoustic_input_mask] = (
557
+ speech_all_connect_features[speech_masks]
558
+ + semantic_speech_all_connect_features[speech_masks]
559
+ )
560
+ else:
561
+ x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
562
+ speech_features = speech_all_features[
563
+ speeches_loss_input & speech_masks
564
+ ] # only part audio need diffuse
565
+ speech_connect_features = speech_all_connect_features[
566
+ speeches_loss_input & speech_masks
567
+ ]
568
+ # Forward-time consistency check: selected latent count should match number of acoustic placeholders
569
+ try:
570
+ if acoustic_input_mask is not None:
571
+ assert speech_connect_features.shape[0] == int(
572
+ acoustic_input_mask.sum().item()
573
+ ), f"Mismatch between selected speech connectors ({speech_connect_features.shape[0]}) and acoustic_input_mask sum ({int(acoustic_input_mask.sum().item())})"
574
+ except Exception:
575
+ pass
576
+ else:
577
+ speech_features, speech_connect_features = self.forward_speech_features(
578
+ speech_tensors=(
579
+ speech_tensors.type_as(x) if speech_tensors is not None else None
580
+ ),
581
+ speech_masks=speech_masks,
582
+ speech_type=kwargs.get("speech_type", "audio"),
583
+ )
584
+ if speech_tensors is not None:
585
+ x[acoustic_input_mask] = speech_connect_features
586
+
587
+ outputs = self.model(
588
+ input_ids=None,
589
+ attention_mask=attention_mask,
590
+ position_ids=position_ids,
591
+ past_key_values=past_key_values,
592
+ inputs_embeds=x,
593
+ use_cache=use_cache,
594
+ output_attentions=output_attentions,
595
+ output_hidden_states=False,
596
+ return_dict=return_dict,
597
+ cache_position=cache_position,
598
+ )
599
+
600
+ hidden_states = outputs.last_hidden_state
601
+ logits = self.lm_head(hidden_states)
602
+ # logits = logits.float()
603
+
604
+ loss = None
605
+ if labels is not None:
606
+ # The custom CE loss with masking is calculated in the training script.
607
+ # We leave the standard loss calculation here as None.
608
+ pass
609
+
610
+ # --- Diffusion Loss Calculation ---
611
+ diffusion_loss = None
612
+ # This block is executed only if we are in a context that involves speech.
613
+ if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
614
+ # Build conditioning mask from positions whose NEXT token is a speech latent (shift left by 1)
615
+ cond_mask = torch.zeros_like(acoustic_loss_mask, dtype=torch.bool)
616
+ cond_mask[:, :-1] = acoustic_loss_mask[:, 1:]
617
+ cond_mask[:, 0] = False
618
+ condition_features = hidden_states[cond_mask]
619
+
620
+ speech_len, latent_size = speech_features.shape
621
+ # Sanity check: ensure 1:1 alignment between selected conditions and latents
622
+ try:
623
+ assert (
624
+ condition_features.shape[0] == speech_len
625
+ ), f"Mismatch: condition_features={condition_features.shape[0]} vs speech_features={speech_len}"
626
+ except Exception:
627
+ pass
628
+
629
+ noise = torch.randn(
630
+ (speech_len * ddpm_batch_mul, latent_size),
631
+ device=hidden_states.device,
632
+ dtype=hidden_states.dtype,
633
+ )
634
+
635
+ timesteps = torch.multinomial(
636
+ torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
637
+ speech_len * ddpm_batch_mul,
638
+ replacement=True,
639
+ ).to(hidden_states.device)
640
+
641
+ speech_features_repeated = speech_features.repeat_interleave(
642
+ ddpm_batch_mul, dim=0
643
+ )
644
+ condition_features_repeated = condition_features.repeat_interleave(
645
+ ddpm_batch_mul, dim=0
646
+ )
647
+
648
+ noisy_speech_features = self.model.noise_scheduler.add_noise(
649
+ speech_features_repeated, noise, timesteps
650
+ )
651
+
652
+ model_output = self.model.prediction_head(
653
+ noisy_speech_features, timesteps.type_as(x), condition_features_repeated
654
+ )
655
+
656
+ prediction_type = self.config.diffusion_head_config.prediction_type
657
+ if prediction_type == "epsilon":
658
+ target_for_loss = noise
659
+ elif prediction_type == "v_prediction":
660
+ target_for_loss = self.model.noise_scheduler.get_velocity(
661
+ speech_features_repeated, noise, timesteps
662
+ )
663
+ else:
664
+ raise NotImplementedError(
665
+ f"Prediction type {prediction_type} not implemented"
666
+ )
667
+
668
+ diffusion_loss = F.mse_loss(
669
+ model_output.float(), target_for_loss.float(), reduction="sum"
670
+ )
671
+ if latent_size > 0 and ddpm_batch_mul > 0:
672
+ # Normalize by latent dim, number of sampled diffusion steps per latent, and number of speech tokens
673
+ diffusion_loss = (
674
+ diffusion_loss / latent_size / ddpm_batch_mul / max(speech_len, 1)
675
+ )
676
+ else:
677
+ diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
678
+
679
+ else:
680
+ # Dummy loss for DDP to work when there are no speech samples in a batch,
681
+ # but we are in a speech context.
682
+ diffusion_loss = (
683
+ sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
684
+ )
685
+ diffusion_loss += (
686
+ sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
687
+ )
688
+ diffusion_loss += (
689
+ sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
690
+ )
691
+ # --- End Diffusion Loss Calculation ---
692
+
693
+ if not return_dict:
694
+ output = (logits, speech_len) + outputs.to_tuple()[1:]
695
+ return (loss, diffusion_loss) + output
696
+
697
+ return KugelAudioCausalLMOutputWithPast(
698
+ loss=loss,
699
+ diffusion_loss=diffusion_loss,
700
+ speech_token_num=torch.tensor(
701
+ speech_len if speech_tensors is not None else 0,
702
+ device=logits.device,
703
+ dtype=torch.long,
704
+ ),
705
+ logits=logits,
706
+ past_key_values=outputs.past_key_values,
707
+ hidden_states=outputs.hidden_states,
708
+ attentions=outputs.attentions,
709
+ )
710
+
711
+
712
+ AutoModel.register(KugelAudioConfig, KugelAudioModel)
713
+ AutoModelForCausalLM.register(KugelAudioConfig, KugelAudioForConditionalGeneration)
714
+
715
+ __all__ = [
716
+ "KugelAudioModel",
717
+ "KugelAudioPreTrainedModel",
718
+ "KugelAudioForConditionalGeneration",
719
+ "KugelAudioCausalLMOutputWithPast",
720
+ "KugelAudioGenerationOutput",
721
+ ]
kugelaudio_open/models/tokenizer.py ADDED
@@ -0,0 +1,1197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing as tp
3
+ from functools import partial
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+ import copy
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch._dynamo
13
+
14
+ from transformers.models.auto import AutoModel
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.activations import ACT2FN
20
+
21
+ from ..configs import KugelAudioAcousticTokenizerConfig, KugelAudioSemanticTokenizerConfig
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ # APEX is not used in the open-source version
26
+ APEX_AVAILABLE = False
27
+
28
+ # Normalization modules
29
+ class ConvLayerNorm(nn.LayerNorm):
30
+ """
31
+ Convolution-friendly LayerNorm that moves channels to last dimensions
32
+ before running the normalization and moves them back to original position right after.
33
+ """
34
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
35
+ super().__init__(normalized_shape, **kwargs)
36
+
37
+ def forward(self, x):
38
+ x = x.transpose(1, 2) # b ... t -> b t ...
39
+ x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
40
+ x = x.transpose(1, 2) # b t ... -> b ... t
41
+ return x
42
+
43
+ class RMSNorm(nn.Module):
44
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
45
+ super().__init__()
46
+ self.dim = dim
47
+ self.eps = eps
48
+ self.elementwise_affine = elementwise_affine
49
+ if self.elementwise_affine:
50
+ weight_shape = (dim,) if weight_shape is None else weight_shape
51
+ self.weight = nn.Parameter(torch.ones(weight_shape))
52
+ else:
53
+ self.register_parameter('weight', None)
54
+
55
+ def _norm(self, x):
56
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
57
+
58
+ def forward(self, x):
59
+ output = self._norm(x.float()).type_as(x)
60
+ if self.weight is not None:
61
+ output = output * self.weight
62
+ return output
63
+
64
+ def extra_repr(self) -> str:
65
+ return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
66
+
67
+ class ConvRMSNorm(RMSNorm):
68
+ def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
69
+ super().__init__(dim, eps, elementwise_affine, weight_shape)
70
+
71
+ def forward(self, x):
72
+ x = x.transpose(1, 2) # b ... t -> b t ...
73
+ if (not APEX_AVAILABLE) or (not self.elementwise_affine):
74
+ # Fallback to native implementation
75
+ output = self._norm(x.float()).type_as(x)
76
+ if self.weight is not None:
77
+ output = output * self.weight
78
+ else:
79
+ output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
80
+ output = output.transpose(1, 2) # b t ... -> b ... t
81
+ return output
82
+
83
+ # Convolutional layers and utilities
84
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
85
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
86
+
87
+
88
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
89
+ assert norm in CONV_NORMALIZATIONS
90
+ if norm == 'weight_norm':
91
+ return nn.utils.weight_norm(module)
92
+ elif norm == 'spectral_norm':
93
+ return nn.utils.spectral_norm(module)
94
+ else:
95
+ # We already check was in CONV_NORMALIZATION, so any other choice
96
+ # doesn't need reparametrization.
97
+ return module
98
+
99
+
100
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
101
+ """Return the proper normalization module. If causal is True, this will ensure the returned
102
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
103
+ """
104
+ assert norm in CONV_NORMALIZATIONS
105
+ if norm == 'layer_norm':
106
+ assert isinstance(module, nn.modules.conv._ConvNd)
107
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
108
+ elif norm == 'time_group_norm':
109
+ if causal:
110
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
111
+ assert isinstance(module, nn.modules.conv._ConvNd)
112
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
113
+ else:
114
+ return nn.Identity()
115
+
116
+
117
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
118
+ padding_total: int = 0) -> int:
119
+ """Calculate extra padding needed for convolution to have the same output length"""
120
+ length = x.shape[-1]
121
+ n_frames = (length - kernel_size + padding_total) / stride + 1
122
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
123
+ return ideal_length - length
124
+
125
+
126
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
127
+ """Pad 1D input with handling for small inputs in reflect mode"""
128
+ length = x.shape[-1]
129
+ padding_left, padding_right = paddings
130
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
131
+ if mode == 'reflect':
132
+ max_pad = max(padding_left, padding_right)
133
+ extra_pad = 0
134
+ if length <= max_pad:
135
+ extra_pad = max_pad - length + 1
136
+ x = F.pad(x, (0, extra_pad))
137
+ padded = F.pad(x, paddings, mode, value)
138
+ end = padded.shape[-1] - extra_pad
139
+ return padded[..., :end]
140
+ else:
141
+ return F.pad(x, paddings, mode, value)
142
+
143
+
144
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
145
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
146
+ padding_left, padding_right = paddings
147
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
148
+ assert (padding_left + padding_right) <= x.shape[-1]
149
+ end = x.shape[-1] - padding_right
150
+ return x[..., padding_left: end]
151
+
152
+
153
+ class NormConv1d(nn.Module):
154
+ """Wrapper around Conv1d and normalization applied to this conv"""
155
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
156
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
157
+ super().__init__()
158
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
159
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
160
+ self.norm_type = norm
161
+
162
+ def forward(self, x):
163
+ x = self.conv(x)
164
+ x = self.norm(x)
165
+ return x
166
+
167
+
168
+ class NormConvTranspose1d(nn.Module):
169
+ """Wrapper around ConvTranspose1d and normalization applied to this conv"""
170
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
171
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
172
+ super().__init__()
173
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
174
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
175
+ self.norm_type = norm
176
+
177
+ def forward(self, x):
178
+ x = self.convtr(x)
179
+ x = self.norm(x)
180
+ return x
181
+
182
+
183
+ class KugelAudioTokenizerStreamingCache:
184
+ """Cache for streaming convolution, similar to KV cache in attention"""
185
+ def __init__(self):
186
+ self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
187
+
188
+ def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
189
+ """Get cached states for given layer and sample indices"""
190
+ states = []
191
+ max_length = 0
192
+
193
+ # First pass: collect states and find max length
194
+ for idx in sample_indices.tolist():
195
+ key = (layer_id, idx)
196
+ if key not in self.cache:
197
+ return None # If any sample is missing, return None
198
+ state = self.cache[key]
199
+ states.append(state)
200
+ max_length = max(max_length, state.shape[-1])
201
+
202
+ # Second pass: pad states to max length if needed
203
+ if len(states) > 0 and states[0].dim() >= 2:
204
+ padded_states = []
205
+ for state in states:
206
+ if state.shape[-1] < max_length:
207
+ # Pad on the time dimension (last dimension)
208
+ pad_size = max_length - state.shape[-1]
209
+ # Pad with zeros on the LEFT to align the most recent samples
210
+ padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
211
+ padded_states.append(padded_state)
212
+ else:
213
+ padded_states.append(state)
214
+ return torch.stack(padded_states, dim=0)
215
+ else:
216
+ return torch.stack(states, dim=0)
217
+
218
+ def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
219
+ """Set cached states for given layer and sample indices"""
220
+ for i, idx in enumerate(sample_indices.tolist()):
221
+ key = (layer_id, idx)
222
+ self.cache[key] = states[i].detach()
223
+
224
+ def set_to_zero(self, sample_indices: torch.Tensor):
225
+ """Set all cached states to zero for given sample indices"""
226
+ for key in list(self.cache.keys()):
227
+ layer_id, sample_idx = key
228
+ if sample_idx in sample_indices.tolist():
229
+ # Create zero tensor with same shape and dtype as cached tensor
230
+ cached_tensor = self.cache[key]
231
+ self.cache[key] = torch.zeros_like(cached_tensor)
232
+
233
+ def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
234
+ """Clear cache for specific layer/samples or everything"""
235
+ if layer_id is None and sample_indices is None:
236
+ self.cache.clear()
237
+ elif layer_id is not None and sample_indices is None:
238
+ # Clear all samples for a specific layer
239
+ keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
240
+ for k in keys_to_remove:
241
+ del self.cache[k]
242
+ elif layer_id is not None and sample_indices is not None:
243
+ # Clear specific samples for a specific layer
244
+ for idx in sample_indices.tolist():
245
+ key = (layer_id, idx)
246
+ self.cache.pop(key, None)
247
+
248
+ class SConv1d(nn.Module):
249
+ """Conv1d with built-in handling of asymmetric or causal padding and normalization."""
250
+ def __init__(self, in_channels: int, out_channels: int,
251
+ kernel_size: int, stride: int = 1, dilation: int = 1,
252
+ groups: int = 1, bias: bool = True, causal: bool = False,
253
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
254
+ pad_mode: str = 'reflect'):
255
+ super().__init__()
256
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
257
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
258
+ norm=norm, norm_kwargs=norm_kwargs)
259
+ self.causal = causal
260
+ self.pad_mode = pad_mode
261
+
262
+ # Store configuration
263
+ self.kernel_size = kernel_size
264
+ self.dilation = dilation
265
+ self.stride = stride
266
+ self.in_channels = in_channels
267
+ self.out_channels = out_channels
268
+
269
+ # For causal convolution, we need to maintain kernel_size - 1 samples as context
270
+ # need to check use which context_size is more suitable
271
+ # self.context_size = (kernel_size - 1) * dilation
272
+ self.context_size = (kernel_size - 1) * dilation - (stride - 1)
273
+
274
+ # For non-streaming mode, calculate padding
275
+ self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
276
+
277
+ # Create a unique layer ID for cache management
278
+ self._layer_id = None
279
+
280
+ @property
281
+ def layer_id(self):
282
+ if self._layer_id is None:
283
+ self._layer_id = f"sconv1d_{id(self)}"
284
+ return self._layer_id
285
+
286
+ def forward(self, x: torch.Tensor,
287
+ cache: Optional[KugelAudioTokenizerStreamingCache] = None,
288
+ sample_indices: Optional[torch.Tensor] = None,
289
+ use_cache: bool = False,
290
+ debug: bool = False) -> torch.Tensor:
291
+ """
292
+ Forward pass with optional streaming support via cache.
293
+
294
+ Args:
295
+ x: Input tensor [batch_size, channels, time]
296
+ cache: KugelAudioTokenizerStreamingCache object for maintaining states
297
+ sample_indices: Indices identifying each sample for cache management
298
+ use_cache: Whether to use cached states for streaming
299
+ debug: Whether to print debug information
300
+
301
+ Returns:
302
+ Output tensor
303
+ """
304
+ B, C, T = x.shape
305
+
306
+ # Non-streaming mode
307
+ if not use_cache or cache is None:
308
+ return self._forward_non_streaming(x, debug=debug)
309
+
310
+ # Streaming mode
311
+ assert self.causal, "Streaming mode is only supported for causal convolutions"
312
+ assert sample_indices is not None, "sample_indices must be provided for streaming mode"
313
+ assert len(sample_indices) == B, "sample_indices must match batch size"
314
+
315
+ return self._forward_streaming(x, cache, sample_indices, debug)
316
+
317
+ @torch._dynamo.disable() # Disable compilation for streaming path - dynamic cache ops cause recompilations
318
+ def _forward_streaming(self, x: torch.Tensor,
319
+ cache: KugelAudioTokenizerStreamingCache,
320
+ sample_indices: torch.Tensor,
321
+ debug: bool = False) -> torch.Tensor:
322
+ """Streaming forward pass with cache operations kept separate from compiled code"""
323
+ B, C, T = x.shape
324
+
325
+ # Cache operations (not compiled)
326
+ cached_states = cache.get(self.layer_id, sample_indices)
327
+
328
+ if cached_states is None:
329
+ # First chunk - initialize with zeros for context
330
+ if self.context_size > 0:
331
+ cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
332
+ if debug:
333
+ print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
334
+ else:
335
+ cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
336
+ if debug:
337
+ print(f"[DEBUG] No context needed (kernel_size=stride)")
338
+
339
+ # Concatenate cached states with input
340
+ if cached_states.shape[2] > 0:
341
+ input_with_context = torch.cat([cached_states, x], dim=2)
342
+ else:
343
+ input_with_context = x
344
+
345
+ if debug:
346
+ print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
347
+
348
+ # Apply convolution directly - no extra padding in streaming mode
349
+ # The conv layer will handle its own padding internally
350
+ output = self.conv(input_with_context)
351
+
352
+ if debug:
353
+ print(f"[DEBUG] Output shape: {output.shape}")
354
+
355
+ # Update cache for next chunk
356
+ if self.context_size > 0:
357
+ # Calculate how many samples to keep
358
+ total_input_length = input_with_context.shape[2]
359
+
360
+ # Keep the last context_size samples
361
+ if total_input_length >= self.context_size:
362
+ new_cache_start = total_input_length - self.context_size
363
+ new_cache = input_with_context[:, :, new_cache_start:]
364
+ else:
365
+ # If we have less than context_size samples, keep everything
366
+ new_cache = input_with_context
367
+
368
+ if debug:
369
+ print(f"[DEBUG] New cache shape: {new_cache.shape}")
370
+
371
+ cache.set(self.layer_id, sample_indices, new_cache)
372
+
373
+ return output
374
+
375
+ def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
376
+ """Standard forward pass without streaming"""
377
+ B, C, T = x.shape
378
+ kernel_size = self.kernel_size
379
+ stride = self.stride
380
+ dilation = self.dilation
381
+ padding_total = self.padding_total
382
+
383
+ # Ensure weight is on the same device as input
384
+ if hasattr(self, "conv") and hasattr(self.conv, "conv") and hasattr(self.conv.conv, "weight"):
385
+ if self.conv.conv.weight.device != x.device:
386
+ self.conv.conv.to(x.device)
387
+
388
+ # Compute extra padding for stride alignment
389
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
390
+
391
+ if debug:
392
+ print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
393
+
394
+ if self.causal:
395
+ # Left padding for causal
396
+ if self.pad_mode == 'constant':
397
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
398
+ else:
399
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
400
+ else:
401
+ # Symmetric padding for non-causal
402
+ padding_right = padding_total // 2
403
+ padding_left = padding_total - padding_right
404
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
405
+
406
+ if debug:
407
+ print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
408
+
409
+ output = self.conv(x)
410
+
411
+ if debug:
412
+ print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
413
+
414
+ return output
415
+
416
+
417
+ class SConvTranspose1d(nn.Module):
418
+ """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
419
+ def __init__(self, in_channels: int, out_channels: int,
420
+ kernel_size: int, stride: int = 1, causal: bool = False,
421
+ norm: str = 'none', trim_right_ratio: float = 1.,
422
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
423
+ super().__init__()
424
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
425
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
426
+ self.causal = causal
427
+ self.trim_right_ratio = trim_right_ratio
428
+ assert self.causal or self.trim_right_ratio == 1., \
429
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
430
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
431
+
432
+ # Store configuration
433
+ self.kernel_size = kernel_size
434
+ self.stride = stride
435
+ self.in_channels = in_channels
436
+ self.out_channels = out_channels
437
+
438
+ # For transposed convolution, padding calculation is different
439
+ self.padding_total = kernel_size - stride
440
+
441
+ # For streaming, we need to keep track of input history
442
+ # Transposed conv needs to see multiple input samples to produce correct output
443
+ self.context_size = kernel_size - 1
444
+
445
+ # Create a unique layer ID for cache management
446
+ self._layer_id = None
447
+
448
+ @property
449
+ def layer_id(self):
450
+ if self._layer_id is None:
451
+ self._layer_id = f"sconvtr1d_{id(self)}"
452
+ return self._layer_id
453
+
454
+ def forward(self, x: torch.Tensor,
455
+ cache: Optional[KugelAudioTokenizerStreamingCache] = None,
456
+ sample_indices: Optional[torch.Tensor] = None,
457
+ use_cache: bool = False,
458
+ debug: bool = False) -> torch.Tensor:
459
+ """
460
+ Forward pass with optional streaming support via cache.
461
+ """
462
+ B, C, T = x.shape
463
+
464
+ # Non-streaming mode
465
+ if not use_cache or cache is None:
466
+ return self._forward_non_streaming(x, debug=debug)
467
+
468
+ # Streaming mode
469
+ assert sample_indices is not None, "sample_indices must be provided for streaming mode"
470
+ assert len(sample_indices) == B, "sample_indices must match batch size"
471
+
472
+ return self._forward_streaming(x, cache, sample_indices, debug)
473
+
474
+ @torch._dynamo.disable() # Disable compilation for streaming path - dynamic cache ops cause recompilations
475
+ def _forward_streaming(self, x: torch.Tensor,
476
+ cache: KugelAudioTokenizerStreamingCache,
477
+ sample_indices: torch.Tensor,
478
+ debug: bool = False) -> torch.Tensor:
479
+ """Streaming forward pass with cache operations kept separate from compiled code"""
480
+ B, C, T = x.shape
481
+
482
+ # Cache operations (not compiled)
483
+ cached_input = cache.get(self.layer_id, sample_indices)
484
+
485
+ if cached_input is None:
486
+ # First chunk - no history yet
487
+ cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
488
+ if debug:
489
+ print(f"[DEBUG] Initialized empty cache for transposed conv")
490
+
491
+ # Concatenate cached input with new input
492
+ full_input = torch.cat([cached_input, x], dim=2)
493
+
494
+ if debug:
495
+ print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
496
+
497
+ # First chunk or debug mode - use uncompiled version
498
+ full_output = self.convtr(full_input)
499
+
500
+ if debug:
501
+ print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
502
+
503
+ # Calculate padding to remove
504
+ if self.causal:
505
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
506
+ padding_left = self.padding_total - padding_right
507
+ else:
508
+ padding_right = self.padding_total // 2
509
+ padding_left = self.padding_total - padding_right
510
+
511
+ # Remove padding
512
+ if padding_left + padding_right > 0:
513
+ full_output = unpad1d(full_output, (padding_left, padding_right))
514
+
515
+ if debug:
516
+ print(f"[DEBUG] After unpadding: {full_output.shape}")
517
+
518
+ # Determine which part of the output corresponds to the new input
519
+ if cached_input.shape[2] == 0:
520
+ # First chunk - return all output
521
+ output = full_output
522
+ else:
523
+ # Subsequent chunks - return only the new output
524
+ expected_new_output = T * self.stride
525
+
526
+ # Take the last expected_new_output samples
527
+ if full_output.shape[2] >= expected_new_output:
528
+ output = full_output[:, :, -expected_new_output:]
529
+ else:
530
+ output = full_output
531
+
532
+ if debug:
533
+ print(f"[DEBUG] Final streaming output shape: {output.shape}")
534
+
535
+ # Update cache
536
+ if full_input.shape[2] > self.context_size:
537
+ new_cache = full_input[:, :, -self.context_size:]
538
+ else:
539
+ new_cache = full_input
540
+
541
+ if debug:
542
+ print(f"[DEBUG] New cache shape: {new_cache.shape}")
543
+
544
+ cache.set(self.layer_id, sample_indices, new_cache)
545
+
546
+ return output
547
+
548
+ def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
549
+ """Standard forward pass without streaming"""
550
+ # Ensure weight is on the same device as input
551
+ if hasattr(self, "convtr") and hasattr(self.convtr, "convtr") and hasattr(self.convtr.convtr, "weight"):
552
+ if self.convtr.convtr.weight.device != x.device:
553
+ self.convtr.convtr.to(x.device)
554
+
555
+ if debug:
556
+ print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
557
+
558
+ # Apply transposed convolution
559
+ y = self.convtr(x)
560
+
561
+ if debug:
562
+ print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
563
+
564
+ # Calculate and remove padding
565
+ if self.causal:
566
+ padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
567
+ padding_left = self.padding_total - padding_right
568
+ else:
569
+ padding_right = self.padding_total // 2
570
+ padding_left = self.padding_total - padding_right
571
+
572
+ if padding_left + padding_right > 0:
573
+ y = unpad1d(y, (padding_left, padding_right))
574
+
575
+ if debug:
576
+ print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
577
+
578
+ return y
579
+
580
+ # FFN
581
+ class FFN(nn.Module):
582
+ def __init__(
583
+ self,
584
+ embed_dim,
585
+ ffn_dim,
586
+ bias=False,
587
+ ):
588
+ super().__init__()
589
+ self.embed_dim = embed_dim
590
+ self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
591
+ self.gelu = ACT2FN["gelu"]
592
+ self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
593
+
594
+ def forward(self, x):
595
+ x = self.linear1(x)
596
+ x = self.gelu(x)
597
+ x = self.linear2(x)
598
+ return x
599
+
600
+
601
+ class Convlayer(nn.Module):
602
+ def __init__(
603
+ self,
604
+ in_channels,
605
+ out_channels,
606
+ kernel_size,
607
+ stride=1,
608
+ dilation=1,
609
+ groups=1,
610
+ bias=True,
611
+ pad_mode='zeros',
612
+ norm='weight_norm',
613
+ causal=True,
614
+ ):
615
+ super().__init__()
616
+ self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
617
+ groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
618
+
619
+ def forward(self, x):
620
+ return self.conv(x)
621
+
622
+ class Block1D(nn.Module):
623
+ def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
624
+ layer_scale_init_value=1e-6, **kwargs):
625
+ super().__init__()
626
+
627
+ if kwargs.get('layernorm', 'LN') == 'LN':
628
+ self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
629
+ self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
630
+ elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
631
+ self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
632
+ self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
633
+
634
+ if mixer_layer == 'conv':
635
+ self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
636
+ kernel_size=kernel_size,
637
+ pad_mode=kwargs.get('pad_mode', 'reflect'),
638
+ norm=kwargs.get('norm', 'none'),
639
+ causal=kwargs.get('causal', True),
640
+ bias=kwargs.get('bias', True),
641
+ )
642
+ elif mixer_layer == 'depthwise_conv':
643
+ self.mixer = Convlayer(dim, dim, groups=dim,
644
+ kernel_size=kernel_size,
645
+ pad_mode=kwargs.get('pad_mode', 'reflect'),
646
+ norm=kwargs.get('norm', 'none'),
647
+ causal=kwargs.get('causal', True),
648
+ bias=kwargs.get('bias', True),
649
+ )
650
+ else:
651
+ raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
652
+
653
+ self.ffn = FFN(
654
+ dim,
655
+ kwargs.get('ffn_expansion', 4) * dim,
656
+ bias=kwargs.get('bias', False),
657
+ )
658
+ self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
659
+
660
+ if layer_scale_init_value > 0:
661
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
662
+ self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
663
+ else:
664
+ self.gamma = None
665
+ self.ffn_gamma = None
666
+
667
+ def forward(self, x):
668
+ # mixer
669
+ residual = x
670
+ x = self.norm(x)
671
+ x = self.mixer(x)
672
+ if self.gamma is not None:
673
+ x = x * self.gamma.unsqueeze(-1)
674
+ x = residual + self.drop_path(x)
675
+
676
+ # ffn
677
+ residual = x
678
+ x = self.ffn_norm(x)
679
+ x = x.permute(0, 2, 1)
680
+ x = self.ffn(x)
681
+ x = x.permute(0, 2, 1)
682
+ if self.ffn_gamma is not None:
683
+ x = x * self.ffn_gamma.unsqueeze(-1)
684
+ x = residual + self.drop_path(x)
685
+
686
+ return x
687
+
688
+
689
+ class TokenizerEncoder(nn.Module):
690
+ """
691
+ Encoder component for the KugelAudio tokenizer that converts audio to latent representations.
692
+
693
+ Args:
694
+ config: Configuration object with model parameters
695
+ """
696
+ def __init__(self, config):
697
+ super().__init__()
698
+
699
+ # Extract parameters from config
700
+ self.channels = config.channels
701
+ self.dimension = config.dimension
702
+ self.n_filters = config.n_filters
703
+ self.ratios = list(reversed(config.ratios))
704
+ self.depths = config.depths
705
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
706
+ self.hop_length = np.prod(self.ratios)
707
+ self.causal = config.causal
708
+
709
+ # Additional config parameters with defaults
710
+ kernel_size = getattr(config, "kernel_size", 7)
711
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
712
+ norm = getattr(config, "norm", "none")
713
+ norm_params = getattr(config, "norm_params", {})
714
+ pad_mode = getattr(config, "pad_mode", "reflect")
715
+ bias = getattr(config, "bias", True)
716
+ layernorm = getattr(config, "layernorm", "LN")
717
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
718
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
719
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
720
+ mixer_layer = getattr(config, "mixer_layer", "conv")
721
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
722
+ disable_last_norm = getattr(config, "disable_last_norm", False)
723
+
724
+ # determine the norm type based on layernorm
725
+ if layernorm == 'LN':
726
+ norm_type = ConvLayerNorm
727
+ elif layernorm == 'RMSNorm':
728
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
729
+ else:
730
+ raise ValueError(f"Unsupported norm type: {layernorm}")
731
+
732
+ # stem and intermediate downsampling conv layers
733
+ stem = nn.Sequential(
734
+ SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
735
+ )
736
+
737
+ self.downsample_layers = nn.ModuleList()
738
+ self.downsample_layers.append(stem)
739
+ for i in range(len(self.ratios)):
740
+ in_ch = self.n_filters * (2 ** i)
741
+ out_ch = self.n_filters * (2 ** (i + 1))
742
+ downsample_layer = nn.Sequential(
743
+ SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
744
+ )
745
+ self.downsample_layers.append(downsample_layer)
746
+
747
+ # configure the transformer blocks
748
+ layer_type = partial(
749
+ Block1D,
750
+ mixer_layer=mixer_layer,
751
+ layernorm=layernorm,
752
+ eps=layernorm_eps,
753
+ causal=self.causal,
754
+ pad_mode=pad_mode,
755
+ norm=norm,
756
+ bias=bias,
757
+ layer_scale_init_value=layer_scale_init_value,
758
+ )
759
+
760
+ self.stages = nn.ModuleList()
761
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
762
+ cur = 0
763
+
764
+ for i in range(len(self.depths)):
765
+ in_ch = self.n_filters * (2 ** i)
766
+ stage = nn.Sequential(
767
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
768
+ )
769
+ self.stages.append(stage)
770
+ cur += self.depths[i]
771
+
772
+ if not disable_last_norm:
773
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
774
+ else:
775
+ self.norm = nn.Identity()
776
+ self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
777
+
778
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
779
+ for i in range(len(self.depths)):
780
+ # Apply downsampling
781
+ for layer in self.downsample_layers[i]:
782
+ if isinstance(layer, SConv1d):
783
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
784
+ else:
785
+ x = layer(x)
786
+
787
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
788
+ for block in self.stages[i]:
789
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
790
+ # Block1D forward with cache support
791
+ residual = x
792
+ x = block.norm(x)
793
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
794
+ if block.gamma is not None:
795
+ x = x * block.gamma.unsqueeze(-1)
796
+ x = residual + x
797
+
798
+ # FFN part
799
+ residual = x
800
+ x = block.ffn_norm(x)
801
+ x = x.permute(0, 2, 1)
802
+ x = block.ffn(x)
803
+ x = x.permute(0, 2, 1)
804
+ if block.ffn_gamma is not None:
805
+ x = x * block.ffn_gamma.unsqueeze(-1)
806
+ x = residual + x
807
+ else:
808
+ x = block(x)
809
+
810
+ return self.norm(x)
811
+
812
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
813
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
814
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
815
+ return x
816
+
817
+
818
+ class TokenizerDecoder(nn.Module):
819
+ """
820
+ Decoder component for the KugelAudio tokenizer that converts latent representations back to audio.
821
+
822
+ Args:
823
+ config: Configuration object with model parameters
824
+ """
825
+ def __init__(self, config):
826
+ super().__init__()
827
+
828
+ # Extract parameters from config
829
+ self.dimension = config.dimension
830
+ self.channels = config.channels
831
+ self.n_filters = config.n_filters
832
+ self.ratios = config.ratios
833
+
834
+ # IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in KugelAudioAcousticTokenizerModel
835
+ self.depths = config.depths # Changed from list(reversed(config.depths))
836
+
837
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
838
+ self.hop_length = np.prod(self.ratios)
839
+ self.causal = config.causal
840
+
841
+ # Additional config parameters with defaults
842
+ kernel_size = getattr(config, "kernel_size", 7)
843
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
844
+ norm = getattr(config, "norm", "none")
845
+ norm_params = getattr(config, "norm_params", {})
846
+ pad_mode = getattr(config, "pad_mode", "reflect")
847
+ bias = getattr(config, "bias", True)
848
+ layernorm = getattr(config, "layernorm", "LN")
849
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
850
+ trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
851
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
852
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
853
+ mixer_layer = getattr(config, "mixer_layer", "conv")
854
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
855
+ disable_last_norm = getattr(config, "disable_last_norm", False)
856
+
857
+ # determine the norm type based on layernorm
858
+ if layernorm == 'LN':
859
+ norm_type = ConvLayerNorm
860
+ elif layernorm == 'RMSNorm':
861
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
862
+ else:
863
+ raise ValueError(f"Unsupported norm type: {layernorm}")
864
+
865
+ # stem and upsampling layers
866
+ stem = nn.Sequential(
867
+ SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
868
+ norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
869
+ )
870
+
871
+ self.upsample_layers = nn.ModuleList()
872
+ self.upsample_layers.append(stem)
873
+ for i in range(len(self.ratios)):
874
+ in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
875
+ out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
876
+ upsample_layer = nn.Sequential(
877
+ SConvTranspose1d(in_ch, out_ch,
878
+ kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
879
+ norm=norm, norm_kwargs=norm_params, bias=bias,
880
+ causal=self.causal, trim_right_ratio=trim_right_ratio),
881
+ )
882
+ self.upsample_layers.append(upsample_layer)
883
+
884
+ # configure transformer blocks
885
+ layer_type = partial(
886
+ Block1D,
887
+ mixer_layer=mixer_layer,
888
+ layernorm=layernorm,
889
+ eps=layernorm_eps,
890
+ causal=self.causal,
891
+ pad_mode=pad_mode,
892
+ norm=norm,
893
+ bias=bias,
894
+ layer_scale_init_value=layer_scale_init_value,
895
+ )
896
+
897
+ self.stages = nn.ModuleList()
898
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
899
+ cur = 0
900
+
901
+ # Create stages in the same order as the original model
902
+ for i in range(len(self.depths)):
903
+ in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
904
+ stage = nn.Sequential(
905
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
906
+ )
907
+ self.stages.append(stage)
908
+ cur += self.depths[i]
909
+
910
+ if not disable_last_norm:
911
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
912
+ else:
913
+ self.norm = nn.Identity()
914
+ self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
915
+
916
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
917
+ for i in range(len(self.depths)):
918
+ # Apply upsampling
919
+ for layer in self.upsample_layers[i]:
920
+ if isinstance(layer, (SConv1d, SConvTranspose1d)):
921
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
922
+ else:
923
+ x = layer(x)
924
+
925
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
926
+ for block in self.stages[i]:
927
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
928
+ # Block1D forward with cache support
929
+ residual = x
930
+ x = block.norm(x)
931
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
932
+ if block.gamma is not None:
933
+ x = x * block.gamma.unsqueeze(-1)
934
+ x = residual + x
935
+
936
+ # FFN part
937
+ residual = x
938
+ x = block.ffn_norm(x)
939
+ x = x.permute(0, 2, 1)
940
+ x = block.ffn(x)
941
+ x = x.permute(0, 2, 1)
942
+ if block.ffn_gamma is not None:
943
+ x = x * block.ffn_gamma.unsqueeze(-1)
944
+ x = residual + x
945
+ else:
946
+ x = block(x)
947
+
948
+ return self.norm(x)
949
+
950
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
951
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
952
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
953
+ return x
954
+
955
+
956
+ @dataclass
957
+ class KugelAudioTokenizerEncoderOutput:
958
+ """
959
+ Output of KugelAudio tokenizer encoder, representing a Gaussian distribution with fixed variance.
960
+
961
+ Args:
962
+ mean (`torch.FloatTensor`): The mean parameters of the distribution.
963
+ std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
964
+ """
965
+ mean: torch.Tensor
966
+ std: Optional[Union[float, torch.Tensor]] = None
967
+
968
+ def sample(self, dist_type='fix'):
969
+ """
970
+ Sample from the distribution.
971
+
972
+ Args:
973
+ dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
974
+
975
+ Returns:
976
+ `torch.FloatTensor`: Sampled values.
977
+ `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
978
+ """
979
+ if dist_type == 'fix':
980
+ x = self.mean + self.std * torch.randn_like(self.mean)
981
+ return x, self.std
982
+ elif dist_type == 'gaussian':
983
+ batch_size = self.mean.size(0)
984
+ value = self.std / 0.8
985
+ std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
986
+
987
+ while std.dim() < self.mean.dim():
988
+ std = std.unsqueeze(-1)
989
+
990
+ x = self.mean + std * torch.randn_like(self.mean)
991
+ return x, std
992
+ else:
993
+ return self.mean, self.std
994
+
995
+ def kl(self):
996
+ """Compute KL divergence between this distribution and a standard normal."""
997
+ target = torch.zeros_like(self.mean)
998
+ return F.mse_loss(self.mean, target, reduction='none')
999
+
1000
+ def mode(self):
1001
+ """Return the distribution mode (which is the mean for Gaussian)."""
1002
+ return self.mean
1003
+
1004
+ class KugelAudioAcousticTokenizerModel(PreTrainedModel):
1005
+ """KugelAudio speech tokenizer model combining encoder and decoder for acoustic tokens"""
1006
+
1007
+ config_class = KugelAudioAcousticTokenizerConfig
1008
+ base_model_prefix = "kugelaudio_acoustic_tokenizer"
1009
+ _supports_flash_attn_2 = True
1010
+ _supports_sdpa = True
1011
+ _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
1012
+
1013
+ def __init__(self, config):
1014
+ super().__init__(config)
1015
+
1016
+ self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
1017
+ self.std_dist_type = getattr(config, "std_dist_type", "fix")
1018
+
1019
+ # Parse encoder depths
1020
+ if isinstance(config.encoder_depths, str):
1021
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1022
+ else:
1023
+ encoder_depths = config.encoder_depths
1024
+
1025
+ # Parse decoder depths if provided
1026
+ if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
1027
+ decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
1028
+ else:
1029
+ # Default: use reversed encoder depths if decoder_depths is None
1030
+ decoder_depths = list(reversed(encoder_depths))
1031
+
1032
+ # Create encoder config
1033
+ encoder_config = copy.deepcopy(config)
1034
+ encoder_config.dimension = config.vae_dim
1035
+ encoder_config.n_filters = config.encoder_n_filters
1036
+ encoder_config.ratios = config.encoder_ratios
1037
+ encoder_config.depths = encoder_depths
1038
+ encoder_config.norm = config.conv_norm
1039
+ encoder_config.pad_mode = config.pad_mode
1040
+ encoder_config.bias = config.conv_bias
1041
+ encoder_config.layernorm_eps = config.layernorm_eps
1042
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1043
+ encoder_config.mixer_layer = config.mixer_layer
1044
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
1045
+ encoder_config.disable_last_norm = config.disable_last_norm
1046
+
1047
+ # Create decoder config
1048
+ decoder_config = copy.deepcopy(config)
1049
+ decoder_config.dimension = config.vae_dim
1050
+ decoder_config.n_filters = config.decoder_n_filters
1051
+ decoder_config.ratios = config.decoder_ratios
1052
+ decoder_config.depths = decoder_depths
1053
+ decoder_config.norm = config.conv_norm
1054
+ decoder_config.pad_mode = config.pad_mode
1055
+ decoder_config.bias = config.conv_bias
1056
+ decoder_config.layernorm_eps = config.layernorm_eps
1057
+ decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1058
+ decoder_config.mixer_layer = config.mixer_layer
1059
+ decoder_config.layer_scale_init_value = config.layer_scale_init_value
1060
+ decoder_config.disable_last_norm = config.disable_last_norm
1061
+
1062
+ # Initialize encoder and decoder
1063
+ self.encoder = TokenizerEncoder(encoder_config)
1064
+ self.decoder = TokenizerDecoder(decoder_config)
1065
+
1066
+ # Initialize weights
1067
+ self.apply(self._init_weights)
1068
+
1069
+ def _init_weights(self, module):
1070
+ """Initialize weights for the model"""
1071
+ if isinstance(module, nn.Linear):
1072
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1073
+ if module.bias is not None:
1074
+ nn.init.zeros_(module.bias)
1075
+ elif isinstance(module, nn.LayerNorm):
1076
+ nn.init.ones_(module.weight)
1077
+ nn.init.zeros_(module.bias)
1078
+ elif isinstance(module, nn.Conv1d):
1079
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1080
+ if module.bias is not None:
1081
+ nn.init.zeros_(module.bias)
1082
+
1083
+ @torch.no_grad()
1084
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1085
+ """Convert audio to latent representations"""
1086
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1087
+ return KugelAudioTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
1088
+
1089
+ @torch.no_grad()
1090
+ def sampling(self, encoder_output, dist_type=None):
1091
+ """Sample from the encoder output distribution"""
1092
+ dist_type = dist_type or self.std_dist_type
1093
+
1094
+ if dist_type == 'fix':
1095
+ return encoder_output.sample(dist_type='fix')
1096
+ elif dist_type == 'gaussian':
1097
+ return encoder_output.sample(dist_type='gaussian')
1098
+ else:
1099
+ raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
1100
+
1101
+ @torch.no_grad()
1102
+ def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
1103
+ """Convert latent representations back to audio"""
1104
+ if latents.shape[1] == self.config.vae_dim:
1105
+ pass
1106
+ else:
1107
+ latents = latents.permute(0, 2, 1)
1108
+
1109
+ audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1110
+ return audio
1111
+
1112
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1113
+ """Full forward pass: encode audio to latents, then decode back to audio"""
1114
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1115
+ sampled_latents, _ = self.sampling(encoder_output)
1116
+ reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1117
+ return reconstructed, sampled_latents
1118
+
1119
+
1120
+ class KugelAudioSemanticTokenizerModel(PreTrainedModel):
1121
+ """KugelAudio speech tokenizer model with only encoder for semantic tokens"""
1122
+
1123
+ config_class = KugelAudioSemanticTokenizerConfig
1124
+ base_model_prefix = "kugelaudio_semantic_tokenizer"
1125
+ _supports_flash_attn_2 = True
1126
+ _supports_sdpa = True
1127
+ _no_split_modules = ["TokenizerEncoder"]
1128
+
1129
+ def __init__(self, config):
1130
+ super().__init__(config)
1131
+
1132
+ # Parse encoder depths
1133
+ if isinstance(config.encoder_depths, str):
1134
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
1135
+ else:
1136
+ encoder_depths = config.encoder_depths
1137
+
1138
+ # Create encoder config
1139
+ encoder_config = copy.deepcopy(config)
1140
+ encoder_config.dimension = config.vae_dim
1141
+ encoder_config.n_filters = config.encoder_n_filters
1142
+ encoder_config.ratios = config.encoder_ratios
1143
+ encoder_config.depths = encoder_depths
1144
+ encoder_config.norm = config.conv_norm
1145
+ encoder_config.pad_mode = config.pad_mode
1146
+ encoder_config.bias = config.conv_bias
1147
+ encoder_config.layernorm_eps = config.layernorm_eps
1148
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
1149
+ encoder_config.mixer_layer = config.mixer_layer
1150
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
1151
+ encoder_config.disable_last_norm = config.disable_last_norm
1152
+
1153
+ # Initialize encoder and decoder
1154
+ self.encoder = TokenizerEncoder(encoder_config)
1155
+
1156
+ # Initialize weights
1157
+ self.apply(self._init_weights)
1158
+
1159
+ def _init_weights(self, module):
1160
+ """Initialize weights for the model"""
1161
+ if isinstance(module, nn.Linear):
1162
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1163
+ if module.bias is not None:
1164
+ nn.init.zeros_(module.bias)
1165
+ elif isinstance(module, nn.LayerNorm):
1166
+ nn.init.ones_(module.weight)
1167
+ nn.init.zeros_(module.bias)
1168
+ elif isinstance(module, nn.Conv1d):
1169
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
1170
+ if module.bias is not None:
1171
+ nn.init.zeros_(module.bias)
1172
+
1173
+ @torch.no_grad()
1174
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1175
+ """Convert audio to latent representations"""
1176
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1177
+ return KugelAudioTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
1178
+
1179
+ @torch.no_grad()
1180
+ def sampling(self, encoder_output, dist_type=None):
1181
+ """Sample from the encoder output distribution"""
1182
+ return encoder_output.sample(dist_type='none')
1183
+
1184
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
1185
+ """Full forward pass: encode audio to latents, then decode back to audio"""
1186
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
1187
+ sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
1188
+ return None, sampled_latents
1189
+
1190
+ AutoModel.register(KugelAudioAcousticTokenizerConfig, KugelAudioAcousticTokenizerModel)
1191
+ AutoModel.register(KugelAudioSemanticTokenizerConfig, KugelAudioSemanticTokenizerModel)
1192
+
1193
+ __all__ = [
1194
+ "KugelAudioTokenizerStreamingCache",
1195
+ "KugelAudioAcousticTokenizerModel",
1196
+ "KugelAudioSemanticTokenizerModel",
1197
+ ]
kugelaudio_open/processors/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Processors for KugelAudio text and audio handling."""
2
+
3
+ from kugelaudio_open.processors.audio_processor import AudioProcessor, AudioNormalizer
4
+ from kugelaudio_open.processors.kugelaudio_processor import KugelAudioProcessor
5
+
6
+ __all__ = [
7
+ "AudioProcessor",
8
+ "AudioNormalizer",
9
+ "KugelAudioProcessor",
10
+ ]
kugelaudio_open/processors/audio_processor.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio processing utilities for KugelAudio."""
2
+
3
+ import os
4
+ from typing import Optional, Union, List, Dict, Any
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from transformers.feature_extraction_utils import FeatureExtractionMixin
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class AudioNormalizer:
16
+ """Normalize audio to target dB FS level.
17
+
18
+ This ensures consistent input levels for the model while
19
+ maintaining audio quality and avoiding clipping.
20
+ """
21
+
22
+ def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
23
+ self.target_dB_FS = target_dB_FS
24
+ self.eps = eps
25
+
26
+ def normalize_db(self, audio: np.ndarray) -> tuple:
27
+ """Adjust audio to target dB FS level."""
28
+ rms = np.sqrt(np.mean(audio**2))
29
+ scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
30
+ return audio * scalar, rms, scalar
31
+
32
+ def avoid_clipping(self, audio: np.ndarray) -> tuple:
33
+ """Scale down if necessary to avoid clipping."""
34
+ max_val = np.max(np.abs(audio))
35
+ if max_val > 1.0:
36
+ scalar = max_val + self.eps
37
+ return audio / scalar, scalar
38
+ return audio, 1.0
39
+
40
+ def __call__(self, audio: np.ndarray) -> np.ndarray:
41
+ """Normalize audio: adjust dB FS then avoid clipping."""
42
+ audio, _, _ = self.normalize_db(audio)
43
+ audio, _ = self.avoid_clipping(audio)
44
+ return audio
45
+
46
+
47
+ class AudioProcessor(FeatureExtractionMixin):
48
+ """Processor for audio preprocessing and postprocessing.
49
+
50
+ Handles:
51
+ - Audio format conversion (stereo to mono)
52
+ - Normalization
53
+ - Loading from various file formats
54
+ - Saving to WAV files
55
+
56
+ Example:
57
+ >>> processor = AudioProcessor(sampling_rate=24000)
58
+ >>> audio = processor("path/to/audio.wav")
59
+ >>> processor.save_audio(generated_audio, "output.wav")
60
+ """
61
+
62
+ model_input_names = ["input_features"]
63
+
64
+ def __init__(
65
+ self,
66
+ sampling_rate: int = 24000,
67
+ normalize_audio: bool = True,
68
+ target_dB_FS: float = -25,
69
+ eps: float = 1e-6,
70
+ **kwargs,
71
+ ):
72
+ super().__init__(**kwargs)
73
+
74
+ self.sampling_rate = sampling_rate
75
+ self.normalize_audio = normalize_audio
76
+ self.normalizer = AudioNormalizer(target_dB_FS, eps) if normalize_audio else None
77
+
78
+ self.feature_extractor_dict = {
79
+ "sampling_rate": sampling_rate,
80
+ "normalize_audio": normalize_audio,
81
+ "target_dB_FS": target_dB_FS,
82
+ "eps": eps,
83
+ }
84
+
85
+ def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
86
+ """Convert stereo to mono if needed."""
87
+ if len(audio.shape) == 1:
88
+ return audio
89
+ elif len(audio.shape) == 2:
90
+ if audio.shape[0] == 2:
91
+ return np.mean(audio, axis=0)
92
+ elif audio.shape[1] == 2:
93
+ return np.mean(audio, axis=1)
94
+ elif audio.shape[0] == 1:
95
+ return audio.squeeze(0)
96
+ elif audio.shape[1] == 1:
97
+ return audio.squeeze(1)
98
+ else:
99
+ raise ValueError(f"Unexpected audio shape: {audio.shape}")
100
+ else:
101
+ raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
102
+
103
+ def _process_single(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
104
+ """Process a single audio array."""
105
+ if not isinstance(audio, np.ndarray):
106
+ audio = np.array(audio, dtype=np.float32)
107
+ else:
108
+ audio = audio.astype(np.float32)
109
+
110
+ audio = self._ensure_mono(audio)
111
+
112
+ if self.normalize_audio and self.normalizer:
113
+ audio = self.normalizer(audio)
114
+
115
+ return audio
116
+
117
+ def _load_from_path(self, audio_path: str) -> np.ndarray:
118
+ """Load audio from file path."""
119
+ ext = os.path.splitext(audio_path)[1].lower()
120
+
121
+ if ext in [".wav", ".mp3", ".flac", ".m4a", ".ogg"]:
122
+ import librosa
123
+ audio, _ = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
124
+ return audio
125
+ elif ext == ".pt":
126
+ tensor = torch.load(audio_path, map_location="cpu", weights_only=True).squeeze()
127
+ return tensor.numpy().astype(np.float32)
128
+ elif ext == ".npy":
129
+ return np.load(audio_path).astype(np.float32)
130
+ else:
131
+ raise ValueError(f"Unsupported format: {ext}")
132
+
133
+ def __call__(
134
+ self,
135
+ audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[str]] = None,
136
+ sampling_rate: Optional[int] = None,
137
+ return_tensors: Optional[str] = None,
138
+ **kwargs,
139
+ ) -> Dict[str, Any]:
140
+ """Process audio input(s).
141
+
142
+ Args:
143
+ audio: Audio input - path, array, or list of either
144
+ sampling_rate: Input sampling rate (for validation)
145
+ return_tensors: Return format ("pt" for PyTorch, "np" for NumPy)
146
+
147
+ Returns:
148
+ Dictionary with processed audio
149
+ """
150
+ if audio is None:
151
+ raise ValueError("Audio input is required")
152
+
153
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
154
+ logger.warning(
155
+ f"Input sampling rate ({sampling_rate}) differs from expected ({self.sampling_rate}). "
156
+ "Please resample your audio."
157
+ )
158
+
159
+ # Handle different input types
160
+ if isinstance(audio, str):
161
+ audio = self._load_from_path(audio)
162
+ is_batched = False
163
+ elif isinstance(audio, list):
164
+ if all(isinstance(item, str) for item in audio):
165
+ audio = [self._load_from_path(p) for p in audio]
166
+ is_batched = True
167
+ else:
168
+ is_batched = isinstance(audio[0], (np.ndarray, list))
169
+ else:
170
+ is_batched = False
171
+
172
+ # Process
173
+ if is_batched:
174
+ processed = [self._process_single(a) for a in audio]
175
+ else:
176
+ processed = [self._process_single(audio)]
177
+
178
+ # Convert to tensors
179
+ if return_tensors == "pt":
180
+ if len(processed) == 1:
181
+ features = torch.from_numpy(processed[0]).unsqueeze(0).unsqueeze(1)
182
+ else:
183
+ features = torch.stack([torch.from_numpy(a) for a in processed]).unsqueeze(1)
184
+ elif return_tensors == "np":
185
+ if len(processed) == 1:
186
+ features = processed[0][np.newaxis, np.newaxis, :]
187
+ else:
188
+ features = np.stack(processed)[:, np.newaxis, :]
189
+ else:
190
+ features = processed[0] if len(processed) == 1 else processed
191
+
192
+ return {"audio": features}
193
+
194
+ def save_audio(
195
+ self,
196
+ audio: Union[torch.Tensor, np.ndarray, List],
197
+ output_path: str = "output.wav",
198
+ sampling_rate: Optional[int] = None,
199
+ normalize: bool = False,
200
+ batch_prefix: str = "audio_",
201
+ ) -> List[str]:
202
+ """Save audio to WAV file(s).
203
+
204
+ Args:
205
+ audio: Audio data to save
206
+ output_path: Output path (directory for batched audio)
207
+ sampling_rate: Sampling rate (defaults to processor's rate)
208
+ normalize: Whether to normalize before saving
209
+ batch_prefix: Prefix for batch files
210
+
211
+ Returns:
212
+ List of saved file paths
213
+ """
214
+ import soundfile as sf
215
+
216
+ if sampling_rate is None:
217
+ sampling_rate = self.sampling_rate
218
+
219
+ # Convert to numpy
220
+ if isinstance(audio, torch.Tensor):
221
+ audio_np = audio.float().detach().cpu().numpy()
222
+ elif isinstance(audio, list):
223
+ if all(isinstance(a, torch.Tensor) for a in audio):
224
+ audio_np = [a.float().detach().cpu().numpy() for a in audio]
225
+ else:
226
+ audio_np = audio
227
+ else:
228
+ audio_np = audio
229
+
230
+ saved_paths = []
231
+
232
+ if isinstance(audio_np, list):
233
+ os.makedirs(output_path, exist_ok=True)
234
+ for i, item in enumerate(audio_np):
235
+ item = self._prepare_for_save(item, normalize)
236
+ path = os.path.join(output_path, f"{batch_prefix}{i}.wav")
237
+ sf.write(path, item, sampling_rate)
238
+ saved_paths.append(path)
239
+ elif len(audio_np.shape) >= 3 and audio_np.shape[0] > 1:
240
+ os.makedirs(output_path, exist_ok=True)
241
+ for i in range(audio_np.shape[0]):
242
+ item = audio_np[i].squeeze()
243
+ item = self._prepare_for_save(item, normalize)
244
+ path = os.path.join(output_path, f"{batch_prefix}{i}.wav")
245
+ sf.write(path, item, sampling_rate)
246
+ saved_paths.append(path)
247
+ else:
248
+ item = self._prepare_for_save(audio_np.squeeze(), normalize)
249
+ sf.write(output_path, item, sampling_rate)
250
+ saved_paths.append(output_path)
251
+
252
+ return saved_paths
253
+
254
+ def _prepare_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
255
+ """Prepare audio for saving."""
256
+ if len(audio.shape) > 1 and audio.shape[0] == 1:
257
+ audio = audio.squeeze(0)
258
+
259
+ if normalize:
260
+ max_val = np.abs(audio).max()
261
+ if max_val > 0:
262
+ audio = audio / max_val
263
+
264
+ return audio
265
+
266
+ def to_dict(self) -> Dict[str, Any]:
267
+ """Convert to dictionary for serialization."""
268
+ return self.feature_extractor_dict
kugelaudio_open/processors/kugelaudio_processor.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main processor for KugelAudio combining text and audio processing."""
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from kugelaudio_open.processors.audio_processor import AudioNormalizer, AudioProcessor
11
+ from transformers.tokenization_utils_base import (
12
+ BatchEncoding,
13
+ PaddingStrategy,
14
+ TruncationStrategy,
15
+ )
16
+ from transformers.utils import TensorType, cached_file, logging
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ class KugelAudioProcessor:
22
+ """Combined processor for KugelAudio text and audio.
23
+
24
+ Wraps a text tokenizer and audio processor into a single interface
25
+ for preparing inputs for KugelAudio models.
26
+
27
+ Example:
28
+ >>> processor = KugelAudioProcessor.from_pretrained("kugelaudio/kugelaudio-0-open")
29
+ >>> inputs = processor(text="Hello world", voice_prompt=voice_audio)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ tokenizer=None,
35
+ audio_processor: Optional[AudioProcessor] = None,
36
+ speech_compression_ratio: int = 3200,
37
+ db_normalize: bool = True,
38
+ **kwargs,
39
+ ):
40
+ self.tokenizer = tokenizer
41
+ self.audio_processor = audio_processor or AudioProcessor()
42
+ self.speech_compression_ratio = speech_compression_ratio
43
+ self.db_normalize = db_normalize
44
+ self.audio_normalizer = AudioNormalizer() if db_normalize else None
45
+
46
+ @classmethod
47
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
48
+ """Load processor from pretrained model.
49
+
50
+ Args:
51
+ pretrained_model_name_or_path: Model ID or local path
52
+
53
+ Returns:
54
+ KugelAudioProcessor instance
55
+ """
56
+ from kugelaudio_open.processors.text_tokenizer import KugelAudioTextTokenizer
57
+
58
+ # Try to load config
59
+ config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
60
+ config = None
61
+
62
+ if os.path.exists(config_path):
63
+ with open(config_path, "r") as f:
64
+ config = json.load(f)
65
+ else:
66
+ try:
67
+ config_file = cached_file(
68
+ pretrained_model_name_or_path, "preprocessor_config.json", **kwargs
69
+ )
70
+ with open(config_file, "r") as f:
71
+ config = json.load(f)
72
+ except Exception as e:
73
+ logger.warning(f"Could not load config: {e}. Using defaults.")
74
+ config = {
75
+ "speech_compression_ratio": 3200,
76
+ "db_normalize": True,
77
+ }
78
+
79
+ # Extract parameters
80
+ speech_compression_ratio = config.get("speech_compression_ratio", 3200)
81
+ db_normalize = config.get("db_normalize", True)
82
+
83
+ # Load tokenizer
84
+ lm_name = config.get("language_model_pretrained_name") or kwargs.pop(
85
+ "language_model_pretrained_name", "Qwen/Qwen2.5-1.5B"
86
+ )
87
+ logger.info(f"Loading tokenizer from {lm_name}")
88
+ tokenizer = KugelAudioTextTokenizer.from_pretrained(lm_name, **kwargs)
89
+
90
+ # Load audio processor
91
+ if "audio_processor" in config:
92
+ audio_config = config["audio_processor"]
93
+ audio_processor = AudioProcessor(
94
+ sampling_rate=audio_config.get("sampling_rate", 24000),
95
+ normalize_audio=audio_config.get("normalize_audio", True),
96
+ target_dB_FS=audio_config.get("target_dB_FS", -25),
97
+ )
98
+ else:
99
+ audio_processor = AudioProcessor()
100
+
101
+ return cls(
102
+ tokenizer=tokenizer,
103
+ audio_processor=audio_processor,
104
+ speech_compression_ratio=speech_compression_ratio,
105
+ db_normalize=db_normalize,
106
+ )
107
+
108
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
109
+ """Save processor to directory."""
110
+ os.makedirs(save_directory, exist_ok=True)
111
+
112
+ config = {
113
+ "processor_class": "KugelAudioProcessor",
114
+ "speech_compression_ratio": self.speech_compression_ratio,
115
+ "db_normalize": self.db_normalize,
116
+ "audio_processor": {
117
+ "feature_extractor_type": "AudioProcessor",
118
+ "sampling_rate": getattr(self.audio_processor, "sampling_rate", 24000),
119
+ "normalize_audio": getattr(self.audio_processor, "normalize_audio", True),
120
+ "target_dB_FS": getattr(self.audio_processor, "target_dB_FS", -25),
121
+ },
122
+ }
123
+
124
+ config_path = os.path.join(save_directory, "preprocessor_config.json")
125
+ with open(config_path, "w") as f:
126
+ json.dump(config, f, indent=2)
127
+
128
+ logger.info(f"Processor saved to {config_path}")
129
+
130
+ def __call__(
131
+ self,
132
+ text: Optional[str] = None,
133
+ voice_prompt: Optional[Union[np.ndarray, torch.Tensor, str]] = None,
134
+ padding: Union[bool, str, PaddingStrategy] = True,
135
+ truncation: Union[bool, str, TruncationStrategy] = False,
136
+ max_length: Optional[int] = None,
137
+ return_tensors: Optional[Union[str, TensorType]] = None,
138
+ **kwargs,
139
+ ) -> BatchEncoding:
140
+ """Process text and optional voice prompt.
141
+
142
+ Args:
143
+ text: Input text to synthesize
144
+ voice_prompt: Voice prompt audio for speaker identity (raw audio tensor or path)
145
+ padding: Padding strategy
146
+ truncation: Truncation strategy
147
+ max_length: Maximum sequence length
148
+ return_tensors: Return format
149
+
150
+ Returns:
151
+ BatchEncoding with processed inputs including speech_input_mask for voice cloning
152
+ """
153
+ if text is None:
154
+ raise ValueError("Text input is required")
155
+
156
+ # Special token IDs
157
+ speech_start_id = 151652 # <|vision_start|> repurposed for speech
158
+ speech_diffusion_id = 151654 # VAE token used as placeholder
159
+
160
+ # Format text with proper template
161
+ # Add speaker prefix if not present (use Speaker 0 to match training format)
162
+ formatted_text = text.strip()
163
+ if not formatted_text.startswith("Speaker"):
164
+ formatted_text = f"Speaker 0: {formatted_text}"
165
+
166
+ # Build the full prompt template matching the training format
167
+ system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
168
+
169
+ # Start building tokens and speech_input_mask
170
+ full_tokens = []
171
+ speech_input_mask = []
172
+ voice_audio = None
173
+
174
+ # System prompt tokens
175
+ system_tokens = self.tokenizer.encode(system_prompt, add_special_tokens=False)
176
+ full_tokens.extend(system_tokens)
177
+ speech_input_mask.extend([False] * len(system_tokens))
178
+
179
+ # Process voice prompt if provided
180
+ if voice_prompt is not None:
181
+ # Load audio if it's a path
182
+ if isinstance(voice_prompt, str):
183
+ voice_audio = self.audio_processor._load_from_path(voice_prompt)
184
+ if self.db_normalize and self.audio_normalizer:
185
+ voice_audio = self.audio_normalizer(voice_audio)
186
+ elif isinstance(voice_prompt, np.ndarray):
187
+ voice_audio = voice_prompt.astype(np.float32)
188
+ elif isinstance(voice_prompt, torch.Tensor):
189
+ voice_audio = voice_prompt.cpu().numpy()
190
+ if voice_audio.ndim > 1:
191
+ voice_audio = voice_audio.squeeze()
192
+ voice_audio = voice_audio.astype(np.float32)
193
+
194
+ # Voice input section with placeholder tokens
195
+ voice_input_tokens = self.tokenizer.encode(" Voice input:\n", add_special_tokens=False)
196
+ full_tokens.extend(voice_input_tokens)
197
+ speech_input_mask.extend([False] * len(voice_input_tokens))
198
+
199
+ # Speaker prefix for voice
200
+ speaker_prefix = self.tokenizer.encode(" Speaker 0:", add_special_tokens=False)
201
+ full_tokens.extend(speaker_prefix)
202
+ speech_input_mask.extend([False] * len(speaker_prefix))
203
+
204
+ # Calculate number of VAE tokens needed based on audio length
205
+ # compression ratio is typically 3200 samples per token at 24kHz
206
+ num_voice_tokens = math.ceil(len(voice_audio) / self.speech_compression_ratio)
207
+
208
+ # Add placeholder VAE tokens that will be replaced with speech embeddings
209
+ full_tokens.extend([speech_diffusion_id] * num_voice_tokens)
210
+ speech_input_mask.extend([True] * num_voice_tokens) # These positions get speech embeddings
211
+
212
+ # Newline after voice
213
+ newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False)
214
+ full_tokens.extend(newline_tokens)
215
+ speech_input_mask.extend([False] * len(newline_tokens))
216
+
217
+ # Text input section
218
+ text_input_tokens = self.tokenizer.encode(" Text input:\n", add_special_tokens=False)
219
+ full_tokens.extend(text_input_tokens)
220
+ speech_input_mask.extend([False] * len(text_input_tokens))
221
+
222
+ # Speaker text
223
+ speaker_text_tokens = self.tokenizer.encode(f" {formatted_text}\n", add_special_tokens=False)
224
+ full_tokens.extend(speaker_text_tokens)
225
+ speech_input_mask.extend([False] * len(speaker_text_tokens))
226
+
227
+ # Speech output section
228
+ speech_output_tokens = self.tokenizer.encode(" Speech output:\n", add_special_tokens=False)
229
+ full_tokens.extend(speech_output_tokens)
230
+ speech_input_mask.extend([False] * len(speech_output_tokens))
231
+
232
+ # Add speech_start token
233
+ full_tokens.append(speech_start_id)
234
+ speech_input_mask.append(False)
235
+
236
+ result = BatchEncoding()
237
+ result["text_ids"] = full_tokens
238
+ result["speech_input_mask"] = speech_input_mask
239
+
240
+ if return_tensors == "pt":
241
+ result["text_ids"] = torch.tensor([full_tokens], dtype=torch.long)
242
+ result["speech_input_mask"] = torch.tensor([speech_input_mask], dtype=torch.bool)
243
+
244
+ # Include processed voice audio for the model to encode
245
+ if voice_audio is not None:
246
+ if return_tensors == "pt":
247
+ result["speech_tensors"] = torch.tensor(voice_audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
248
+ # Create speech_masks (all True for the voice frames)
249
+ num_frames = math.ceil(len(voice_audio) / self.speech_compression_ratio)
250
+ result["speech_masks"] = torch.ones(1, num_frames, dtype=torch.bool)
251
+ else:
252
+ result["speech_tensors"] = voice_audio
253
+ num_frames = math.ceil(len(voice_audio) / self.speech_compression_ratio)
254
+ result["speech_masks"] = [True] * num_frames
255
+
256
+ return result
257
+
258
+ def process_with_cached_prompt(
259
+ self,
260
+ text: str,
261
+ cached_prompt: Dict[str, Any],
262
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
263
+ **kwargs,
264
+ ) -> BatchEncoding:
265
+ """Process text with pre-computed voice prompt cache.
266
+
267
+ Args:
268
+ text: Input text to synthesize
269
+ cached_prompt: Pre-computed KV cache from voice prompt
270
+ return_tensors: Return format
271
+
272
+ Returns:
273
+ BatchEncoding ready for generation
274
+ """
275
+ script_tokens = self.tokenizer.encode(text.strip() + "\n", add_special_tokens=False)
276
+
277
+ lm_length = cached_prompt["lm"]["last_hidden_state"].size(1)
278
+ tts_lm_length = cached_prompt["tts_lm"]["last_hidden_state"].size(1)
279
+
280
+ # Create pseudo input IDs
281
+ input_ids = [self.tokenizer.pad_id] * lm_length
282
+ tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_length
283
+ speech_input_mask = [False] * tts_lm_length
284
+
285
+ result = BatchEncoding()
286
+
287
+ if return_tensors == "pt":
288
+ result["input_ids"] = torch.tensor([input_ids], dtype=torch.long)
289
+ result["tts_lm_input_ids"] = torch.tensor([tts_lm_input_ids], dtype=torch.long)
290
+ result["tts_text_ids"] = torch.tensor([script_tokens], dtype=torch.long)
291
+ result["attention_mask"] = torch.ones(1, lm_length, dtype=torch.long)
292
+ result["tts_lm_attention_mask"] = torch.ones(1, tts_lm_length, dtype=torch.long)
293
+ result["speech_input_mask"] = torch.tensor([speech_input_mask], dtype=torch.bool)
294
+ else:
295
+ result["input_ids"] = [input_ids]
296
+ result["tts_lm_input_ids"] = [tts_lm_input_ids]
297
+ result["tts_text_ids"] = [script_tokens]
298
+ result["attention_mask"] = [[1] * lm_length]
299
+ result["tts_lm_attention_mask"] = [[1] * tts_lm_length]
300
+ result["speech_input_mask"] = [speech_input_mask]
301
+
302
+ return result
303
+
304
+ def prepare_speech_inputs(
305
+ self,
306
+ speech_inputs: List[np.ndarray],
307
+ return_tensors: Optional[Union[str, TensorType]] = None,
308
+ device: Optional[Union[str, torch.device]] = None,
309
+ dtype: Optional[torch.dtype] = None,
310
+ ) -> Dict[str, Any]:
311
+ """Prepare speech inputs for model.
312
+
313
+ Args:
314
+ speech_inputs: List of speech arrays
315
+ return_tensors: Return format
316
+ device: Device to place tensors
317
+ dtype: Data type for tensors
318
+
319
+ Returns:
320
+ Dictionary with padded speeches and masks
321
+ """
322
+ if not speech_inputs:
323
+ return {"padded_speeches": None, "speech_masks": None}
324
+
325
+ # Calculate sequence lengths
326
+ seq_lens = [math.ceil(s.shape[0] / self.speech_compression_ratio) for s in speech_inputs]
327
+ max_speech_len = max(s.shape[0] for s in speech_inputs)
328
+
329
+ # Pad speeches
330
+ padded = np.zeros((len(speech_inputs), max_speech_len), dtype=np.float32)
331
+ masks = np.zeros((len(speech_inputs), max(seq_lens)), dtype=np.bool_)
332
+
333
+ for i, (speech, seq_len) in enumerate(zip(speech_inputs, seq_lens)):
334
+ padded[i, : len(speech)] = speech
335
+ masks[i, :seq_len] = True
336
+
337
+ result = {"padded_speeches": padded, "speech_masks": masks}
338
+
339
+ if return_tensors == "pt":
340
+ result["padded_speeches"] = torch.tensor(
341
+ padded, device=device, dtype=dtype or torch.float32
342
+ )
343
+ result["speech_masks"] = torch.tensor(masks, device=device, dtype=torch.bool)
344
+
345
+ return result
346
+
347
+ def batch_decode(self, *args, **kwargs):
348
+ """Decode token IDs to text."""
349
+ return self.tokenizer.batch_decode(*args, **kwargs)
350
+
351
+ def decode(self, *args, **kwargs):
352
+ """Decode token IDs to text."""
353
+ return self.tokenizer.decode(*args, **kwargs)
354
+
355
+ def save_audio(self, audio, output_path: str = "output.wav", **kwargs) -> List[str]:
356
+ """Save generated audio to file."""
357
+ return self.audio_processor.save_audio(audio, output_path, **kwargs)
358
+
359
+ @property
360
+ def model_input_names(self) -> List[str]:
361
+ """Return list of model input names."""
362
+ tokenizer_names = getattr(self.tokenizer, "model_input_names", [])
363
+ audio_names = getattr(self.audio_processor, "model_input_names", [])
364
+ return list(
365
+ dict.fromkeys(tokenizer_names + audio_names + ["speech_inputs", "speech_input_mask"])
366
+ )
kugelaudio_open/processors/text_tokenizer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text tokenizer for KugelAudio based on Qwen2."""
2
+
3
+ from typing import List, Optional
4
+
5
+ from transformers.utils import logging
6
+ from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class KugelAudioTextTokenizer(Qwen2TokenizerFast):
12
+ """Text tokenizer for KugelAudio with speech special tokens.
13
+
14
+ Based on Qwen2 tokenizer with additional tokens for speech synthesis:
15
+ - speech_start: Marks the beginning of speech generation
16
+ - speech_end: Marks the end of speech generation
17
+ - speech_diffusion: Placeholder for diffusion tokens
18
+
19
+ Example:
20
+ >>> tokenizer = KugelAudioTextTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B")
21
+ >>> tokens = tokenizer.encode("Hello world")
22
+ """
23
+
24
+ model_input_names = ["input_ids", "attention_mask"]
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_file=None,
29
+ merges_file=None,
30
+ tokenizer_file=None,
31
+ unk_token="<|endoftext|>",
32
+ bos_token=None,
33
+ eos_token="<|endoftext|>",
34
+ pad_token="<|endoftext|>",
35
+ add_prefix_space=False,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(
39
+ vocab_file=vocab_file,
40
+ merges_file=merges_file,
41
+ tokenizer_file=tokenizer_file,
42
+ unk_token=unk_token,
43
+ bos_token=bos_token,
44
+ eos_token=eos_token,
45
+ pad_token=pad_token,
46
+ add_prefix_space=add_prefix_space,
47
+ **kwargs,
48
+ )
49
+
50
+ self._add_speech_special_tokens()
51
+
52
+ def _add_speech_special_tokens(self):
53
+ """Add KugelAudio-specific special tokens for speech."""
54
+ special_tokens = {
55
+ "additional_special_tokens": [
56
+ "<|vision_start|>", # Speech start (reusing vision tokens for compatibility)
57
+ "<|vision_end|>", # Speech end
58
+ "<|vision_pad|>", # Speech diffusion pad
59
+ ]
60
+ }
61
+ self.add_special_tokens(special_tokens)
62
+
63
+ # Cache special token IDs
64
+ self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
65
+ self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
66
+ self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
67
+ self._eos_id = self.eos_token_id
68
+ self._pad_id = self.convert_tokens_to_ids("<|image_pad|>")
69
+
70
+ @property
71
+ def eos_id(self) -> int:
72
+ """End of sequence token ID."""
73
+ return self._eos_id
74
+
75
+ @property
76
+ def speech_start_id(self) -> int:
77
+ """Speech start token ID."""
78
+ return self._speech_start_id
79
+
80
+ @property
81
+ def speech_end_id(self) -> int:
82
+ """Speech end token ID."""
83
+ return self._speech_end_id
84
+
85
+ @property
86
+ def speech_diffusion_id(self) -> int:
87
+ """Speech diffusion placeholder token ID."""
88
+ return self._speech_diffusion_id
89
+
90
+ @property
91
+ def pad_id(self) -> int:
92
+ """Padding token ID (returns -100 for loss masking)."""
93
+ return self._pad_id
kugelaudio_open/schedule/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """KugelAudio scheduling components."""
2
+
3
+ from .dpm_solver import DPMSolverMultistepScheduler
4
+
5
+ __all__ = ["DPMSolverMultistepScheduler"]
kugelaudio_open/schedule/dpm_solver.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import deprecate
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+ def betas_for_alpha_bar(
29
+ num_diffusion_timesteps,
30
+ max_beta=0.999,
31
+ alpha_transform_type="cosine",
32
+ ):
33
+ """
34
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
35
+ (1-beta) over time from t = [0,1].
36
+
37
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
38
+ to that part of the diffusion process.
39
+
40
+
41
+ Args:
42
+ num_diffusion_timesteps (`int`): the number of betas to produce.
43
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
44
+ prevent singularities.
45
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
46
+ Choose from `cosine` or `exp`
47
+
48
+ Returns:
49
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
50
+ """
51
+ if alpha_transform_type == "cosine":
52
+
53
+ def alpha_bar_fn(t):
54
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
55
+ # return math.cos(t * math.pi / 2 * 0.95) ** 2
56
+
57
+ elif alpha_transform_type == "exp":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.exp(t * -12.0)
61
+
62
+ elif alpha_transform_type == "cauchy":
63
+ # Β΅ + Ξ³ tan (Ο€ (0.5 - x)) Ξ³ = 1, Β΅ = 3
64
+ # alpha^2 = 1-1/(exp(Ξ»)+1)
65
+ def alpha_bar_fn(t, gamma=1, mu=3):
66
+ snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9)
67
+ return 1 - 1 / (math.exp(snr) + 1.1)
68
+
69
+ elif alpha_transform_type == "laplace":
70
+ # Β΅ βˆ’ bsgn(0.5 βˆ’ t) log(1 βˆ’ 2|t βˆ’ 0.5|) Β΅ = 0, b = 1
71
+ def alpha_bar_fn(t, mu=0, b=1):
72
+ snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98)
73
+ return 1 - 1 / (math.exp(snr) + 1.02)
74
+
75
+ else:
76
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
77
+
78
+ betas = []
79
+ for i in range(num_diffusion_timesteps):
80
+ t1 = i / num_diffusion_timesteps
81
+ t2 = (i + 1) / num_diffusion_timesteps
82
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
83
+ return torch.tensor(betas, dtype=torch.float32)
84
+
85
+
86
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
87
+ def rescale_zero_terminal_snr(betas):
88
+ """
89
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
90
+
91
+
92
+ Args:
93
+ betas (`torch.Tensor`):
94
+ the betas that the scheduler is being initialized with.
95
+
96
+ Returns:
97
+ `torch.Tensor`: rescaled betas with zero terminal SNR
98
+ """
99
+ # Convert betas to alphas_bar_sqrt
100
+ alphas = 1.0 - betas
101
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
102
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
103
+
104
+ # Store old values.
105
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
106
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
107
+
108
+ # Shift so the last timestep is zero.
109
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
110
+
111
+ # Scale so the first timestep is back to the old value.
112
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
113
+
114
+ # Convert alphas_bar_sqrt to betas
115
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
116
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
117
+ alphas = torch.cat([alphas_bar[0:1], alphas])
118
+ betas = 1 - alphas
119
+
120
+ return betas
121
+
122
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
123
+ """
124
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
125
+
126
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
127
+ methods the library implements for all schedulers such as loading and saving.
128
+
129
+ Args:
130
+ num_train_timesteps (`int`, defaults to 1000):
131
+ The number of diffusion steps to train the model.
132
+ beta_start (`float`, defaults to 0.0001):
133
+ The starting `beta` value of inference.
134
+ beta_end (`float`, defaults to 0.02):
135
+ The final `beta` value.
136
+ beta_schedule (`str`, defaults to `"linear"`):
137
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
138
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
139
+ trained_betas (`np.ndarray`, *optional*):
140
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
141
+ solver_order (`int`, defaults to 2):
142
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
143
+ sampling, and `solver_order=3` for unconditional sampling.
144
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
145
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
146
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
147
+ Video](https://imagen.research.google/video/paper.pdf) paper).
148
+ thresholding (`bool`, defaults to `False`):
149
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
150
+ as Stable Diffusion.
151
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
152
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
153
+ sample_max_value (`float`, defaults to 1.0):
154
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
155
+ `algorithm_type="dpmsolver++"`.
156
+ algorithm_type (`str`, defaults to `dpmsolver++`):
157
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
158
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
159
+ paper, and the `dpmsolver++` type implements the algorithms in the
160
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
161
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
162
+ solver_type (`str`, defaults to `midpoint`):
163
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
164
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
165
+ lower_order_final (`bool`, defaults to `True`):
166
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
167
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
168
+ euler_at_final (`bool`, defaults to `False`):
169
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
170
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
171
+ steps, but sometimes may result in blurring.
172
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
173
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
174
+ the sigmas are determined according to a sequence of noise levels {Οƒi}.
175
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
176
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
177
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
178
+ `lambda(t)`.
179
+ final_sigmas_type (`str`, defaults to `"zero"`):
180
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
181
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
182
+ lambda_min_clipped (`float`, defaults to `-inf`):
183
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
184
+ cosine (`squaredcos_cap_v2`) noise schedule.
185
+ variance_type (`str`, *optional*):
186
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
187
+ contains the predicted Gaussian variance.
188
+ timestep_spacing (`str`, defaults to `"linspace"`):
189
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
190
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
191
+ steps_offset (`int`, defaults to 0):
192
+ An offset added to the inference steps, as required by some model families.
193
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
194
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
195
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
196
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
197
+ """
198
+
199
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
200
+ order = 1
201
+
202
+ @register_to_config
203
+ def __init__(
204
+ self,
205
+ num_train_timesteps: int = 1000,
206
+ beta_start: float = 0.0001,
207
+ beta_end: float = 0.02,
208
+ beta_schedule: str = "linear",
209
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
210
+ solver_order: int = 2,
211
+ prediction_type: str = "epsilon",
212
+ thresholding: bool = False,
213
+ dynamic_thresholding_ratio: float = 0.995,
214
+ sample_max_value: float = 1.0,
215
+ algorithm_type: str = "dpmsolver++",
216
+ solver_type: str = "midpoint",
217
+ lower_order_final: bool = True,
218
+ euler_at_final: bool = False,
219
+ use_karras_sigmas: Optional[bool] = False,
220
+ use_lu_lambdas: Optional[bool] = False,
221
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
222
+ lambda_min_clipped: float = -float("inf"),
223
+ variance_type: Optional[str] = None,
224
+ timestep_spacing: str = "linspace",
225
+ steps_offset: int = 0,
226
+ rescale_betas_zero_snr: bool = False,
227
+ ):
228
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
229
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
230
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
231
+
232
+ if trained_betas is not None:
233
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
234
+ elif beta_schedule == "linear":
235
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
236
+ elif beta_schedule == "scaled_linear":
237
+ # this schedule is very specific to the latent diffusion model.
238
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
239
+ elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine":
240
+ # Glide cosine schedule
241
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
242
+ elif beta_schedule == "cauchy":
243
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy")
244
+ elif beta_schedule == "laplace":
245
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
246
+ else:
247
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
248
+
249
+ if rescale_betas_zero_snr:
250
+ self.betas = rescale_zero_terminal_snr(self.betas)
251
+
252
+ self.alphas = 1.0 - self.betas
253
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
254
+
255
+ if rescale_betas_zero_snr:
256
+ # Close to 0 without being 0 so first sigma is not inf
257
+ # FP16 smallest positive subnormal works well here
258
+ self.alphas_cumprod[-1] = 2**-24
259
+
260
+ # Currently we only support VP-type noise schedule
261
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
262
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
263
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
264
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
265
+
266
+ # standard deviation of the initial noise distribution
267
+ self.init_noise_sigma = 1.0
268
+
269
+ # settings for DPM-Solver
270
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
271
+ if algorithm_type == "deis":
272
+ self.register_to_config(algorithm_type="dpmsolver++")
273
+ else:
274
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
275
+
276
+ if solver_type not in ["midpoint", "heun"]:
277
+ if solver_type in ["logrho", "bh1", "bh2"]:
278
+ self.register_to_config(solver_type="midpoint")
279
+ else:
280
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
281
+
282
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
283
+ raise ValueError(
284
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
285
+ )
286
+
287
+ # setable values
288
+ self.num_inference_steps = None
289
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
290
+ self.timesteps = torch.from_numpy(timesteps)
291
+ self.model_outputs = [None] * solver_order
292
+ self.lower_order_nums = 0
293
+ self._step_index = None
294
+ self._begin_index = None
295
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
296
+
297
+ @property
298
+ def step_index(self):
299
+ """
300
+ The index counter for current timestep. It will increase 1 after each scheduler step.
301
+ """
302
+ return self._step_index
303
+
304
+ @property
305
+ def begin_index(self):
306
+ """
307
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
308
+ """
309
+ return self._begin_index
310
+
311
+ def set_begin_index(self, begin_index: int = 0):
312
+ """
313
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
314
+
315
+ Args:
316
+ begin_index (`int`):
317
+ The begin index for the scheduler.
318
+ """
319
+ self._begin_index = begin_index
320
+
321
+ def set_timesteps(
322
+ self,
323
+ num_inference_steps: int = None,
324
+ device: Union[str, torch.device] = None,
325
+ timesteps: Optional[List[int]] = None,
326
+ ):
327
+ """
328
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
329
+
330
+ Args:
331
+ num_inference_steps (`int`):
332
+ The number of diffusion steps used when generating samples with a pre-trained model.
333
+ device (`str` or `torch.device`, *optional*):
334
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
335
+ timesteps (`List[int]`, *optional*):
336
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
337
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
338
+ must be `None`, and `timestep_spacing` attribute will be ignored.
339
+ """
340
+ if num_inference_steps is None and timesteps is None:
341
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
342
+ if num_inference_steps is not None and timesteps is not None:
343
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
344
+ if timesteps is not None and self.config.use_karras_sigmas:
345
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
346
+ if timesteps is not None and self.config.use_lu_lambdas:
347
+ raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
348
+
349
+ if timesteps is not None:
350
+ timesteps = np.array(timesteps).astype(np.int64)
351
+ else:
352
+ # Clipping the minimum of all lambda(t) for numerical stability.
353
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
354
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
355
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
356
+
357
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
358
+ if self.config.timestep_spacing == "linspace":
359
+ timesteps = (
360
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1)
361
+ .round()[::-1][:-1]
362
+ .copy()
363
+ .astype(np.int64)
364
+ )
365
+ elif self.config.timestep_spacing == "leading":
366
+ step_ratio = last_timestep // (num_inference_steps + 1)
367
+ # creates integer timesteps by multiplying by ratio
368
+ # casting to int to avoid issues when num_inference_step is power of 3
369
+ timesteps = (
370
+ (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
371
+ )
372
+ timesteps += self.config.steps_offset
373
+ elif self.config.timestep_spacing == "trailing":
374
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
375
+ # creates integer timesteps by multiplying by ratio
376
+ # casting to int to avoid issues when num_inference_step is power of 3
377
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
378
+ timesteps -= 1
379
+ else:
380
+ raise ValueError(
381
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
382
+ )
383
+
384
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
385
+ log_sigmas = np.log(sigmas)
386
+
387
+ if self.config.use_karras_sigmas:
388
+ sigmas = np.flip(sigmas).copy()
389
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
390
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
391
+ elif self.config.use_lu_lambdas:
392
+ lambdas = np.flip(log_sigmas.copy())
393
+ lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
394
+ sigmas = np.exp(lambdas)
395
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
396
+ else:
397
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
398
+
399
+ if self.config.final_sigmas_type == "sigma_min":
400
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
401
+ elif self.config.final_sigmas_type == "zero":
402
+ sigma_last = 0
403
+ else:
404
+ raise ValueError(
405
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
406
+ )
407
+
408
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
409
+
410
+ self.sigmas = torch.from_numpy(sigmas)
411
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
412
+
413
+ self.num_inference_steps = len(timesteps)
414
+
415
+ self.model_outputs = [
416
+ None,
417
+ ] * self.config.solver_order
418
+ self.lower_order_nums = 0
419
+
420
+ # add an index counter for schedulers that allow duplicated timesteps
421
+ self._step_index = None
422
+ self._begin_index = None
423
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
424
+
425
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
426
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
427
+ """
428
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
429
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
430
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
431
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
432
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
433
+
434
+ https://arxiv.org/abs/2205.11487
435
+ """
436
+ dtype = sample.dtype
437
+ batch_size, channels, *remaining_dims = sample.shape
438
+
439
+ if dtype not in (torch.float32, torch.float64):
440
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
441
+
442
+ # Flatten sample for doing quantile calculation along each image
443
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
444
+
445
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
446
+
447
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
448
+ s = torch.clamp(
449
+ s, min=1, max=self.config.sample_max_value
450
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
451
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
452
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
453
+
454
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
455
+ sample = sample.to(dtype)
456
+
457
+ return sample
458
+
459
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
460
+ def _sigma_to_t(self, sigma, log_sigmas):
461
+ # get log sigma
462
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
463
+
464
+ # get distribution
465
+ dists = log_sigma - log_sigmas[:, np.newaxis]
466
+
467
+ # get sigmas range
468
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
469
+ high_idx = low_idx + 1
470
+
471
+ low = log_sigmas[low_idx]
472
+ high = log_sigmas[high_idx]
473
+
474
+ # interpolate sigmas
475
+ w = (low - log_sigma) / (low - high)
476
+ w = np.clip(w, 0, 1)
477
+
478
+ # transform interpolation to time range
479
+ t = (1 - w) * low_idx + w * high_idx
480
+ t = t.reshape(sigma.shape)
481
+ return t
482
+
483
+ def _sigma_to_alpha_sigma_t(self, sigma):
484
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
485
+ sigma_t = sigma * alpha_t
486
+
487
+ return alpha_t, sigma_t
488
+
489
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
490
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
491
+ """Constructs the noise schedule of Karras et al. (2022)."""
492
+
493
+ # Hack to make sure that other schedulers which copy this function don't break
494
+ # TODO: Add this logic to the other schedulers
495
+ if hasattr(self.config, "sigma_min"):
496
+ sigma_min = self.config.sigma_min
497
+ else:
498
+ sigma_min = None
499
+
500
+ if hasattr(self.config, "sigma_max"):
501
+ sigma_max = self.config.sigma_max
502
+ else:
503
+ sigma_max = None
504
+
505
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
506
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
507
+
508
+ rho = 7.0 # 7.0 is the value used in the paper
509
+ ramp = np.linspace(0, 1, num_inference_steps)
510
+ min_inv_rho = sigma_min ** (1 / rho)
511
+ max_inv_rho = sigma_max ** (1 / rho)
512
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
513
+ return sigmas
514
+
515
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
516
+ """Constructs the noise schedule of Lu et al. (2022)."""
517
+
518
+ lambda_min: float = in_lambdas[-1].item()
519
+ lambda_max: float = in_lambdas[0].item()
520
+
521
+ rho = 1.0 # 1.0 is the value used in the paper
522
+ ramp = np.linspace(0, 1, num_inference_steps)
523
+ min_inv_rho = lambda_min ** (1 / rho)
524
+ max_inv_rho = lambda_max ** (1 / rho)
525
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
526
+ return lambdas
527
+
528
+ def convert_model_output(
529
+ self,
530
+ model_output: torch.Tensor,
531
+ *args,
532
+ sample: torch.Tensor = None,
533
+ **kwargs,
534
+ ) -> torch.Tensor:
535
+ """
536
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
537
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
538
+ integral of the data prediction model.
539
+
540
+ <Tip>
541
+
542
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
543
+ prediction and data prediction models.
544
+
545
+ </Tip>
546
+
547
+ Args:
548
+ model_output (`torch.Tensor`):
549
+ The direct output from the learned diffusion model.
550
+ sample (`torch.Tensor`):
551
+ A current instance of a sample created by the diffusion process.
552
+
553
+ Returns:
554
+ `torch.Tensor`:
555
+ The converted model output.
556
+ """
557
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
558
+ if sample is None:
559
+ if len(args) > 1:
560
+ sample = args[1]
561
+ else:
562
+ raise ValueError("missing `sample` as a required keyward argument")
563
+ if timestep is not None:
564
+ deprecate(
565
+ "timesteps",
566
+ "1.0.0",
567
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
568
+ )
569
+
570
+ # Guard against out-of-bounds access (can occur in concurrent scenarios)
571
+ safe_step_index = min(self.step_index, len(self.sigmas) - 1)
572
+
573
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
574
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
575
+ if self.config.prediction_type == "epsilon":
576
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
577
+ if self.config.variance_type in ["learned", "learned_range"]:
578
+ model_output = model_output[:, :3]
579
+ sigma = self.sigmas[safe_step_index]
580
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
581
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
582
+ elif self.config.prediction_type == "sample":
583
+ x0_pred = model_output
584
+ elif self.config.prediction_type == "v_prediction":
585
+ sigma = self.sigmas[safe_step_index]
586
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
587
+ x0_pred = alpha_t * sample - sigma_t * model_output
588
+ else:
589
+ raise ValueError(
590
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
591
+ " `v_prediction` for the DPMSolverMultistepScheduler."
592
+ )
593
+
594
+ if self.config.thresholding:
595
+ x0_pred = self._threshold_sample(x0_pred)
596
+
597
+ return x0_pred
598
+
599
+ # DPM-Solver needs to solve an integral of the noise prediction model.
600
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
601
+ if self.config.prediction_type == "epsilon":
602
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
603
+ if self.config.variance_type in ["learned", "learned_range"]:
604
+ epsilon = model_output[:, :3]
605
+ else:
606
+ epsilon = model_output
607
+ elif self.config.prediction_type == "sample":
608
+ sigma = self.sigmas[safe_step_index]
609
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
610
+ epsilon = (sample - alpha_t * model_output) / sigma_t
611
+ elif self.config.prediction_type == "v_prediction":
612
+ sigma = self.sigmas[safe_step_index]
613
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
614
+ epsilon = alpha_t * model_output + sigma_t * sample
615
+ else:
616
+ raise ValueError(
617
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
618
+ " `v_prediction` for the DPMSolverMultistepScheduler."
619
+ )
620
+
621
+ if self.config.thresholding:
622
+ sigma = self.sigmas[safe_step_index]
623
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
624
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
625
+ x0_pred = self._threshold_sample(x0_pred)
626
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
627
+
628
+ return epsilon
629
+
630
+ def dpm_solver_first_order_update(
631
+ self,
632
+ model_output: torch.Tensor,
633
+ *args,
634
+ sample: torch.Tensor = None,
635
+ noise: Optional[torch.Tensor] = None,
636
+ **kwargs,
637
+ ) -> torch.Tensor:
638
+ """
639
+ One step for the first-order DPMSolver (equivalent to DDIM).
640
+
641
+ Args:
642
+ model_output (`torch.Tensor`):
643
+ The direct output from the learned diffusion model.
644
+ sample (`torch.Tensor`):
645
+ A current instance of a sample created by the diffusion process.
646
+
647
+ Returns:
648
+ `torch.Tensor`:
649
+ The sample tensor at the previous timestep.
650
+ """
651
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
652
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
653
+ if sample is None:
654
+ if len(args) > 2:
655
+ sample = args[2]
656
+ else:
657
+ raise ValueError(" missing `sample` as a required keyward argument")
658
+ if timestep is not None:
659
+ deprecate(
660
+ "timesteps",
661
+ "1.0.0",
662
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
663
+ )
664
+
665
+ if prev_timestep is not None:
666
+ deprecate(
667
+ "prev_timestep",
668
+ "1.0.0",
669
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
670
+ )
671
+
672
+ # Guard against out-of-bounds access (can occur in concurrent scenarios)
673
+ current_index = min(self.step_index, len(self.sigmas) - 1)
674
+ next_index = min(self.step_index + 1, len(self.sigmas) - 1)
675
+ sigma_t, sigma_s = self.sigmas[next_index], self.sigmas[current_index]
676
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
677
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
678
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
679
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
680
+
681
+ h = lambda_t - lambda_s
682
+ if self.config.algorithm_type == "dpmsolver++":
683
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
684
+ elif self.config.algorithm_type == "dpmsolver":
685
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
686
+ elif self.config.algorithm_type == "sde-dpmsolver++":
687
+ assert noise is not None
688
+ x_t = (
689
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
690
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
691
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
692
+ )
693
+ elif self.config.algorithm_type == "sde-dpmsolver":
694
+ assert noise is not None
695
+ x_t = (
696
+ (alpha_t / alpha_s) * sample
697
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
698
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
699
+ )
700
+ return x_t
701
+
702
+ def multistep_dpm_solver_second_order_update(
703
+ self,
704
+ model_output_list: List[torch.Tensor],
705
+ *args,
706
+ sample: torch.Tensor = None,
707
+ noise: Optional[torch.Tensor] = None,
708
+ **kwargs,
709
+ ) -> torch.Tensor:
710
+ """
711
+ One step for the second-order multistep DPMSolver.
712
+
713
+ Args:
714
+ model_output_list (`List[torch.Tensor]`):
715
+ The direct outputs from learned diffusion model at current and latter timesteps.
716
+ sample (`torch.Tensor`):
717
+ A current instance of a sample created by the diffusion process.
718
+
719
+ Returns:
720
+ `torch.Tensor`:
721
+ The sample tensor at the previous timestep.
722
+ """
723
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
724
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
725
+ if sample is None:
726
+ if len(args) > 2:
727
+ sample = args[2]
728
+ else:
729
+ raise ValueError(" missing `sample` as a required keyward argument")
730
+ if timestep_list is not None:
731
+ deprecate(
732
+ "timestep_list",
733
+ "1.0.0",
734
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
735
+ )
736
+
737
+ if prev_timestep is not None:
738
+ deprecate(
739
+ "prev_timestep",
740
+ "1.0.0",
741
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
742
+ )
743
+
744
+ # Guard against out-of-bounds access (can occur in concurrent scenarios)
745
+ current_index = min(self.step_index, len(self.sigmas) - 1)
746
+ next_index = min(self.step_index + 1, len(self.sigmas) - 1)
747
+ prev_index = max(self.step_index - 1, 0)
748
+ sigma_t, sigma_s0, sigma_s1 = (
749
+ self.sigmas[next_index],
750
+ self.sigmas[current_index],
751
+ self.sigmas[prev_index],
752
+ )
753
+
754
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
755
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
756
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
757
+
758
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
759
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
760
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
761
+
762
+ m0, m1 = model_output_list[-1], model_output_list[-2]
763
+
764
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
765
+ r0 = h_0 / h
766
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
767
+ if self.config.algorithm_type == "dpmsolver++":
768
+ # See https://arxiv.org/abs/2211.01095 for detailed derivations
769
+ if self.config.solver_type == "midpoint":
770
+ x_t = (
771
+ (sigma_t / sigma_s0) * sample
772
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
773
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
774
+ )
775
+ elif self.config.solver_type == "heun":
776
+ x_t = (
777
+ (sigma_t / sigma_s0) * sample
778
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
779
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
780
+ )
781
+ elif self.config.algorithm_type == "dpmsolver":
782
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
783
+ if self.config.solver_type == "midpoint":
784
+ x_t = (
785
+ (alpha_t / alpha_s0) * sample
786
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
787
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
788
+ )
789
+ elif self.config.solver_type == "heun":
790
+ x_t = (
791
+ (alpha_t / alpha_s0) * sample
792
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
793
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
794
+ )
795
+ elif self.config.algorithm_type == "sde-dpmsolver++":
796
+ assert noise is not None
797
+ if self.config.solver_type == "midpoint":
798
+ x_t = (
799
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
800
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
801
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
802
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
803
+ )
804
+ elif self.config.solver_type == "heun":
805
+ x_t = (
806
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
807
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
808
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
809
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
810
+ )
811
+ elif self.config.algorithm_type == "sde-dpmsolver":
812
+ assert noise is not None
813
+ if self.config.solver_type == "midpoint":
814
+ x_t = (
815
+ (alpha_t / alpha_s0) * sample
816
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
817
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
818
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
819
+ )
820
+ elif self.config.solver_type == "heun":
821
+ x_t = (
822
+ (alpha_t / alpha_s0) * sample
823
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
824
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
825
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
826
+ )
827
+ return x_t
828
+
829
+ def multistep_dpm_solver_third_order_update(
830
+ self,
831
+ model_output_list: List[torch.Tensor],
832
+ *args,
833
+ sample: torch.Tensor = None,
834
+ **kwargs,
835
+ ) -> torch.Tensor:
836
+ """
837
+ One step for the third-order multistep DPMSolver.
838
+
839
+ Args:
840
+ model_output_list (`List[torch.Tensor]`):
841
+ The direct outputs from learned diffusion model at current and latter timesteps.
842
+ sample (`torch.Tensor`):
843
+ A current instance of a sample created by diffusion process.
844
+
845
+ Returns:
846
+ `torch.Tensor`:
847
+ The sample tensor at the previous timestep.
848
+ """
849
+
850
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
851
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
852
+ if sample is None:
853
+ if len(args) > 2:
854
+ sample = args[2]
855
+ else:
856
+ raise ValueError(" missing`sample` as a required keyward argument")
857
+ if timestep_list is not None:
858
+ deprecate(
859
+ "timestep_list",
860
+ "1.0.0",
861
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
862
+ )
863
+
864
+ if prev_timestep is not None:
865
+ deprecate(
866
+ "prev_timestep",
867
+ "1.0.0",
868
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
869
+ )
870
+
871
+ # Guard against out-of-bounds access (can occur in concurrent scenarios)
872
+ current_index = min(self.step_index, len(self.sigmas) - 1)
873
+ next_index = min(self.step_index + 1, len(self.sigmas) - 1)
874
+ prev_index_1 = max(self.step_index - 1, 0)
875
+ prev_index_2 = max(self.step_index - 2, 0)
876
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
877
+ self.sigmas[next_index],
878
+ self.sigmas[current_index],
879
+ self.sigmas[prev_index_1],
880
+ self.sigmas[prev_index_2],
881
+ )
882
+
883
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
884
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
885
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
886
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
887
+
888
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
889
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
890
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
891
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
892
+
893
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
894
+
895
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
896
+ r0, r1 = h_0 / h, h_1 / h
897
+ D0 = m0
898
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
899
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
900
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
901
+ if self.config.algorithm_type == "dpmsolver++":
902
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
903
+ x_t = (
904
+ (sigma_t / sigma_s0) * sample
905
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
906
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
907
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
908
+ )
909
+ elif self.config.algorithm_type == "dpmsolver":
910
+ # See https://arxiv.org/abs/2206.00927 for detailed derivations
911
+ x_t = (
912
+ (alpha_t / alpha_s0) * sample
913
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
914
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
915
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
916
+ )
917
+ return x_t
918
+
919
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
920
+ if schedule_timesteps is None:
921
+ schedule_timesteps = self.timesteps
922
+
923
+ index_candidates = (schedule_timesteps == timestep).nonzero()
924
+
925
+ if len(index_candidates) == 0:
926
+ step_index = len(self.timesteps) - 1
927
+ # The sigma index that is taken for the **very** first `step`
928
+ # is always the second index (or the last index if there is only 1)
929
+ # This way we can ensure we don't accidentally skip a sigma in
930
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
931
+ elif len(index_candidates) > 1:
932
+ step_index = index_candidates[1].item()
933
+ else:
934
+ step_index = index_candidates[0].item()
935
+
936
+ return step_index
937
+
938
+ def _init_step_index(self, timestep):
939
+ """
940
+ Initialize the step_index counter for the scheduler.
941
+ """
942
+
943
+ if self.begin_index is None:
944
+ if isinstance(timestep, torch.Tensor):
945
+ timestep = timestep.to(self.timesteps.device)
946
+ self._step_index = self.index_for_timestep(timestep)
947
+ else:
948
+ self._step_index = self._begin_index
949
+
950
+ def step(
951
+ self,
952
+ model_output: torch.Tensor,
953
+ timestep: int,
954
+ sample: torch.Tensor,
955
+ generator=None,
956
+ variance_noise: Optional[torch.Tensor] = None,
957
+ return_dict: bool = True,
958
+ ) -> Union[SchedulerOutput, Tuple]:
959
+ """
960
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
961
+ the multistep DPMSolver.
962
+
963
+ Args:
964
+ model_output (`torch.Tensor`):
965
+ The direct output from learned diffusion model.
966
+ timestep (`int`):
967
+ The current discrete timestep in the diffusion chain.
968
+ sample (`torch.Tensor`):
969
+ A current instance of a sample created by the diffusion process.
970
+ generator (`torch.Generator`, *optional*):
971
+ A random number generator.
972
+ variance_noise (`torch.Tensor`):
973
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
974
+ itself. Useful for methods such as [`LEdits++`].
975
+ return_dict (`bool`):
976
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
977
+
978
+ Returns:
979
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
980
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
981
+ tuple is returned where the first element is the sample tensor.
982
+
983
+ """
984
+ if self.num_inference_steps is None:
985
+ raise ValueError(
986
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
987
+ )
988
+
989
+ if self.step_index is None:
990
+ self._init_step_index(timestep)
991
+
992
+ # Improve numerical stability for small number of steps
993
+ # Also guard against out-of-bounds access: if step_index >= len(timesteps) - 1,
994
+ # we must use first-order to avoid accessing sigmas[step_index + 1] out of bounds
995
+ is_last_or_past = self.step_index >= len(self.timesteps) - 1
996
+ lower_order_final = is_last_or_past and (
997
+ self.config.euler_at_final
998
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
999
+ or self.config.final_sigmas_type == "zero"
1000
+ or self.step_index >= len(self.sigmas) - 1 # Safety: prevent OOB access
1001
+ )
1002
+ lower_order_second = (
1003
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
1004
+ )
1005
+
1006
+ model_output = self.convert_model_output(model_output, sample=sample)
1007
+ for i in range(self.config.solver_order - 1):
1008
+ self.model_outputs[i] = self.model_outputs[i + 1]
1009
+ self.model_outputs[-1] = model_output
1010
+
1011
+ # Upcast to avoid precision issues when computing prev_sample
1012
+ sample = sample.to(torch.float32)
1013
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
1014
+ noise = randn_tensor(
1015
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
1016
+ )
1017
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
1018
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
1019
+ else:
1020
+ noise = None
1021
+
1022
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
1023
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
1024
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
1025
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
1026
+ else:
1027
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
1028
+
1029
+ if self.lower_order_nums < self.config.solver_order:
1030
+ self.lower_order_nums += 1
1031
+
1032
+ # Cast sample back to expected dtype
1033
+ prev_sample = prev_sample.to(model_output.dtype)
1034
+
1035
+ # upon completion increase step index by one
1036
+ self._step_index += 1
1037
+
1038
+ if not return_dict:
1039
+ return (prev_sample,)
1040
+
1041
+ return SchedulerOutput(prev_sample=prev_sample)
1042
+
1043
+ def add_noise(
1044
+ self,
1045
+ original_samples: torch.Tensor,
1046
+ noise: torch.Tensor,
1047
+ timesteps: torch.IntTensor,
1048
+ ) -> torch.Tensor:
1049
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
1050
+ # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1051
+ # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1052
+ alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1053
+ sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1054
+ timesteps = timesteps.to(original_samples.device)
1055
+ alpha_t = alpha_t[timesteps].flatten()
1056
+ while len(alpha_t.shape) < len(original_samples.shape):
1057
+ alpha_t = alpha_t.unsqueeze(-1)
1058
+
1059
+ sigma_t = sigma_t[timesteps].flatten()
1060
+ while len(sigma_t.shape) < len(original_samples.shape):
1061
+ sigma_t = sigma_t.unsqueeze(-1)
1062
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
1063
+ return noisy_samples
1064
+
1065
+ def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
1066
+ # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype)
1067
+ # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype)
1068
+ alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype)
1069
+ sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype)
1070
+
1071
+ timesteps = timesteps.to(original_samples.device)
1072
+ alpha_t = alpha_t[timesteps].flatten()
1073
+ while len(alpha_t.shape) < len(original_samples.shape):
1074
+ alpha_t = alpha_t.unsqueeze(-1)
1075
+
1076
+ sigma_t = sigma_t[timesteps].flatten()
1077
+ while len(sigma_t.shape) < len(original_samples.shape):
1078
+ sigma_t = sigma_t.unsqueeze(-1)
1079
+
1080
+ velocity = alpha_t * noise - sigma_t * original_samples
1081
+ return velocity
1082
+
1083
+ def __len__(self):
1084
+ return self.config.num_train_timesteps
kugelaudio_open/ui/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Gradio web interface for KugelAudio."""
2
+
3
+ from kugelaudio_open.ui.app import create_app, launch_app
4
+
5
+ __all__ = ["create_app", "launch_app"]
kugelaudio_open/ui/__main__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI entry point for KugelAudio UI."""
2
+
3
+ import argparse
4
+
5
+ from kugelaudio_open.ui import launch_app
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description="Launch KugelAudio Gradio UI")
10
+ parser.add_argument(
11
+ "--share",
12
+ action="store_true",
13
+ help="Create a public Gradio share link",
14
+ )
15
+ parser.add_argument(
16
+ "--host",
17
+ default="127.0.0.1",
18
+ help="Server hostname (default: 127.0.0.1, use 0.0.0.0 for network access)",
19
+ )
20
+ parser.add_argument(
21
+ "--port",
22
+ type=int,
23
+ default=7860,
24
+ help="Server port (default: 7860)",
25
+ )
26
+
27
+ args = parser.parse_args()
28
+
29
+ print(f"πŸŽ™οΈ Starting KugelAudio UI on {args.host}:{args.port}")
30
+ if args.share:
31
+ print("πŸ“‘ Creating public share link...")
32
+
33
+ launch_app(
34
+ share=args.share,
35
+ server_name=args.host,
36
+ server_port=args.port,
37
+ )
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
kugelaudio_open/ui/app.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio web interface for KugelAudio text-to-speech."""
2
+
3
+ import os
4
+ import tempfile
5
+ import warnings
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ try:
12
+ import gradio as gr
13
+
14
+ GRADIO_AVAILABLE = True
15
+ except ImportError:
16
+ GRADIO_AVAILABLE = False
17
+ warnings.warn("Gradio not installed. Install with: pip install gradio")
18
+
19
+
20
+ # Global model instances (lazy loaded)
21
+ _model = None
22
+ _processor = None
23
+ _watermark = None
24
+ _current_model_id = None # Track which model is loaded
25
+
26
+
27
+ def get_device():
28
+ """Get the best available device."""
29
+ if torch.cuda.is_available():
30
+ return "cuda"
31
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
32
+ return "mps"
33
+ return "cpu"
34
+
35
+
36
+ def _warmup_model(model, processor=None):
37
+ """Warmup model components to eliminate CUDA kernel compilation overhead on first generation.
38
+
39
+ This runs dummy data through all model components (acoustic decoder, semantic encoder,
40
+ diffusion head, language model) to trigger JIT compilation before actual inference.
41
+ """
42
+ device = next(model.parameters()).device
43
+ dtype = next(model.parameters()).dtype
44
+
45
+ with torch.no_grad():
46
+ # 1. Warmup acoustic decoder (biggest impact - saves ~190ms on first call)
47
+ latent_dim = model.config.acoustic_vae_dim
48
+ dummy_latent = torch.randn(1, latent_dim, 1, device=device, dtype=dtype)
49
+ _ = model.acoustic_tokenizer.decode(dummy_latent)
50
+
51
+ # 2. Warmup semantic encoder
52
+ dummy_audio = torch.randn(1, 1, 3200, device=device, dtype=dtype)
53
+ _ = model.semantic_tokenizer.encode(dummy_audio)
54
+
55
+ # 3. Warmup diffusion/prediction head
56
+ hidden_size = model.config.decoder_config.hidden_size
57
+ model.noise_scheduler.set_timesteps(model.ddpm_inference_steps)
58
+
59
+ dummy_condition = torch.randn(2, hidden_size, device=device, dtype=dtype)
60
+ dummy_speech = torch.randn(2, latent_dim, device=device, dtype=dtype)
61
+
62
+ for t in model.noise_scheduler.timesteps:
63
+ half = dummy_speech[:1]
64
+ combined = torch.cat([half, half], dim=0)
65
+ _ = model.prediction_head(
66
+ combined,
67
+ t.repeat(combined.shape[0]).to(combined),
68
+ condition=dummy_condition,
69
+ )
70
+ dummy_eps = torch.randn_like(dummy_speech)
71
+ dummy_speech = model.noise_scheduler.step(dummy_eps, t, dummy_speech).prev_sample
72
+
73
+ # 4. Warmup language model with KV cache path
74
+ dummy_ids = torch.randint(0, 32000, (1, 64), device=device)
75
+ dummy_mask = torch.ones_like(dummy_ids)
76
+ _ = model.model.language_model(input_ids=dummy_ids, attention_mask=dummy_mask, use_cache=True)
77
+
78
+ # 5. Warmup acoustic encoder (for voice prompts)
79
+ dummy_voice = torch.randn(1, 1, 24000, device=device, dtype=dtype)
80
+ _ = model.acoustic_tokenizer.encode(dummy_voice)
81
+
82
+ # 6. Run a minimal generation to warmup the full generation path
83
+ if processor is not None:
84
+ dummy_inputs = processor(text="Hi.", return_tensors="pt")
85
+ dummy_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs.items()}
86
+ _ = model.generate(**dummy_inputs, cfg_scale=3.0, max_new_tokens=10, show_progress=False)
87
+
88
+ # Clear memory
89
+ if device.type == "cuda":
90
+ torch.cuda.empty_cache()
91
+
92
+
93
+ def load_models(model_id: str = "kugelaudio/kugelaudio-0-open"):
94
+ """Load model and processor. Switches model if a different model_id is requested."""
95
+ global _model, _processor, _watermark, _current_model_id
96
+
97
+ from kugelaudio_open.models import KugelAudioForConditionalGenerationInference
98
+ from kugelaudio_open.processors import KugelAudioProcessor
99
+ from kugelaudio_open.watermark import AudioWatermark
100
+
101
+ device = get_device()
102
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
103
+
104
+ # Check if we need to load a different model
105
+ if _model is None or _current_model_id != model_id:
106
+ # Clean up old model if switching
107
+ if _model is not None and _current_model_id != model_id:
108
+ print(f"Switching model from {_current_model_id} to {model_id}...")
109
+ del _model
110
+ del _processor
111
+ _model = None
112
+ _processor = None
113
+ # Clear CUDA cache to free memory
114
+ if device == "cuda":
115
+ torch.cuda.empty_cache()
116
+
117
+ print(f"Loading model {model_id} on {device}...")
118
+ try:
119
+ _model = KugelAudioForConditionalGenerationInference.from_pretrained(
120
+ model_id,
121
+ torch_dtype=dtype,
122
+ attn_implementation="flash_attention_2" if device == "cuda" else "sdpa",
123
+ ).to(device)
124
+ except Exception:
125
+ _model = KugelAudioForConditionalGenerationInference.from_pretrained(
126
+ model_id,
127
+ torch_dtype=dtype,
128
+ ).to(device)
129
+ _model.eval()
130
+ _current_model_id = model_id
131
+ print(f"Model {model_id} loaded!")
132
+
133
+ if _processor is None:
134
+ _processor = KugelAudioProcessor.from_pretrained(model_id)
135
+
136
+ # Warmup to eliminate first-generation slowness from CUDA kernel compilation
137
+ # Do this after processor is loaded so we can run a mini-generation
138
+ if device == "cuda" and _model is not None:
139
+ # Check if we need to warmup (only on first load)
140
+ if not getattr(_model, "_warmed_up", False):
141
+ print("Warming up model (this may take a moment)...")
142
+ _warmup_model(_model, _processor)
143
+ _model._warmed_up = True
144
+ print("Warmup complete!")
145
+
146
+ if _watermark is None:
147
+ _watermark = AudioWatermark(device=device)
148
+
149
+ return _model, _processor, _watermark
150
+
151
+
152
+ def generate_speech(
153
+ text: str,
154
+ reference_audio: Optional[Tuple[int, np.ndarray]] = None,
155
+ model_choice: str = "kugelaudio-0-open",
156
+ cfg_scale: float = 3.0,
157
+ max_tokens: int = 2048,
158
+ ) -> Tuple[int, np.ndarray]:
159
+ """Generate speech from text.
160
+
161
+ Args:
162
+ text: Text to synthesize
163
+ reference_audio: Optional (sample_rate, audio_array) for voice cloning
164
+ model_choice: Model variant to use
165
+ cfg_scale: Classifier-free guidance scale
166
+ max_tokens: Maximum generation tokens
167
+
168
+ Returns:
169
+ Tuple of (sample_rate, audio_array)
170
+
171
+ Note:
172
+ All generated audio is automatically watermarked for identification.
173
+ """
174
+ if not text.strip():
175
+ raise gr.Error("Please enter some text to synthesize.")
176
+
177
+ model_id = f"kugelaudio/{model_choice}"
178
+ model, processor, watermark = load_models(model_id)
179
+ device = next(model.parameters()).device
180
+
181
+ # Process reference audio if provided
182
+ voice_audio = None
183
+ if reference_audio is not None:
184
+ ref_sr, ref_audio = reference_audio
185
+ print(f"[Voice Cloning] Input audio: sr={ref_sr}, shape={ref_audio.shape}, dtype={ref_audio.dtype}")
186
+
187
+ # Convert to float32 and normalize based on dtype
188
+ if ref_audio.dtype == np.int16:
189
+ ref_audio = ref_audio.astype(np.float32) / 32768.0
190
+ elif ref_audio.dtype == np.int32:
191
+ ref_audio = ref_audio.astype(np.float32) / 2147483648.0
192
+ elif ref_audio.dtype == np.float64:
193
+ ref_audio = ref_audio.astype(np.float32)
194
+ elif ref_audio.dtype != np.float32:
195
+ ref_audio = ref_audio.astype(np.float32)
196
+
197
+ # Ensure mono BEFORE resampling (important for stereo files)
198
+ if ref_audio.ndim > 1:
199
+ if ref_audio.shape[0] == 2: # [2, samples] format (channels first)
200
+ ref_audio = ref_audio.mean(axis=0)
201
+ elif ref_audio.shape[-1] == 2: # [samples, 2] format (channels last)
202
+ ref_audio = ref_audio.mean(axis=-1)
203
+ elif ref_audio.shape[0] < ref_audio.shape[-1]: # Likely [channels, samples]
204
+ ref_audio = ref_audio.mean(axis=0)
205
+ else: # Likely [samples, channels]
206
+ ref_audio = ref_audio.mean(axis=-1)
207
+
208
+ # Ensure 1D
209
+ ref_audio = ref_audio.squeeze()
210
+
211
+ print(f"[Voice Cloning] After mono conversion: shape={ref_audio.shape}, dtype={ref_audio.dtype}")
212
+
213
+ # Resample to 24kHz if needed - this is CRITICAL for voice cloning
214
+ if ref_sr != 24000:
215
+ import librosa
216
+ print(f"[Voice Cloning] Resampling from {ref_sr}Hz to 24000Hz (ratio: {ref_sr/24000:.4f})")
217
+ ref_audio = librosa.resample(ref_audio, orig_sr=ref_sr, target_sr=24000)
218
+ print(f"[Voice Cloning] After resampling: shape={ref_audio.shape}, duration={len(ref_audio)/24000:.2f}s")
219
+ else:
220
+ print(f"[Voice Cloning] No resampling needed, already at 24kHz")
221
+
222
+ # Normalize audio to reasonable range
223
+ max_val = np.abs(ref_audio).max()
224
+ if max_val > 0:
225
+ ref_audio = ref_audio / max_val * 0.95
226
+
227
+ voice_audio = ref_audio
228
+ print(f"[Voice Cloning] Final voice audio: shape={voice_audio.shape}, min={voice_audio.min():.4f}, max={voice_audio.max():.4f}, std={voice_audio.std():.4f}")
229
+
230
+ # Process text input with optional voice prompt
231
+ inputs = processor(text=text.strip(), voice_prompt=voice_audio, return_tensors="pt")
232
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
233
+
234
+ print(f"[Generation] Using model: {model_id}, cfg_scale={cfg_scale}, max_tokens={max_tokens}")
235
+
236
+ # Generate
237
+ with torch.no_grad():
238
+ outputs = model.generate(
239
+ **inputs,
240
+ cfg_scale=cfg_scale,
241
+ max_new_tokens=max_tokens,
242
+ )
243
+
244
+ if not outputs.speech_outputs or outputs.speech_outputs[0] is None:
245
+ raise gr.Error("Generation failed. Please try again with different settings.")
246
+
247
+ # Audio is already watermarked by the model's generate method
248
+ audio = outputs.speech_outputs[0]
249
+ print(f"[Generation] Raw output: shape={audio.shape}, dtype={audio.dtype}")
250
+
251
+ # Convert to numpy (convert to float32 first since numpy doesn't support bfloat16)
252
+ if isinstance(audio, torch.Tensor):
253
+ audio = audio.cpu().float().numpy()
254
+
255
+ # Ensure correct shape (1D array)
256
+ audio = audio.squeeze()
257
+
258
+ # Normalize to prevent clipping (important for Gradio playback)
259
+ max_val = np.abs(audio).max()
260
+ if max_val > 1.0:
261
+ audio = audio / max_val * 0.95
262
+
263
+ print(f"[Generation] Final output: shape={audio.shape}, dtype={audio.dtype}, duration={len(audio)/24000:.2f}s")
264
+ print(f"[Generation] Audio stats: min={audio.min():.4f}, max={audio.max():.4f}, std={audio.std():.4f}")
265
+
266
+ # Return with explicit sample rate - Gradio expects (sample_rate, audio_array)
267
+ return (24000, audio)
268
+
269
+
270
+ def check_watermark(audio: Tuple[int, np.ndarray]) -> str:
271
+ """Check if audio contains KugelAudio watermark."""
272
+ if audio is None:
273
+ return "No audio provided."
274
+
275
+ from kugelaudio_open.watermark import AudioWatermark
276
+
277
+ sr, audio_data = audio
278
+
279
+ # Convert to float32 if needed
280
+ if audio_data.dtype == np.int16:
281
+ audio_data = audio_data.astype(np.float32) / 32768.0
282
+ elif audio_data.dtype == np.int32:
283
+ audio_data = audio_data.astype(np.float32) / 2147483648.0
284
+
285
+ watermark = AudioWatermark()
286
+ result = watermark.detect(audio_data, sample_rate=sr)
287
+
288
+ if result.detected:
289
+ return f"βœ… **Watermark Detected**\n\nConfidence: {result.confidence:.1%}\n\nThis audio was generated by KugelAudio."
290
+ else:
291
+ return f"❌ **No Watermark Detected**\n\nConfidence: {result.confidence:.1%}\n\nThis audio does not appear to be generated by KugelAudio."
292
+
293
+
294
+ def create_app() -> "gr.Blocks":
295
+ """Create the Gradio application."""
296
+ if not GRADIO_AVAILABLE:
297
+ raise ImportError("Gradio not installed. Install with: pip install gradio")
298
+
299
+ # Logo URLs
300
+ kugelaudio_logo = "https://www.kugelaudio.com/logos/Logo%20Short.svg"
301
+ kisz_logo = "https://docs.sc.hpi.de/attachments/aisc/aisc-logo.png"
302
+ bmftr_logo = (
303
+ "https://hpi.de/fileadmin/_processed_/a/3/csm_BMFTR_de_Web_RGB_gef_durch_cd1f5345bd.jpg"
304
+ )
305
+
306
+ with gr.Blocks(title="KugelAudio - Text to Speech") as app:
307
+ gr.HTML(
308
+ f"""
309
+ <div style="text-align: center; margin-bottom: 1.5rem;">
310
+ <h1 style="margin-bottom: 0.5rem;">πŸŽ™οΈ KugelAudio</h1>
311
+ <p style="color: #666; margin-bottom: 1rem;">Open-source text-to-speech with voice cloning capabilities</p>
312
+ <div style="display: flex; justify-content: center; align-items: center; gap: 2rem; flex-wrap: wrap;">
313
+ <a href="https://kugelaudio.com" target="_blank">
314
+ <img src="{kugelaudio_logo}" alt="KugelAudio" style="height: 50px; width: auto;">
315
+ </a>
316
+ <a href="https://hpi.de/ki-servicezentrum/" target="_blank">
317
+ <img src="{kisz_logo}" alt="KI-Servicezentrum Berlin-Brandenburg" style="height: 50px; width: auto;">
318
+ </a>
319
+ <a href="https://www.bmftr.bund.de" target="_blank">
320
+ <img src="{bmftr_logo}" alt="GefΓΆrdert durch BMFTR" style="height: 70px; width: auto;">
321
+ </a>
322
+ </div>
323
+ </div>
324
+ """
325
+ )
326
+
327
+ with gr.Tabs():
328
+ # Tab 1: Text to Speech
329
+ with gr.TabItem("πŸ—£οΈ Generate Speech"):
330
+ with gr.Row():
331
+ with gr.Column(scale=1):
332
+ text_input = gr.Textbox(
333
+ label="Text to Synthesize",
334
+ placeholder="Enter the text you want to convert to speech...",
335
+ lines=5,
336
+ max_lines=20,
337
+ )
338
+
339
+ reference_audio = gr.Audio(
340
+ label="Reference Audio (Optional)",
341
+ type="numpy",
342
+ sources=["upload", "microphone"],
343
+ )
344
+ gr.Markdown("*Upload a voice sample to clone the speaker's voice*")
345
+
346
+ with gr.Accordion("Advanced Settings", open=False):
347
+ model_choice = gr.Dropdown(
348
+ choices=["kugelaudio-0-open"],
349
+ value="kugelaudio-0-open",
350
+ label="Model",
351
+ )
352
+ cfg_scale = gr.Slider(
353
+ minimum=1.0,
354
+ maximum=10.0,
355
+ value=3.0,
356
+ step=0.5,
357
+ label="Guidance Scale",
358
+ info="Higher values = more adherence to text",
359
+ )
360
+ max_tokens = gr.Slider(
361
+ minimum=512,
362
+ maximum=8192,
363
+ value=2048,
364
+ step=256,
365
+ label="Max Tokens",
366
+ info="Maximum generation length",
367
+ )
368
+
369
+ generate_btn = gr.Button("🎀 Generate Speech", variant="primary", size="lg")
370
+
371
+ with gr.Column(scale=1):
372
+ output_audio = gr.Audio(
373
+ label="Generated Speech",
374
+ type="numpy",
375
+ interactive=False,
376
+ )
377
+
378
+ gr.Markdown(
379
+ """
380
+ ### Tips
381
+ - For best results, use clear and well-punctuated text
382
+ - Reference audio should be 5-30 seconds of clear speech
383
+ - The 7B model produces higher quality but is slower
384
+ """
385
+ )
386
+
387
+ generate_btn.click(
388
+ fn=generate_speech,
389
+ inputs=[text_input, reference_audio, model_choice, cfg_scale, max_tokens],
390
+ outputs=[output_audio],
391
+ )
392
+
393
+ # Tab 2: Watermark Detection
394
+ with gr.TabItem("πŸ” Verify Watermark"):
395
+ gr.Markdown(
396
+ """
397
+ ### Watermark Verification
398
+ Check if an audio file was generated by KugelAudio. All audio generated
399
+ by KugelAudio contains an imperceptible watermark for identification.
400
+ """
401
+ )
402
+
403
+ with gr.Row():
404
+ with gr.Column():
405
+ verify_audio = gr.Audio(
406
+ label="Audio to Verify",
407
+ type="numpy",
408
+ sources=["upload"],
409
+ )
410
+ verify_btn = gr.Button("πŸ” Check Watermark", variant="secondary")
411
+
412
+ with gr.Column():
413
+ verify_result = gr.Markdown(
414
+ label="Result",
415
+ value="Upload an audio file to check for watermark.",
416
+ )
417
+
418
+ verify_btn.click(
419
+ fn=check_watermark,
420
+ inputs=[verify_audio],
421
+ outputs=[verify_result],
422
+ )
423
+
424
+ # Tab 3: About
425
+ with gr.TabItem("ℹ️ About"):
426
+ gr.Markdown(
427
+ """
428
+ ## About KugelAudio
429
+
430
+ KugelAudio is an open-source text-to-speech system that combines:
431
+
432
+ - **AR + Diffusion Architecture**: Uses autoregressive language modeling
433
+ with diffusion-based speech synthesis for high-quality output
434
+ - **Voice Cloning**: Clone any voice with just a few seconds of reference audio
435
+ - **Audio Watermarking**: All generated audio contains an imperceptible watermark
436
+ using [Facebook's AudioSeal](https://huggingface.co/facebook/audioseal) technology
437
+
438
+ ### Models
439
+
440
+ | Model | Parameters | Quality | Speed |
441
+ |-------|------------|---------|-------|
442
+ | kugelaudio-0-open | 7B | Best | Standard |
443
+
444
+ ### Responsible Use
445
+
446
+ This technology is intended for legitimate purposes such as:
447
+ - Accessibility (text-to-speech for visually impaired)
448
+ - Content creation (podcasts, videos, audiobooks)
449
+ - Voice assistants and chatbots
450
+
451
+ **Please do not use this technology for:**
452
+ - Creating deepfakes or misleading content
453
+ - Impersonating individuals without consent
454
+ - Any illegal or harmful purposes
455
+
456
+ All generated audio is watermarked to enable detection.
457
+ """
458
+ )
459
+
460
+ gr.HTML(
461
+ """
462
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #eee;">
463
+ <p style="color: #888; margin-bottom: 0.5rem;">
464
+ <strong>KugelAudio</strong> β€’ Open Source TTS with Voice Cloning
465
+ </p>
466
+ <p style="color: #aaa; font-size: 0.9rem;">
467
+ Created by <a href="mailto:kajo@kugelaudio.com" style="color: #667eea;">Kajo Kratzenstein</a> β€’
468
+ <a href="https://kugelaudio.com" style="color: #667eea;">kugelaudio.com</a> β€’
469
+ <a href="https://github.com/kugelaudio/kugelaudio" style="color: #667eea;">GitHub</a>
470
+ </p>
471
+ </div>
472
+ """
473
+ )
474
+
475
+ return app
476
+
477
+
478
+ def launch_app(
479
+ share: bool = False,
480
+ server_name: str = "127.0.0.1",
481
+ server_port: int = 7860,
482
+ **kwargs,
483
+ ):
484
+ """Launch the Gradio web interface.
485
+
486
+ Args:
487
+ share: Create a public share link
488
+ server_name: Server hostname (use "0.0.0.0" for network access)
489
+ server_port: Server port
490
+ **kwargs: Additional arguments passed to gr.Blocks.launch()
491
+ """
492
+ app = create_app()
493
+ app.launch(
494
+ share=share,
495
+ server_name=server_name,
496
+ server_port=server_port,
497
+ theme=gr.themes.Soft(
498
+ primary_hue="indigo",
499
+ secondary_hue="slate",
500
+ ),
501
+ **kwargs,
502
+ )
503
+
504
+
505
+ if __name__ == "__main__":
506
+ launch_app()
kugelaudio_open/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Utility functions for KugelAudio."""
2
+
3
+ from kugelaudio_open.utils.generation import generate_speech, load_model_and_processor
4
+
5
+ __all__ = ["generate_speech", "load_model_and_processor"]
kugelaudio_open/utils/generation.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """High-level generation utilities for KugelAudio."""
2
+
3
+ from typing import Optional, Union, Tuple
4
+ import torch
5
+
6
+
7
+ def load_model_and_processor(
8
+ model_name_or_path: str = "kugelaudio/kugelaudio-0-open",
9
+ device: Optional[Union[str, torch.device]] = None,
10
+ torch_dtype: Optional[torch.dtype] = None,
11
+ use_flash_attention: bool = True,
12
+ ):
13
+ """Load KugelAudio model and processor.
14
+
15
+ Args:
16
+ model_name_or_path: HuggingFace model ID or local path
17
+ device: Device to load model on (auto-detected if None)
18
+ torch_dtype: Data type for model weights
19
+ use_flash_attention: Whether to use flash attention if available
20
+
21
+ Returns:
22
+ Tuple of (model, processor)
23
+
24
+ Example:
25
+ >>> model, processor = load_model_and_processor("kugelaudio/kugelaudio-0-open")
26
+ """
27
+ from kugelaudio_open.models import KugelAudioForConditionalGenerationInference
28
+ from kugelaudio_open.processors import KugelAudioProcessor
29
+
30
+ # Auto-detect device
31
+ if device is None:
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ # Auto-detect dtype
35
+ if torch_dtype is None:
36
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
+
38
+ # Load model
39
+ attn_impl = "flash_attention_2" if use_flash_attention else "sdpa"
40
+ try:
41
+ model = KugelAudioForConditionalGenerationInference.from_pretrained(
42
+ model_name_or_path,
43
+ torch_dtype=torch_dtype,
44
+ attn_implementation=attn_impl,
45
+ ).to(device)
46
+ except Exception:
47
+ # Fallback without flash attention
48
+ model = KugelAudioForConditionalGenerationInference.from_pretrained(
49
+ model_name_or_path,
50
+ torch_dtype=torch_dtype,
51
+ ).to(device)
52
+
53
+ model.eval()
54
+
55
+ # Load processor
56
+ processor = KugelAudioProcessor.from_pretrained(model_name_or_path)
57
+
58
+ return model, processor
59
+
60
+
61
+ def generate_speech(
62
+ model,
63
+ processor,
64
+ text: str,
65
+ voice_prompt: Optional[torch.Tensor] = None,
66
+ voice_prompt_path: Optional[str] = None,
67
+ cfg_scale: float = 3.0,
68
+ max_new_tokens: int = 4096,
69
+ device: Optional[Union[str, torch.device]] = None,
70
+ ) -> torch.Tensor:
71
+ """Generate speech from text.
72
+
73
+ All generated audio is automatically watermarked for identification.
74
+
75
+ Args:
76
+ model: KugelAudio model
77
+ processor: KugelAudio processor
78
+ text: Text to synthesize
79
+ voice_prompt: Voice prompt tensor for speaker identity
80
+ voice_prompt_path: Path to voice prompt audio file
81
+ cfg_scale: Classifier-free guidance scale
82
+ max_new_tokens: Maximum number of tokens to generate
83
+ device: Device for generation
84
+
85
+ Returns:
86
+ Generated audio tensor (watermarked)
87
+
88
+ Example:
89
+ >>> audio = generate_speech(model, processor, "Hello world!")
90
+ >>> processor.save_audio(audio, "output.wav")
91
+ """
92
+ if device is None:
93
+ device = next(model.parameters()).device
94
+
95
+ # Load voice prompt if path provided
96
+ if voice_prompt is None and voice_prompt_path is not None:
97
+ voice_data = processor.audio_processor(voice_prompt_path, return_tensors="pt")
98
+ voice_prompt = voice_data["audio"].to(device)
99
+
100
+ # Process inputs
101
+ inputs = processor(text=text, return_tensors="pt")
102
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
103
+
104
+ # Generate (watermark is automatically applied by the model)
105
+ with torch.no_grad():
106
+ outputs = model.generate(
107
+ **inputs,
108
+ voice_prompt=voice_prompt,
109
+ cfg_scale=cfg_scale,
110
+ max_new_tokens=max_new_tokens,
111
+ )
112
+
113
+ audio = outputs.speech_outputs[0] if outputs.speech_outputs else None
114
+
115
+ if audio is None:
116
+ raise RuntimeError("Generation failed - no audio output")
117
+
118
+ return audio
kugelaudio_open/watermark/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Audio watermarking for KugelAudio generated speech."""
2
+
3
+ from kugelaudio_open.watermark.watermark import AudioWatermark
4
+
5
+ __all__ = ["AudioWatermark"]
kugelaudio_open/watermark/watermark.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio watermarking for KugelAudio using Facebook's AudioSeal.
2
+
3
+ AudioSeal provides state-of-the-art speech localized watermarking with:
4
+ - High robustness to audio editing and compression
5
+ - Fast single-pass detection (real-time capable)
6
+ - Sample-level detection (1/16k second resolution)
7
+ - Optional 16-bit message embedding
8
+
9
+ Reference: https://huggingface.co/facebook/audioseal
10
+ """
11
+
12
+ from typing import Optional, Union, Tuple
13
+ from dataclasses import dataclass
14
+ import warnings
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ # Try to import AudioSeal
20
+ try:
21
+ from audioseal import AudioSeal
22
+ AUDIOSEAL_AVAILABLE = True
23
+ except ImportError:
24
+ AUDIOSEAL_AVAILABLE = False
25
+ warnings.warn(
26
+ "AudioSeal not installed. Install with: pip install audioseal\n"
27
+ "Watermarking will use fallback implementation."
28
+ )
29
+
30
+
31
+ @dataclass
32
+ class WatermarkResult:
33
+ """Result of watermark detection."""
34
+ detected: bool
35
+ confidence: float
36
+ message: Optional[torch.Tensor] = None
37
+ frame_probabilities: Optional[torch.Tensor] = None
38
+
39
+
40
+ class AudioWatermark:
41
+ """Professional audio watermarking using Facebook's AudioSeal.
42
+
43
+ AudioSeal is a state-of-the-art watermarking system that embeds
44
+ imperceptible watermarks in audio that are robust to various
45
+ audio transformations.
46
+
47
+ Features:
48
+ - Imperceptible watermarks with minimal quality degradation
49
+ - Robust to compression, resampling, and editing
50
+ - Fast detection suitable for real-time applications
51
+ - Optional 16-bit message embedding for tracking
52
+
53
+ Example:
54
+ >>> watermark = AudioWatermark()
55
+ >>> watermarked_audio = watermark.embed(audio)
56
+ >>> result = watermark.detect(watermarked_audio)
57
+ >>> print(f"Detected: {result.detected}, Confidence: {result.confidence:.2%}")
58
+
59
+ Args:
60
+ model_name: AudioSeal model variant ("audioseal_wm_16bits")
61
+ device: Device for inference ("cuda" or "cpu")
62
+ message: Optional 16-bit message to embed (for tracking)
63
+ """
64
+
65
+ # Default message identifying KugelAudio-generated content
66
+ KUGELAUDIO_MESSAGE = torch.tensor([[1, 0, 1, 0, 1, 0, 1, 0,
67
+ 0, 1, 0, 1, 0, 1, 0, 1]]) # Alternating pattern
68
+
69
+ # AudioSeal expects 16kHz audio
70
+ AUDIOSEAL_SAMPLE_RATE = 16000
71
+
72
+ def __init__(
73
+ self,
74
+ model_name: str = "audioseal_wm_16bits",
75
+ detector_name: str = "audioseal_detector_16bits",
76
+ device: Optional[Union[str, torch.device]] = None,
77
+ message: Optional[torch.Tensor] = None,
78
+ ):
79
+ if device is None:
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ self.device = torch.device(device)
82
+
83
+ self._generator = None
84
+ self._detector = None
85
+ self._model_name = model_name
86
+ self._detector_name = detector_name
87
+
88
+ # Use KugelAudio identifier message by default
89
+ self.message = message if message is not None else self.KUGELAUDIO_MESSAGE.clone()
90
+
91
+ if not AUDIOSEAL_AVAILABLE:
92
+ warnings.warn(
93
+ "AudioSeal not available. Watermarking disabled. "
94
+ "Install with: pip install audioseal"
95
+ )
96
+
97
+ @property
98
+ def generator(self):
99
+ """Lazy load the generator model."""
100
+ if self._generator is None and AUDIOSEAL_AVAILABLE:
101
+ self._generator = AudioSeal.load_generator(self._model_name)
102
+ self._generator = self._generator.to(self.device)
103
+ self._generator.eval()
104
+ return self._generator
105
+
106
+ @property
107
+ def detector(self):
108
+ """Lazy load the detector model."""
109
+ if self._detector is None and AUDIOSEAL_AVAILABLE:
110
+ self._detector = AudioSeal.load_detector(self._detector_name)
111
+ self._detector = self._detector.to(self.device)
112
+ self._detector.eval()
113
+ return self._detector
114
+
115
+ def _resample(
116
+ self,
117
+ audio: torch.Tensor,
118
+ orig_sr: int,
119
+ target_sr: int
120
+ ) -> torch.Tensor:
121
+ """Resample audio to target sample rate."""
122
+ if orig_sr == target_sr:
123
+ return audio
124
+
125
+ try:
126
+ import torchaudio.functional as F
127
+ return F.resample(audio, orig_sr, target_sr)
128
+ except ImportError:
129
+ # Fallback using scipy
130
+ from scipy import signal
131
+ audio_np = audio.cpu().numpy()
132
+ num_samples = int(len(audio_np.flatten()) * target_sr / orig_sr)
133
+ resampled = signal.resample(audio_np.flatten(), num_samples)
134
+ return torch.from_numpy(resampled).reshape(audio.shape[0], audio.shape[1], -1).to(audio.device)
135
+
136
+ def embed(
137
+ self,
138
+ audio: Union[np.ndarray, torch.Tensor],
139
+ sample_rate: int = 24000,
140
+ message: Optional[torch.Tensor] = None,
141
+ ) -> Union[np.ndarray, torch.Tensor]:
142
+ """Embed watermark into audio.
143
+
144
+ The watermark is imperceptible and robust to various audio
145
+ transformations including compression and resampling.
146
+
147
+ Args:
148
+ audio: Input audio of shape (samples,), (channels, samples),
149
+ or (batch, channels, samples)
150
+ sample_rate: Sample rate of input audio (default: 24000 for KugelAudio)
151
+ message: Optional 16-bit message to embed
152
+
153
+ Returns:
154
+ Watermarked audio with same shape and type as input
155
+ """
156
+ if not AUDIOSEAL_AVAILABLE:
157
+ # Return unchanged if AudioSeal not available
158
+ return audio
159
+
160
+ # Track input type
161
+ is_numpy = isinstance(audio, np.ndarray)
162
+ if is_numpy:
163
+ audio = torch.from_numpy(audio)
164
+
165
+ original_device = audio.device
166
+ original_dtype = audio.dtype
167
+
168
+ # Ensure float32 for processing
169
+ audio = audio.float()
170
+
171
+ # Handle different input shapes
172
+ original_shape = audio.shape
173
+ if audio.ndim == 1:
174
+ # (samples,) -> (1, 1, samples)
175
+ audio = audio.unsqueeze(0).unsqueeze(0)
176
+ elif audio.ndim == 2:
177
+ # (channels, samples) -> (1, channels, samples)
178
+ audio = audio.unsqueeze(0)
179
+
180
+ # Move to device
181
+ audio = audio.to(self.device)
182
+
183
+ # Resample to 16kHz for AudioSeal
184
+ if sample_rate != self.AUDIOSEAL_SAMPLE_RATE:
185
+ audio_16k = self._resample(audio, sample_rate, self.AUDIOSEAL_SAMPLE_RATE)
186
+ else:
187
+ audio_16k = audio
188
+
189
+ # Prepare message
190
+ msg = message if message is not None else self.message
191
+ msg = msg.to(self.device)
192
+ if msg.shape[0] != audio_16k.shape[0]:
193
+ msg = msg.expand(audio_16k.shape[0], -1)
194
+
195
+ # Generate watermark at 16kHz
196
+ with torch.no_grad():
197
+ watermark_16k = self.generator.get_watermark(audio_16k, self.AUDIOSEAL_SAMPLE_RATE, message=msg)
198
+
199
+ # Resample watermark back to original sample rate
200
+ if sample_rate != self.AUDIOSEAL_SAMPLE_RATE:
201
+ watermark = self._resample(watermark_16k, self.AUDIOSEAL_SAMPLE_RATE, sample_rate)
202
+ # Ensure same length as original
203
+ if watermark.shape[-1] != audio.shape[-1]:
204
+ if watermark.shape[-1] > audio.shape[-1]:
205
+ watermark = watermark[..., :audio.shape[-1]]
206
+ else:
207
+ watermark = torch.nn.functional.pad(
208
+ watermark, (0, audio.shape[-1] - watermark.shape[-1])
209
+ )
210
+ # Re-fetch original audio at original sample rate
211
+ audio = self._resample(audio_16k, self.AUDIOSEAL_SAMPLE_RATE, sample_rate)
212
+ if audio.shape[-1] != original_shape[-1] if len(original_shape) > 0 else True:
213
+ # Adjust to match original length
214
+ target_len = original_shape[-1] if original_shape else watermark.shape[-1]
215
+ if audio.shape[-1] > target_len:
216
+ audio = audio[..., :target_len]
217
+ watermark = watermark[..., :target_len]
218
+ else:
219
+ watermark = watermark_16k
220
+
221
+ # Add watermark to audio
222
+ watermarked = audio + watermark
223
+
224
+ # Prevent clipping
225
+ max_val = watermarked.abs().max()
226
+ if max_val > 1.0:
227
+ watermarked = watermarked / max_val
228
+
229
+ # Restore original shape
230
+ if len(original_shape) == 1:
231
+ watermarked = watermarked.squeeze(0).squeeze(0)
232
+ elif len(original_shape) == 2:
233
+ watermarked = watermarked.squeeze(0)
234
+
235
+ # Restore device and dtype
236
+ watermarked = watermarked.to(device=original_device, dtype=original_dtype)
237
+
238
+ # Convert back to numpy if input was numpy
239
+ if is_numpy:
240
+ watermarked = watermarked.numpy()
241
+
242
+ return watermarked
243
+
244
+ def detect(
245
+ self,
246
+ audio: Union[np.ndarray, torch.Tensor],
247
+ sample_rate: int = 24000,
248
+ threshold: float = 0.5,
249
+ ) -> WatermarkResult:
250
+ """Detect watermark in audio.
251
+
252
+ Args:
253
+ audio: Input audio to check for watermark
254
+ sample_rate: Sample rate of input audio
255
+ threshold: Detection threshold (0.0-1.0)
256
+
257
+ Returns:
258
+ WatermarkResult with detection status, confidence, and decoded message
259
+ """
260
+ if not AUDIOSEAL_AVAILABLE:
261
+ return WatermarkResult(
262
+ detected=False,
263
+ confidence=0.0,
264
+ message=None,
265
+ frame_probabilities=None,
266
+ )
267
+
268
+ # Convert to tensor if needed
269
+ if isinstance(audio, np.ndarray):
270
+ audio = torch.from_numpy(audio)
271
+
272
+ audio = audio.float()
273
+
274
+ # Handle different input shapes
275
+ if audio.ndim == 1:
276
+ audio = audio.unsqueeze(0).unsqueeze(0)
277
+ elif audio.ndim == 2:
278
+ audio = audio.unsqueeze(0)
279
+
280
+ audio = audio.to(self.device)
281
+
282
+ # Resample to 16kHz
283
+ if sample_rate != self.AUDIOSEAL_SAMPLE_RATE:
284
+ audio = self._resample(audio, sample_rate, self.AUDIOSEAL_SAMPLE_RATE)
285
+
286
+ # Detect watermark
287
+ with torch.no_grad():
288
+ result, message = self.detector(audio, self.AUDIOSEAL_SAMPLE_RATE)
289
+
290
+ # result shape: (batch, 2, frames) - probabilities for [no_watermark, watermark]
291
+ # Get positive (watermark present) probabilities
292
+ watermark_probs = result[:, 1, :] # (batch, frames)
293
+
294
+ # Calculate overall confidence as mean of frame probabilities
295
+ confidence = watermark_probs.mean().item()
296
+
297
+ # Detection based on threshold
298
+ detected = confidence > threshold
299
+
300
+ return WatermarkResult(
301
+ detected=detected,
302
+ confidence=confidence,
303
+ message=message.cpu() if message is not None else None,
304
+ frame_probabilities=watermark_probs.cpu(),
305
+ )
306
+
307
+ def verify(self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int = 24000) -> bool:
308
+ """Quick verification that audio contains KugelAudio watermark.
309
+
310
+ Args:
311
+ audio: Audio to verify
312
+ sample_rate: Sample rate of audio
313
+
314
+ Returns:
315
+ True if watermark detected with high confidence
316
+ """
317
+ result = self.detect(audio, sample_rate)
318
+ return result.detected and result.confidence > 0.6
319
+
320
+
321
+ class WatermarkPostProcessor:
322
+ """Post-processor that automatically adds watermarks to generated audio.
323
+
324
+ Designed to be integrated into the generation pipeline to ensure
325
+ all generated audio is watermarked transparently.
326
+
327
+ Example:
328
+ >>> post_processor = WatermarkPostProcessor()
329
+ >>> # In generation pipeline:
330
+ >>> audio = model.generate(...)
331
+ >>> audio = post_processor(audio) # Watermark added automatically
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ enabled: bool = True,
337
+ device: Optional[Union[str, torch.device]] = None,
338
+ sample_rate: int = 24000,
339
+ ):
340
+ self.enabled = enabled
341
+ self.sample_rate = sample_rate
342
+ self._watermark = None
343
+ self._device = device
344
+
345
+ @property
346
+ def watermark(self) -> AudioWatermark:
347
+ """Lazy initialization of watermark model."""
348
+ if self._watermark is None:
349
+ self._watermark = AudioWatermark(device=self._device)
350
+ return self._watermark
351
+
352
+ def __call__(
353
+ self,
354
+ audio: Union[np.ndarray, torch.Tensor],
355
+ sample_rate: Optional[int] = None,
356
+ ) -> Union[np.ndarray, torch.Tensor]:
357
+ """Add watermark to audio if enabled."""
358
+ if not self.enabled:
359
+ return audio
360
+
361
+ sr = sample_rate or self.sample_rate
362
+ return self.watermark.embed(audio, sample_rate=sr)
363
+
364
+ def disable(self):
365
+ """Disable watermarking."""
366
+ self.enabled = False
367
+
368
+ def enable(self):
369
+ """Enable watermarking."""
370
+ self.enabled = True
371
+
372
+
373
+ def is_watermarked(
374
+ audio: Union[np.ndarray, torch.Tensor],
375
+ sample_rate: int = 24000,
376
+ threshold: float = 0.5,
377
+ ) -> bool:
378
+ """Convenience function to check if audio is watermarked.
379
+
380
+ Args:
381
+ audio: Audio to check
382
+ sample_rate: Sample rate of audio
383
+ threshold: Detection threshold
384
+
385
+ Returns:
386
+ True if watermark detected
387
+ """
388
+ watermark = AudioWatermark()
389
+ result = watermark.detect(audio, sample_rate, threshold)
390
+ return result.detected