File size: 9,607 Bytes
2abe772 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import torch, random, fire
from transformers.models.whisper import WhisperConfig
from torch.nn import functional as F
from flash_attn import flash_attn_varlen_func
from torch import nn
import numpy as np
from transformers.activations import ACT2FN
import math
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class OceanWhisperAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=False) # (bsz * qlen, nheads, headdim)
attn_output = attn_output.reshape(bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class OceanWhisperEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = OceanWhisperAttention(self.embed_dim, config.encoder_attention_heads)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = ACT2FN[config.activation_function]
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states, seq_len)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return hidden_states
class OceanAudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
config._attn_implementation = 'flash_attention_2' #
self.config = config
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
# 需要在LLM的初始化中注册注册
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size, stride=config.stride_size, padding=1)
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
self.layers = nn.ModuleList([OceanWhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = True
@torch.no_grad()
def fake_input(self, device):
input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
bridge_length = torch.ones([2], dtype=torch.int32, device=device)
return input_features, encoder_length, bridge_length
def forward(
self,
input_features,
output_length, # MAKESURE 输入的必须是两次conv计算后的hidden state长度
):
input_features = input_features.to(self.conv1.weight.dtype)
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
bsz, tgt_len, _ = inputs_embeds.size() # 当前batch最大长度
if tgt_len < self.positional_embedding.shape[0]:
current_positional_embedding = self.positional_embedding[:tgt_len]
else:
current_positional_embedding = self.positional_embedding
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
# packing hidden states
attention_mask = torch.arange(0, tgt_len).to(hidden_states.device)
attention_mask = torch.lt(attention_mask, output_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
unpacking_index = torch.cumsum(attention_mask.to(torch.int32).view(-1), dim=0) - 1 # 转成下标
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
for idx, encoder_layer in enumerate(self.layers):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
output_length
)
else:
hidden_states = encoder_layer(hidden_states, output_length)
hidden_states = self.layer_norm(hidden_states)
# unpacking
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
hidden_states = torch.where(attention_mask, hidden_states, 0)
return hidden_states
class OceanAudioBridge(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config.audio_config
if self.config.avg_pooler > 1:
self.avg_pooler = nn.AvgPool1d(self.config.avg_pooler, stride=2)
else:
self.avg_pooler = None
self.proj1 = nn.Linear(self.config.d_model, config.hidden_size)
self.proj2 = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x, output_length):
if self.avg_pooler is not None:
x = x.permute(0, 2, 1)
x = self.avg_pooler(x)
x = x.permute(0, 2, 1)
batch_size, sl, _ = x.shape
output_length = output_length.to(x.device)
valid_mask = torch.arange(0, sl).to(x.device)
valid_mask = torch.lt(valid_mask, output_length.reshape(batch_size, 1)).reshape(batch_size, sl, 1)
x = torch.masked_select(x, valid_mask).reshape(-1, self.config.d_model) # (sum(valid_sequence_length), d)
x = ACT2FN[self.config.activation_function](self.proj1(x))
x = self.proj2(x)
return x
def test_audio():
from transformers import AutoConfig
from processor_ocean import OceanAudioProcessor
# from ..configuration_ocean import OceanConfig
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
config.audio_config.d_model = 24
config.audio_config.encoder_layers = 2
config.audio_config.encoder_attention_heads = 4
config.audio_config.encoder_ffn_dim = 48
ae = OceanAudioEncoder(config.audio_config).cuda().to(torch.bfloat16)
bg = OceanAudioBridge(config).cuda().to(torch.bfloat16)
l = random.randint(10, 30)
bs = 3
input_length = torch.tensor([random.randint(1, l) for _ in range(bs)])
encoder_length, bridge_length = OceanAudioProcessor.inference_output_length(config.audio_config, input_length)
print("l={}, input_valid_length={},\nencoder_valid_length={}, bridge_valid_length={}".format(l, input_length, encoder_length, bridge_length))
wave_features = torch.rand((bs, config.audio_config.num_mel_bins, l))
a = ae(wave_features.to('cuda'), encoder_length.to('cuda'))
b = bg(a, bridge_length.to('cuda'))
print('encoder output={}, bridge output={}'.format(a.shape, b.shape))
print(a)
print(b)
if __name__ == '__main__':
fire.Fire()
|