Lakoc commited on
Commit
57fe226
·
verified ·
1 Parent(s): 7e5f173

Update SCBs.py

Browse files
Files changed (1) hide show
  1. SCBs.py +24 -19
SCBs.py CHANGED
@@ -88,6 +88,8 @@ class CrossAttentionEnrollBlockNew(nn.Module):
88
  nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
89
  )
90
 
 
 
91
 
92
 
93
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -97,25 +99,28 @@ class CrossAttentionEnrollBlockNew(nn.Module):
97
  Returns:
98
  Updated hidden states of same shape
99
  """
100
- q_channel = hidden_states[:, 0] # (B, T, F)
101
- kv_channel = hidden_states[:, 1] # (B, T, F)
102
-
103
- # Cross-attention
104
- attn_output = self.cross_attn(
105
- hidden_states=q_channel,
106
- key_value_states=kv_channel,
107
- output_attentions=False
108
- )[0]
109
-
110
- # Concatenate attention output with original normalized query
111
- q_concat = torch.cat([attn_output, q_channel], dim=-1) # (B, T, 2*F)
112
-
113
- # Feed-forward processing (no normalization to preserve initialization)
114
- # updated_q = self.ffn(q_concat) # (B, T, F)
115
- updated_q = q_channel + torch.tanh(self.cross_gate) * self.ffn(q_concat)
116
-
117
- # Return stacked result (only query channel is updated)
118
- return torch.stack([updated_q, kv_channel], dim=1)
 
 
 
119
 
120
  class SpeakerCommunicationBlock(nn.Module):
121
  def __init__(self, config):
 
88
  nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
89
  )
90
 
91
+ self.enabled = True
92
+
93
 
94
 
95
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
99
  Returns:
100
  Updated hidden states of same shape
101
  """
102
+ if self.enabled:
103
+ q_channel = hidden_states[:, 0] # (B, T, F)
104
+ kv_channel = hidden_states[:, 1] # (B, T, F)
105
+
106
+ # Cross-attention
107
+ attn_output = self.cross_attn(
108
+ hidden_states=q_channel,
109
+ key_value_states=kv_channel,
110
+ output_attentions=False
111
+ )[0]
112
+
113
+ # Concatenate attention output with original normalized query
114
+ q_concat = torch.cat([attn_output, q_channel], dim=-1) # (B, T, 2*F)
115
+
116
+ # Feed-forward processing (no normalization to preserve initialization)
117
+ # updated_q = self.ffn(q_concat) # (B, T, F)
118
+ updated_q = q_channel + torch.tanh(self.cross_gate) * self.ffn(q_concat)
119
+
120
+ # Return stacked result (only query channel is updated)
121
+ return torch.stack([updated_q, kv_channel], dim=1)
122
+ else:
123
+ return hidden_states
124
 
125
  class SpeakerCommunicationBlock(nn.Module):
126
  def __init__(self, config):