Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
7940474
1
Parent(s):
0108756
correct speaker encoder
Browse files- speech/config.yaml +8 -5
- speech/cosyvoice/dataset/processor.py +3 -19
- speech/cosyvoice/flow/flow.py +129 -5
- speech/cosyvoice/llm/llm.py +72 -70
- speech/cosyvoice/transformer/arch_util.py +373 -0
- speech/cosyvoice/transformer/xtransformers.py +1248 -0
- speech/tools/download_dataset.py +371 -0
- speech/tools/download_vivoice.py +210 -0
speech/config.yaml
CHANGED
|
@@ -12,12 +12,12 @@ spk_embed_dim: 192
|
|
| 12 |
qwen_pretrain_path: ''
|
| 13 |
token_frame_rate: 25
|
| 14 |
token_mel_ratio: 2
|
| 15 |
-
|
|
|
|
| 16 |
# stream related params
|
| 17 |
chunk_size: 25 # streaming inference chunk size, in token
|
| 18 |
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
| 19 |
|
| 20 |
-
use_speaker_encoder: True # New flag
|
| 21 |
speaker_encoder_config:
|
| 22 |
mel_dim: 80
|
| 23 |
model_dim: 512
|
|
@@ -51,8 +51,8 @@ llm: !new:cosyvoice.llm.llm.Qwen2LM
|
|
| 51 |
extract_reference_mel: !name:cosyvoice.dataset.processor.extract_reference_mel_from_speech
|
| 52 |
feat_extractor: !ref <feat_extractor>
|
| 53 |
min_length: 0.5
|
| 54 |
-
max_length:
|
| 55 |
-
num_crops:
|
| 56 |
training: True
|
| 57 |
sample_rate: !ref <sample_rate>
|
| 58 |
|
|
@@ -66,6 +66,9 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
|
|
| 66 |
only_mask_loss: True
|
| 67 |
token_mel_ratio: !ref <token_mel_ratio>
|
| 68 |
pre_lookahead_len: 3
|
|
|
|
|
|
|
|
|
|
| 69 |
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
|
| 70 |
output_size: 512
|
| 71 |
attention_heads: 8
|
|
@@ -198,7 +201,7 @@ sort: !name:cosyvoice.dataset.processor.sort
|
|
| 198 |
sort_size: 500 # sort_size should be less than shuffle_size
|
| 199 |
batch: !name:cosyvoice.dataset.processor.batch
|
| 200 |
batch_type: 'dynamic'
|
| 201 |
-
max_frames_in_batch:
|
| 202 |
padding: !name:cosyvoice.dataset.processor.padding
|
| 203 |
use_spk_embedding: False # change to True during sft
|
| 204 |
use_speaker_encoder: !ref <use_speaker_encoder>
|
|
|
|
| 12 |
qwen_pretrain_path: ''
|
| 13 |
token_frame_rate: 25
|
| 14 |
token_mel_ratio: 2
|
| 15 |
+
use_speaker_encoder: True
|
| 16 |
+
speaker_encoder_path: '/data/checkpoint/llm/epoch_29_step_20001.pt'
|
| 17 |
# stream related params
|
| 18 |
chunk_size: 25 # streaming inference chunk size, in token
|
| 19 |
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
|
| 20 |
|
|
|
|
| 21 |
speaker_encoder_config:
|
| 22 |
mel_dim: 80
|
| 23 |
model_dim: 512
|
|
|
|
| 51 |
extract_reference_mel: !name:cosyvoice.dataset.processor.extract_reference_mel_from_speech
|
| 52 |
feat_extractor: !ref <feat_extractor>
|
| 53 |
min_length: 0.5
|
| 54 |
+
max_length: 12.0
|
| 55 |
+
num_crops: 3 # Multiple crops from same utterance
|
| 56 |
training: True
|
| 57 |
sample_rate: !ref <sample_rate>
|
| 58 |
|
|
|
|
| 66 |
only_mask_loss: True
|
| 67 |
token_mel_ratio: !ref <token_mel_ratio>
|
| 68 |
pre_lookahead_len: 3
|
| 69 |
+
use_speaker_encoder: !ref <use_speaker_encoder> # Add this
|
| 70 |
+
freeze_speaker_encoder: True # Freeze by default for flow training
|
| 71 |
+
speaker_encoder_path: !ref <speaker_encoder_path> # Add this
|
| 72 |
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
|
| 73 |
output_size: 512
|
| 74 |
attention_heads: 8
|
|
|
|
| 201 |
sort_size: 500 # sort_size should be less than shuffle_size
|
| 202 |
batch: !name:cosyvoice.dataset.processor.batch
|
| 203 |
batch_type: 'dynamic'
|
| 204 |
+
max_frames_in_batch: 5000
|
| 205 |
padding: !name:cosyvoice.dataset.processor.padding
|
| 206 |
use_spk_embedding: False # change to True during sft
|
| 207 |
use_speaker_encoder: !ref <use_speaker_encoder>
|
speech/cosyvoice/dataset/processor.py
CHANGED
|
@@ -103,7 +103,7 @@ def individual_file_opener(data, mode='train', tts_data={}):
|
|
| 103 |
|
| 104 |
# Load embeddings if they exist
|
| 105 |
if os.path.exists(file_info['embedding_path']):
|
| 106 |
-
utt_embedding = torch.load(file_info['embedding_path'])
|
| 107 |
if isinstance(utt_embedding, torch.Tensor):
|
| 108 |
utt_embedding = utt_embedding.tolist()
|
| 109 |
else:
|
|
@@ -113,7 +113,7 @@ def individual_file_opener(data, mode='train', tts_data={}):
|
|
| 113 |
|
| 114 |
# Load tokens if they exist
|
| 115 |
if os.path.exists(file_info['token_path']):
|
| 116 |
-
speech_token = torch.load(file_info['token_path'])
|
| 117 |
if isinstance(speech_token, torch.Tensor):
|
| 118 |
speech_token = speech_token.tolist()
|
| 119 |
else:
|
|
@@ -122,7 +122,7 @@ def individual_file_opener(data, mode='train', tts_data={}):
|
|
| 122 |
|
| 123 |
# Load speaker embedding
|
| 124 |
if os.path.exists(file_info['spk_embedding_path']):
|
| 125 |
-
spk_embedding = torch.load(file_info['spk_embedding_path'])
|
| 126 |
if isinstance(spk_embedding, torch.Tensor):
|
| 127 |
spk_embedding = spk_embedding.tolist()
|
| 128 |
else:
|
|
@@ -686,17 +686,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
|
|
| 686 |
batch['reference_mels'] = padded_reference_mels
|
| 687 |
batch['reference_mel_lengths'] = padded_reference_mel_lengths
|
| 688 |
batch['reference_mel_masks'] = reference_mel_masks
|
| 689 |
-
else:
|
| 690 |
-
# No valid mels, create dummy
|
| 691 |
-
batch['reference_mels'] = torch.zeros(batch_size, 1, mel_dim, 100)
|
| 692 |
-
batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
|
| 693 |
-
batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
|
| 694 |
-
else:
|
| 695 |
-
# No references available, create dummy for batch
|
| 696 |
-
batch_size = len(order)
|
| 697 |
-
batch['reference_mels'] = torch.zeros(batch_size, 1, 80, 100)
|
| 698 |
-
batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
|
| 699 |
-
batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
|
| 700 |
|
| 701 |
if gan is True:
|
| 702 |
# in gan train, we need pitch_feat
|
|
@@ -726,9 +715,4 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
|
|
| 726 |
batch['reject_speech_token'] = reject_speech_token
|
| 727 |
batch['reject_speech_token_len'] = reject_speech_token_len
|
| 728 |
|
| 729 |
-
if use_spk_embedding is True:
|
| 730 |
-
batch["embedding"] = batch["spk_embedding"]
|
| 731 |
-
else:
|
| 732 |
-
batch["embedding"] = batch["utt_embedding"]
|
| 733 |
-
|
| 734 |
yield batch
|
|
|
|
| 103 |
|
| 104 |
# Load embeddings if they exist
|
| 105 |
if os.path.exists(file_info['embedding_path']):
|
| 106 |
+
utt_embedding = torch.load(file_info['embedding_path'], weights_only=False)
|
| 107 |
if isinstance(utt_embedding, torch.Tensor):
|
| 108 |
utt_embedding = utt_embedding.tolist()
|
| 109 |
else:
|
|
|
|
| 113 |
|
| 114 |
# Load tokens if they exist
|
| 115 |
if os.path.exists(file_info['token_path']):
|
| 116 |
+
speech_token = torch.load(file_info['token_path'], weights_only=False)
|
| 117 |
if isinstance(speech_token, torch.Tensor):
|
| 118 |
speech_token = speech_token.tolist()
|
| 119 |
else:
|
|
|
|
| 122 |
|
| 123 |
# Load speaker embedding
|
| 124 |
if os.path.exists(file_info['spk_embedding_path']):
|
| 125 |
+
spk_embedding = torch.load(file_info['spk_embedding_path'], weights_only=False)
|
| 126 |
if isinstance(spk_embedding, torch.Tensor):
|
| 127 |
spk_embedding = spk_embedding.tolist()
|
| 128 |
else:
|
|
|
|
| 686 |
batch['reference_mels'] = padded_reference_mels
|
| 687 |
batch['reference_mel_lengths'] = padded_reference_mel_lengths
|
| 688 |
batch['reference_mel_masks'] = reference_mel_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
if gan is True:
|
| 691 |
# in gan train, we need pitch_feat
|
|
|
|
| 715 |
batch['reject_speech_token'] = reject_speech_token
|
| 716 |
batch['reject_speech_token_len'] = reject_speech_token_len
|
| 717 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
yield batch
|
speech/cosyvoice/flow/flow.py
CHANGED
|
@@ -210,6 +210,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 210 |
only_mask_loss: bool = True,
|
| 211 |
token_mel_ratio: int = 2,
|
| 212 |
pre_lookahead_len: int = 3,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
encoder: torch.nn.Module = None,
|
| 214 |
decoder: torch.nn.Module = None,
|
| 215 |
decoder_conf: Dict = {
|
|
@@ -261,6 +265,60 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 261 |
self.input_frame_rate = input_frame_rate
|
| 262 |
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 263 |
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 265 |
self.encoder = encoder
|
| 266 |
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
@@ -271,6 +329,55 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 271 |
print(" decoder_conf['cfm_params']: ", decoder_conf["cfm_params"])
|
| 272 |
self.use_contrastive_fm = decoder_conf["cfm_params"]["use_contrastive_fm"]
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
def forward(
|
| 275 |
self,
|
| 276 |
batch: dict,
|
|
@@ -280,10 +387,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 280 |
token_len = batch["speech_token_len"].to(device)
|
| 281 |
feat = batch["speech_feat"].to(device)
|
| 282 |
feat_len = batch["speech_feat_len"].to(device)
|
| 283 |
-
embedding = batch["embedding"].to(device)
|
| 284 |
|
| 285 |
# NOTE unified training, static_chunk_size > 0 or = 0
|
| 286 |
streaming = False # if random.random() < 0.5 else False
|
|
|
|
|
|
|
| 287 |
|
| 288 |
# xvec projection
|
| 289 |
embedding = F.normalize(embedding, dim=1)
|
|
@@ -335,13 +443,29 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
| 335 |
prompt_token_len,
|
| 336 |
prompt_feat,
|
| 337 |
prompt_feat_len,
|
| 338 |
-
embedding,
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
| 341 |
):
|
| 342 |
assert token.shape[0] == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
# xvec projection
|
| 344 |
-
embedding = F.normalize(embedding, dim=1)
|
| 345 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 346 |
|
| 347 |
# concat text and prompt_text
|
|
|
|
| 210 |
only_mask_loss: bool = True,
|
| 211 |
token_mel_ratio: int = 2,
|
| 212 |
pre_lookahead_len: int = 3,
|
| 213 |
+
use_speaker_encoder: bool = False, # Add this
|
| 214 |
+
freeze_speaker_encoder: bool = False, # Add this
|
| 215 |
+
max_conditioning_inputs: int = 2, # Add this
|
| 216 |
+
speaker_encoder_path: str = None,
|
| 217 |
encoder: torch.nn.Module = None,
|
| 218 |
decoder: torch.nn.Module = None,
|
| 219 |
decoder_conf: Dict = {
|
|
|
|
| 265 |
self.input_frame_rate = input_frame_rate
|
| 266 |
logging.info(f"input frame rate={self.input_frame_rate}")
|
| 267 |
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
| 268 |
+
self.use_speaker_encoder = use_speaker_encoder
|
| 269 |
+
# Speaker encoder setup
|
| 270 |
+
if use_speaker_encoder:
|
| 271 |
+
from cosyvoice.llm.llm import LearnableSpeakerEncoder
|
| 272 |
+
self.speaker_encoder = LearnableSpeakerEncoder(
|
| 273 |
+
mel_dim=80,
|
| 274 |
+
model_dim=512,
|
| 275 |
+
output_dim=spk_embed_dim,
|
| 276 |
+
num_blocks=6,
|
| 277 |
+
num_heads=8,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Load speaker encoder weights from LLM checkpoint if provided
|
| 281 |
+
if speaker_encoder_path is not None:
|
| 282 |
+
logging.info(f"Loading speaker encoder from {speaker_encoder_path}")
|
| 283 |
+
checkpoint = torch.load(speaker_encoder_path, map_location='cpu')
|
| 284 |
+
|
| 285 |
+
# Debug: print checkpoint structure
|
| 286 |
+
print(f'Checkpoint keys: {checkpoint.keys()}')
|
| 287 |
+
|
| 288 |
+
# Extract speaker encoder weights
|
| 289 |
+
speaker_encoder_state = {}
|
| 290 |
+
|
| 291 |
+
# Check if checkpoint has 'state_dict' key or direct model weights
|
| 292 |
+
if 'state_dict' in checkpoint:
|
| 293 |
+
state_dict = checkpoint['state_dict']
|
| 294 |
+
else:
|
| 295 |
+
# Direct model weights (based on your save function)
|
| 296 |
+
state_dict = {k: v for k, v in checkpoint.items() if not k in ['epoch', 'step']}
|
| 297 |
+
|
| 298 |
+
# Extract speaker encoder weights
|
| 299 |
+
for key, value in state_dict.items():
|
| 300 |
+
if 'speaker_encoder.' in key:
|
| 301 |
+
# Remove module. prefix if exists (from DDP)
|
| 302 |
+
new_key = key.replace('module.', '')
|
| 303 |
+
# Remove speaker_encoder. prefix to match the local module
|
| 304 |
+
new_key = new_key.replace('speaker_encoder.', '')
|
| 305 |
+
speaker_encoder_state[new_key] = value
|
| 306 |
+
|
| 307 |
+
if len(speaker_encoder_state) == 0:
|
| 308 |
+
logging.warning("No speaker encoder weights found in checkpoint!")
|
| 309 |
+
logging.warning(f"Available keys: {list(state_dict.keys())[:10]}...") # Show first 10 keys
|
| 310 |
+
else:
|
| 311 |
+
logging.info(f"Found {len(speaker_encoder_state)} speaker encoder weights")
|
| 312 |
+
# Load the weights
|
| 313 |
+
self.speaker_encoder.load_state_dict(speaker_encoder_state, strict=True)
|
| 314 |
+
logging.info("Speaker encoder loaded successfully")
|
| 315 |
+
self.freeze_speaker_encoder = freeze_speaker_encoder
|
| 316 |
+
if freeze_speaker_encoder:
|
| 317 |
+
# Freeze speaker encoder parameters
|
| 318 |
+
for param in self.speaker_encoder.parameters():
|
| 319 |
+
param.requires_grad = False
|
| 320 |
+
logging.info("Speaker encoder frozen in flow matching")
|
| 321 |
+
|
| 322 |
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
| 323 |
self.encoder = encoder
|
| 324 |
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
|
|
| 329 |
print(" decoder_conf['cfm_params']: ", decoder_conf["cfm_params"])
|
| 330 |
self.use_contrastive_fm = decoder_conf["cfm_params"]["use_contrastive_fm"]
|
| 331 |
|
| 332 |
+
def get_speaker_embedding(self, batch, device):
|
| 333 |
+
"""Extract speaker embedding from reference mels or use provided embeddings"""
|
| 334 |
+
|
| 335 |
+
if self.use_speaker_encoder and 'reference_mels' in batch:
|
| 336 |
+
reference_mels = batch['reference_mels'].to(device)
|
| 337 |
+
|
| 338 |
+
# Handle multiple references
|
| 339 |
+
if reference_mels.dim() == 4: # [B, N, C, T]
|
| 340 |
+
B, N, C, T = reference_mels.shape
|
| 341 |
+
embeddings = []
|
| 342 |
+
|
| 343 |
+
for i in range(N):
|
| 344 |
+
ref_mel = reference_mels[:, i, :, :] # [B, C, T]
|
| 345 |
+
if 'reference_mel_masks' in batch:
|
| 346 |
+
mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
|
| 347 |
+
else:
|
| 348 |
+
mask = None
|
| 349 |
+
print('ref_mel mask: ', ref_mel.shape, mask.shape)
|
| 350 |
+
# Apply speaker encoder
|
| 351 |
+
with torch.set_grad_enabled(not self.freeze_speaker_encoder):
|
| 352 |
+
emb = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
|
| 353 |
+
embeddings.append(emb)
|
| 354 |
+
|
| 355 |
+
# Average multiple references
|
| 356 |
+
embedding = torch.stack(embeddings, dim=1).mean(dim=1) # [B, spk_embed_dim]
|
| 357 |
+
|
| 358 |
+
else: # Single reference [B, C, T]
|
| 359 |
+
if 'reference_mel_mask' in batch:
|
| 360 |
+
mask = batch['reference_mel_mask'].unsqueeze(1).to(device)
|
| 361 |
+
else:
|
| 362 |
+
mask = None
|
| 363 |
+
|
| 364 |
+
with torch.set_grad_enabled(not self.freeze_speaker_encoder):
|
| 365 |
+
embedding = self.speaker_encoder(reference_mels, mask)
|
| 366 |
+
|
| 367 |
+
# Normalize (already normalized in speaker encoder, but just to be safe)
|
| 368 |
+
embedding = F.normalize(embedding, dim=1)
|
| 369 |
+
|
| 370 |
+
elif 'embedding' in batch:
|
| 371 |
+
# Use provided embeddings (backward compatibility)
|
| 372 |
+
embedding = batch['embedding'].to(device)
|
| 373 |
+
embedding = F.normalize(embedding, dim=1)
|
| 374 |
+
else:
|
| 375 |
+
# No speaker conditioning
|
| 376 |
+
B = batch['speech_token'].shape[0]
|
| 377 |
+
embedding = torch.zeros(B, self.spk_embed_dim).to(device)
|
| 378 |
+
|
| 379 |
+
return embedding
|
| 380 |
+
|
| 381 |
def forward(
|
| 382 |
self,
|
| 383 |
batch: dict,
|
|
|
|
| 387 |
token_len = batch["speech_token_len"].to(device)
|
| 388 |
feat = batch["speech_feat"].to(device)
|
| 389 |
feat_len = batch["speech_feat_len"].to(device)
|
|
|
|
| 390 |
|
| 391 |
# NOTE unified training, static_chunk_size > 0 or = 0
|
| 392 |
streaming = False # if random.random() < 0.5 else False
|
| 393 |
+
print("get speaker embedding")
|
| 394 |
+
embedding = self.get_speaker_embedding(batch, device)
|
| 395 |
|
| 396 |
# xvec projection
|
| 397 |
embedding = F.normalize(embedding, dim=1)
|
|
|
|
| 443 |
prompt_token_len,
|
| 444 |
prompt_feat,
|
| 445 |
prompt_feat_len,
|
| 446 |
+
embedding=None,
|
| 447 |
+
reference_mels=None,
|
| 448 |
+
reference_mel_lengths=None,
|
| 449 |
+
reference_mel_masks=None,
|
| 450 |
+
streaming=False,
|
| 451 |
+
finalize=False,
|
| 452 |
):
|
| 453 |
assert token.shape[0] == 1
|
| 454 |
+
|
| 455 |
+
# Get speaker embedding
|
| 456 |
+
if self.use_speaker_encoder and reference_mels is not None:
|
| 457 |
+
batch = {
|
| 458 |
+
'reference_mels': reference_mels,
|
| 459 |
+
'reference_mel_lengths': reference_mel_lengths,
|
| 460 |
+
'reference_mel_masks': reference_mel_masks
|
| 461 |
+
}
|
| 462 |
+
embedding = self.get_speaker_embedding(batch, token.device)
|
| 463 |
+
elif embedding is not None:
|
| 464 |
+
embedding = F.normalize(embedding, dim=1)
|
| 465 |
+
else:
|
| 466 |
+
embedding = torch.zeros(1, self.spk_embed_dim).to(token.device)
|
| 467 |
+
|
| 468 |
# xvec projection
|
|
|
|
| 469 |
embedding = self.spk_embed_affine_layer(embedding)
|
| 470 |
|
| 471 |
# concat text and prompt_text
|
speech/cosyvoice/llm/llm.py
CHANGED
|
@@ -29,11 +29,11 @@ from cosyvoice.utils.file_utils import logging
|
|
| 29 |
from cosyvoice.utils.mask import make_pad_mask
|
| 30 |
from cosyvoice.transformer.attention import MultiHeadedAttention
|
| 31 |
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
|
|
|
| 32 |
|
| 33 |
class LearnableSpeakerEncoder(nn.Module):
|
| 34 |
"""
|
| 35 |
-
Speaker encoder
|
| 36 |
-
Processes mel-spectrograms through attention blocks to extract speaker characteristics.
|
| 37 |
"""
|
| 38 |
def __init__(
|
| 39 |
self,
|
|
@@ -42,58 +42,60 @@ class LearnableSpeakerEncoder(nn.Module):
|
|
| 42 |
output_dim: int = 192,
|
| 43 |
num_blocks: int = 6,
|
| 44 |
num_heads: int = 8,
|
| 45 |
-
kernel_size: int = 1,
|
| 46 |
dropout: float = 0.0,
|
|
|
|
| 47 |
):
|
| 48 |
super().__init__()
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
self.init = nn.Conv1d(mel_dim, model_dim, kernel_size=
|
| 52 |
-
|
| 53 |
-
self.blocks = nn.ModuleList([
|
| 54 |
-
ConformerEncoderLayer(
|
| 55 |
-
size=model_dim,
|
| 56 |
-
self_attn=MultiHeadedAttention(
|
| 57 |
-
n_head=num_heads,
|
| 58 |
-
n_feat=model_dim,
|
| 59 |
-
dropout_rate=dropout,
|
| 60 |
-
),
|
| 61 |
-
feed_forward=None, # Can add if needed
|
| 62 |
-
feed_forward_macaron=None,
|
| 63 |
-
conv_module=None,
|
| 64 |
-
dropout_rate=dropout,
|
| 65 |
-
normalize_before=True,
|
| 66 |
-
) for _ in range(num_blocks)
|
| 67 |
-
])
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
def forward(self, x, mask=None):
|
| 73 |
"""
|
| 74 |
Args:
|
| 75 |
x: mel-spectrogram [B, 80, T]
|
| 76 |
-
mask: padding mask [B, 1, T]
|
| 77 |
Returns:
|
| 78 |
speaker embedding [B, output_dim]
|
| 79 |
"""
|
| 80 |
# Initial conv
|
| 81 |
h = self.init(x) # [B, model_dim, T]
|
| 82 |
-
h = h.transpose(1, 2) # [B, T, model_dim]
|
| 83 |
|
| 84 |
# Apply attention blocks
|
| 85 |
-
|
| 86 |
-
h, _ = block(h, mask)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# Project to output dimension
|
| 93 |
-
output = self.output_proj(
|
| 94 |
|
| 95 |
return F.normalize(output, p=2, dim=1)
|
| 96 |
|
|
|
|
| 97 |
class TransformerLM(torch.nn.Module):
|
| 98 |
def __init__(
|
| 99 |
self,
|
|
@@ -163,19 +165,19 @@ class TransformerLM(torch.nn.Module):
|
|
| 163 |
|
| 164 |
if self.use_speaker_encoder and 'reference_mels' in batch:
|
| 165 |
reference_mels = batch['reference_mels'].to(device)
|
| 166 |
-
|
| 167 |
# Handle multiple references
|
| 168 |
if reference_mels.dim() == 4: # [B, N, T, 80]
|
| 169 |
-
B, N,
|
| 170 |
conds = []
|
| 171 |
|
| 172 |
for i in range(N):
|
| 173 |
-
ref_mel = reference_mels[:, i, :, :]
|
| 174 |
if 'reference_mel_masks' in batch:
|
| 175 |
mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
|
| 176 |
else:
|
| 177 |
mask = None
|
| 178 |
-
|
| 179 |
cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
|
| 180 |
conds.append(cond)
|
| 181 |
|
|
@@ -547,42 +549,42 @@ class Qwen2LM(TransformerLM):
|
|
| 547 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
| 548 |
return lm_target, lm_input, lm_input_len
|
| 549 |
|
| 550 |
-
def forward(
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
) -> Dict[str, Optional[torch.Tensor]]:
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
|
| 585 |
-
def
|
| 586 |
self,
|
| 587 |
batch: dict,
|
| 588 |
device: torch.device,
|
|
|
|
| 29 |
from cosyvoice.utils.mask import make_pad_mask
|
| 30 |
from cosyvoice.transformer.attention import MultiHeadedAttention
|
| 31 |
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
| 32 |
+
from cosyvoice.transformer.arch_util import AttentionBlock
|
| 33 |
|
| 34 |
class LearnableSpeakerEncoder(nn.Module):
|
| 35 |
"""
|
| 36 |
+
Speaker encoder using the same architecture as Tortoise-TTS ConditioningEncoder.
|
|
|
|
| 37 |
"""
|
| 38 |
def __init__(
|
| 39 |
self,
|
|
|
|
| 42 |
output_dim: int = 192,
|
| 43 |
num_blocks: int = 6,
|
| 44 |
num_heads: int = 8,
|
|
|
|
| 45 |
dropout: float = 0.0,
|
| 46 |
+
mean_pooling: bool = False, # Tortoise uses first position by default
|
| 47 |
):
|
| 48 |
super().__init__()
|
| 49 |
|
| 50 |
+
# Same as Tortoise ConditioningEncoder
|
| 51 |
+
self.init = nn.Conv1d(mel_dim, model_dim, kernel_size=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
# AttentionBlock from Tortoise
|
| 54 |
+
attn = []
|
| 55 |
+
for _ in range(num_blocks):
|
| 56 |
+
attn.append(AttentionBlock(model_dim, num_heads))
|
| 57 |
+
self.attn = nn.Sequential(*attn)
|
| 58 |
+
|
| 59 |
+
self.dim = model_dim
|
| 60 |
+
self.mean_pooling = mean_pooling
|
| 61 |
|
| 62 |
+
# Output projection to match CosyVoice embedding dimension
|
| 63 |
+
self.output_proj = nn.Linear(model_dim, output_dim)
|
| 64 |
+
|
| 65 |
def forward(self, x, mask=None):
|
| 66 |
"""
|
| 67 |
Args:
|
| 68 |
x: mel-spectrogram [B, 80, T]
|
| 69 |
+
mask: padding mask [B, 1, T] (not used in Tortoise version)
|
| 70 |
Returns:
|
| 71 |
speaker embedding [B, output_dim]
|
| 72 |
"""
|
| 73 |
# Initial conv
|
| 74 |
h = self.init(x) # [B, model_dim, T]
|
|
|
|
| 75 |
|
| 76 |
# Apply attention blocks
|
| 77 |
+
h = self.attn(h) # [B, model_dim, T]
|
|
|
|
| 78 |
|
| 79 |
+
# Pooling - Tortoise uses first position
|
| 80 |
+
if self.mean_pooling:
|
| 81 |
+
if mask is not None:
|
| 82 |
+
# Masked mean pooling
|
| 83 |
+
h_masked = h * mask
|
| 84 |
+
h_sum = h_masked.sum(dim=2)
|
| 85 |
+
mask_sum = mask.sum(dim=2).clamp(min=1)
|
| 86 |
+
h_pooled = h_sum / mask_sum
|
| 87 |
+
else:
|
| 88 |
+
h_pooled = h.mean(dim=2)
|
| 89 |
+
else:
|
| 90 |
+
# Use first position like Tortoise
|
| 91 |
+
h_pooled = h[:, :, 0]
|
| 92 |
|
| 93 |
# Project to output dimension
|
| 94 |
+
output = self.output_proj(h_pooled) # [B, output_dim]
|
| 95 |
|
| 96 |
return F.normalize(output, p=2, dim=1)
|
| 97 |
|
| 98 |
+
|
| 99 |
class TransformerLM(torch.nn.Module):
|
| 100 |
def __init__(
|
| 101 |
self,
|
|
|
|
| 165 |
|
| 166 |
if self.use_speaker_encoder and 'reference_mels' in batch:
|
| 167 |
reference_mels = batch['reference_mels'].to(device)
|
| 168 |
+
print('reference_mels shape: ', reference_mels.shape)
|
| 169 |
# Handle multiple references
|
| 170 |
if reference_mels.dim() == 4: # [B, N, T, 80]
|
| 171 |
+
B, N, C, T = reference_mels.shape
|
| 172 |
conds = []
|
| 173 |
|
| 174 |
for i in range(N):
|
| 175 |
+
ref_mel = reference_mels[:, i, :, :]
|
| 176 |
if 'reference_mel_masks' in batch:
|
| 177 |
mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
|
| 178 |
else:
|
| 179 |
mask = None
|
| 180 |
+
print('ref_mel shape: ', ref_mel.shape, 'mask: ', mask.shape)
|
| 181 |
cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
|
| 182 |
conds.append(cond)
|
| 183 |
|
|
|
|
| 549 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
| 550 |
return lm_target, lm_input, lm_input_len
|
| 551 |
|
| 552 |
+
# def forward(
|
| 553 |
+
# self,
|
| 554 |
+
# batch: dict,
|
| 555 |
+
# device: torch.device,
|
| 556 |
+
# ) -> Dict[str, Optional[torch.Tensor]]:
|
| 557 |
+
# """
|
| 558 |
+
# Args:
|
| 559 |
+
# text: (B, L, D)
|
| 560 |
+
# text_lengths: (B,)
|
| 561 |
+
# audio: (B, T, N) or (B, T)
|
| 562 |
+
# audio_lengths: (B,)
|
| 563 |
+
# """
|
| 564 |
+
# text_token = batch['text_token'].to(device)
|
| 565 |
+
# text_token_len = batch['text_token_len'].to(device)
|
| 566 |
+
# speech_token = batch['speech_token'].to(device)
|
| 567 |
+
# speech_token_len = batch['speech_token_len'].to(device)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# # 1. encode text_token
|
| 571 |
+
# text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
| 572 |
+
|
| 573 |
+
# # 2. encode speech_token
|
| 574 |
+
# speech_token_emb = self.speech_embedding(speech_token)
|
| 575 |
+
|
| 576 |
+
# # 3. prepare llm_input/target
|
| 577 |
+
# lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
|
| 578 |
+
# lm_target = lm_target.to(device)
|
| 579 |
+
|
| 580 |
+
# # 4. run lm forward
|
| 581 |
+
# lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
| 582 |
+
# logits = self.llm_decoder(lm_output)
|
| 583 |
+
# loss = self.criterion_ce(logits, lm_target.to(device))
|
| 584 |
+
# acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
| 585 |
+
# return {'loss': loss, 'acc': acc}
|
| 586 |
|
| 587 |
+
def forward(
|
| 588 |
self,
|
| 589 |
batch: dict,
|
| 590 |
device: torch.device,
|
speech/cosyvoice/transformer/arch_util.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import functools
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchaudio
|
| 9 |
+
from cosyvoice.transformer.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def zero_module(module):
|
| 13 |
+
"""
|
| 14 |
+
Zero out the parameters of a module and return it.
|
| 15 |
+
"""
|
| 16 |
+
for p in module.parameters():
|
| 17 |
+
p.detach().zero_()
|
| 18 |
+
return module
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GroupNorm32(nn.GroupNorm):
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return super().forward(x.float()).type(x.dtype)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def normalization(channels):
|
| 27 |
+
"""
|
| 28 |
+
Make a standard normalization layer.
|
| 29 |
+
|
| 30 |
+
:param channels: number of input channels.
|
| 31 |
+
:return: an nn.Module for normalization.
|
| 32 |
+
"""
|
| 33 |
+
groups = 32
|
| 34 |
+
if channels <= 16:
|
| 35 |
+
groups = 8
|
| 36 |
+
elif channels <= 64:
|
| 37 |
+
groups = 16
|
| 38 |
+
while channels % groups != 0:
|
| 39 |
+
groups = int(groups / 2)
|
| 40 |
+
assert groups > 2
|
| 41 |
+
return GroupNorm32(groups, channels)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class QKVAttentionLegacy(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, n_heads):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.n_heads = n_heads
|
| 52 |
+
|
| 53 |
+
def forward(self, qkv, mask=None, rel_pos=None):
|
| 54 |
+
"""
|
| 55 |
+
Apply QKV attention.
|
| 56 |
+
|
| 57 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
| 58 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
| 59 |
+
"""
|
| 60 |
+
bs, width, length = qkv.shape
|
| 61 |
+
assert width % (3 * self.n_heads) == 0
|
| 62 |
+
ch = width // (3 * self.n_heads)
|
| 63 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
| 64 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 65 |
+
weight = torch.einsum(
|
| 66 |
+
"bct,bcs->bts", q * scale, k * scale
|
| 67 |
+
) # More stable with f16 than dividing afterwards
|
| 68 |
+
if rel_pos is not None:
|
| 69 |
+
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
| 70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 71 |
+
if mask is not None:
|
| 72 |
+
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
| 73 |
+
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
| 74 |
+
weight = weight * mask
|
| 75 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
| 76 |
+
|
| 77 |
+
return a.reshape(bs, -1, length)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class AttentionBlock(nn.Module):
|
| 81 |
+
"""
|
| 82 |
+
An attention block that allows spatial positions to attend to each other.
|
| 83 |
+
|
| 84 |
+
Originally ported from here, but adapted to the N-d case.
|
| 85 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
channels,
|
| 91 |
+
num_heads=1,
|
| 92 |
+
num_head_channels=-1,
|
| 93 |
+
do_checkpoint=True,
|
| 94 |
+
relative_pos_embeddings=False,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.channels = channels
|
| 98 |
+
self.do_checkpoint = do_checkpoint
|
| 99 |
+
if num_head_channels == -1:
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
else:
|
| 102 |
+
assert (
|
| 103 |
+
channels % num_head_channels == 0
|
| 104 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 105 |
+
self.num_heads = channels // num_head_channels
|
| 106 |
+
self.norm = normalization(channels)
|
| 107 |
+
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
| 108 |
+
# split heads before split qkv
|
| 109 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
| 110 |
+
|
| 111 |
+
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
| 112 |
+
if relative_pos_embeddings:
|
| 113 |
+
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
| 114 |
+
else:
|
| 115 |
+
self.relative_pos_embeddings = None
|
| 116 |
+
|
| 117 |
+
def forward(self, x, mask=None):
|
| 118 |
+
b, c, *spatial = x.shape
|
| 119 |
+
x = x.reshape(b, c, -1)
|
| 120 |
+
qkv = self.qkv(self.norm(x))
|
| 121 |
+
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
| 122 |
+
h = self.proj_out(h)
|
| 123 |
+
return (x + h).reshape(b, c, *spatial)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Upsample(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
An upsampling layer with an optional convolution.
|
| 129 |
+
|
| 130 |
+
:param channels: channels in the inputs and outputs.
|
| 131 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, channels, use_conv, out_channels=None, factor=4):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.channels = channels
|
| 137 |
+
self.out_channels = out_channels or channels
|
| 138 |
+
self.use_conv = use_conv
|
| 139 |
+
self.factor = factor
|
| 140 |
+
if use_conv:
|
| 141 |
+
ksize = 5
|
| 142 |
+
pad = 2
|
| 143 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
assert x.shape[1] == self.channels
|
| 147 |
+
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
|
| 148 |
+
if self.use_conv:
|
| 149 |
+
x = self.conv(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class Downsample(nn.Module):
|
| 154 |
+
"""
|
| 155 |
+
A downsampling layer with an optional convolution.
|
| 156 |
+
|
| 157 |
+
:param channels: channels in the inputs and outputs.
|
| 158 |
+
:param use_conv: a bool determining if a convolution is applied.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.channels = channels
|
| 164 |
+
self.out_channels = out_channels or channels
|
| 165 |
+
self.use_conv = use_conv
|
| 166 |
+
|
| 167 |
+
stride = factor
|
| 168 |
+
if use_conv:
|
| 169 |
+
self.op = nn.Conv1d(
|
| 170 |
+
self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
assert self.channels == self.out_channels
|
| 174 |
+
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
assert x.shape[1] == self.channels
|
| 178 |
+
return self.op(x)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ResBlock(nn.Module):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
channels,
|
| 185 |
+
dropout,
|
| 186 |
+
out_channels=None,
|
| 187 |
+
use_conv=False,
|
| 188 |
+
use_scale_shift_norm=False,
|
| 189 |
+
up=False,
|
| 190 |
+
down=False,
|
| 191 |
+
kernel_size=3,
|
| 192 |
+
):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.channels = channels
|
| 195 |
+
self.dropout = dropout
|
| 196 |
+
self.out_channels = out_channels or channels
|
| 197 |
+
self.use_conv = use_conv
|
| 198 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 199 |
+
padding = 1 if kernel_size == 3 else 2
|
| 200 |
+
|
| 201 |
+
self.in_layers = nn.Sequential(
|
| 202 |
+
normalization(channels),
|
| 203 |
+
nn.SiLU(),
|
| 204 |
+
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.updown = up or down
|
| 208 |
+
|
| 209 |
+
if up:
|
| 210 |
+
self.h_upd = Upsample(channels, False)
|
| 211 |
+
self.x_upd = Upsample(channels, False)
|
| 212 |
+
elif down:
|
| 213 |
+
self.h_upd = Downsample(channels, False)
|
| 214 |
+
self.x_upd = Downsample(channels, False)
|
| 215 |
+
else:
|
| 216 |
+
self.h_upd = self.x_upd = nn.Identity()
|
| 217 |
+
|
| 218 |
+
self.out_layers = nn.Sequential(
|
| 219 |
+
normalization(self.out_channels),
|
| 220 |
+
nn.SiLU(),
|
| 221 |
+
nn.Dropout(p=dropout),
|
| 222 |
+
zero_module(
|
| 223 |
+
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
| 224 |
+
),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if self.out_channels == channels:
|
| 228 |
+
self.skip_connection = nn.Identity()
|
| 229 |
+
elif use_conv:
|
| 230 |
+
self.skip_connection = nn.Conv1d(
|
| 231 |
+
channels, self.out_channels, kernel_size, padding=padding
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
if self.updown:
|
| 238 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 239 |
+
h = in_rest(x)
|
| 240 |
+
h = self.h_upd(h)
|
| 241 |
+
x = self.x_upd(x)
|
| 242 |
+
h = in_conv(h)
|
| 243 |
+
else:
|
| 244 |
+
h = self.in_layers(x)
|
| 245 |
+
h = self.out_layers(h)
|
| 246 |
+
return self.skip_connection(x) + h
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class AudioMiniEncoder(nn.Module):
|
| 250 |
+
def __init__(self,
|
| 251 |
+
spec_dim,
|
| 252 |
+
embedding_dim,
|
| 253 |
+
base_channels=128,
|
| 254 |
+
depth=2,
|
| 255 |
+
resnet_blocks=2,
|
| 256 |
+
attn_blocks=4,
|
| 257 |
+
num_attn_heads=4,
|
| 258 |
+
dropout=0,
|
| 259 |
+
downsample_factor=2,
|
| 260 |
+
kernel_size=3):
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.init = nn.Sequential(
|
| 263 |
+
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
| 264 |
+
)
|
| 265 |
+
ch = base_channels
|
| 266 |
+
res = []
|
| 267 |
+
for l in range(depth):
|
| 268 |
+
for r in range(resnet_blocks):
|
| 269 |
+
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
| 270 |
+
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
| 271 |
+
ch *= 2
|
| 272 |
+
self.res = nn.Sequential(*res)
|
| 273 |
+
self.final = nn.Sequential(
|
| 274 |
+
normalization(ch),
|
| 275 |
+
nn.SiLU(),
|
| 276 |
+
nn.Conv1d(ch, embedding_dim, 1)
|
| 277 |
+
)
|
| 278 |
+
attn = []
|
| 279 |
+
for a in range(attn_blocks):
|
| 280 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
|
| 281 |
+
self.attn = nn.Sequential(*attn)
|
| 282 |
+
self.dim = embedding_dim
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
h = self.init(x)
|
| 286 |
+
h = self.res(h)
|
| 287 |
+
h = self.final(h)
|
| 288 |
+
h = self.attn(h)
|
| 289 |
+
return h[:, :, 0]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class TorchMelSpectrogram(nn.Module):
|
| 296 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
|
| 297 |
+
sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE):
|
| 298 |
+
super().__init__()
|
| 299 |
+
# These are the default tacotron values for the MEL spectrogram.
|
| 300 |
+
self.filter_length = filter_length
|
| 301 |
+
self.hop_length = hop_length
|
| 302 |
+
self.win_length = win_length
|
| 303 |
+
self.n_mel_channels = n_mel_channels
|
| 304 |
+
self.mel_fmin = mel_fmin
|
| 305 |
+
self.mel_fmax = mel_fmax
|
| 306 |
+
self.sampling_rate = sampling_rate
|
| 307 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
| 308 |
+
win_length=self.win_length, power=2, normalized=normalize,
|
| 309 |
+
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
| 310 |
+
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
|
| 311 |
+
norm="slaney")
|
| 312 |
+
self.mel_norm_file = mel_norm_file
|
| 313 |
+
if self.mel_norm_file is not None:
|
| 314 |
+
self.mel_norms = torch.load(self.mel_norm_file)
|
| 315 |
+
else:
|
| 316 |
+
self.mel_norms = None
|
| 317 |
+
|
| 318 |
+
def forward(self, inp):
|
| 319 |
+
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
| 320 |
+
inp = inp.squeeze(1)
|
| 321 |
+
assert len(inp.shape) == 2
|
| 322 |
+
if torch.backends.mps.is_available():
|
| 323 |
+
inp = inp.to('cpu')
|
| 324 |
+
self.mel_stft = self.mel_stft.to(inp.device)
|
| 325 |
+
mel = self.mel_stft(inp)
|
| 326 |
+
# Perform dynamic range compression
|
| 327 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 328 |
+
if self.mel_norms is not None:
|
| 329 |
+
self.mel_norms = self.mel_norms.to(mel.device)
|
| 330 |
+
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
| 331 |
+
return mel
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class CheckpointedLayer(nn.Module):
|
| 335 |
+
"""
|
| 336 |
+
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
| 337 |
+
checkpoint for all other args.
|
| 338 |
+
"""
|
| 339 |
+
def __init__(self, wrap):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.wrap = wrap
|
| 342 |
+
|
| 343 |
+
def forward(self, x, *args, **kwargs):
|
| 344 |
+
for k, v in kwargs.items():
|
| 345 |
+
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
| 346 |
+
partial = functools.partial(self.wrap, **kwargs)
|
| 347 |
+
return partial(x, *args)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class CheckpointedXTransformerEncoder(nn.Module):
|
| 351 |
+
"""
|
| 352 |
+
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
| 353 |
+
to channels-last that XTransformer expects.
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
| 358 |
+
self.needs_permute = needs_permute
|
| 359 |
+
self.exit_permute = exit_permute
|
| 360 |
+
|
| 361 |
+
if not checkpoint:
|
| 362 |
+
return
|
| 363 |
+
for i in range(len(self.transformer.attn_layers.layers)):
|
| 364 |
+
n, b, r = self.transformer.attn_layers.layers[i]
|
| 365 |
+
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
| 366 |
+
|
| 367 |
+
def forward(self, x, **kwargs):
|
| 368 |
+
if self.needs_permute:
|
| 369 |
+
x = x.permute(0,2,1)
|
| 370 |
+
h = self.transformer(x, **kwargs)
|
| 371 |
+
if self.exit_permute:
|
| 372 |
+
h = h.permute(0,2,1)
|
| 373 |
+
return h
|
speech/cosyvoice/transformer/xtransformers.py
ADDED
|
@@ -0,0 +1,1248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
from functools import partial
|
| 4 |
+
from inspect import isfunction
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
from torch import nn, einsum
|
| 10 |
+
|
| 11 |
+
DEFAULT_DIM_HEAD = 64
|
| 12 |
+
|
| 13 |
+
Intermediates = namedtuple('Intermediates', [
|
| 14 |
+
'pre_softmax_attn',
|
| 15 |
+
'post_softmax_attn'
|
| 16 |
+
])
|
| 17 |
+
|
| 18 |
+
LayerIntermediates = namedtuple('Intermediates', [
|
| 19 |
+
'hiddens',
|
| 20 |
+
'attn_intermediates',
|
| 21 |
+
'past_key_values',
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# helpers
|
| 26 |
+
|
| 27 |
+
def exists(val):
|
| 28 |
+
return val is not None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def default(val, d):
|
| 32 |
+
if exists(val):
|
| 33 |
+
return val
|
| 34 |
+
return d() if isfunction(d) else d
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def cast_tuple(val, depth):
|
| 38 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class always():
|
| 42 |
+
def __init__(self, val):
|
| 43 |
+
self.val = val
|
| 44 |
+
|
| 45 |
+
def __call__(self, *args, **kwargs):
|
| 46 |
+
return self.val
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class not_equals():
|
| 50 |
+
def __init__(self, val):
|
| 51 |
+
self.val = val
|
| 52 |
+
|
| 53 |
+
def __call__(self, x, *args, **kwargs):
|
| 54 |
+
return x != self.val
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class equals():
|
| 58 |
+
def __init__(self, val):
|
| 59 |
+
self.val = val
|
| 60 |
+
|
| 61 |
+
def __call__(self, x, *args, **kwargs):
|
| 62 |
+
return x == self.val
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def max_neg_value(tensor):
|
| 66 |
+
return -torch.finfo(tensor.dtype).max
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def l2norm(t):
|
| 70 |
+
return F.normalize(t, p=2, dim=-1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# init helpers
|
| 74 |
+
|
| 75 |
+
def init_zero_(layer):
|
| 76 |
+
nn.init.constant_(layer.weight, 0.)
|
| 77 |
+
if exists(layer.bias):
|
| 78 |
+
nn.init.constant_(layer.bias, 0.)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# keyword argument helpers
|
| 82 |
+
|
| 83 |
+
def pick_and_pop(keys, d):
|
| 84 |
+
values = list(map(lambda key: d.pop(key), keys))
|
| 85 |
+
return dict(zip(keys, values))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def group_dict_by_key(cond, d):
|
| 89 |
+
return_val = [dict(), dict()]
|
| 90 |
+
for key in d.keys():
|
| 91 |
+
match = bool(cond(key))
|
| 92 |
+
ind = int(not match)
|
| 93 |
+
return_val[ind][key] = d[key]
|
| 94 |
+
return (*return_val,)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def string_begins_with(prefix, str):
|
| 98 |
+
return str.startswith(prefix)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def group_by_key_prefix(prefix, d):
|
| 102 |
+
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def groupby_prefix_and_trim(prefix, d):
|
| 106 |
+
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
| 107 |
+
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
| 108 |
+
return kwargs_without_prefix, kwargs
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# activations
|
| 112 |
+
|
| 113 |
+
class ReluSquared(nn.Module):
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
return F.relu(x) ** 2
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# positional embeddings
|
| 119 |
+
|
| 120 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
| 121 |
+
def __init__(self, dim, max_seq_len):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.scale = dim ** -0.5
|
| 124 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
n = torch.arange(x.shape[1], device=x.device)
|
| 128 |
+
pos_emb = self.emb(n)
|
| 129 |
+
pos_emb = rearrange(pos_emb, 'n d -> () n d')
|
| 130 |
+
return pos_emb * self.scale
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class FixedPositionalEmbedding(nn.Module):
|
| 134 |
+
def __init__(self, dim):
|
| 135 |
+
super().__init__()
|
| 136 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| 137 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 138 |
+
|
| 139 |
+
def forward(self, x, seq_dim=1, offset=0):
|
| 140 |
+
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
| 141 |
+
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
| 142 |
+
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
| 143 |
+
return rearrange(emb, 'n d -> () n d')
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class RelativePositionBias(nn.Module):
|
| 147 |
+
def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.scale = scale
|
| 150 |
+
self.causal = causal
|
| 151 |
+
self.num_buckets = num_buckets
|
| 152 |
+
self.max_distance = max_distance
|
| 153 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
|
| 157 |
+
ret = 0
|
| 158 |
+
n = -relative_position
|
| 159 |
+
if not causal:
|
| 160 |
+
num_buckets //= 2
|
| 161 |
+
ret += (n < 0).long() * num_buckets
|
| 162 |
+
n = torch.abs(n)
|
| 163 |
+
else:
|
| 164 |
+
n = torch.max(n, torch.zeros_like(n))
|
| 165 |
+
|
| 166 |
+
max_exact = num_buckets // 2
|
| 167 |
+
is_small = n < max_exact
|
| 168 |
+
|
| 169 |
+
val_if_large = max_exact + (
|
| 170 |
+
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
| 171 |
+
).long()
|
| 172 |
+
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
| 173 |
+
|
| 174 |
+
ret += torch.where(is_small, n, val_if_large)
|
| 175 |
+
return ret
|
| 176 |
+
|
| 177 |
+
def forward(self, qk_dots):
|
| 178 |
+
i, j, device = *qk_dots.shape[-2:], qk_dots.device
|
| 179 |
+
q_pos = torch.arange(i, dtype=torch.long, device=device)
|
| 180 |
+
k_pos = torch.arange(j, dtype=torch.long, device=device)
|
| 181 |
+
rel_pos = k_pos[None, :] - q_pos[:, None]
|
| 182 |
+
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
|
| 183 |
+
max_distance=self.max_distance)
|
| 184 |
+
values = self.relative_attention_bias(rp_bucket)
|
| 185 |
+
bias = rearrange(values, 'i j h -> () h i j')
|
| 186 |
+
return qk_dots + (bias * self.scale)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class AlibiPositionalBias(nn.Module):
|
| 190 |
+
def __init__(self, heads, **kwargs):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.heads = heads
|
| 193 |
+
slopes = torch.Tensor(self._get_slopes(heads))
|
| 194 |
+
slopes = rearrange(slopes, 'h -> () h () ()')
|
| 195 |
+
self.register_buffer('slopes', slopes, persistent=False)
|
| 196 |
+
self.register_buffer('bias', None, persistent=False)
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def _get_slopes(heads):
|
| 200 |
+
def get_slopes_power_of_2(n):
|
| 201 |
+
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
| 202 |
+
ratio = start
|
| 203 |
+
return [start * ratio ** i for i in range(n)]
|
| 204 |
+
|
| 205 |
+
if math.log2(heads).is_integer():
|
| 206 |
+
return get_slopes_power_of_2(heads)
|
| 207 |
+
|
| 208 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
| 209 |
+
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
|
| 210 |
+
:heads - closest_power_of_2]
|
| 211 |
+
|
| 212 |
+
def forward(self, qk_dots):
|
| 213 |
+
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
| 214 |
+
|
| 215 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
| 216 |
+
return qk_dots + self.bias[..., :j]
|
| 217 |
+
|
| 218 |
+
bias = torch.arange(j, device=device)
|
| 219 |
+
bias = rearrange(bias, 'j -> () () () j')
|
| 220 |
+
bias = bias * self.slopes
|
| 221 |
+
|
| 222 |
+
num_heads_unalibied = h - bias.shape[1]
|
| 223 |
+
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
|
| 224 |
+
|
| 225 |
+
self.register_buffer('bias', bias, persistent=False)
|
| 226 |
+
return qk_dots + self.bias
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
| 230 |
+
def __init__(self, heads, bidirectional=False):
|
| 231 |
+
super().__init__(heads)
|
| 232 |
+
los_slopes = torch.log(self.slopes)
|
| 233 |
+
self.learned_logslopes = nn.Parameter(los_slopes)
|
| 234 |
+
|
| 235 |
+
self.bidirectional = bidirectional
|
| 236 |
+
if self.bidirectional:
|
| 237 |
+
self.learned_logslopes_future = nn.Parameter(los_slopes)
|
| 238 |
+
|
| 239 |
+
def forward(self, qk_dots):
|
| 240 |
+
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
|
| 241 |
+
|
| 242 |
+
def get_slopes(param):
|
| 243 |
+
return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
|
| 244 |
+
|
| 245 |
+
if exists(self.bias) and self.bias.shape[-1] >= j:
|
| 246 |
+
bias = self.bias[..., :i, :j]
|
| 247 |
+
else:
|
| 248 |
+
i_arange = torch.arange(i, device=device)
|
| 249 |
+
j_arange = torch.arange(j, device=device)
|
| 250 |
+
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
|
| 251 |
+
self.register_buffer('bias', bias, persistent=False)
|
| 252 |
+
|
| 253 |
+
if self.bidirectional:
|
| 254 |
+
past_slopes = get_slopes(self.learned_logslopes)
|
| 255 |
+
future_slopes = get_slopes(self.learned_logslopes_future)
|
| 256 |
+
bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
|
| 257 |
+
else:
|
| 258 |
+
slopes = get_slopes(self.learned_logslopes)
|
| 259 |
+
bias = bias * slopes
|
| 260 |
+
|
| 261 |
+
return qk_dots + bias
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class RotaryEmbedding(nn.Module):
|
| 265 |
+
def __init__(self, dim):
|
| 266 |
+
super().__init__()
|
| 267 |
+
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| 268 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 269 |
+
|
| 270 |
+
def forward(self, max_seq_len, device):
|
| 271 |
+
t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
|
| 272 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
| 273 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 274 |
+
return rearrange(emb, 'n d -> () () n d')
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def rotate_half(x):
|
| 278 |
+
x = rearrange(x, '... (j d) -> ... j d', j=2)
|
| 279 |
+
x1, x2 = x.unbind(dim=-2)
|
| 280 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def apply_rotary_pos_emb(t, freqs):
|
| 284 |
+
seq_len = t.shape[-2]
|
| 285 |
+
freqs = freqs[:, :, -seq_len:]
|
| 286 |
+
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# norms
|
| 290 |
+
|
| 291 |
+
class Scale(nn.Module):
|
| 292 |
+
def __init__(self, value, fn):
|
| 293 |
+
super().__init__()
|
| 294 |
+
self.value = value
|
| 295 |
+
self.fn = fn
|
| 296 |
+
|
| 297 |
+
def forward(self, x, **kwargs):
|
| 298 |
+
out = self.fn(x, **kwargs)
|
| 299 |
+
scale_fn = lambda t: t * self.value
|
| 300 |
+
|
| 301 |
+
if not isinstance(out, tuple):
|
| 302 |
+
return scale_fn(out)
|
| 303 |
+
|
| 304 |
+
return (scale_fn(out[0]), *out[1:])
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class Rezero(nn.Module):
|
| 308 |
+
def __init__(self, fn):
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.fn = fn
|
| 311 |
+
self.g = nn.Parameter(torch.zeros(1))
|
| 312 |
+
|
| 313 |
+
def forward(self, x, **kwargs):
|
| 314 |
+
out = self.fn(x, **kwargs)
|
| 315 |
+
rezero_fn = lambda t: t * self.g
|
| 316 |
+
|
| 317 |
+
if not isinstance(out, tuple):
|
| 318 |
+
return rezero_fn(out)
|
| 319 |
+
|
| 320 |
+
return (rezero_fn(out[0]), *out[1:])
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class ScaleNorm(nn.Module):
|
| 324 |
+
def __init__(self, dim, eps=1e-5):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.scale = dim ** -0.5
|
| 327 |
+
self.eps = eps
|
| 328 |
+
self.g = nn.Parameter(torch.ones(1))
|
| 329 |
+
|
| 330 |
+
def forward(self, x):
|
| 331 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
| 332 |
+
return x / norm.clamp(min=self.eps) * self.g
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class RMSNorm(nn.Module):
|
| 336 |
+
def __init__(self, dim, eps=1e-8):
|
| 337 |
+
super().__init__()
|
| 338 |
+
self.scale = dim ** -0.5
|
| 339 |
+
self.eps = eps
|
| 340 |
+
self.g = nn.Parameter(torch.ones(dim))
|
| 341 |
+
|
| 342 |
+
def forward(self, x):
|
| 343 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
| 344 |
+
return x / norm.clamp(min=self.eps) * self.g
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class RMSScaleShiftNorm(nn.Module):
|
| 348 |
+
def __init__(self, dim, eps=1e-8):
|
| 349 |
+
super().__init__()
|
| 350 |
+
self.scale = dim ** -0.5
|
| 351 |
+
self.eps = eps
|
| 352 |
+
self.g = nn.Parameter(torch.ones(dim))
|
| 353 |
+
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
|
| 354 |
+
|
| 355 |
+
def forward(self, x, norm_scale_shift_inp):
|
| 356 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
| 357 |
+
norm = x / norm.clamp(min=self.eps) * self.g
|
| 358 |
+
|
| 359 |
+
ss_emb = self.scale_shift_process(norm_scale_shift_inp)
|
| 360 |
+
scale, shift = torch.chunk(ss_emb, 2, dim=1)
|
| 361 |
+
h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 362 |
+
return h
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# residual and residual gates
|
| 366 |
+
|
| 367 |
+
class Residual(nn.Module):
|
| 368 |
+
def __init__(self, dim, scale_residual=False):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
| 371 |
+
|
| 372 |
+
def forward(self, x, residual):
|
| 373 |
+
if exists(self.residual_scale):
|
| 374 |
+
residual = residual * self.residual_scale
|
| 375 |
+
|
| 376 |
+
return x + residual
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class GRUGating(nn.Module):
|
| 380 |
+
def __init__(self, dim, scale_residual=False):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.gru = nn.GRUCell(dim, dim)
|
| 383 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
| 384 |
+
|
| 385 |
+
def forward(self, x, residual):
|
| 386 |
+
if exists(self.residual_scale):
|
| 387 |
+
residual = residual * self.residual_scale
|
| 388 |
+
|
| 389 |
+
gated_output = self.gru(
|
| 390 |
+
rearrange(x, 'b n d -> (b n) d'),
|
| 391 |
+
rearrange(residual, 'b n d -> (b n) d')
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return gated_output.reshape_as(x)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# token shifting
|
| 398 |
+
|
| 399 |
+
def shift(t, amount, mask=None):
|
| 400 |
+
if amount == 0:
|
| 401 |
+
return t
|
| 402 |
+
|
| 403 |
+
if exists(mask):
|
| 404 |
+
t = t.masked_fill(~mask[..., None], 0.)
|
| 405 |
+
|
| 406 |
+
return F.pad(t, (0, 0, amount, -amount), value=0.)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class ShiftTokens(nn.Module):
|
| 410 |
+
def __init__(self, shifts, fn):
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.fn = fn
|
| 413 |
+
self.shifts = tuple(shifts)
|
| 414 |
+
|
| 415 |
+
def forward(self, x, **kwargs):
|
| 416 |
+
mask = kwargs.get('mask', None)
|
| 417 |
+
shifts = self.shifts
|
| 418 |
+
segments = len(shifts)
|
| 419 |
+
feats_per_shift = x.shape[-1] // segments
|
| 420 |
+
splitted = x.split(feats_per_shift, dim=-1)
|
| 421 |
+
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
| 422 |
+
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
| 423 |
+
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
| 424 |
+
return self.fn(x, **kwargs)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# feedforward
|
| 428 |
+
|
| 429 |
+
class GLU(nn.Module):
|
| 430 |
+
def __init__(self, dim_in, dim_out, activation):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.act = activation
|
| 433 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 434 |
+
|
| 435 |
+
def forward(self, x):
|
| 436 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 437 |
+
return x * self.act(gate)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class FeedForward(nn.Module):
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
dim,
|
| 444 |
+
dim_out=None,
|
| 445 |
+
mult=4,
|
| 446 |
+
glu=False,
|
| 447 |
+
relu_squared=False,
|
| 448 |
+
post_act_ln=False,
|
| 449 |
+
dropout=0.,
|
| 450 |
+
zero_init_output=False
|
| 451 |
+
):
|
| 452 |
+
super().__init__()
|
| 453 |
+
inner_dim = int(dim * mult)
|
| 454 |
+
dim_out = default(dim_out, dim)
|
| 455 |
+
activation = ReluSquared() if relu_squared else nn.GELU()
|
| 456 |
+
|
| 457 |
+
project_in = nn.Sequential(
|
| 458 |
+
nn.Linear(dim, inner_dim),
|
| 459 |
+
activation
|
| 460 |
+
) if not glu else GLU(dim, inner_dim, activation)
|
| 461 |
+
|
| 462 |
+
self.net = nn.Sequential(
|
| 463 |
+
project_in,
|
| 464 |
+
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
| 465 |
+
nn.Dropout(dropout),
|
| 466 |
+
nn.Linear(inner_dim, dim_out)
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# init last linear layer to 0
|
| 470 |
+
if zero_init_output:
|
| 471 |
+
init_zero_(self.net[-1])
|
| 472 |
+
|
| 473 |
+
def forward(self, x):
|
| 474 |
+
return self.net(x)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# attention.
|
| 478 |
+
|
| 479 |
+
class Attention(nn.Module):
|
| 480 |
+
def __init__(
|
| 481 |
+
self,
|
| 482 |
+
dim,
|
| 483 |
+
dim_head=DEFAULT_DIM_HEAD,
|
| 484 |
+
heads=8,
|
| 485 |
+
causal=False,
|
| 486 |
+
talking_heads=False,
|
| 487 |
+
head_scale=False,
|
| 488 |
+
collab_heads=False,
|
| 489 |
+
collab_compression=.3,
|
| 490 |
+
sparse_topk=None,
|
| 491 |
+
use_entmax15=False,
|
| 492 |
+
num_mem_kv=0,
|
| 493 |
+
dropout=0.,
|
| 494 |
+
on_attn=False,
|
| 495 |
+
gate_values=False,
|
| 496 |
+
zero_init_output=False,
|
| 497 |
+
max_attend_past=None,
|
| 498 |
+
qk_norm=False,
|
| 499 |
+
scale_init_value=None,
|
| 500 |
+
rel_pos_bias=False,
|
| 501 |
+
rel_pos_num_buckets=32,
|
| 502 |
+
rel_pos_max_distance=128,
|
| 503 |
+
):
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.scale = dim_head ** -0.5
|
| 506 |
+
|
| 507 |
+
self.heads = heads
|
| 508 |
+
self.causal = causal
|
| 509 |
+
self.max_attend_past = max_attend_past
|
| 510 |
+
|
| 511 |
+
qk_dim = v_dim = dim_head * heads
|
| 512 |
+
|
| 513 |
+
# collaborative heads
|
| 514 |
+
self.collab_heads = collab_heads
|
| 515 |
+
if self.collab_heads:
|
| 516 |
+
qk_dim = int(collab_compression * qk_dim)
|
| 517 |
+
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
|
| 518 |
+
|
| 519 |
+
self.to_q = nn.Linear(dim, qk_dim, bias=False)
|
| 520 |
+
self.to_k = nn.Linear(dim, qk_dim, bias=False)
|
| 521 |
+
self.to_v = nn.Linear(dim, v_dim, bias=False)
|
| 522 |
+
|
| 523 |
+
self.dropout = nn.Dropout(dropout)
|
| 524 |
+
|
| 525 |
+
# add GLU gating for aggregated values, from alphafold2
|
| 526 |
+
self.to_v_gate = None
|
| 527 |
+
if gate_values:
|
| 528 |
+
self.to_v_gate = nn.Linear(dim, v_dim)
|
| 529 |
+
nn.init.constant_(self.to_v_gate.weight, 0)
|
| 530 |
+
nn.init.constant_(self.to_v_gate.bias, 1)
|
| 531 |
+
|
| 532 |
+
# cosine sim attention
|
| 533 |
+
self.qk_norm = qk_norm
|
| 534 |
+
if qk_norm:
|
| 535 |
+
scale_init_value = default(scale_init_value,
|
| 536 |
+
-3) # if not provided, initialize as though it were sequence length of 1024
|
| 537 |
+
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
|
| 538 |
+
|
| 539 |
+
# talking heads
|
| 540 |
+
self.talking_heads = talking_heads
|
| 541 |
+
if talking_heads:
|
| 542 |
+
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
| 543 |
+
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
| 544 |
+
|
| 545 |
+
# head scaling
|
| 546 |
+
self.head_scale = head_scale
|
| 547 |
+
if head_scale:
|
| 548 |
+
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
| 549 |
+
|
| 550 |
+
# explicit topk sparse attention
|
| 551 |
+
self.sparse_topk = sparse_topk
|
| 552 |
+
|
| 553 |
+
# entmax
|
| 554 |
+
self.attn_fn = F.softmax
|
| 555 |
+
|
| 556 |
+
# add memory key / values
|
| 557 |
+
self.num_mem_kv = num_mem_kv
|
| 558 |
+
if num_mem_kv > 0:
|
| 559 |
+
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
| 560 |
+
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
| 561 |
+
|
| 562 |
+
# attention on attention
|
| 563 |
+
self.attn_on_attn = on_attn
|
| 564 |
+
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
|
| 565 |
+
|
| 566 |
+
self.rel_pos_bias = rel_pos_bias
|
| 567 |
+
if rel_pos_bias:
|
| 568 |
+
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
| 569 |
+
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
|
| 570 |
+
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
|
| 571 |
+
|
| 572 |
+
# init output projection 0
|
| 573 |
+
if zero_init_output:
|
| 574 |
+
init_zero_(self.to_out)
|
| 575 |
+
|
| 576 |
+
def forward(
|
| 577 |
+
self,
|
| 578 |
+
x,
|
| 579 |
+
context=None,
|
| 580 |
+
mask=None,
|
| 581 |
+
context_mask=None,
|
| 582 |
+
attn_mask=None,
|
| 583 |
+
sinusoidal_emb=None,
|
| 584 |
+
rotary_pos_emb=None,
|
| 585 |
+
prev_attn=None,
|
| 586 |
+
mem=None,
|
| 587 |
+
layer_past=None,
|
| 588 |
+
):
|
| 589 |
+
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
|
| 590 |
+
context)
|
| 591 |
+
kv_input = default(context, x)
|
| 592 |
+
|
| 593 |
+
q_input = x
|
| 594 |
+
k_input = kv_input
|
| 595 |
+
v_input = kv_input
|
| 596 |
+
|
| 597 |
+
if exists(mem):
|
| 598 |
+
k_input = torch.cat((mem, k_input), dim=-2)
|
| 599 |
+
v_input = torch.cat((mem, v_input), dim=-2)
|
| 600 |
+
|
| 601 |
+
if exists(sinusoidal_emb):
|
| 602 |
+
# in shortformer, the query would start at a position offset depending on the past cached memory
|
| 603 |
+
offset = k_input.shape[-2] - q_input.shape[-2]
|
| 604 |
+
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
| 605 |
+
k_input = k_input + sinusoidal_emb(k_input)
|
| 606 |
+
|
| 607 |
+
q = self.to_q(q_input)
|
| 608 |
+
k = self.to_k(k_input)
|
| 609 |
+
v = self.to_v(v_input)
|
| 610 |
+
|
| 611 |
+
if not collab_heads:
|
| 612 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
| 613 |
+
else:
|
| 614 |
+
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
|
| 615 |
+
k = rearrange(k, 'b n d -> b () n d')
|
| 616 |
+
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
|
| 617 |
+
|
| 618 |
+
if layer_past is not None:
|
| 619 |
+
past_key, past_value = layer_past
|
| 620 |
+
k = torch.cat([past_key, k], dim=-2)
|
| 621 |
+
v = torch.cat([past_value, v], dim=-2)
|
| 622 |
+
k_cache = k
|
| 623 |
+
v_cache = v
|
| 624 |
+
|
| 625 |
+
if exists(rotary_pos_emb) and not has_context:
|
| 626 |
+
l = rotary_pos_emb.shape[-1]
|
| 627 |
+
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
| 628 |
+
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
| 629 |
+
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
| 630 |
+
|
| 631 |
+
input_mask = None
|
| 632 |
+
if any(map(exists, (mask, context_mask))):
|
| 633 |
+
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
| 634 |
+
k_mask = q_mask if not exists(context) else context_mask
|
| 635 |
+
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
| 636 |
+
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
| 637 |
+
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
| 638 |
+
input_mask = q_mask * k_mask
|
| 639 |
+
|
| 640 |
+
if self.num_mem_kv > 0:
|
| 641 |
+
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
| 642 |
+
k = torch.cat((mem_k, k), dim=-2)
|
| 643 |
+
v = torch.cat((mem_v, v), dim=-2)
|
| 644 |
+
if exists(input_mask):
|
| 645 |
+
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
| 646 |
+
|
| 647 |
+
if collab_heads:
|
| 648 |
+
k = k.expand(-1, h, -1, -1)
|
| 649 |
+
|
| 650 |
+
if self.qk_norm:
|
| 651 |
+
q, k = map(l2norm, (q, k))
|
| 652 |
+
scale = 1 / (self.scale.exp().clamp(min=1e-2))
|
| 653 |
+
|
| 654 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
|
| 655 |
+
mask_value = max_neg_value(dots)
|
| 656 |
+
|
| 657 |
+
if exists(prev_attn):
|
| 658 |
+
dots = dots + prev_attn
|
| 659 |
+
|
| 660 |
+
pre_softmax_attn = dots.clone()
|
| 661 |
+
|
| 662 |
+
if talking_heads:
|
| 663 |
+
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
| 664 |
+
|
| 665 |
+
if self.rel_pos_bias:
|
| 666 |
+
dots = self.rel_pos(dots)
|
| 667 |
+
|
| 668 |
+
if exists(input_mask):
|
| 669 |
+
dots.masked_fill_(~input_mask, mask_value)
|
| 670 |
+
del input_mask
|
| 671 |
+
|
| 672 |
+
if exists(attn_mask):
|
| 673 |
+
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
|
| 674 |
+
if attn_mask.ndim == 2:
|
| 675 |
+
attn_mask = rearrange(attn_mask, 'i j -> () () i j')
|
| 676 |
+
elif attn_mask.ndim == 3:
|
| 677 |
+
attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
|
| 678 |
+
dots.masked_fill_(~attn_mask, mask_value)
|
| 679 |
+
|
| 680 |
+
if exists(self.max_attend_past):
|
| 681 |
+
i, j = dots.shape[-2:]
|
| 682 |
+
range_q = torch.arange(j - i, j, device=device)
|
| 683 |
+
range_k = torch.arange(j, device=device)
|
| 684 |
+
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
|
| 685 |
+
mask = dist > self.max_attend_past
|
| 686 |
+
dots.masked_fill_(mask, mask_value)
|
| 687 |
+
del mask
|
| 688 |
+
|
| 689 |
+
if self.causal:
|
| 690 |
+
i, j = dots.shape[-2:]
|
| 691 |
+
r = torch.arange(i, device=device)
|
| 692 |
+
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
| 693 |
+
mask = F.pad(mask, (j - i, 0), value=False)
|
| 694 |
+
dots.masked_fill_(mask, mask_value)
|
| 695 |
+
del mask
|
| 696 |
+
|
| 697 |
+
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
| 698 |
+
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
| 699 |
+
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
| 700 |
+
mask = dots < vk
|
| 701 |
+
dots.masked_fill_(mask, mask_value)
|
| 702 |
+
del mask
|
| 703 |
+
|
| 704 |
+
attn = self.attn_fn(dots, dim=-1)
|
| 705 |
+
post_softmax_attn = attn.clone()
|
| 706 |
+
|
| 707 |
+
attn = self.dropout(attn)
|
| 708 |
+
|
| 709 |
+
if talking_heads:
|
| 710 |
+
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
| 711 |
+
|
| 712 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
| 713 |
+
|
| 714 |
+
if head_scale:
|
| 715 |
+
out = out * self.head_scale_params
|
| 716 |
+
|
| 717 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 718 |
+
|
| 719 |
+
if exists(self.to_v_gate):
|
| 720 |
+
gates = self.to_v_gate(x)
|
| 721 |
+
out = out * gates.sigmoid()
|
| 722 |
+
|
| 723 |
+
intermediates = Intermediates(
|
| 724 |
+
pre_softmax_attn=pre_softmax_attn,
|
| 725 |
+
post_softmax_attn=post_softmax_attn
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
return self.to_out(out), intermediates, k_cache, v_cache
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class AttentionLayers(nn.Module):
|
| 732 |
+
def __init__(
|
| 733 |
+
self,
|
| 734 |
+
dim,
|
| 735 |
+
depth,
|
| 736 |
+
heads=8,
|
| 737 |
+
causal=False,
|
| 738 |
+
cross_attend=False,
|
| 739 |
+
only_cross=False,
|
| 740 |
+
use_scalenorm=False,
|
| 741 |
+
use_rms_scaleshift_norm=False,
|
| 742 |
+
use_rmsnorm=False,
|
| 743 |
+
use_rezero=False,
|
| 744 |
+
alibi_pos_bias=False,
|
| 745 |
+
alibi_num_heads=None,
|
| 746 |
+
alibi_learned=False,
|
| 747 |
+
position_infused_attn=False,
|
| 748 |
+
rotary_pos_emb=False,
|
| 749 |
+
rotary_emb_dim=None,
|
| 750 |
+
custom_layers=None,
|
| 751 |
+
sandwich_coef=None,
|
| 752 |
+
par_ratio=None,
|
| 753 |
+
residual_attn=False,
|
| 754 |
+
cross_residual_attn=False,
|
| 755 |
+
macaron=False,
|
| 756 |
+
pre_norm=True,
|
| 757 |
+
gate_residual=False,
|
| 758 |
+
scale_residual=False,
|
| 759 |
+
shift_tokens=0,
|
| 760 |
+
sandwich_norm=False,
|
| 761 |
+
use_qk_norm_attn=False,
|
| 762 |
+
qk_norm_attn_seq_len=None,
|
| 763 |
+
zero_init_branch_output=False,
|
| 764 |
+
**kwargs
|
| 765 |
+
):
|
| 766 |
+
super().__init__()
|
| 767 |
+
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
| 768 |
+
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
| 769 |
+
|
| 770 |
+
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
| 771 |
+
|
| 772 |
+
self.dim = dim
|
| 773 |
+
self.depth = depth
|
| 774 |
+
self.layers = nn.ModuleList([])
|
| 775 |
+
self.causal = causal
|
| 776 |
+
|
| 777 |
+
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
|
| 778 |
+
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
|
| 779 |
+
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
| 780 |
+
|
| 781 |
+
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
|
| 782 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
|
| 783 |
+
|
| 784 |
+
assert not (
|
| 785 |
+
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
|
| 786 |
+
|
| 787 |
+
if alibi_pos_bias:
|
| 788 |
+
alibi_num_heads = default(alibi_num_heads, heads)
|
| 789 |
+
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
| 790 |
+
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
|
| 791 |
+
self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
|
| 792 |
+
else:
|
| 793 |
+
self.rel_pos = None
|
| 794 |
+
|
| 795 |
+
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
| 796 |
+
self.pre_norm = pre_norm
|
| 797 |
+
self.sandwich_norm = sandwich_norm
|
| 798 |
+
|
| 799 |
+
self.residual_attn = residual_attn
|
| 800 |
+
self.cross_residual_attn = cross_residual_attn
|
| 801 |
+
self.cross_attend = cross_attend
|
| 802 |
+
|
| 803 |
+
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
| 804 |
+
norm_class = RMSNorm if use_rmsnorm else norm_class
|
| 805 |
+
norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
|
| 806 |
+
norm_fn = partial(norm_class, dim)
|
| 807 |
+
|
| 808 |
+
norm_fn = nn.Identity if use_rezero else norm_fn
|
| 809 |
+
branch_fn = Rezero if use_rezero else None
|
| 810 |
+
|
| 811 |
+
if cross_attend and not only_cross:
|
| 812 |
+
default_block = ('a', 'c', 'f')
|
| 813 |
+
elif cross_attend and only_cross:
|
| 814 |
+
default_block = ('c', 'f')
|
| 815 |
+
else:
|
| 816 |
+
default_block = ('a', 'f')
|
| 817 |
+
|
| 818 |
+
if macaron:
|
| 819 |
+
default_block = ('f',) + default_block
|
| 820 |
+
|
| 821 |
+
# qk normalization
|
| 822 |
+
|
| 823 |
+
if use_qk_norm_attn:
|
| 824 |
+
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
|
| 825 |
+
qk_norm_attn_seq_len) else None
|
| 826 |
+
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
| 827 |
+
|
| 828 |
+
# zero init
|
| 829 |
+
|
| 830 |
+
if zero_init_branch_output:
|
| 831 |
+
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
| 832 |
+
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
| 833 |
+
|
| 834 |
+
# calculate layer block order
|
| 835 |
+
|
| 836 |
+
if exists(custom_layers):
|
| 837 |
+
layer_types = custom_layers
|
| 838 |
+
elif exists(par_ratio):
|
| 839 |
+
par_depth = depth * len(default_block)
|
| 840 |
+
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
| 841 |
+
default_block = tuple(filter(not_equals('f'), default_block))
|
| 842 |
+
par_attn = par_depth // par_ratio
|
| 843 |
+
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
| 844 |
+
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
| 845 |
+
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
| 846 |
+
par_block = default_block + ('f',) * (par_width - len(default_block))
|
| 847 |
+
par_head = par_block * par_attn
|
| 848 |
+
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
| 849 |
+
elif exists(sandwich_coef):
|
| 850 |
+
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
| 851 |
+
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
| 852 |
+
else:
|
| 853 |
+
layer_types = default_block * depth
|
| 854 |
+
|
| 855 |
+
self.layer_types = layer_types
|
| 856 |
+
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
| 857 |
+
|
| 858 |
+
# calculate token shifting
|
| 859 |
+
|
| 860 |
+
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
| 861 |
+
|
| 862 |
+
# iterate and construct layers
|
| 863 |
+
|
| 864 |
+
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
| 865 |
+
is_last_layer = ind == (len(self.layer_types) - 1)
|
| 866 |
+
|
| 867 |
+
if layer_type == 'a':
|
| 868 |
+
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
| 869 |
+
elif layer_type == 'c':
|
| 870 |
+
layer = Attention(dim, heads=heads, **attn_kwargs)
|
| 871 |
+
elif layer_type == 'f':
|
| 872 |
+
layer = FeedForward(dim, **ff_kwargs)
|
| 873 |
+
layer = layer if not macaron else Scale(0.5, layer)
|
| 874 |
+
else:
|
| 875 |
+
raise Exception(f'invalid layer type {layer_type}')
|
| 876 |
+
|
| 877 |
+
if layer_shift_tokens > 0:
|
| 878 |
+
shift_range_upper = layer_shift_tokens + 1
|
| 879 |
+
shift_range_lower = -layer_shift_tokens if not causal else 0
|
| 880 |
+
layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
|
| 881 |
+
|
| 882 |
+
if exists(branch_fn):
|
| 883 |
+
layer = branch_fn(layer)
|
| 884 |
+
|
| 885 |
+
residual_fn = GRUGating if gate_residual else Residual
|
| 886 |
+
residual = residual_fn(dim, scale_residual=scale_residual)
|
| 887 |
+
|
| 888 |
+
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
|
| 889 |
+
|
| 890 |
+
pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
|
| 891 |
+
post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
|
| 892 |
+
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
|
| 893 |
+
|
| 894 |
+
norms = nn.ModuleList([
|
| 895 |
+
pre_branch_norm,
|
| 896 |
+
post_branch_norm,
|
| 897 |
+
post_main_norm
|
| 898 |
+
])
|
| 899 |
+
|
| 900 |
+
self.layers.append(nn.ModuleList([
|
| 901 |
+
norms,
|
| 902 |
+
layer,
|
| 903 |
+
residual
|
| 904 |
+
]))
|
| 905 |
+
|
| 906 |
+
def forward(
|
| 907 |
+
self,
|
| 908 |
+
x,
|
| 909 |
+
context=None,
|
| 910 |
+
full_context=None, # for passing a list of hidden states from an encoder
|
| 911 |
+
mask=None,
|
| 912 |
+
context_mask=None,
|
| 913 |
+
attn_mask=None,
|
| 914 |
+
mems=None,
|
| 915 |
+
return_hiddens=False,
|
| 916 |
+
norm_scale_shift_inp=None,
|
| 917 |
+
past_key_values=None,
|
| 918 |
+
expected_seq_len=None,
|
| 919 |
+
):
|
| 920 |
+
|
| 921 |
+
assert not (self.cross_attend ^ (exists(context) or exists(
|
| 922 |
+
full_context))), 'context must be passed in if cross_attend is set to True'
|
| 923 |
+
assert context is None or full_context is None, 'only one of full_context or context can be provided'
|
| 924 |
+
|
| 925 |
+
hiddens = []
|
| 926 |
+
intermediates = []
|
| 927 |
+
prev_attn = None
|
| 928 |
+
prev_cross_attn = None
|
| 929 |
+
|
| 930 |
+
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
| 931 |
+
norm_args = {}
|
| 932 |
+
if exists(norm_scale_shift_inp):
|
| 933 |
+
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
|
| 934 |
+
|
| 935 |
+
rotary_pos_emb = None
|
| 936 |
+
if exists(self.rotary_pos_emb):
|
| 937 |
+
if not self.training and self.causal:
|
| 938 |
+
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
|
| 939 |
+
elif expected_seq_len is None:
|
| 940 |
+
expected_seq_len = 0
|
| 941 |
+
seq_len = x.shape[1]
|
| 942 |
+
if past_key_values is not None:
|
| 943 |
+
seq_len += past_key_values[0][0].shape[-2]
|
| 944 |
+
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
|
| 945 |
+
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
| 946 |
+
|
| 947 |
+
present_key_values = []
|
| 948 |
+
cross_attn_count = 0
|
| 949 |
+
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
| 950 |
+
if layer_type == 'a':
|
| 951 |
+
layer_mem = mems.pop(0) if mems else None
|
| 952 |
+
|
| 953 |
+
residual = x
|
| 954 |
+
|
| 955 |
+
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
| 956 |
+
|
| 957 |
+
if exists(pre_branch_norm):
|
| 958 |
+
x = pre_branch_norm(x, **norm_args)
|
| 959 |
+
|
| 960 |
+
if layer_type == 'a' or layer_type == 'c':
|
| 961 |
+
if past_key_values is not None:
|
| 962 |
+
layer_kv = past_key_values.pop(0)
|
| 963 |
+
layer_past = tuple(s.to(x.device) for s in layer_kv)
|
| 964 |
+
else:
|
| 965 |
+
layer_past = None
|
| 966 |
+
|
| 967 |
+
if layer_type == 'a':
|
| 968 |
+
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
|
| 969 |
+
prev_attn, layer_mem, layer_past)
|
| 970 |
+
elif layer_type == 'c':
|
| 971 |
+
if exists(full_context):
|
| 972 |
+
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
|
| 973 |
+
None, prev_attn, None, layer_past)
|
| 974 |
+
else:
|
| 975 |
+
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
|
| 976 |
+
elif layer_type == 'f':
|
| 977 |
+
out = block(x)
|
| 978 |
+
|
| 979 |
+
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
|
| 980 |
+
present_key_values.append((k.detach(), v.detach()))
|
| 981 |
+
|
| 982 |
+
if exists(post_branch_norm):
|
| 983 |
+
out = post_branch_norm(out, **norm_args)
|
| 984 |
+
|
| 985 |
+
x = residual_fn(out, residual)
|
| 986 |
+
|
| 987 |
+
if layer_type in ('a', 'c'):
|
| 988 |
+
intermediates.append(inter)
|
| 989 |
+
|
| 990 |
+
if layer_type == 'a' and self.residual_attn:
|
| 991 |
+
prev_attn = inter.pre_softmax_attn
|
| 992 |
+
elif layer_type == 'c' and self.cross_residual_attn:
|
| 993 |
+
prev_cross_attn = inter.pre_softmax_attn
|
| 994 |
+
|
| 995 |
+
if exists(post_main_norm):
|
| 996 |
+
x = post_main_norm(x, **norm_args)
|
| 997 |
+
|
| 998 |
+
if layer_type == 'c':
|
| 999 |
+
cross_attn_count += 1
|
| 1000 |
+
|
| 1001 |
+
if layer_type == 'f':
|
| 1002 |
+
hiddens.append(x)
|
| 1003 |
+
|
| 1004 |
+
if return_hiddens:
|
| 1005 |
+
intermediates = LayerIntermediates(
|
| 1006 |
+
hiddens=hiddens,
|
| 1007 |
+
attn_intermediates=intermediates,
|
| 1008 |
+
past_key_values=present_key_values
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
return x, intermediates
|
| 1012 |
+
|
| 1013 |
+
return x
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
class Encoder(AttentionLayers):
|
| 1017 |
+
def __init__(self, **kwargs):
|
| 1018 |
+
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
| 1019 |
+
super().__init__(causal=False, **kwargs)
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
class Decoder(AttentionLayers):
|
| 1023 |
+
def __init__(self, **kwargs):
|
| 1024 |
+
assert 'causal' not in kwargs, 'cannot set causality on decoder'
|
| 1025 |
+
super().__init__(causal=True, **kwargs)
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
class CrossAttender(AttentionLayers):
|
| 1029 |
+
def __init__(self, **kwargs):
|
| 1030 |
+
super().__init__(cross_attend=True, only_cross=True, **kwargs)
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
class ViTransformerWrapper(nn.Module):
|
| 1034 |
+
def __init__(
|
| 1035 |
+
self,
|
| 1036 |
+
*,
|
| 1037 |
+
image_size,
|
| 1038 |
+
patch_size,
|
| 1039 |
+
attn_layers,
|
| 1040 |
+
num_classes=None,
|
| 1041 |
+
dropout=0.,
|
| 1042 |
+
emb_dropout=0.
|
| 1043 |
+
):
|
| 1044 |
+
super().__init__()
|
| 1045 |
+
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
|
| 1046 |
+
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
| 1047 |
+
dim = attn_layers.dim
|
| 1048 |
+
num_patches = (image_size // patch_size) ** 2
|
| 1049 |
+
patch_dim = 3 * patch_size ** 2
|
| 1050 |
+
|
| 1051 |
+
self.patch_size = patch_size
|
| 1052 |
+
|
| 1053 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
| 1054 |
+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
| 1055 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
| 1056 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 1057 |
+
|
| 1058 |
+
self.attn_layers = attn_layers
|
| 1059 |
+
self.norm = nn.LayerNorm(dim)
|
| 1060 |
+
self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
|
| 1061 |
+
|
| 1062 |
+
def forward(
|
| 1063 |
+
self,
|
| 1064 |
+
img,
|
| 1065 |
+
return_embeddings=False
|
| 1066 |
+
):
|
| 1067 |
+
p = self.patch_size
|
| 1068 |
+
|
| 1069 |
+
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
| 1070 |
+
x = self.patch_to_embedding(x)
|
| 1071 |
+
b, n, _ = x.shape
|
| 1072 |
+
|
| 1073 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
| 1074 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 1075 |
+
x = x + self.pos_embedding[:, :(n + 1)]
|
| 1076 |
+
x = self.dropout(x)
|
| 1077 |
+
|
| 1078 |
+
x = self.attn_layers(x)
|
| 1079 |
+
x = self.norm(x)
|
| 1080 |
+
|
| 1081 |
+
if not exists(self.mlp_head) or return_embeddings:
|
| 1082 |
+
return x
|
| 1083 |
+
|
| 1084 |
+
return self.mlp_head(x[:, 0])
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
class TransformerWrapper(nn.Module):
|
| 1088 |
+
def __init__(
|
| 1089 |
+
self,
|
| 1090 |
+
*,
|
| 1091 |
+
num_tokens,
|
| 1092 |
+
max_seq_len,
|
| 1093 |
+
attn_layers,
|
| 1094 |
+
emb_dim=None,
|
| 1095 |
+
max_mem_len=0.,
|
| 1096 |
+
shift_mem_down=0,
|
| 1097 |
+
emb_dropout=0.,
|
| 1098 |
+
num_memory_tokens=None,
|
| 1099 |
+
tie_embedding=False,
|
| 1100 |
+
use_pos_emb=True
|
| 1101 |
+
):
|
| 1102 |
+
super().__init__()
|
| 1103 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
| 1104 |
+
|
| 1105 |
+
dim = attn_layers.dim
|
| 1106 |
+
emb_dim = default(emb_dim, dim)
|
| 1107 |
+
|
| 1108 |
+
self.max_seq_len = max_seq_len
|
| 1109 |
+
self.max_mem_len = max_mem_len
|
| 1110 |
+
self.shift_mem_down = shift_mem_down
|
| 1111 |
+
|
| 1112 |
+
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
| 1113 |
+
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
| 1114 |
+
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
| 1115 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
| 1116 |
+
|
| 1117 |
+
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
| 1118 |
+
self.attn_layers = attn_layers
|
| 1119 |
+
self.norm = nn.LayerNorm(dim)
|
| 1120 |
+
|
| 1121 |
+
self.init_()
|
| 1122 |
+
|
| 1123 |
+
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
| 1124 |
+
|
| 1125 |
+
# memory tokens (like [cls]) from Memory Transformers paper
|
| 1126 |
+
num_memory_tokens = default(num_memory_tokens, 0)
|
| 1127 |
+
self.num_memory_tokens = num_memory_tokens
|
| 1128 |
+
if num_memory_tokens > 0:
|
| 1129 |
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
| 1130 |
+
|
| 1131 |
+
def init_(self):
|
| 1132 |
+
nn.init.kaiming_normal_(self.token_emb.weight)
|
| 1133 |
+
|
| 1134 |
+
def forward(
|
| 1135 |
+
self,
|
| 1136 |
+
x,
|
| 1137 |
+
return_embeddings=False,
|
| 1138 |
+
mask=None,
|
| 1139 |
+
return_hiddens=False,
|
| 1140 |
+
return_attn=False,
|
| 1141 |
+
mems=None,
|
| 1142 |
+
use_cache=False,
|
| 1143 |
+
**kwargs
|
| 1144 |
+
):
|
| 1145 |
+
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
| 1146 |
+
x = self.token_emb(x)
|
| 1147 |
+
x = x + self.pos_emb(x)
|
| 1148 |
+
x = self.emb_dropout(x)
|
| 1149 |
+
|
| 1150 |
+
x = self.project_emb(x)
|
| 1151 |
+
|
| 1152 |
+
if num_mem > 0:
|
| 1153 |
+
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
| 1154 |
+
x = torch.cat((mem, x), dim=1)
|
| 1155 |
+
|
| 1156 |
+
# auto-handle masking after appending memory tokens
|
| 1157 |
+
if exists(mask):
|
| 1158 |
+
mask = F.pad(mask, (num_mem, 0), value=True)
|
| 1159 |
+
|
| 1160 |
+
if self.shift_mem_down and exists(mems):
|
| 1161 |
+
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
| 1162 |
+
mems = [*mems_r, *mems_l]
|
| 1163 |
+
|
| 1164 |
+
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
| 1165 |
+
x = self.norm(x)
|
| 1166 |
+
|
| 1167 |
+
mem, x = x[:, :num_mem], x[:, num_mem:]
|
| 1168 |
+
|
| 1169 |
+
out = self.to_logits(x) if not return_embeddings else x
|
| 1170 |
+
|
| 1171 |
+
if return_hiddens:
|
| 1172 |
+
hiddens = intermediates.hiddens
|
| 1173 |
+
return out, hiddens
|
| 1174 |
+
|
| 1175 |
+
res = [out]
|
| 1176 |
+
if return_attn:
|
| 1177 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
| 1178 |
+
res.append(attn_maps)
|
| 1179 |
+
if use_cache:
|
| 1180 |
+
res.append(intermediates.past_key_values)
|
| 1181 |
+
|
| 1182 |
+
if len(res) > 1:
|
| 1183 |
+
return tuple(res)
|
| 1184 |
+
return res[0]
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
class ContinuousTransformerWrapper(nn.Module):
|
| 1188 |
+
def __init__(
|
| 1189 |
+
self,
|
| 1190 |
+
*,
|
| 1191 |
+
max_seq_len,
|
| 1192 |
+
attn_layers,
|
| 1193 |
+
dim_in=None,
|
| 1194 |
+
dim_out=None,
|
| 1195 |
+
emb_dim=None,
|
| 1196 |
+
emb_dropout=0.,
|
| 1197 |
+
use_pos_emb=True
|
| 1198 |
+
):
|
| 1199 |
+
super().__init__()
|
| 1200 |
+
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
| 1201 |
+
|
| 1202 |
+
dim = attn_layers.dim
|
| 1203 |
+
|
| 1204 |
+
self.max_seq_len = max_seq_len
|
| 1205 |
+
|
| 1206 |
+
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
|
| 1207 |
+
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
| 1208 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
| 1209 |
+
|
| 1210 |
+
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
|
| 1211 |
+
|
| 1212 |
+
self.attn_layers = attn_layers
|
| 1213 |
+
self.norm = nn.LayerNorm(dim)
|
| 1214 |
+
|
| 1215 |
+
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
|
| 1216 |
+
|
| 1217 |
+
def forward(
|
| 1218 |
+
self,
|
| 1219 |
+
x,
|
| 1220 |
+
return_embeddings=False,
|
| 1221 |
+
mask=None,
|
| 1222 |
+
return_attn=False,
|
| 1223 |
+
mems=None,
|
| 1224 |
+
use_cache=False,
|
| 1225 |
+
**kwargs
|
| 1226 |
+
):
|
| 1227 |
+
b, n, _, device = *x.shape, x.device
|
| 1228 |
+
|
| 1229 |
+
x = self.project_in(x)
|
| 1230 |
+
x = x + self.pos_emb(x)
|
| 1231 |
+
x = self.emb_dropout(x)
|
| 1232 |
+
|
| 1233 |
+
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
| 1234 |
+
x = self.norm(x)
|
| 1235 |
+
|
| 1236 |
+
out = self.project_out(x) if not return_embeddings else x
|
| 1237 |
+
|
| 1238 |
+
res = [out]
|
| 1239 |
+
if return_attn:
|
| 1240 |
+
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
| 1241 |
+
res.append(attn_maps)
|
| 1242 |
+
if use_cache:
|
| 1243 |
+
res.append(intermediates.past_key_values)
|
| 1244 |
+
|
| 1245 |
+
if len(res) > 1:
|
| 1246 |
+
return tuple(res)
|
| 1247 |
+
return res[0]
|
| 1248 |
+
|
speech/tools/download_dataset.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import time
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
| 6 |
+
from multiprocessing import cpu_count, Queue, Process
|
| 7 |
+
from threading import Lock
|
| 8 |
+
from queue import Empty
|
| 9 |
+
import logging
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
import soundfile as sf
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
|
| 19 |
+
# Set up logging
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Global settings
|
| 24 |
+
CHUNK_SIZE = 100 # Process this many samples before showing progress
|
| 25 |
+
MAX_WORKERS = cpu_count() - 1 # Leave one CPU free
|
| 26 |
+
PREFETCH_SIZE = 50 # How many samples to prefetch
|
| 27 |
+
|
| 28 |
+
def process_single_sample(data, root_save_path):
|
| 29 |
+
"""Process a single sample from the dataset"""
|
| 30 |
+
try:
|
| 31 |
+
metadata = data['json']
|
| 32 |
+
|
| 33 |
+
# Prepare paths
|
| 34 |
+
out_wav_path = os.path.join(root_save_path, metadata['wav'].replace('/mp3', '').replace('.mp3', '.wav'))
|
| 35 |
+
os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
|
| 36 |
+
out_text_path = out_wav_path.replace('.wav', '.txt')
|
| 37 |
+
# check if the files are existsed
|
| 38 |
+
if os.path.exists(out_wav_path) and os.path.exists(out_text_path):
|
| 39 |
+
return metadata['id'], True, None
|
| 40 |
+
|
| 41 |
+
# Save text
|
| 42 |
+
with open(out_text_path, 'w') as f:
|
| 43 |
+
f.write(metadata['text'])
|
| 44 |
+
|
| 45 |
+
# Decode and save audio
|
| 46 |
+
raw_bytes = data['mp3']._hf_encoded['bytes']
|
| 47 |
+
audio_tensor, sample_rate = torchaudio.load(io.BytesIO(raw_bytes), format='mp3')
|
| 48 |
+
audio_array = audio_tensor.squeeze().numpy()
|
| 49 |
+
sf.write(out_wav_path, audio_array, sample_rate)
|
| 50 |
+
|
| 51 |
+
return metadata['id'], True, None
|
| 52 |
+
except Exception as e:
|
| 53 |
+
return data.get('json', {}).get('id', 'unknown'), False, str(e)
|
| 54 |
+
|
| 55 |
+
def process_batch(batch_data, root_save_path):
|
| 56 |
+
"""Process a batch of samples"""
|
| 57 |
+
results = []
|
| 58 |
+
for data in batch_data:
|
| 59 |
+
result = process_single_sample(data, root_save_path)
|
| 60 |
+
results.append(result)
|
| 61 |
+
return results
|
| 62 |
+
|
| 63 |
+
class ParallelDatasetProcessor:
|
| 64 |
+
"""Main class for parallel processing of the dataset"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, language, root_save_path, num_workers=None):
|
| 67 |
+
self.language = language
|
| 68 |
+
self.root_save_path = root_save_path
|
| 69 |
+
self.num_workers = num_workers or MAX_WORKERS
|
| 70 |
+
self.processed_count = 0
|
| 71 |
+
self.error_count = 0
|
| 72 |
+
self.lock = Lock()
|
| 73 |
+
|
| 74 |
+
def process_with_multiprocessing(self):
|
| 75 |
+
"""Process dataset using multiprocessing (fastest for CPU-bound tasks)"""
|
| 76 |
+
logger.info(f"Starting multiprocessing with {self.num_workers} workers")
|
| 77 |
+
|
| 78 |
+
# Load dataset
|
| 79 |
+
path = f"Emilia/{self.language.upper()}/*.tar"
|
| 80 |
+
dataset = load_dataset(
|
| 81 |
+
"amphion/Emilia-Dataset",
|
| 82 |
+
data_files={self.language: path},
|
| 83 |
+
split=self.language,
|
| 84 |
+
streaming=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Create process pool
|
| 88 |
+
with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
|
| 89 |
+
# Submit jobs in batches
|
| 90 |
+
futures = []
|
| 91 |
+
batch = []
|
| 92 |
+
|
| 93 |
+
# Progress bar
|
| 94 |
+
pbar = tqdm(desc="Processing samples", unit="samples")
|
| 95 |
+
|
| 96 |
+
for data in dataset:
|
| 97 |
+
batch.append(data)
|
| 98 |
+
|
| 99 |
+
if len(batch) >= CHUNK_SIZE:
|
| 100 |
+
# Submit batch for processing
|
| 101 |
+
future = executor.submit(process_batch, batch, self.root_save_path)
|
| 102 |
+
futures.append(future)
|
| 103 |
+
batch = []
|
| 104 |
+
|
| 105 |
+
# Process completed futures
|
| 106 |
+
self._process_completed_futures(futures, pbar, max_pending=self.num_workers * 2)
|
| 107 |
+
|
| 108 |
+
# Submit remaining batch
|
| 109 |
+
if batch:
|
| 110 |
+
future = executor.submit(process_batch, batch, self.root_save_path)
|
| 111 |
+
futures.append(future)
|
| 112 |
+
|
| 113 |
+
# Wait for all remaining futures
|
| 114 |
+
for future in as_completed(futures):
|
| 115 |
+
results = future.result()
|
| 116 |
+
for sample_id, success, error in results:
|
| 117 |
+
if success:
|
| 118 |
+
self.processed_count += 1
|
| 119 |
+
else:
|
| 120 |
+
self.error_count += 1
|
| 121 |
+
logger.error(f"Error processing {sample_id}: {error}")
|
| 122 |
+
pbar.update(1)
|
| 123 |
+
|
| 124 |
+
pbar.close()
|
| 125 |
+
|
| 126 |
+
def process_with_threading(self):
|
| 127 |
+
"""Process dataset using threading (good for I/O-bound tasks)"""
|
| 128 |
+
logger.info(f"Starting threading with {self.num_workers} workers")
|
| 129 |
+
|
| 130 |
+
# Load dataset
|
| 131 |
+
path = f"Emilia/{self.language.upper()}/*.tar"
|
| 132 |
+
dataset = load_dataset(
|
| 133 |
+
"amphion/Emilia-Dataset",
|
| 134 |
+
data_files={self.language: path},
|
| 135 |
+
split=self.language,
|
| 136 |
+
streaming=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Create thread pool
|
| 140 |
+
with ThreadPoolExecutor(max_workers=self.num_workers * 2) as executor:
|
| 141 |
+
futures = []
|
| 142 |
+
pbar = tqdm(desc="Processing samples", unit="samples")
|
| 143 |
+
|
| 144 |
+
for data in dataset:
|
| 145 |
+
# Submit individual samples
|
| 146 |
+
future = executor.submit(process_single_sample, data, self.root_save_path)
|
| 147 |
+
futures.append((future, data.get('json', {}).get('id', 'unknown')))
|
| 148 |
+
|
| 149 |
+
# Process completed futures
|
| 150 |
+
if len(futures) >= self.num_workers * 4:
|
| 151 |
+
completed = []
|
| 152 |
+
for i, (future, sample_id) in enumerate(futures):
|
| 153 |
+
if future.done():
|
| 154 |
+
try:
|
| 155 |
+
_, success, error = future.result()
|
| 156 |
+
if success:
|
| 157 |
+
self.processed_count += 1
|
| 158 |
+
else:
|
| 159 |
+
self.error_count += 1
|
| 160 |
+
logger.error(f"Error processing {sample_id}: {error}")
|
| 161 |
+
pbar.update(1)
|
| 162 |
+
completed.append(i)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Exception processing {sample_id}: {e}")
|
| 165 |
+
self.error_count += 1
|
| 166 |
+
pbar.update(1)
|
| 167 |
+
completed.append(i)
|
| 168 |
+
|
| 169 |
+
# Remove completed futures
|
| 170 |
+
for i in reversed(completed):
|
| 171 |
+
futures.pop(i)
|
| 172 |
+
|
| 173 |
+
# Wait for remaining futures
|
| 174 |
+
for future, sample_id in futures:
|
| 175 |
+
try:
|
| 176 |
+
_, success, error = future.result()
|
| 177 |
+
if success:
|
| 178 |
+
self.processed_count += 1
|
| 179 |
+
else:
|
| 180 |
+
self.error_count += 1
|
| 181 |
+
logger.error(f"Error processing {sample_id}: {error}")
|
| 182 |
+
pbar.update(1)
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"Exception processing {sample_id}: {e}")
|
| 185 |
+
self.error_count += 1
|
| 186 |
+
pbar.update(1)
|
| 187 |
+
|
| 188 |
+
pbar.close()
|
| 189 |
+
|
| 190 |
+
def process_with_producer_consumer(self):
|
| 191 |
+
"""Process dataset using producer-consumer pattern"""
|
| 192 |
+
logger.info(f"Starting producer-consumer with {self.num_workers} workers")
|
| 193 |
+
|
| 194 |
+
# Create queues
|
| 195 |
+
work_queue = Queue(maxsize=PREFETCH_SIZE)
|
| 196 |
+
result_queue = Queue()
|
| 197 |
+
|
| 198 |
+
# Start producer
|
| 199 |
+
producer = Process(target=self._producer, args=(work_queue,))
|
| 200 |
+
producer.start()
|
| 201 |
+
|
| 202 |
+
# Start consumers
|
| 203 |
+
consumers = []
|
| 204 |
+
for i in range(self.num_workers):
|
| 205 |
+
consumer = Process(target=self._consumer, args=(work_queue, result_queue, i))
|
| 206 |
+
consumer.start()
|
| 207 |
+
consumers.append(consumer)
|
| 208 |
+
|
| 209 |
+
# Start result processor
|
| 210 |
+
result_processor = Process(target=self._result_processor, args=(result_queue,))
|
| 211 |
+
result_processor.start()
|
| 212 |
+
|
| 213 |
+
# Wait for completion
|
| 214 |
+
producer.join()
|
| 215 |
+
|
| 216 |
+
# Signal consumers to stop
|
| 217 |
+
for _ in range(self.num_workers):
|
| 218 |
+
work_queue.put(None)
|
| 219 |
+
|
| 220 |
+
for consumer in consumers:
|
| 221 |
+
consumer.join()
|
| 222 |
+
|
| 223 |
+
# Signal result processor to stop
|
| 224 |
+
result_queue.put(None)
|
| 225 |
+
result_processor.join()
|
| 226 |
+
|
| 227 |
+
def _producer(self, work_queue):
|
| 228 |
+
"""Producer process that reads from dataset"""
|
| 229 |
+
path = f"Emilia/{self.language.upper()}/*.tar"
|
| 230 |
+
dataset = load_dataset(
|
| 231 |
+
"amphion/Emilia-Dataset",
|
| 232 |
+
data_files={self.language: path},
|
| 233 |
+
split=self.language,
|
| 234 |
+
streaming=True
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
for data in dataset:
|
| 238 |
+
work_queue.put(data)
|
| 239 |
+
|
| 240 |
+
def _consumer(self, work_queue, result_queue, worker_id):
|
| 241 |
+
"""Consumer process that processes samples"""
|
| 242 |
+
while True:
|
| 243 |
+
try:
|
| 244 |
+
data = work_queue.get(timeout=1)
|
| 245 |
+
if data is None:
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
result = process_single_sample(data, self.root_save_path)
|
| 249 |
+
result_queue.put(result)
|
| 250 |
+
except Empty:
|
| 251 |
+
continue
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(f"Worker {worker_id} error: {e}")
|
| 254 |
+
|
| 255 |
+
def _result_processor(self, result_queue):
|
| 256 |
+
"""Process results and update progress"""
|
| 257 |
+
pbar = tqdm(desc="Processing samples", unit="samples")
|
| 258 |
+
|
| 259 |
+
while True:
|
| 260 |
+
try:
|
| 261 |
+
result = result_queue.get(timeout=1)
|
| 262 |
+
if result is None:
|
| 263 |
+
break
|
| 264 |
+
|
| 265 |
+
sample_id, success, error = result
|
| 266 |
+
if success:
|
| 267 |
+
self.processed_count += 1
|
| 268 |
+
else:
|
| 269 |
+
self.error_count += 1
|
| 270 |
+
logger.error(f"Error processing {sample_id}: {error}")
|
| 271 |
+
pbar.update(1)
|
| 272 |
+
except Empty:
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
pbar.close()
|
| 276 |
+
|
| 277 |
+
def _process_completed_futures(self, futures, pbar, max_pending):
|
| 278 |
+
"""Process completed futures to avoid memory buildup"""
|
| 279 |
+
while len(futures) > max_pending:
|
| 280 |
+
# Wait for at least one to complete
|
| 281 |
+
completed_futures = []
|
| 282 |
+
for i, future in enumerate(futures):
|
| 283 |
+
if future.done():
|
| 284 |
+
results = future.result()
|
| 285 |
+
for sample_id, success, error in results:
|
| 286 |
+
if success:
|
| 287 |
+
self.processed_count += 1
|
| 288 |
+
else:
|
| 289 |
+
self.error_count += 1
|
| 290 |
+
logger.error(f"Error processing {sample_id}: {error}")
|
| 291 |
+
pbar.update(1)
|
| 292 |
+
completed_futures.append(i)
|
| 293 |
+
|
| 294 |
+
# Remove completed futures
|
| 295 |
+
for i in reversed(completed_futures):
|
| 296 |
+
futures.pop(i)
|
| 297 |
+
|
| 298 |
+
if not completed_futures:
|
| 299 |
+
# If nothing completed, wait a bit
|
| 300 |
+
time.sleep(0.1)
|
| 301 |
+
|
| 302 |
+
def main():
|
| 303 |
+
"""Main function with different processing options"""
|
| 304 |
+
language = "en"
|
| 305 |
+
root_save_path = f'/data/emilia/{language}'
|
| 306 |
+
os.makedirs(root_save_path, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
# Choose processing method
|
| 309 |
+
processor = ParallelDatasetProcessor(language, root_save_path)
|
| 310 |
+
|
| 311 |
+
# Method 1: Multiprocessing (recommended for CPU-bound tasks)
|
| 312 |
+
logger.info("Starting parallel processing...")
|
| 313 |
+
start_time = time.time()
|
| 314 |
+
|
| 315 |
+
# You can choose one of these methods:
|
| 316 |
+
processor.process_with_multiprocessing() # Fastest for this use case
|
| 317 |
+
# processor.process_with_threading() # Good for I/O heavy tasks
|
| 318 |
+
# processor.process_with_producer_consumer() # Good for memory efficiency
|
| 319 |
+
|
| 320 |
+
elapsed_time = time.time() - start_time
|
| 321 |
+
logger.info(f"Processing completed in {elapsed_time:.2f} seconds")
|
| 322 |
+
logger.info(f"Successfully processed: {processor.processed_count} samples")
|
| 323 |
+
logger.info(f"Errors: {processor.error_count} samples")
|
| 324 |
+
|
| 325 |
+
# Simplified version for quick use
|
| 326 |
+
def simple_parallel_process():
|
| 327 |
+
"""Simplified parallel processing function"""
|
| 328 |
+
from joblib import Parallel, delayed
|
| 329 |
+
|
| 330 |
+
language = "en"
|
| 331 |
+
root_save_path = f'/data/emilia/{language}'
|
| 332 |
+
os.makedirs(root_save_path, exist_ok=True)
|
| 333 |
+
|
| 334 |
+
# Load dataset
|
| 335 |
+
path = f"Emilia/{language.upper()}/*.tar"
|
| 336 |
+
dataset = load_dataset(
|
| 337 |
+
"amphion/Emilia-Dataset",
|
| 338 |
+
data_files={language: path},
|
| 339 |
+
split=language,
|
| 340 |
+
streaming=True
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Collect samples in batches
|
| 344 |
+
batch_size = 1000
|
| 345 |
+
batch = []
|
| 346 |
+
|
| 347 |
+
def process_batch_simple(batch_data):
|
| 348 |
+
for data in batch_data:
|
| 349 |
+
process_single_sample(data, root_save_path)
|
| 350 |
+
|
| 351 |
+
# Process in parallel batches
|
| 352 |
+
for i, data in enumerate(tqdm(dataset, desc="Loading samples")):
|
| 353 |
+
batch.append(data)
|
| 354 |
+
|
| 355 |
+
if len(batch) >= batch_size:
|
| 356 |
+
# Process batch in parallel
|
| 357 |
+
Parallel(n_jobs=-1, backend='multiprocessing')(
|
| 358 |
+
delayed(process_single_sample)(sample, root_save_path)
|
| 359 |
+
for sample in batch
|
| 360 |
+
)
|
| 361 |
+
batch = []
|
| 362 |
+
|
| 363 |
+
# Process remaining batch
|
| 364 |
+
if batch:
|
| 365 |
+
Parallel(n_jobs=-1, backend='multiprocessing')(
|
| 366 |
+
delayed(process_single_sample)(sample, root_save_path)
|
| 367 |
+
for sample in batch
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
main()
|
speech/tools/download_vivoice.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
| 4 |
+
from functools import partial
|
| 5 |
+
import multiprocessing
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import torchaudio
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def process_single_item(data: Dict[str, Any], root_save_path: str) -> Optional[str]:
|
| 15 |
+
"""
|
| 16 |
+
Process a single audio item: extract audio, convert, and save with text.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
data: Dictionary containing audio data and text
|
| 20 |
+
root_save_path: Root directory for saving files
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Audio file path if successful, None if error
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
# Extract audio data
|
| 27 |
+
raw_bytes = data['audio']._hf_encoded['bytes']
|
| 28 |
+
audio_name = data['audio']._hf_encoded['path']
|
| 29 |
+
# Prepare output paths
|
| 30 |
+
out_wav_path = os.path.join(root_save_path, audio_name)
|
| 31 |
+
out_text_path = out_wav_path.replace('.wav', '.txt')
|
| 32 |
+
|
| 33 |
+
if os.path.exists(out_wav_path) and os.path.exists(out_text_path):
|
| 34 |
+
# print(f'skip {out_wav_path}')
|
| 35 |
+
return out_wav_path
|
| 36 |
+
|
| 37 |
+
text = data['text']
|
| 38 |
+
|
| 39 |
+
# Load audio from bytes
|
| 40 |
+
bytes_io = io.BytesIO(raw_bytes)
|
| 41 |
+
audio_tensor, sample_rate = torchaudio.load(bytes_io)
|
| 42 |
+
audio_array = audio_tensor.squeeze().numpy()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Create directory if needed
|
| 47 |
+
os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
|
| 48 |
+
|
| 49 |
+
# Save audio and text
|
| 50 |
+
sf.write(out_wav_path, audio_array, sample_rate)
|
| 51 |
+
with open(out_text_path, 'w', encoding='utf-8') as f:
|
| 52 |
+
f.write(text)
|
| 53 |
+
|
| 54 |
+
return out_wav_path
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error processing {data.get('audio', {})._hf_encoded.get('path', 'unknown')}: {e}")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def process_dataset_parallel(
|
| 62 |
+
dataset_name: str = "capleaf/viVoice",
|
| 63 |
+
root_save_path: str = '/mnt/nvme-temp/vivoice',
|
| 64 |
+
max_workers: Optional[int] = None,
|
| 65 |
+
batch_size: int = 100,
|
| 66 |
+
limit: Optional[int] = None,
|
| 67 |
+
use_threads: bool = False
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Process dataset in parallel with progress tracking.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
dataset_name: Name of the HuggingFace dataset
|
| 74 |
+
root_save_path: Root directory for saving files
|
| 75 |
+
max_workers: Maximum number of parallel workers (None for CPU count)
|
| 76 |
+
batch_size: Number of items to process in each batch
|
| 77 |
+
limit: Maximum number of items to process (None for all)
|
| 78 |
+
use_threads: Use ThreadPoolExecutor instead of ProcessPoolExecutor
|
| 79 |
+
"""
|
| 80 |
+
# Load dataset
|
| 81 |
+
ds = load_dataset(dataset_name, streaming=True)
|
| 82 |
+
|
| 83 |
+
# Set up executor
|
| 84 |
+
if max_workers is None:
|
| 85 |
+
max_workers = multiprocessing.cpu_count()
|
| 86 |
+
|
| 87 |
+
Executor = ThreadPoolExecutor if use_threads else ProcessPoolExecutor
|
| 88 |
+
|
| 89 |
+
# Process each split
|
| 90 |
+
for mode in ds:
|
| 91 |
+
print(f"\nProcessing '{mode}' split...")
|
| 92 |
+
|
| 93 |
+
# Create partial function with root_save_path
|
| 94 |
+
process_func = partial(process_single_item, root_save_path=root_save_path)
|
| 95 |
+
|
| 96 |
+
# Collect items in batches for better progress tracking
|
| 97 |
+
batch = []
|
| 98 |
+
processed_count = 0
|
| 99 |
+
|
| 100 |
+
with Executor(max_workers=max_workers) as executor:
|
| 101 |
+
# Create progress bar
|
| 102 |
+
pbar = tqdm(desc=f"Processing {mode}", unit="files")
|
| 103 |
+
|
| 104 |
+
for idx, data in enumerate(ds[mode]):
|
| 105 |
+
batch.append(data)
|
| 106 |
+
|
| 107 |
+
# Process batch when full or at limit
|
| 108 |
+
if len(batch) >= batch_size or (limit and idx + 1 >= limit):
|
| 109 |
+
# Submit batch for processing
|
| 110 |
+
futures = [executor.submit(process_func, item) for item in batch]
|
| 111 |
+
|
| 112 |
+
# Wait for completion and update progress
|
| 113 |
+
for future in futures:
|
| 114 |
+
result = future.result()
|
| 115 |
+
if result:
|
| 116 |
+
processed_count += 1
|
| 117 |
+
pbar.update(1)
|
| 118 |
+
|
| 119 |
+
batch = []
|
| 120 |
+
|
| 121 |
+
# Stop if limit reached
|
| 122 |
+
if limit and idx + 1 >= limit:
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
# Process remaining items
|
| 126 |
+
if batch:
|
| 127 |
+
futures = [executor.submit(process_func, item) for item in batch]
|
| 128 |
+
for future in futures:
|
| 129 |
+
result = future.result()
|
| 130 |
+
if result:
|
| 131 |
+
processed_count += 1
|
| 132 |
+
pbar.update(1)
|
| 133 |
+
|
| 134 |
+
pbar.close()
|
| 135 |
+
|
| 136 |
+
print(f"Completed {mode}: {processed_count} files processed successfully")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def process_dataset_streaming(
|
| 140 |
+
dataset_name: str = "capleaf/viVoice",
|
| 141 |
+
root_save_path: str = '/data/vivoice',
|
| 142 |
+
max_workers: Optional[int] = None,
|
| 143 |
+
limit: Optional[int] = None
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Alternative approach using streaming with thread pool for I/O bound operations.
|
| 147 |
+
More memory efficient for very large datasets.
|
| 148 |
+
"""
|
| 149 |
+
ds = load_dataset(dataset_name, streaming=True)
|
| 150 |
+
|
| 151 |
+
if max_workers is None:
|
| 152 |
+
max_workers = min(32, multiprocessing.cpu_count() * 4) # More threads for I/O
|
| 153 |
+
|
| 154 |
+
for mode in ds:
|
| 155 |
+
print(f"\nProcessing '{mode}' split...")
|
| 156 |
+
|
| 157 |
+
process_func = partial(process_single_item, root_save_path=root_save_path)
|
| 158 |
+
|
| 159 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 160 |
+
# Submit items as they come from the stream
|
| 161 |
+
futures = []
|
| 162 |
+
|
| 163 |
+
for idx, data in enumerate(tqdm(ds[mode], desc=f"Submitting {mode}")):
|
| 164 |
+
future = executor.submit(process_func, data)
|
| 165 |
+
futures.append(future)
|
| 166 |
+
|
| 167 |
+
# Limit number of pending futures to control memory usage
|
| 168 |
+
if len(futures) >= max_workers * 2:
|
| 169 |
+
# Wait for some to complete
|
| 170 |
+
for f in futures[:max_workers]:
|
| 171 |
+
f.result()
|
| 172 |
+
futures = futures[max_workers:]
|
| 173 |
+
|
| 174 |
+
if limit and idx + 1 >= limit:
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
# Wait for remaining futures
|
| 178 |
+
for future in tqdm(futures, desc="Finishing"):
|
| 179 |
+
future.result()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
# Example usage - choose one approach:
|
| 184 |
+
|
| 185 |
+
# Approach 1: Process with multiprocessing (good for CPU-bound operations)
|
| 186 |
+
process_dataset_parallel(
|
| 187 |
+
dataset_name="capleaf/viVoice",
|
| 188 |
+
root_save_path='/data/vivoice',
|
| 189 |
+
max_workers=24, # Adjust based on your system
|
| 190 |
+
batch_size=100,
|
| 191 |
+
limit=None, # Remove to process all data
|
| 192 |
+
use_threads=False
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Approach 2: Process with threading (good for I/O-bound operations)
|
| 196 |
+
# process_dataset_parallel(
|
| 197 |
+
# dataset_name="capleaf/viVoice",
|
| 198 |
+
# root_save_path='/mnt/nvme-temp/vivoice',
|
| 199 |
+
# max_workers=16, # Can use more threads for I/O
|
| 200 |
+
# batch_size=100,
|
| 201 |
+
# limit=None,
|
| 202 |
+
# use_threads=True
|
| 203 |
+
# )
|
| 204 |
+
|
| 205 |
+
# Approach 3: Streaming approach (most memory efficient)
|
| 206 |
+
# process_dataset_streaming(
|
| 207 |
+
# dataset_name="capleaf/viVoice",
|
| 208 |
+
# root_save_path='/mnt/nvme-temp/vivoice',
|
| 209 |
+
# max_workers=16,
|
| 210 |
+
# limit=None
|