Update SCBs.py
Browse files
SCBs.py
CHANGED
|
@@ -88,6 +88,8 @@ class CrossAttentionEnrollBlockNew(nn.Module):
|
|
| 88 |
nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
|
| 89 |
)
|
| 90 |
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
@@ -97,25 +99,28 @@ class CrossAttentionEnrollBlockNew(nn.Module):
|
|
| 97 |
Returns:
|
| 98 |
Updated hidden states of same shape
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
class SpeakerCommunicationBlock(nn.Module):
|
| 121 |
def __init__(self, config):
|
|
|
|
| 88 |
nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
|
| 89 |
)
|
| 90 |
|
| 91 |
+
self.enabled = True
|
| 92 |
+
|
| 93 |
|
| 94 |
|
| 95 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 99 |
Returns:
|
| 100 |
Updated hidden states of same shape
|
| 101 |
"""
|
| 102 |
+
if self.enabled:
|
| 103 |
+
q_channel = hidden_states[:, 0] # (B, T, F)
|
| 104 |
+
kv_channel = hidden_states[:, 1] # (B, T, F)
|
| 105 |
+
|
| 106 |
+
# Cross-attention
|
| 107 |
+
attn_output = self.cross_attn(
|
| 108 |
+
hidden_states=q_channel,
|
| 109 |
+
key_value_states=kv_channel,
|
| 110 |
+
output_attentions=False
|
| 111 |
+
)[0]
|
| 112 |
+
|
| 113 |
+
# Concatenate attention output with original normalized query
|
| 114 |
+
q_concat = torch.cat([attn_output, q_channel], dim=-1) # (B, T, 2*F)
|
| 115 |
+
|
| 116 |
+
# Feed-forward processing (no normalization to preserve initialization)
|
| 117 |
+
# updated_q = self.ffn(q_concat) # (B, T, F)
|
| 118 |
+
updated_q = q_channel + torch.tanh(self.cross_gate) * self.ffn(q_concat)
|
| 119 |
+
|
| 120 |
+
# Return stacked result (only query channel is updated)
|
| 121 |
+
return torch.stack([updated_q, kv_channel], dim=1)
|
| 122 |
+
else:
|
| 123 |
+
return hidden_states
|
| 124 |
|
| 125 |
class SpeakerCommunicationBlock(nn.Module):
|
| 126 |
def __init__(self, config):
|