Text-to-Speech
PyTorch
moss_tts_nano
custom_code
hf-upload-bot commited on
Commit
ad3fd89
·
1 Parent(s): 5962079

Upload Nano TTS checkpoint 500000

Browse files
__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_nanotts import NanoTTSConfig
2
+ from .modeling_nanotts_global_local import (
3
+ NanoTTSGenerationOutput,
4
+ NanoTTSGlobalLocalForCausalLM,
5
+ NanoTTSOutput,
6
+ )
7
+ from .tokenization_nanotts_sentencepiece import NanoTTSSentencePieceTokenizer
8
+
9
+ try:
10
+ NanoTTSConfig.register_for_auto_class()
11
+ except Exception:
12
+ pass
13
+
14
+ for auto_class_name in ("AutoModel", "AutoModelForCausalLM"):
15
+ try:
16
+ NanoTTSGlobalLocalForCausalLM.register_for_auto_class(auto_class_name)
17
+ except Exception:
18
+ pass
19
+
20
+ try:
21
+ NanoTTSSentencePieceTokenizer.register_for_auto_class("AutoTokenizer")
22
+ except Exception:
23
+ pass
24
+
25
+ __all__ = [
26
+ "NanoTTSConfig",
27
+ "NanoTTSGlobalLocalForCausalLM",
28
+ "NanoTTSSentencePieceTokenizer",
29
+ "NanoTTSGenerationOutput",
30
+ "NanoTTSOutput",
31
+ ]
config.json ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "NanoTTSGlobalLocalForCausalLM"
5
+ ],
6
+ "attn_implementation": "flash_attention_2",
7
+ "audio_assistant_slot_token_id": 9,
8
+ "audio_codebook_sizes": [
9
+ 1024,
10
+ 1024,
11
+ 1024,
12
+ 1024,
13
+ 1024,
14
+ 1024,
15
+ 1024,
16
+ 1024,
17
+ 1024,
18
+ 1024,
19
+ 1024,
20
+ 1024,
21
+ 1024,
22
+ 1024,
23
+ 1024,
24
+ 1024
25
+ ],
26
+ "audio_end_token_id": 7,
27
+ "audio_pad_token_id": 1024,
28
+ "audio_start_token_id": 6,
29
+ "audio_tokenizer_pretrained_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
30
+ "audio_tokenizer_sample_rate": 48000,
31
+ "audio_tokenizer_type": "moss-audio-tokenizer-nano",
32
+ "audio_user_slot_token_id": 8,
33
+ "audio_vocab_size": 1024,
34
+ "bad_words_ids": null,
35
+ "begin_suppress_tokens": null,
36
+ "bos_token_id": null,
37
+ "chunk_size_feed_forward": 0,
38
+ "cross_attention_hidden_size": null,
39
+ "decoder_start_token_id": null,
40
+ "diversity_penalty": 0.0,
41
+ "do_sample": false,
42
+ "dtype": "float32",
43
+ "early_stopping": false,
44
+ "encoder_no_repeat_ngram_size": 0,
45
+ "eos_token_id": null,
46
+ "exponential_decay_length_penalty": null,
47
+ "finetuning_task": null,
48
+ "forced_bos_token_id": null,
49
+ "forced_eos_token_id": null,
50
+ "gpt2_config": {
51
+ "_name_or_path": "",
52
+ "activation_function": "gelu_new",
53
+ "add_cross_attention": false,
54
+ "architectures": null,
55
+ "attn_pdrop": 0.0,
56
+ "bad_words_ids": null,
57
+ "begin_suppress_tokens": null,
58
+ "bos_token_id": 1,
59
+ "chunk_size_feed_forward": 0,
60
+ "cross_attention_hidden_size": null,
61
+ "decoder_start_token_id": null,
62
+ "diversity_penalty": 0.0,
63
+ "do_sample": false,
64
+ "dtype": null,
65
+ "early_stopping": false,
66
+ "embd_pdrop": 0.0,
67
+ "encoder_no_repeat_ngram_size": 0,
68
+ "eos_token_id": 2,
69
+ "exponential_decay_length_penalty": null,
70
+ "finetuning_task": null,
71
+ "forced_bos_token_id": null,
72
+ "forced_eos_token_id": null,
73
+ "id2label": {
74
+ "0": "LABEL_0",
75
+ "1": "LABEL_1"
76
+ },
77
+ "initializer_range": 0.02,
78
+ "is_decoder": false,
79
+ "is_encoder_decoder": false,
80
+ "label2id": {
81
+ "LABEL_0": 0,
82
+ "LABEL_1": 1
83
+ },
84
+ "layer_norm_epsilon": 1e-05,
85
+ "length_penalty": 1.0,
86
+ "max_length": 20,
87
+ "min_length": 0,
88
+ "model_type": "gpt2",
89
+ "n_ctx": 32768,
90
+ "n_embd": 768,
91
+ "n_head": 12,
92
+ "n_inner": 3072,
93
+ "n_layer": 12,
94
+ "n_positions": 32768,
95
+ "no_repeat_ngram_size": 0,
96
+ "num_beam_groups": 1,
97
+ "num_beams": 1,
98
+ "num_return_sequences": 1,
99
+ "output_attentions": false,
100
+ "output_hidden_states": false,
101
+ "output_scores": false,
102
+ "pad_token_id": 3,
103
+ "position_embedding_type": "rope",
104
+ "prefix": null,
105
+ "problem_type": null,
106
+ "pruned_heads": {},
107
+ "remove_invalid_values": false,
108
+ "reorder_and_upcast_attn": false,
109
+ "repetition_penalty": 1.0,
110
+ "resid_pdrop": 0.0,
111
+ "return_dict": true,
112
+ "return_dict_in_generate": false,
113
+ "rope_base": 10000.0,
114
+ "scale_attn_by_inverse_layer_idx": false,
115
+ "scale_attn_weights": true,
116
+ "sep_token_id": null,
117
+ "summary_activation": null,
118
+ "summary_first_dropout": 0.1,
119
+ "summary_proj_to_labels": true,
120
+ "summary_type": "cls_index",
121
+ "summary_use_proj": true,
122
+ "suppress_tokens": null,
123
+ "task_specific_params": null,
124
+ "temperature": 1.0,
125
+ "tf_legacy_loss": false,
126
+ "tie_encoder_decoder": false,
127
+ "tie_word_embeddings": true,
128
+ "tokenizer_class": null,
129
+ "top_k": 50,
130
+ "top_p": 1.0,
131
+ "torchscript": false,
132
+ "transformers_version": "4.57.1",
133
+ "typical_p": 1.0,
134
+ "use_bfloat16": false,
135
+ "use_cache": true,
136
+ "vocab_size": 16384
137
+ },
138
+ "hidden_size": 768,
139
+ "id2label": {
140
+ "0": "LABEL_0",
141
+ "1": "LABEL_1"
142
+ },
143
+ "im_end_token_id": 5,
144
+ "im_start_token_id": 4,
145
+ "initializer_range": 0.02,
146
+ "is_decoder": false,
147
+ "is_encoder_decoder": false,
148
+ "label2id": {
149
+ "LABEL_0": 0,
150
+ "LABEL_1": 1
151
+ },
152
+ "length_penalty": 1.0,
153
+ "local_transformer_attn_implementation": "flash_attention_2",
154
+ "local_transformer_layers": 1,
155
+ "max_length": 20,
156
+ "max_position_embeddings": 32768,
157
+ "min_length": 0,
158
+ "model_architecture": "global_local_transformer",
159
+ "model_type": "nano_tts",
160
+ "n_vq": 16,
161
+ "no_repeat_ngram_size": 0,
162
+ "num_beam_groups": 1,
163
+ "num_beams": 1,
164
+ "num_return_sequences": 1,
165
+ "output_attentions": false,
166
+ "output_hidden_states": false,
167
+ "output_scores": false,
168
+ "pad_token_id": 3,
169
+ "prefix": null,
170
+ "problem_type": null,
171
+ "pruned_heads": {},
172
+ "remove_invalid_values": false,
173
+ "repetition_penalty": 1.0,
174
+ "return_dict": true,
175
+ "return_dict_in_generate": false,
176
+ "sep_token_id": null,
177
+ "suppress_tokens": null,
178
+ "task_specific_params": null,
179
+ "temperature": 1.0,
180
+ "tf_legacy_loss": false,
181
+ "tie_encoder_decoder": false,
182
+ "tie_word_embeddings": true,
183
+ "tokenizer_class": "NanoTTSSentencePieceTokenizer",
184
+ "tokenizer_use_fast": false,
185
+ "top_k": 50,
186
+ "top_p": 1.0,
187
+ "torchscript": false,
188
+ "transformers_version": "4.57.1",
189
+ "typical_p": 1.0,
190
+ "use_bfloat16": false,
191
+ "vocab_size": 16384,
192
+ "auto_map": {
193
+ "AutoConfig": "configuration_nanotts.NanoTTSConfig",
194
+ "AutoModel": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM",
195
+ "AutoModelForCausalLM": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM"
196
+ }
197
+ }
configuration_nanotts.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
6
+
7
+
8
+ class NanoTTSConfig(PretrainedConfig):
9
+ model_type = "nano_tts"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+
12
+ def __init__(
13
+ self,
14
+ gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
15
+ n_vq: int = 8,
16
+ audio_vocab_size: Optional[int] = 1024,
17
+ audio_codebook_sizes: Optional[list[int]] = None,
18
+ audio_pad_token_id: int = 1024,
19
+ pad_token_id: int = 151643,
20
+ im_start_token_id: int = 151644,
21
+ im_end_token_id: int = 151645,
22
+ audio_start_token_id: int = 151652,
23
+ audio_end_token_id: int = 151653,
24
+ audio_user_slot_token_id: int = 151654,
25
+ audio_assistant_slot_token_id: int = 151656,
26
+ tokenizer_use_fast: bool = False,
27
+ audio_tokenizer_type: str = "moss-audio-tokenizer-nano",
28
+ audio_tokenizer_pretrained_name_or_path: Optional[str] = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
29
+ audio_tokenizer_sample_rate: int = 48000,
30
+ attn_implementation: str = "flash_attention_2",
31
+ initializer_range: float = 0.02,
32
+ model_architecture: str = "global_local_transformer",
33
+ local_transformer_layers: int = 4,
34
+ local_transformer_attn_implementation: Optional[str] = None,
35
+ **kwargs: Any,
36
+ ) -> None:
37
+ if isinstance(gpt2_config, dict):
38
+ self.gpt2_config = GPT2Config(**gpt2_config)
39
+ elif gpt2_config is None:
40
+ self.gpt2_config = GPT2Config()
41
+ else:
42
+ self.gpt2_config = gpt2_config
43
+
44
+ self.n_vq = int(n_vq)
45
+ if audio_codebook_sizes is None:
46
+ if audio_vocab_size is None:
47
+ raise ValueError("audio_vocab_size must be set when audio_codebook_sizes is not provided.")
48
+ resolved_audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
49
+ else:
50
+ resolved_audio_codebook_sizes = [int(codebook_size) for codebook_size in audio_codebook_sizes]
51
+ if len(resolved_audio_codebook_sizes) != self.n_vq:
52
+ raise ValueError(
53
+ "audio_codebook_sizes must have length n_vq "
54
+ f"(expected {self.n_vq}, got {len(resolved_audio_codebook_sizes)})."
55
+ )
56
+ if any(codebook_size <= 0 for codebook_size in resolved_audio_codebook_sizes):
57
+ raise ValueError("audio_codebook_sizes must contain positive integers.")
58
+
59
+ max_audio_codebook_size = max(resolved_audio_codebook_sizes)
60
+ if audio_vocab_size is not None and int(audio_vocab_size) < max_audio_codebook_size:
61
+ raise ValueError(
62
+ "audio_vocab_size must be >= max(audio_codebook_sizes) "
63
+ f"(got {audio_vocab_size}, expected at least {max_audio_codebook_size})."
64
+ )
65
+
66
+ self.audio_codebook_sizes = resolved_audio_codebook_sizes
67
+ self.audio_vocab_size = (
68
+ max_audio_codebook_size if audio_vocab_size is None else int(audio_vocab_size)
69
+ )
70
+ self.audio_pad_token_id = int(audio_pad_token_id)
71
+ if self.audio_pad_token_id < max_audio_codebook_size:
72
+ raise ValueError(
73
+ "audio_pad_token_id must be >= max(audio_codebook_sizes) so pad stays outside every codebook "
74
+ f"(got {self.audio_pad_token_id}, max codebook size {max_audio_codebook_size})."
75
+ )
76
+ self.pad_token_id = pad_token_id
77
+ self.im_start_token_id = im_start_token_id
78
+ self.im_end_token_id = im_end_token_id
79
+ self.audio_start_token_id = audio_start_token_id
80
+ self.audio_end_token_id = audio_end_token_id
81
+ self.audio_user_slot_token_id = audio_user_slot_token_id
82
+ self.audio_assistant_slot_token_id = audio_assistant_slot_token_id
83
+ self.tokenizer_use_fast = tokenizer_use_fast
84
+ self.audio_tokenizer_type = audio_tokenizer_type
85
+ self.audio_tokenizer_pretrained_name_or_path = audio_tokenizer_pretrained_name_or_path
86
+ self.audio_tokenizer_sample_rate = audio_tokenizer_sample_rate
87
+ self.attn_implementation = attn_implementation
88
+ self.initializer_range = initializer_range
89
+ self.model_architecture = model_architecture
90
+ self.local_transformer_layers = local_transformer_layers
91
+ self.local_transformer_attn_implementation = (
92
+ attn_implementation
93
+ if local_transformer_attn_implementation is None
94
+ else local_transformer_attn_implementation
95
+ )
96
+ self.vocab_size = self.gpt2_config.vocab_size
97
+ self.hidden_size = self.gpt2_config.hidden_size
98
+ self.max_position_embeddings = self.gpt2_config.n_positions
99
+
100
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
101
+
102
+ def to_dict(self) -> Dict[str, Any]:
103
+ output = super().to_dict()
104
+ output["gpt2_config"] = self.gpt2_config.to_dict()
105
+ return output
gpt2_decoder.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import BaseModelOutputWithPast
12
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
16
+ from flash_attn.bert_padding import pad_input, unpad_input
17
+
18
+ _FLASH_ATTN_AVAILABLE = True
19
+ except Exception:
20
+ flash_attn_func = None
21
+ flash_attn_varlen_func = None
22
+ pad_input = None
23
+ unpad_input = None
24
+ _FLASH_ATTN_AVAILABLE = False
25
+
26
+
27
+ @dataclass
28
+ class PackedSequenceMetadata:
29
+ cu_seqlens: torch.Tensor
30
+ max_seqlen: int
31
+ indices: Optional[torch.Tensor] = None
32
+ batch_size: Optional[int] = None
33
+ seq_len: Optional[int] = None
34
+
35
+
36
+ class NanoGPT2RotaryEmbedding(nn.Module):
37
+ def __init__(self, dim: int, base: float = 10000.0) -> None:
38
+ super().__init__()
39
+ if dim % 2 != 0:
40
+ raise ValueError(f"RoPE head_dim must be even, got {dim}")
41
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
42
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
43
+
44
+ def forward(
45
+ self,
46
+ position_ids: torch.LongTensor,
47
+ *,
48
+ device: torch.device,
49
+ dtype: torch.dtype,
50
+ ) -> tuple[torch.Tensor, torch.Tensor]:
51
+ if position_ids.ndim == 1:
52
+ position_ids = position_ids.unsqueeze(0)
53
+ freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=self.inv_freq.dtype), self.inv_freq)
54
+ cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
55
+ sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
56
+ return cos, sin
57
+
58
+
59
+ def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
60
+ even = hidden_states[..., ::2]
61
+ odd = hidden_states[..., 1::2]
62
+ return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
63
+
64
+
65
+ def apply_rotary_pos_emb(
66
+ hidden_states: torch.Tensor,
67
+ cos: torch.Tensor,
68
+ sin: torch.Tensor,
69
+ ) -> torch.Tensor:
70
+ return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
71
+
72
+
73
+ class NanoGPT2MLP(nn.Module):
74
+ def __init__(self, config: GPT2Config) -> None:
75
+ super().__init__()
76
+ hidden_size = int(config.hidden_size)
77
+ inner_size = int(config.n_inner or 4 * hidden_size)
78
+ self.fc_in = nn.Linear(hidden_size, inner_size)
79
+ self.fc_out = nn.Linear(inner_size, hidden_size)
80
+ self.act = ACT2FN[config.activation_function]
81
+ self.dropout = nn.Dropout(config.resid_pdrop)
82
+
83
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
84
+ hidden_states = self.fc_in(hidden_states)
85
+ hidden_states = self.act(hidden_states)
86
+ hidden_states = self.fc_out(hidden_states)
87
+ return self.dropout(hidden_states)
88
+
89
+
90
+ class NanoGPT2Attention(nn.Module):
91
+ def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
92
+ super().__init__()
93
+ hidden_size = int(config.hidden_size)
94
+ num_heads = int(config.num_attention_heads)
95
+ if hidden_size % num_heads != 0:
96
+ raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
97
+
98
+ self.num_heads = num_heads
99
+ self.head_dim = hidden_size // num_heads
100
+ self.embed_dim = hidden_size
101
+ self.layer_idx = layer_idx
102
+ self.attn_implementation = attn_implementation
103
+ self.attn_dropout = float(config.attn_pdrop)
104
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
105
+ self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
106
+ self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
107
+ self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
108
+ if self.position_embedding_type not in {"absolute", "rope"}:
109
+ raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
110
+
111
+ self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
112
+ self.c_proj = nn.Linear(hidden_size, hidden_size)
113
+ self.rotary_emb = None
114
+ if self.position_embedding_type == "rope":
115
+ self.rotary_emb = NanoGPT2RotaryEmbedding(
116
+ self.head_dim,
117
+ base=float(getattr(config, "rope_base", 10000.0)),
118
+ )
119
+
120
+ def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
121
+ if tensor.ndim == 3:
122
+ batch_size, seq_len, _ = tensor.shape
123
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
124
+ if tensor.ndim == 2:
125
+ total_tokens, _ = tensor.shape
126
+ return tensor.view(total_tokens, self.num_heads, self.head_dim)
127
+ raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
128
+
129
+ def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
130
+ if tensor.ndim == 4:
131
+ batch_size, seq_len, _, _ = tensor.shape
132
+ return tensor.reshape(batch_size, seq_len, self.embed_dim)
133
+ if tensor.ndim == 3:
134
+ total_tokens, _, _ = tensor.shape
135
+ return tensor.reshape(total_tokens, self.embed_dim)
136
+ raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
137
+
138
+ def _causal_attention_mask(
139
+ self,
140
+ attention_mask: Optional[torch.Tensor],
141
+ query_length: int,
142
+ key_length: int,
143
+ device: torch.device,
144
+ ) -> torch.Tensor:
145
+ query_positions = torch.arange(query_length, device=device, dtype=torch.long)
146
+ query_positions = query_positions + max(key_length - query_length, 0)
147
+ key_positions = torch.arange(key_length, device=device, dtype=torch.long)
148
+ causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
149
+ causal = causal.unsqueeze(0).unsqueeze(0)
150
+ if attention_mask is None:
151
+ return causal
152
+ key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
153
+ return causal & key_mask
154
+
155
+ def _eager_attention(
156
+ self,
157
+ query: torch.Tensor,
158
+ key: torch.Tensor,
159
+ value: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor],
161
+ ) -> torch.Tensor:
162
+ query = query.transpose(1, 2)
163
+ key = key.transpose(1, 2)
164
+ value = value.transpose(1, 2)
165
+
166
+ scale = 1.0
167
+ if self.scale_attn_weights:
168
+ scale /= self.head_dim ** 0.5
169
+ if self.scale_attn_by_inverse_layer_idx:
170
+ scale /= float(self.layer_idx + 1)
171
+
172
+ scores = torch.matmul(query, key.transpose(-1, -2)) * scale
173
+ causal_mask = self._causal_attention_mask(
174
+ attention_mask=attention_mask,
175
+ query_length=query.shape[-2],
176
+ key_length=key.shape[-2],
177
+ device=query.device,
178
+ )
179
+ scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
180
+ probs = torch.softmax(scores, dim=-1)
181
+ if self.training and self.attn_dropout > 0:
182
+ probs = torch.dropout(probs, self.attn_dropout, train=True)
183
+ output = torch.matmul(probs, value)
184
+ return output.transpose(1, 2).contiguous()
185
+
186
+ def _sdpa_attention(
187
+ self,
188
+ query: torch.Tensor,
189
+ key: torch.Tensor,
190
+ value: torch.Tensor,
191
+ attention_mask: Optional[torch.Tensor],
192
+ ) -> torch.Tensor:
193
+ query = query.transpose(1, 2)
194
+ key = key.transpose(1, 2)
195
+ value = value.transpose(1, 2)
196
+ mask = None
197
+ if attention_mask is not None:
198
+ mask = self._causal_attention_mask(
199
+ attention_mask=attention_mask,
200
+ query_length=query.shape[-2],
201
+ key_length=key.shape[-2],
202
+ device=query.device,
203
+ )
204
+ output = torch.nn.functional.scaled_dot_product_attention(
205
+ query,
206
+ key,
207
+ value,
208
+ attn_mask=mask,
209
+ dropout_p=self.attn_dropout if self.training else 0.0,
210
+ is_causal=mask is None,
211
+ )
212
+ return output.transpose(1, 2).contiguous()
213
+
214
+ def _flash_attention(
215
+ self,
216
+ query: torch.Tensor,
217
+ key: torch.Tensor,
218
+ value: torch.Tensor,
219
+ attention_mask: Optional[torch.Tensor],
220
+ packed_metadata: Optional[PackedSequenceMetadata],
221
+ ) -> torch.Tensor:
222
+ if not _FLASH_ATTN_AVAILABLE:
223
+ raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
224
+ if query.device.type != "cuda":
225
+ raise ValueError("flash_attention_2 requires CUDA tensors.")
226
+ if query.dtype not in (torch.float16, torch.bfloat16):
227
+ raise ValueError(
228
+ f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
229
+ )
230
+
231
+ dropout_p = self.attn_dropout if self.training else 0.0
232
+ if packed_metadata is not None:
233
+ if packed_metadata.indices is not None:
234
+ query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
235
+ key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
236
+ value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
237
+ output = flash_attn_varlen_func(
238
+ query,
239
+ key,
240
+ value,
241
+ packed_metadata.cu_seqlens,
242
+ packed_metadata.cu_seqlens,
243
+ packed_metadata.max_seqlen,
244
+ packed_metadata.max_seqlen,
245
+ dropout_p=dropout_p,
246
+ causal=True,
247
+ )
248
+ if packed_metadata.indices is None:
249
+ return output
250
+ return pad_input(
251
+ output,
252
+ packed_metadata.indices,
253
+ packed_metadata.batch_size,
254
+ packed_metadata.seq_len,
255
+ )
256
+
257
+ if attention_mask is None or bool(attention_mask.all()):
258
+ return flash_attn_func(
259
+ query,
260
+ key,
261
+ value,
262
+ dropout_p=dropout_p,
263
+ causal=True,
264
+ )
265
+
266
+ unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
267
+ unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
268
+ unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
269
+ output = flash_attn_varlen_func(
270
+ unpadded_query,
271
+ unpadded_key,
272
+ unpadded_value,
273
+ cu_seqlens,
274
+ cu_seqlens,
275
+ max_seqlen,
276
+ max_seqlen,
277
+ dropout_p=dropout_p,
278
+ causal=True,
279
+ )
280
+ return pad_input(output, indices, query.shape[0], query.shape[1])
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states: torch.Tensor,
285
+ attention_mask: Optional[torch.Tensor] = None,
286
+ position_ids: Optional[torch.LongTensor] = None,
287
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
288
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
289
+ use_cache: bool = False,
290
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
291
+ qkv = self.c_attn(hidden_states)
292
+ query, key, value = qkv.split(self.embed_dim, dim=-1)
293
+ query = self._split_heads(query)
294
+ key = self._split_heads(key)
295
+ value = self._split_heads(value)
296
+
297
+ if self.rotary_emb is not None:
298
+ if position_ids is None:
299
+ raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
300
+ cos, sin = self.rotary_emb(
301
+ position_ids.to(device=query.device),
302
+ device=query.device,
303
+ dtype=query.dtype,
304
+ )
305
+ query = apply_rotary_pos_emb(query, cos, sin)
306
+ key = apply_rotary_pos_emb(key, cos, sin)
307
+
308
+ if layer_past is not None:
309
+ past_key, past_value = layer_past
310
+ key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
311
+ value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
312
+
313
+ present = (key, value) if use_cache else None
314
+
315
+ if self.attn_implementation == "flash_attention_2" and layer_past is None:
316
+ attn_output = self._flash_attention(
317
+ query=query,
318
+ key=key,
319
+ value=value,
320
+ attention_mask=attention_mask,
321
+ packed_metadata=packed_metadata,
322
+ )
323
+ elif self.attn_implementation == "sdpa":
324
+ attn_output = self._sdpa_attention(
325
+ query=query,
326
+ key=key,
327
+ value=value,
328
+ attention_mask=attention_mask,
329
+ )
330
+ else:
331
+ attn_output = self._eager_attention(
332
+ query=query,
333
+ key=key,
334
+ value=value,
335
+ attention_mask=attention_mask,
336
+ )
337
+
338
+ attn_output = self._merge_heads(attn_output)
339
+ attn_output = self.c_proj(attn_output)
340
+ return self.resid_dropout(attn_output), present
341
+
342
+
343
+ class NanoGPT2Block(nn.Module):
344
+ def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
345
+ super().__init__()
346
+ hidden_size = int(config.hidden_size)
347
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
348
+ self.attn = NanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
349
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
350
+ self.mlp = NanoGPT2MLP(config)
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ position_ids: Optional[torch.LongTensor] = None,
357
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
358
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
359
+ use_cache: bool = False,
360
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
361
+ attn_output, present = self.attn(
362
+ self.ln_1(hidden_states),
363
+ attention_mask=attention_mask,
364
+ position_ids=position_ids,
365
+ packed_metadata=packed_metadata,
366
+ layer_past=layer_past,
367
+ use_cache=use_cache,
368
+ )
369
+ hidden_states = hidden_states + attn_output
370
+ hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
371
+ return hidden_states, present
372
+
373
+
374
+ class NanoGPT2Model(nn.Module):
375
+ def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
376
+ super().__init__()
377
+ self.config = config
378
+ self.attn_implementation = attn_implementation
379
+ self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
380
+ if self.position_embedding_type not in {"absolute", "rope"}:
381
+ raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
382
+ hidden_size = int(config.hidden_size)
383
+ self.wte = nn.Embedding(config.vocab_size, hidden_size)
384
+ self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
385
+ self.drop = nn.Dropout(config.embd_pdrop)
386
+ self.h = nn.ModuleList(
387
+ [NanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
388
+ )
389
+ self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
390
+ self.gradient_checkpointing = False
391
+ self._reset_parameters()
392
+
393
+ def _reset_parameters(self) -> None:
394
+ init_std = float(self.config.initializer_range)
395
+ for module in self.modules():
396
+ if isinstance(module, nn.Linear):
397
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
398
+ if module.bias is not None:
399
+ nn.init.zeros_(module.bias)
400
+ elif isinstance(module, nn.Embedding):
401
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
402
+ elif isinstance(module, nn.LayerNorm):
403
+ nn.init.ones_(module.weight)
404
+ nn.init.zeros_(module.bias)
405
+
406
+ @staticmethod
407
+ def _normalize_num_sequences(
408
+ cu_seqlens: torch.Tensor,
409
+ num_sequences: Optional[torch.Tensor],
410
+ device: torch.device,
411
+ ) -> torch.Tensor:
412
+ if cu_seqlens.ndim == 1:
413
+ cu_seqlens = cu_seqlens.unsqueeze(0)
414
+ if num_sequences is None:
415
+ counts = []
416
+ for boundary in cu_seqlens:
417
+ diffs = boundary[1:] - boundary[:-1]
418
+ counts.append(int((diffs > 0).sum().item()))
419
+ return torch.tensor(counts, dtype=torch.int32, device=device)
420
+ if num_sequences.ndim == 0:
421
+ return num_sequences.unsqueeze(0)
422
+ return num_sequences
423
+
424
+ @staticmethod
425
+ def build_packed_position_ids(
426
+ attention_mask: Optional[torch.Tensor],
427
+ cu_seqlens: torch.Tensor,
428
+ num_sequences: Optional[torch.Tensor],
429
+ ) -> torch.Tensor:
430
+ if cu_seqlens.ndim == 1:
431
+ cu_seqlens = cu_seqlens.unsqueeze(0)
432
+ batch_size, seq_len = cu_seqlens.shape[0], cu_seqlens.shape[1] - 1
433
+ device = cu_seqlens.device
434
+ position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
435
+ counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
436
+ for batch_index in range(batch_size):
437
+ sequence_count = int(counts[batch_index].item())
438
+ boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
439
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
440
+ start = int(start)
441
+ end = int(end)
442
+ if end > start:
443
+ position_ids[batch_index, start:end] = torch.arange(end - start, device=device)
444
+ if attention_mask is not None:
445
+ position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
446
+ return position_ids
447
+
448
+ @staticmethod
449
+ def build_packed_metadata(
450
+ hidden_states: torch.Tensor,
451
+ cu_seqlens: torch.Tensor,
452
+ num_sequences: Optional[torch.Tensor],
453
+ ) -> PackedSequenceMetadata:
454
+ if cu_seqlens.ndim == 1:
455
+ cu_seqlens = cu_seqlens.unsqueeze(0)
456
+ device = hidden_states.device
457
+ counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
458
+ flat_indices = []
459
+ cumulative = [0]
460
+ max_seqlen = 0
461
+ seq_len = hidden_states.shape[1]
462
+
463
+ for batch_index in range(hidden_states.shape[0]):
464
+ sequence_count = int(counts[batch_index].item())
465
+ boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
466
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
467
+ start = int(start)
468
+ end = int(end)
469
+ if end <= start:
470
+ continue
471
+ segment_indices = batch_index * seq_len + torch.arange(start, end, device=device)
472
+ flat_indices.append(segment_indices)
473
+ cumulative.append(cumulative[-1] + (end - start))
474
+ max_seqlen = max(max_seqlen, end - start)
475
+
476
+ if not flat_indices:
477
+ raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
478
+
479
+ indices = torch.cat(flat_indices, dim=0)
480
+ return PackedSequenceMetadata(
481
+ cu_seqlens=torch.tensor(cumulative, dtype=torch.int32, device=device),
482
+ max_seqlen=max_seqlen,
483
+ indices=indices,
484
+ batch_size=hidden_states.shape[0],
485
+ seq_len=hidden_states.shape[1],
486
+ )
487
+
488
+ def forward(
489
+ self,
490
+ input_ids: Optional[torch.LongTensor] = None,
491
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
492
+ attention_mask: Optional[torch.Tensor] = None,
493
+ position_ids: Optional[torch.LongTensor] = None,
494
+ inputs_embeds: Optional[torch.FloatTensor] = None,
495
+ use_cache: Optional[bool] = None,
496
+ output_attentions: Optional[bool] = None,
497
+ output_hidden_states: Optional[bool] = None,
498
+ return_dict: bool = True,
499
+ cu_seqlens: Optional[torch.Tensor] = None,
500
+ num_sequences: Optional[torch.Tensor] = None,
501
+ ) -> BaseModelOutputWithPast:
502
+ del input_ids, output_attentions
503
+
504
+ if inputs_embeds is None:
505
+ raise ValueError("inputs_embeds must be provided.")
506
+
507
+ use_cache = bool(use_cache)
508
+ if use_cache and cu_seqlens is not None:
509
+ raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
510
+
511
+ hidden_states = inputs_embeds
512
+ if attention_mask is None:
513
+ attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
514
+ else:
515
+ attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
516
+ query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
517
+
518
+ packed_metadata = None
519
+ if position_ids is None:
520
+ if cu_seqlens is not None:
521
+ position_ids = self.build_packed_position_ids(
522
+ attention_mask=attention_mask,
523
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
524
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
525
+ )
526
+ elif attention_mask is not None:
527
+ position_ids = attention_mask.long().cumsum(dim=-1) - 1
528
+ position_ids = position_ids.masked_fill(~attention_mask, 0)
529
+ position_ids = position_ids[:, -hidden_states.shape[1] :]
530
+ else:
531
+ past_length = 0
532
+ if past_key_values is not None and len(past_key_values) > 0:
533
+ past_length = past_key_values[0][0].shape[1]
534
+ position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
535
+ position_ids = position_ids + past_length
536
+ position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
537
+
538
+ if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
539
+ packed_metadata = self.build_packed_metadata(
540
+ hidden_states=hidden_states,
541
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
542
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
543
+ )
544
+
545
+ if self.position_embedding_type == "absolute":
546
+ hidden_states = hidden_states + self.wpe(position_ids)
547
+ hidden_states = self.drop(hidden_states)
548
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
549
+
550
+ all_hidden_states = () if output_hidden_states else None
551
+ presents = [] if use_cache else None
552
+ for layer_index, block in enumerate(self.h):
553
+ if output_hidden_states:
554
+ all_hidden_states = all_hidden_states + (hidden_states,)
555
+
556
+ if self.gradient_checkpointing and self.training:
557
+ if use_cache:
558
+ raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
559
+
560
+ def custom_forward(*inputs):
561
+ output, _ = block(
562
+ inputs[0],
563
+ attention_mask=inputs[1],
564
+ position_ids=inputs[2],
565
+ packed_metadata=packed_metadata,
566
+ layer_past=None,
567
+ use_cache=False,
568
+ )
569
+ return output
570
+
571
+ hidden_states = torch.utils.checkpoint.checkpoint(
572
+ custom_forward,
573
+ hidden_states,
574
+ attention_mask,
575
+ position_ids,
576
+ use_reentrant=False,
577
+ )
578
+ present = None
579
+ else:
580
+ hidden_states, present = block(
581
+ hidden_states,
582
+ attention_mask=attention_mask,
583
+ position_ids=position_ids,
584
+ packed_metadata=packed_metadata,
585
+ layer_past=None if past_key_values is None else past_key_values[layer_index],
586
+ use_cache=use_cache,
587
+ )
588
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
589
+ if presents is not None:
590
+ presents.append(present)
591
+
592
+ hidden_states = self.ln_f(hidden_states)
593
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
594
+ if output_hidden_states:
595
+ all_hidden_states = all_hidden_states + (hidden_states,)
596
+
597
+ if not return_dict:
598
+ return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
599
+
600
+ return BaseModelOutputWithPast(
601
+ last_hidden_state=hidden_states,
602
+ past_key_values=tuple(presents) if presents is not None else None,
603
+ hidden_states=all_hidden_states,
604
+ attentions=None,
605
+ )
modeling_nanotts_global_local.py ADDED
@@ -0,0 +1,1757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import logging
6
+ import os
7
+ from contextlib import nullcontext
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Optional, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torchaudio
16
+ from transformers import AutoModel, AutoTokenizer
17
+ from transformers.modeling_outputs import ModelOutput
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
20
+
21
+ from .configuration_nanotts import NanoTTSConfig
22
+ from .gpt2_decoder import NanoGPT2Block, NanoGPT2Model
23
+ from .prompting import (
24
+ build_assistant_prompt_prefix,
25
+ build_prompt_token_ids,
26
+ build_user_prompt_after_reference,
27
+ build_user_prompt_prefix,
28
+ )
29
+ from .tokenization_nanotts_sentencepiece import NanoTTSSentencePieceTokenizer
30
+
31
+
32
+ @dataclass
33
+ class NanoTTSOutput(ModelOutput):
34
+ global_hidden_states: Optional[torch.FloatTensor] = None
35
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None
36
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
37
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
38
+
39
+
40
+ @dataclass
41
+ class NanoTTSGenerationOutput(ModelOutput):
42
+ audio_token_ids: torch.LongTensor
43
+ prompt_input_ids: Optional[torch.LongTensor] = None
44
+
45
+
46
+ MOSS_AUDIO_TOKENIZER_NANO_TYPE = "moss-audio-tokenizer-nano"
47
+ DEFAULT_MOSS_AUDIO_TOKENIZER_PRETRAINED_NAME_OR_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano"
48
+ DEFAULT_VOICE_CLONE_MAX_TEXT_TOKENS = 50
49
+ DEFAULT_VOICE_CLONE_MAX_MEMORY_PER_SAMPLE_GB = 1.0
50
+ DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_SHORT_SECONDS = 0.40
51
+ DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_LONG_SECONDS = 0.24
52
+ _SENTENCE_END_PUNCTUATION = frozenset(".!?。!?;;")
53
+ _CLAUSE_SPLIT_PUNCTUATION = frozenset(",,、;;::")
54
+ _CLOSING_PUNCTUATION = frozenset("\"'”’)]})】》」』")
55
+
56
+
57
+ class NanoTTSPreTrainedModel(PreTrainedModel):
58
+ config_class = NanoTTSConfig
59
+ base_model_prefix = "transformer"
60
+ supports_gradient_checkpointing = False
61
+ _no_split_modules = ["NanoGPT2Block"]
62
+ _supports_flash_attn_2 = True
63
+ _supports_sdpa = True
64
+
65
+
66
+ class NanoTTSGlobalLocalForCausalLM(NanoTTSPreTrainedModel):
67
+ _keys_to_ignore_on_load_unexpected = [r"local_transformer\.wte\.weight"]
68
+
69
+ def __init__(self, config: NanoTTSConfig) -> None:
70
+ super().__init__(config)
71
+ config.gpt2_config.pad_token_id = config.pad_token_id
72
+ config.gpt2_config._attn_implementation = config.attn_implementation
73
+
74
+ self.transformer = NanoGPT2Model(
75
+ config.gpt2_config,
76
+ attn_implementation=config.attn_implementation,
77
+ )
78
+ hidden_size = config.gpt2_config.hidden_size
79
+ init_std = config.gpt2_config.initializer_range
80
+
81
+ self.audio_embeddings = nn.ModuleList(
82
+ [
83
+ nn.Embedding(int(config.audio_codebook_sizes[index]), hidden_size)
84
+ for index in range(config.n_vq)
85
+ ]
86
+ )
87
+ self.text_lm_head = nn.Linear(hidden_size, config.gpt2_config.vocab_size, bias=False)
88
+ self.audio_lm_heads = nn.ModuleList(
89
+ [
90
+ nn.Linear(hidden_size, int(config.audio_codebook_sizes[index]), bias=False)
91
+ for index in range(config.n_vq)
92
+ ]
93
+ )
94
+
95
+ local_gpt2_config = config.gpt2_config.to_dict()
96
+ local_gpt2_config["n_layer"] = int(config.local_transformer_layers)
97
+ local_gpt2_config["n_positions"] = config.n_vq + 1
98
+ local_gpt2_config["n_ctx"] = config.n_vq + 1
99
+ self.local_transformer = NanoGPT2Model(
100
+ GPT2Config(**local_gpt2_config),
101
+ attn_implementation=str(config.local_transformer_attn_implementation),
102
+ )
103
+ self.local_transformer.wte = nn.Identity()
104
+
105
+ for module in list(self.audio_embeddings) + [self.text_lm_head] + list(self.audio_lm_heads):
106
+ if hasattr(module, "weight") and module.weight is not None:
107
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
108
+
109
+ self._tied_weights_keys = tuple(self.all_tied_weights_keys.keys())
110
+ self.tie_weights()
111
+
112
+ @property
113
+ def all_tied_weights_keys(self) -> dict[str, str]:
114
+ tied_weights = {"text_lm_head.weight": "transformer.wte.weight"}
115
+ tied_weights.update(
116
+ {
117
+ f"audio_lm_heads.{index}.weight": f"audio_embeddings.{index}.weight"
118
+ for index in range(self.config.n_vq)
119
+ }
120
+ )
121
+ return tied_weights
122
+
123
+ def tie_weights(self, *args, **kwargs) -> None:
124
+ del args, kwargs
125
+ self.text_lm_head.weight = self.transformer.wte.weight
126
+ for embedding, lm_head in zip(self.audio_embeddings, self.audio_lm_heads):
127
+ lm_head.weight = embedding.weight
128
+
129
+ def get_input_embeddings(self) -> nn.Embedding:
130
+ return self.transformer.wte
131
+
132
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
133
+ self.transformer.wte = value
134
+ self.tie_weights()
135
+
136
+ def _build_inputs_embeds(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
137
+ if input_ids.ndim != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
138
+ raise ValueError(
139
+ f"Expected input_ids shape [batch, seq, {self.config.n_vq + 1}], got {tuple(input_ids.shape)}"
140
+ )
141
+
142
+ text_ids = input_ids[..., 0]
143
+ inputs_embeds = self.transformer.wte(text_ids)
144
+
145
+ for channel_index, embedding in enumerate(self.audio_embeddings):
146
+ channel_ids = input_ids[..., channel_index + 1]
147
+ valid_mask = channel_ids.ne(self.config.audio_pad_token_id)
148
+ invalid_mask = valid_mask & ((channel_ids < 0) | (channel_ids >= embedding.num_embeddings))
149
+ if invalid_mask.any():
150
+ invalid_token_ids = channel_ids[invalid_mask]
151
+ raise ValueError(
152
+ "Found out-of-range audio token ids for channel "
153
+ f"{channel_index}: min={int(invalid_token_ids.min().item())} "
154
+ f"max={int(invalid_token_ids.max().item())} "
155
+ f"codebook_size={embedding.num_embeddings} "
156
+ f"audio_pad_token_id={self.config.audio_pad_token_id}"
157
+ )
158
+ safe_ids = channel_ids.masked_fill(~valid_mask, 0)
159
+ audio_embeds = embedding(safe_ids)
160
+ audio_embeds = audio_embeds * valid_mask.unsqueeze(-1)
161
+ inputs_embeds = inputs_embeds + audio_embeds
162
+
163
+ return inputs_embeds
164
+
165
+ def forward(
166
+ self,
167
+ input_ids: Optional[torch.LongTensor] = None,
168
+ attention_mask: Optional[torch.Tensor] = None,
169
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
170
+ inputs_embeds: Optional[torch.FloatTensor] = None,
171
+ use_cache: Optional[bool] = None,
172
+ output_attentions: Optional[bool] = None,
173
+ output_hidden_states: Optional[bool] = None,
174
+ return_dict: Optional[bool] = None,
175
+ **kwargs,
176
+ ):
177
+ labels = kwargs.pop("labels", None)
178
+ if labels is not None:
179
+ raise NotImplementedError("This open-source package is inference-only and does not support training forward.")
180
+ if kwargs:
181
+ ignored = ", ".join(sorted(kwargs.keys()))
182
+ logging.debug("ignoring unsupported forward kwargs: %s", ignored)
183
+
184
+ return_dict = self.config.use_return_dict if return_dict is None else return_dict
185
+ if inputs_embeds is None:
186
+ if input_ids is None:
187
+ raise ValueError("Either input_ids or inputs_embeds must be provided.")
188
+ inputs_embeds = self._build_inputs_embeds(input_ids)
189
+
190
+ outputs = self.transformer(
191
+ input_ids=None,
192
+ past_key_values=past_key_values,
193
+ attention_mask=attention_mask,
194
+ position_ids=None,
195
+ inputs_embeds=inputs_embeds,
196
+ use_cache=use_cache,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=True,
200
+ cu_seqlens=None,
201
+ num_sequences=None,
202
+ )
203
+
204
+ if not return_dict:
205
+ return (
206
+ outputs.last_hidden_state,
207
+ outputs.past_key_values,
208
+ outputs.hidden_states,
209
+ outputs.attentions,
210
+ )
211
+
212
+ return NanoTTSOutput(
213
+ global_hidden_states=outputs.last_hidden_state,
214
+ past_key_values=outputs.past_key_values,
215
+ hidden_states=outputs.hidden_states,
216
+ attentions=outputs.attentions,
217
+ )
218
+
219
+ def _build_text_rows(
220
+ self,
221
+ token_ids: list[int],
222
+ device: torch.device,
223
+ ) -> torch.LongTensor:
224
+ rows = torch.full(
225
+ (len(token_ids), self.config.n_vq + 1),
226
+ self.config.audio_pad_token_id,
227
+ dtype=torch.long,
228
+ device=device,
229
+ )
230
+ if token_ids:
231
+ rows[:, 0] = torch.tensor(token_ids, dtype=torch.long, device=device)
232
+ return rows
233
+
234
+ def _encode_text(self, tokenizer, text: str) -> list[int]:
235
+ try:
236
+ return list(tokenizer.encode(text, add_special_tokens=False))
237
+ except TypeError:
238
+ return list(tokenizer.encode(text))
239
+
240
+ @staticmethod
241
+ def _contains_cjk(text: str) -> bool:
242
+ return any(
243
+ "\u4e00" <= ch <= "\u9fff"
244
+ or "\u3400" <= ch <= "\u4dbf"
245
+ or "\u3040" <= ch <= "\u30ff"
246
+ or "\uac00" <= ch <= "\ud7af"
247
+ for ch in str(text)
248
+ )
249
+
250
+ @staticmethod
251
+ def _prepare_text_for_sentence_chunking(text: str) -> str:
252
+ normalized_text = str(text).strip()
253
+ if normalized_text == "":
254
+ raise ValueError("Text prompt cannot be empty.")
255
+
256
+ normalized_text = normalized_text.replace("\n", " ").replace("\r", " ")
257
+ while " " in normalized_text:
258
+ normalized_text = normalized_text.replace(" ", " ")
259
+
260
+ contains_cjk = NanoTTSGlobalLocalForCausalLM._contains_cjk(normalized_text)
261
+ if contains_cjk:
262
+ if normalized_text[-1] not in _SENTENCE_END_PUNCTUATION:
263
+ normalized_text = normalized_text + "。"
264
+ return normalized_text
265
+
266
+ if not normalized_text[0].isupper():
267
+ normalized_text = normalized_text[0].upper() + normalized_text[1:]
268
+ if normalized_text[-1].isalnum():
269
+ normalized_text = normalized_text + "."
270
+ if len(normalized_text.split()) < 5:
271
+ normalized_text = " " * 8 + normalized_text
272
+ return normalized_text
273
+
274
+ @staticmethod
275
+ def _split_text_by_punctuation(text: str, punctuation: set[str] | frozenset[str]) -> list[str]:
276
+ sentences: list[str] = []
277
+ current_chars: list[str] = []
278
+ text = str(text)
279
+ index = 0
280
+ while index < len(text):
281
+ char = text[index]
282
+ current_chars.append(char)
283
+ if char in punctuation:
284
+ lookahead = index + 1
285
+ while lookahead < len(text) and text[lookahead] in _CLOSING_PUNCTUATION:
286
+ current_chars.append(text[lookahead])
287
+ lookahead += 1
288
+ sentence = "".join(current_chars).strip()
289
+ if sentence:
290
+ sentences.append(sentence)
291
+ current_chars = []
292
+ while lookahead < len(text) and text[lookahead].isspace():
293
+ lookahead += 1
294
+ index = lookahead
295
+ continue
296
+ index += 1
297
+
298
+ tail = "".join(current_chars).strip()
299
+ if tail:
300
+ sentences.append(tail)
301
+ return sentences
302
+
303
+ def _count_text_tokens(self, text_tokenizer, text: str) -> int:
304
+ return len(self._encode_text(text_tokenizer, text))
305
+
306
+ def _split_text_by_token_budget(
307
+ self,
308
+ text_tokenizer,
309
+ text: str,
310
+ max_tokens: int,
311
+ ) -> list[str]:
312
+ remaining_text = str(text).strip()
313
+ if remaining_text == "":
314
+ return []
315
+
316
+ pieces: list[str] = []
317
+ preferred_boundary_chars = _CLAUSE_SPLIT_PUNCTUATION | _SENTENCE_END_PUNCTUATION | frozenset({" "})
318
+ while remaining_text:
319
+ if self._count_text_tokens(text_tokenizer, remaining_text) <= int(max_tokens):
320
+ pieces.append(remaining_text)
321
+ break
322
+
323
+ low = 1
324
+ high = len(remaining_text)
325
+ best_prefix_length = 1
326
+ while low <= high:
327
+ middle = (low + high) // 2
328
+ candidate = remaining_text[:middle].strip()
329
+ if not candidate:
330
+ low = middle + 1
331
+ continue
332
+ if self._count_text_tokens(text_tokenizer, candidate) <= int(max_tokens):
333
+ best_prefix_length = middle
334
+ low = middle + 1
335
+ else:
336
+ high = middle - 1
337
+
338
+ cut_index = best_prefix_length
339
+ prefix = remaining_text[:best_prefix_length]
340
+ preferred_index = -1
341
+ for scan_index in range(len(prefix) - 1, max(-1, len(prefix) - 25), -1):
342
+ if prefix[scan_index] in preferred_boundary_chars:
343
+ preferred_index = scan_index + 1
344
+ break
345
+ if preferred_index > 0:
346
+ cut_index = preferred_index
347
+
348
+ piece = remaining_text[:cut_index].strip()
349
+ if not piece:
350
+ piece = remaining_text[:best_prefix_length].strip()
351
+ cut_index = best_prefix_length
352
+ pieces.append(piece)
353
+ remaining_text = remaining_text[cut_index:].strip()
354
+ return pieces
355
+
356
+ @staticmethod
357
+ def _join_sentence_parts(left: str, right: str) -> str:
358
+ if not left:
359
+ return right
360
+ if not right:
361
+ return left
362
+ if NanoTTSGlobalLocalForCausalLM._contains_cjk(left) or NanoTTSGlobalLocalForCausalLM._contains_cjk(right):
363
+ return left + right
364
+ return left + " " + right
365
+
366
+ def _split_text_into_best_sentences(
367
+ self,
368
+ text_tokenizer,
369
+ text: str,
370
+ max_tokens: int,
371
+ ) -> list[str]:
372
+ if int(max_tokens) <= 0:
373
+ return [str(text)]
374
+
375
+ prepared_text = self._prepare_text_for_sentence_chunking(text)
376
+ sentence_candidates = self._split_text_by_punctuation(prepared_text, punctuation=_SENTENCE_END_PUNCTUATION)
377
+ if not sentence_candidates:
378
+ sentence_candidates = [prepared_text.strip()]
379
+
380
+ sentence_slices: list[tuple[int, str]] = []
381
+ for sentence_text in sentence_candidates:
382
+ normalized_sentence = sentence_text.strip()
383
+ if not normalized_sentence:
384
+ continue
385
+ sentence_token_count = self._count_text_tokens(text_tokenizer, normalized_sentence)
386
+ if sentence_token_count <= int(max_tokens):
387
+ sentence_slices.append((sentence_token_count, normalized_sentence))
388
+ continue
389
+
390
+ clause_candidates = self._split_text_by_punctuation(
391
+ normalized_sentence,
392
+ punctuation=_CLAUSE_SPLIT_PUNCTUATION,
393
+ )
394
+ if len(clause_candidates) <= 1:
395
+ clause_candidates = [normalized_sentence]
396
+
397
+ for clause_text in clause_candidates:
398
+ normalized_clause = clause_text.strip()
399
+ if not normalized_clause:
400
+ continue
401
+ clause_token_count = self._count_text_tokens(text_tokenizer, normalized_clause)
402
+ if clause_token_count <= int(max_tokens):
403
+ sentence_slices.append((clause_token_count, normalized_clause))
404
+ continue
405
+ for piece in self._split_text_by_token_budget(
406
+ text_tokenizer=text_tokenizer,
407
+ text=normalized_clause,
408
+ max_tokens=max_tokens,
409
+ ):
410
+ normalized_piece = piece.strip()
411
+ if normalized_piece:
412
+ sentence_slices.append(
413
+ (self._count_text_tokens(text_tokenizer, normalized_piece), normalized_piece)
414
+ )
415
+
416
+ chunks: list[str] = []
417
+ current_chunk = ""
418
+ current_chunk_token_count = 0
419
+ for sentence_token_count, sentence_text in sentence_slices:
420
+ if current_chunk == "":
421
+ current_chunk = sentence_text
422
+ current_chunk_token_count = sentence_token_count
423
+ continue
424
+ if current_chunk_token_count + sentence_token_count > int(max_tokens):
425
+ chunks.append(current_chunk.strip())
426
+ current_chunk = sentence_text
427
+ current_chunk_token_count = sentence_token_count
428
+ else:
429
+ current_chunk = self._join_sentence_parts(current_chunk, sentence_text)
430
+ current_chunk_token_count = self._count_text_tokens(text_tokenizer, current_chunk)
431
+
432
+ if current_chunk:
433
+ chunks.append(current_chunk.strip())
434
+ return chunks or [prepared_text.strip()]
435
+
436
+ @staticmethod
437
+ def _estimate_voice_clone_inter_chunk_pause_seconds(text_chunk: str) -> float:
438
+ return (
439
+ DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_SHORT_SECONDS
440
+ if len(str(text_chunk).strip().split()) <= 4
441
+ else DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_LONG_SECONDS
442
+ )
443
+
444
+ def _concat_voice_clone_waveform_chunks(
445
+ self,
446
+ waveform_chunks: list[torch.FloatTensor],
447
+ text_chunks: list[str],
448
+ sample_rate: int,
449
+ ) -> torch.FloatTensor:
450
+ if not waveform_chunks:
451
+ return torch.zeros((1, 0), dtype=torch.float32)
452
+ if len(waveform_chunks) != len(text_chunks):
453
+ raise ValueError("waveform_chunks and text_chunks must have the same length.")
454
+ if len(waveform_chunks) == 1:
455
+ return waveform_chunks[0]
456
+
457
+ segments: list[torch.FloatTensor] = []
458
+ for chunk_index, waveform_chunk in enumerate(waveform_chunks):
459
+ segments.append(waveform_chunk)
460
+ if chunk_index >= len(waveform_chunks) - 1:
461
+ continue
462
+ pause_seconds = self._estimate_voice_clone_inter_chunk_pause_seconds(text_chunks[chunk_index])
463
+ pause_samples = max(0, int(round(float(sample_rate) * pause_seconds)))
464
+ if pause_samples > 0:
465
+ silence = torch.zeros((waveform_chunk.shape[0], pause_samples), dtype=waveform_chunk.dtype)
466
+ segments.append(silence)
467
+ return torch.cat(segments, dim=-1)
468
+
469
+ @staticmethod
470
+ def _resolve_inference_mode(
471
+ mode: str,
472
+ has_prompt_text: bool,
473
+ has_prompt_audio: bool,
474
+ ) -> str:
475
+ normalized_mode = str(mode or "continuation").strip().lower() or "continuation"
476
+ if normalized_mode not in {"continuation", "voice_clone"}:
477
+ raise ValueError(f"Unsupported inference mode {mode!r}.")
478
+ if normalized_mode == "voice_clone":
479
+ if not has_prompt_audio:
480
+ raise ValueError("voice_clone mode requires prompt_audio_path.")
481
+ if has_prompt_text:
482
+ raise ValueError("voice_clone mode does not accept prompt_text.")
483
+ elif has_prompt_text != has_prompt_audio:
484
+ raise ValueError(
485
+ "continuation mode accepts either target text only, or prompt_text and prompt_audio_path together."
486
+ )
487
+ return normalized_mode
488
+
489
+ def _resolve_inference_nq(self, nq: Optional[int] = None) -> int:
490
+ if nq is None:
491
+ return int(self.config.n_vq)
492
+ resolved_nq = int(nq)
493
+ if resolved_nq < 1 or resolved_nq > int(self.config.n_vq):
494
+ raise ValueError(f"nq must be in [1, {self.config.n_vq}], got {resolved_nq}.")
495
+ return resolved_nq
496
+
497
+ def _mask_unused_audio_channels(
498
+ self,
499
+ audio_token_ids: torch.LongTensor,
500
+ nq: int,
501
+ ) -> torch.LongTensor:
502
+ tensor = torch.as_tensor(audio_token_ids, dtype=torch.long)
503
+ if tensor.shape[-1] != self.config.n_vq:
504
+ raise ValueError(
505
+ f"Expected audio token ids with trailing dim {self.config.n_vq}, got {tuple(tensor.shape)}"
506
+ )
507
+ if nq < self.config.n_vq:
508
+ tensor = tensor.clone()
509
+ tensor[..., nq:] = self.config.audio_pad_token_id
510
+ return tensor
511
+
512
+ def _build_audio_prefix_rows(
513
+ self,
514
+ prompt_audio_codes: torch.LongTensor,
515
+ slot_token_id: int,
516
+ device: torch.device,
517
+ ) -> torch.LongTensor:
518
+ rows = torch.full(
519
+ (int(prompt_audio_codes.shape[0]), self.config.n_vq + 1),
520
+ self.config.audio_pad_token_id,
521
+ dtype=torch.long,
522
+ device=device,
523
+ )
524
+ if rows.shape[0] > 0:
525
+ rows[:, 0] = int(slot_token_id)
526
+ rows[:, 1:] = prompt_audio_codes
527
+ return rows
528
+
529
+ def build_inference_input_ids(
530
+ self,
531
+ text: str,
532
+ text_tokenizer,
533
+ mode: str = "continuation",
534
+ prompt_text: Optional[str] = None,
535
+ prompt_audio_codes: Optional[torch.LongTensor] = None,
536
+ device: Optional[Union[str, torch.device]] = None,
537
+ ) -> tuple[torch.LongTensor, torch.BoolTensor]:
538
+ resolved_device = self._resolve_device(device)
539
+ resolved_mode = self._resolve_inference_mode(
540
+ mode=mode,
541
+ has_prompt_text=prompt_text is not None,
542
+ has_prompt_audio=prompt_audio_codes is not None,
543
+ )
544
+
545
+ if resolved_mode == "voice_clone":
546
+ assert prompt_audio_codes is not None
547
+ text_token_ids = self._encode_text(text_tokenizer, text)
548
+ prompt_token_ids = build_user_prompt_prefix(text_tokenizer, self.config) + [self.config.audio_start_token_id]
549
+ suffix_token_ids = (
550
+ [self.config.audio_end_token_id]
551
+ + build_user_prompt_after_reference(text_tokenizer)
552
+ + text_token_ids
553
+ + build_assistant_prompt_prefix(text_tokenizer, self.config)
554
+ + [self.config.audio_start_token_id]
555
+ )
556
+ sections = [
557
+ self._build_text_rows(prompt_token_ids, device=resolved_device),
558
+ self._build_audio_prefix_rows(
559
+ prompt_audio_codes=prompt_audio_codes.to(resolved_device),
560
+ slot_token_id=self.config.audio_user_slot_token_id,
561
+ device=resolved_device,
562
+ ),
563
+ self._build_text_rows(suffix_token_ids, device=resolved_device),
564
+ ]
565
+ input_ids = torch.cat(sections, dim=0).unsqueeze(0)
566
+ attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=resolved_device)
567
+ return input_ids, attention_mask
568
+
569
+ effective_text = text if prompt_text is None else prompt_text + text
570
+ prompt_token_ids = build_prompt_token_ids(
571
+ tokenizer=text_tokenizer,
572
+ config=self.config,
573
+ text_token_ids=self._encode_text(text_tokenizer, effective_text),
574
+ )
575
+ sections = [
576
+ self._build_text_rows(prompt_token_ids, device=resolved_device),
577
+ self._build_text_rows([self.config.audio_start_token_id], device=resolved_device),
578
+ ]
579
+ if prompt_audio_codes is not None:
580
+ sections.append(
581
+ self._build_audio_prefix_rows(
582
+ prompt_audio_codes=prompt_audio_codes.to(resolved_device),
583
+ slot_token_id=self.config.audio_assistant_slot_token_id,
584
+ device=resolved_device,
585
+ )
586
+ )
587
+ input_ids = torch.cat(sections, dim=0).unsqueeze(0)
588
+ attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=resolved_device)
589
+ return input_ids, attention_mask
590
+
591
+ def _left_pad_inference_batch(
592
+ self,
593
+ input_id_batches: list[torch.LongTensor],
594
+ attention_mask_batches: list[torch.BoolTensor],
595
+ device: torch.device,
596
+ ) -> tuple[torch.LongTensor, torch.BoolTensor]:
597
+ if not input_id_batches:
598
+ raise ValueError("input_id_batches must not be empty.")
599
+ if len(input_id_batches) != len(attention_mask_batches):
600
+ raise ValueError("input_id_batches and attention_mask_batches must have the same length.")
601
+
602
+ batch_size = len(input_id_batches)
603
+ max_seq_len = max(int(batch.shape[1]) for batch in input_id_batches)
604
+ row_width = self.config.n_vq + 1
605
+
606
+ padded_input_ids = torch.full(
607
+ (batch_size, max_seq_len, row_width),
608
+ self.config.audio_pad_token_id,
609
+ dtype=torch.long,
610
+ device=device,
611
+ )
612
+ padded_input_ids[:, :, 0] = self.config.pad_token_id
613
+ padded_attention_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device)
614
+
615
+ for batch_index, (input_ids, attention_mask) in enumerate(zip(input_id_batches, attention_mask_batches)):
616
+ normalized_input_ids = input_ids.squeeze(0).to(device=device, dtype=torch.long)
617
+ normalized_attention_mask = attention_mask.squeeze(0).to(device=device, dtype=torch.bool)
618
+ seq_len = int(normalized_input_ids.shape[0])
619
+ padded_input_ids[batch_index, -seq_len:, :] = normalized_input_ids
620
+ padded_attention_mask[batch_index, -seq_len:] = normalized_attention_mask
621
+
622
+ return padded_input_ids, padded_attention_mask
623
+
624
+ def _trim_generated_audio_token_ids(
625
+ self,
626
+ audio_token_ids: torch.LongTensor,
627
+ effective_nq: int,
628
+ ) -> torch.LongTensor:
629
+ tensor = self._mask_unused_audio_channels(audio_token_ids, nq=effective_nq)
630
+ if tensor.ndim != 2:
631
+ raise ValueError(f"Expected a 2D audio token tensor, got {tuple(tensor.shape)}")
632
+ valid_rows = tensor[:, :effective_nq].ne(self.config.audio_pad_token_id).any(dim=-1)
633
+ if not bool(valid_rows.any()):
634
+ return tensor[:0]
635
+ last_valid_index = int(torch.nonzero(valid_rows, as_tuple=False)[-1].item()) + 1
636
+ return tensor[:last_valid_index]
637
+
638
+ def _resolve_voice_clone_chunk_batch_size(
639
+ self,
640
+ *,
641
+ resolved_device: torch.device,
642
+ chunk_count: int,
643
+ max_memory_per_sample_gb: float,
644
+ ) -> int:
645
+ if chunk_count <= 1 or max_memory_per_sample_gb <= 0 or resolved_device.type != "cuda":
646
+ return 1
647
+ if not hasattr(torch.cuda, "mem_get_info"):
648
+ return 1
649
+ try:
650
+ free_bytes, _ = torch.cuda.mem_get_info(resolved_device)
651
+ except Exception:
652
+ return 1
653
+ bytes_per_sample = int(float(max_memory_per_sample_gb) * (1024**3))
654
+ if bytes_per_sample <= 0:
655
+ return 1
656
+ usable_free_bytes = max(0, int(free_bytes * 0.9))
657
+ batch_size = max(1, usable_free_bytes // bytes_per_sample)
658
+ resolved_batch_size = max(1, min(int(chunk_count), int(batch_size)))
659
+ logging.info(
660
+ "voice_clone chunk batching device=%s free_gb=%.2f max_memory_per_sample_gb=%.2f resolved_batch_size=%d chunk_count=%d",
661
+ resolved_device,
662
+ float(free_bytes) / float(1024**3),
663
+ float(max_memory_per_sample_gb),
664
+ resolved_batch_size,
665
+ int(chunk_count),
666
+ )
667
+ return resolved_batch_size
668
+
669
+ def _generate_audio_token_ids_with_fallback(
670
+ self,
671
+ *,
672
+ prompt_input_ids: torch.LongTensor,
673
+ attention_mask: torch.BoolTensor,
674
+ effective_nq: int,
675
+ max_new_frames: int,
676
+ do_sample: bool,
677
+ text_temperature: float,
678
+ text_top_p: float,
679
+ text_top_k: int,
680
+ audio_temperature: float,
681
+ audio_top_p: float,
682
+ audio_top_k: int,
683
+ audio_repetition_penalty: float,
684
+ use_kv_cache: bool,
685
+ resolved_device: torch.device,
686
+ ) -> torch.LongTensor:
687
+ try:
688
+ generation = self.generate(
689
+ input_ids=prompt_input_ids,
690
+ attention_mask=attention_mask,
691
+ nq=effective_nq,
692
+ max_new_frames=max_new_frames,
693
+ do_sample=do_sample,
694
+ text_temperature=text_temperature,
695
+ text_top_p=text_top_p,
696
+ text_top_k=text_top_k,
697
+ audio_temperature=audio_temperature,
698
+ audio_top_p=audio_top_p,
699
+ audio_top_k=audio_top_k,
700
+ audio_repetition_penalty=audio_repetition_penalty,
701
+ use_kv_cache=use_kv_cache,
702
+ return_dict_in_generate=True,
703
+ )
704
+ except (RuntimeError, ValueError) as exc:
705
+ if not self._is_generation_stability_error(exc):
706
+ raise
707
+ self._apply_inference_stability_fallback(resolved_device)
708
+ generation = self.generate(
709
+ input_ids=prompt_input_ids,
710
+ attention_mask=attention_mask,
711
+ nq=effective_nq,
712
+ max_new_frames=max_new_frames,
713
+ do_sample=do_sample,
714
+ text_temperature=text_temperature,
715
+ text_top_p=text_top_p,
716
+ text_top_k=text_top_k,
717
+ audio_temperature=audio_temperature,
718
+ audio_top_p=audio_top_p,
719
+ audio_top_k=audio_top_k,
720
+ audio_repetition_penalty=audio_repetition_penalty,
721
+ use_kv_cache=use_kv_cache,
722
+ return_dict_in_generate=True,
723
+ )
724
+ return self._mask_unused_audio_channels(generation.audio_token_ids, nq=effective_nq)
725
+
726
+ def _decode_audio_token_ids_to_waveform(
727
+ self,
728
+ *,
729
+ audio_tokenizer,
730
+ audio_token_ids: torch.LongTensor,
731
+ target_sample_rate: int,
732
+ effective_nq: int,
733
+ resolved_device: torch.device,
734
+ ) -> tuple[torch.FloatTensor, int]:
735
+ decoded = self._call_audio_decode(
736
+ audio_tokenizer=audio_tokenizer,
737
+ audio_token_ids=audio_token_ids.to(resolved_device),
738
+ sample_rate=target_sample_rate,
739
+ nq=effective_nq,
740
+ )
741
+ return self._extract_waveform_and_sample_rate(decoded, fallback_sample_rate=target_sample_rate)
742
+
743
+ def _build_generation_row(
744
+ self,
745
+ batch_size: int,
746
+ device: torch.device,
747
+ audio_token_ids: torch.LongTensor,
748
+ ) -> torch.LongTensor:
749
+ row = torch.full(
750
+ (batch_size, 1, self.config.n_vq + 1),
751
+ self.config.audio_pad_token_id,
752
+ dtype=torch.long,
753
+ device=device,
754
+ )
755
+ row[:, :, 0] = self.config.audio_assistant_slot_token_id
756
+ row[:, :, 1:] = audio_token_ids.unsqueeze(1)
757
+ return row
758
+
759
+ def _sample_next_token(
760
+ self,
761
+ logits: torch.FloatTensor,
762
+ do_sample: bool,
763
+ temperature: float,
764
+ top_k: Optional[int],
765
+ top_p: Optional[float],
766
+ previous_token_ids: Optional[torch.LongTensor] = None,
767
+ repetition_penalty: float = 1.0,
768
+ ) -> torch.LongTensor:
769
+ scores = self._apply_repetition_penalty(
770
+ logits=logits,
771
+ previous_token_ids=previous_token_ids,
772
+ repetition_penalty=repetition_penalty,
773
+ )
774
+ if not do_sample:
775
+ return scores.argmax(dim=-1)
776
+ if temperature <= 0:
777
+ raise ValueError("temperature must be positive when do_sample=True")
778
+
779
+ scores = scores / temperature
780
+ if top_k is not None and top_k > 0:
781
+ top_k = min(top_k, scores.shape[-1])
782
+ threshold = torch.topk(scores, top_k, dim=-1).values[..., -1, None]
783
+ scores = scores.masked_fill(scores < threshold, float("-inf"))
784
+
785
+ if top_p is not None and 0.0 < top_p < 1.0:
786
+ sorted_scores, sorted_indices = torch.sort(scores, descending=True, dim=-1)
787
+ sorted_probs = torch.softmax(sorted_scores, dim=-1)
788
+ sorted_cumsum = torch.cumsum(sorted_probs, dim=-1)
789
+ sorted_remove = sorted_cumsum > top_p
790
+ sorted_remove[..., 1:] = sorted_remove[..., :-1].clone()
791
+ sorted_remove[..., 0] = False
792
+ sorted_scores = sorted_scores.masked_fill(sorted_remove, float("-inf"))
793
+ scores = torch.full_like(scores, float("-inf"))
794
+ scores.scatter_(dim=-1, index=sorted_indices, src=sorted_scores)
795
+
796
+ probs = torch.softmax(scores, dim=-1)
797
+ return torch.multinomial(probs, num_samples=1).squeeze(-1)
798
+
799
+ @staticmethod
800
+ def _ensure_finite_generation_logits(logits: torch.FloatTensor, name: str) -> None:
801
+ if torch.isfinite(logits).all():
802
+ return
803
+ finite_mask = torch.isfinite(logits)
804
+ finite_logits = logits[finite_mask]
805
+ min_value = float(finite_logits.min().item()) if finite_logits.numel() > 0 else float("nan")
806
+ max_value = float(finite_logits.max().item()) if finite_logits.numel() > 0 else float("nan")
807
+ raise RuntimeError(
808
+ f"Non-finite {name} during generation: dtype={logits.dtype} "
809
+ f"shape={tuple(logits.shape)} finite={int(finite_mask.sum().item())}/{int(logits.numel())} "
810
+ f"min={min_value} max={max_value}"
811
+ )
812
+
813
+ def _apply_repetition_penalty(
814
+ self,
815
+ logits: torch.FloatTensor,
816
+ previous_token_ids: Optional[torch.LongTensor],
817
+ repetition_penalty: float,
818
+ ) -> torch.FloatTensor:
819
+ if repetition_penalty <= 0:
820
+ raise ValueError("repetition_penalty must be positive")
821
+ if repetition_penalty == 1.0 or previous_token_ids is None:
822
+ return logits
823
+
824
+ token_ids = torch.as_tensor(previous_token_ids, device=logits.device, dtype=torch.long)
825
+ if token_ids.ndim == 1:
826
+ token_ids = token_ids.unsqueeze(0)
827
+ elif token_ids.ndim > 2:
828
+ token_ids = token_ids.reshape(token_ids.shape[0], -1)
829
+
830
+ scores = logits.clone()
831
+ vocab_size = scores.shape[-1]
832
+ for batch_index in range(scores.shape[0]):
833
+ valid_token_ids = token_ids[batch_index]
834
+ valid_token_ids = valid_token_ids[(valid_token_ids >= 0) & (valid_token_ids < vocab_size)]
835
+ if valid_token_ids.numel() == 0:
836
+ continue
837
+ unique_token_ids = torch.unique(valid_token_ids)
838
+ token_scores = scores[batch_index].index_select(0, unique_token_ids)
839
+ token_scores = torch.where(
840
+ token_scores < 0,
841
+ token_scores * repetition_penalty,
842
+ token_scores / repetition_penalty,
843
+ )
844
+ scores[batch_index].scatter_(0, unique_token_ids, token_scores)
845
+ return scores
846
+
847
+ def _sample_next_assistant_text_token(
848
+ self,
849
+ logits: torch.FloatTensor,
850
+ do_sample: bool,
851
+ temperature: float,
852
+ top_k: Optional[int] = None,
853
+ top_p: Optional[float] = None,
854
+ ) -> torch.LongTensor:
855
+ candidate_ids = torch.tensor(
856
+ [
857
+ self.config.audio_assistant_slot_token_id,
858
+ self.config.audio_end_token_id,
859
+ ],
860
+ dtype=torch.long,
861
+ device=logits.device,
862
+ )
863
+ candidate_logits = logits.index_select(dim=-1, index=candidate_ids)
864
+ sampled_indices = self._sample_next_token(
865
+ logits=candidate_logits,
866
+ do_sample=do_sample,
867
+ temperature=temperature,
868
+ top_k=top_k,
869
+ top_p=top_p,
870
+ )
871
+ return candidate_ids[sampled_indices]
872
+
873
+ def _resolve_device(self, device: Optional[Union[str, torch.device]] = None) -> torch.device:
874
+ return torch.device(device) if device is not None else next(self.parameters()).device
875
+
876
+ @staticmethod
877
+ def _looks_like_hf_tokenizer_dir(candidate_path: Path) -> bool:
878
+ if not candidate_path.is_dir():
879
+ return False
880
+ if (candidate_path / "tokenizer.model").is_file():
881
+ return True
882
+ if (candidate_path / "tokenizer.json").is_file():
883
+ return True
884
+ if (candidate_path / "tokenizer_config.json").is_file() and (
885
+ (candidate_path / "vocab.json").is_file()
886
+ or (candidate_path / "merges.txt").is_file()
887
+ or (candidate_path / "special_tokens_map.json").is_file()
888
+ ):
889
+ return True
890
+ return False
891
+
892
+ def _resolve_text_tokenizer_path(self, raw_path: Union[str, Path]) -> Path:
893
+ candidate_path = Path(raw_path)
894
+ if candidate_path.is_file() and candidate_path.suffix == ".model":
895
+ return candidate_path
896
+ if not candidate_path.exists():
897
+ raise FileNotFoundError(f"Tokenizer path does not exist: {candidate_path}")
898
+ if candidate_path.is_dir():
899
+ if (candidate_path / "tokenizer.model").is_file():
900
+ return candidate_path
901
+ if self._looks_like_hf_tokenizer_dir(candidate_path):
902
+ return candidate_path
903
+ hf_dir = candidate_path / "hf_tokenizer"
904
+ if self._looks_like_hf_tokenizer_dir(hf_dir):
905
+ return hf_dir
906
+ sentencepiece_model = candidate_path / "sentencepiece" / "nanotts_spm_bpe.model"
907
+ if sentencepiece_model.is_file():
908
+ return sentencepiece_model
909
+ final_summary_path = candidate_path / "final_summary.json"
910
+ if final_summary_path.is_file():
911
+ final_summary = json.loads(final_summary_path.read_text(encoding="utf-8"))
912
+ latest_hf_dir = final_summary.get("latest_hf_tokenizer_dir")
913
+ if latest_hf_dir:
914
+ latest_hf_path = Path(str(latest_hf_dir))
915
+ if self._looks_like_hf_tokenizer_dir(latest_hf_path):
916
+ return latest_hf_path
917
+ raise ValueError(
918
+ "Could not resolve a tokenizer from the provided path. Expected a tokenizer dir, experiment dir, or SentencePiece .model file."
919
+ )
920
+
921
+ def _load_resolved_text_tokenizer(self, resolved_path: Path, cache_dir: str):
922
+ if resolved_path.is_file() and resolved_path.suffix == ".model":
923
+ return NanoTTSSentencePieceTokenizer(vocab_file=str(resolved_path))
924
+ try:
925
+ return AutoTokenizer.from_pretrained(
926
+ str(resolved_path),
927
+ trust_remote_code=True,
928
+ use_fast=bool(self.config.tokenizer_use_fast),
929
+ local_files_only=True,
930
+ cache_dir=cache_dir,
931
+ )
932
+ except Exception:
933
+ model_path = resolved_path / "tokenizer.model"
934
+ if model_path.is_file():
935
+ return NanoTTSSentencePieceTokenizer(vocab_file=str(model_path))
936
+ raise
937
+
938
+ @staticmethod
939
+ def _resolve_hf_cache_dir() -> str:
940
+ cache_dir = Path(__file__).resolve().parent / ".cache" / "huggingface"
941
+ cache_dir.mkdir(parents=True, exist_ok=True)
942
+ return str(cache_dir)
943
+
944
+ @staticmethod
945
+ def _patch_hf_dynamic_module_cache_dir(cache_dir: str) -> None:
946
+ import transformers.dynamic_module_utils as dynamic_module_utils
947
+
948
+ modules_cache_dir = str(Path(cache_dir) / "modules")
949
+ Path(modules_cache_dir).mkdir(parents=True, exist_ok=True)
950
+ os.environ["HF_MODULES_CACHE"] = modules_cache_dir
951
+ dynamic_module_utils.HF_MODULES_CACHE = modules_cache_dir
952
+
953
+ def _resolve_default_text_tokenizer_path(self) -> Path:
954
+ candidates: list[Path] = []
955
+
956
+ raw_name_or_path = getattr(self.config, "_name_or_path", None)
957
+ if raw_name_or_path:
958
+ candidates.append(Path(str(raw_name_or_path)).expanduser())
959
+
960
+ raw_model_name_or_path = getattr(self, "name_or_path", None)
961
+ if raw_model_name_or_path:
962
+ candidates.append(Path(str(raw_model_name_or_path)).expanduser())
963
+
964
+ candidates.append(Path(__file__).resolve().parent)
965
+
966
+ checked: set[str] = set()
967
+ for candidate in candidates:
968
+ resolved_candidate = candidate.resolve()
969
+ key = str(resolved_candidate)
970
+ if key in checked:
971
+ continue
972
+ checked.add(key)
973
+
974
+ if (resolved_candidate / "tokenizer.model").is_file():
975
+ return resolved_candidate
976
+ if self._looks_like_hf_tokenizer_dir(resolved_candidate):
977
+ return resolved_candidate
978
+
979
+ return candidates[0].resolve()
980
+
981
+ def _load_text_tokenizer(self, text_tokenizer=None, text_tokenizer_path: Optional[str] = None):
982
+ if text_tokenizer is not None:
983
+ return text_tokenizer
984
+
985
+ resolved_path = (
986
+ self._resolve_text_tokenizer_path(text_tokenizer_path)
987
+ if text_tokenizer_path is not None
988
+ else self._resolve_default_text_tokenizer_path()
989
+ )
990
+ normalized_path = str(resolved_path.resolve())
991
+ cached = getattr(self, "_cached_text_tokenizer", None)
992
+ cached_path = getattr(self, "_cached_text_tokenizer_path", None)
993
+ if cached is not None and cached_path == normalized_path:
994
+ return cached
995
+
996
+ cache_dir = self._resolve_hf_cache_dir()
997
+ self._patch_hf_dynamic_module_cache_dir(cache_dir)
998
+ tokenizer = self._load_resolved_text_tokenizer(resolved_path=resolved_path, cache_dir=cache_dir)
999
+ if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
1000
+ tokenizer.pad_token = tokenizer.eos_token
1001
+ self._cached_text_tokenizer = tokenizer
1002
+ self._cached_text_tokenizer_path = normalized_path
1003
+ return tokenizer
1004
+
1005
+ @staticmethod
1006
+ def _normalize_audio_tokenizer_type(audio_tokenizer_type: Optional[str]) -> Optional[str]:
1007
+ if audio_tokenizer_type is None:
1008
+ return None
1009
+ normalized = str(audio_tokenizer_type).strip().lower()
1010
+ if not normalized:
1011
+ return None
1012
+ if normalized == MOSS_AUDIO_TOKENIZER_NANO_TYPE:
1013
+ return MOSS_AUDIO_TOKENIZER_NANO_TYPE
1014
+ raise ValueError(
1015
+ "Unsupported audio tokenizer type. "
1016
+ f"The open-source package only supports '{MOSS_AUDIO_TOKENIZER_NANO_TYPE}'."
1017
+ )
1018
+
1019
+ def _resolve_audio_tokenizer_type(self, audio_tokenizer_type: Optional[str]) -> str:
1020
+ explicit_type = self._normalize_audio_tokenizer_type(audio_tokenizer_type)
1021
+ if explicit_type is not None:
1022
+ return explicit_type
1023
+ config_type = self._normalize_audio_tokenizer_type(getattr(self.config, "audio_tokenizer_type", None))
1024
+ return MOSS_AUDIO_TOKENIZER_NANO_TYPE if config_type is None else config_type
1025
+
1026
+ @staticmethod
1027
+ def _set_decoder_attention_implementation(decoder, attn_implementation: str) -> None:
1028
+ decoder.attn_implementation = str(attn_implementation)
1029
+ if getattr(decoder, "config", None) is not None:
1030
+ decoder.config._attn_implementation = str(attn_implementation)
1031
+ for block in getattr(decoder, "h", []):
1032
+ block.attn.attn_implementation = str(attn_implementation)
1033
+
1034
+ def _set_attention_implementation(
1035
+ self,
1036
+ attn_implementation: str,
1037
+ local_attn_implementation: Optional[str] = None,
1038
+ ) -> None:
1039
+ resolved_global = str(attn_implementation)
1040
+ resolved_local = resolved_global if local_attn_implementation is None else str(local_attn_implementation)
1041
+ self.config.attn_implementation = resolved_global
1042
+ self.config.gpt2_config._attn_implementation = resolved_global
1043
+ self._set_decoder_attention_implementation(self.transformer, resolved_global)
1044
+ self.config.local_transformer_attn_implementation = resolved_local
1045
+ self._set_decoder_attention_implementation(self.local_transformer, resolved_local)
1046
+
1047
+ @staticmethod
1048
+ def _select_fallback_attention_implementation(device: torch.device) -> str:
1049
+ return "sdpa" if device.type == "cuda" else "eager"
1050
+
1051
+ @staticmethod
1052
+ def _is_generation_stability_error(exc: Exception) -> bool:
1053
+ message = str(exc)
1054
+ return any(
1055
+ marker in message
1056
+ for marker in (
1057
+ "Non-finite",
1058
+ "device-side assert triggered",
1059
+ "probability tensor contains either",
1060
+ "flash_attention_2 requires fp16/bf16 tensors",
1061
+ )
1062
+ )
1063
+
1064
+ def _apply_inference_stability_fallback(self, device: torch.device) -> None:
1065
+ fallback_attn = self._select_fallback_attention_implementation(device)
1066
+ if next(self.parameters()).dtype != torch.float32:
1067
+ self.to(device=device, dtype=torch.float32)
1068
+ self._set_attention_implementation(fallback_attn)
1069
+ logging.warning(
1070
+ "retrying inference with dtype=float32 attn_implementation=%s due to numerical instability",
1071
+ fallback_attn,
1072
+ )
1073
+
1074
+ def _load_audio_tokenizer(
1075
+ self,
1076
+ audio_tokenizer=None,
1077
+ audio_tokenizer_type: Optional[str] = None,
1078
+ audio_tokenizer_pretrained_name_or_path: Optional[str] = None,
1079
+ device: Optional[Union[str, torch.device]] = None,
1080
+ ):
1081
+ if audio_tokenizer is not None:
1082
+ return audio_tokenizer
1083
+
1084
+ resolved_type = self._resolve_audio_tokenizer_type(audio_tokenizer_type=audio_tokenizer_type)
1085
+ if resolved_type != MOSS_AUDIO_TOKENIZER_NANO_TYPE:
1086
+ raise ValueError(
1087
+ f"Unsupported audio tokenizer type {resolved_type!r}; expected '{MOSS_AUDIO_TOKENIZER_NANO_TYPE}'."
1088
+ )
1089
+
1090
+ resolved_pretrained_name_or_path = (
1091
+ audio_tokenizer_pretrained_name_or_path
1092
+ or getattr(self.config, "audio_tokenizer_pretrained_name_or_path", None)
1093
+ or DEFAULT_MOSS_AUDIO_TOKENIZER_PRETRAINED_NAME_OR_PATH
1094
+ )
1095
+ candidate_path = Path(str(resolved_pretrained_name_or_path)).expanduser()
1096
+ if candidate_path.exists():
1097
+ load_source = str(candidate_path.resolve())
1098
+ load_kwargs: dict[str, object] = {
1099
+ "trust_remote_code": True,
1100
+ "local_files_only": True,
1101
+ "force_download": True,
1102
+ }
1103
+ cache_key = f"{resolved_type}|{load_source}"
1104
+ else:
1105
+ load_source = str(resolved_pretrained_name_or_path)
1106
+ load_kwargs = {
1107
+ "trust_remote_code": True,
1108
+ }
1109
+ cache_key = f"{resolved_type}|{load_source}"
1110
+
1111
+ cached = getattr(self, "_cached_audio_tokenizer", None)
1112
+ cached_path = getattr(self, "_cached_audio_tokenizer_path", None)
1113
+ if cached is not None and cached_path == cache_key:
1114
+ tokenizer = cached
1115
+ else:
1116
+ tokenizer = AutoModel.from_pretrained(load_source, **load_kwargs)
1117
+ if hasattr(tokenizer, "eval"):
1118
+ tokenizer.eval()
1119
+ self._cached_audio_tokenizer = tokenizer
1120
+ self._cached_audio_tokenizer_path = cache_key
1121
+
1122
+ resolved_device = self._resolve_device(device)
1123
+ return tokenizer.to(resolved_device) if hasattr(tokenizer, "to") else tokenizer
1124
+
1125
+ @staticmethod
1126
+ def _extract_tensor_candidate(output: Any) -> Any:
1127
+ if torch.is_tensor(output) or isinstance(output, np.ndarray):
1128
+ return output
1129
+ for attr_name in ("audio_codes", "audio_token_ids", "codes", "tokens", "input_ids"):
1130
+ value = getattr(output, attr_name, None)
1131
+ if value is not None:
1132
+ return value
1133
+ if isinstance(output, dict):
1134
+ for key in ("audio_codes", "audio_token_ids", "codes", "tokens", "input_ids"):
1135
+ if key in output:
1136
+ return output[key]
1137
+ if len(output) == 1:
1138
+ return next(iter(output.values()))
1139
+ if isinstance(output, (list, tuple)) and output:
1140
+ if len(output) == 2 and isinstance(output[1], (int, float)):
1141
+ return output[0]
1142
+ return NanoTTSGlobalLocalForCausalLM._extract_tensor_candidate(output[0])
1143
+ raise TypeError(f"Unsupported audio tokenizer output type: {type(output)!r}")
1144
+
1145
+ @staticmethod
1146
+ def _extract_audio_code_length(output: Any) -> Optional[int]:
1147
+ for attr_name in ("audio_codes_lengths", "audio_token_ids_lengths", "codes_lengths", "lengths"):
1148
+ candidate = getattr(output, attr_name, None)
1149
+ if candidate is not None:
1150
+ lengths = torch.as_tensor(candidate).reshape(-1)
1151
+ if lengths.numel() > 0:
1152
+ return int(lengths[0].item())
1153
+ if isinstance(output, dict):
1154
+ for key in ("audio_codes_lengths", "audio_token_ids_lengths", "codes_lengths", "lengths"):
1155
+ if key in output:
1156
+ lengths = torch.as_tensor(output[key]).reshape(-1)
1157
+ if lengths.numel() > 0:
1158
+ return int(lengths[0].item())
1159
+ if isinstance(output, (list, tuple)) and len(output) >= 2:
1160
+ candidate = output[1]
1161
+ if torch.is_tensor(candidate) or isinstance(candidate, np.ndarray):
1162
+ lengths = torch.as_tensor(candidate).reshape(-1)
1163
+ if lengths.numel() > 0:
1164
+ return int(lengths[0].item())
1165
+ if isinstance(candidate, (int, float)):
1166
+ return int(candidate)
1167
+ return None
1168
+
1169
+ def _normalize_audio_codes(self, audio_codes: Any) -> torch.LongTensor:
1170
+ code_length = self._extract_audio_code_length(audio_codes)
1171
+ tensor = torch.as_tensor(self._extract_tensor_candidate(audio_codes))
1172
+ if tensor.ndim == 1:
1173
+ tensor = tensor.unsqueeze(-1)
1174
+ if tensor.ndim == 3:
1175
+ if tensor.shape[1] == 1 and tensor.shape[0] >= self.config.n_vq:
1176
+ tensor = tensor[: self.config.n_vq, 0, :].transpose(0, 1)
1177
+ elif tensor.shape[0] == 1:
1178
+ tensor = tensor[0]
1179
+ elif tensor.shape[1] == self.config.n_vq:
1180
+ tensor = tensor.transpose(1, 2)[0]
1181
+ elif tensor.shape[-1] == self.config.n_vq:
1182
+ tensor = tensor[0]
1183
+ else:
1184
+ raise ValueError(f"Unable to normalize audio codes with shape {tuple(tensor.shape)}")
1185
+
1186
+ if tensor.ndim != 2:
1187
+ raise ValueError(f"Expected audio codes with 2 dims after normalization, got {tuple(tensor.shape)}")
1188
+ if tensor.shape[-1] != self.config.n_vq and tensor.shape[0] == self.config.n_vq:
1189
+ tensor = tensor.transpose(0, 1)
1190
+ elif tensor.shape[-1] != self.config.n_vq and tensor.shape[0] > self.config.n_vq:
1191
+ tensor = tensor[: self.config.n_vq].transpose(0, 1)
1192
+ elif tensor.shape[-1] > self.config.n_vq:
1193
+ tensor = tensor[:, : self.config.n_vq]
1194
+ if tensor.shape[-1] != self.config.n_vq:
1195
+ raise ValueError(
1196
+ f"Expected normalized audio codes with trailing dim {self.config.n_vq}, got {tuple(tensor.shape)}"
1197
+ )
1198
+ if code_length is not None:
1199
+ tensor = tensor[:code_length]
1200
+ return tensor.to(dtype=torch.long)
1201
+
1202
+ def _extract_waveform_and_sample_rate(
1203
+ self,
1204
+ decode_output: Any,
1205
+ fallback_sample_rate: int,
1206
+ ) -> tuple[torch.FloatTensor, int]:
1207
+ sample_rate = fallback_sample_rate
1208
+ waveform = decode_output
1209
+ waveform_length = None
1210
+
1211
+ for key in ("sample_rate", "sampling_rate"):
1212
+ value = getattr(decode_output, key, None)
1213
+ if value is not None:
1214
+ sample_rate = int(value)
1215
+ break
1216
+ for key in ("waveform", "audio", "wav", "samples"):
1217
+ value = getattr(decode_output, key, None)
1218
+ if value is not None:
1219
+ waveform = value
1220
+ break
1221
+ for key in ("audio_lengths", "waveform_lengths", "lengths"):
1222
+ value = getattr(decode_output, key, None)
1223
+ if value is not None:
1224
+ lengths = torch.as_tensor(value).reshape(-1)
1225
+ if lengths.numel() > 0:
1226
+ waveform_length = int(lengths[0].item())
1227
+ break
1228
+
1229
+ if isinstance(decode_output, dict):
1230
+ for key in ("sample_rate", "sampling_rate"):
1231
+ if key in decode_output:
1232
+ sample_rate = int(decode_output[key])
1233
+ break
1234
+ for key in ("waveform", "audio", "wav", "samples"):
1235
+ if key in decode_output:
1236
+ waveform = decode_output[key]
1237
+ break
1238
+ for key in ("audio_lengths", "waveform_lengths", "lengths"):
1239
+ if key in decode_output:
1240
+ lengths = torch.as_tensor(decode_output[key]).reshape(-1)
1241
+ if lengths.numel() > 0:
1242
+ waveform_length = int(lengths[0].item())
1243
+ break
1244
+ elif isinstance(decode_output, (list, tuple)) and decode_output:
1245
+ if len(decode_output) == 2 and isinstance(decode_output[1], (int, float)):
1246
+ waveform = decode_output[0]
1247
+ sample_rate = int(decode_output[1])
1248
+ else:
1249
+ waveform = decode_output[0]
1250
+
1251
+ waveform_tensor = torch.as_tensor(waveform, dtype=torch.float32)
1252
+ if waveform_tensor.ndim == 3 and waveform_tensor.shape[0] == 1:
1253
+ waveform_tensor = waveform_tensor[0]
1254
+ if waveform_tensor.ndim == 2 and waveform_tensor.shape[0] > waveform_tensor.shape[1]:
1255
+ waveform_tensor = waveform_tensor.transpose(0, 1)
1256
+ if waveform_tensor.ndim == 1:
1257
+ waveform_tensor = waveform_tensor.unsqueeze(0)
1258
+ if waveform_tensor.ndim != 2:
1259
+ raise ValueError(f"Expected decoded waveform with 2 dims, got {tuple(waveform_tensor.shape)}")
1260
+ if waveform_length is not None:
1261
+ waveform_tensor = waveform_tensor[..., : max(0, waveform_length)]
1262
+ return waveform_tensor.cpu(), sample_rate
1263
+
1264
+ def _call_audio_encode(
1265
+ self,
1266
+ audio_tokenizer,
1267
+ waveform: torch.FloatTensor,
1268
+ sample_rate: int,
1269
+ ) -> Any:
1270
+ del sample_rate
1271
+ batch_encode_fn = getattr(audio_tokenizer, "batch_encode", None)
1272
+ if batch_encode_fn is None:
1273
+ raise AttributeError("audio_tokenizer must provide a batch_encode method.")
1274
+
1275
+ waveform_tensor = torch.as_tensor(waveform, dtype=torch.float32, device=self._resolve_device(waveform.device))
1276
+ if waveform_tensor.ndim == 1:
1277
+ waveform_tensor = waveform_tensor.unsqueeze(0)
1278
+ if waveform_tensor.ndim != 2:
1279
+ raise ValueError(
1280
+ f"MOSS audio tokenizer encode expects waveform shaped like (C, T), got {tuple(waveform_tensor.shape)}"
1281
+ )
1282
+
1283
+ with self._audio_tokenizer_inference_context(audio_tokenizer, waveform_tensor.device):
1284
+ return batch_encode_fn([waveform_tensor], chunk_duration=None)
1285
+
1286
+ def _call_audio_decode(
1287
+ self,
1288
+ audio_tokenizer,
1289
+ audio_token_ids: torch.LongTensor,
1290
+ sample_rate: int,
1291
+ nq: Optional[int] = None,
1292
+ ) -> Any:
1293
+ del sample_rate
1294
+ batch_decode_fn = getattr(audio_tokenizer, "batch_decode", None)
1295
+ if batch_decode_fn is None:
1296
+ raise AttributeError("audio_tokenizer must provide a batch_decode method.")
1297
+
1298
+ effective_nq = self._resolve_inference_nq(nq)
1299
+ decode_codes = self._prepare_audio_codes_for_decode(audio_token_ids, nq=effective_nq)
1300
+ with self._audio_tokenizer_inference_context(audio_tokenizer, decode_codes.device):
1301
+ return batch_decode_fn([decode_codes], num_quantizers=effective_nq, chunk_duration=None)
1302
+
1303
+ @staticmethod
1304
+ def _resolve_audio_tokenizer_downsample_rate(audio_tokenizer) -> int:
1305
+ for holder in (audio_tokenizer, getattr(audio_tokenizer, "config", None)):
1306
+ if holder is None:
1307
+ continue
1308
+ for attr_name in ("downsample_rate", "hop_length", "frame_size"):
1309
+ value = getattr(holder, attr_name, None)
1310
+ if value is not None:
1311
+ return int(value)
1312
+ sampling_rate = getattr(holder, "sampling_rate", None)
1313
+ frame_rate = getattr(holder, "frame_rate", None)
1314
+ if sampling_rate is not None and frame_rate not in (None, 0):
1315
+ return int(round(float(sampling_rate) / float(frame_rate)))
1316
+ raise ValueError("audio_tokenizer.downsample_rate is required for prompt-audio decoding.")
1317
+
1318
+ def _resolve_audio_tokenizer_sample_rate(self, audio_tokenizer) -> int:
1319
+ for holder in (audio_tokenizer, getattr(audio_tokenizer, "config", None)):
1320
+ if holder is None:
1321
+ continue
1322
+ for attr_name in ("sampling_rate", "sample_rate"):
1323
+ value = getattr(holder, attr_name, None)
1324
+ if value is not None:
1325
+ return int(value)
1326
+ return int(self.config.audio_tokenizer_sample_rate)
1327
+
1328
+ @staticmethod
1329
+ def _resolve_audio_tokenizer_channels(audio_tokenizer) -> int:
1330
+ for holder in (audio_tokenizer, getattr(audio_tokenizer, "config", None)):
1331
+ if holder is None:
1332
+ continue
1333
+ for attr_name in ("number_channels", "channels_numbers", "audio_channels", "channels", "num_channels"):
1334
+ value = getattr(holder, attr_name, None)
1335
+ if value is not None:
1336
+ return int(value)
1337
+ return 1
1338
+
1339
+ @staticmethod
1340
+ def _audio_tokenizer_inference_context(audio_tokenizer, device: Union[str, torch.device]):
1341
+ del audio_tokenizer, device
1342
+ return nullcontext()
1343
+
1344
+ def _prepare_audio_codes_for_decode(
1345
+ self,
1346
+ audio_token_ids: torch.LongTensor,
1347
+ nq: Optional[int] = None,
1348
+ ) -> torch.LongTensor:
1349
+ effective_nq = self._resolve_inference_nq(nq)
1350
+ tensor = torch.as_tensor(audio_token_ids, dtype=torch.long)
1351
+ if tensor.ndim == 2:
1352
+ if tensor.shape[-1] == self.config.n_vq and tensor.shape[0] != self.config.n_vq:
1353
+ return tensor[:, :effective_nq].transpose(0, 1).contiguous()
1354
+ if tensor.shape[0] == self.config.n_vq:
1355
+ return tensor[:effective_nq].contiguous()
1356
+ elif tensor.ndim == 3:
1357
+ if tensor.shape[-1] == self.config.n_vq:
1358
+ return tensor[..., :effective_nq].permute(2, 0, 1).contiguous()
1359
+ if tensor.shape[0] == self.config.n_vq:
1360
+ return tensor[:effective_nq].contiguous()
1361
+ raise ValueError(
1362
+ f"Expected generated audio token ids shaped like (T, {self.config.n_vq}) or ({self.config.n_vq}, T); got {tuple(tensor.shape)}"
1363
+ )
1364
+
1365
+ def _load_reference_audio(
1366
+ self,
1367
+ reference_audio_path: Union[str, Path],
1368
+ target_sample_rate: int,
1369
+ target_channels: int,
1370
+ ) -> tuple[torch.FloatTensor, int]:
1371
+ waveform, sample_rate = torchaudio.load(str(reference_audio_path))
1372
+ waveform = waveform.to(torch.float32)
1373
+ if sample_rate != target_sample_rate:
1374
+ waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
1375
+ sample_rate = target_sample_rate
1376
+ current_channels = int(waveform.shape[0])
1377
+ if current_channels == target_channels:
1378
+ return waveform, sample_rate
1379
+ if current_channels == 1 and target_channels > 1:
1380
+ return waveform.repeat(target_channels, 1), sample_rate
1381
+ if current_channels > 1 and target_channels == 1:
1382
+ return waveform.mean(dim=0, keepdim=True), sample_rate
1383
+ raise ValueError(f"Unsupported reference audio channel conversion: {current_channels} -> {target_channels}")
1384
+
1385
+ def _decode_local_last_hidden_state(
1386
+ self,
1387
+ local_inputs_embeds: torch.FloatTensor,
1388
+ ) -> torch.FloatTensor:
1389
+ local_attention_mask = torch.ones(
1390
+ local_inputs_embeds.shape[:2],
1391
+ dtype=torch.bool,
1392
+ device=local_inputs_embeds.device,
1393
+ )
1394
+ local_outputs = self.local_transformer(
1395
+ input_ids=None,
1396
+ attention_mask=local_attention_mask,
1397
+ position_ids=None,
1398
+ inputs_embeds=local_inputs_embeds,
1399
+ use_cache=False,
1400
+ output_attentions=False,
1401
+ output_hidden_states=False,
1402
+ return_dict=True,
1403
+ cu_seqlens=None,
1404
+ num_sequences=None,
1405
+ )
1406
+ return local_outputs.last_hidden_state[:, -1, :]
1407
+
1408
+ @torch.no_grad()
1409
+ def generate(
1410
+ self,
1411
+ input_ids: torch.LongTensor,
1412
+ attention_mask: Optional[torch.Tensor] = None,
1413
+ nq: Optional[int] = None,
1414
+ max_new_frames: int = 300,
1415
+ do_sample: bool = False,
1416
+ text_temperature: float = 1.5,
1417
+ text_top_p: float = 1.0,
1418
+ text_top_k: int = 50,
1419
+ audio_temperature: float = 1.7,
1420
+ audio_top_p: float = 0.8,
1421
+ audio_top_k: int = 25,
1422
+ audio_repetition_penalty: float = 1.0,
1423
+ use_kv_cache: bool = True,
1424
+ return_dict_in_generate: bool = True,
1425
+ ):
1426
+ if input_ids.ndim == 2:
1427
+ input_ids = input_ids.unsqueeze(0)
1428
+ if input_ids.ndim != 3:
1429
+ raise ValueError(f"Expected input_ids with 3 dims, got shape {tuple(input_ids.shape)}")
1430
+ if attention_mask is None:
1431
+ attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=input_ids.device)
1432
+ elif attention_mask.ndim == 1:
1433
+ attention_mask = attention_mask.unsqueeze(0)
1434
+
1435
+ effective_nq = self._resolve_inference_nq(nq)
1436
+ batch_size = input_ids.shape[0]
1437
+ current_input_ids = input_ids
1438
+ current_attention_mask = attention_mask.to(device=input_ids.device)
1439
+ current_model_input_ids = current_input_ids
1440
+ generated_frames = []
1441
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
1442
+ past_key_values = None
1443
+ local_dtype = self.local_transformer.ln_f.weight.dtype
1444
+
1445
+ for _ in range(max_new_frames):
1446
+ generated_audio_history = torch.stack(generated_frames, dim=1) if generated_frames else None
1447
+ global_inputs_embeds = self._build_inputs_embeds(current_model_input_ids)
1448
+ global_outputs = self.transformer(
1449
+ input_ids=None,
1450
+ past_key_values=past_key_values,
1451
+ attention_mask=current_attention_mask,
1452
+ position_ids=None,
1453
+ inputs_embeds=global_inputs_embeds,
1454
+ use_cache=use_kv_cache,
1455
+ output_attentions=False,
1456
+ output_hidden_states=False,
1457
+ return_dict=True,
1458
+ cu_seqlens=None,
1459
+ num_sequences=None,
1460
+ )
1461
+ global_hidden_states = global_outputs.last_hidden_state[:, -1, :].to(dtype=local_dtype)
1462
+
1463
+ local_inputs_embeds = global_hidden_states.unsqueeze(1)
1464
+ local_hidden_states = self._decode_local_last_hidden_state(local_inputs_embeds)
1465
+ text_logits = self.text_lm_head(local_hidden_states)
1466
+ self._ensure_finite_generation_logits(text_logits, "text logits")
1467
+ next_text_tokens = self._sample_next_assistant_text_token(
1468
+ logits=text_logits,
1469
+ do_sample=do_sample,
1470
+ temperature=text_temperature,
1471
+ top_k=text_top_k,
1472
+ top_p=text_top_p,
1473
+ )
1474
+ should_continue = next_text_tokens.eq(self.config.audio_assistant_slot_token_id) & ~finished
1475
+ finished = finished | next_text_tokens.eq(self.config.audio_end_token_id)
1476
+ if not should_continue.any():
1477
+ break
1478
+
1479
+ next_frame_tokens = []
1480
+ current_local_input = self.transformer.wte(next_text_tokens).to(dtype=local_dtype)
1481
+ for channel_index in range(effective_nq):
1482
+ local_inputs_embeds = torch.cat([local_inputs_embeds, current_local_input.unsqueeze(1)], dim=1)
1483
+ local_hidden_states = self._decode_local_last_hidden_state(local_inputs_embeds)
1484
+ channel_logits = self.audio_lm_heads[channel_index](local_hidden_states)
1485
+ self._ensure_finite_generation_logits(channel_logits, f"audio logits[{channel_index}]")
1486
+ channel_token = self._sample_next_token(
1487
+ logits=channel_logits,
1488
+ do_sample=do_sample,
1489
+ temperature=audio_temperature,
1490
+ top_k=audio_top_k,
1491
+ top_p=audio_top_p,
1492
+ previous_token_ids=(
1493
+ None if generated_audio_history is None else generated_audio_history[:, :, channel_index]
1494
+ ),
1495
+ repetition_penalty=audio_repetition_penalty,
1496
+ )
1497
+ next_frame_tokens.append(channel_token)
1498
+ current_local_input = self.audio_embeddings[channel_index](channel_token).to(dtype=local_dtype)
1499
+
1500
+ next_frame_prefix = torch.stack(next_frame_tokens, dim=-1)
1501
+ if effective_nq < self.config.n_vq:
1502
+ next_frame = torch.full(
1503
+ (batch_size, self.config.n_vq),
1504
+ self.config.audio_pad_token_id,
1505
+ dtype=next_frame_prefix.dtype,
1506
+ device=next_frame_prefix.device,
1507
+ )
1508
+ next_frame[:, :effective_nq] = next_frame_prefix
1509
+ else:
1510
+ next_frame = next_frame_prefix
1511
+ padded_next_frame = next_frame.masked_fill(~should_continue.unsqueeze(-1), self.config.audio_pad_token_id)
1512
+ generated_frames.append(padded_next_frame)
1513
+
1514
+ next_row = self._build_generation_row(
1515
+ batch_size=batch_size,
1516
+ device=input_ids.device,
1517
+ audio_token_ids=padded_next_frame,
1518
+ )
1519
+ if (~should_continue).any():
1520
+ next_row[~should_continue, 0, 0] = self.config.pad_token_id
1521
+ next_row[~should_continue, 0, 1:] = self.config.audio_pad_token_id
1522
+
1523
+ current_input_ids = torch.cat([current_input_ids, next_row], dim=1)
1524
+ current_attention_mask = torch.cat([current_attention_mask, should_continue.unsqueeze(1)], dim=1)
1525
+ if use_kv_cache:
1526
+ current_model_input_ids = next_row
1527
+ past_key_values = global_outputs.past_key_values
1528
+ else:
1529
+ current_model_input_ids = current_input_ids
1530
+
1531
+ if generated_frames:
1532
+ audio_token_ids = torch.stack(generated_frames, dim=1)
1533
+ else:
1534
+ audio_token_ids = torch.empty((batch_size, 0, self.config.n_vq), dtype=torch.long, device=input_ids.device)
1535
+
1536
+ if not return_dict_in_generate:
1537
+ return audio_token_ids
1538
+ return NanoTTSGenerationOutput(audio_token_ids=audio_token_ids, prompt_input_ids=input_ids)
1539
+
1540
+ @torch.no_grad()
1541
+ def inference(
1542
+ self,
1543
+ text: str,
1544
+ output_audio_path: Union[str, Path],
1545
+ mode: str = "continuation",
1546
+ prompt_text: Optional[str] = None,
1547
+ prompt_audio_path: Optional[Union[str, Path]] = None,
1548
+ reference_audio_path: Optional[Union[str, Path]] = None,
1549
+ text_tokenizer=None,
1550
+ text_tokenizer_path: Optional[str] = None,
1551
+ audio_tokenizer=None,
1552
+ audio_tokenizer_type: Optional[str] = None,
1553
+ audio_tokenizer_pretrained_name_or_path: Optional[str] = None,
1554
+ device: Optional[Union[str, torch.device]] = None,
1555
+ nq: Optional[int] = None,
1556
+ max_new_frames: int = 300,
1557
+ do_sample: bool = False,
1558
+ text_temperature: float = 1.5,
1559
+ text_top_p: float = 1.0,
1560
+ text_top_k: int = 50,
1561
+ audio_temperature: float = 1.7,
1562
+ audio_top_p: float = 0.8,
1563
+ audio_top_k: int = 25,
1564
+ audio_repetition_penalty: float = 1.0,
1565
+ use_kv_cache: bool = True,
1566
+ voice_clone_max_text_tokens: int = DEFAULT_VOICE_CLONE_MAX_TEXT_TOKENS,
1567
+ voice_clone_max_memory_per_sample_gb: float = DEFAULT_VOICE_CLONE_MAX_MEMORY_PER_SAMPLE_GB,
1568
+ ) -> dict[str, Any]:
1569
+ resolved_device = self._resolve_device(device)
1570
+ effective_nq = self._resolve_inference_nq(nq)
1571
+ if next(self.parameters()).device != resolved_device:
1572
+ self.to(resolved_device)
1573
+
1574
+ was_training = self.training
1575
+ self.eval()
1576
+
1577
+ text_tokenizer = self._load_text_tokenizer(
1578
+ text_tokenizer=text_tokenizer,
1579
+ text_tokenizer_path=text_tokenizer_path,
1580
+ )
1581
+ audio_tokenizer = self._load_audio_tokenizer(
1582
+ audio_tokenizer=audio_tokenizer,
1583
+ audio_tokenizer_type=audio_tokenizer_type,
1584
+ audio_tokenizer_pretrained_name_or_path=audio_tokenizer_pretrained_name_or_path,
1585
+ device=resolved_device,
1586
+ )
1587
+
1588
+ target_sample_rate = self._resolve_audio_tokenizer_sample_rate(audio_tokenizer)
1589
+ target_channels = self._resolve_audio_tokenizer_channels(audio_tokenizer)
1590
+ effective_prompt_audio_path = prompt_audio_path or reference_audio_path
1591
+ resolved_mode = self._resolve_inference_mode(
1592
+ mode=mode,
1593
+ has_prompt_text=prompt_text is not None,
1594
+ has_prompt_audio=effective_prompt_audio_path is not None,
1595
+ )
1596
+ if reference_audio_path is not None and prompt_audio_path is None:
1597
+ logging.warning(
1598
+ "reference_audio_path=%s is treated as prompt_audio_path for backward compatibility.",
1599
+ reference_audio_path,
1600
+ )
1601
+
1602
+ prompt_audio_codes = None
1603
+ if effective_prompt_audio_path is not None:
1604
+ waveform, sample_rate = self._load_reference_audio(
1605
+ effective_prompt_audio_path,
1606
+ target_sample_rate,
1607
+ target_channels,
1608
+ )
1609
+ encoded = self._call_audio_encode(
1610
+ audio_tokenizer=audio_tokenizer,
1611
+ waveform=waveform.to(resolved_device),
1612
+ sample_rate=sample_rate,
1613
+ )
1614
+ prompt_audio_codes = self._mask_unused_audio_channels(
1615
+ self._normalize_audio_codes(encoded),
1616
+ nq=effective_nq,
1617
+ ).to(resolved_device)
1618
+
1619
+ if resolved_mode == "voice_clone":
1620
+ split_voice_clone_text_chunks = self._split_text_into_best_sentences(
1621
+ text_tokenizer=text_tokenizer,
1622
+ text=text,
1623
+ max_tokens=voice_clone_max_text_tokens,
1624
+ )
1625
+ voice_clone_text_chunks = split_voice_clone_text_chunks if len(split_voice_clone_text_chunks) > 1 else [text]
1626
+ else:
1627
+ voice_clone_text_chunks = [text]
1628
+
1629
+ generated_audio_token_chunks: list[torch.LongTensor] = []
1630
+ decoded_waveform_chunks: list[torch.FloatTensor] = []
1631
+ decoded_sample_rate: Optional[int] = None
1632
+
1633
+ if resolved_mode == "voice_clone" and len(voice_clone_text_chunks) > 1:
1634
+ voice_clone_chunk_batch_size = self._resolve_voice_clone_chunk_batch_size(
1635
+ resolved_device=resolved_device,
1636
+ chunk_count=len(voice_clone_text_chunks),
1637
+ max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb),
1638
+ )
1639
+ else:
1640
+ voice_clone_chunk_batch_size = 1
1641
+
1642
+ for batch_start in range(0, len(voice_clone_text_chunks), voice_clone_chunk_batch_size):
1643
+ batch_chunks = voice_clone_text_chunks[batch_start : batch_start + voice_clone_chunk_batch_size]
1644
+ batch_prompt_input_ids: list[torch.LongTensor] = []
1645
+ batch_attention_masks: list[torch.BoolTensor] = []
1646
+ for text_chunk in batch_chunks:
1647
+ prompt_input_ids, attention_mask = self.build_inference_input_ids(
1648
+ text=text_chunk,
1649
+ text_tokenizer=text_tokenizer,
1650
+ mode=resolved_mode,
1651
+ prompt_text=prompt_text,
1652
+ prompt_audio_codes=prompt_audio_codes,
1653
+ device=resolved_device,
1654
+ )
1655
+ batch_prompt_input_ids.append(prompt_input_ids)
1656
+ batch_attention_masks.append(attention_mask)
1657
+
1658
+ batched_prompt_input_ids, batched_attention_mask = self._left_pad_inference_batch(
1659
+ input_id_batches=batch_prompt_input_ids,
1660
+ attention_mask_batches=batch_attention_masks,
1661
+ device=resolved_device,
1662
+ )
1663
+ batched_audio_token_ids = self._generate_audio_token_ids_with_fallback(
1664
+ prompt_input_ids=batched_prompt_input_ids,
1665
+ attention_mask=batched_attention_mask,
1666
+ effective_nq=effective_nq,
1667
+ max_new_frames=max_new_frames,
1668
+ do_sample=do_sample,
1669
+ text_temperature=text_temperature,
1670
+ text_top_p=text_top_p,
1671
+ text_top_k=text_top_k,
1672
+ audio_temperature=audio_temperature,
1673
+ audio_top_p=audio_top_p,
1674
+ audio_top_k=audio_top_k,
1675
+ audio_repetition_penalty=audio_repetition_penalty,
1676
+ use_kv_cache=use_kv_cache,
1677
+ resolved_device=resolved_device,
1678
+ )
1679
+
1680
+ for sample_index in range(len(batch_chunks)):
1681
+ audio_token_ids = self._trim_generated_audio_token_ids(
1682
+ batched_audio_token_ids[sample_index],
1683
+ effective_nq=effective_nq,
1684
+ )
1685
+ generated_audio_token_chunks.append(audio_token_ids)
1686
+
1687
+ if resolved_mode == "voice_clone" and len(voice_clone_text_chunks) > 1:
1688
+ decoded_waveform, current_sample_rate = self._decode_audio_token_ids_to_waveform(
1689
+ audio_tokenizer=audio_tokenizer,
1690
+ audio_token_ids=audio_token_ids,
1691
+ target_sample_rate=target_sample_rate,
1692
+ effective_nq=effective_nq,
1693
+ resolved_device=resolved_device,
1694
+ )
1695
+ if decoded_sample_rate is None:
1696
+ decoded_sample_rate = current_sample_rate
1697
+ elif decoded_sample_rate != current_sample_rate:
1698
+ raise ValueError(
1699
+ f"Decoded sample rates differ across voice_clone chunks: {decoded_sample_rate} vs {current_sample_rate}"
1700
+ )
1701
+ decoded_waveform_chunks.append(decoded_waveform)
1702
+
1703
+ if generated_audio_token_chunks:
1704
+ audio_token_ids = torch.cat(generated_audio_token_chunks, dim=0)
1705
+ else:
1706
+ audio_token_ids = torch.empty((0, self.config.n_vq), dtype=torch.long, device=resolved_device)
1707
+
1708
+ if resolved_mode == "voice_clone" and len(voice_clone_text_chunks) > 1:
1709
+ waveform = (
1710
+ self._concat_voice_clone_waveform_chunks(
1711
+ waveform_chunks=decoded_waveform_chunks,
1712
+ text_chunks=voice_clone_text_chunks,
1713
+ sample_rate=decoded_sample_rate,
1714
+ )
1715
+ if decoded_waveform_chunks
1716
+ else torch.zeros((target_channels, 0), dtype=torch.float32)
1717
+ )
1718
+ else:
1719
+ decode_audio_token_ids = audio_token_ids
1720
+ prompt_waveform_prefix_samples = 0
1721
+ if resolved_mode == "continuation" and prompt_audio_codes is not None:
1722
+ decode_audio_token_ids = torch.cat([prompt_audio_codes, audio_token_ids], dim=0)
1723
+ prompt_waveform_prefix_samples = (
1724
+ int(prompt_audio_codes.shape[0]) * self._resolve_audio_tokenizer_downsample_rate(audio_tokenizer)
1725
+ )
1726
+
1727
+ waveform, decoded_sample_rate = self._decode_audio_token_ids_to_waveform(
1728
+ audio_tokenizer=audio_tokenizer,
1729
+ audio_token_ids=decode_audio_token_ids,
1730
+ target_sample_rate=target_sample_rate,
1731
+ effective_nq=effective_nq,
1732
+ resolved_device=resolved_device,
1733
+ )
1734
+ if prompt_waveform_prefix_samples > 0:
1735
+ if decoded_sample_rate != target_sample_rate:
1736
+ prompt_waveform_prefix_samples = int(
1737
+ round(prompt_waveform_prefix_samples * float(decoded_sample_rate) / float(target_sample_rate))
1738
+ )
1739
+ prompt_waveform_prefix_samples = min(prompt_waveform_prefix_samples, int(waveform.shape[-1]))
1740
+ waveform = waveform[:, prompt_waveform_prefix_samples:]
1741
+
1742
+ assert decoded_sample_rate is not None
1743
+
1744
+ output_path = Path(output_audio_path)
1745
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1746
+ torchaudio.save(str(output_path), waveform, decoded_sample_rate)
1747
+
1748
+ if was_training:
1749
+ self.train()
1750
+
1751
+ return {
1752
+ "audio_path": str(output_path),
1753
+ "sample_rate": decoded_sample_rate,
1754
+ "audio_token_ids": audio_token_ids.detach().cpu(),
1755
+ "waveform": waveform,
1756
+ "reference_audio_token_ids": None if prompt_audio_codes is None else prompt_audio_codes.detach().cpu(),
1757
+ }
prompting.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Sequence
4
+
5
+ from .configuration_nanotts import NanoTTSConfig
6
+
7
+
8
+ USER_ROLE_PREFIX = "user\n"
9
+ USER_TEMPLATE_REFERENCE_PREFIX = (
10
+ "<user_inst>\n"
11
+ "- Reference(s):\n"
12
+ )
13
+ USER_TEMPLATE_AFTER_REFERENCE = (
14
+ "\n- Instruction:\nNone\n"
15
+ "- Tokens:\nNone\n"
16
+ "- Quality:\nNone\n"
17
+ "- Sound Event:\nNone\n"
18
+ "- Ambient Sound:\nNone\n"
19
+ "- Language:\nNone\n"
20
+ "- Text:\n"
21
+ )
22
+ USER_TEMPLATE_PREFIX = USER_TEMPLATE_REFERENCE_PREFIX + "None" + USER_TEMPLATE_AFTER_REFERENCE
23
+ USER_TEMPLATE_SUFFIX = "\n</user_inst>"
24
+ ASSISTANT_TURN_PREFIX = "\n"
25
+ ASSISTANT_ROLE_PREFIX = "assistant\n"
26
+
27
+
28
+ def encode_text(tokenizer, text: str) -> List[int]:
29
+ try:
30
+ return list(tokenizer.encode(text, add_special_tokens=False))
31
+ except TypeError:
32
+ return list(tokenizer.encode(text))
33
+
34
+
35
+ def decode_text(tokenizer, token_ids: Sequence[int]) -> str:
36
+ try:
37
+ return str(
38
+ tokenizer.decode(
39
+ list(token_ids),
40
+ skip_special_tokens=False,
41
+ clean_up_tokenization_spaces=False,
42
+ )
43
+ )
44
+ except TypeError:
45
+ try:
46
+ return str(tokenizer.decode(list(token_ids), skip_special_tokens=False))
47
+ except TypeError:
48
+ return str(tokenizer.decode(list(token_ids)))
49
+
50
+
51
+ def build_user_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
52
+ return [config.im_start_token_id] + encode_text(tokenizer, USER_ROLE_PREFIX) + encode_text(
53
+ tokenizer,
54
+ USER_TEMPLATE_REFERENCE_PREFIX,
55
+ )
56
+
57
+
58
+ def build_user_prompt_after_reference(tokenizer) -> List[int]:
59
+ return encode_text(tokenizer, USER_TEMPLATE_AFTER_REFERENCE)
60
+
61
+
62
+ def build_assistant_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
63
+ return encode_text(tokenizer, USER_TEMPLATE_SUFFIX) + [config.im_end_token_id] + encode_text(
64
+ tokenizer,
65
+ ASSISTANT_TURN_PREFIX,
66
+ ) + [config.im_start_token_id] + encode_text(
67
+ tokenizer,
68
+ ASSISTANT_ROLE_PREFIX,
69
+ )
70
+
71
+
72
+ def build_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
73
+ return (
74
+ build_user_prompt_prefix(tokenizer, config)
75
+ + encode_text(tokenizer, "None")
76
+ + build_user_prompt_after_reference(tokenizer)
77
+ )
78
+
79
+
80
+ def build_prompt_suffix(tokenizer, config: NanoTTSConfig) -> List[int]:
81
+ return build_assistant_prompt_prefix(tokenizer, config)
82
+
83
+
84
+ def build_prompt_token_ids(
85
+ tokenizer,
86
+ config: NanoTTSConfig,
87
+ text_token_ids: Sequence[int],
88
+ ) -> List[int]:
89
+ return build_prompt_prefix(tokenizer, config) + [int(token_id) for token_id in text_token_ids] + build_prompt_suffix(
90
+ tokenizer,
91
+ config,
92
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d19e63fdc6a35f61a7d1c27e06bfacaba7c1ed40ea3c619c86efc64bcd50a496
3
+ size 234693095
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_nanotts_sentencepiece.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import sentencepiece as spm
8
+ from transformers import PreTrainedTokenizer
9
+
10
+
11
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
12
+
13
+
14
+ class NanoTTSSentencePieceTokenizer(PreTrainedTokenizer):
15
+ vocab_files_names = VOCAB_FILES_NAMES
16
+ model_input_names = ["input_ids", "attention_mask"]
17
+
18
+ def __init__(
19
+ self,
20
+ vocab_file: str,
21
+ unk_token: str = "<unk>",
22
+ bos_token: str = "<s>",
23
+ eos_token: str = "</s>",
24
+ pad_token: str = "<pad>",
25
+ sp_model_kwargs: dict[str, Any] | None = None,
26
+ **kwargs,
27
+ ) -> None:
28
+ self.vocab_file = str(vocab_file)
29
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else dict(sp_model_kwargs)
30
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
31
+ self.sp_model.Load(self.vocab_file)
32
+ super().__init__(
33
+ unk_token=unk_token,
34
+ bos_token=bos_token,
35
+ eos_token=eos_token,
36
+ pad_token=pad_token,
37
+ **kwargs,
38
+ )
39
+
40
+ @property
41
+ def vocab_size(self) -> int:
42
+ return int(self.sp_model.get_piece_size())
43
+
44
+ def get_vocab(self) -> dict[str, int]:
45
+ vocab = {self.sp_model.id_to_piece(i): i for i in range(self.vocab_size)}
46
+ vocab.update(self.added_tokens_encoder)
47
+ return vocab
48
+
49
+ def _tokenize(self, text: str) -> list[str]:
50
+ return list(self.sp_model.encode(text, out_type=str))
51
+
52
+ def _convert_token_to_id(self, token: str) -> int:
53
+ token_id = int(self.sp_model.piece_to_id(token))
54
+ return token_id
55
+
56
+ def _convert_id_to_token(self, index: int) -> str:
57
+ return str(self.sp_model.id_to_piece(int(index)))
58
+
59
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
60
+ return str(self.sp_model.decode(tokens))
61
+
62
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
63
+ save_dir = Path(save_directory)
64
+ save_dir.mkdir(parents=True, exist_ok=True)
65
+ out_name = "tokenizer.model" if filename_prefix is None else f"{filename_prefix}-tokenizer.model"
66
+ out_path = save_dir / out_name
67
+ if Path(self.vocab_file).resolve() != out_path.resolve():
68
+ shutil.copyfile(self.vocab_file, out_path)
69
+ return (str(out_path),)
70
+
71
+ def build_inputs_with_special_tokens(
72
+ self,
73
+ token_ids_0: list[int],
74
+ token_ids_1: list[int] | None = None,
75
+ ) -> list[int]:
76
+ if token_ids_1 is None:
77
+ return list(token_ids_0)
78
+ return list(token_ids_0) + list(token_ids_1)
79
+
80
+ def get_special_tokens_mask(
81
+ self,
82
+ token_ids_0: list[int],
83
+ token_ids_1: list[int] | None = None,
84
+ already_has_special_tokens: bool = False,
85
+ ) -> list[int]:
86
+ if already_has_special_tokens:
87
+ return super().get_special_tokens_mask(
88
+ token_ids_0=token_ids_0,
89
+ token_ids_1=token_ids_1,
90
+ already_has_special_tokens=True,
91
+ )
92
+ if token_ids_1 is None:
93
+ return [0] * len(token_ids_0)
94
+ return [0] * (len(token_ids_0) + len(token_ids_1))
95
+
96
+ def create_token_type_ids_from_sequences(
97
+ self,
98
+ token_ids_0: list[int],
99
+ token_ids_1: list[int] | None = None,
100
+ ) -> list[int]:
101
+ if token_ids_1 is None:
102
+ return [0] * len(token_ids_0)
103
+ return [0] * (len(token_ids_0) + len(token_ids_1))
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c353ee1479b536bf414c1b247f5542b6607fb8ae91320e5af1781fee200fddff
3
+ size 470897
tokenizer_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<unk>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<pad>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "additional_special_tokens": [],
37
+ "auto_map": {
38
+ "AutoTokenizer": [
39
+ "tokenization_nanotts_sentencepiece.NanoTTSSentencePieceTokenizer",
40
+ null
41
+ ]
42
+ },
43
+ "backend": "custom",
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "eos_token": "</s>",
47
+ "extra_special_tokens": {},
48
+ "model_max_length": 16384,
49
+ "pad_token": "<pad>",
50
+ "tokenizer_class": "NanoTTSSentencePieceTokenizer",
51
+ "unk_token": "<unk>"
52
+ }