primepake commited on
Commit
7940474
·
1 Parent(s): 0108756

correct speaker encoder

Browse files
speech/config.yaml CHANGED
@@ -12,12 +12,12 @@ spk_embed_dim: 192
12
  qwen_pretrain_path: ''
13
  token_frame_rate: 25
14
  token_mel_ratio: 2
15
-
 
16
  # stream related params
17
  chunk_size: 25 # streaming inference chunk size, in token
18
  num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
19
 
20
- use_speaker_encoder: True # New flag
21
  speaker_encoder_config:
22
  mel_dim: 80
23
  model_dim: 512
@@ -51,8 +51,8 @@ llm: !new:cosyvoice.llm.llm.Qwen2LM
51
  extract_reference_mel: !name:cosyvoice.dataset.processor.extract_reference_mel_from_speech
52
  feat_extractor: !ref <feat_extractor>
53
  min_length: 0.5
54
- max_length: 6.0
55
- num_crops: 2 # Multiple crops from same utterance
56
  training: True
57
  sample_rate: !ref <sample_rate>
58
 
@@ -66,6 +66,9 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
66
  only_mask_loss: True
67
  token_mel_ratio: !ref <token_mel_ratio>
68
  pre_lookahead_len: 3
 
 
 
69
  encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
70
  output_size: 512
71
  attention_heads: 8
@@ -198,7 +201,7 @@ sort: !name:cosyvoice.dataset.processor.sort
198
  sort_size: 500 # sort_size should be less than shuffle_size
199
  batch: !name:cosyvoice.dataset.processor.batch
200
  batch_type: 'dynamic'
201
- max_frames_in_batch: 30000
202
  padding: !name:cosyvoice.dataset.processor.padding
203
  use_spk_embedding: False # change to True during sft
204
  use_speaker_encoder: !ref <use_speaker_encoder>
 
12
  qwen_pretrain_path: ''
13
  token_frame_rate: 25
14
  token_mel_ratio: 2
15
+ use_speaker_encoder: True
16
+ speaker_encoder_path: '/data/checkpoint/llm/epoch_29_step_20001.pt'
17
  # stream related params
18
  chunk_size: 25 # streaming inference chunk size, in token
19
  num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
20
 
 
21
  speaker_encoder_config:
22
  mel_dim: 80
23
  model_dim: 512
 
51
  extract_reference_mel: !name:cosyvoice.dataset.processor.extract_reference_mel_from_speech
52
  feat_extractor: !ref <feat_extractor>
53
  min_length: 0.5
54
+ max_length: 12.0
55
+ num_crops: 3 # Multiple crops from same utterance
56
  training: True
57
  sample_rate: !ref <sample_rate>
58
 
 
66
  only_mask_loss: True
67
  token_mel_ratio: !ref <token_mel_ratio>
68
  pre_lookahead_len: 3
69
+ use_speaker_encoder: !ref <use_speaker_encoder> # Add this
70
+ freeze_speaker_encoder: True # Freeze by default for flow training
71
+ speaker_encoder_path: !ref <speaker_encoder_path> # Add this
72
  encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
73
  output_size: 512
74
  attention_heads: 8
 
201
  sort_size: 500 # sort_size should be less than shuffle_size
202
  batch: !name:cosyvoice.dataset.processor.batch
203
  batch_type: 'dynamic'
204
+ max_frames_in_batch: 5000
205
  padding: !name:cosyvoice.dataset.processor.padding
206
  use_spk_embedding: False # change to True during sft
207
  use_speaker_encoder: !ref <use_speaker_encoder>
speech/cosyvoice/dataset/processor.py CHANGED
@@ -103,7 +103,7 @@ def individual_file_opener(data, mode='train', tts_data={}):
103
 
104
  # Load embeddings if they exist
105
  if os.path.exists(file_info['embedding_path']):
106
- utt_embedding = torch.load(file_info['embedding_path'])
107
  if isinstance(utt_embedding, torch.Tensor):
108
  utt_embedding = utt_embedding.tolist()
109
  else:
@@ -113,7 +113,7 @@ def individual_file_opener(data, mode='train', tts_data={}):
113
 
114
  # Load tokens if they exist
115
  if os.path.exists(file_info['token_path']):
116
- speech_token = torch.load(file_info['token_path'])
117
  if isinstance(speech_token, torch.Tensor):
118
  speech_token = speech_token.tolist()
119
  else:
@@ -122,7 +122,7 @@ def individual_file_opener(data, mode='train', tts_data={}):
122
 
123
  # Load speaker embedding
124
  if os.path.exists(file_info['spk_embedding_path']):
125
- spk_embedding = torch.load(file_info['spk_embedding_path'])
126
  if isinstance(spk_embedding, torch.Tensor):
127
  spk_embedding = spk_embedding.tolist()
128
  else:
@@ -686,17 +686,6 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
686
  batch['reference_mels'] = padded_reference_mels
687
  batch['reference_mel_lengths'] = padded_reference_mel_lengths
688
  batch['reference_mel_masks'] = reference_mel_masks
689
- else:
690
- # No valid mels, create dummy
691
- batch['reference_mels'] = torch.zeros(batch_size, 1, mel_dim, 100)
692
- batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
693
- batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
694
- else:
695
- # No references available, create dummy for batch
696
- batch_size = len(order)
697
- batch['reference_mels'] = torch.zeros(batch_size, 1, 80, 100)
698
- batch['reference_mel_lengths'] = torch.zeros(batch_size, 1, dtype=torch.int32)
699
- batch['reference_mel_masks'] = torch.zeros(batch_size, 1, 100)
700
 
701
  if gan is True:
702
  # in gan train, we need pitch_feat
@@ -726,9 +715,4 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
726
  batch['reject_speech_token'] = reject_speech_token
727
  batch['reject_speech_token_len'] = reject_speech_token_len
728
 
729
- if use_spk_embedding is True:
730
- batch["embedding"] = batch["spk_embedding"]
731
- else:
732
- batch["embedding"] = batch["utt_embedding"]
733
-
734
  yield batch
 
103
 
104
  # Load embeddings if they exist
105
  if os.path.exists(file_info['embedding_path']):
106
+ utt_embedding = torch.load(file_info['embedding_path'], weights_only=False)
107
  if isinstance(utt_embedding, torch.Tensor):
108
  utt_embedding = utt_embedding.tolist()
109
  else:
 
113
 
114
  # Load tokens if they exist
115
  if os.path.exists(file_info['token_path']):
116
+ speech_token = torch.load(file_info['token_path'], weights_only=False)
117
  if isinstance(speech_token, torch.Tensor):
118
  speech_token = speech_token.tolist()
119
  else:
 
122
 
123
  # Load speaker embedding
124
  if os.path.exists(file_info['spk_embedding_path']):
125
+ spk_embedding = torch.load(file_info['spk_embedding_path'], weights_only=False)
126
  if isinstance(spk_embedding, torch.Tensor):
127
  spk_embedding = spk_embedding.tolist()
128
  else:
 
686
  batch['reference_mels'] = padded_reference_mels
687
  batch['reference_mel_lengths'] = padded_reference_mel_lengths
688
  batch['reference_mel_masks'] = reference_mel_masks
 
 
 
 
 
 
 
 
 
 
 
689
 
690
  if gan is True:
691
  # in gan train, we need pitch_feat
 
715
  batch['reject_speech_token'] = reject_speech_token
716
  batch['reject_speech_token_len'] = reject_speech_token_len
717
 
 
 
 
 
 
718
  yield batch
speech/cosyvoice/flow/flow.py CHANGED
@@ -210,6 +210,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
210
  only_mask_loss: bool = True,
211
  token_mel_ratio: int = 2,
212
  pre_lookahead_len: int = 3,
 
 
 
 
213
  encoder: torch.nn.Module = None,
214
  decoder: torch.nn.Module = None,
