GPA-v1.5 / modeling_audio.py
chua's picture
Upload GPA v1.5 model package
7e53c9a
from typing import Any, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from torch.nn.functional import scaled_dot_product_attention
from transformers import WhisperConfig
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer
from transformers.utils import logging
logger = logging.get_logger(__name__)
# ==========================================
# 1. Core Rotary Embedding components
# ==========================================
class RotaryEmbedding(nn.Module):
def __init__(self, dim, rope_ratio=1):
super().__init__()
self.dim = dim
self.rope_ratio = rope_ratio
@torch.no_grad()
def get_emb(self, seq_len: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
"""Generate the cached RoPE table."""
base = base * self.rope_ratio
# Compute the theta frequencies.
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
# Build the position indices.
t = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = torch.outer(t, inv_freq) # [seq_len, dim/2]
# Construct the cos/sin cache.
# Shape: [seq_len, dim/2, 2]
emb = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
if dtype in (torch.float16, torch.bfloat16):
emb = emb.to(dtype)
return emb
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
"""
x: [batch, num_heads, seq_len, head_dim]
rope_cache: [1, seq_len, dim/2, 2]
"""
b, nh, sq, hd = x.shape
rot_dim = rope_cache.shape[-2] * 2
# Split x into the rotated and pass-through portions.
x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# Reshape x_rot to match rope_cache: [b, nh, sq, rot_dim/2, 2]
x_shaped = x_rot.reshape(b, nh, sq, rot_dim // 2, 2)
# Apply the complex rotation: (a+bi)(c+di) = (ac-bd) + (ad+bc)i
cos = rope_cache[..., 0] # [1, sq, rot_dim/2]
sin = rope_cache[..., 1] # [1, sq, rot_dim/2]
# Add the head dimension.
cos = cos.unsqueeze(1) # [1, 1, sq, rot_dim/2]
sin = sin.unsqueeze(1) # [1, 1, sq, rot_dim/2]
x_out = torch.stack([
x_shaped[..., 0] * cos - x_shaped[..., 1] * sin,
x_shaped[..., 1] * cos + x_shaped[..., 0] * sin
], dim=-1)
x_out = x_out.flatten(3) # Merge the final two dimensions into rot_dim.
return torch.cat([x_out, x_pass], dim=-1)
# ==========================================
# 2. RoPE attention built on SDPA
# ==========================================
class WhisperRoPESdpaAttention(nn.Module):
"""
Replace WhisperFlashAttention2 with PyTorch's native scaled_dot_product_attention.
"""
def __init__(self, config: WhisperConfig, embed_dim: int, num_heads: int, dropout: float = 0.0):
super().__init__()
self.config = config
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
# Standard Whisper projection layers.
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
rotary_pos_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], None]:
bsz, q_len, _ = hidden_states.size()
# 1. Project to queries, keys, and values.
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 2. Reshape to [batch, heads, seq, dim] and keep memory contiguous.
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# 3. Apply RoPE.
if rotary_pos_emb is not None:
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
# 4. Align dtypes to avoid mismatches introduced by fp32 LayerNorm.
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# 5. Run SDPA. Do not apply manual scaling; SDPA handles it internally.
# If a 4D attention_mask is provided, SDPA applies it correctly.
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
)
# 6. Restore shape and apply the output projection.
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, None
# ==========================================
# 3. Wrapped encoder layer and encoder
# ==========================================
class WhisperSpecialEncoderLayer(WhisperEncoderLayer):
def __init__(self, config: WhisperConfig):
super().__init__(config)
# Replace self-attention with the RoPE + SDPA implementation.
self.self_attn = WhisperRoPESdpaAttention(
config=config,
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Any]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return (hidden_states, None) # Keep the tuple length aligned with the Whisper interface.
class WhisperSpecialEncoder(WhisperEncoder):
def __init__(self, config: WhisperConfig, use_rope=True, rope_ratio=1):
super().__init__(config)
self.use_rope = use_rope
# Override the parent layer stack.
self.layers = nn.ModuleList(
[WhisperSpecialEncoderLayer(config) for _ in range(config.encoder_layers)]
)
if use_rope:
# Compute the RoPE dimension, typically a subset of head_dim.
head_dim = config.d_model // config.encoder_attention_heads
self.rotary_embedding = RotaryEmbedding(head_dim // 2, rope_ratio)
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
position_ids=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Whisper convolutional feature extraction.
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1) # [B, T_down, D]
if self.use_rope:
# Build the rotary embedding cache.
rotary_embs = self.rotary_embedding.get_emb(
seq_len=inputs_embeds.shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device
)
# Reshape to [1, seq_len, dim/2, 2] for broadcasting.
rotary_embs = rotary_embs.unsqueeze(0)
hidden_states = inputs_embeds
else:
rotary_embs = None
# Fall back to absolute positional embeddings.
embed_pos = self.embed_positions.weight[:inputs_embeds.shape[1]]
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None, # attention_mask
(head_mask[idx] if head_mask is not None else None),
output_attentions,
rotary_embs,
position_ids,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask=None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
rotary_pos_emb=rotary_embs,
position_ids=position_ids,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)