File size: 6,122 Bytes
64c2cbc 57fe226 64c2cbc 57fe226 64c2cbc bcb1a51 64c2cbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
from torch import nn
from transformers import WhisperConfig
from transformers.activations import ACT2FN
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
import torch.nn.functional as F
from .layers import CustomLinear, CustomDiagonalLinear, Gate, CustomLinearInitialized
class LowRankApproxSelectFirst(nn.Module):
def __init__(self, d_in, d_out, rank):
super().__init__()
self.d_in = d_in
self.d_out = d_out
self.rank = rank
self.proj_in = nn.Linear(d_in, rank)
self.proj_out = nn.Linear(rank, d_out)
def forward(self, x):
return self.proj_out(self.proj_in(x))
def _init_weights(self):
# Create low-rank approximation of the identity projection from first d_out of input
eye = torch.eye(self.d_out, self.d_in) # (d_out x d_in)
# Low-rank SVD of eye matrix
U, S, Vh = torch.linalg.svd(eye, full_matrices=False) # U: (d_out x d_out), Vh: (d_in x d_in)
U_k = U[:, :self.rank] # (d_out x rank)
S_k = S[:self.rank] # (rank,)
V_k = Vh[:self.rank, :] # (rank x d_in)
A = V_k # (rank x d_in)
B = U_k @ torch.diag(S_k) # (d_out x rank)
# Set weights
self.proj_in.weight.data.copy_(A)
self.proj_in.bias.data.zero_()
self.proj_out.weight.data.copy_(B)
self.proj_out.bias.data.zero_()
def first_init_fun(module):
# Zero out all weights initially
# module.weight.data.zero_()
torch.nn.init.xavier_uniform_(module.weight, gain=0.1)
# Create identity mapping for second half of input (q_normed part)
# Input: [cross_attn_output, q_normed] -> map q_normed to first embed_dim outputs
module.weight.data[:module.weight.shape[1] // 2, module.weight.shape[1] // 2:] += torch.eye(module.weight.shape[1] // 2)
# module.weight.data[:module.weight.shape[1]//2, module.weight.shape[1]//2:] = torch.eye(module.weight.shape[1]//2)
# Zero bias
module.bias.data.zero_()
def second_init_fun(module):
# module.weight.data.zero_()
torch.nn.init.xavier_uniform_(module.weight, gain=0.1)
# Create identity mapping from first embed_dim inputs to output
module.weight.data[:, :module.weight.shape[0]] += torch.eye(module.weight.shape[0])
# Zero bias for second linear
module.bias.data.zero_()
# Cross attention block that can easily learn to ignore cross attention initially
class CrossAttentionEnrollBlockNew(nn.Module):
def __init__(self, config, layer_norm_eps: float = 1e-5):
super().__init__()
self.embed_dim = config.d_model
self.ffn_dim = config.encoder_ffn_dim
self.cross_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
)
# Layer normalization (pre-norm style)
# self.norm_attn = nn.LayerNorm(self.embed_dim, eps=layer_norm_eps)
self.cross_gate = nn.Parameter(torch.zeros(1))
# Feed-forward network that maps concat space back to single channel
self.ffn = nn.Sequential(
CustomLinearInitialized(self.embed_dim * 2, self.ffn_dim, init_fun=first_init_fun),
ACT2FN[config.activation_function],
nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1),
CustomLinearInitialized(self.ffn_dim, self.embed_dim, init_fun=second_init_fun),
nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.1)
)
self.enabled = True
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: (B, 2, T, F) - batch, channels, time, features
Returns:
Updated hidden states of same shape
"""
if self.enabled:
q_channel = hidden_states[:, 0] # (B, T, F)
kv_channel = hidden_states[:, 1] # (B, T, F)
# Cross-attention
attn_output = self.cross_attn(
hidden_states=q_channel,
key_value_states=kv_channel,
output_attentions=False
)[0]
# Concatenate attention output with original normalized query
q_concat = torch.cat([attn_output, q_channel], dim=-1) # (B, T, 2*F)
# Feed-forward processing (no normalization to preserve initialization)
# updated_q = self.ffn(q_concat) # (B, T, F)
updated_q = q_channel + torch.tanh(self.cross_gate) * self.ffn(q_concat)
# Return stacked result (only query channel is updated)
return torch.stack([updated_q, kv_channel], dim=1)
else:
return hidden_states
class SpeakerCommunicationBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_speakers = getattr(config, "mt_num_speakers", 2)
self.embed_dim = config.d_model
self.scb_method = config.scb_method
self.config = config
if self.scb_method == "cross_attention_enroll_new":
self.method = CrossAttentionEnrollBlockNew(config)
elif self.scb_method == "identity":
self.method = (nn.Parameter(torch.zeros(self.embed_dim)) if config.fddt_bias_only else (
CustomDiagonalLinear(self.embed_dim, bias=True, init_eye_val=1.0) if config.fddt_is_diagonal else CustomLinear(
self.embed_dim, self.embed_dim, bias=True, init_eye_val=1.0)))
else:
raise ValueError(f"Unsupported scb_method: {self.scb_method}")
def forward(self, x):
# x: (B, T, F)
B, T, F = x.shape
S = self.num_speakers
# Reshape to (B//S, S, T, F)
x_reshaped = x.view(B//S, S, T, F)
# Call the selected method
out = self.method(x_reshaped)
# Reshape back (B, T, F)
out_merged = out.view(B, T, F)
return out_merged
|