215
  decoder_conf: Dict = {
@@ -261,6 +265,60 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
261
  self.input_frame_rate = input_frame_rate
262
  logging.info(f"input frame rate={self.input_frame_rate}")
263
  self.input_embedding = nn.Embedding(vocab_size, input_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
265
  self.encoder = encoder
266
  self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
@@ -271,6 +329,55 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
271
  print(" decoder_conf['cfm_params']: ", decoder_conf["cfm_params"])
272
  self.use_contrastive_fm = decoder_conf["cfm_params"]["use_contrastive_fm"]
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def forward(
275
  self,
276
  batch: dict,
@@ -280,10 +387,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
280
  token_len = batch["speech_token_len"].to(device)
281
  feat = batch["speech_feat"].to(device)
282
  feat_len = batch["speech_feat_len"].to(device)
283
- embedding = batch["embedding"].to(device)
284
 
285
  # NOTE unified training, static_chunk_size > 0 or = 0
286
  streaming = False # if random.random() < 0.5 else False
 
 
287
 
288
  # xvec projection
289
  embedding = F.normalize(embedding, dim=1)
@@ -335,13 +443,29 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
335
  prompt_token_len,
336
  prompt_feat,
337
  prompt_feat_len,
338
- embedding,
339
- streaming,
340
- finalize,
 
 
 
341
  ):
342
  assert token.shape[0] == 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  # xvec projection
344
- embedding = F.normalize(embedding, dim=1)
345
  embedding = self.spk_embed_affine_layer(embedding)
346
 
347
  # concat text and prompt_text
 
210
  only_mask_loss: bool = True,
211
  token_mel_ratio: int = 2,
212
  pre_lookahead_len: int = 3,
213
+ use_speaker_encoder: bool = False, # Add this
214
+ freeze_speaker_encoder: bool = False, # Add this
215
+ max_conditioning_inputs: int = 2, # Add this
216
+ speaker_encoder_path: str = None,
217
  encoder: torch.nn.Module = None,
218
  decoder: torch.nn.Module = None,
219
  decoder_conf: Dict = {
 
265
  self.input_frame_rate = input_frame_rate
266
  logging.info(f"input frame rate={self.input_frame_rate}")
267
  self.input_embedding = nn.Embedding(vocab_size, input_size)
268
+ self.use_speaker_encoder = use_speaker_encoder
269
+ # Speaker encoder setup
270
+ if use_speaker_encoder:
271
+ from cosyvoice.llm.llm import LearnableSpeakerEncoder
272
+ self.speaker_encoder = LearnableSpeakerEncoder(
273
+ mel_dim=80,
274
+ model_dim=512,
275
+ output_dim=spk_embed_dim,
276
+ num_blocks=6,
277
+ num_heads=8,
278
+ )
279
+
280
+ # Load speaker encoder weights from LLM checkpoint if provided
281
+ if speaker_encoder_path is not None:
282
+ logging.info(f"Loading speaker encoder from {speaker_encoder_path}")
283
+ checkpoint = torch.load(speaker_encoder_path, map_location='cpu')
284
+
285
+ # Debug: print checkpoint structure
286
+ print(f'Checkpoint keys: {checkpoint.keys()}')
287
+
288
+ # Extract speaker encoder weights
289
+ speaker_encoder_state = {}
290
+
291
+ # Check if checkpoint has 'state_dict' key or direct model weights
292
+ if 'state_dict' in checkpoint:
293
+ state_dict = checkpoint['state_dict']
294
+ else:
295
+ # Direct model weights (based on your save function)
296
+ state_dict = {k: v for k, v in checkpoint.items() if not k in ['epoch', 'step']}
297
+
298
+ # Extract speaker encoder weights
299
+ for key, value in state_dict.items():
300
+ if 'speaker_encoder.' in key:
301
+ # Remove module. prefix if exists (from DDP)
302
+ new_key = key.replace('module.', '')
303
+ # Remove speaker_encoder. prefix to match the local module
304
+ new_key = new_key.replace('speaker_encoder.', '')
305
+ speaker_encoder_state[new_key] = value
306
+
307
+ if len(speaker_encoder_state) == 0:
308
+ logging.warning("No speaker encoder weights found in checkpoint!")
309
+ logging.warning(f"Available keys: {list(state_dict.keys())[:10]}...") # Show first 10 keys
310
+ else:
311
+ logging.info(f"Found {len(speaker_encoder_state)} speaker encoder weights")
312
+ # Load the weights
313
+ self.speaker_encoder.load_state_dict(speaker_encoder_state, strict=True)
314
+ logging.info("Speaker encoder loaded successfully")
315
+ self.freeze_speaker_encoder = freeze_speaker_encoder
316
+ if freeze_speaker_encoder:
317
+ # Freeze speaker encoder parameters
318
+ for param in self.speaker_encoder.parameters():
319
+ param.requires_grad = False
320
+ logging.info("Speaker encoder frozen in flow matching")
321
+
322
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
323
  self.encoder = encoder
324
  self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
 
329
  print(" decoder_conf['cfm_params']: ", decoder_conf["cfm_params"])
330
  self.use_contrastive_fm = decoder_conf["cfm_params"]["use_contrastive_fm"]
331
 
332
+ def get_speaker_embedding(self, batch, device):
333
+ """Extract speaker embedding from reference mels or use provided embeddings"""
334
+
335
+ if self.use_speaker_encoder and 'reference_mels' in batch:
336
+ reference_mels = batch['reference_mels'].to(device)
337
+
338
+ # Handle multiple references
339
+ if reference_mels.dim() == 4: # [B, N, C, T]
340
+ B, N, C, T = reference_mels.shape
341
+ embeddings = []
342
+
343
+ for i in range(N):
344
+ ref_mel = reference_mels[:, i, :, :] # [B, C, T]
345
+ if 'reference_mel_masks' in batch:
346
+ mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
347
+ else:
348
+ mask = None
349
+ print('ref_mel mask: ', ref_mel.shape, mask.shape)
350
+ # Apply speaker encoder
351
+ with torch.set_grad_enabled(not self.freeze_speaker_encoder):
352
+ emb = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
353
+ embeddings.append(emb)
354
+
355
+ # Average multiple references
356
+ embedding = torch.stack(embeddings, dim=1).mean(dim=1) # [B, spk_embed_dim]
357
+
358
+ else: # Single reference [B, C, T]
359
+ if 'reference_mel_mask' in batch:
360
+ mask = batch['reference_mel_mask'].unsqueeze(1).to(device)
361
+ else:
362
+ mask = None
363
+
364
+ with torch.set_grad_enabled(not self.freeze_speaker_encoder):
365
+ embedding = self.speaker_encoder(reference_mels, mask)
366
+
367
+ # Normalize (already normalized in speaker encoder, but just to be safe)
368
+ embedding = F.normalize(embedding, dim=1)
369
+
370
+ elif 'embedding' in batch:
371
+ # Use provided embeddings (backward compatibility)
372
+ embedding = batch['embedding'].to(device)
373
+ embedding = F.normalize(embedding, dim=1)
374
+ else:
375
+ # No speaker conditioning
376
+ B = batch['speech_token'].shape[0]
377
+ embedding = torch.zeros(B, self.spk_embed_dim).to(device)
378
+
379
+ return embedding
380
+
381
  def forward(
382
  self,
383
  batch: dict,
 
387
  token_len = batch["speech_token_len"].to(device)
388
  feat = batch["speech_feat"].to(device)
389
  feat_len = batch["speech_feat_len"].to(device)
 
390
 
391
  # NOTE unified training, static_chunk_size > 0 or = 0
392
  streaming = False # if random.random() < 0.5 else False
393
+ print("get speaker embedding")
394
+ embedding = self.get_speaker_embedding(batch, device)
395
 
396
  # xvec projection
397
  embedding = F.normalize(embedding, dim=1)
 
443
  prompt_token_len,
444
  prompt_feat,
445
  prompt_feat_len,
446
+ embedding=None,
447
+ reference_mels=None,
448
+ reference_mel_lengths=None,
449
+ reference_mel_masks=None,
450
+ streaming=False,
451
+ finalize=False,
452
  ):
453
  assert token.shape[0] == 1
454
+
455
+ # Get speaker embedding
456
+ if self.use_speaker_encoder and reference_mels is not None:
457
+ batch = {
458
+ 'reference_mels': reference_mels,
459
+ 'reference_mel_lengths': reference_mel_lengths,
460
+ 'reference_mel_masks': reference_mel_masks
461
+ }
462
+ embedding = self.get_speaker_embedding(batch, token.device)
463
+ elif embedding is not None:
464
+ embedding = F.normalize(embedding, dim=1)
465
+ else:
466
+ embedding = torch.zeros(1, self.spk_embed_dim).to(token.device)
467
+
468
  # xvec projection
 
469
  embedding = self.spk_embed_affine_layer(embedding)
470
 
471
  # concat text and prompt_text
speech/cosyvoice/llm/llm.py CHANGED
@@ -29,11 +29,11 @@ from cosyvoice.utils.file_utils import logging
29
  from cosyvoice.utils.mask import make_pad_mask
30
  from cosyvoice.transformer.attention import MultiHeadedAttention
31
  from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
 
32
 
33
  class LearnableSpeakerEncoder(nn.Module):
34
  """
35
- Speaker encoder inspired by Tortoise-TTS for CosyVoice2.
36
- Processes mel-spectrograms through attention blocks to extract speaker characteristics.
37
  """
38
  def __init__(
39
  self,
@@ -42,58 +42,60 @@ class LearnableSpeakerEncoder(nn.Module):
42
  output_dim: int = 192,
43
  num_blocks: int = 6,
44
  num_heads: int = 8,
45
- kernel_size: int = 1,
46
  dropout: float = 0.0,
 
47
  ):
48
  super().__init__()
49
 
50
- # Initial projection (like Tortoise's ConditioningEncoder)
51
- self.init = nn.Conv1d(mel_dim, model_dim, kernel_size=kernel_size)
52
-
53
- self.blocks = nn.ModuleList([
54
- ConformerEncoderLayer(
55
- size=model_dim,
56
- self_attn=MultiHeadedAttention(
57
- n_head=num_heads,
58
- n_feat=model_dim,
59
- dropout_rate=dropout,
60
- ),
61
- feed_forward=None, # Can add if needed
62
- feed_forward_macaron=None,
63
- conv_module=None,
64
- dropout_rate=dropout,
65
- normalize_before=True,
66
- ) for _ in range(num_blocks)
67
- ])
68
 
69
- # Output projection
70
- self.output_proj = nn.Linear(model_dim, output_dim)
 
 
 
 
 
 
71
 
 
 
 
72
  def forward(self, x, mask=None):
73
  """
74
  Args:
75
  x: mel-spectrogram [B, 80, T]
76
- mask: padding mask [B, 1, T]
77
  Returns:
78
  speaker embedding [B, output_dim]
79
  """
80
  # Initial conv
81
  h = self.init(x) # [B, model_dim, T]
82
- h = h.transpose(1, 2) # [B, T, model_dim]
83
 
84
  # Apply attention blocks
85
- for block in self.blocks:
86
- h, _ = block(h, mask)
87
 
88
- # Pool over time (take first position like Tortoise)
89
- # Could also use mean pooling
90
- output = h[:, 0, :] # [B, model_dim]
 
 
 
 
 
 
 
 
 
 
91
 
92
  # Project to output dimension
93
- output = self.output_proj(output) # [B, output_dim]
94
 
95
  return F.normalize(output, p=2, dim=1)
96
 
 
97
  class TransformerLM(torch.nn.Module):
98
  def __init__(
99
  self,
@@ -163,19 +165,19 @@ class TransformerLM(torch.nn.Module):
163
 
164
  if self.use_speaker_encoder and 'reference_mels' in batch:
165
  reference_mels = batch['reference_mels'].to(device)
166
-
167
  # Handle multiple references
168
  if reference_mels.dim() == 4: # [B, N, T, 80]
169
- B, N, T, D = reference_mels.shape
170
  conds = []
171
 
172
  for i in range(N):
173
- ref_mel = reference_mels[:, i, :, :].transpose(1, 2) # [B, 80, T]
174
  if 'reference_mel_masks' in batch:
175
  mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
176
  else:
177
  mask = None
178
-
179
  cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
180
  conds.append(cond)
181
 
@@ -547,42 +549,42 @@ class Qwen2LM(TransformerLM):
547
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
548
  return lm_target, lm_input, lm_input_len
549
 
550
- def forward(
551
- self,
552
- batch: dict,
553
- device: torch.device,
554
- ) -> Dict[str, Optional[torch.Tensor]]:
555
- """
556
- Args:
557
- text: (B, L, D)
558
- text_lengths: (B,)
559
- audio: (B, T, N) or (B, T)
560
- audio_lengths: (B,)
561
- """
562
- text_token = batch['text_token'].to(device)
563
- text_token_len = batch['text_token_len'].to(device)
564
- speech_token = batch['speech_token'].to(device)
565
- speech_token_len = batch['speech_token_len'].to(device)
566
-
567
-
568
- # 1. encode text_token
569
- text_token_emb = self.llm.model.model.embed_tokens(text_token)
570
-
571
- # 2. encode speech_token
572
- speech_token_emb = self.speech_embedding(speech_token)
573
-
574
- # 3. prepare llm_input/target
575
- lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
576
- lm_target = lm_target.to(device)
577
-
578
- # 4. run lm forward
579
- lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
580
- logits = self.llm_decoder(lm_output)
581
- loss = self.criterion_ce(logits, lm_target.to(device))
582
- acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
583
- return {'loss': loss, 'acc': acc}
584
 
585
- def forward_spk(
586
  self,
587
  batch: dict,
588
  device: torch.device,
 
29
  from cosyvoice.utils.mask import make_pad_mask
30
  from cosyvoice.transformer.attention import MultiHeadedAttention
31
  from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
32
+ from cosyvoice.transformer.arch_util import AttentionBlock
33
 
34
  class LearnableSpeakerEncoder(nn.Module):
35
  """
36
+ Speaker encoder using the same architecture as Tortoise-TTS ConditioningEncoder.
 
37
  """
38
  def __init__(
39
  self,
 
42
  output_dim: int = 192,
43
  num_blocks: int = 6,
44
  num_heads: int = 8,
 
45
  dropout: float = 0.0,
46
+ mean_pooling: bool = False, # Tortoise uses first position by default
47
  ):
48
  super().__init__()
49
 
50
+ # Same as Tortoise ConditioningEncoder
51
+ self.init = nn.Conv1d(mel_dim, model_dim, kernel_size=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # AttentionBlock from Tortoise
54
+ attn = []
55
+ for _ in range(num_blocks):
56
+ attn.append(AttentionBlock(model_dim, num_heads))
57
+ self.attn = nn.Sequential(*attn)
58
+
59
+ self.dim = model_dim
60
+ self.mean_pooling = mean_pooling
61
 
62
+ # Output projection to match CosyVoice embedding dimension
63
+ self.output_proj = nn.Linear(model_dim, output_dim)
64
+
65
  def forward(self, x, mask=None):
66
  """
67
  Args:
68
  x: mel-spectrogram [B, 80, T]
69
+ mask: padding mask [B, 1, T] (not used in Tortoise version)
70
  Returns:
71
  speaker embedding [B, output_dim]
72
  """
73
  # Initial conv
74
  h = self.init(x) # [B, model_dim, T]
 
75
 
76
  # Apply attention blocks
77
+ h = self.attn(h) # [B, model_dim, T]
 
78
 
79
+ # Pooling - Tortoise uses first position
80
+ if self.mean_pooling:
81
+ if mask is not None:
82
+ # Masked mean pooling
83
+ h_masked = h * mask
84
+ h_sum = h_masked.sum(dim=2)
85
+ mask_sum = mask.sum(dim=2).clamp(min=1)
86
+ h_pooled = h_sum / mask_sum
87
+ else:
88
+ h_pooled = h.mean(dim=2)
89
+ else:
90
+ # Use first position like Tortoise
91
+ h_pooled = h[:, :, 0]
92
 
93
  # Project to output dimension
94
+ output = self.output_proj(h_pooled) # [B, output_dim]
95
 
96
  return F.normalize(output, p=2, dim=1)
97
 
98
+
99
  class TransformerLM(torch.nn.Module):
100
  def __init__(
101
  self,
 
165
 
166
  if self.use_speaker_encoder and 'reference_mels' in batch:
167
  reference_mels = batch['reference_mels'].to(device)
168
+ print('reference_mels shape: ', reference_mels.shape)
169
  # Handle multiple references
170
  if reference_mels.dim() == 4: # [B, N, T, 80]
171
+ B, N, C, T = reference_mels.shape
172
  conds = []
173
 
174
  for i in range(N):
175
+ ref_mel = reference_mels[:, i, :, :]
176
  if 'reference_mel_masks' in batch:
177
  mask = batch['reference_mel_masks'][:, i, :].unsqueeze(1).to(device)
178
  else:
179
  mask = None
180
+ print('ref_mel shape: ', ref_mel.shape, 'mask: ', mask.shape)
181
  cond = self.speaker_encoder(ref_mel, mask) # [B, spk_embed_dim]
182
  conds.append(cond)
183
 
 
549
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
550
  return lm_target, lm_input, lm_input_len
551
 
552
+ # def forward(
553
+ # self,
554
+ # batch: dict,
555
+ # device: torch.device,
556
+ # ) -> Dict[str, Optional[torch.Tensor]]:
557
+ # """
558
+ # Args:
559
+ # text: (B, L, D)
560
+ # text_lengths: (B,)
561
+ # audio: (B, T, N) or (B, T)
562
+ # audio_lengths: (B,)
563
+ # """
564
+ # text_token = batch['text_token'].to(device)
565
+ # text_token_len = batch['text_token_len'].to(device)
566
+ # speech_token = batch['speech_token'].to(device)
567
+ # speech_token_len = batch['speech_token_len'].to(device)
568
+
569
+
570
+ # # 1. encode text_token
571
+ # text_token_emb = self.llm.model.model.embed_tokens(text_token)
572
+
573
+ # # 2. encode speech_token
574
+ # speech_token_emb = self.speech_embedding(speech_token)
575
+
576
+ # # 3. prepare llm_input/target
577
+ # lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
578
+ # lm_target = lm_target.to(device)
579
+
580
+ # # 4. run lm forward
581
+ # lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
582
+ # logits = self.llm_decoder(lm_output)
583
+ # loss = self.criterion_ce(logits, lm_target.to(device))
584
+ # acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
585
+ # return {'loss': loss, 'acc': acc}
586
 
587
+ def forward(
588
  self,
589
  batch: dict,
590
  device: torch.device,
speech/cosyvoice/transformer/arch_util.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from cosyvoice.transformer.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
10
+
11
+
12
+ def zero_module(module):
13
+ """
14
+ Zero out the parameters of a module and return it.
15
+ """
16
+ for p in module.parameters():
17
+ p.detach().zero_()
18
+ return module
19
+
20
+
21
+ class GroupNorm32(nn.GroupNorm):
22
+ def forward(self, x):
23
+ return super().forward(x.float()).type(x.dtype)
24
+
25
+
26
+ def normalization(channels):
27
+ """
28
+ Make a standard normalization layer.
29
+
30
+ :param channels: number of input channels.
31
+ :return: an nn.Module for normalization.
32
+ """
33
+ groups = 32
34
+ if channels <= 16:
35
+ groups = 8
36
+ elif channels <= 64:
37
+ groups = 16
38
+ while channels % groups != 0:
39
+ groups = int(groups / 2)
40
+ assert groups > 2
41
+ return GroupNorm32(groups, channels)
42
+
43
+
44
+ class QKVAttentionLegacy(nn.Module):
45
+ """
46
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
47
+ """
48
+
49
+ def __init__(self, n_heads):
50
+ super().__init__()
51
+ self.n_heads = n_heads
52
+
53
+ def forward(self, qkv, mask=None, rel_pos=None):
54
+ """
55
+ Apply QKV attention.
56
+
57
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
58
+ :return: an [N x (H * C) x T] tensor after attention.
59
+ """
60
+ bs, width, length = qkv.shape
61
+ assert width % (3 * self.n_heads) == 0
62
+ ch = width // (3 * self.n_heads)
63
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
64
+ scale = 1 / math.sqrt(math.sqrt(ch))
65
+ weight = torch.einsum(
66
+ "bct,bcs->bts", q * scale, k * scale
67
+ ) # More stable with f16 than dividing afterwards
68
+ if rel_pos is not None:
69
+ weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ if mask is not None:
72
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
73
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
74
+ weight = weight * mask
75
+ a = torch.einsum("bts,bcs->bct", weight, v)
76
+
77
+ return a.reshape(bs, -1, length)
78
+
79
+
80
+ class AttentionBlock(nn.Module):
81
+ """
82
+ An attention block that allows spatial positions to attend to each other.
83
+
84
+ Originally ported from here, but adapted to the N-d case.
85
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ channels,
91
+ num_heads=1,
92
+ num_head_channels=-1,
93
+ do_checkpoint=True,
94
+ relative_pos_embeddings=False,
95
+ ):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.do_checkpoint = do_checkpoint
99
+ if num_head_channels == -1:
100
+ self.num_heads = num_heads
101
+ else:
102
+ assert (
103
+ channels % num_head_channels == 0
104
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
105
+ self.num_heads = channels // num_head_channels
106
+ self.norm = normalization(channels)
107
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
108
+ # split heads before split qkv
109
+ self.attention = QKVAttentionLegacy(self.num_heads)
110
+
111
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
112
+ if relative_pos_embeddings:
113
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
114
+ else:
115
+ self.relative_pos_embeddings = None
116
+
117
+ def forward(self, x, mask=None):
118
+ b, c, *spatial = x.shape
119
+ x = x.reshape(b, c, -1)
120
+ qkv = self.qkv(self.norm(x))
121
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
122
+ h = self.proj_out(h)
123
+ return (x + h).reshape(b, c, *spatial)
124
+
125
+
126
+ class Upsample(nn.Module):
127
+ """
128
+ An upsampling layer with an optional convolution.
129
+
130
+ :param channels: channels in the inputs and outputs.
131
+ :param use_conv: a bool determining if a convolution is applied.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv, out_channels=None, factor=4):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.factor = factor
140
+ if use_conv:
141
+ ksize = 5
142
+ pad = 2
143
+ self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
144
+
145
+ def forward(self, x):
146
+ assert x.shape[1] == self.channels
147
+ x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
148
+ if self.use_conv:
149
+ x = self.conv(x)
150
+ return x
151
+
152
+
153
+ class Downsample(nn.Module):
154
+ """
155
+ A downsampling layer with an optional convolution.
156
+
157
+ :param channels: channels in the inputs and outputs.
158
+ :param use_conv: a bool determining if a convolution is applied.
159
+ """
160
+
161
+ def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
162
+ super().__init__()
163
+ self.channels = channels
164
+ self.out_channels = out_channels or channels
165
+ self.use_conv = use_conv
166
+
167
+ stride = factor
168
+ if use_conv:
169
+ self.op = nn.Conv1d(
170
+ self.channels, self.out_channels, ksize, stride=stride, padding=pad
171
+ )
172
+ else:
173
+ assert self.channels == self.out_channels
174
+ self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
175
+
176
+ def forward(self, x):
177
+ assert x.shape[1] == self.channels
178
+ return self.op(x)
179
+
180
+
181
+ class ResBlock(nn.Module):
182
+ def __init__(
183
+ self,
184
+ channels,
185
+ dropout,
186
+ out_channels=None,
187
+ use_conv=False,
188
+ use_scale_shift_norm=False,
189
+ up=False,
190
+ down=False,
191
+ kernel_size=3,
192
+ ):
193
+ super().__init__()
194
+ self.channels = channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+ padding = 1 if kernel_size == 3 else 2
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False)
211
+ self.x_upd = Upsample(channels, False)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False)
214
+ self.x_upd = Downsample(channels, False)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.out_layers = nn.Sequential(
219
+ normalization(self.out_channels),
220
+ nn.SiLU(),
221
+ nn.Dropout(p=dropout),
222
+ zero_module(
223
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
224
+ ),
225
+ )
226
+
227
+ if self.out_channels == channels:
228
+ self.skip_connection = nn.Identity()
229
+ elif use_conv:
230
+ self.skip_connection = nn.Conv1d(
231
+ channels, self.out_channels, kernel_size, padding=padding
232
+ )
233
+ else:
234
+ self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
235
+
236
+ def forward(self, x):
237
+ if self.updown:
238
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
239
+ h = in_rest(x)
240
+ h = self.h_upd(h)
241
+ x = self.x_upd(x)
242
+ h = in_conv(h)
243
+ else:
244
+ h = self.in_layers(x)
245
+ h = self.out_layers(h)
246
+ return self.skip_connection(x) + h
247
+
248
+
249
+ class AudioMiniEncoder(nn.Module):
250
+ def __init__(self,
251
+ spec_dim,
252
+ embedding_dim,
253
+ base_channels=128,
254
+ depth=2,
255
+ resnet_blocks=2,
256
+ attn_blocks=4,
257
+ num_attn_heads=4,
258
+ dropout=0,
259
+ downsample_factor=2,
260
+ kernel_size=3):
261
+ super().__init__()
262
+ self.init = nn.Sequential(
263
+ nn.Conv1d(spec_dim, base_channels, 3, padding=1)
264
+ )
265
+ ch = base_channels
266
+ res = []
267
+ for l in range(depth):
268
+ for r in range(resnet_blocks):
269
+ res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
270
+ res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
271
+ ch *= 2
272
+ self.res = nn.Sequential(*res)
273
+ self.final = nn.Sequential(
274
+ normalization(ch),
275
+ nn.SiLU(),
276
+ nn.Conv1d(ch, embedding_dim, 1)
277
+ )
278
+ attn = []
279
+ for a in range(attn_blocks):
280
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
281
+ self.attn = nn.Sequential(*attn)
282
+ self.dim = embedding_dim
283
+
284
+ def forward(self, x):
285
+ h = self.init(x)
286
+ h = self.res(h)
287
+ h = self.final(h)
288
+ h = self.attn(h)
289
+ return h[:, :, 0]
290
+
291
+
292
+ DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
293
+
294
+
295
+ class TorchMelSpectrogram(nn.Module):
296
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
297
+ sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE):
298
+ super().__init__()
299
+ # These are the default tacotron values for the MEL spectrogram.
300
+ self.filter_length = filter_length
301
+ self.hop_length = hop_length
302
+ self.win_length = win_length
303
+ self.n_mel_channels = n_mel_channels
304
+ self.mel_fmin = mel_fmin
305
+ self.mel_fmax = mel_fmax
306
+ self.sampling_rate = sampling_rate
307
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
308
+ win_length=self.win_length, power=2, normalized=normalize,
309
+ sample_rate=self.sampling_rate, f_min=self.mel_fmin,
310
+ f_max=self.mel_fmax, n_mels=self.n_mel_channels,
311
+ norm="slaney")
312
+ self.mel_norm_file = mel_norm_file
313
+ if self.mel_norm_file is not None:
314
+ self.mel_norms = torch.load(self.mel_norm_file)
315
+ else:
316
+ self.mel_norms = None
317
+
318
+ def forward(self, inp):
319
+ if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
320
+ inp = inp.squeeze(1)
321
+ assert len(inp.shape) == 2
322
+ if torch.backends.mps.is_available():
323
+ inp = inp.to('cpu')
324
+ self.mel_stft = self.mel_stft.to(inp.device)
325
+ mel = self.mel_stft(inp)
326
+ # Perform dynamic range compression
327
+ mel = torch.log(torch.clamp(mel, min=1e-5))
328
+ if self.mel_norms is not None:
329
+ self.mel_norms = self.mel_norms.to(mel.device)
330
+ mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
331
+ return mel
332
+
333
+
334
+ class CheckpointedLayer(nn.Module):
335
+ """
336
+ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
337
+ checkpoint for all other args.
338
+ """
339
+ def __init__(self, wrap):
340
+ super().__init__()
341
+ self.wrap = wrap
342
+
343
+ def forward(self, x, *args, **kwargs):
344
+ for k, v in kwargs.items():
345
+ assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
346
+ partial = functools.partial(self.wrap, **kwargs)
347
+ return partial(x, *args)
348
+
349
+
350
+ class CheckpointedXTransformerEncoder(nn.Module):
351
+ """
352
+ Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
353
+ to channels-last that XTransformer expects.
354
+ """
355
+ def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
356
+ super().__init__()
357
+ self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
358
+ self.needs_permute = needs_permute
359
+ self.exit_permute = exit_permute
360
+
361
+ if not checkpoint:
362
+ return
363
+ for i in range(len(self.transformer.attn_layers.layers)):
364
+ n, b, r = self.transformer.attn_layers.layers[i]
365
+ self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
366
+
367
+ def forward(self, x, **kwargs):
368
+ if self.needs_permute:
369
+ x = x.permute(0,2,1)
370
+ h = self.transformer(x, **kwargs)
371
+ if self.exit_permute:
372
+ h = h.permute(0,2,1)
373
+ return h
speech/cosyvoice/transformer/xtransformers.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import nn, einsum
10
+
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple('Intermediates', [
14
+ 'pre_softmax_attn',
15
+ 'post_softmax_attn'
16
+ ])
17
+
18
+ LayerIntermediates = namedtuple('Intermediates', [
19
+ 'hiddens',
20
+ 'attn_intermediates',
21
+ 'past_key_values',
22
+ ])
23
+
24
+
25
+ # helpers
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def cast_tuple(val, depth):
38
+ return val if isinstance(val, tuple) else (val,) * depth
39
+
40
+
41
+ class always():
42
+ def __init__(self, val):
43
+ self.val = val
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ return self.val
47
+
48
+
49
+ class not_equals():
50
+ def __init__(self, val):
51
+ self.val = val
52
+
53
+ def __call__(self, x, *args, **kwargs):
54
+ return x != self.val
55
+
56
+
57
+ class equals():
58
+ def __init__(self, val):
59
+ self.val = val
60
+
61
+ def __call__(self, x, *args, **kwargs):
62
+ return x == self.val
63
+
64
+
65
+ def max_neg_value(tensor):
66
+ return -torch.finfo(tensor.dtype).max
67
+
68
+
69
+ def l2norm(t):
70
+ return F.normalize(t, p=2, dim=-1)
71
+
72
+
73
+ # init helpers
74
+
75
+ def init_zero_(layer):
76
+ nn.init.constant_(layer.weight, 0.)
77
+ if exists(layer.bias):
78
+ nn.init.constant_(layer.bias, 0.)
79
+
80
+
81
+ # keyword argument helpers
82
+
83
+ def pick_and_pop(keys, d):
84
+ values = list(map(lambda key: d.pop(key), keys))
85
+ return dict(zip(keys, values))
86
+
87
+
88
+ def group_dict_by_key(cond, d):
89
+ return_val = [dict(), dict()]
90
+ for key in d.keys():
91
+ match = bool(cond(key))
92
+ ind = int(not match)
93
+ return_val[ind][key] = d[key]
94
+ return (*return_val,)
95
+
96
+
97
+ def string_begins_with(prefix, str):
98
+ return str.startswith(prefix)
99
+
100
+
101
+ def group_by_key_prefix(prefix, d):
102
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
103
+
104
+
105
+ def groupby_prefix_and_trim(prefix, d):
106
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
107
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
108
+ return kwargs_without_prefix, kwargs
109
+
110
+
111
+ # activations
112
+
113
+ class ReluSquared(nn.Module):
114
+ def forward(self, x):
115
+ return F.relu(x) ** 2
116
+
117
+
118
+ # positional embeddings
119
+
120
+ class AbsolutePositionalEmbedding(nn.Module):
121
+ def __init__(self, dim, max_seq_len):
122
+ super().__init__()
123
+ self.scale = dim ** -0.5
124
+ self.emb = nn.Embedding(max_seq_len, dim)
125
+
126
+ def forward(self, x):
127
+ n = torch.arange(x.shape[1], device=x.device)
128
+ pos_emb = self.emb(n)
129
+ pos_emb = rearrange(pos_emb, 'n d -> () n d')
130
+ return pos_emb * self.scale
131
+
132
+
133
+ class FixedPositionalEmbedding(nn.Module):
134
+ def __init__(self, dim):
135
+ super().__init__()
136
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
137
+ self.register_buffer('inv_freq', inv_freq)
138
+
139
+ def forward(self, x, seq_dim=1, offset=0):
140
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
141
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
142
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
143
+ return rearrange(emb, 'n d -> () n d')
144
+
145
+
146
+ class RelativePositionBias(nn.Module):
147
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
148
+ super().__init__()
149
+ self.scale = scale
150
+ self.causal = causal
151
+ self.num_buckets = num_buckets
152
+ self.max_distance = max_distance
153
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
154
+
155
+ @staticmethod
156
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
157
+ ret = 0
158
+ n = -relative_position
159
+ if not causal:
160
+ num_buckets //= 2
161
+ ret += (n < 0).long() * num_buckets
162
+ n = torch.abs(n)
163
+ else:
164
+ n = torch.max(n, torch.zeros_like(n))
165
+
166
+ max_exact = num_buckets // 2
167
+ is_small = n < max_exact
168
+
169
+ val_if_large = max_exact + (
170
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
171
+ ).long()
172
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
173
+
174
+ ret += torch.where(is_small, n, val_if_large)
175
+ return ret
176
+
177
+ def forward(self, qk_dots):
178
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
179
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
180
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
181
+ rel_pos = k_pos[None, :] - q_pos[:, None]
182
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
183
+ max_distance=self.max_distance)
184
+ values = self.relative_attention_bias(rp_bucket)
185
+ bias = rearrange(values, 'i j h -> () h i j')
186
+ return qk_dots + (bias * self.scale)
187
+
188
+
189
+ class AlibiPositionalBias(nn.Module):
190
+ def __init__(self, heads, **kwargs):
191
+ super().__init__()
192
+ self.heads = heads
193
+ slopes = torch.Tensor(self._get_slopes(heads))
194
+ slopes = rearrange(slopes, 'h -> () h () ()')
195
+ self.register_buffer('slopes', slopes, persistent=False)
196
+ self.register_buffer('bias', None, persistent=False)
197
+
198
+ @staticmethod
199
+ def _get_slopes(heads):
200
+ def get_slopes_power_of_2(n):
201
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
202
+ ratio = start
203
+ return [start * ratio ** i for i in range(n)]
204
+
205
+ if math.log2(heads).is_integer():
206
+ return get_slopes_power_of_2(heads)
207
+
208
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
209
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
210
+ :heads - closest_power_of_2]
211
+
212
+ def forward(self, qk_dots):
213
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
214
+
215
+ if exists(self.bias) and self.bias.shape[-1] >= j:
216
+ return qk_dots + self.bias[..., :j]
217
+
218
+ bias = torch.arange(j, device=device)
219
+ bias = rearrange(bias, 'j -> () () () j')
220
+ bias = bias * self.slopes
221
+
222
+ num_heads_unalibied = h - bias.shape[1]
223
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
224
+
225
+ self.register_buffer('bias', bias, persistent=False)
226
+ return qk_dots + self.bias
227
+
228
+
229
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
230
+ def __init__(self, heads, bidirectional=False):
231
+ super().__init__(heads)
232
+ los_slopes = torch.log(self.slopes)
233
+ self.learned_logslopes = nn.Parameter(los_slopes)
234
+
235
+ self.bidirectional = bidirectional
236
+ if self.bidirectional:
237
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
238
+
239
+ def forward(self, qk_dots):
240
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
241
+
242
+ def get_slopes(param):
243
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
244
+
245
+ if exists(self.bias) and self.bias.shape[-1] >= j:
246
+ bias = self.bias[..., :i, :j]
247
+ else:
248
+ i_arange = torch.arange(i, device=device)
249
+ j_arange = torch.arange(j, device=device)
250
+ bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
251
+ self.register_buffer('bias', bias, persistent=False)
252
+
253
+ if self.bidirectional:
254
+ past_slopes = get_slopes(self.learned_logslopes)
255
+ future_slopes = get_slopes(self.learned_logslopes_future)
256
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
257
+ else:
258
+ slopes = get_slopes(self.learned_logslopes)
259
+ bias = bias * slopes
260
+
261
+ return qk_dots + bias
262
+
263
+
264
+ class RotaryEmbedding(nn.Module):
265
+ def __init__(self, dim):
266
+ super().__init__()
267
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
268
+ self.register_buffer('inv_freq', inv_freq)
269
+
270
+ def forward(self, max_seq_len, device):
271
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
272
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
273
+ emb = torch.cat((freqs, freqs), dim=-1)
274
+ return rearrange(emb, 'n d -> () () n d')
275
+
276
+
277
+ def rotate_half(x):
278
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
279
+ x1, x2 = x.unbind(dim=-2)
280
+ return torch.cat((-x2, x1), dim=-1)
281
+
282
+
283
+ def apply_rotary_pos_emb(t, freqs):
284
+ seq_len = t.shape[-2]
285
+ freqs = freqs[:, :, -seq_len:]
286
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
287
+
288
+
289
+ # norms
290
+
291
+ class Scale(nn.Module):
292
+ def __init__(self, value, fn):
293
+ super().__init__()
294
+ self.value = value
295
+ self.fn = fn
296
+
297
+ def forward(self, x, **kwargs):
298
+ out = self.fn(x, **kwargs)
299
+ scale_fn = lambda t: t * self.value
300
+
301
+ if not isinstance(out, tuple):
302
+ return scale_fn(out)
303
+
304
+ return (scale_fn(out[0]), *out[1:])
305
+
306
+
307
+ class Rezero(nn.Module):
308
+ def __init__(self, fn):
309
+ super().__init__()
310
+ self.fn = fn
311
+ self.g = nn.Parameter(torch.zeros(1))
312
+
313
+ def forward(self, x, **kwargs):
314
+ out = self.fn(x, **kwargs)
315
+ rezero_fn = lambda t: t * self.g
316
+
317
+ if not isinstance(out, tuple):
318
+ return rezero_fn(out)
319
+
320
+ return (rezero_fn(out[0]), *out[1:])
321
+
322
+
323
+ class ScaleNorm(nn.Module):
324
+ def __init__(self, dim, eps=1e-5):
325
+ super().__init__()
326
+ self.scale = dim ** -0.5
327
+ self.eps = eps
328
+ self.g = nn.Parameter(torch.ones(1))
329
+
330
+ def forward(self, x):
331
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
332
+ return x / norm.clamp(min=self.eps) * self.g
333
+
334
+
335
+ class RMSNorm(nn.Module):
336
+ def __init__(self, dim, eps=1e-8):
337
+ super().__init__()
338
+ self.scale = dim ** -0.5
339
+ self.eps = eps
340
+ self.g = nn.Parameter(torch.ones(dim))
341
+
342
+ def forward(self, x):
343
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
344
+ return x / norm.clamp(min=self.eps) * self.g
345
+
346
+
347
+ class RMSScaleShiftNorm(nn.Module):
348
+ def __init__(self, dim, eps=1e-8):
349
+ super().__init__()
350
+ self.scale = dim ** -0.5
351
+ self.eps = eps
352
+ self.g = nn.Parameter(torch.ones(dim))
353
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
354
+
355
+ def forward(self, x, norm_scale_shift_inp):
356
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
357
+ norm = x / norm.clamp(min=self.eps) * self.g
358
+
359
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
360
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
361
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
362
+ return h
363
+
364
+
365
+ # residual and residual gates
366
+
367
+ class Residual(nn.Module):
368
+ def __init__(self, dim, scale_residual=False):
369
+ super().__init__()
370
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
371
+
372
+ def forward(self, x, residual):
373
+ if exists(self.residual_scale):
374
+ residual = residual * self.residual_scale
375
+
376
+ return x + residual
377
+
378
+
379
+ class GRUGating(nn.Module):
380
+ def __init__(self, dim, scale_residual=False):
381
+ super().__init__()
382
+ self.gru = nn.GRUCell(dim, dim)
383
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
384
+
385
+ def forward(self, x, residual):
386
+ if exists(self.residual_scale):
387
+ residual = residual * self.residual_scale
388
+
389
+ gated_output = self.gru(
390
+ rearrange(x, 'b n d -> (b n) d'),
391
+ rearrange(residual, 'b n d -> (b n) d')
392
+ )
393
+
394
+ return gated_output.reshape_as(x)
395
+
396
+
397
+ # token shifting
398
+
399
+ def shift(t, amount, mask=None):
400
+ if amount == 0:
401
+ return t
402
+
403
+ if exists(mask):
404
+ t = t.masked_fill(~mask[..., None], 0.)
405
+
406
+ return F.pad(t, (0, 0, amount, -amount), value=0.)
407
+
408
+
409
+ class ShiftTokens(nn.Module):
410
+ def __init__(self, shifts, fn):
411
+ super().__init__()
412
+ self.fn = fn
413
+ self.shifts = tuple(shifts)
414
+
415
+ def forward(self, x, **kwargs):
416
+ mask = kwargs.get('mask', None)
417
+ shifts = self.shifts
418
+ segments = len(shifts)
419
+ feats_per_shift = x.shape[-1] // segments
420
+ splitted = x.split(feats_per_shift, dim=-1)
421
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
422
+ segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
423
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
424
+ return self.fn(x, **kwargs)
425
+
426
+
427
+ # feedforward
428
+
429
+ class GLU(nn.Module):
430
+ def __init__(self, dim_in, dim_out, activation):
431
+ super().__init__()
432
+ self.act = activation
433
+ self.proj = nn.Linear(dim_in, dim_out * 2)
434
+
435
+ def forward(self, x):
436
+ x, gate = self.proj(x).chunk(2, dim=-1)
437
+ return x * self.act(gate)
438
+
439
+
440
+ class FeedForward(nn.Module):
441
+ def __init__(
442
+ self,
443
+ dim,
444
+ dim_out=None,
445
+ mult=4,
446
+ glu=False,
447
+ relu_squared=False,
448
+ post_act_ln=False,
449
+ dropout=0.,
450
+ zero_init_output=False
451
+ ):
452
+ super().__init__()
453
+ inner_dim = int(dim * mult)
454
+ dim_out = default(dim_out, dim)
455
+ activation = ReluSquared() if relu_squared else nn.GELU()
456
+
457
+ project_in = nn.Sequential(
458
+ nn.Linear(dim, inner_dim),
459
+ activation
460
+ ) if not glu else GLU(dim, inner_dim, activation)
461
+
462
+ self.net = nn.Sequential(
463
+ project_in,
464
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
465
+ nn.Dropout(dropout),
466
+ nn.Linear(inner_dim, dim_out)
467
+ )
468
+
469
+ # init last linear layer to 0
470
+ if zero_init_output:
471
+ init_zero_(self.net[-1])
472
+
473
+ def forward(self, x):
474
+ return self.net(x)
475
+
476
+
477
+ # attention.
478
+
479
+ class Attention(nn.Module):
480
+ def __init__(
481
+ self,
482
+ dim,
483
+ dim_head=DEFAULT_DIM_HEAD,
484
+ heads=8,
485
+ causal=False,
486
+ talking_heads=False,
487
+ head_scale=False,
488
+ collab_heads=False,
489
+ collab_compression=.3,
490
+ sparse_topk=None,
491
+ use_entmax15=False,
492
+ num_mem_kv=0,
493
+ dropout=0.,
494
+ on_attn=False,
495
+ gate_values=False,
496
+ zero_init_output=False,
497
+ max_attend_past=None,
498
+ qk_norm=False,
499
+ scale_init_value=None,
500
+ rel_pos_bias=False,
501
+ rel_pos_num_buckets=32,
502
+ rel_pos_max_distance=128,
503
+ ):
504
+ super().__init__()
505
+ self.scale = dim_head ** -0.5
506
+
507
+ self.heads = heads
508
+ self.causal = causal
509
+ self.max_attend_past = max_attend_past
510
+
511
+ qk_dim = v_dim = dim_head * heads
512
+
513
+ # collaborative heads
514
+ self.collab_heads = collab_heads
515
+ if self.collab_heads:
516
+ qk_dim = int(collab_compression * qk_dim)
517
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
518
+
519
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
520
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
521
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
522
+
523
+ self.dropout = nn.Dropout(dropout)
524
+
525
+ # add GLU gating for aggregated values, from alphafold2
526
+ self.to_v_gate = None
527
+ if gate_values:
528
+ self.to_v_gate = nn.Linear(dim, v_dim)
529
+ nn.init.constant_(self.to_v_gate.weight, 0)
530
+ nn.init.constant_(self.to_v_gate.bias, 1)
531
+
532
+ # cosine sim attention
533
+ self.qk_norm = qk_norm
534
+ if qk_norm:
535
+ scale_init_value = default(scale_init_value,
536
+ -3) # if not provided, initialize as though it were sequence length of 1024
537
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
538
+
539
+ # talking heads
540
+ self.talking_heads = talking_heads
541
+ if talking_heads:
542
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
543
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
544
+
545
+ # head scaling
546
+ self.head_scale = head_scale
547
+ if head_scale:
548
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
549
+
550
+ # explicit topk sparse attention
551
+ self.sparse_topk = sparse_topk
552
+
553
+ # entmax
554
+ self.attn_fn = F.softmax
555
+
556
+ # add memory key / values
557
+ self.num_mem_kv = num_mem_kv
558
+ if num_mem_kv > 0:
559
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
560
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
561
+
562
+ # attention on attention
563
+ self.attn_on_attn = on_attn
564
+ self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
565
+
566
+ self.rel_pos_bias = rel_pos_bias
567
+ if rel_pos_bias:
568
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
569
+ self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
570
+ num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
571
+
572
+ # init output projection 0
573
+ if zero_init_output:
574
+ init_zero_(self.to_out)
575
+
576
+ def forward(
577
+ self,
578
+ x,
579
+ context=None,
580
+ mask=None,
581
+ context_mask=None,
582
+ attn_mask=None,
583
+ sinusoidal_emb=None,
584
+ rotary_pos_emb=None,
585
+ prev_attn=None,
586
+ mem=None,
587
+ layer_past=None,
588
+ ):
589
+ b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
590
+ context)
591
+ kv_input = default(context, x)
592
+
593
+ q_input = x
594
+ k_input = kv_input
595
+ v_input = kv_input
596
+
597
+ if exists(mem):
598
+ k_input = torch.cat((mem, k_input), dim=-2)
599
+ v_input = torch.cat((mem, v_input), dim=-2)
600
+
601
+ if exists(sinusoidal_emb):
602
+ # in shortformer, the query would start at a position offset depending on the past cached memory
603
+ offset = k_input.shape[-2] - q_input.shape[-2]
604
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
605
+ k_input = k_input + sinusoidal_emb(k_input)
606
+
607
+ q = self.to_q(q_input)
608
+ k = self.to_k(k_input)
609
+ v = self.to_v(v_input)
610
+
611
+ if not collab_heads:
612
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
613
+ else:
614
+ q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
615
+ k = rearrange(k, 'b n d -> b () n d')
616
+ v = rearrange(v, 'b n (h d) -> b h n d', h=h)
617
+
618
+ if layer_past is not None:
619
+ past_key, past_value = layer_past
620
+ k = torch.cat([past_key, k], dim=-2)
621
+ v = torch.cat([past_value, v], dim=-2)
622
+ k_cache = k
623
+ v_cache = v
624
+
625
+ if exists(rotary_pos_emb) and not has_context:
626
+ l = rotary_pos_emb.shape[-1]
627
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
628
+ ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
629
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
630
+
631
+ input_mask = None
632
+ if any(map(exists, (mask, context_mask))):
633
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
634
+ k_mask = q_mask if not exists(context) else context_mask
635
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
636
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
637
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
638
+ input_mask = q_mask * k_mask
639
+
640
+ if self.num_mem_kv > 0:
641
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
642
+ k = torch.cat((mem_k, k), dim=-2)
643
+ v = torch.cat((mem_v, v), dim=-2)
644
+ if exists(input_mask):
645
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
646
+
647
+ if collab_heads:
648
+ k = k.expand(-1, h, -1, -1)
649
+
650
+ if self.qk_norm:
651
+ q, k = map(l2norm, (q, k))
652
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
653
+
654
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
655
+ mask_value = max_neg_value(dots)
656
+
657
+ if exists(prev_attn):
658
+ dots = dots + prev_attn
659
+
660
+ pre_softmax_attn = dots.clone()
661
+
662
+ if talking_heads:
663
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
664
+
665
+ if self.rel_pos_bias:
666
+ dots = self.rel_pos(dots)
667
+
668
+ if exists(input_mask):
669
+ dots.masked_fill_(~input_mask, mask_value)
670
+ del input_mask
671
+
672
+ if exists(attn_mask):
673
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
674
+ if attn_mask.ndim == 2:
675
+ attn_mask = rearrange(attn_mask, 'i j -> () () i j')
676
+ elif attn_mask.ndim == 3:
677
+ attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
678
+ dots.masked_fill_(~attn_mask, mask_value)
679
+
680
+ if exists(self.max_attend_past):
681
+ i, j = dots.shape[-2:]
682
+ range_q = torch.arange(j - i, j, device=device)
683
+ range_k = torch.arange(j, device=device)
684
+ dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
685
+ mask = dist > self.max_attend_past
686
+ dots.masked_fill_(mask, mask_value)
687
+ del mask
688
+
689
+ if self.causal:
690
+ i, j = dots.shape[-2:]
691
+ r = torch.arange(i, device=device)
692
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
693
+ mask = F.pad(mask, (j - i, 0), value=False)
694
+ dots.masked_fill_(mask, mask_value)
695
+ del mask
696
+
697
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
698
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
699
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
700
+ mask = dots < vk
701
+ dots.masked_fill_(mask, mask_value)
702
+ del mask
703
+
704
+ attn = self.attn_fn(dots, dim=-1)
705
+ post_softmax_attn = attn.clone()
706
+
707
+ attn = self.dropout(attn)
708
+
709
+ if talking_heads:
710
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
711
+
712
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
713
+
714
+ if head_scale:
715
+ out = out * self.head_scale_params
716
+
717
+ out = rearrange(out, 'b h n d -> b n (h d)')
718
+
719
+ if exists(self.to_v_gate):
720
+ gates = self.to_v_gate(x)
721
+ out = out * gates.sigmoid()
722
+
723
+ intermediates = Intermediates(
724
+ pre_softmax_attn=pre_softmax_attn,
725
+ post_softmax_attn=post_softmax_attn
726
+ )
727
+
728
+ return self.to_out(out), intermediates, k_cache, v_cache
729
+
730
+
731
+ class AttentionLayers(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim,
735
+ depth,
736
+ heads=8,
737
+ causal=False,
738
+ cross_attend=False,
739
+ only_cross=False,
740
+ use_scalenorm=False,
741
+ use_rms_scaleshift_norm=False,
742
+ use_rmsnorm=False,
743
+ use_rezero=False,
744
+ alibi_pos_bias=False,
745
+ alibi_num_heads=None,
746
+ alibi_learned=False,
747
+ position_infused_attn=False,
748
+ rotary_pos_emb=False,
749
+ rotary_emb_dim=None,
750
+ custom_layers=None,
751
+ sandwich_coef=None,
752
+ par_ratio=None,
753
+ residual_attn=False,
754
+ cross_residual_attn=False,
755
+ macaron=False,
756
+ pre_norm=True,
757
+ gate_residual=False,
758
+ scale_residual=False,
759
+ shift_tokens=0,
760
+ sandwich_norm=False,
761
+ use_qk_norm_attn=False,
762
+ qk_norm_attn_seq_len=None,
763
+ zero_init_branch_output=False,
764
+ **kwargs
765
+ ):
766
+ super().__init__()
767
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
768
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
769
+
770
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
771
+
772
+ self.dim = dim
773
+ self.depth = depth
774
+ self.layers = nn.ModuleList([])
775
+ self.causal = causal
776
+
777
+ rel_pos_bias = 'rel_pos_bias' in attn_kwargs
778
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
779
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
780
+
781
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
782
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
783
+
784
+ assert not (
785
+ alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
786
+
787
+ if alibi_pos_bias:
788
+ alibi_num_heads = default(alibi_num_heads, heads)
789
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
790
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
791
+ self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
792
+ else:
793
+ self.rel_pos = None
794
+
795
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
796
+ self.pre_norm = pre_norm
797
+ self.sandwich_norm = sandwich_norm
798
+
799
+ self.residual_attn = residual_attn
800
+ self.cross_residual_attn = cross_residual_attn
801
+ self.cross_attend = cross_attend
802
+
803
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
804
+ norm_class = RMSNorm if use_rmsnorm else norm_class
805
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
806
+ norm_fn = partial(norm_class, dim)
807
+
808
+ norm_fn = nn.Identity if use_rezero else norm_fn
809
+ branch_fn = Rezero if use_rezero else None
810
+
811
+ if cross_attend and not only_cross:
812
+ default_block = ('a', 'c', 'f')
813
+ elif cross_attend and only_cross:
814
+ default_block = ('c', 'f')
815
+ else:
816
+ default_block = ('a', 'f')
817
+
818
+ if macaron:
819
+ default_block = ('f',) + default_block
820
+
821
+ # qk normalization
822
+
823
+ if use_qk_norm_attn:
824
+ attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
825
+ qk_norm_attn_seq_len) else None
826
+ attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
827
+
828
+ # zero init
829
+
830
+ if zero_init_branch_output:
831
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
832
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
833
+
834
+ # calculate layer block order
835
+
836
+ if exists(custom_layers):
837
+ layer_types = custom_layers
838
+ elif exists(par_ratio):
839
+ par_depth = depth * len(default_block)
840
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
841
+ default_block = tuple(filter(not_equals('f'), default_block))
842
+ par_attn = par_depth // par_ratio
843
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
844
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
845
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
846
+ par_block = default_block + ('f',) * (par_width - len(default_block))
847
+ par_head = par_block * par_attn
848
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
849
+ elif exists(sandwich_coef):
850
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
851
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
852
+ else:
853
+ layer_types = default_block * depth
854
+
855
+ self.layer_types = layer_types
856
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
857
+
858
+ # calculate token shifting
859
+
860
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
861
+
862
+ # iterate and construct layers
863
+
864
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
865
+ is_last_layer = ind == (len(self.layer_types) - 1)
866
+
867
+ if layer_type == 'a':
868
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
869
+ elif layer_type == 'c':
870
+ layer = Attention(dim, heads=heads, **attn_kwargs)
871
+ elif layer_type == 'f':
872
+ layer = FeedForward(dim, **ff_kwargs)
873
+ layer = layer if not macaron else Scale(0.5, layer)
874
+ else:
875
+ raise Exception(f'invalid layer type {layer_type}')
876
+
877
+ if layer_shift_tokens > 0:
878
+ shift_range_upper = layer_shift_tokens + 1
879
+ shift_range_lower = -layer_shift_tokens if not causal else 0
880
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
881
+
882
+ if exists(branch_fn):
883
+ layer = branch_fn(layer)
884
+
885
+ residual_fn = GRUGating if gate_residual else Residual
886
+ residual = residual_fn(dim, scale_residual=scale_residual)
887
+
888
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
889
+
890
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
891
+ post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
892
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
893
+
894
+ norms = nn.ModuleList([
895
+ pre_branch_norm,
896
+ post_branch_norm,
897
+ post_main_norm
898
+ ])
899
+
900
+ self.layers.append(nn.ModuleList([
901
+ norms,
902
+ layer,
903
+ residual
904
+ ]))
905
+
906
+ def forward(
907
+ self,
908
+ x,
909
+ context=None,
910
+ full_context=None, # for passing a list of hidden states from an encoder
911
+ mask=None,
912
+ context_mask=None,
913
+ attn_mask=None,
914
+ mems=None,
915
+ return_hiddens=False,
916
+ norm_scale_shift_inp=None,
917
+ past_key_values=None,
918
+ expected_seq_len=None,
919
+ ):
920
+
921
+ assert not (self.cross_attend ^ (exists(context) or exists(
922
+ full_context))), 'context must be passed in if cross_attend is set to True'
923
+ assert context is None or full_context is None, 'only one of full_context or context can be provided'
924
+
925
+ hiddens = []
926
+ intermediates = []
927
+ prev_attn = None
928
+ prev_cross_attn = None
929
+
930
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
931
+ norm_args = {}
932
+ if exists(norm_scale_shift_inp):
933
+ norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
934
+
935
+ rotary_pos_emb = None
936
+ if exists(self.rotary_pos_emb):
937
+ if not self.training and self.causal:
938
+ assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
939
+ elif expected_seq_len is None:
940
+ expected_seq_len = 0
941
+ seq_len = x.shape[1]
942
+ if past_key_values is not None:
943
+ seq_len += past_key_values[0][0].shape[-2]
944
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
945
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
946
+
947
+ present_key_values = []
948
+ cross_attn_count = 0
949
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
950
+ if layer_type == 'a':
951
+ layer_mem = mems.pop(0) if mems else None
952
+
953
+ residual = x
954
+
955
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
956
+
957
+ if exists(pre_branch_norm):
958
+ x = pre_branch_norm(x, **norm_args)
959
+
960
+ if layer_type == 'a' or layer_type == 'c':
961
+ if past_key_values is not None:
962
+ layer_kv = past_key_values.pop(0)
963
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
964
+ else:
965
+ layer_past = None
966
+
967
+ if layer_type == 'a':
968
+ out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
969
+ prev_attn, layer_mem, layer_past)
970
+ elif layer_type == 'c':
971
+ if exists(full_context):
972
+ out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
973
+ None, prev_attn, None, layer_past)
974
+ else:
975
+ out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
976
+ elif layer_type == 'f':
977
+ out = block(x)
978
+
979
+ if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
980
+ present_key_values.append((k.detach(), v.detach()))
981
+
982
+ if exists(post_branch_norm):
983
+ out = post_branch_norm(out, **norm_args)
984
+
985
+ x = residual_fn(out, residual)
986
+
987
+ if layer_type in ('a', 'c'):
988
+ intermediates.append(inter)
989
+
990
+ if layer_type == 'a' and self.residual_attn:
991
+ prev_attn = inter.pre_softmax_attn
992
+ elif layer_type == 'c' and self.cross_residual_attn:
993
+ prev_cross_attn = inter.pre_softmax_attn
994
+
995
+ if exists(post_main_norm):
996
+ x = post_main_norm(x, **norm_args)
997
+
998
+ if layer_type == 'c':
999
+ cross_attn_count += 1
1000
+
1001
+ if layer_type == 'f':
1002
+ hiddens.append(x)
1003
+
1004
+ if return_hiddens:
1005
+ intermediates = LayerIntermediates(
1006
+ hiddens=hiddens,
1007
+ attn_intermediates=intermediates,
1008
+ past_key_values=present_key_values
1009
+ )
1010
+
1011
+ return x, intermediates
1012
+
1013
+ return x
1014
+
1015
+
1016
+ class Encoder(AttentionLayers):
1017
+ def __init__(self, **kwargs):
1018
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
1019
+ super().__init__(causal=False, **kwargs)
1020
+
1021
+
1022
+ class Decoder(AttentionLayers):
1023
+ def __init__(self, **kwargs):
1024
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1025
+ super().__init__(causal=True, **kwargs)
1026
+
1027
+
1028
+ class CrossAttender(AttentionLayers):
1029
+ def __init__(self, **kwargs):
1030
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1031
+
1032
+
1033
+ class ViTransformerWrapper(nn.Module):
1034
+ def __init__(
1035
+ self,
1036
+ *,
1037
+ image_size,
1038
+ patch_size,
1039
+ attn_layers,
1040
+ num_classes=None,
1041
+ dropout=0.,
1042
+ emb_dropout=0.
1043
+ ):
1044
+ super().__init__()
1045
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1046
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1047
+ dim = attn_layers.dim
1048
+ num_patches = (image_size // patch_size) ** 2
1049
+ patch_dim = 3 * patch_size ** 2
1050
+
1051
+ self.patch_size = patch_size
1052
+
1053
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1054
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1055
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1056
+ self.dropout = nn.Dropout(emb_dropout)
1057
+
1058
+ self.attn_layers = attn_layers
1059
+ self.norm = nn.LayerNorm(dim)
1060
+ self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
1061
+
1062
+ def forward(
1063
+ self,
1064
+ img,
1065
+ return_embeddings=False
1066
+ ):
1067
+ p = self.patch_size
1068
+
1069
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
1070
+ x = self.patch_to_embedding(x)
1071
+ b, n, _ = x.shape
1072
+
1073
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
1074
+ x = torch.cat((cls_tokens, x), dim=1)
1075
+ x = x + self.pos_embedding[:, :(n + 1)]
1076
+ x = self.dropout(x)
1077
+
1078
+ x = self.attn_layers(x)
1079
+ x = self.norm(x)
1080
+
1081
+ if not exists(self.mlp_head) or return_embeddings:
1082
+ return x
1083
+
1084
+ return self.mlp_head(x[:, 0])
1085
+
1086
+
1087
+ class TransformerWrapper(nn.Module):
1088
+ def __init__(
1089
+ self,
1090
+ *,
1091
+ num_tokens,
1092
+ max_seq_len,
1093
+ attn_layers,
1094
+ emb_dim=None,
1095
+ max_mem_len=0.,
1096
+ shift_mem_down=0,
1097
+ emb_dropout=0.,
1098
+ num_memory_tokens=None,
1099
+ tie_embedding=False,
1100
+ use_pos_emb=True
1101
+ ):
1102
+ super().__init__()
1103
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1104
+
1105
+ dim = attn_layers.dim
1106
+ emb_dim = default(emb_dim, dim)
1107
+
1108
+ self.max_seq_len = max_seq_len
1109
+ self.max_mem_len = max_mem_len
1110
+ self.shift_mem_down = shift_mem_down
1111
+
1112
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1113
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
1114
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1115
+ self.emb_dropout = nn.Dropout(emb_dropout)
1116
+
1117
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1118
+ self.attn_layers = attn_layers
1119
+ self.norm = nn.LayerNorm(dim)
1120
+
1121
+ self.init_()
1122
+
1123
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1124
+
1125
+ # memory tokens (like [cls]) from Memory Transformers paper
1126
+ num_memory_tokens = default(num_memory_tokens, 0)
1127
+ self.num_memory_tokens = num_memory_tokens
1128
+ if num_memory_tokens > 0:
1129
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1130
+
1131
+ def init_(self):
1132
+ nn.init.kaiming_normal_(self.token_emb.weight)
1133
+
1134
+ def forward(
1135
+ self,
1136
+ x,
1137
+ return_embeddings=False,
1138
+ mask=None,
1139
+ return_hiddens=False,
1140
+ return_attn=False,
1141
+ mems=None,
1142
+ use_cache=False,
1143
+ **kwargs
1144
+ ):
1145
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1146
+ x = self.token_emb(x)
1147
+ x = x + self.pos_emb(x)
1148
+ x = self.emb_dropout(x)
1149
+
1150
+ x = self.project_emb(x)
1151
+
1152
+ if num_mem > 0:
1153
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
1154
+ x = torch.cat((mem, x), dim=1)
1155
+
1156
+ # auto-handle masking after appending memory tokens
1157
+ if exists(mask):
1158
+ mask = F.pad(mask, (num_mem, 0), value=True)
1159
+
1160
+ if self.shift_mem_down and exists(mems):
1161
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1162
+ mems = [*mems_r, *mems_l]
1163
+
1164
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1165
+ x = self.norm(x)
1166
+
1167
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1168
+
1169
+ out = self.to_logits(x) if not return_embeddings else x
1170
+
1171
+ if return_hiddens:
1172
+ hiddens = intermediates.hiddens
1173
+ return out, hiddens
1174
+
1175
+ res = [out]
1176
+ if return_attn:
1177
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1178
+ res.append(attn_maps)
1179
+ if use_cache:
1180
+ res.append(intermediates.past_key_values)
1181
+
1182
+ if len(res) > 1:
1183
+ return tuple(res)
1184
+ return res[0]
1185
+
1186
+
1187
+ class ContinuousTransformerWrapper(nn.Module):
1188
+ def __init__(
1189
+ self,
1190
+ *,
1191
+ max_seq_len,
1192
+ attn_layers,
1193
+ dim_in=None,
1194
+ dim_out=None,
1195
+ emb_dim=None,
1196
+ emb_dropout=0.,
1197
+ use_pos_emb=True
1198
+ ):
1199
+ super().__init__()
1200
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1201
+
1202
+ dim = attn_layers.dim
1203
+
1204
+ self.max_seq_len = max_seq_len
1205
+
1206
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
1207
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1208
+ self.emb_dropout = nn.Dropout(emb_dropout)
1209
+
1210
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1211
+
1212
+ self.attn_layers = attn_layers
1213
+ self.norm = nn.LayerNorm(dim)
1214
+
1215
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ x,
1220
+ return_embeddings=False,
1221
+ mask=None,
1222
+ return_attn=False,
1223
+ mems=None,
1224
+ use_cache=False,
1225
+ **kwargs
1226
+ ):
1227
+ b, n, _, device = *x.shape, x.device
1228
+
1229
+ x = self.project_in(x)
1230
+ x = x + self.pos_emb(x)
1231
+ x = self.emb_dropout(x)
1232
+
1233
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1234
+ x = self.norm(x)
1235
+
1236
+ out = self.project_out(x) if not return_embeddings else x
1237
+
1238
+ res = [out]
1239
+ if return_attn:
1240
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1241
+ res.append(attn_maps)
1242
+ if use_cache:
1243
+ res.append(intermediates.past_key_values)
1244
+
1245
+ if len(res) > 1:
1246
+ return tuple(res)
1247
+ return res[0]
1248
+
speech/tools/download_dataset.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ from pathlib import Path
5
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
6
+ from multiprocessing import cpu_count, Queue, Process
7
+ from threading import Lock
8
+ from queue import Empty
9
+ import logging
10
+ from tqdm import tqdm
11
+ from functools import partial
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ import soundfile as sf
17
+ from datasets import load_dataset
18
+
19
+ # Set up logging
20
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Global settings
24
+ CHUNK_SIZE = 100 # Process this many samples before showing progress
25
+ MAX_WORKERS = cpu_count() - 1 # Leave one CPU free
26
+ PREFETCH_SIZE = 50 # How many samples to prefetch
27
+
28
+ def process_single_sample(data, root_save_path):
29
+ """Process a single sample from the dataset"""
30
+ try:
31
+ metadata = data['json']
32
+
33
+ # Prepare paths
34
+ out_wav_path = os.path.join(root_save_path, metadata['wav'].replace('/mp3', '').replace('.mp3', '.wav'))
35
+ os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
36
+ out_text_path = out_wav_path.replace('.wav', '.txt')
37
+ # check if the files are existsed
38
+ if os.path.exists(out_wav_path) and os.path.exists(out_text_path):
39
+ return metadata['id'], True, None
40
+
41
+ # Save text
42
+ with open(out_text_path, 'w') as f:
43
+ f.write(metadata['text'])
44
+
45
+ # Decode and save audio
46
+ raw_bytes = data['mp3']._hf_encoded['bytes']
47
+ audio_tensor, sample_rate = torchaudio.load(io.BytesIO(raw_bytes), format='mp3')
48
+ audio_array = audio_tensor.squeeze().numpy()
49
+ sf.write(out_wav_path, audio_array, sample_rate)
50
+
51
+ return metadata['id'], True, None
52
+ except Exception as e:
53
+ return data.get('json', {}).get('id', 'unknown'), False, str(e)
54
+
55
+ def process_batch(batch_data, root_save_path):
56
+ """Process a batch of samples"""
57
+ results = []
58
+ for data in batch_data:
59
+ result = process_single_sample(data, root_save_path)
60
+ results.append(result)
61
+ return results
62
+
63
+ class ParallelDatasetProcessor:
64
+ """Main class for parallel processing of the dataset"""
65
+
66
+ def __init__(self, language, root_save_path, num_workers=None):
67
+ self.language = language
68
+ self.root_save_path = root_save_path
69
+ self.num_workers = num_workers or MAX_WORKERS
70
+ self.processed_count = 0
71
+ self.error_count = 0
72
+ self.lock = Lock()
73
+
74
+ def process_with_multiprocessing(self):
75
+ """Process dataset using multiprocessing (fastest for CPU-bound tasks)"""
76
+ logger.info(f"Starting multiprocessing with {self.num_workers} workers")
77
+
78
+ # Load dataset
79
+ path = f"Emilia/{self.language.upper()}/*.tar"
80
+ dataset = load_dataset(
81
+ "amphion/Emilia-Dataset",
82
+ data_files={self.language: path},
83
+ split=self.language,
84
+ streaming=True
85
+ )
86
+
87
+ # Create process pool
88
+ with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
89
+ # Submit jobs in batches
90
+ futures = []
91
+ batch = []
92
+
93
+ # Progress bar
94
+ pbar = tqdm(desc="Processing samples", unit="samples")
95
+
96
+ for data in dataset:
97
+ batch.append(data)
98
+
99
+ if len(batch) >= CHUNK_SIZE:
100
+ # Submit batch for processing
101
+ future = executor.submit(process_batch, batch, self.root_save_path)
102
+ futures.append(future)
103
+ batch = []
104
+
105
+ # Process completed futures
106
+ self._process_completed_futures(futures, pbar, max_pending=self.num_workers * 2)
107
+
108
+ # Submit remaining batch
109
+ if batch:
110
+ future = executor.submit(process_batch, batch, self.root_save_path)
111
+ futures.append(future)
112
+
113
+ # Wait for all remaining futures
114
+ for future in as_completed(futures):
115
+ results = future.result()
116
+ for sample_id, success, error in results:
117
+ if success:
118
+ self.processed_count += 1
119
+ else:
120
+ self.error_count += 1
121
+ logger.error(f"Error processing {sample_id}: {error}")
122
+ pbar.update(1)
123
+
124
+ pbar.close()
125
+
126
+ def process_with_threading(self):
127
+ """Process dataset using threading (good for I/O-bound tasks)"""
128
+ logger.info(f"Starting threading with {self.num_workers} workers")
129
+
130
+ # Load dataset
131
+ path = f"Emilia/{self.language.upper()}/*.tar"
132
+ dataset = load_dataset(
133
+ "amphion/Emilia-Dataset",
134
+ data_files={self.language: path},
135
+ split=self.language,
136
+ streaming=True
137
+ )
138
+
139
+ # Create thread pool
140
+ with ThreadPoolExecutor(max_workers=self.num_workers * 2) as executor:
141
+ futures = []
142
+ pbar = tqdm(desc="Processing samples", unit="samples")
143
+
144
+ for data in dataset:
145
+ # Submit individual samples
146
+ future = executor.submit(process_single_sample, data, self.root_save_path)
147
+ futures.append((future, data.get('json', {}).get('id', 'unknown')))
148
+
149
+ # Process completed futures
150
+ if len(futures) >= self.num_workers * 4:
151
+ completed = []
152
+ for i, (future, sample_id) in enumerate(futures):
153
+ if future.done():
154
+ try:
155
+ _, success, error = future.result()
156
+ if success:
157
+ self.processed_count += 1
158
+ else:
159
+ self.error_count += 1
160
+ logger.error(f"Error processing {sample_id}: {error}")
161
+ pbar.update(1)
162
+ completed.append(i)
163
+ except Exception as e:
164
+ logger.error(f"Exception processing {sample_id}: {e}")
165
+ self.error_count += 1
166
+ pbar.update(1)
167
+ completed.append(i)
168
+
169
+ # Remove completed futures
170
+ for i in reversed(completed):
171
+ futures.pop(i)
172
+
173
+ # Wait for remaining futures
174
+ for future, sample_id in futures:
175
+ try:
176
+ _, success, error = future.result()
177
+ if success:
178
+ self.processed_count += 1
179
+ else:
180
+ self.error_count += 1
181
+ logger.error(f"Error processing {sample_id}: {error}")
182
+ pbar.update(1)
183
+ except Exception as e:
184
+ logger.error(f"Exception processing {sample_id}: {e}")
185
+ self.error_count += 1
186
+ pbar.update(1)
187
+
188
+ pbar.close()
189
+
190
+ def process_with_producer_consumer(self):
191
+ """Process dataset using producer-consumer pattern"""
192
+ logger.info(f"Starting producer-consumer with {self.num_workers} workers")
193
+
194
+ # Create queues
195
+ work_queue = Queue(maxsize=PREFETCH_SIZE)
196
+ result_queue = Queue()
197
+
198
+ # Start producer
199
+ producer = Process(target=self._producer, args=(work_queue,))
200
+ producer.start()
201
+
202
+ # Start consumers
203
+ consumers = []
204
+ for i in range(self.num_workers):
205
+ consumer = Process(target=self._consumer, args=(work_queue, result_queue, i))
206
+ consumer.start()
207
+ consumers.append(consumer)
208
+
209
+ # Start result processor
210
+ result_processor = Process(target=self._result_processor, args=(result_queue,))
211
+ result_processor.start()
212
+
213
+ # Wait for completion
214
+ producer.join()
215
+
216
+ # Signal consumers to stop
217
+ for _ in range(self.num_workers):
218
+ work_queue.put(None)
219
+
220
+ for consumer in consumers:
221
+ consumer.join()
222
+
223
+ # Signal result processor to stop
224
+ result_queue.put(None)
225
+ result_processor.join()
226
+
227
+ def _producer(self, work_queue):
228
+ """Producer process that reads from dataset"""
229
+ path = f"Emilia/{self.language.upper()}/*.tar"
230
+ dataset = load_dataset(
231
+ "amphion/Emilia-Dataset",
232
+ data_files={self.language: path},
233
+ split=self.language,
234
+ streaming=True
235
+ )
236
+
237
+ for data in dataset:
238
+ work_queue.put(data)
239
+
240
+ def _consumer(self, work_queue, result_queue, worker_id):
241
+ """Consumer process that processes samples"""
242
+ while True:
243
+ try:
244
+ data = work_queue.get(timeout=1)
245
+ if data is None:
246
+ break
247
+
248
+ result = process_single_sample(data, self.root_save_path)
249
+ result_queue.put(result)
250
+ except Empty:
251
+ continue
252
+ except Exception as e:
253
+ logger.error(f"Worker {worker_id} error: {e}")
254
+
255
+ def _result_processor(self, result_queue):
256
+ """Process results and update progress"""
257
+ pbar = tqdm(desc="Processing samples", unit="samples")
258
+
259
+ while True:
260
+ try:
261
+ result = result_queue.get(timeout=1)
262
+ if result is None:
263
+ break
264
+
265
+ sample_id, success, error = result
266
+ if success:
267
+ self.processed_count += 1
268
+ else:
269
+ self.error_count += 1
270
+ logger.error(f"Error processing {sample_id}: {error}")
271
+ pbar.update(1)
272
+ except Empty:
273
+ continue
274
+
275
+ pbar.close()
276
+
277
+ def _process_completed_futures(self, futures, pbar, max_pending):
278
+ """Process completed futures to avoid memory buildup"""
279
+ while len(futures) > max_pending:
280
+ # Wait for at least one to complete
281
+ completed_futures = []
282
+ for i, future in enumerate(futures):
283
+ if future.done():
284
+ results = future.result()
285
+ for sample_id, success, error in results:
286
+ if success:
287
+ self.processed_count += 1
288
+ else:
289
+ self.error_count += 1
290
+ logger.error(f"Error processing {sample_id}: {error}")
291
+ pbar.update(1)
292
+ completed_futures.append(i)
293
+
294
+ # Remove completed futures
295
+ for i in reversed(completed_futures):
296
+ futures.pop(i)
297
+
298
+ if not completed_futures:
299
+ # If nothing completed, wait a bit
300
+ time.sleep(0.1)
301
+
302
+ def main():
303
+ """Main function with different processing options"""
304
+ language = "en"
305
+ root_save_path = f'/data/emilia/{language}'
306
+ os.makedirs(root_save_path, exist_ok=True)
307
+
308
+ # Choose processing method
309
+ processor = ParallelDatasetProcessor(language, root_save_path)
310
+
311
+ # Method 1: Multiprocessing (recommended for CPU-bound tasks)
312
+ logger.info("Starting parallel processing...")
313
+ start_time = time.time()
314
+
315
+ # You can choose one of these methods:
316
+ processor.process_with_multiprocessing() # Fastest for this use case
317
+ # processor.process_with_threading() # Good for I/O heavy tasks
318
+ # processor.process_with_producer_consumer() # Good for memory efficiency
319
+
320
+ elapsed_time = time.time() - start_time
321
+ logger.info(f"Processing completed in {elapsed_time:.2f} seconds")
322
+ logger.info(f"Successfully processed: {processor.processed_count} samples")
323
+ logger.info(f"Errors: {processor.error_count} samples")
324
+
325
+ # Simplified version for quick use
326
+ def simple_parallel_process():
327
+ """Simplified parallel processing function"""
328
+ from joblib import Parallel, delayed
329
+
330
+ language = "en"
331
+ root_save_path = f'/data/emilia/{language}'
332
+ os.makedirs(root_save_path, exist_ok=True)
333
+
334
+ # Load dataset
335
+ path = f"Emilia/{language.upper()}/*.tar"
336
+ dataset = load_dataset(
337
+ "amphion/Emilia-Dataset",
338
+ data_files={language: path},
339
+ split=language,
340
+ streaming=True
341
+ )
342
+
343
+ # Collect samples in batches
344
+ batch_size = 1000
345
+ batch = []
346
+
347
+ def process_batch_simple(batch_data):
348
+ for data in batch_data:
349
+ process_single_sample(data, root_save_path)
350
+
351
+ # Process in parallel batches
352
+ for i, data in enumerate(tqdm(dataset, desc="Loading samples")):
353
+ batch.append(data)
354
+
355
+ if len(batch) >= batch_size:
356
+ # Process batch in parallel
357
+ Parallel(n_jobs=-1, backend='multiprocessing')(
358
+ delayed(process_single_sample)(sample, root_save_path)
359
+ for sample in batch
360
+ )
361
+ batch = []
362
+
363
+ # Process remaining batch
364
+ if batch:
365
+ Parallel(n_jobs=-1, backend='multiprocessing')(
366
+ delayed(process_single_sample)(sample, root_save_path)
367
+ for sample in batch
368
+ )
369
+
370
+ if __name__ == "__main__":
371
+ main()
speech/tools/download_vivoice.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
4
+ from functools import partial
5
+ import multiprocessing
6
+ from typing import Dict, Any, Optional
7
+
8
+ from datasets import load_dataset
9
+ import soundfile as sf
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+
13
+
14
+ def process_single_item(data: Dict[str, Any], root_save_path: str) -> Optional[str]:
15
+ """
16
+ Process a single audio item: extract audio, convert, and save with text.
17
+
18
+ Args:
19
+ data: Dictionary containing audio data and text
20
+ root_save_path: Root directory for saving files
21
+
22
+ Returns:
23
+ Audio file path if successful, None if error
24
+ """
25
+ try:
26
+ # Extract audio data
27
+ raw_bytes = data['audio']._hf_encoded['bytes']
28
+ audio_name = data['audio']._hf_encoded['path']
29
+ # Prepare output paths
30
+ out_wav_path = os.path.join(root_save_path, audio_name)
31
+ out_text_path = out_wav_path.replace('.wav', '.txt')
32
+
33
+ if os.path.exists(out_wav_path) and os.path.exists(out_text_path):
34
+ # print(f'skip {out_wav_path}')
35
+ return out_wav_path
36
+
37
+ text = data['text']
38
+
39
+ # Load audio from bytes
40
+ bytes_io = io.BytesIO(raw_bytes)
41
+ audio_tensor, sample_rate = torchaudio.load(bytes_io)
42
+ audio_array = audio_tensor.squeeze().numpy()
43
+
44
+
45
+
46
+ # Create directory if needed
47
+ os.makedirs(os.path.dirname(out_wav_path), exist_ok=True)
48
+
49
+ # Save audio and text
50
+ sf.write(out_wav_path, audio_array, sample_rate)
51
+ with open(out_text_path, 'w', encoding='utf-8') as f:
52
+ f.write(text)
53
+
54
+ return out_wav_path
55
+
56
+ except Exception as e:
57
+ print(f"Error processing {data.get('audio', {})._hf_encoded.get('path', 'unknown')}: {e}")
58
+ return None
59
+
60
+
61
+ def process_dataset_parallel(
62
+ dataset_name: str = "capleaf/viVoice",
63
+ root_save_path: str = '/mnt/nvme-temp/vivoice',
64
+ max_workers: Optional[int] = None,
65
+ batch_size: int = 100,
66
+ limit: Optional[int] = None,
67
+ use_threads: bool = False
68
+ ):
69
+ """
70
+ Process dataset in parallel with progress tracking.
71
+
72
+ Args:
73
+ dataset_name: Name of the HuggingFace dataset
74
+ root_save_path: Root directory for saving files
75
+ max_workers: Maximum number of parallel workers (None for CPU count)
76
+ batch_size: Number of items to process in each batch
77
+ limit: Maximum number of items to process (None for all)
78
+ use_threads: Use ThreadPoolExecutor instead of ProcessPoolExecutor
79
+ """
80
+ # Load dataset
81
+ ds = load_dataset(dataset_name, streaming=True)
82
+
83
+ # Set up executor
84
+ if max_workers is None:
85
+ max_workers = multiprocessing.cpu_count()
86
+
87
+ Executor = ThreadPoolExecutor if use_threads else ProcessPoolExecutor
88
+
89
+ # Process each split
90
+ for mode in ds:
91
+ print(f"\nProcessing '{mode}' split...")
92
+
93
+ # Create partial function with root_save_path
94
+ process_func = partial(process_single_item, root_save_path=root_save_path)
95
+
96
+ # Collect items in batches for better progress tracking
97
+ batch = []
98
+ processed_count = 0
99
+
100
+ with Executor(max_workers=max_workers) as executor:
101
+ # Create progress bar
102
+ pbar = tqdm(desc=f"Processing {mode}", unit="files")
103
+
104
+ for idx, data in enumerate(ds[mode]):
105
+ batch.append(data)
106
+
107
+ # Process batch when full or at limit
108
+ if len(batch) >= batch_size or (limit and idx + 1 >= limit):
109
+ # Submit batch for processing
110
+ futures = [executor.submit(process_func, item) for item in batch]
111
+
112
+ # Wait for completion and update progress
113
+ for future in futures:
114
+ result = future.result()
115
+ if result:
116
+ processed_count += 1
117
+ pbar.update(1)
118
+
119
+ batch = []
120
+
121
+ # Stop if limit reached
122
+ if limit and idx + 1 >= limit:
123
+ break
124
+
125
+ # Process remaining items
126
+ if batch:
127
+ futures = [executor.submit(process_func, item) for item in batch]
128
+ for future in futures:
129
+ result = future.result()
130
+ if result:
131
+ processed_count += 1
132
+ pbar.update(1)
133
+
134
+ pbar.close()
135
+
136
+ print(f"Completed {mode}: {processed_count} files processed successfully")
137
+
138
+
139
+ def process_dataset_streaming(
140
+ dataset_name: str = "capleaf/viVoice",
141
+ root_save_path: str = '/data/vivoice',
142
+ max_workers: Optional[int] = None,
143
+ limit: Optional[int] = None
144
+ ):
145
+ """
146
+ Alternative approach using streaming with thread pool for I/O bound operations.
147
+ More memory efficient for very large datasets.
148
+ """
149
+ ds = load_dataset(dataset_name, streaming=True)
150
+
151
+ if max_workers is None:
152
+ max_workers = min(32, multiprocessing.cpu_count() * 4) # More threads for I/O
153
+
154
+ for mode in ds:
155
+ print(f"\nProcessing '{mode}' split...")
156
+
157
+ process_func = partial(process_single_item, root_save_path=root_save_path)
158
+
159
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
160
+ # Submit items as they come from the stream
161
+ futures = []
162
+
163
+ for idx, data in enumerate(tqdm(ds[mode], desc=f"Submitting {mode}")):
164
+ future = executor.submit(process_func, data)
165
+ futures.append(future)
166
+
167
+ # Limit number of pending futures to control memory usage
168
+ if len(futures) >= max_workers * 2:
169
+ # Wait for some to complete
170
+ for f in futures[:max_workers]:
171
+ f.result()
172
+ futures = futures[max_workers:]
173
+
174
+ if limit and idx + 1 >= limit:
175
+ break
176
+
177
+ # Wait for remaining futures
178
+ for future in tqdm(futures, desc="Finishing"):
179
+ future.result()
180
+
181
+
182
+ if __name__ == "__main__":
183
+ # Example usage - choose one approach:
184
+
185
+ # Approach 1: Process with multiprocessing (good for CPU-bound operations)
186
+ process_dataset_parallel(
187
+ dataset_name="capleaf/viVoice",
188
+ root_save_path='/data/vivoice',
189
+ max_workers=24, # Adjust based on your system
190
+ batch_size=100,
191
+ limit=None, # Remove to process all data
192
+ use_threads=False
193
+ )
194
+
195
+ # Approach 2: Process with threading (good for I/O-bound operations)
196
+ # process_dataset_parallel(
197
+ # dataset_name="capleaf/viVoice",
198
+ # root_save_path='/mnt/nvme-temp/vivoice',
199
+ # max_workers=16, # Can use more threads for I/O
200
+ # batch_size=100,
201
+ # limit=None,
202
+ # use_threads=True
203
+ # )
204
+
205
+ # Approach 3: Streaming approach (most memory efficient)
206
+ # process_dataset_streaming(
207
+ # dataset_name="capleaf/viVoice",
208
+ # root_save_path='/mnt/nvme-temp/vivoice',
209
+ # max_workers=16,
210
+ # limit=None