marcoyang commited on
Commit
5d423ae
·
verified ·
1 Parent(s): a820d32

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. modeling_spear.py +3 -61
  2. spear_model.py +4 -5
  3. spear_modules.py +52 -0
  4. zipformer.py +1 -1
modeling_spear.py CHANGED
@@ -1,10 +1,8 @@
1
  # modeling_spear.py
2
 
3
- import torch
4
-
5
  from transformers import PreTrainedModel
6
- from configuration_spear import SpearConfig
7
- from spear_model import SpearModel as model
8
 
9
 
10
  class SpearModel(PreTrainedModel):
@@ -18,60 +16,4 @@ class SpearModel(PreTrainedModel):
18
  return self.model(*args, **kwargs)
19
 
20
  def load_audio(self, audio_path):
21
- return self.model.load_audio(audio_path)
22
-
23
- @classmethod
24
- def from_legacy_checkpoint(cls, path, config):
25
- model = cls(config)
26
- ckpt = torch.load(path)["model"]
27
- info = model.model.model.load_state_dict(ckpt, strict=False)
28
- print(info)
29
- return model
30
-
31
- def export_to_hf():
32
- ckpt = "/mnt/shared-storage-user/housiyuan/xiaoyu/models/spear_encoders/94m-uni-v2-dual-domain-mvq/iter-400000-avg-4.pt"
33
- config = SpearConfig()
34
- my_model = SpearModel.from_legacy_checkpoint(ckpt, config)
35
- my_model.save_pretrained("/mnt/shared-storage-user/housiyuan/xiaoyu/models/spear_encoders_hf/spear_base_speech_audio")
36
-
37
-
38
- def _test_from_pretrained():
39
- device = torch.device("cpu")
40
- if torch.cuda.is_available():
41
- device = torch.device("cuda")
42
-
43
- audio_file = [
44
- "common_voice_af_39597042.wav",
45
- # "1284-1180-0027.flac",
46
- ]
47
-
48
-
49
- config= SpearConfig.from_pretrained("/mnt/shared-storage-user/housiyuan/xiaoyu/models/spear_encoders_hf/spear_base_speech_audio")
50
- my_model= SpearModel.from_pretrained("/mnt/shared-storage-user/housiyuan/xiaoyu/models/spear_encoders_hf/spear_base_speech_audio", config=config)
51
- my_model.eval()
52
- my_model.to(device)
53
-
54
- num_params = sum([p.numel() for p in my_model.parameters()])
55
- print(f"A total of {num_params} model parameters.")
56
-
57
- audio, audio_len = my_model.load_audio(audio_file)
58
- audio = audio.to(device)
59
- audio_len = audio_len.to(device)
60
- with torch.no_grad():
61
- outputs = my_model(audio, audio_len)
62
-
63
- encoder_out = outputs["encoder_out"] # (N,T,C)
64
- encoder_out_lens = outputs["encoder_out_lens"] # (N)
65
- middle_out = outputs["hidden_states"] # list of (N,T,C)
66
-
67
- print(encoder_out)
68
- print(encoder_out_lens)
69
- print(middle_out[0].shape)
70
-
71
-
72
- if __name__=="__main__":
73
- export_to_hf()
74
- _test_from_pretrained()
75
-
76
-
77
-
 
1
  # modeling_spear.py
2
 
 
 
3
  from transformers import PreTrainedModel
4
+ from .configuration_spear import SpearConfig
5
+ from .spear_model import SpearModel as model
6
 
7
 
8
  class SpearModel(PreTrainedModel):
 
16
  return self.model(*args, **kwargs)
17
 
18
  def load_audio(self, audio_path):
