hf-upload-bot commited on
Commit ·
ad3fd89
1
Parent(s): 5962079
Upload Nano TTS checkpoint 500000
Browse files- __init__.py +31 -0
- config.json +197 -0
- configuration_nanotts.py +105 -0
- gpt2_decoder.py +605 -0
- modeling_nanotts_global_local.py +1757 -0
- prompting.py +92 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +30 -0
- tokenization_nanotts_sentencepiece.py +103 -0
- tokenizer.model +3 -0
- tokenizer_config.json +52 -0
__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_nanotts import NanoTTSConfig
|
| 2 |
+
from .modeling_nanotts_global_local import (
|
| 3 |
+
NanoTTSGenerationOutput,
|
| 4 |
+
NanoTTSGlobalLocalForCausalLM,
|
| 5 |
+
NanoTTSOutput,
|
| 6 |
+
)
|
| 7 |
+
from .tokenization_nanotts_sentencepiece import NanoTTSSentencePieceTokenizer
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
NanoTTSConfig.register_for_auto_class()
|
| 11 |
+
except Exception:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
for auto_class_name in ("AutoModel", "AutoModelForCausalLM"):
|
| 15 |
+
try:
|
| 16 |
+
NanoTTSGlobalLocalForCausalLM.register_for_auto_class(auto_class_name)
|
| 17 |
+
except Exception:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
NanoTTSSentencePieceTokenizer.register_for_auto_class("AutoTokenizer")
|
| 22 |
+
except Exception:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"NanoTTSConfig",
|
| 27 |
+
"NanoTTSGlobalLocalForCausalLM",
|
| 28 |
+
"NanoTTSSentencePieceTokenizer",
|
| 29 |
+
"NanoTTSGenerationOutput",
|
| 30 |
+
"NanoTTSOutput",
|
| 31 |
+
]
|
config.json
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_cross_attention": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NanoTTSGlobalLocalForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attn_implementation": "flash_attention_2",
|
| 7 |
+
"audio_assistant_slot_token_id": 9,
|
| 8 |
+
"audio_codebook_sizes": [
|
| 9 |
+
1024,
|
| 10 |
+
1024,
|
| 11 |
+
1024,
|
| 12 |
+
1024,
|
| 13 |
+
1024,
|
| 14 |
+
1024,
|
| 15 |
+
1024,
|
| 16 |
+
1024,
|
| 17 |
+
1024,
|
| 18 |
+
1024,
|
| 19 |
+
1024,
|
| 20 |
+
1024,
|
| 21 |
+
1024,
|
| 22 |
+
1024,
|
| 23 |
+
1024,
|
| 24 |
+
1024
|
| 25 |
+
],
|
| 26 |
+
"audio_end_token_id": 7,
|
| 27 |
+
"audio_pad_token_id": 1024,
|
| 28 |
+
"audio_start_token_id": 6,
|
| 29 |
+
"audio_tokenizer_pretrained_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
|
| 30 |
+
"audio_tokenizer_sample_rate": 48000,
|
| 31 |
+
"audio_tokenizer_type": "moss-audio-tokenizer-nano",
|
| 32 |
+
"audio_user_slot_token_id": 8,
|
| 33 |
+
"audio_vocab_size": 1024,
|
| 34 |
+
"bad_words_ids": null,
|
| 35 |
+
"begin_suppress_tokens": null,
|
| 36 |
+
"bos_token_id": null,
|
| 37 |
+
"chunk_size_feed_forward": 0,
|
| 38 |
+
"cross_attention_hidden_size": null,
|
| 39 |
+
"decoder_start_token_id": null,
|
| 40 |
+
"diversity_penalty": 0.0,
|
| 41 |
+
"do_sample": false,
|
| 42 |
+
"dtype": "float32",
|
| 43 |
+
"early_stopping": false,
|
| 44 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 45 |
+
"eos_token_id": null,
|
| 46 |
+
"exponential_decay_length_penalty": null,
|
| 47 |
+
"finetuning_task": null,
|
| 48 |
+
"forced_bos_token_id": null,
|
| 49 |
+
"forced_eos_token_id": null,
|
| 50 |
+
"gpt2_config": {
|
| 51 |
+
"_name_or_path": "",
|
| 52 |
+
"activation_function": "gelu_new",
|
| 53 |
+
"add_cross_attention": false,
|
| 54 |
+
"architectures": null,
|
| 55 |
+
"attn_pdrop": 0.0,
|
| 56 |
+
"bad_words_ids": null,
|
| 57 |
+
"begin_suppress_tokens": null,
|
| 58 |
+
"bos_token_id": 1,
|
| 59 |
+
"chunk_size_feed_forward": 0,
|
| 60 |
+
"cross_attention_hidden_size": null,
|
| 61 |
+
"decoder_start_token_id": null,
|
| 62 |
+
"diversity_penalty": 0.0,
|
| 63 |
+
"do_sample": false,
|
| 64 |
+
"dtype": null,
|
| 65 |
+
"early_stopping": false,
|
| 66 |
+
"embd_pdrop": 0.0,
|
| 67 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 68 |
+
"eos_token_id": 2,
|
| 69 |
+
"exponential_decay_length_penalty": null,
|
| 70 |
+
"finetuning_task": null,
|
| 71 |
+
"forced_bos_token_id": null,
|
| 72 |
+
"forced_eos_token_id": null,
|
| 73 |
+
"id2label": {
|
| 74 |
+
"0": "LABEL_0",
|
| 75 |
+
"1": "LABEL_1"
|
| 76 |
+
},
|
| 77 |
+
"initializer_range": 0.02,
|
| 78 |
+
"is_decoder": false,
|
| 79 |
+
"is_encoder_decoder": false,
|
| 80 |
+
"label2id": {
|
| 81 |
+
"LABEL_0": 0,
|
| 82 |
+
"LABEL_1": 1
|
| 83 |
+
},
|
| 84 |
+
"layer_norm_epsilon": 1e-05,
|
| 85 |
+
"length_penalty": 1.0,
|
| 86 |
+
"max_length": 20,
|
| 87 |
+
"min_length": 0,
|
| 88 |
+
"model_type": "gpt2",
|
| 89 |
+
"n_ctx": 32768,
|
| 90 |
+
"n_embd": 768,
|
| 91 |
+
"n_head": 12,
|
| 92 |
+
"n_inner": 3072,
|
| 93 |
+
"n_layer": 12,
|
| 94 |
+
"n_positions": 32768,
|
| 95 |
+
"no_repeat_ngram_size": 0,
|
| 96 |
+
"num_beam_groups": 1,
|
| 97 |
+
"num_beams": 1,
|
| 98 |
+
"num_return_sequences": 1,
|
| 99 |
+
"output_attentions": false,
|
| 100 |
+
"output_hidden_states": false,
|
| 101 |
+
"output_scores": false,
|
| 102 |
+
"pad_token_id": 3,
|
| 103 |
+
"position_embedding_type": "rope",
|
| 104 |
+
"prefix": null,
|
| 105 |
+
"problem_type": null,
|
| 106 |
+
"pruned_heads": {},
|
| 107 |
+
"remove_invalid_values": false,
|
| 108 |
+
"reorder_and_upcast_attn": false,
|
| 109 |
+
"repetition_penalty": 1.0,
|
| 110 |
+
"resid_pdrop": 0.0,
|
| 111 |
+
"return_dict": true,
|
| 112 |
+
"return_dict_in_generate": false,
|
| 113 |
+
"rope_base": 10000.0,
|
| 114 |
+
"scale_attn_by_inverse_layer_idx": false,
|
| 115 |
+
"scale_attn_weights": true,
|
| 116 |
+
"sep_token_id": null,
|
| 117 |
+
"summary_activation": null,
|
| 118 |
+
"summary_first_dropout": 0.1,
|
| 119 |
+
"summary_proj_to_labels": true,
|
| 120 |
+
"summary_type": "cls_index",
|
| 121 |
+
"summary_use_proj": true,
|
| 122 |
+
"suppress_tokens": null,
|
| 123 |
+
"task_specific_params": null,
|
| 124 |
+
"temperature": 1.0,
|
| 125 |
+
"tf_legacy_loss": false,
|
| 126 |
+
"tie_encoder_decoder": false,
|
| 127 |
+
"tie_word_embeddings": true,
|
| 128 |
+
"tokenizer_class": null,
|
| 129 |
+
"top_k": 50,
|
| 130 |
+
"top_p": 1.0,
|
| 131 |
+
"torchscript": false,
|
| 132 |
+
"transformers_version": "4.57.1",
|
| 133 |
+
"typical_p": 1.0,
|
| 134 |
+
"use_bfloat16": false,
|
| 135 |
+
"use_cache": true,
|
| 136 |
+
"vocab_size": 16384
|
| 137 |
+
},
|
| 138 |
+
"hidden_size": 768,
|
| 139 |
+
"id2label": {
|
| 140 |
+
"0": "LABEL_0",
|
| 141 |
+
"1": "LABEL_1"
|
| 142 |
+
},
|
| 143 |
+
"im_end_token_id": 5,
|
| 144 |
+
"im_start_token_id": 4,
|
| 145 |
+
"initializer_range": 0.02,
|
| 146 |
+
"is_decoder": false,
|
| 147 |
+
"is_encoder_decoder": false,
|
| 148 |
+
"label2id": {
|
| 149 |
+
"LABEL_0": 0,
|
| 150 |
+
"LABEL_1": 1
|
| 151 |
+
},
|
| 152 |
+
"length_penalty": 1.0,
|
| 153 |
+
"local_transformer_attn_implementation": "flash_attention_2",
|
| 154 |
+
"local_transformer_layers": 1,
|
| 155 |
+
"max_length": 20,
|
| 156 |
+
"max_position_embeddings": 32768,
|
| 157 |
+
"min_length": 0,
|
| 158 |
+
"model_architecture": "global_local_transformer",
|
| 159 |
+
"model_type": "nano_tts",
|
| 160 |
+
"n_vq": 16,
|
| 161 |
+
"no_repeat_ngram_size": 0,
|
| 162 |
+
"num_beam_groups": 1,
|
| 163 |
+
"num_beams": 1,
|
| 164 |
+
"num_return_sequences": 1,
|
| 165 |
+
"output_attentions": false,
|
| 166 |
+
"output_hidden_states": false,
|
| 167 |
+
"output_scores": false,
|
| 168 |
+
"pad_token_id": 3,
|
| 169 |
+
"prefix": null,
|
| 170 |
+
"problem_type": null,
|
| 171 |
+
"pruned_heads": {},
|
| 172 |
+
"remove_invalid_values": false,
|
| 173 |
+
"repetition_penalty": 1.0,
|
| 174 |
+
"return_dict": true,
|
| 175 |
+
"return_dict_in_generate": false,
|
| 176 |
+
"sep_token_id": null,
|
| 177 |
+
"suppress_tokens": null,
|
| 178 |
+
"task_specific_params": null,
|
| 179 |
+
"temperature": 1.0,
|
| 180 |
+
"tf_legacy_loss": false,
|
| 181 |
+
"tie_encoder_decoder": false,
|
| 182 |
+
"tie_word_embeddings": true,
|
| 183 |
+
"tokenizer_class": "NanoTTSSentencePieceTokenizer",
|
| 184 |
+
"tokenizer_use_fast": false,
|
| 185 |
+
"top_k": 50,
|
| 186 |
+
"top_p": 1.0,
|
| 187 |
+
"torchscript": false,
|
| 188 |
+
"transformers_version": "4.57.1",
|
| 189 |
+
"typical_p": 1.0,
|
| 190 |
+
"use_bfloat16": false,
|
| 191 |
+
"vocab_size": 16384,
|
| 192 |
+
"auto_map": {
|
| 193 |
+
"AutoConfig": "configuration_nanotts.NanoTTSConfig",
|
| 194 |
+
"AutoModel": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM",
|
| 195 |
+
"AutoModelForCausalLM": "modeling_nanotts_global_local.NanoTTSGlobalLocalForCausalLM"
|
| 196 |
+
}
|
| 197 |
+
}
|
configuration_nanotts.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from typing import Any, Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NanoTTSConfig(PretrainedConfig):
|
| 9 |
+
model_type = "nano_tts"
|
| 10 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
|
| 15 |
+
n_vq: int = 8,
|
| 16 |
+
audio_vocab_size: Optional[int] = 1024,
|
| 17 |
+
audio_codebook_sizes: Optional[list[int]] = None,
|
| 18 |
+
audio_pad_token_id: int = 1024,
|
| 19 |
+
pad_token_id: int = 151643,
|
| 20 |
+
im_start_token_id: int = 151644,
|
| 21 |
+
im_end_token_id: int = 151645,
|
| 22 |
+
audio_start_token_id: int = 151652,
|
| 23 |
+
audio_end_token_id: int = 151653,
|
| 24 |
+
audio_user_slot_token_id: int = 151654,
|
| 25 |
+
audio_assistant_slot_token_id: int = 151656,
|
| 26 |
+
tokenizer_use_fast: bool = False,
|
| 27 |
+
audio_tokenizer_type: str = "moss-audio-tokenizer-nano",
|
| 28 |
+
audio_tokenizer_pretrained_name_or_path: Optional[str] = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano",
|
| 29 |
+
audio_tokenizer_sample_rate: int = 48000,
|
| 30 |
+
attn_implementation: str = "flash_attention_2",
|
| 31 |
+
initializer_range: float = 0.02,
|
| 32 |
+
model_architecture: str = "global_local_transformer",
|
| 33 |
+
local_transformer_layers: int = 4,
|
| 34 |
+
local_transformer_attn_implementation: Optional[str] = None,
|
| 35 |
+
**kwargs: Any,
|
| 36 |
+
) -> None:
|
| 37 |
+
if isinstance(gpt2_config, dict):
|
| 38 |
+
self.gpt2_config = GPT2Config(**gpt2_config)
|
| 39 |
+
elif gpt2_config is None:
|
| 40 |
+
self.gpt2_config = GPT2Config()
|
| 41 |
+
else:
|
| 42 |
+
self.gpt2_config = gpt2_config
|
| 43 |
+
|
| 44 |
+
self.n_vq = int(n_vq)
|
| 45 |
+
if audio_codebook_sizes is None:
|
| 46 |
+
if audio_vocab_size is None:
|
| 47 |
+
raise ValueError("audio_vocab_size must be set when audio_codebook_sizes is not provided.")
|
| 48 |
+
resolved_audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
|
| 49 |
+
else:
|
| 50 |
+
resolved_audio_codebook_sizes = [int(codebook_size) for codebook_size in audio_codebook_sizes]
|
| 51 |
+
if len(resolved_audio_codebook_sizes) != self.n_vq:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"audio_codebook_sizes must have length n_vq "
|
| 54 |
+
f"(expected {self.n_vq}, got {len(resolved_audio_codebook_sizes)})."
|
| 55 |
+
)
|
| 56 |
+
if any(codebook_size <= 0 for codebook_size in resolved_audio_codebook_sizes):
|
| 57 |
+
raise ValueError("audio_codebook_sizes must contain positive integers.")
|
| 58 |
+
|
| 59 |
+
max_audio_codebook_size = max(resolved_audio_codebook_sizes)
|
| 60 |
+
if audio_vocab_size is not None and int(audio_vocab_size) < max_audio_codebook_size:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"audio_vocab_size must be >= max(audio_codebook_sizes) "
|
| 63 |
+
f"(got {audio_vocab_size}, expected at least {max_audio_codebook_size})."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.audio_codebook_sizes = resolved_audio_codebook_sizes
|
| 67 |
+
self.audio_vocab_size = (
|
| 68 |
+
max_audio_codebook_size if audio_vocab_size is None else int(audio_vocab_size)
|
| 69 |
+
)
|
| 70 |
+
self.audio_pad_token_id = int(audio_pad_token_id)
|
| 71 |
+
if self.audio_pad_token_id < max_audio_codebook_size:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"audio_pad_token_id must be >= max(audio_codebook_sizes) so pad stays outside every codebook "
|
| 74 |
+
f"(got {self.audio_pad_token_id}, max codebook size {max_audio_codebook_size})."
|
| 75 |
+
)
|
| 76 |
+
self.pad_token_id = pad_token_id
|
| 77 |
+
self.im_start_token_id = im_start_token_id
|
| 78 |
+
self.im_end_token_id = im_end_token_id
|
| 79 |
+
self.audio_start_token_id = audio_start_token_id
|
| 80 |
+
self.audio_end_token_id = audio_end_token_id
|
| 81 |
+
self.audio_user_slot_token_id = audio_user_slot_token_id
|
| 82 |
+
self.audio_assistant_slot_token_id = audio_assistant_slot_token_id
|
| 83 |
+
self.tokenizer_use_fast = tokenizer_use_fast
|
| 84 |
+
self.audio_tokenizer_type = audio_tokenizer_type
|
| 85 |
+
self.audio_tokenizer_pretrained_name_or_path = audio_tokenizer_pretrained_name_or_path
|
| 86 |
+
self.audio_tokenizer_sample_rate = audio_tokenizer_sample_rate
|
| 87 |
+
self.attn_implementation = attn_implementation
|
| 88 |
+
self.initializer_range = initializer_range
|
| 89 |
+
self.model_architecture = model_architecture
|
| 90 |
+
self.local_transformer_layers = local_transformer_layers
|
| 91 |
+
self.local_transformer_attn_implementation = (
|
| 92 |
+
attn_implementation
|
| 93 |
+
if local_transformer_attn_implementation is None
|
| 94 |
+
else local_transformer_attn_implementation
|
| 95 |
+
)
|
| 96 |
+
self.vocab_size = self.gpt2_config.vocab_size
|
| 97 |
+
self.hidden_size = self.gpt2_config.hidden_size
|
| 98 |
+
self.max_position_embeddings = self.gpt2_config.n_positions
|
| 99 |
+
|
| 100 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 101 |
+
|
| 102 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 103 |
+
output = super().to_dict()
|
| 104 |
+
output["gpt2_config"] = self.gpt2_config.to_dict()
|
| 105 |
+
return output
|
gpt2_decoder.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from transformers.activations import ACT2FN
|
| 11 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 12 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 16 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 17 |
+
|
| 18 |
+
_FLASH_ATTN_AVAILABLE = True
|
| 19 |
+
except Exception:
|
| 20 |
+
flash_attn_func = None
|
| 21 |
+
flash_attn_varlen_func = None
|
| 22 |
+
pad_input = None
|
| 23 |
+
unpad_input = None
|
| 24 |
+
_FLASH_ATTN_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class PackedSequenceMetadata:
|
| 29 |
+
cu_seqlens: torch.Tensor
|
| 30 |
+
max_seqlen: int
|
| 31 |
+
indices: Optional[torch.Tensor] = None
|
| 32 |
+
batch_size: Optional[int] = None
|
| 33 |
+
seq_len: Optional[int] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class NanoGPT2RotaryEmbedding(nn.Module):
|
| 37 |
+
def __init__(self, dim: int, base: float = 10000.0) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
if dim % 2 != 0:
|
| 40 |
+
raise ValueError(f"RoPE head_dim must be even, got {dim}")
|
| 41 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 42 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 43 |
+
|
| 44 |
+
def forward(
|
| 45 |
+
self,
|
| 46 |
+
position_ids: torch.LongTensor,
|
| 47 |
+
*,
|
| 48 |
+
device: torch.device,
|
| 49 |
+
dtype: torch.dtype,
|
| 50 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 51 |
+
if position_ids.ndim == 1:
|
| 52 |
+
position_ids = position_ids.unsqueeze(0)
|
| 53 |
+
freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=self.inv_freq.dtype), self.inv_freq)
|
| 54 |
+
cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
|
| 55 |
+
sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
|
| 56 |
+
return cos, sin
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
even = hidden_states[..., ::2]
|
| 61 |
+
odd = hidden_states[..., 1::2]
|
| 62 |
+
return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_rotary_pos_emb(
|
| 66 |
+
hidden_states: torch.Tensor,
|
| 67 |
+
cos: torch.Tensor,
|
| 68 |
+
sin: torch.Tensor,
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class NanoGPT2MLP(nn.Module):
|
| 74 |
+
def __init__(self, config: GPT2Config) -> None:
|
| 75 |
+
super().__init__()
|
| 76 |
+
hidden_size = int(config.hidden_size)
|
| 77 |
+
inner_size = int(config.n_inner or 4 * hidden_size)
|
| 78 |
+
self.fc_in = nn.Linear(hidden_size, inner_size)
|
| 79 |
+
self.fc_out = nn.Linear(inner_size, hidden_size)
|
| 80 |
+
self.act = ACT2FN[config.activation_function]
|
| 81 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 82 |
+
|
| 83 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
hidden_states = self.fc_in(hidden_states)
|
| 85 |
+
hidden_states = self.act(hidden_states)
|
| 86 |
+
hidden_states = self.fc_out(hidden_states)
|
| 87 |
+
return self.dropout(hidden_states)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class NanoGPT2Attention(nn.Module):
|
| 91 |
+
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
|
| 92 |
+
super().__init__()
|
| 93 |
+
hidden_size = int(config.hidden_size)
|
| 94 |
+
num_heads = int(config.num_attention_heads)
|
| 95 |
+
if hidden_size % num_heads != 0:
|
| 96 |
+
raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
|
| 97 |
+
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.head_dim = hidden_size // num_heads
|
| 100 |
+
self.embed_dim = hidden_size
|
| 101 |
+
self.layer_idx = layer_idx
|
| 102 |
+
self.attn_implementation = attn_implementation
|
| 103 |
+
self.attn_dropout = float(config.attn_pdrop)
|
| 104 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 105 |
+
self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
|
| 106 |
+
self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
|
| 107 |
+
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
|
| 108 |
+
if self.position_embedding_type not in {"absolute", "rope"}:
|
| 109 |
+
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
|
| 110 |
+
|
| 111 |
+
self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
|
| 112 |
+
self.c_proj = nn.Linear(hidden_size, hidden_size)
|
| 113 |
+
self.rotary_emb = None
|
| 114 |
+
if self.position_embedding_type == "rope":
|
| 115 |
+
self.rotary_emb = NanoGPT2RotaryEmbedding(
|
| 116 |
+
self.head_dim,
|
| 117 |
+
base=float(getattr(config, "rope_base", 10000.0)),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
if tensor.ndim == 3:
|
| 122 |
+
batch_size, seq_len, _ = tensor.shape
|
| 123 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 124 |
+
if tensor.ndim == 2:
|
| 125 |
+
total_tokens, _ = tensor.shape
|
| 126 |
+
return tensor.view(total_tokens, self.num_heads, self.head_dim)
|
| 127 |
+
raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
|
| 128 |
+
|
| 129 |
+
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
if tensor.ndim == 4:
|
| 131 |
+
batch_size, seq_len, _, _ = tensor.shape
|
| 132 |
+
return tensor.reshape(batch_size, seq_len, self.embed_dim)
|
| 133 |
+
if tensor.ndim == 3:
|
| 134 |
+
total_tokens, _, _ = tensor.shape
|
| 135 |
+
return tensor.reshape(total_tokens, self.embed_dim)
|
| 136 |
+
raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
|
| 137 |
+
|
| 138 |
+
def _causal_attention_mask(
|
| 139 |
+
self,
|
| 140 |
+
attention_mask: Optional[torch.Tensor],
|
| 141 |
+
query_length: int,
|
| 142 |
+
key_length: int,
|
| 143 |
+
device: torch.device,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
query_positions = torch.arange(query_length, device=device, dtype=torch.long)
|
| 146 |
+
query_positions = query_positions + max(key_length - query_length, 0)
|
| 147 |
+
key_positions = torch.arange(key_length, device=device, dtype=torch.long)
|
| 148 |
+
causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
|
| 149 |
+
causal = causal.unsqueeze(0).unsqueeze(0)
|
| 150 |
+
if attention_mask is None:
|
| 151 |
+
return causal
|
| 152 |
+
key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
|
| 153 |
+
return causal & key_mask
|
| 154 |
+
|
| 155 |
+
def _eager_attention(
|
| 156 |
+
self,
|
| 157 |
+
query: torch.Tensor,
|
| 158 |
+
key: torch.Tensor,
|
| 159 |
+
value: torch.Tensor,
|
| 160 |
+
attention_mask: Optional[torch.Tensor],
|
| 161 |
+
) -> torch.Tensor:
|
| 162 |
+
query = query.transpose(1, 2)
|
| 163 |
+
key = key.transpose(1, 2)
|
| 164 |
+
value = value.transpose(1, 2)
|
| 165 |
+
|
| 166 |
+
scale = 1.0
|
| 167 |
+
if self.scale_attn_weights:
|
| 168 |
+
scale /= self.head_dim ** 0.5
|
| 169 |
+
if self.scale_attn_by_inverse_layer_idx:
|
| 170 |
+
scale /= float(self.layer_idx + 1)
|
| 171 |
+
|
| 172 |
+
scores = torch.matmul(query, key.transpose(-1, -2)) * scale
|
| 173 |
+
causal_mask = self._causal_attention_mask(
|
| 174 |
+
attention_mask=attention_mask,
|
| 175 |
+
query_length=query.shape[-2],
|
| 176 |
+
key_length=key.shape[-2],
|
| 177 |
+
device=query.device,
|
| 178 |
+
)
|
| 179 |
+
scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
|
| 180 |
+
probs = torch.softmax(scores, dim=-1)
|
| 181 |
+
if self.training and self.attn_dropout > 0:
|
| 182 |
+
probs = torch.dropout(probs, self.attn_dropout, train=True)
|
| 183 |
+
output = torch.matmul(probs, value)
|
| 184 |
+
return output.transpose(1, 2).contiguous()
|
| 185 |
+
|
| 186 |
+
def _sdpa_attention(
|
| 187 |
+
self,
|
| 188 |
+
query: torch.Tensor,
|
| 189 |
+
key: torch.Tensor,
|
| 190 |
+
value: torch.Tensor,
|
| 191 |
+
attention_mask: Optional[torch.Tensor],
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
query = query.transpose(1, 2)
|
| 194 |
+
key = key.transpose(1, 2)
|
| 195 |
+
value = value.transpose(1, 2)
|
| 196 |
+
mask = None
|
| 197 |
+
if attention_mask is not None:
|
| 198 |
+
mask = self._causal_attention_mask(
|
| 199 |
+
attention_mask=attention_mask,
|
| 200 |
+
query_length=query.shape[-2],
|
| 201 |
+
key_length=key.shape[-2],
|
| 202 |
+
device=query.device,
|
| 203 |
+
)
|
| 204 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 205 |
+
query,
|
| 206 |
+
key,
|
| 207 |
+
value,
|
| 208 |
+
attn_mask=mask,
|
| 209 |
+
dropout_p=self.attn_dropout if self.training else 0.0,
|
| 210 |
+
is_causal=mask is None,
|
| 211 |
+
)
|
| 212 |
+
return output.transpose(1, 2).contiguous()
|
| 213 |
+
|
| 214 |
+
def _flash_attention(
|
| 215 |
+
self,
|
| 216 |
+
query: torch.Tensor,
|
| 217 |
+
key: torch.Tensor,
|
| 218 |
+
value: torch.Tensor,
|
| 219 |
+
attention_mask: Optional[torch.Tensor],
|
| 220 |
+
packed_metadata: Optional[PackedSequenceMetadata],
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
if not _FLASH_ATTN_AVAILABLE:
|
| 223 |
+
raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
|
| 224 |
+
if query.device.type != "cuda":
|
| 225 |
+
raise ValueError("flash_attention_2 requires CUDA tensors.")
|
| 226 |
+
if query.dtype not in (torch.float16, torch.bfloat16):
|
| 227 |
+
raise ValueError(
|
| 228 |
+
f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
dropout_p = self.attn_dropout if self.training else 0.0
|
| 232 |
+
if packed_metadata is not None:
|
| 233 |
+
if packed_metadata.indices is not None:
|
| 234 |
+
query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 235 |
+
key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 236 |
+
value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 237 |
+
output = flash_attn_varlen_func(
|
| 238 |
+
query,
|
| 239 |
+
key,
|
| 240 |
+
value,
|
| 241 |
+
packed_metadata.cu_seqlens,
|
| 242 |
+
packed_metadata.cu_seqlens,
|
| 243 |
+
packed_metadata.max_seqlen,
|
| 244 |
+
packed_metadata.max_seqlen,
|
| 245 |
+
dropout_p=dropout_p,
|
| 246 |
+
causal=True,
|
| 247 |
+
)
|
| 248 |
+
if packed_metadata.indices is None:
|
| 249 |
+
return output
|
| 250 |
+
return pad_input(
|
| 251 |
+
output,
|
| 252 |
+
packed_metadata.indices,
|
| 253 |
+
packed_metadata.batch_size,
|
| 254 |
+
packed_metadata.seq_len,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if attention_mask is None or bool(attention_mask.all()):
|
| 258 |
+
return flash_attn_func(
|
| 259 |
+
query,
|
| 260 |
+
key,
|
| 261 |
+
value,
|
| 262 |
+
dropout_p=dropout_p,
|
| 263 |
+
causal=True,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
|
| 267 |
+
unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
|
| 268 |
+
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
|
| 269 |
+
output = flash_attn_varlen_func(
|
| 270 |
+
unpadded_query,
|
| 271 |
+
unpadded_key,
|
| 272 |
+
unpadded_value,
|
| 273 |
+
cu_seqlens,
|
| 274 |
+
cu_seqlens,
|
| 275 |
+
max_seqlen,
|
| 276 |
+
max_seqlen,
|
| 277 |
+
dropout_p=dropout_p,
|
| 278 |
+
causal=True,
|
| 279 |
+
)
|
| 280 |
+
return pad_input(output, indices, query.shape[0], query.shape[1])
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
hidden_states: torch.Tensor,
|
| 285 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 286 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 287 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 288 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 289 |
+
use_cache: bool = False,
|
| 290 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 291 |
+
qkv = self.c_attn(hidden_states)
|
| 292 |
+
query, key, value = qkv.split(self.embed_dim, dim=-1)
|
| 293 |
+
query = self._split_heads(query)
|
| 294 |
+
key = self._split_heads(key)
|
| 295 |
+
value = self._split_heads(value)
|
| 296 |
+
|
| 297 |
+
if self.rotary_emb is not None:
|
| 298 |
+
if position_ids is None:
|
| 299 |
+
raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
|
| 300 |
+
cos, sin = self.rotary_emb(
|
| 301 |
+
position_ids.to(device=query.device),
|
| 302 |
+
device=query.device,
|
| 303 |
+
dtype=query.dtype,
|
| 304 |
+
)
|
| 305 |
+
query = apply_rotary_pos_emb(query, cos, sin)
|
| 306 |
+
key = apply_rotary_pos_emb(key, cos, sin)
|
| 307 |
+
|
| 308 |
+
if layer_past is not None:
|
| 309 |
+
past_key, past_value = layer_past
|
| 310 |
+
key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
|
| 311 |
+
value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
|
| 312 |
+
|
| 313 |
+
present = (key, value) if use_cache else None
|
| 314 |
+
|
| 315 |
+
if self.attn_implementation == "flash_attention_2" and layer_past is None:
|
| 316 |
+
attn_output = self._flash_attention(
|
| 317 |
+
query=query,
|
| 318 |
+
key=key,
|
| 319 |
+
value=value,
|
| 320 |
+
attention_mask=attention_mask,
|
| 321 |
+
packed_metadata=packed_metadata,
|
| 322 |
+
)
|
| 323 |
+
elif self.attn_implementation == "sdpa":
|
| 324 |
+
attn_output = self._sdpa_attention(
|
| 325 |
+
query=query,
|
| 326 |
+
key=key,
|
| 327 |
+
value=value,
|
| 328 |
+
attention_mask=attention_mask,
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
attn_output = self._eager_attention(
|
| 332 |
+
query=query,
|
| 333 |
+
key=key,
|
| 334 |
+
value=value,
|
| 335 |
+
attention_mask=attention_mask,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
attn_output = self._merge_heads(attn_output)
|
| 339 |
+
attn_output = self.c_proj(attn_output)
|
| 340 |
+
return self.resid_dropout(attn_output), present
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class NanoGPT2Block(nn.Module):
|
| 344 |
+
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
|
| 345 |
+
super().__init__()
|
| 346 |
+
hidden_size = int(config.hidden_size)
|
| 347 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 348 |
+
self.attn = NanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
|
| 349 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 350 |
+
self.mlp = NanoGPT2MLP(config)
|
| 351 |
+
|
| 352 |
+
def forward(
|
| 353 |
+
self,
|
| 354 |
+
hidden_states: torch.Tensor,
|
| 355 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 356 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 357 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 358 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 359 |
+
use_cache: bool = False,
|
| 360 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 361 |
+
attn_output, present = self.attn(
|
| 362 |
+
self.ln_1(hidden_states),
|
| 363 |
+
attention_mask=attention_mask,
|
| 364 |
+
position_ids=position_ids,
|
| 365 |
+
packed_metadata=packed_metadata,
|
| 366 |
+
layer_past=layer_past,
|
| 367 |
+
use_cache=use_cache,
|
| 368 |
+
)
|
| 369 |
+
hidden_states = hidden_states + attn_output
|
| 370 |
+
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
|
| 371 |
+
return hidden_states, present
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class NanoGPT2Model(nn.Module):
|
| 375 |
+
def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.config = config
|
| 378 |
+
self.attn_implementation = attn_implementation
|
| 379 |
+
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
|
| 380 |
+
if self.position_embedding_type not in {"absolute", "rope"}:
|
| 381 |
+
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
|
| 382 |
+
hidden_size = int(config.hidden_size)
|
| 383 |
+
self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
| 384 |
+
self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
|
| 385 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 386 |
+
self.h = nn.ModuleList(
|
| 387 |
+
[NanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
|
| 388 |
+
)
|
| 389 |
+
self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 390 |
+
self.gradient_checkpointing = False
|
| 391 |
+
self._reset_parameters()
|
| 392 |
+
|
| 393 |
+
def _reset_parameters(self) -> None:
|
| 394 |
+
init_std = float(self.config.initializer_range)
|
| 395 |
+
for module in self.modules():
|
| 396 |
+
if isinstance(module, nn.Linear):
|
| 397 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 398 |
+
if module.bias is not None:
|
| 399 |
+
nn.init.zeros_(module.bias)
|
| 400 |
+
elif isinstance(module, nn.Embedding):
|
| 401 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 402 |
+
elif isinstance(module, nn.LayerNorm):
|
| 403 |
+
nn.init.ones_(module.weight)
|
| 404 |
+
nn.init.zeros_(module.bias)
|
| 405 |
+
|
| 406 |
+
@staticmethod
|
| 407 |
+
def _normalize_num_sequences(
|
| 408 |
+
cu_seqlens: torch.Tensor,
|
| 409 |
+
num_sequences: Optional[torch.Tensor],
|
| 410 |
+
device: torch.device,
|
| 411 |
+
) -> torch.Tensor:
|
| 412 |
+
if cu_seqlens.ndim == 1:
|
| 413 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 414 |
+
if num_sequences is None:
|
| 415 |
+
counts = []
|
| 416 |
+
for boundary in cu_seqlens:
|
| 417 |
+
diffs = boundary[1:] - boundary[:-1]
|
| 418 |
+
counts.append(int((diffs > 0).sum().item()))
|
| 419 |
+
return torch.tensor(counts, dtype=torch.int32, device=device)
|
| 420 |
+
if num_sequences.ndim == 0:
|
| 421 |
+
return num_sequences.unsqueeze(0)
|
| 422 |
+
return num_sequences
|
| 423 |
+
|
| 424 |
+
@staticmethod
|
| 425 |
+
def build_packed_position_ids(
|
| 426 |
+
attention_mask: Optional[torch.Tensor],
|
| 427 |
+
cu_seqlens: torch.Tensor,
|
| 428 |
+
num_sequences: Optional[torch.Tensor],
|
| 429 |
+
) -> torch.Tensor:
|
| 430 |
+
if cu_seqlens.ndim == 1:
|
| 431 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 432 |
+
batch_size, seq_len = cu_seqlens.shape[0], cu_seqlens.shape[1] - 1
|
| 433 |
+
device = cu_seqlens.device
|
| 434 |
+
position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
|
| 435 |
+
counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
|
| 436 |
+
for batch_index in range(batch_size):
|
| 437 |
+
sequence_count = int(counts[batch_index].item())
|
| 438 |
+
boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
|
| 439 |
+
for start, end in zip(boundaries[:-1], boundaries[1:]):
|
| 440 |
+
start = int(start)
|
| 441 |
+
end = int(end)
|
| 442 |
+
if end > start:
|
| 443 |
+
position_ids[batch_index, start:end] = torch.arange(end - start, device=device)
|
| 444 |
+
if attention_mask is not None:
|
| 445 |
+
position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
|
| 446 |
+
return position_ids
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def build_packed_metadata(
|
| 450 |
+
hidden_states: torch.Tensor,
|
| 451 |
+
cu_seqlens: torch.Tensor,
|
| 452 |
+
num_sequences: Optional[torch.Tensor],
|
| 453 |
+
) -> PackedSequenceMetadata:
|
| 454 |
+
if cu_seqlens.ndim == 1:
|
| 455 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 456 |
+
device = hidden_states.device
|
| 457 |
+
counts = NanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
|
| 458 |
+
flat_indices = []
|
| 459 |
+
cumulative = [0]
|
| 460 |
+
max_seqlen = 0
|
| 461 |
+
seq_len = hidden_states.shape[1]
|
| 462 |
+
|
| 463 |
+
for batch_index in range(hidden_states.shape[0]):
|
| 464 |
+
sequence_count = int(counts[batch_index].item())
|
| 465 |
+
boundaries = cu_seqlens[batch_index, : sequence_count + 1].tolist()
|
| 466 |
+
for start, end in zip(boundaries[:-1], boundaries[1:]):
|
| 467 |
+
start = int(start)
|
| 468 |
+
end = int(end)
|
| 469 |
+
if end <= start:
|
| 470 |
+
continue
|
| 471 |
+
segment_indices = batch_index * seq_len + torch.arange(start, end, device=device)
|
| 472 |
+
flat_indices.append(segment_indices)
|
| 473 |
+
cumulative.append(cumulative[-1] + (end - start))
|
| 474 |
+
max_seqlen = max(max_seqlen, end - start)
|
| 475 |
+
|
| 476 |
+
if not flat_indices:
|
| 477 |
+
raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
|
| 478 |
+
|
| 479 |
+
indices = torch.cat(flat_indices, dim=0)
|
| 480 |
+
return PackedSequenceMetadata(
|
| 481 |
+
cu_seqlens=torch.tensor(cumulative, dtype=torch.int32, device=device),
|
| 482 |
+
max_seqlen=max_seqlen,
|
| 483 |
+
indices=indices,
|
| 484 |
+
batch_size=hidden_states.shape[0],
|
| 485 |
+
seq_len=hidden_states.shape[1],
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def forward(
|
| 489 |
+
self,
|
| 490 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 491 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 492 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 493 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 494 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 495 |
+
use_cache: Optional[bool] = None,
|
| 496 |
+
output_attentions: Optional[bool] = None,
|
| 497 |
+
output_hidden_states: Optional[bool] = None,
|
| 498 |
+
return_dict: bool = True,
|
| 499 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 500 |
+
num_sequences: Optional[torch.Tensor] = None,
|
| 501 |
+
) -> BaseModelOutputWithPast:
|
| 502 |
+
del input_ids, output_attentions
|
| 503 |
+
|
| 504 |
+
if inputs_embeds is None:
|
| 505 |
+
raise ValueError("inputs_embeds must be provided.")
|
| 506 |
+
|
| 507 |
+
use_cache = bool(use_cache)
|
| 508 |
+
if use_cache and cu_seqlens is not None:
|
| 509 |
+
raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
|
| 510 |
+
|
| 511 |
+
hidden_states = inputs_embeds
|
| 512 |
+
if attention_mask is None:
|
| 513 |
+
attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
|
| 514 |
+
else:
|
| 515 |
+
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
|
| 516 |
+
query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
|
| 517 |
+
|
| 518 |
+
packed_metadata = None
|
| 519 |
+
if position_ids is None:
|
| 520 |
+
if cu_seqlens is not None:
|
| 521 |
+
position_ids = self.build_packed_position_ids(
|
| 522 |
+
attention_mask=attention_mask,
|
| 523 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 524 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 525 |
+
)
|
| 526 |
+
elif attention_mask is not None:
|
| 527 |
+
position_ids = attention_mask.long().cumsum(dim=-1) - 1
|
| 528 |
+
position_ids = position_ids.masked_fill(~attention_mask, 0)
|
| 529 |
+
position_ids = position_ids[:, -hidden_states.shape[1] :]
|
| 530 |
+
else:
|
| 531 |
+
past_length = 0
|
| 532 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
| 533 |
+
past_length = past_key_values[0][0].shape[1]
|
| 534 |
+
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
|
| 535 |
+
position_ids = position_ids + past_length
|
| 536 |
+
position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
|
| 537 |
+
|
| 538 |
+
if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
|
| 539 |
+
packed_metadata = self.build_packed_metadata(
|
| 540 |
+
hidden_states=hidden_states,
|
| 541 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 542 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if self.position_embedding_type == "absolute":
|
| 546 |
+
hidden_states = hidden_states + self.wpe(position_ids)
|
| 547 |
+
hidden_states = self.drop(hidden_states)
|
| 548 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 549 |
+
|
| 550 |
+
all_hidden_states = () if output_hidden_states else None
|
| 551 |
+
presents = [] if use_cache else None
|
| 552 |
+
for layer_index, block in enumerate(self.h):
|
| 553 |
+
if output_hidden_states:
|
| 554 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 555 |
+
|
| 556 |
+
if self.gradient_checkpointing and self.training:
|
| 557 |
+
if use_cache:
|
| 558 |
+
raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
|
| 559 |
+
|
| 560 |
+
def custom_forward(*inputs):
|
| 561 |
+
output, _ = block(
|
| 562 |
+
inputs[0],
|
| 563 |
+
attention_mask=inputs[1],
|
| 564 |
+
position_ids=inputs[2],
|
| 565 |
+
packed_metadata=packed_metadata,
|
| 566 |
+
layer_past=None,
|
| 567 |
+
use_cache=False,
|
| 568 |
+
)
|
| 569 |
+
return output
|
| 570 |
+
|
| 571 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 572 |
+
custom_forward,
|
| 573 |
+
hidden_states,
|
| 574 |
+
attention_mask,
|
| 575 |
+
position_ids,
|
| 576 |
+
use_reentrant=False,
|
| 577 |
+
)
|
| 578 |
+
present = None
|
| 579 |
+
else:
|
| 580 |
+
hidden_states, present = block(
|
| 581 |
+
hidden_states,
|
| 582 |
+
attention_mask=attention_mask,
|
| 583 |
+
position_ids=position_ids,
|
| 584 |
+
packed_metadata=packed_metadata,
|
| 585 |
+
layer_past=None if past_key_values is None else past_key_values[layer_index],
|
| 586 |
+
use_cache=use_cache,
|
| 587 |
+
)
|
| 588 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 589 |
+
if presents is not None:
|
| 590 |
+
presents.append(present)
|
| 591 |
+
|
| 592 |
+
hidden_states = self.ln_f(hidden_states)
|
| 593 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 594 |
+
if output_hidden_states:
|
| 595 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 596 |
+
|
| 597 |
+
if not return_dict:
|
| 598 |
+
return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
|
| 599 |
+
|
| 600 |
+
return BaseModelOutputWithPast(
|
| 601 |
+
last_hidden_state=hidden_states,
|
| 602 |
+
past_key_values=tuple(presents) if presents is not None else None,
|
| 603 |
+
hidden_states=all_hidden_states,
|
| 604 |
+
attentions=None,
|
| 605 |
+
)
|
modeling_nanotts_global_local.py
ADDED
|
@@ -0,0 +1,1757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Optional, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torchaudio
|
| 16 |
+
from transformers import AutoModel, AutoTokenizer
|
| 17 |
+
from transformers.modeling_outputs import ModelOutput
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 20 |
+
|
| 21 |
+
from .configuration_nanotts import NanoTTSConfig
|
| 22 |
+
from .gpt2_decoder import NanoGPT2Block, NanoGPT2Model
|
| 23 |
+
from .prompting import (
|
| 24 |
+
build_assistant_prompt_prefix,
|
| 25 |
+
build_prompt_token_ids,
|
| 26 |
+
build_user_prompt_after_reference,
|
| 27 |
+
build_user_prompt_prefix,
|
| 28 |
+
)
|
| 29 |
+
from .tokenization_nanotts_sentencepiece import NanoTTSSentencePieceTokenizer
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class NanoTTSOutput(ModelOutput):
|
| 34 |
+
global_hidden_states: Optional[torch.FloatTensor] = None
|
| 35 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None
|
| 36 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 37 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class NanoTTSGenerationOutput(ModelOutput):
|
| 42 |
+
audio_token_ids: torch.LongTensor
|
| 43 |
+
prompt_input_ids: Optional[torch.LongTensor] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
MOSS_AUDIO_TOKENIZER_NANO_TYPE = "moss-audio-tokenizer-nano"
|
| 47 |
+
DEFAULT_MOSS_AUDIO_TOKENIZER_PRETRAINED_NAME_OR_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer-Nano"
|
| 48 |
+
DEFAULT_VOICE_CLONE_MAX_TEXT_TOKENS = 50
|
| 49 |
+
DEFAULT_VOICE_CLONE_MAX_MEMORY_PER_SAMPLE_GB = 1.0
|
| 50 |
+
DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_SHORT_SECONDS = 0.40
|
| 51 |
+
DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_LONG_SECONDS = 0.24
|
| 52 |
+
_SENTENCE_END_PUNCTUATION = frozenset(".!?。!?;;")
|
| 53 |
+
_CLAUSE_SPLIT_PUNCTUATION = frozenset(",,、;;::")
|
| 54 |
+
_CLOSING_PUNCTUATION = frozenset("\"'”’)]})】》」』")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class NanoTTSPreTrainedModel(PreTrainedModel):
|
| 58 |
+
config_class = NanoTTSConfig
|
| 59 |
+
base_model_prefix = "transformer"
|
| 60 |
+
supports_gradient_checkpointing = False
|
| 61 |
+
_no_split_modules = ["NanoGPT2Block"]
|
| 62 |
+
_supports_flash_attn_2 = True
|
| 63 |
+
_supports_sdpa = True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class NanoTTSGlobalLocalForCausalLM(NanoTTSPreTrainedModel):
|
| 67 |
+
_keys_to_ignore_on_load_unexpected = [r"local_transformer\.wte\.weight"]
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: NanoTTSConfig) -> None:
|
| 70 |
+
super().__init__(config)
|
| 71 |
+
config.gpt2_config.pad_token_id = config.pad_token_id
|
| 72 |
+
config.gpt2_config._attn_implementation = config.attn_implementation
|
| 73 |
+
|
| 74 |
+
self.transformer = NanoGPT2Model(
|
| 75 |
+
config.gpt2_config,
|
| 76 |
+
attn_implementation=config.attn_implementation,
|
| 77 |
+
)
|
| 78 |
+
hidden_size = config.gpt2_config.hidden_size
|
| 79 |
+
init_std = config.gpt2_config.initializer_range
|
| 80 |
+
|
| 81 |
+
self.audio_embeddings = nn.ModuleList(
|
| 82 |
+
[
|
| 83 |
+
nn.Embedding(int(config.audio_codebook_sizes[index]), hidden_size)
|
| 84 |
+
for index in range(config.n_vq)
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
self.text_lm_head = nn.Linear(hidden_size, config.gpt2_config.vocab_size, bias=False)
|
| 88 |
+
self.audio_lm_heads = nn.ModuleList(
|
| 89 |
+
[
|
| 90 |
+
nn.Linear(hidden_size, int(config.audio_codebook_sizes[index]), bias=False)
|
| 91 |
+
for index in range(config.n_vq)
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
local_gpt2_config = config.gpt2_config.to_dict()
|
| 96 |
+
local_gpt2_config["n_layer"] = int(config.local_transformer_layers)
|
| 97 |
+
local_gpt2_config["n_positions"] = config.n_vq + 1
|
| 98 |
+
local_gpt2_config["n_ctx"] = config.n_vq + 1
|
| 99 |
+
self.local_transformer = NanoGPT2Model(
|
| 100 |
+
GPT2Config(**local_gpt2_config),
|
| 101 |
+
attn_implementation=str(config.local_transformer_attn_implementation),
|
| 102 |
+
)
|
| 103 |
+
self.local_transformer.wte = nn.Identity()
|
| 104 |
+
|
| 105 |
+
for module in list(self.audio_embeddings) + [self.text_lm_head] + list(self.audio_lm_heads):
|
| 106 |
+
if hasattr(module, "weight") and module.weight is not None:
|
| 107 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 108 |
+
|
| 109 |
+
self._tied_weights_keys = tuple(self.all_tied_weights_keys.keys())
|
| 110 |
+
self.tie_weights()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def all_tied_weights_keys(self) -> dict[str, str]:
|
| 114 |
+
tied_weights = {"text_lm_head.weight": "transformer.wte.weight"}
|
| 115 |
+
tied_weights.update(
|
| 116 |
+
{
|
| 117 |
+
f"audio_lm_heads.{index}.weight": f"audio_embeddings.{index}.weight"
|
| 118 |
+
for index in range(self.config.n_vq)
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
return tied_weights
|
| 122 |
+
|
| 123 |
+
def tie_weights(self, *args, **kwargs) -> None:
|
| 124 |
+
del args, kwargs
|
| 125 |
+
self.text_lm_head.weight = self.transformer.wte.weight
|
| 126 |
+
for embedding, lm_head in zip(self.audio_embeddings, self.audio_lm_heads):
|
| 127 |
+
lm_head.weight = embedding.weight
|
| 128 |
+
|
| 129 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 130 |
+
return self.transformer.wte
|
| 131 |
+
|
| 132 |
+
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 133 |
+
self.transformer.wte = value
|
| 134 |
+
self.tie_weights()
|
| 135 |
+
|
| 136 |
+
def _build_inputs_embeds(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
| 137 |
+
if input_ids.ndim != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"Expected input_ids shape [batch, seq, {self.config.n_vq + 1}], got {tuple(input_ids.shape)}"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
text_ids = input_ids[..., 0]
|
| 143 |
+
inputs_embeds = self.transformer.wte(text_ids)
|
| 144 |
+
|
| 145 |
+
for channel_index, embedding in enumerate(self.audio_embeddings):
|
| 146 |
+
channel_ids = input_ids[..., channel_index + 1]
|
| 147 |
+
valid_mask = channel_ids.ne(self.config.audio_pad_token_id)
|
| 148 |
+
invalid_mask = valid_mask & ((channel_ids < 0) | (channel_ids >= embedding.num_embeddings))
|
| 149 |
+
if invalid_mask.any():
|
| 150 |
+
invalid_token_ids = channel_ids[invalid_mask]
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Found out-of-range audio token ids for channel "
|
| 153 |
+
f"{channel_index}: min={int(invalid_token_ids.min().item())} "
|
| 154 |
+
f"max={int(invalid_token_ids.max().item())} "
|
| 155 |
+
f"codebook_size={embedding.num_embeddings} "
|
| 156 |
+
f"audio_pad_token_id={self.config.audio_pad_token_id}"
|
| 157 |
+
)
|
| 158 |
+
safe_ids = channel_ids.masked_fill(~valid_mask, 0)
|
| 159 |
+
audio_embeds = embedding(safe_ids)
|
| 160 |
+
audio_embeds = audio_embeds * valid_mask.unsqueeze(-1)
|
| 161 |
+
inputs_embeds = inputs_embeds + audio_embeds
|
| 162 |
+
|
| 163 |
+
return inputs_embeds
|
| 164 |
+
|
| 165 |
+
def forward(
|
| 166 |
+
self,
|
| 167 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 168 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 169 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 170 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 171 |
+
use_cache: Optional[bool] = None,
|
| 172 |
+
output_attentions: Optional[bool] = None,
|
| 173 |
+
output_hidden_states: Optional[bool] = None,
|
| 174 |
+
return_dict: Optional[bool] = None,
|
| 175 |
+
**kwargs,
|
| 176 |
+
):
|
| 177 |
+
labels = kwargs.pop("labels", None)
|
| 178 |
+
if labels is not None:
|
| 179 |
+
raise NotImplementedError("This open-source package is inference-only and does not support training forward.")
|
| 180 |
+
if kwargs:
|
| 181 |
+
ignored = ", ".join(sorted(kwargs.keys()))
|
| 182 |
+
logging.debug("ignoring unsupported forward kwargs: %s", ignored)
|
| 183 |
+
|
| 184 |
+
return_dict = self.config.use_return_dict if return_dict is None else return_dict
|
| 185 |
+
if inputs_embeds is None:
|
| 186 |
+
if input_ids is None:
|
| 187 |
+
raise ValueError("Either input_ids or inputs_embeds must be provided.")
|
| 188 |
+
inputs_embeds = self._build_inputs_embeds(input_ids)
|
| 189 |
+
|
| 190 |
+
outputs = self.transformer(
|
| 191 |
+
input_ids=None,
|
| 192 |
+
past_key_values=past_key_values,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
position_ids=None,
|
| 195 |
+
inputs_embeds=inputs_embeds,
|
| 196 |
+
use_cache=use_cache,
|
| 197 |
+
output_attentions=output_attentions,
|
| 198 |
+
output_hidden_states=output_hidden_states,
|
| 199 |
+
return_dict=True,
|
| 200 |
+
cu_seqlens=None,
|
| 201 |
+
num_sequences=None,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if not return_dict:
|
| 205 |
+
return (
|
| 206 |
+
outputs.last_hidden_state,
|
| 207 |
+
outputs.past_key_values,
|
| 208 |
+
outputs.hidden_states,
|
| 209 |
+
outputs.attentions,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return NanoTTSOutput(
|
| 213 |
+
global_hidden_states=outputs.last_hidden_state,
|
| 214 |
+
past_key_values=outputs.past_key_values,
|
| 215 |
+
hidden_states=outputs.hidden_states,
|
| 216 |
+
attentions=outputs.attentions,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def _build_text_rows(
|
| 220 |
+
self,
|
| 221 |
+
token_ids: list[int],
|
| 222 |
+
device: torch.device,
|
| 223 |
+
) -> torch.LongTensor:
|
| 224 |
+
rows = torch.full(
|
| 225 |
+
(len(token_ids), self.config.n_vq + 1),
|
| 226 |
+
self.config.audio_pad_token_id,
|
| 227 |
+
dtype=torch.long,
|
| 228 |
+
device=device,
|
| 229 |
+
)
|
| 230 |
+
if token_ids:
|
| 231 |
+
rows[:, 0] = torch.tensor(token_ids, dtype=torch.long, device=device)
|
| 232 |
+
return rows
|
| 233 |
+
|
| 234 |
+
def _encode_text(self, tokenizer, text: str) -> list[int]:
|
| 235 |
+
try:
|
| 236 |
+
return list(tokenizer.encode(text, add_special_tokens=False))
|
| 237 |
+
except TypeError:
|
| 238 |
+
return list(tokenizer.encode(text))
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def _contains_cjk(text: str) -> bool:
|
| 242 |
+
return any(
|
| 243 |
+
"\u4e00" <= ch <= "\u9fff"
|
| 244 |
+
or "\u3400" <= ch <= "\u4dbf"
|
| 245 |
+
or "\u3040" <= ch <= "\u30ff"
|
| 246 |
+
or "\uac00" <= ch <= "\ud7af"
|
| 247 |
+
for ch in str(text)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def _prepare_text_for_sentence_chunking(text: str) -> str:
|
| 252 |
+
normalized_text = str(text).strip()
|
| 253 |
+
if normalized_text == "":
|
| 254 |
+
raise ValueError("Text prompt cannot be empty.")
|
| 255 |
+
|
| 256 |
+
normalized_text = normalized_text.replace("\n", " ").replace("\r", " ")
|
| 257 |
+
while " " in normalized_text:
|
| 258 |
+
normalized_text = normalized_text.replace(" ", " ")
|
| 259 |
+
|
| 260 |
+
contains_cjk = NanoTTSGlobalLocalForCausalLM._contains_cjk(normalized_text)
|
| 261 |
+
if contains_cjk:
|
| 262 |
+
if normalized_text[-1] not in _SENTENCE_END_PUNCTUATION:
|
| 263 |
+
normalized_text = normalized_text + "。"
|
| 264 |
+
return normalized_text
|
| 265 |
+
|
| 266 |
+
if not normalized_text[0].isupper():
|
| 267 |
+
normalized_text = normalized_text[0].upper() + normalized_text[1:]
|
| 268 |
+
if normalized_text[-1].isalnum():
|
| 269 |
+
normalized_text = normalized_text + "."
|
| 270 |
+
if len(normalized_text.split()) < 5:
|
| 271 |
+
normalized_text = " " * 8 + normalized_text
|
| 272 |
+
return normalized_text
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def _split_text_by_punctuation(text: str, punctuation: set[str] | frozenset[str]) -> list[str]:
|
| 276 |
+
sentences: list[str] = []
|
| 277 |
+
current_chars: list[str] = []
|
| 278 |
+
text = str(text)
|
| 279 |
+
index = 0
|
| 280 |
+
while index < len(text):
|
| 281 |
+
char = text[index]
|
| 282 |
+
current_chars.append(char)
|
| 283 |
+
if char in punctuation:
|
| 284 |
+
lookahead = index + 1
|
| 285 |
+
while lookahead < len(text) and text[lookahead] in _CLOSING_PUNCTUATION:
|
| 286 |
+
current_chars.append(text[lookahead])
|
| 287 |
+
lookahead += 1
|
| 288 |
+
sentence = "".join(current_chars).strip()
|
| 289 |
+
if sentence:
|
| 290 |
+
sentences.append(sentence)
|
| 291 |
+
current_chars = []
|
| 292 |
+
while lookahead < len(text) and text[lookahead].isspace():
|
| 293 |
+
lookahead += 1
|
| 294 |
+
index = lookahead
|
| 295 |
+
continue
|
| 296 |
+
index += 1
|
| 297 |
+
|
| 298 |
+
tail = "".join(current_chars).strip()
|
| 299 |
+
if tail:
|
| 300 |
+
sentences.append(tail)
|
| 301 |
+
return sentences
|
| 302 |
+
|
| 303 |
+
def _count_text_tokens(self, text_tokenizer, text: str) -> int:
|
| 304 |
+
return len(self._encode_text(text_tokenizer, text))
|
| 305 |
+
|
| 306 |
+
def _split_text_by_token_budget(
|
| 307 |
+
self,
|
| 308 |
+
text_tokenizer,
|
| 309 |
+
text: str,
|
| 310 |
+
max_tokens: int,
|
| 311 |
+
) -> list[str]:
|
| 312 |
+
remaining_text = str(text).strip()
|
| 313 |
+
if remaining_text == "":
|
| 314 |
+
return []
|
| 315 |
+
|
| 316 |
+
pieces: list[str] = []
|
| 317 |
+
preferred_boundary_chars = _CLAUSE_SPLIT_PUNCTUATION | _SENTENCE_END_PUNCTUATION | frozenset({" "})
|
| 318 |
+
while remaining_text:
|
| 319 |
+
if self._count_text_tokens(text_tokenizer, remaining_text) <= int(max_tokens):
|
| 320 |
+
pieces.append(remaining_text)
|
| 321 |
+
break
|
| 322 |
+
|
| 323 |
+
low = 1
|
| 324 |
+
high = len(remaining_text)
|
| 325 |
+
best_prefix_length = 1
|
| 326 |
+
while low <= high:
|
| 327 |
+
middle = (low + high) // 2
|
| 328 |
+
candidate = remaining_text[:middle].strip()
|
| 329 |
+
if not candidate:
|
| 330 |
+
low = middle + 1
|
| 331 |
+
continue
|
| 332 |
+
if self._count_text_tokens(text_tokenizer, candidate) <= int(max_tokens):
|
| 333 |
+
best_prefix_length = middle
|
| 334 |
+
low = middle + 1
|
| 335 |
+
else:
|
| 336 |
+
high = middle - 1
|
| 337 |
+
|
| 338 |
+
cut_index = best_prefix_length
|
| 339 |
+
prefix = remaining_text[:best_prefix_length]
|
| 340 |
+
preferred_index = -1
|
| 341 |
+
for scan_index in range(len(prefix) - 1, max(-1, len(prefix) - 25), -1):
|
| 342 |
+
if prefix[scan_index] in preferred_boundary_chars:
|
| 343 |
+
preferred_index = scan_index + 1
|
| 344 |
+
break
|
| 345 |
+
if preferred_index > 0:
|
| 346 |
+
cut_index = preferred_index
|
| 347 |
+
|
| 348 |
+
piece = remaining_text[:cut_index].strip()
|
| 349 |
+
if not piece:
|
| 350 |
+
piece = remaining_text[:best_prefix_length].strip()
|
| 351 |
+
cut_index = best_prefix_length
|
| 352 |
+
pieces.append(piece)
|
| 353 |
+
remaining_text = remaining_text[cut_index:].strip()
|
| 354 |
+
return pieces
|
| 355 |
+
|
| 356 |
+
@staticmethod
|
| 357 |
+
def _join_sentence_parts(left: str, right: str) -> str:
|
| 358 |
+
if not left:
|
| 359 |
+
return right
|
| 360 |
+
if not right:
|
| 361 |
+
return left
|
| 362 |
+
if NanoTTSGlobalLocalForCausalLM._contains_cjk(left) or NanoTTSGlobalLocalForCausalLM._contains_cjk(right):
|
| 363 |
+
return left + right
|
| 364 |
+
return left + " " + right
|
| 365 |
+
|
| 366 |
+
def _split_text_into_best_sentences(
|
| 367 |
+
self,
|
| 368 |
+
text_tokenizer,
|
| 369 |
+
text: str,
|
| 370 |
+
max_tokens: int,
|
| 371 |
+
) -> list[str]:
|
| 372 |
+
if int(max_tokens) <= 0:
|
| 373 |
+
return [str(text)]
|
| 374 |
+
|
| 375 |
+
prepared_text = self._prepare_text_for_sentence_chunking(text)
|
| 376 |
+
sentence_candidates = self._split_text_by_punctuation(prepared_text, punctuation=_SENTENCE_END_PUNCTUATION)
|
| 377 |
+
if not sentence_candidates:
|
| 378 |
+
sentence_candidates = [prepared_text.strip()]
|
| 379 |
+
|
| 380 |
+
sentence_slices: list[tuple[int, str]] = []
|
| 381 |
+
for sentence_text in sentence_candidates:
|
| 382 |
+
normalized_sentence = sentence_text.strip()
|
| 383 |
+
if not normalized_sentence:
|
| 384 |
+
continue
|
| 385 |
+
sentence_token_count = self._count_text_tokens(text_tokenizer, normalized_sentence)
|
| 386 |
+
if sentence_token_count <= int(max_tokens):
|
| 387 |
+
sentence_slices.append((sentence_token_count, normalized_sentence))
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
clause_candidates = self._split_text_by_punctuation(
|
| 391 |
+
normalized_sentence,
|
| 392 |
+
punctuation=_CLAUSE_SPLIT_PUNCTUATION,
|
| 393 |
+
)
|
| 394 |
+
if len(clause_candidates) <= 1:
|
| 395 |
+
clause_candidates = [normalized_sentence]
|
| 396 |
+
|
| 397 |
+
for clause_text in clause_candidates:
|
| 398 |
+
normalized_clause = clause_text.strip()
|
| 399 |
+
if not normalized_clause:
|
| 400 |
+
continue
|
| 401 |
+
clause_token_count = self._count_text_tokens(text_tokenizer, normalized_clause)
|
| 402 |
+
if clause_token_count <= int(max_tokens):
|
| 403 |
+
sentence_slices.append((clause_token_count, normalized_clause))
|
| 404 |
+
continue
|
| 405 |
+
for piece in self._split_text_by_token_budget(
|
| 406 |
+
text_tokenizer=text_tokenizer,
|
| 407 |
+
text=normalized_clause,
|
| 408 |
+
max_tokens=max_tokens,
|
| 409 |
+
):
|
| 410 |
+
normalized_piece = piece.strip()
|
| 411 |
+
if normalized_piece:
|
| 412 |
+
sentence_slices.append(
|
| 413 |
+
(self._count_text_tokens(text_tokenizer, normalized_piece), normalized_piece)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
chunks: list[str] = []
|
| 417 |
+
current_chunk = ""
|
| 418 |
+
current_chunk_token_count = 0
|
| 419 |
+
for sentence_token_count, sentence_text in sentence_slices:
|
| 420 |
+
if current_chunk == "":
|
| 421 |
+
current_chunk = sentence_text
|
| 422 |
+
current_chunk_token_count = sentence_token_count
|
| 423 |
+
continue
|
| 424 |
+
if current_chunk_token_count + sentence_token_count > int(max_tokens):
|
| 425 |
+
chunks.append(current_chunk.strip())
|
| 426 |
+
current_chunk = sentence_text
|
| 427 |
+
current_chunk_token_count = sentence_token_count
|
| 428 |
+
else:
|
| 429 |
+
current_chunk = self._join_sentence_parts(current_chunk, sentence_text)
|
| 430 |
+
current_chunk_token_count = self._count_text_tokens(text_tokenizer, current_chunk)
|
| 431 |
+
|
| 432 |
+
if current_chunk:
|
| 433 |
+
chunks.append(current_chunk.strip())
|
| 434 |
+
return chunks or [prepared_text.strip()]
|
| 435 |
+
|
| 436 |
+
@staticmethod
|
| 437 |
+
def _estimate_voice_clone_inter_chunk_pause_seconds(text_chunk: str) -> float:
|
| 438 |
+
return (
|
| 439 |
+
DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_SHORT_SECONDS
|
| 440 |
+
if len(str(text_chunk).strip().split()) <= 4
|
| 441 |
+
else DEFAULT_VOICE_CLONE_INTER_CHUNK_PAUSE_LONG_SECONDS
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def _concat_voice_clone_waveform_chunks(
|
| 445 |
+
self,
|
| 446 |
+
waveform_chunks: list[torch.FloatTensor],
|
| 447 |
+
text_chunks: list[str],
|
| 448 |
+
sample_rate: int,
|
| 449 |
+
) -> torch.FloatTensor:
|
| 450 |
+
if not waveform_chunks:
|
| 451 |
+
return torch.zeros((1, 0), dtype=torch.float32)
|
| 452 |
+
if len(waveform_chunks) != len(text_chunks):
|
| 453 |
+
raise ValueError("waveform_chunks and text_chunks must have the same length.")
|
| 454 |
+
if len(waveform_chunks) == 1:
|
| 455 |
+
return waveform_chunks[0]
|
| 456 |
+
|
| 457 |
+
segments: list[torch.FloatTensor] = []
|
| 458 |
+
for chunk_index, waveform_chunk in enumerate(waveform_chunks):
|
| 459 |
+
segments.append(waveform_chunk)
|
| 460 |
+
if chunk_index >= len(waveform_chunks) - 1:
|
| 461 |
+
continue
|
| 462 |
+
pause_seconds = self._estimate_voice_clone_inter_chunk_pause_seconds(text_chunks[chunk_index])
|
| 463 |
+
pause_samples = max(0, int(round(float(sample_rate) * pause_seconds)))
|
| 464 |
+
if pause_samples > 0:
|
| 465 |
+
silence = torch.zeros((waveform_chunk.shape[0], pause_samples), dtype=waveform_chunk.dtype)
|
| 466 |
+
segments.append(silence)
|
| 467 |
+
return torch.cat(segments, dim=-1)
|
| 468 |
+
|
| 469 |
+
@staticmethod
|
| 470 |
+
def _resolve_inference_mode(
|
| 471 |
+
mode: str,
|
| 472 |
+
has_prompt_text: bool,
|
| 473 |
+
has_prompt_audio: bool,
|
| 474 |
+
) -> str:
|
| 475 |
+
normalized_mode = str(mode or "continuation").strip().lower() or "continuation"
|
| 476 |
+
if normalized_mode not in {"continuation", "voice_clone"}:
|
| 477 |
+
raise ValueError(f"Unsupported inference mode {mode!r}.")
|
| 478 |
+
if normalized_mode == "voice_clone":
|
| 479 |
+
if not has_prompt_audio:
|
| 480 |
+
raise ValueError("voice_clone mode requires prompt_audio_path.")
|
| 481 |
+
if has_prompt_text:
|
| 482 |
+
raise ValueError("voice_clone mode does not accept prompt_text.")
|
| 483 |
+
elif has_prompt_text != has_prompt_audio:
|
| 484 |
+
raise ValueError(
|
| 485 |
+
"continuation mode accepts either target text only, or prompt_text and prompt_audio_path together."
|
| 486 |
+
)
|
| 487 |
+
return normalized_mode
|
| 488 |
+
|
| 489 |
+
def _resolve_inference_nq(self, nq: Optional[int] = None) -> int:
|
| 490 |
+
if nq is None:
|
| 491 |
+
return int(self.config.n_vq)
|
| 492 |
+
resolved_nq = int(nq)
|
| 493 |
+
if resolved_nq < 1 or resolved_nq > int(self.config.n_vq):
|
| 494 |
+
raise ValueError(f"nq must be in [1, {self.config.n_vq}], got {resolved_nq}.")
|
| 495 |
+
return resolved_nq
|
| 496 |
+
|
| 497 |
+
def _mask_unused_audio_channels(
|
| 498 |
+
self,
|
| 499 |
+
audio_token_ids: torch.LongTensor,
|
| 500 |
+
nq: int,
|
| 501 |
+
) -> torch.LongTensor:
|
| 502 |
+
tensor = torch.as_tensor(audio_token_ids, dtype=torch.long)
|
| 503 |
+
if tensor.shape[-1] != self.config.n_vq:
|
| 504 |
+
raise ValueError(
|
| 505 |
+
f"Expected audio token ids with trailing dim {self.config.n_vq}, got {tuple(tensor.shape)}"
|
| 506 |
+
)
|
| 507 |
+
if nq < self.config.n_vq:
|
| 508 |
+
tensor = tensor.clone()
|
| 509 |
+
tensor[..., nq:] = self.config.audio_pad_token_id
|
| 510 |
+
return tensor
|
| 511 |
+
|
| 512 |
+
def _build_audio_prefix_rows(
|
| 513 |
+
self,
|
| 514 |
+
prompt_audio_codes: torch.LongTensor,
|
| 515 |
+
slot_token_id: int,
|
| 516 |
+
device: torch.device,
|
| 517 |
+
) -> torch.LongTensor:
|
| 518 |
+
rows = torch.full(
|
| 519 |
+
(int(prompt_audio_codes.shape[0]), self.config.n_vq + 1),
|
| 520 |
+
self.config.audio_pad_token_id,
|
| 521 |
+
dtype=torch.long,
|
| 522 |
+
device=device,
|
| 523 |
+
)
|
| 524 |
+
if rows.shape[0] > 0:
|
| 525 |
+
rows[:, 0] = int(slot_token_id)
|
| 526 |
+
rows[:, 1:] = prompt_audio_codes
|
| 527 |
+
return rows
|
| 528 |
+
|
| 529 |
+
def build_inference_input_ids(
|
| 530 |
+
self,
|
| 531 |
+
text: str,
|
| 532 |
+
text_tokenizer,
|
| 533 |
+
mode: str = "continuation",
|
| 534 |
+
prompt_text: Optional[str] = None,
|
| 535 |
+
prompt_audio_codes: Optional[torch.LongTensor] = None,
|
| 536 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 537 |
+
) -> tuple[torch.LongTensor, torch.BoolTensor]:
|
| 538 |
+
resolved_device = self._resolve_device(device)
|
| 539 |
+
resolved_mode = self._resolve_inference_mode(
|
| 540 |
+
mode=mode,
|
| 541 |
+
has_prompt_text=prompt_text is not None,
|
| 542 |
+
has_prompt_audio=prompt_audio_codes is not None,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if resolved_mode == "voice_clone":
|
| 546 |
+
assert prompt_audio_codes is not None
|
| 547 |
+
text_token_ids = self._encode_text(text_tokenizer, text)
|
| 548 |
+
prompt_token_ids = build_user_prompt_prefix(text_tokenizer, self.config) + [self.config.audio_start_token_id]
|
| 549 |
+
suffix_token_ids = (
|
| 550 |
+
[self.config.audio_end_token_id]
|
| 551 |
+
+ build_user_prompt_after_reference(text_tokenizer)
|
| 552 |
+
+ text_token_ids
|
| 553 |
+
+ build_assistant_prompt_prefix(text_tokenizer, self.config)
|
| 554 |
+
+ [self.config.audio_start_token_id]
|
| 555 |
+
)
|
| 556 |
+
sections = [
|
| 557 |
+
self._build_text_rows(prompt_token_ids, device=resolved_device),
|
| 558 |
+
self._build_audio_prefix_rows(
|
| 559 |
+
prompt_audio_codes=prompt_audio_codes.to(resolved_device),
|
| 560 |
+
slot_token_id=self.config.audio_user_slot_token_id,
|
| 561 |
+
device=resolved_device,
|
| 562 |
+
),
|
| 563 |
+
self._build_text_rows(suffix_token_ids, device=resolved_device),
|
| 564 |
+
]
|
| 565 |
+
input_ids = torch.cat(sections, dim=0).unsqueeze(0)
|
| 566 |
+
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=resolved_device)
|
| 567 |
+
return input_ids, attention_mask
|
| 568 |
+
|
| 569 |
+
effective_text = text if prompt_text is None else prompt_text + text
|
| 570 |
+
prompt_token_ids = build_prompt_token_ids(
|
| 571 |
+
tokenizer=text_tokenizer,
|
| 572 |
+
config=self.config,
|
| 573 |
+
text_token_ids=self._encode_text(text_tokenizer, effective_text),
|
| 574 |
+
)
|
| 575 |
+
sections = [
|
| 576 |
+
self._build_text_rows(prompt_token_ids, device=resolved_device),
|
| 577 |
+
self._build_text_rows([self.config.audio_start_token_id], device=resolved_device),
|
| 578 |
+
]
|
| 579 |
+
if prompt_audio_codes is not None:
|
| 580 |
+
sections.append(
|
| 581 |
+
self._build_audio_prefix_rows(
|
| 582 |
+
prompt_audio_codes=prompt_audio_codes.to(resolved_device),
|
| 583 |
+
slot_token_id=self.config.audio_assistant_slot_token_id,
|
| 584 |
+
device=resolved_device,
|
| 585 |
+
)
|
| 586 |
+
)
|
| 587 |
+
input_ids = torch.cat(sections, dim=0).unsqueeze(0)
|
| 588 |
+
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=resolved_device)
|
| 589 |
+
return input_ids, attention_mask
|
| 590 |
+
|
| 591 |
+
def _left_pad_inference_batch(
|
| 592 |
+
self,
|
| 593 |
+
input_id_batches: list[torch.LongTensor],
|
| 594 |
+
attention_mask_batches: list[torch.BoolTensor],
|
| 595 |
+
device: torch.device,
|
| 596 |
+
) -> tuple[torch.LongTensor, torch.BoolTensor]:
|
| 597 |
+
if not input_id_batches:
|
| 598 |
+
raise ValueError("input_id_batches must not be empty.")
|
| 599 |
+
if len(input_id_batches) != len(attention_mask_batches):
|
| 600 |
+
raise ValueError("input_id_batches and attention_mask_batches must have the same length.")
|
| 601 |
+
|
| 602 |
+
batch_size = len(input_id_batches)
|
| 603 |
+
max_seq_len = max(int(batch.shape[1]) for batch in input_id_batches)
|
| 604 |
+
row_width = self.config.n_vq + 1
|
| 605 |
+
|
| 606 |
+
padded_input_ids = torch.full(
|
| 607 |
+
(batch_size, max_seq_len, row_width),
|
| 608 |
+
self.config.audio_pad_token_id,
|
| 609 |
+
dtype=torch.long,
|
| 610 |
+
device=device,
|
| 611 |
+
)
|
| 612 |
+
padded_input_ids[:, :, 0] = self.config.pad_token_id
|
| 613 |
+
padded_attention_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device)
|
| 614 |
+
|
| 615 |
+
for batch_index, (input_ids, attention_mask) in enumerate(zip(input_id_batches, attention_mask_batches)):
|
| 616 |
+
normalized_input_ids = input_ids.squeeze(0).to(device=device, dtype=torch.long)
|
| 617 |
+
normalized_attention_mask = attention_mask.squeeze(0).to(device=device, dtype=torch.bool)
|
| 618 |
+
seq_len = int(normalized_input_ids.shape[0])
|
| 619 |
+
padded_input_ids[batch_index, -seq_len:, :] = normalized_input_ids
|
| 620 |
+
padded_attention_mask[batch_index, -seq_len:] = normalized_attention_mask
|
| 621 |
+
|
| 622 |
+
return padded_input_ids, padded_attention_mask
|
| 623 |
+
|
| 624 |
+
def _trim_generated_audio_token_ids(
|
| 625 |
+
self,
|
| 626 |
+
audio_token_ids: torch.LongTensor,
|
| 627 |
+
effective_nq: int,
|
| 628 |
+
) -> torch.LongTensor:
|
| 629 |
+
tensor = self._mask_unused_audio_channels(audio_token_ids, nq=effective_nq)
|
| 630 |
+
if tensor.ndim != 2:
|
| 631 |
+
raise ValueError(f"Expected a 2D audio token tensor, got {tuple(tensor.shape)}")
|
| 632 |
+
valid_rows = tensor[:, :effective_nq].ne(self.config.audio_pad_token_id).any(dim=-1)
|
| 633 |
+
if not bool(valid_rows.any()):
|
| 634 |
+
return tensor[:0]
|
| 635 |
+
last_valid_index = int(torch.nonzero(valid_rows, as_tuple=False)[-1].item()) + 1
|
| 636 |
+
return tensor[:last_valid_index]
|
| 637 |
+
|
| 638 |
+
def _resolve_voice_clone_chunk_batch_size(
|
| 639 |
+
self,
|
| 640 |
+
*,
|
| 641 |
+
resolved_device: torch.device,
|
| 642 |
+
chunk_count: int,
|
| 643 |
+
max_memory_per_sample_gb: float,
|
| 644 |
+
) -> int:
|
| 645 |
+
if chunk_count <= 1 or max_memory_per_sample_gb <= 0 or resolved_device.type != "cuda":
|
| 646 |
+
return 1
|
| 647 |
+
if not hasattr(torch.cuda, "mem_get_info"):
|
| 648 |
+
return 1
|
| 649 |
+
try:
|
| 650 |
+
free_bytes, _ = torch.cuda.mem_get_info(resolved_device)
|
| 651 |
+
except Exception:
|
| 652 |
+
return 1
|
| 653 |
+
bytes_per_sample = int(float(max_memory_per_sample_gb) * (1024**3))
|
| 654 |
+
if bytes_per_sample <= 0:
|
| 655 |
+
return 1
|
| 656 |
+
usable_free_bytes = max(0, int(free_bytes * 0.9))
|
| 657 |
+
batch_size = max(1, usable_free_bytes // bytes_per_sample)
|
| 658 |
+
resolved_batch_size = max(1, min(int(chunk_count), int(batch_size)))
|
| 659 |
+
logging.info(
|
| 660 |
+
"voice_clone chunk batching device=%s free_gb=%.2f max_memory_per_sample_gb=%.2f resolved_batch_size=%d chunk_count=%d",
|
| 661 |
+
resolved_device,
|
| 662 |
+
float(free_bytes) / float(1024**3),
|
| 663 |
+
float(max_memory_per_sample_gb),
|
| 664 |
+
resolved_batch_size,
|
| 665 |
+
int(chunk_count),
|
| 666 |
+
)
|
| 667 |
+
return resolved_batch_size
|
| 668 |
+
|
| 669 |
+
def _generate_audio_token_ids_with_fallback(
|
| 670 |
+
self,
|
| 671 |
+
*,
|
| 672 |
+
prompt_input_ids: torch.LongTensor,
|
| 673 |
+
attention_mask: torch.BoolTensor,
|
| 674 |
+
effective_nq: int,
|
| 675 |
+
max_new_frames: int,
|
| 676 |
+
do_sample: bool,
|
| 677 |
+
text_temperature: float,
|
| 678 |
+
text_top_p: float,
|
| 679 |
+
text_top_k: int,
|
| 680 |
+
audio_temperature: float,
|
| 681 |
+
audio_top_p: float,
|
| 682 |
+
audio_top_k: int,
|
| 683 |
+
audio_repetition_penalty: float,
|
| 684 |
+
use_kv_cache: bool,
|
| 685 |
+
resolved_device: torch.device,
|
| 686 |
+
) -> torch.LongTensor:
|
| 687 |
+
try:
|
| 688 |
+
generation = self.generate(
|
| 689 |
+
input_ids=prompt_input_ids,
|
| 690 |
+
attention_mask=attention_mask,
|
| 691 |
+
nq=effective_nq,
|
| 692 |
+
max_new_frames=max_new_frames,
|
| 693 |
+
do_sample=do_sample,
|
| 694 |
+
text_temperature=text_temperature,
|
| 695 |
+
text_top_p=text_top_p,
|
| 696 |
+
text_top_k=text_top_k,
|
| 697 |
+
audio_temperature=audio_temperature,
|
| 698 |
+
audio_top_p=audio_top_p,
|
| 699 |
+
audio_top_k=audio_top_k,
|
| 700 |
+
audio_repetition_penalty=audio_repetition_penalty,
|
| 701 |
+
use_kv_cache=use_kv_cache,
|
| 702 |
+
return_dict_in_generate=True,
|
| 703 |
+
)
|
| 704 |
+
except (RuntimeError, ValueError) as exc:
|
| 705 |
+
if not self._is_generation_stability_error(exc):
|
| 706 |
+
raise
|
| 707 |
+
self._apply_inference_stability_fallback(resolved_device)
|
| 708 |
+
generation = self.generate(
|
| 709 |
+
input_ids=prompt_input_ids,
|
| 710 |
+
attention_mask=attention_mask,
|
| 711 |
+
nq=effective_nq,
|
| 712 |
+
max_new_frames=max_new_frames,
|
| 713 |
+
do_sample=do_sample,
|
| 714 |
+
text_temperature=text_temperature,
|
| 715 |
+
text_top_p=text_top_p,
|
| 716 |
+
text_top_k=text_top_k,
|
| 717 |
+
audio_temperature=audio_temperature,
|
| 718 |
+
audio_top_p=audio_top_p,
|
| 719 |
+
audio_top_k=audio_top_k,
|
| 720 |
+
audio_repetition_penalty=audio_repetition_penalty,
|
| 721 |
+
use_kv_cache=use_kv_cache,
|
| 722 |
+
return_dict_in_generate=True,
|
| 723 |
+
)
|
| 724 |
+
return self._mask_unused_audio_channels(generation.audio_token_ids, nq=effective_nq)
|
| 725 |
+
|
| 726 |
+
def _decode_audio_token_ids_to_waveform(
|
| 727 |
+
self,
|
| 728 |
+
*,
|
| 729 |
+
audio_tokenizer,
|
| 730 |
+
audio_token_ids: torch.LongTensor,
|
| 731 |
+
target_sample_rate: int,
|
| 732 |
+
effective_nq: int,
|
| 733 |
+
resolved_device: torch.device,
|
| 734 |
+
) -> tuple[torch.FloatTensor, int]:
|
| 735 |
+
decoded = self._call_audio_decode(
|
| 736 |
+
audio_tokenizer=audio_tokenizer,
|
| 737 |
+
audio_token_ids=audio_token_ids.to(resolved_device),
|
| 738 |
+
sample_rate=target_sample_rate,
|
| 739 |
+
nq=effective_nq,
|
| 740 |
+
)
|
| 741 |
+
return self._extract_waveform_and_sample_rate(decoded, fallback_sample_rate=target_sample_rate)
|
| 742 |
+
|
| 743 |
+
def _build_generation_row(
|
| 744 |
+
self,
|
| 745 |
+
batch_size: int,
|
| 746 |
+
device: torch.device,
|
| 747 |
+
audio_token_ids: torch.LongTensor,
|
| 748 |
+
) -> torch.LongTensor:
|
| 749 |
+
row = torch.full(
|
| 750 |
+
(batch_size, 1, self.config.n_vq + 1),
|
| 751 |
+
self.config.audio_pad_token_id,
|
| 752 |
+
dtype=torch.long,
|
| 753 |
+
device=device,
|
| 754 |
+
)
|
| 755 |
+
row[:, :, 0] = self.config.audio_assistant_slot_token_id
|
| 756 |
+
row[:, :, 1:] = audio_token_ids.unsqueeze(1)
|
| 757 |
+
return row
|
| 758 |
+
|
| 759 |
+
def _sample_next_token(
|
| 760 |
+
self,
|
| 761 |
+
logits: torch.FloatTensor,
|
| 762 |
+
do_sample: bool,
|
| 763 |
+
temperature: float,
|
| 764 |
+
top_k: Optional[int],
|
| 765 |
+
top_p: Optional[float],
|
| 766 |
+
previous_token_ids: Optional[torch.LongTensor] = None,
|
| 767 |
+
repetition_penalty: float = 1.0,
|
| 768 |
+
) -> torch.LongTensor:
|
| 769 |
+
scores = self._apply_repetition_penalty(
|
| 770 |
+
logits=logits,
|
| 771 |
+
previous_token_ids=previous_token_ids,
|
| 772 |
+
repetition_penalty=repetition_penalty,
|
| 773 |
+
)
|
| 774 |
+
if not do_sample:
|
| 775 |
+
return scores.argmax(dim=-1)
|
| 776 |
+
if temperature <= 0:
|
| 777 |
+
raise ValueError("temperature must be positive when do_sample=True")
|
| 778 |
+
|
| 779 |
+
scores = scores / temperature
|
| 780 |
+
if top_k is not None and top_k > 0:
|
| 781 |
+
top_k = min(top_k, scores.shape[-1])
|
| 782 |
+
threshold = torch.topk(scores, top_k, dim=-1).values[..., -1, None]
|
| 783 |
+
scores = scores.masked_fill(scores < threshold, float("-inf"))
|
| 784 |
+
|
| 785 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 786 |
+
sorted_scores, sorted_indices = torch.sort(scores, descending=True, dim=-1)
|
| 787 |
+
sorted_probs = torch.softmax(sorted_scores, dim=-1)
|
| 788 |
+
sorted_cumsum = torch.cumsum(sorted_probs, dim=-1)
|
| 789 |
+
sorted_remove = sorted_cumsum > top_p
|
| 790 |
+
sorted_remove[..., 1:] = sorted_remove[..., :-1].clone()
|
| 791 |
+
sorted_remove[..., 0] = False
|
| 792 |
+
sorted_scores = sorted_scores.masked_fill(sorted_remove, float("-inf"))
|
| 793 |
+
scores = torch.full_like(scores, float("-inf"))
|
| 794 |
+
scores.scatter_(dim=-1, index=sorted_indices, src=sorted_scores)
|
| 795 |
+
|
| 796 |
+
probs = torch.softmax(scores, dim=-1)
|
| 797 |
+
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 798 |
+
|
| 799 |
+
@staticmethod
|
| 800 |
+
def _ensure_finite_generation_logits(logits: torch.FloatTensor, name: str) -> None:
|
| 801 |
+
if torch.isfinite(logits).all():
|
| 802 |
+
return
|
| 803 |
+
finite_mask = torch.isfinite(logits)
|
| 804 |
+
finite_logits = logits[finite_mask]
|
| 805 |
+
min_value = float(finite_logits.min().item()) if finite_logits.numel() > 0 else float("nan")
|
| 806 |
+
max_value = float(finite_logits.max().item()) if finite_logits.numel() > 0 else float("nan")
|
| 807 |
+
raise RuntimeError(
|
| 808 |
+
f"Non-finite {name} during generation: dtype={logits.dtype} "
|
| 809 |
+
f"shape={tuple(logits.shape)} finite={int(finite_mask.sum().item())}/{int(logits.numel())} "
|
| 810 |
+
f"min={min_value} max={max_value}"
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
def _apply_repetition_penalty(
|
| 814 |
+
self,
|
| 815 |
+
logits: torch.FloatTensor,
|
| 816 |
+
previous_token_ids: Optional[torch.LongTensor],
|
| 817 |
+
repetition_penalty: float,
|
| 818 |
+
) -> torch.FloatTensor:
|
| 819 |
+
if repetition_penalty <= 0:
|
| 820 |
+
raise ValueError("repetition_penalty must be positive")
|
| 821 |
+
if repetition_penalty == 1.0 or previous_token_ids is None:
|
| 822 |
+
return logits
|
| 823 |
+
|
| 824 |
+
token_ids = torch.as_tensor(previous_token_ids, device=logits.device, dtype=torch.long)
|
| 825 |
+
if token_ids.ndim == 1:
|
| 826 |
+
token_ids = token_ids.unsqueeze(0)
|
| 827 |
+
elif token_ids.ndim > 2:
|
| 828 |
+
token_ids = token_ids.reshape(token_ids.shape[0], -1)
|
| 829 |
+
|
| 830 |
+
scores = logits.clone()
|
| 831 |
+
vocab_size = scores.shape[-1]
|
| 832 |
+
for batch_index in range(scores.shape[0]):
|
| 833 |
+
valid_token_ids = token_ids[batch_index]
|
| 834 |
+
valid_token_ids = valid_token_ids[(valid_token_ids >= 0) & (valid_token_ids < vocab_size)]
|
| 835 |
+
if valid_token_ids.numel() == 0:
|
| 836 |
+
continue
|
| 837 |
+
unique_token_ids = torch.unique(valid_token_ids)
|
| 838 |
+
token_scores = scores[batch_index].index_select(0, unique_token_ids)
|
| 839 |
+
token_scores = torch.where(
|
| 840 |
+
token_scores < 0,
|
| 841 |
+
token_scores * repetition_penalty,
|
| 842 |
+
token_scores / repetition_penalty,
|
| 843 |
+
)
|
| 844 |
+
scores[batch_index].scatter_(0, unique_token_ids, token_scores)
|
| 845 |
+
return scores
|
| 846 |
+
|
| 847 |
+
def _sample_next_assistant_text_token(
|
| 848 |
+
self,
|
| 849 |
+
logits: torch.FloatTensor,
|
| 850 |
+
do_sample: bool,
|
| 851 |
+
temperature: float,
|
| 852 |
+
top_k: Optional[int] = None,
|
| 853 |
+
top_p: Optional[float] = None,
|
| 854 |
+
) -> torch.LongTensor:
|
| 855 |
+
candidate_ids = torch.tensor(
|
| 856 |
+
[
|
| 857 |
+
self.config.audio_assistant_slot_token_id,
|
| 858 |
+
self.config.audio_end_token_id,
|
| 859 |
+
],
|
| 860 |
+
dtype=torch.long,
|
| 861 |
+
device=logits.device,
|
| 862 |
+
)
|
| 863 |
+
candidate_logits = logits.index_select(dim=-1, index=candidate_ids)
|
| 864 |
+
sampled_indices = self._sample_next_token(
|
| 865 |
+
logits=candidate_logits,
|
| 866 |
+
do_sample=do_sample,
|
| 867 |
+
temperature=temperature,
|
| 868 |
+
top_k=top_k,
|
| 869 |
+
top_p=top_p,
|
| 870 |
+
)
|
| 871 |
+
return candidate_ids[sampled_indices]
|
| 872 |
+
|
| 873 |
+
def _resolve_device(self, device: Optional[Union[str, torch.device]] = None) -> torch.device:
|
| 874 |
+
return torch.device(device) if device is not None else next(self.parameters()).device
|
| 875 |
+
|
| 876 |
+
@staticmethod
|
| 877 |
+
def _looks_like_hf_tokenizer_dir(candidate_path: Path) -> bool:
|
| 878 |
+
if not candidate_path.is_dir():
|
| 879 |
+
return False
|
| 880 |
+
if (candidate_path / "tokenizer.model").is_file():
|
| 881 |
+
return True
|
| 882 |
+
if (candidate_path / "tokenizer.json").is_file():
|
| 883 |
+
return True
|
| 884 |
+
if (candidate_path / "tokenizer_config.json").is_file() and (
|
| 885 |
+
(candidate_path / "vocab.json").is_file()
|
| 886 |
+
or (candidate_path / "merges.txt").is_file()
|
| 887 |
+
or (candidate_path / "special_tokens_map.json").is_file()
|
| 888 |
+
):
|
| 889 |
+
return True
|
| 890 |
+
return False
|
| 891 |
+
|
| 892 |
+
def _resolve_text_tokenizer_path(self, raw_path: Union[str, Path]) -> Path:
|
| 893 |
+
candidate_path = Path(raw_path)
|
| 894 |
+
if candidate_path.is_file() and candidate_path.suffix == ".model":
|
| 895 |
+
return candidate_path
|
| 896 |
+
if not candidate_path.exists():
|
| 897 |
+
raise FileNotFoundError(f"Tokenizer path does not exist: {candidate_path}")
|
| 898 |
+
if candidate_path.is_dir():
|
| 899 |
+
if (candidate_path / "tokenizer.model").is_file():
|
| 900 |
+
return candidate_path
|
| 901 |
+
if self._looks_like_hf_tokenizer_dir(candidate_path):
|
| 902 |
+
return candidate_path
|
| 903 |
+
hf_dir = candidate_path / "hf_tokenizer"
|
| 904 |
+
if self._looks_like_hf_tokenizer_dir(hf_dir):
|
| 905 |
+
return hf_dir
|
| 906 |
+
sentencepiece_model = candidate_path / "sentencepiece" / "nanotts_spm_bpe.model"
|
| 907 |
+
if sentencepiece_model.is_file():
|
| 908 |
+
return sentencepiece_model
|
| 909 |
+
final_summary_path = candidate_path / "final_summary.json"
|
| 910 |
+
if final_summary_path.is_file():
|
| 911 |
+
final_summary = json.loads(final_summary_path.read_text(encoding="utf-8"))
|
| 912 |
+
latest_hf_dir = final_summary.get("latest_hf_tokenizer_dir")
|
| 913 |
+
if latest_hf_dir:
|
| 914 |
+
latest_hf_path = Path(str(latest_hf_dir))
|
| 915 |
+
if self._looks_like_hf_tokenizer_dir(latest_hf_path):
|
| 916 |
+
return latest_hf_path
|
| 917 |
+
raise ValueError(
|
| 918 |
+
"Could not resolve a tokenizer from the provided path. Expected a tokenizer dir, experiment dir, or SentencePiece .model file."
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
def _load_resolved_text_tokenizer(self, resolved_path: Path, cache_dir: str):
|
| 922 |
+
if resolved_path.is_file() and resolved_path.suffix == ".model":
|
| 923 |
+
return NanoTTSSentencePieceTokenizer(vocab_file=str(resolved_path))
|
| 924 |
+
try:
|
| 925 |
+
return AutoTokenizer.from_pretrained(
|
| 926 |
+
str(resolved_path),
|
| 927 |
+
trust_remote_code=True,
|
| 928 |
+
use_fast=bool(self.config.tokenizer_use_fast),
|
| 929 |
+
local_files_only=True,
|
| 930 |
+
cache_dir=cache_dir,
|
| 931 |
+
)
|
| 932 |
+
except Exception:
|
| 933 |
+
model_path = resolved_path / "tokenizer.model"
|
| 934 |
+
if model_path.is_file():
|
| 935 |
+
return NanoTTSSentencePieceTokenizer(vocab_file=str(model_path))
|
| 936 |
+
raise
|
| 937 |
+
|
| 938 |
+
@staticmethod
|
| 939 |
+
def _resolve_hf_cache_dir() -> str:
|
| 940 |
+
cache_dir = Path(__file__).resolve().parent / ".cache" / "huggingface"
|
| 941 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 942 |
+
return str(cache_dir)
|
| 943 |
+
|
| 944 |
+
@staticmethod
|
| 945 |
+
def _patch_hf_dynamic_module_cache_dir(cache_dir: str) -> None:
|
| 946 |
+
import transformers.dynamic_module_utils as dynamic_module_utils
|
| 947 |
+
|
| 948 |
+
modules_cache_dir = str(Path(cache_dir) / "modules")
|
| 949 |
+
Path(modules_cache_dir).mkdir(parents=True, exist_ok=True)
|
| 950 |
+
os.environ["HF_MODULES_CACHE"] = modules_cache_dir
|
| 951 |
+
dynamic_module_utils.HF_MODULES_CACHE = modules_cache_dir
|
| 952 |
+
|
| 953 |
+
def _resolve_default_text_tokenizer_path(self) -> Path:
|
| 954 |
+
candidates: list[Path] = []
|
| 955 |
+
|
| 956 |
+
raw_name_or_path = getattr(self.config, "_name_or_path", None)
|
| 957 |
+
if raw_name_or_path:
|
| 958 |
+
candidates.append(Path(str(raw_name_or_path)).expanduser())
|
| 959 |
+
|
| 960 |
+
raw_model_name_or_path = getattr(self, "name_or_path", None)
|
| 961 |
+
if raw_model_name_or_path:
|
| 962 |
+
candidates.append(Path(str(raw_model_name_or_path)).expanduser())
|
| 963 |
+
|
| 964 |
+
candidates.append(Path(__file__).resolve().parent)
|
| 965 |
+
|
| 966 |
+
checked: set[str] = set()
|
| 967 |
+
for candidate in candidates:
|
| 968 |
+
resolved_candidate = candidate.resolve()
|
| 969 |
+
key = str(resolved_candidate)
|
| 970 |
+
if key in checked:
|
| 971 |
+
continue
|
| 972 |
+
checked.add(key)
|
| 973 |
+
|
| 974 |
+
if (resolved_candidate / "tokenizer.model").is_file():
|
| 975 |
+
return resolved_candidate
|
| 976 |
+
if self._looks_like_hf_tokenizer_dir(resolved_candidate):
|
| 977 |
+
return resolved_candidate
|
| 978 |
+
|
| 979 |
+
return candidates[0].resolve()
|
| 980 |
+
|
| 981 |
+
def _load_text_tokenizer(self, text_tokenizer=None, text_tokenizer_path: Optional[str] = None):
|
| 982 |
+
if text_tokenizer is not None:
|
| 983 |
+
return text_tokenizer
|
| 984 |
+
|
| 985 |
+
resolved_path = (
|
| 986 |
+
self._resolve_text_tokenizer_path(text_tokenizer_path)
|
| 987 |
+
if text_tokenizer_path is not None
|
| 988 |
+
else self._resolve_default_text_tokenizer_path()
|
| 989 |
+
)
|
| 990 |
+
normalized_path = str(resolved_path.resolve())
|
| 991 |
+
cached = getattr(self, "_cached_text_tokenizer", None)
|
| 992 |
+
cached_path = getattr(self, "_cached_text_tokenizer_path", None)
|
| 993 |
+
if cached is not None and cached_path == normalized_path:
|
| 994 |
+
return cached
|
| 995 |
+
|
| 996 |
+
cache_dir = self._resolve_hf_cache_dir()
|
| 997 |
+
self._patch_hf_dynamic_module_cache_dir(cache_dir)
|
| 998 |
+
tokenizer = self._load_resolved_text_tokenizer(resolved_path=resolved_path, cache_dir=cache_dir)
|
| 999 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token is not None:
|
| 1000 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1001 |
+
self._cached_text_tokenizer = tokenizer
|
| 1002 |
+
self._cached_text_tokenizer_path = normalized_path
|
| 1003 |
+
return tokenizer
|
| 1004 |
+
|
| 1005 |
+
@staticmethod
|
| 1006 |
+
def _normalize_audio_tokenizer_type(audio_tokenizer_type: Optional[str]) -> Optional[str]:
|
| 1007 |
+
if audio_tokenizer_type is None:
|
| 1008 |
+
return None
|
| 1009 |
+
normalized = str(audio_tokenizer_type).strip().lower()
|
| 1010 |
+
if not normalized:
|
| 1011 |
+
return None
|
| 1012 |
+
if normalized == MOSS_AUDIO_TOKENIZER_NANO_TYPE:
|
| 1013 |
+
return MOSS_AUDIO_TOKENIZER_NANO_TYPE
|
| 1014 |
+
raise ValueError(
|
| 1015 |
+
"Unsupported audio tokenizer type. "
|
| 1016 |
+
f"The open-source package only supports '{MOSS_AUDIO_TOKENIZER_NANO_TYPE}'."
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
def _resolve_audio_tokenizer_type(self, audio_tokenizer_type: Optional[str]) -> str:
|
| 1020 |
+
explicit_type = self._normalize_audio_tokenizer_type(audio_tokenizer_type)
|
| 1021 |
+
if explicit_type is not None:
|
| 1022 |
+
return explicit_type
|
| 1023 |
+
config_type = self._normalize_audio_tokenizer_type(getattr(self.config, "audio_tokenizer_type", None))
|
| 1024 |
+
return MOSS_AUDIO_TOKENIZER_NANO_TYPE if config_type is None else config_type
|
| 1025 |
+
|
| 1026 |
+
@staticmethod
|
| 1027 |
+
def _set_decoder_attention_implementation(decoder, attn_implementation: str) -> None:
|
| 1028 |
+
decoder.attn_implementation = str(attn_implementation)
|
| 1029 |
+
if getattr(decoder, "config", None) is not None:
|
| 1030 |
+
decoder.config._attn_implementation = str(attn_implementation)
|
| 1031 |
+
for block in getattr(decoder, "h", []):
|
| 1032 |
+
block.attn.attn_implementation = str(attn_implementation)
|
| 1033 |
+
|
| 1034 |
+
def _set_attention_implementation(
|
| 1035 |
+
self,
|
| 1036 |
+
attn_implementation: str,
|
| 1037 |
+
local_attn_implementation: Optional[str] = None,
|
| 1038 |
+
) -> None:
|
| 1039 |
+
resolved_global = str(attn_implementation)
|
| 1040 |
+
resolved_local = resolved_global if local_attn_implementation is None else str(local_attn_implementation)
|
| 1041 |
+
self.config.attn_implementation = resolved_global
|
| 1042 |
+
self.config.gpt2_config._attn_implementation = resolved_global
|
| 1043 |
+
self._set_decoder_attention_implementation(self.transformer, resolved_global)
|
| 1044 |
+
self.config.local_transformer_attn_implementation = resolved_local
|
| 1045 |
+
self._set_decoder_attention_implementation(self.local_transformer, resolved_local)
|
| 1046 |
+
|
| 1047 |
+
@staticmethod
|
| 1048 |
+
def _select_fallback_attention_implementation(device: torch.device) -> str:
|
| 1049 |
+
return "sdpa" if device.type == "cuda" else "eager"
|
| 1050 |
+
|
| 1051 |
+
@staticmethod
|
| 1052 |
+
def _is_generation_stability_error(exc: Exception) -> bool:
|
| 1053 |
+
message = str(exc)
|
| 1054 |
+
return any(
|
| 1055 |
+
marker in message
|
| 1056 |
+
for marker in (
|
| 1057 |
+
"Non-finite",
|
| 1058 |
+
"device-side assert triggered",
|
| 1059 |
+
"probability tensor contains either",
|
| 1060 |
+
"flash_attention_2 requires fp16/bf16 tensors",
|
| 1061 |
+
)
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
def _apply_inference_stability_fallback(self, device: torch.device) -> None:
|
| 1065 |
+
fallback_attn = self._select_fallback_attention_implementation(device)
|
| 1066 |
+
if next(self.parameters()).dtype != torch.float32:
|
| 1067 |
+
self.to(device=device, dtype=torch.float32)
|
| 1068 |
+
self._set_attention_implementation(fallback_attn)
|
| 1069 |
+
logging.warning(
|
| 1070 |
+
"retrying inference with dtype=float32 attn_implementation=%s due to numerical instability",
|
| 1071 |
+
fallback_attn,
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
def _load_audio_tokenizer(
|
| 1075 |
+
self,
|
| 1076 |
+
audio_tokenizer=None,
|
| 1077 |
+
audio_tokenizer_type: Optional[str] = None,
|
| 1078 |
+
audio_tokenizer_pretrained_name_or_path: Optional[str] = None,
|
| 1079 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 1080 |
+
):
|
| 1081 |
+
if audio_tokenizer is not None:
|
| 1082 |
+
return audio_tokenizer
|
| 1083 |
+
|
| 1084 |
+
resolved_type = self._resolve_audio_tokenizer_type(audio_tokenizer_type=audio_tokenizer_type)
|
| 1085 |
+
if resolved_type != MOSS_AUDIO_TOKENIZER_NANO_TYPE:
|
| 1086 |
+
raise ValueError(
|
| 1087 |
+
f"Unsupported audio tokenizer type {resolved_type!r}; expected '{MOSS_AUDIO_TOKENIZER_NANO_TYPE}'."
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
resolved_pretrained_name_or_path = (
|
| 1091 |
+
audio_tokenizer_pretrained_name_or_path
|
| 1092 |
+
or getattr(self.config, "audio_tokenizer_pretrained_name_or_path", None)
|
| 1093 |
+
or DEFAULT_MOSS_AUDIO_TOKENIZER_PRETRAINED_NAME_OR_PATH
|
| 1094 |
+
)
|
| 1095 |
+
candidate_path = Path(str(resolved_pretrained_name_or_path)).expanduser()
|
| 1096 |
+
if candidate_path.exists():
|
| 1097 |
+
load_source = str(candidate_path.resolve())
|
| 1098 |
+
load_kwargs: dict[str, object] = {
|
| 1099 |
+
"trust_remote_code": True,
|
| 1100 |
+
"local_files_only": True,
|
| 1101 |
+
"force_download": True,
|
| 1102 |
+
}
|
| 1103 |
+
cache_key = f"{resolved_type}|{load_source}"
|
| 1104 |
+
else:
|
| 1105 |
+
load_source = str(resolved_pretrained_name_or_path)
|
| 1106 |
+
load_kwargs = {
|
| 1107 |
+
"trust_remote_code": True,
|
| 1108 |
+
}
|
| 1109 |
+
cache_key = f"{resolved_type}|{load_source}"
|
| 1110 |
+
|
| 1111 |
+
cached = getattr(self, "_cached_audio_tokenizer", None)
|
| 1112 |
+
cached_path = getattr(self, "_cached_audio_tokenizer_path", None)
|
| 1113 |
+
if cached is not None and cached_path == cache_key:
|
| 1114 |
+
tokenizer = cached
|
| 1115 |
+
else:
|
| 1116 |
+
tokenizer = AutoModel.from_pretrained(load_source, **load_kwargs)
|
| 1117 |
+
if hasattr(tokenizer, "eval"):
|
| 1118 |
+
tokenizer.eval()
|
| 1119 |
+
self._cached_audio_tokenizer = tokenizer
|
| 1120 |
+
self._cached_audio_tokenizer_path = cache_key
|
| 1121 |
+
|
| 1122 |
+
resolved_device = self._resolve_device(device)
|
| 1123 |
+
return tokenizer.to(resolved_device) if hasattr(tokenizer, "to") else tokenizer
|
| 1124 |
+
|
| 1125 |
+
@staticmethod
|
| 1126 |
+
def _extract_tensor_candidate(output: Any) -> Any:
|
| 1127 |
+
if torch.is_tensor(output) or isinstance(output, np.ndarray):
|
| 1128 |
+
return output
|
| 1129 |
+
for attr_name in ("audio_codes", "audio_token_ids", "codes", "tokens", "input_ids"):
|
| 1130 |
+
value = getattr(output, attr_name, None)
|
| 1131 |
+
if value is not None:
|
| 1132 |
+
return value
|
| 1133 |
+
if isinstance(output, dict):
|
| 1134 |
+
for key in ("audio_codes", "audio_token_ids", "codes", "tokens", "input_ids"):
|
| 1135 |
+
if key in output:
|
| 1136 |
+
return output[key]
|
| 1137 |
+
if len(output) == 1:
|
| 1138 |
+
return next(iter(output.values()))
|
| 1139 |
+
if isinstance(output, (list, tuple)) and output:
|
| 1140 |
+
if len(output) == 2 and isinstance(output[1], (int, float)):
|
| 1141 |
+
return output[0]
|
| 1142 |
+
return NanoTTSGlobalLocalForCausalLM._extract_tensor_candidate(output[0])
|
| 1143 |
+
raise TypeError(f"Unsupported audio tokenizer output type: {type(output)!r}")
|
| 1144 |
+
|
| 1145 |
+
@staticmethod
|
| 1146 |
+
def _extract_audio_code_length(output: Any) -> Optional[int]:
|
| 1147 |
+
for attr_name in ("audio_codes_lengths", "audio_token_ids_lengths", "codes_lengths", "lengths"):
|
| 1148 |
+
candidate = getattr(output, attr_name, None)
|
| 1149 |
+
if candidate is not None:
|
| 1150 |
+
lengths = torch.as_tensor(candidate).reshape(-1)
|
| 1151 |
+
if lengths.numel() > 0:
|
| 1152 |
+
return int(lengths[0].item())
|
| 1153 |
+
if isinstance(output, dict):
|
| 1154 |
+
for key in ("audio_codes_lengths", "audio_token_ids_lengths", "codes_lengths", "lengths"):
|
| 1155 |
+
if key in output:
|
| 1156 |
+
lengths = torch.as_tensor(output[key]).reshape(-1)
|
| 1157 |
+
if lengths.numel() > 0:
|
| 1158 |
+
return int(lengths[0].item())
|
| 1159 |
+
if isinstance(output, (list, tuple)) and len(output) >= 2:
|
| 1160 |
+
candidate = output[1]
|
| 1161 |
+
if torch.is_tensor(candidate) or isinstance(candidate, np.ndarray):
|
| 1162 |
+
lengths = torch.as_tensor(candidate).reshape(-1)
|
| 1163 |
+
if lengths.numel() > 0:
|
| 1164 |
+
return int(lengths[0].item())
|
| 1165 |
+
if isinstance(candidate, (int, float)):
|
| 1166 |
+
return int(candidate)
|
| 1167 |
+
return None
|
| 1168 |
+
|
| 1169 |
+
def _normalize_audio_codes(self, audio_codes: Any) -> torch.LongTensor:
|
| 1170 |
+
code_length = self._extract_audio_code_length(audio_codes)
|
| 1171 |
+
tensor = torch.as_tensor(self._extract_tensor_candidate(audio_codes))
|
| 1172 |
+
if tensor.ndim == 1:
|
| 1173 |
+
tensor = tensor.unsqueeze(-1)
|
| 1174 |
+
if tensor.ndim == 3:
|
| 1175 |
+
if tensor.shape[1] == 1 and tensor.shape[0] >= self.config.n_vq:
|
| 1176 |
+
tensor = tensor[: self.config.n_vq, 0, :].transpose(0, 1)
|
| 1177 |
+
elif tensor.shape[0] == 1:
|
| 1178 |
+
tensor = tensor[0]
|
| 1179 |
+
elif tensor.shape[1] == self.config.n_vq:
|
| 1180 |
+
tensor = tensor.transpose(1, 2)[0]
|
| 1181 |
+
elif tensor.shape[-1] == self.config.n_vq:
|
| 1182 |
+
tensor = tensor[0]
|
| 1183 |
+
else:
|
| 1184 |
+
raise ValueError(f"Unable to normalize audio codes with shape {tuple(tensor.shape)}")
|
| 1185 |
+
|
| 1186 |
+
if tensor.ndim != 2:
|
| 1187 |
+
raise ValueError(f"Expected audio codes with 2 dims after normalization, got {tuple(tensor.shape)}")
|
| 1188 |
+
if tensor.shape[-1] != self.config.n_vq and tensor.shape[0] == self.config.n_vq:
|
| 1189 |
+
tensor = tensor.transpose(0, 1)
|
| 1190 |
+
elif tensor.shape[-1] != self.config.n_vq and tensor.shape[0] > self.config.n_vq:
|
| 1191 |
+
tensor = tensor[: self.config.n_vq].transpose(0, 1)
|
| 1192 |
+
elif tensor.shape[-1] > self.config.n_vq:
|
| 1193 |
+
tensor = tensor[:, : self.config.n_vq]
|
| 1194 |
+
if tensor.shape[-1] != self.config.n_vq:
|
| 1195 |
+
raise ValueError(
|
| 1196 |
+
f"Expected normalized audio codes with trailing dim {self.config.n_vq}, got {tuple(tensor.shape)}"
|
| 1197 |
+
)
|
| 1198 |
+
if code_length is not None:
|
| 1199 |
+
tensor = tensor[:code_length]
|
| 1200 |
+
return tensor.to(dtype=torch.long)
|
| 1201 |
+
|
| 1202 |
+
def _extract_waveform_and_sample_rate(
|
| 1203 |
+
self,
|
| 1204 |
+
decode_output: Any,
|
| 1205 |
+
fallback_sample_rate: int,
|
| 1206 |
+
) -> tuple[torch.FloatTensor, int]:
|
| 1207 |
+
sample_rate = fallback_sample_rate
|
| 1208 |
+
waveform = decode_output
|
| 1209 |
+
waveform_length = None
|
| 1210 |
+
|
| 1211 |
+
for key in ("sample_rate", "sampling_rate"):
|
| 1212 |
+
value = getattr(decode_output, key, None)
|
| 1213 |
+
if value is not None:
|
| 1214 |
+
sample_rate = int(value)
|
| 1215 |
+
break
|
| 1216 |
+
for key in ("waveform", "audio", "wav", "samples"):
|
| 1217 |
+
value = getattr(decode_output, key, None)
|
| 1218 |
+
if value is not None:
|
| 1219 |
+
waveform = value
|
| 1220 |
+
break
|
| 1221 |
+
for key in ("audio_lengths", "waveform_lengths", "lengths"):
|
| 1222 |
+
value = getattr(decode_output, key, None)
|
| 1223 |
+
if value is not None:
|
| 1224 |
+
lengths = torch.as_tensor(value).reshape(-1)
|
| 1225 |
+
if lengths.numel() > 0:
|
| 1226 |
+
waveform_length = int(lengths[0].item())
|
| 1227 |
+
break
|
| 1228 |
+
|
| 1229 |
+
if isinstance(decode_output, dict):
|
| 1230 |
+
for key in ("sample_rate", "sampling_rate"):
|
| 1231 |
+
if key in decode_output:
|
| 1232 |
+
sample_rate = int(decode_output[key])
|
| 1233 |
+
break
|
| 1234 |
+
for key in ("waveform", "audio", "wav", "samples"):
|
| 1235 |
+
if key in decode_output:
|
| 1236 |
+
waveform = decode_output[key]
|
| 1237 |
+
break
|
| 1238 |
+
for key in ("audio_lengths", "waveform_lengths", "lengths"):
|
| 1239 |
+
if key in decode_output:
|
| 1240 |
+
lengths = torch.as_tensor(decode_output[key]).reshape(-1)
|
| 1241 |
+
if lengths.numel() > 0:
|
| 1242 |
+
waveform_length = int(lengths[0].item())
|
| 1243 |
+
break
|
| 1244 |
+
elif isinstance(decode_output, (list, tuple)) and decode_output:
|
| 1245 |
+
if len(decode_output) == 2 and isinstance(decode_output[1], (int, float)):
|
| 1246 |
+
waveform = decode_output[0]
|
| 1247 |
+
sample_rate = int(decode_output[1])
|
| 1248 |
+
else:
|
| 1249 |
+
waveform = decode_output[0]
|
| 1250 |
+
|
| 1251 |
+
waveform_tensor = torch.as_tensor(waveform, dtype=torch.float32)
|
| 1252 |
+
if waveform_tensor.ndim == 3 and waveform_tensor.shape[0] == 1:
|
| 1253 |
+
waveform_tensor = waveform_tensor[0]
|
| 1254 |
+
if waveform_tensor.ndim == 2 and waveform_tensor.shape[0] > waveform_tensor.shape[1]:
|
| 1255 |
+
waveform_tensor = waveform_tensor.transpose(0, 1)
|
| 1256 |
+
if waveform_tensor.ndim == 1:
|
| 1257 |
+
waveform_tensor = waveform_tensor.unsqueeze(0)
|
| 1258 |
+
if waveform_tensor.ndim != 2:
|
| 1259 |
+
raise ValueError(f"Expected decoded waveform with 2 dims, got {tuple(waveform_tensor.shape)}")
|
| 1260 |
+
if waveform_length is not None:
|
| 1261 |
+
waveform_tensor = waveform_tensor[..., : max(0, waveform_length)]
|
| 1262 |
+
return waveform_tensor.cpu(), sample_rate
|
| 1263 |
+
|
| 1264 |
+
def _call_audio_encode(
|
| 1265 |
+
self,
|
| 1266 |
+
audio_tokenizer,
|
| 1267 |
+
waveform: torch.FloatTensor,
|
| 1268 |
+
sample_rate: int,
|
| 1269 |
+
) -> Any:
|
| 1270 |
+
del sample_rate
|
| 1271 |
+
batch_encode_fn = getattr(audio_tokenizer, "batch_encode", None)
|
| 1272 |
+
if batch_encode_fn is None:
|
| 1273 |
+
raise AttributeError("audio_tokenizer must provide a batch_encode method.")
|
| 1274 |
+
|
| 1275 |
+
waveform_tensor = torch.as_tensor(waveform, dtype=torch.float32, device=self._resolve_device(waveform.device))
|
| 1276 |
+
if waveform_tensor.ndim == 1:
|
| 1277 |
+
waveform_tensor = waveform_tensor.unsqueeze(0)
|
| 1278 |
+
if waveform_tensor.ndim != 2:
|
| 1279 |
+
raise ValueError(
|
| 1280 |
+
f"MOSS audio tokenizer encode expects waveform shaped like (C, T), got {tuple(waveform_tensor.shape)}"
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
with self._audio_tokenizer_inference_context(audio_tokenizer, waveform_tensor.device):
|
| 1284 |
+
return batch_encode_fn([waveform_tensor], chunk_duration=None)
|
| 1285 |
+
|
| 1286 |
+
def _call_audio_decode(
|
| 1287 |
+
self,
|
| 1288 |
+
audio_tokenizer,
|
| 1289 |
+
audio_token_ids: torch.LongTensor,
|
| 1290 |
+
sample_rate: int,
|
| 1291 |
+
nq: Optional[int] = None,
|
| 1292 |
+
) -> Any:
|
| 1293 |
+
del sample_rate
|
| 1294 |
+
batch_decode_fn = getattr(audio_tokenizer, "batch_decode", None)
|
| 1295 |
+
if batch_decode_fn is None:
|
| 1296 |
+
raise AttributeError("audio_tokenizer must provide a batch_decode method.")
|
| 1297 |
+
|
| 1298 |
+
effective_nq = self._resolve_inference_nq(nq)
|
| 1299 |
+
decode_codes = self._prepare_audio_codes_for_decode(audio_token_ids, nq=effective_nq)
|
| 1300 |
+
with self._audio_tokenizer_inference_context(audio_tokenizer, decode_codes.device):
|
| 1301 |
+
return batch_decode_fn([decode_codes], num_quantizers=effective_nq, chunk_duration=None)
|
| 1302 |
+
|
| 1303 |
+
@staticmethod
|
| 1304 |
+
def _resolve_audio_tokenizer_downsample_rate(audio_tokenizer) -> int:
|
| 1305 |
+
for holder in (audio_tokenizer, getattr(audio_tokenizer, "config", None)):
|
| 1306 |
+
if holder is None:
|
| 1307 |
+
continue
|
| 1308 |
+
for attr_name in ("downsample_rate", "hop_length", "frame_size"):
|
| 1309 |
+
value = getattr(holder, attr_name, None)
|
| 1310 |
+
if value is not None:
|
| 1311 |
+
return int(value)
|
| 1312 |
+
sampling_rate = getattr(holder, "sampling_rate", None)
|
| 1313 |
+
frame_rate = getattr(holder, "frame_rate", None)
|
| 1314 |
+
if sampling_rate is not None and frame_rate not in (None, 0):
|
| 1315 |
+
return int(round(float(sampling_rate) / float(frame_rate)))
|
| 1316 |
+
raise ValueError("audio_tokenizer.downsample_rate is required for prompt-audio decoding.")
|
| 1317 |
+
|
| 1318 |
+
def _resolve_audio_tokenizer_sample_rate(self, audio_tokenizer) -> int:
|
| 1319 |
+
for holder in (audio_tokenizer, getattr(audio_tokenizer, "config", None)):
|
| 1320 |
+
if holder is None:
|
| 1321 |
+
continue
|
| 1322 |
+
for attr_name in ("sampling_rate", "sample_rate"):
|
| 1323 |
+
value = getattr(holder, attr_name, None)
|
| 1324 |
+
if value is not None:
|
| 1325 |
+
return int(value)
|
| 1326 |
+
return int(self.config.audio_tokenizer_sample_rate)
|
| 1327 |
+
|
| 1328 |
+
@staticmethod
|
| 1329 |
+
def _resolve_audio_tokenizer_channels(audio_tokenizer) -> int:
|
| 1330 |
+
for holder in (audio_tokenizer, getattr(audio_tokenizer, "config", None)):
|
| 1331 |
+
if holder is None:
|
| 1332 |
+
continue
|
| 1333 |
+
for attr_name in ("number_channels", "channels_numbers", "audio_channels", "channels", "num_channels"):
|
| 1334 |
+
value = getattr(holder, attr_name, None)
|
| 1335 |
+
if value is not None:
|
| 1336 |
+
return int(value)
|
| 1337 |
+
return 1
|
| 1338 |
+
|
| 1339 |
+
@staticmethod
|
| 1340 |
+
def _audio_tokenizer_inference_context(audio_tokenizer, device: Union[str, torch.device]):
|
| 1341 |
+
del audio_tokenizer, device
|
| 1342 |
+
return nullcontext()
|
| 1343 |
+
|
| 1344 |
+
def _prepare_audio_codes_for_decode(
|
| 1345 |
+
self,
|
| 1346 |
+
audio_token_ids: torch.LongTensor,
|
| 1347 |
+
nq: Optional[int] = None,
|
| 1348 |
+
) -> torch.LongTensor:
|
| 1349 |
+
effective_nq = self._resolve_inference_nq(nq)
|
| 1350 |
+
tensor = torch.as_tensor(audio_token_ids, dtype=torch.long)
|
| 1351 |
+
if tensor.ndim == 2:
|
| 1352 |
+
if tensor.shape[-1] == self.config.n_vq and tensor.shape[0] != self.config.n_vq:
|
| 1353 |
+
return tensor[:, :effective_nq].transpose(0, 1).contiguous()
|
| 1354 |
+
if tensor.shape[0] == self.config.n_vq:
|
| 1355 |
+
return tensor[:effective_nq].contiguous()
|
| 1356 |
+
elif tensor.ndim == 3:
|
| 1357 |
+
if tensor.shape[-1] == self.config.n_vq:
|
| 1358 |
+
return tensor[..., :effective_nq].permute(2, 0, 1).contiguous()
|
| 1359 |
+
if tensor.shape[0] == self.config.n_vq:
|
| 1360 |
+
return tensor[:effective_nq].contiguous()
|
| 1361 |
+
raise ValueError(
|
| 1362 |
+
f"Expected generated audio token ids shaped like (T, {self.config.n_vq}) or ({self.config.n_vq}, T); got {tuple(tensor.shape)}"
|
| 1363 |
+
)
|
| 1364 |
+
|
| 1365 |
+
def _load_reference_audio(
|
| 1366 |
+
self,
|
| 1367 |
+
reference_audio_path: Union[str, Path],
|
| 1368 |
+
target_sample_rate: int,
|
| 1369 |
+
target_channels: int,
|
| 1370 |
+
) -> tuple[torch.FloatTensor, int]:
|
| 1371 |
+
waveform, sample_rate = torchaudio.load(str(reference_audio_path))
|
| 1372 |
+
waveform = waveform.to(torch.float32)
|
| 1373 |
+
if sample_rate != target_sample_rate:
|
| 1374 |
+
waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
|
| 1375 |
+
sample_rate = target_sample_rate
|
| 1376 |
+
current_channels = int(waveform.shape[0])
|
| 1377 |
+
if current_channels == target_channels:
|
| 1378 |
+
return waveform, sample_rate
|
| 1379 |
+
if current_channels == 1 and target_channels > 1:
|
| 1380 |
+
return waveform.repeat(target_channels, 1), sample_rate
|
| 1381 |
+
if current_channels > 1 and target_channels == 1:
|
| 1382 |
+
return waveform.mean(dim=0, keepdim=True), sample_rate
|
| 1383 |
+
raise ValueError(f"Unsupported reference audio channel conversion: {current_channels} -> {target_channels}")
|
| 1384 |
+
|
| 1385 |
+
def _decode_local_last_hidden_state(
|
| 1386 |
+
self,
|
| 1387 |
+
local_inputs_embeds: torch.FloatTensor,
|
| 1388 |
+
) -> torch.FloatTensor:
|
| 1389 |
+
local_attention_mask = torch.ones(
|
| 1390 |
+
local_inputs_embeds.shape[:2],
|
| 1391 |
+
dtype=torch.bool,
|
| 1392 |
+
device=local_inputs_embeds.device,
|
| 1393 |
+
)
|
| 1394 |
+
local_outputs = self.local_transformer(
|
| 1395 |
+
input_ids=None,
|
| 1396 |
+
attention_mask=local_attention_mask,
|
| 1397 |
+
position_ids=None,
|
| 1398 |
+
inputs_embeds=local_inputs_embeds,
|
| 1399 |
+
use_cache=False,
|
| 1400 |
+
output_attentions=False,
|
| 1401 |
+
output_hidden_states=False,
|
| 1402 |
+
return_dict=True,
|
| 1403 |
+
cu_seqlens=None,
|
| 1404 |
+
num_sequences=None,
|
| 1405 |
+
)
|
| 1406 |
+
return local_outputs.last_hidden_state[:, -1, :]
|
| 1407 |
+
|
| 1408 |
+
@torch.no_grad()
|
| 1409 |
+
def generate(
|
| 1410 |
+
self,
|
| 1411 |
+
input_ids: torch.LongTensor,
|
| 1412 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1413 |
+
nq: Optional[int] = None,
|
| 1414 |
+
max_new_frames: int = 300,
|
| 1415 |
+
do_sample: bool = False,
|
| 1416 |
+
text_temperature: float = 1.5,
|
| 1417 |
+
text_top_p: float = 1.0,
|
| 1418 |
+
text_top_k: int = 50,
|
| 1419 |
+
audio_temperature: float = 1.7,
|
| 1420 |
+
audio_top_p: float = 0.8,
|
| 1421 |
+
audio_top_k: int = 25,
|
| 1422 |
+
audio_repetition_penalty: float = 1.0,
|
| 1423 |
+
use_kv_cache: bool = True,
|
| 1424 |
+
return_dict_in_generate: bool = True,
|
| 1425 |
+
):
|
| 1426 |
+
if input_ids.ndim == 2:
|
| 1427 |
+
input_ids = input_ids.unsqueeze(0)
|
| 1428 |
+
if input_ids.ndim != 3:
|
| 1429 |
+
raise ValueError(f"Expected input_ids with 3 dims, got shape {tuple(input_ids.shape)}")
|
| 1430 |
+
if attention_mask is None:
|
| 1431 |
+
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=input_ids.device)
|
| 1432 |
+
elif attention_mask.ndim == 1:
|
| 1433 |
+
attention_mask = attention_mask.unsqueeze(0)
|
| 1434 |
+
|
| 1435 |
+
effective_nq = self._resolve_inference_nq(nq)
|
| 1436 |
+
batch_size = input_ids.shape[0]
|
| 1437 |
+
current_input_ids = input_ids
|
| 1438 |
+
current_attention_mask = attention_mask.to(device=input_ids.device)
|
| 1439 |
+
current_model_input_ids = current_input_ids
|
| 1440 |
+
generated_frames = []
|
| 1441 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
| 1442 |
+
past_key_values = None
|
| 1443 |
+
local_dtype = self.local_transformer.ln_f.weight.dtype
|
| 1444 |
+
|
| 1445 |
+
for _ in range(max_new_frames):
|
| 1446 |
+
generated_audio_history = torch.stack(generated_frames, dim=1) if generated_frames else None
|
| 1447 |
+
global_inputs_embeds = self._build_inputs_embeds(current_model_input_ids)
|
| 1448 |
+
global_outputs = self.transformer(
|
| 1449 |
+
input_ids=None,
|
| 1450 |
+
past_key_values=past_key_values,
|
| 1451 |
+
attention_mask=current_attention_mask,
|
| 1452 |
+
position_ids=None,
|
| 1453 |
+
inputs_embeds=global_inputs_embeds,
|
| 1454 |
+
use_cache=use_kv_cache,
|
| 1455 |
+
output_attentions=False,
|
| 1456 |
+
output_hidden_states=False,
|
| 1457 |
+
return_dict=True,
|
| 1458 |
+
cu_seqlens=None,
|
| 1459 |
+
num_sequences=None,
|
| 1460 |
+
)
|
| 1461 |
+
global_hidden_states = global_outputs.last_hidden_state[:, -1, :].to(dtype=local_dtype)
|
| 1462 |
+
|
| 1463 |
+
local_inputs_embeds = global_hidden_states.unsqueeze(1)
|
| 1464 |
+
local_hidden_states = self._decode_local_last_hidden_state(local_inputs_embeds)
|
| 1465 |
+
text_logits = self.text_lm_head(local_hidden_states)
|
| 1466 |
+
self._ensure_finite_generation_logits(text_logits, "text logits")
|
| 1467 |
+
next_text_tokens = self._sample_next_assistant_text_token(
|
| 1468 |
+
logits=text_logits,
|
| 1469 |
+
do_sample=do_sample,
|
| 1470 |
+
temperature=text_temperature,
|
| 1471 |
+
top_k=text_top_k,
|
| 1472 |
+
top_p=text_top_p,
|
| 1473 |
+
)
|
| 1474 |
+
should_continue = next_text_tokens.eq(self.config.audio_assistant_slot_token_id) & ~finished
|
| 1475 |
+
finished = finished | next_text_tokens.eq(self.config.audio_end_token_id)
|
| 1476 |
+
if not should_continue.any():
|
| 1477 |
+
break
|
| 1478 |
+
|
| 1479 |
+
next_frame_tokens = []
|
| 1480 |
+
current_local_input = self.transformer.wte(next_text_tokens).to(dtype=local_dtype)
|
| 1481 |
+
for channel_index in range(effective_nq):
|
| 1482 |
+
local_inputs_embeds = torch.cat([local_inputs_embeds, current_local_input.unsqueeze(1)], dim=1)
|
| 1483 |
+
local_hidden_states = self._decode_local_last_hidden_state(local_inputs_embeds)
|
| 1484 |
+
channel_logits = self.audio_lm_heads[channel_index](local_hidden_states)
|
| 1485 |
+
self._ensure_finite_generation_logits(channel_logits, f"audio logits[{channel_index}]")
|
| 1486 |
+
channel_token = self._sample_next_token(
|
| 1487 |
+
logits=channel_logits,
|
| 1488 |
+
do_sample=do_sample,
|
| 1489 |
+
temperature=audio_temperature,
|
| 1490 |
+
top_k=audio_top_k,
|
| 1491 |
+
top_p=audio_top_p,
|
| 1492 |
+
previous_token_ids=(
|
| 1493 |
+
None if generated_audio_history is None else generated_audio_history[:, :, channel_index]
|
| 1494 |
+
),
|
| 1495 |
+
repetition_penalty=audio_repetition_penalty,
|
| 1496 |
+
)
|
| 1497 |
+
next_frame_tokens.append(channel_token)
|
| 1498 |
+
current_local_input = self.audio_embeddings[channel_index](channel_token).to(dtype=local_dtype)
|
| 1499 |
+
|
| 1500 |
+
next_frame_prefix = torch.stack(next_frame_tokens, dim=-1)
|
| 1501 |
+
if effective_nq < self.config.n_vq:
|
| 1502 |
+
next_frame = torch.full(
|
| 1503 |
+
(batch_size, self.config.n_vq),
|
| 1504 |
+
self.config.audio_pad_token_id,
|
| 1505 |
+
dtype=next_frame_prefix.dtype,
|
| 1506 |
+
device=next_frame_prefix.device,
|
| 1507 |
+
)
|
| 1508 |
+
next_frame[:, :effective_nq] = next_frame_prefix
|
| 1509 |
+
else:
|
| 1510 |
+
next_frame = next_frame_prefix
|
| 1511 |
+
padded_next_frame = next_frame.masked_fill(~should_continue.unsqueeze(-1), self.config.audio_pad_token_id)
|
| 1512 |
+
generated_frames.append(padded_next_frame)
|
| 1513 |
+
|
| 1514 |
+
next_row = self._build_generation_row(
|
| 1515 |
+
batch_size=batch_size,
|
| 1516 |
+
device=input_ids.device,
|
| 1517 |
+
audio_token_ids=padded_next_frame,
|
| 1518 |
+
)
|
| 1519 |
+
if (~should_continue).any():
|
| 1520 |
+
next_row[~should_continue, 0, 0] = self.config.pad_token_id
|
| 1521 |
+
next_row[~should_continue, 0, 1:] = self.config.audio_pad_token_id
|
| 1522 |
+
|
| 1523 |
+
current_input_ids = torch.cat([current_input_ids, next_row], dim=1)
|
| 1524 |
+
current_attention_mask = torch.cat([current_attention_mask, should_continue.unsqueeze(1)], dim=1)
|
| 1525 |
+
if use_kv_cache:
|
| 1526 |
+
current_model_input_ids = next_row
|
| 1527 |
+
past_key_values = global_outputs.past_key_values
|
| 1528 |
+
else:
|
| 1529 |
+
current_model_input_ids = current_input_ids
|
| 1530 |
+
|
| 1531 |
+
if generated_frames:
|
| 1532 |
+
audio_token_ids = torch.stack(generated_frames, dim=1)
|
| 1533 |
+
else:
|
| 1534 |
+
audio_token_ids = torch.empty((batch_size, 0, self.config.n_vq), dtype=torch.long, device=input_ids.device)
|
| 1535 |
+
|
| 1536 |
+
if not return_dict_in_generate:
|
| 1537 |
+
return audio_token_ids
|
| 1538 |
+
return NanoTTSGenerationOutput(audio_token_ids=audio_token_ids, prompt_input_ids=input_ids)
|
| 1539 |
+
|
| 1540 |
+
@torch.no_grad()
|
| 1541 |
+
def inference(
|
| 1542 |
+
self,
|
| 1543 |
+
text: str,
|
| 1544 |
+
output_audio_path: Union[str, Path],
|
| 1545 |
+
mode: str = "continuation",
|
| 1546 |
+
prompt_text: Optional[str] = None,
|
| 1547 |
+
prompt_audio_path: Optional[Union[str, Path]] = None,
|
| 1548 |
+
reference_audio_path: Optional[Union[str, Path]] = None,
|
| 1549 |
+
text_tokenizer=None,
|
| 1550 |
+
text_tokenizer_path: Optional[str] = None,
|
| 1551 |
+
audio_tokenizer=None,
|
| 1552 |
+
audio_tokenizer_type: Optional[str] = None,
|
| 1553 |
+
audio_tokenizer_pretrained_name_or_path: Optional[str] = None,
|
| 1554 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 1555 |
+
nq: Optional[int] = None,
|
| 1556 |
+
max_new_frames: int = 300,
|
| 1557 |
+
do_sample: bool = False,
|
| 1558 |
+
text_temperature: float = 1.5,
|
| 1559 |
+
text_top_p: float = 1.0,
|
| 1560 |
+
text_top_k: int = 50,
|
| 1561 |
+
audio_temperature: float = 1.7,
|
| 1562 |
+
audio_top_p: float = 0.8,
|
| 1563 |
+
audio_top_k: int = 25,
|
| 1564 |
+
audio_repetition_penalty: float = 1.0,
|
| 1565 |
+
use_kv_cache: bool = True,
|
| 1566 |
+
voice_clone_max_text_tokens: int = DEFAULT_VOICE_CLONE_MAX_TEXT_TOKENS,
|
| 1567 |
+
voice_clone_max_memory_per_sample_gb: float = DEFAULT_VOICE_CLONE_MAX_MEMORY_PER_SAMPLE_GB,
|
| 1568 |
+
) -> dict[str, Any]:
|
| 1569 |
+
resolved_device = self._resolve_device(device)
|
| 1570 |
+
effective_nq = self._resolve_inference_nq(nq)
|
| 1571 |
+
if next(self.parameters()).device != resolved_device:
|
| 1572 |
+
self.to(resolved_device)
|
| 1573 |
+
|
| 1574 |
+
was_training = self.training
|
| 1575 |
+
self.eval()
|
| 1576 |
+
|
| 1577 |
+
text_tokenizer = self._load_text_tokenizer(
|
| 1578 |
+
text_tokenizer=text_tokenizer,
|
| 1579 |
+
text_tokenizer_path=text_tokenizer_path,
|
| 1580 |
+
)
|
| 1581 |
+
audio_tokenizer = self._load_audio_tokenizer(
|
| 1582 |
+
audio_tokenizer=audio_tokenizer,
|
| 1583 |
+
audio_tokenizer_type=audio_tokenizer_type,
|
| 1584 |
+
audio_tokenizer_pretrained_name_or_path=audio_tokenizer_pretrained_name_or_path,
|
| 1585 |
+
device=resolved_device,
|
| 1586 |
+
)
|
| 1587 |
+
|
| 1588 |
+
target_sample_rate = self._resolve_audio_tokenizer_sample_rate(audio_tokenizer)
|
| 1589 |
+
target_channels = self._resolve_audio_tokenizer_channels(audio_tokenizer)
|
| 1590 |
+
effective_prompt_audio_path = prompt_audio_path or reference_audio_path
|
| 1591 |
+
resolved_mode = self._resolve_inference_mode(
|
| 1592 |
+
mode=mode,
|
| 1593 |
+
has_prompt_text=prompt_text is not None,
|
| 1594 |
+
has_prompt_audio=effective_prompt_audio_path is not None,
|
| 1595 |
+
)
|
| 1596 |
+
if reference_audio_path is not None and prompt_audio_path is None:
|
| 1597 |
+
logging.warning(
|
| 1598 |
+
"reference_audio_path=%s is treated as prompt_audio_path for backward compatibility.",
|
| 1599 |
+
reference_audio_path,
|
| 1600 |
+
)
|
| 1601 |
+
|
| 1602 |
+
prompt_audio_codes = None
|
| 1603 |
+
if effective_prompt_audio_path is not None:
|
| 1604 |
+
waveform, sample_rate = self._load_reference_audio(
|
| 1605 |
+
effective_prompt_audio_path,
|
| 1606 |
+
target_sample_rate,
|
| 1607 |
+
target_channels,
|
| 1608 |
+
)
|
| 1609 |
+
encoded = self._call_audio_encode(
|
| 1610 |
+
audio_tokenizer=audio_tokenizer,
|
| 1611 |
+
waveform=waveform.to(resolved_device),
|
| 1612 |
+
sample_rate=sample_rate,
|
| 1613 |
+
)
|
| 1614 |
+
prompt_audio_codes = self._mask_unused_audio_channels(
|
| 1615 |
+
self._normalize_audio_codes(encoded),
|
| 1616 |
+
nq=effective_nq,
|
| 1617 |
+
).to(resolved_device)
|
| 1618 |
+
|
| 1619 |
+
if resolved_mode == "voice_clone":
|
| 1620 |
+
split_voice_clone_text_chunks = self._split_text_into_best_sentences(
|
| 1621 |
+
text_tokenizer=text_tokenizer,
|
| 1622 |
+
text=text,
|
| 1623 |
+
max_tokens=voice_clone_max_text_tokens,
|
| 1624 |
+
)
|
| 1625 |
+
voice_clone_text_chunks = split_voice_clone_text_chunks if len(split_voice_clone_text_chunks) > 1 else [text]
|
| 1626 |
+
else:
|
| 1627 |
+
voice_clone_text_chunks = [text]
|
| 1628 |
+
|
| 1629 |
+
generated_audio_token_chunks: list[torch.LongTensor] = []
|
| 1630 |
+
decoded_waveform_chunks: list[torch.FloatTensor] = []
|
| 1631 |
+
decoded_sample_rate: Optional[int] = None
|
| 1632 |
+
|
| 1633 |
+
if resolved_mode == "voice_clone" and len(voice_clone_text_chunks) > 1:
|
| 1634 |
+
voice_clone_chunk_batch_size = self._resolve_voice_clone_chunk_batch_size(
|
| 1635 |
+
resolved_device=resolved_device,
|
| 1636 |
+
chunk_count=len(voice_clone_text_chunks),
|
| 1637 |
+
max_memory_per_sample_gb=float(voice_clone_max_memory_per_sample_gb),
|
| 1638 |
+
)
|
| 1639 |
+
else:
|
| 1640 |
+
voice_clone_chunk_batch_size = 1
|
| 1641 |
+
|
| 1642 |
+
for batch_start in range(0, len(voice_clone_text_chunks), voice_clone_chunk_batch_size):
|
| 1643 |
+
batch_chunks = voice_clone_text_chunks[batch_start : batch_start + voice_clone_chunk_batch_size]
|
| 1644 |
+
batch_prompt_input_ids: list[torch.LongTensor] = []
|
| 1645 |
+
batch_attention_masks: list[torch.BoolTensor] = []
|
| 1646 |
+
for text_chunk in batch_chunks:
|
| 1647 |
+
prompt_input_ids, attention_mask = self.build_inference_input_ids(
|
| 1648 |
+
text=text_chunk,
|
| 1649 |
+
text_tokenizer=text_tokenizer,
|
| 1650 |
+
mode=resolved_mode,
|
| 1651 |
+
prompt_text=prompt_text,
|
| 1652 |
+
prompt_audio_codes=prompt_audio_codes,
|
| 1653 |
+
device=resolved_device,
|
| 1654 |
+
)
|
| 1655 |
+
batch_prompt_input_ids.append(prompt_input_ids)
|
| 1656 |
+
batch_attention_masks.append(attention_mask)
|
| 1657 |
+
|
| 1658 |
+
batched_prompt_input_ids, batched_attention_mask = self._left_pad_inference_batch(
|
| 1659 |
+
input_id_batches=batch_prompt_input_ids,
|
| 1660 |
+
attention_mask_batches=batch_attention_masks,
|
| 1661 |
+
device=resolved_device,
|
| 1662 |
+
)
|
| 1663 |
+
batched_audio_token_ids = self._generate_audio_token_ids_with_fallback(
|
| 1664 |
+
prompt_input_ids=batched_prompt_input_ids,
|
| 1665 |
+
attention_mask=batched_attention_mask,
|
| 1666 |
+
effective_nq=effective_nq,
|
| 1667 |
+
max_new_frames=max_new_frames,
|
| 1668 |
+
do_sample=do_sample,
|
| 1669 |
+
text_temperature=text_temperature,
|
| 1670 |
+
text_top_p=text_top_p,
|
| 1671 |
+
text_top_k=text_top_k,
|
| 1672 |
+
audio_temperature=audio_temperature,
|
| 1673 |
+
audio_top_p=audio_top_p,
|
| 1674 |
+
audio_top_k=audio_top_k,
|
| 1675 |
+
audio_repetition_penalty=audio_repetition_penalty,
|
| 1676 |
+
use_kv_cache=use_kv_cache,
|
| 1677 |
+
resolved_device=resolved_device,
|
| 1678 |
+
)
|
| 1679 |
+
|
| 1680 |
+
for sample_index in range(len(batch_chunks)):
|
| 1681 |
+
audio_token_ids = self._trim_generated_audio_token_ids(
|
| 1682 |
+
batched_audio_token_ids[sample_index],
|
| 1683 |
+
effective_nq=effective_nq,
|
| 1684 |
+
)
|
| 1685 |
+
generated_audio_token_chunks.append(audio_token_ids)
|
| 1686 |
+
|
| 1687 |
+
if resolved_mode == "voice_clone" and len(voice_clone_text_chunks) > 1:
|
| 1688 |
+
decoded_waveform, current_sample_rate = self._decode_audio_token_ids_to_waveform(
|
| 1689 |
+
audio_tokenizer=audio_tokenizer,
|
| 1690 |
+
audio_token_ids=audio_token_ids,
|
| 1691 |
+
target_sample_rate=target_sample_rate,
|
| 1692 |
+
effective_nq=effective_nq,
|
| 1693 |
+
resolved_device=resolved_device,
|
| 1694 |
+
)
|
| 1695 |
+
if decoded_sample_rate is None:
|
| 1696 |
+
decoded_sample_rate = current_sample_rate
|
| 1697 |
+
elif decoded_sample_rate != current_sample_rate:
|
| 1698 |
+
raise ValueError(
|
| 1699 |
+
f"Decoded sample rates differ across voice_clone chunks: {decoded_sample_rate} vs {current_sample_rate}"
|
| 1700 |
+
)
|
| 1701 |
+
decoded_waveform_chunks.append(decoded_waveform)
|
| 1702 |
+
|
| 1703 |
+
if generated_audio_token_chunks:
|
| 1704 |
+
audio_token_ids = torch.cat(generated_audio_token_chunks, dim=0)
|
| 1705 |
+
else:
|
| 1706 |
+
audio_token_ids = torch.empty((0, self.config.n_vq), dtype=torch.long, device=resolved_device)
|
| 1707 |
+
|
| 1708 |
+
if resolved_mode == "voice_clone" and len(voice_clone_text_chunks) > 1:
|
| 1709 |
+
waveform = (
|
| 1710 |
+
self._concat_voice_clone_waveform_chunks(
|
| 1711 |
+
waveform_chunks=decoded_waveform_chunks,
|
| 1712 |
+
text_chunks=voice_clone_text_chunks,
|
| 1713 |
+
sample_rate=decoded_sample_rate,
|
| 1714 |
+
)
|
| 1715 |
+
if decoded_waveform_chunks
|
| 1716 |
+
else torch.zeros((target_channels, 0), dtype=torch.float32)
|
| 1717 |
+
)
|
| 1718 |
+
else:
|
| 1719 |
+
decode_audio_token_ids = audio_token_ids
|
| 1720 |
+
prompt_waveform_prefix_samples = 0
|
| 1721 |
+
if resolved_mode == "continuation" and prompt_audio_codes is not None:
|
| 1722 |
+
decode_audio_token_ids = torch.cat([prompt_audio_codes, audio_token_ids], dim=0)
|
| 1723 |
+
prompt_waveform_prefix_samples = (
|
| 1724 |
+
int(prompt_audio_codes.shape[0]) * self._resolve_audio_tokenizer_downsample_rate(audio_tokenizer)
|
| 1725 |
+
)
|
| 1726 |
+
|
| 1727 |
+
waveform, decoded_sample_rate = self._decode_audio_token_ids_to_waveform(
|
| 1728 |
+
audio_tokenizer=audio_tokenizer,
|
| 1729 |
+
audio_token_ids=decode_audio_token_ids,
|
| 1730 |
+
target_sample_rate=target_sample_rate,
|
| 1731 |
+
effective_nq=effective_nq,
|
| 1732 |
+
resolved_device=resolved_device,
|
| 1733 |
+
)
|
| 1734 |
+
if prompt_waveform_prefix_samples > 0:
|
| 1735 |
+
if decoded_sample_rate != target_sample_rate:
|
| 1736 |
+
prompt_waveform_prefix_samples = int(
|
| 1737 |
+
round(prompt_waveform_prefix_samples * float(decoded_sample_rate) / float(target_sample_rate))
|
| 1738 |
+
)
|
| 1739 |
+
prompt_waveform_prefix_samples = min(prompt_waveform_prefix_samples, int(waveform.shape[-1]))
|
| 1740 |
+
waveform = waveform[:, prompt_waveform_prefix_samples:]
|
| 1741 |
+
|
| 1742 |
+
assert decoded_sample_rate is not None
|
| 1743 |
+
|
| 1744 |
+
output_path = Path(output_audio_path)
|
| 1745 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1746 |
+
torchaudio.save(str(output_path), waveform, decoded_sample_rate)
|
| 1747 |
+
|
| 1748 |
+
if was_training:
|
| 1749 |
+
self.train()
|
| 1750 |
+
|
| 1751 |
+
return {
|
| 1752 |
+
"audio_path": str(output_path),
|
| 1753 |
+
"sample_rate": decoded_sample_rate,
|
| 1754 |
+
"audio_token_ids": audio_token_ids.detach().cpu(),
|
| 1755 |
+
"waveform": waveform,
|
| 1756 |
+
"reference_audio_token_ids": None if prompt_audio_codes is None else prompt_audio_codes.detach().cpu(),
|
| 1757 |
+
}
|
prompting.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List, Sequence
|
| 4 |
+
|
| 5 |
+
from .configuration_nanotts import NanoTTSConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
USER_ROLE_PREFIX = "user\n"
|
| 9 |
+
USER_TEMPLATE_REFERENCE_PREFIX = (
|
| 10 |
+
"<user_inst>\n"
|
| 11 |
+
"- Reference(s):\n"
|
| 12 |
+
)
|
| 13 |
+
USER_TEMPLATE_AFTER_REFERENCE = (
|
| 14 |
+
"\n- Instruction:\nNone\n"
|
| 15 |
+
"- Tokens:\nNone\n"
|
| 16 |
+
"- Quality:\nNone\n"
|
| 17 |
+
"- Sound Event:\nNone\n"
|
| 18 |
+
"- Ambient Sound:\nNone\n"
|
| 19 |
+
"- Language:\nNone\n"
|
| 20 |
+
"- Text:\n"
|
| 21 |
+
)
|
| 22 |
+
USER_TEMPLATE_PREFIX = USER_TEMPLATE_REFERENCE_PREFIX + "None" + USER_TEMPLATE_AFTER_REFERENCE
|
| 23 |
+
USER_TEMPLATE_SUFFIX = "\n</user_inst>"
|
| 24 |
+
ASSISTANT_TURN_PREFIX = "\n"
|
| 25 |
+
ASSISTANT_ROLE_PREFIX = "assistant\n"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def encode_text(tokenizer, text: str) -> List[int]:
|
| 29 |
+
try:
|
| 30 |
+
return list(tokenizer.encode(text, add_special_tokens=False))
|
| 31 |
+
except TypeError:
|
| 32 |
+
return list(tokenizer.encode(text))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def decode_text(tokenizer, token_ids: Sequence[int]) -> str:
|
| 36 |
+
try:
|
| 37 |
+
return str(
|
| 38 |
+
tokenizer.decode(
|
| 39 |
+
list(token_ids),
|
| 40 |
+
skip_special_tokens=False,
|
| 41 |
+
clean_up_tokenization_spaces=False,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
except TypeError:
|
| 45 |
+
try:
|
| 46 |
+
return str(tokenizer.decode(list(token_ids), skip_special_tokens=False))
|
| 47 |
+
except TypeError:
|
| 48 |
+
return str(tokenizer.decode(list(token_ids)))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_user_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 52 |
+
return [config.im_start_token_id] + encode_text(tokenizer, USER_ROLE_PREFIX) + encode_text(
|
| 53 |
+
tokenizer,
|
| 54 |
+
USER_TEMPLATE_REFERENCE_PREFIX,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_user_prompt_after_reference(tokenizer) -> List[int]:
|
| 59 |
+
return encode_text(tokenizer, USER_TEMPLATE_AFTER_REFERENCE)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def build_assistant_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 63 |
+
return encode_text(tokenizer, USER_TEMPLATE_SUFFIX) + [config.im_end_token_id] + encode_text(
|
| 64 |
+
tokenizer,
|
| 65 |
+
ASSISTANT_TURN_PREFIX,
|
| 66 |
+
) + [config.im_start_token_id] + encode_text(
|
| 67 |
+
tokenizer,
|
| 68 |
+
ASSISTANT_ROLE_PREFIX,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_prompt_prefix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 73 |
+
return (
|
| 74 |
+
build_user_prompt_prefix(tokenizer, config)
|
| 75 |
+
+ encode_text(tokenizer, "None")
|
| 76 |
+
+ build_user_prompt_after_reference(tokenizer)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_prompt_suffix(tokenizer, config: NanoTTSConfig) -> List[int]:
|
| 81 |
+
return build_assistant_prompt_prefix(tokenizer, config)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_prompt_token_ids(
|
| 85 |
+
tokenizer,
|
| 86 |
+
config: NanoTTSConfig,
|
| 87 |
+
text_token_ids: Sequence[int],
|
| 88 |
+
) -> List[int]:
|
| 89 |
+
return build_prompt_prefix(tokenizer, config) + [int(token_id) for token_id in text_token_ids] + build_prompt_suffix(
|
| 90 |
+
tokenizer,
|
| 91 |
+
config,
|
| 92 |
+
)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d19e63fdc6a35f61a7d1c27e06bfacaba7c1ed40ea3c619c86efc64bcd50a496
|
| 3 |
+
size 234693095
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<pad>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<unk>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenization_nanotts_sentencepiece.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import shutil
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import sentencepiece as spm
|
| 8 |
+
from transformers import PreTrainedTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NanoTTSSentencePieceTokenizer(PreTrainedTokenizer):
|
| 15 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 16 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
vocab_file: str,
|
| 21 |
+
unk_token: str = "<unk>",
|
| 22 |
+
bos_token: str = "<s>",
|
| 23 |
+
eos_token: str = "</s>",
|
| 24 |
+
pad_token: str = "<pad>",
|
| 25 |
+
sp_model_kwargs: dict[str, Any] | None = None,
|
| 26 |
+
**kwargs,
|
| 27 |
+
) -> None:
|
| 28 |
+
self.vocab_file = str(vocab_file)
|
| 29 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else dict(sp_model_kwargs)
|
| 30 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 31 |
+
self.sp_model.Load(self.vocab_file)
|
| 32 |
+
super().__init__(
|
| 33 |
+
unk_token=unk_token,
|
| 34 |
+
bos_token=bos_token,
|
| 35 |
+
eos_token=eos_token,
|
| 36 |
+
pad_token=pad_token,
|
| 37 |
+
**kwargs,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def vocab_size(self) -> int:
|
| 42 |
+
return int(self.sp_model.get_piece_size())
|
| 43 |
+
|
| 44 |
+
def get_vocab(self) -> dict[str, int]:
|
| 45 |
+
vocab = {self.sp_model.id_to_piece(i): i for i in range(self.vocab_size)}
|
| 46 |
+
vocab.update(self.added_tokens_encoder)
|
| 47 |
+
return vocab
|
| 48 |
+
|
| 49 |
+
def _tokenize(self, text: str) -> list[str]:
|
| 50 |
+
return list(self.sp_model.encode(text, out_type=str))
|
| 51 |
+
|
| 52 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 53 |
+
token_id = int(self.sp_model.piece_to_id(token))
|
| 54 |
+
return token_id
|
| 55 |
+
|
| 56 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 57 |
+
return str(self.sp_model.id_to_piece(int(index)))
|
| 58 |
+
|
| 59 |
+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
| 60 |
+
return str(self.sp_model.decode(tokens))
|
| 61 |
+
|
| 62 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
|
| 63 |
+
save_dir = Path(save_directory)
|
| 64 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
out_name = "tokenizer.model" if filename_prefix is None else f"{filename_prefix}-tokenizer.model"
|
| 66 |
+
out_path = save_dir / out_name
|
| 67 |
+
if Path(self.vocab_file).resolve() != out_path.resolve():
|
| 68 |
+
shutil.copyfile(self.vocab_file, out_path)
|
| 69 |
+
return (str(out_path),)
|
| 70 |
+
|
| 71 |
+
def build_inputs_with_special_tokens(
|
| 72 |
+
self,
|
| 73 |
+
token_ids_0: list[int],
|
| 74 |
+
token_ids_1: list[int] | None = None,
|
| 75 |
+
) -> list[int]:
|
| 76 |
+
if token_ids_1 is None:
|
| 77 |
+
return list(token_ids_0)
|
| 78 |
+
return list(token_ids_0) + list(token_ids_1)
|
| 79 |
+
|
| 80 |
+
def get_special_tokens_mask(
|
| 81 |
+
self,
|
| 82 |
+
token_ids_0: list[int],
|
| 83 |
+
token_ids_1: list[int] | None = None,
|
| 84 |
+
already_has_special_tokens: bool = False,
|
| 85 |
+
) -> list[int]:
|
| 86 |
+
if already_has_special_tokens:
|
| 87 |
+
return super().get_special_tokens_mask(
|
| 88 |
+
token_ids_0=token_ids_0,
|
| 89 |
+
token_ids_1=token_ids_1,
|
| 90 |
+
already_has_special_tokens=True,
|
| 91 |
+
)
|
| 92 |
+
if token_ids_1 is None:
|
| 93 |
+
return [0] * len(token_ids_0)
|
| 94 |
+
return [0] * (len(token_ids_0) + len(token_ids_1))
|
| 95 |
+
|
| 96 |
+
def create_token_type_ids_from_sequences(
|
| 97 |
+
self,
|
| 98 |
+
token_ids_0: list[int],
|
| 99 |
+
token_ids_1: list[int] | None = None,
|
| 100 |
+
) -> list[int]:
|
| 101 |
+
if token_ids_1 is None:
|
| 102 |
+
return [0] * len(token_ids_0)
|
| 103 |
+
return [0] * (len(token_ids_0) + len(token_ids_1))
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c353ee1479b536bf414c1b247f5542b6607fb8ae91320e5af1781fee200fddff
|
| 3 |
+
size 470897
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<unk>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<s>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<pad>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"additional_special_tokens": [],
|
| 37 |
+
"auto_map": {
|
| 38 |
+
"AutoTokenizer": [
|
| 39 |
+
"tokenization_nanotts_sentencepiece.NanoTTSSentencePieceTokenizer",
|
| 40 |
+
null
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
"backend": "custom",
|
| 44 |
+
"bos_token": "<s>",
|
| 45 |
+
"clean_up_tokenization_spaces": false,
|
| 46 |
+
"eos_token": "</s>",
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"model_max_length": 16384,
|
| 49 |
+
"pad_token": "<pad>",
|
| 50 |
+
"tokenizer_class": "NanoTTSSentencePieceTokenizer",
|
| 51 |
+
"unk_token": "<unk>"
|
| 52 |
+
}
|