marcoyang commited on
Commit
9324520
·
1 Parent(s): 43a8860
Files changed (2) hide show
  1. spear_model.py +2 -3
  2. spear_modules.py +52 -0
spear_model.py CHANGED
@@ -237,11 +237,10 @@ class SpearEncoder(nn.Module):
237
  self.distillation_delta = distillation_delta
238
 
239
  if num_codebooks > 0:
240
- from multi_quantization.prediction import JointCodebookLoss
241
  self.codebook_loss_net = JointCodebookLoss(
242
- predictor_channels=encoder_dim,
243
  num_codebooks=num_codebooks * self.teacher_frame_ratio,
244
- is_joint=False,
245
  reduction="none",
246
  )
247
  else:
 
237
  self.distillation_delta = distillation_delta
238
 
239
  if num_codebooks > 0:
240
+ from .spear_modules import JointCodebookLoss
241
  self.codebook_loss_net = JointCodebookLoss(
242
+ input_dim=encoder_dim,
243
  num_codebooks=num_codebooks * self.teacher_frame_ratio,
 
244
  reduction="none",
245
  )
246
  else:
spear_modules.py CHANGED
@@ -32,6 +32,58 @@ def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
32
  diff = torch.abs(x - y)
33
  return max_value + torch.log1p(torch.exp(-diff))
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # RuntimeError: Exporting the operator logaddexp to ONNX opset version
37
  # 14 is not supported. Please feel free to request support or submit
 
32
  diff = torch.abs(x - y)
33
  return max_value + torch.log1p(torch.exp(-diff))
34
 
35
+ class JointCodebookLoss(torch.nn.Module):
36
+ def __init__(
37
+ self,
38
+ input_dim: int = 512,
39
+ num_codebooks: int = 16,
40
+ codebook_size: int = 256,
41
+ ignore_index: int = -100,
42
+ reduction: str = "none"
43
+ ):
44
+ super().__init__()
45
+ self.input_dim = input_dim
46
+ self.num_codebooks = num_codebooks
47
+ self.codebook_size = codebook_size
48
+ self.reduction = reduction
49
+ self.ignore_index = ignore_index
50
+
51
+ self.proj = nn.Linear(input_dim, num_codebooks * codebook_size)
52
+
53
+ def forward_logprobs(self, input: torch.Tensor):
54
+ B,T,_ = input.shape
55
+ logits = self.proj(input)
56
+ logits = logits.view(B, T, self.num_codebooks, self.codebook_size) # (B,T,N,256)
57
+ log_probs = F.log_softmax(logits, dim=-1) # (B,T,N,256)
58
+ return log_probs
59
+
60
+
61
+ def forward(self, input, target, return_log_probs: bool = False):
62
+ # input: (B,T,C)
63
+ # target: (B,T,num_codebooks)
64
+
65
+ B,T,_ = input.shape
66
+ logits = self.proj(input)
67
+ logits = logits.view(B, T, self.num_codebooks, self.codebook_size) # (B,T,N,256)
68
+
69
+ loss = F.cross_entropy(
70
+ logits.reshape(-1, self.codebook_size),
71
+ target.reshape(-1),
72
+ ignore_index=self.ignore_index,
73
+ reduction=self.reduction
74
+ )
75
+ log_probs = None
76
+ if return_log_probs:
77
+ log_probs = F.log_softmax(logits, dim=-1)
78
+
79
+ if self.reduction == "none":
80
+ loss = loss.view(B, T, self.num_codebooks)
81
+
82
+ if return_log_probs:
83
+ return loss, log_probs
84
+
85
+ return loss
86
+
87
 
88
  # RuntimeError: Exporting the operator logaddexp to ONNX opset version
89
  # 14 is not supported. Please feel free to request support or submit