19
+ return self.model.load_audio(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spear_model.py CHANGED
@@ -26,8 +26,8 @@ import torch.nn as nn
26
  from torch.nn.utils.rnn import pad_sequence
27
  from torchaudio.compliance.kaldi import fbank as torch_fbank
28
 
29
- from configuration_spear import SpearConfig
30
- from zipformer import Zipformer2, Conv2dSubsampling
31
 
32
  LOG_EPS=math.log(1e-10)
33
  SAMPLING_RATE=16000
@@ -237,11 +237,10 @@ class SpearEncoder(nn.Module):
237
  self.distillation_delta = distillation_delta
238
 
239
  if num_codebooks > 0:
240
- from multi_quantization.prediction import JointCodebookLoss
241
  self.codebook_loss_net = JointCodebookLoss(
242
- predictor_channels=encoder_dim,
243
  num_codebooks=num_codebooks * self.teacher_frame_ratio,
244
- is_joint=False,
245
  reduction="none",
246
  )
247
  else:
 
26
  from torch.nn.utils.rnn import pad_sequence
27
  from torchaudio.compliance.kaldi import fbank as torch_fbank
28
 
29
+ from .configuration_spear import SpearConfig
30
+ from .zipformer import Zipformer2, Conv2dSubsampling
31
 
32
  LOG_EPS=math.log(1e-10)
33
  SAMPLING_RATE=16000
 
237
  self.distillation_delta = distillation_delta
238
 
239
  if num_codebooks > 0:
240
+ from .spear_modules import JointCodebookLoss
241
  self.codebook_loss_net = JointCodebookLoss(
242
+ input_dim=encoder_dim,
243
  num_codebooks=num_codebooks * self.teacher_frame_ratio,
 
244
  reduction="none",
245
  )
246
  else:
spear_modules.py CHANGED
@@ -32,6 +32,58 @@ def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
32
  diff = torch.abs(x - y)
33
  return max_value + torch.log1p(torch.exp(-diff))
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # RuntimeError: Exporting the operator logaddexp to ONNX opset version
37
  # 14 is not supported. Please feel free to request support or submit
 
32
  diff = torch.abs(x - y)
33
  return max_value + torch.log1p(torch.exp(-diff))
34
 
35
+ class JointCodebookLoss(torch.nn.Module):
36
+ def __init__(
37
+ self,
38
+ input_dim: int = 512,
39
+ num_codebooks: int = 16,
40
+ codebook_size: int = 256,
41
+ ignore_index: int = -100,
42
+ reduction: str = "none"
43
+ ):
44
+ super().__init__()
45
+ self.input_dim = input_dim
46
+ self.num_codebooks = num_codebooks
47
+ self.codebook_size = codebook_size
48
+ self.reduction = reduction
49
+ self.ignore_index = ignore_index
50
+
51
+ self.proj = nn.Linear(input_dim, num_codebooks * codebook_size)
52
+
53
+ def forward_logprobs(self, input: torch.Tensor):
54
+ B,T,_ = input.shape
55
+ logits = self.proj(input)
56
+ logits = logits.view(B, T, self.num_codebooks, self.codebook_size) # (B,T,N,256)
57
+ log_probs = F.log_softmax(logits, dim=-1) # (B,T,N,256)
58
+ return log_probs
59
+
60
+
61
+ def forward(self, input, target, return_log_probs: bool = False):
62
+ # input: (B,T,C)
63
+ # target: (B,T,num_codebooks)
64
+
65
+ B,T,_ = input.shape
66
+ logits = self.proj(input)
67
+ logits = logits.view(B, T, self.num_codebooks, self.codebook_size) # (B,T,N,256)
68
+
69
+ loss = F.cross_entropy(
70
+ logits.reshape(-1, self.codebook_size),
71
+ target.reshape(-1),
72
+ ignore_index=self.ignore_index,
73
+ reduction=self.reduction
74
+ )
75
+ log_probs = None
76
+ if return_log_probs:
77
+ log_probs = F.log_softmax(logits, dim=-1)
78
+
79
+ if self.reduction == "none":
80
+ loss = loss.view(B, T, self.num_codebooks)
81
+
82
+ if return_log_probs:
83
+ return loss, log_probs
84
+
85
+ return loss
86
+
87
 
88
  # RuntimeError: Exporting the operator logaddexp to ONNX opset version
89
  # 14 is not supported. Please feel free to request support or submit
zipformer.py CHANGED
@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple, Union
23
  import logging
24
  import torch
25
  import random
26
- from spear_modules import (
27
  Balancer,
28
  BiasNorm,
29
  Dropout2,
 
23
  import logging
24
  import torch
25
  import random
26
+ from .spear_modules import (
27
  Balancer,
28
  BiasNorm,
29
  Dropout2,