Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import copy | |
| import math | |
| class MultiHeadAttention(nn.Module): | |
| """ | |
| Multi-Head Attention: | |
| 1) Linear projections for Q, K, V | |
| 2) Scaled dot-product attention per head | |
| 3) Concatenate heads and final linear projection | |
| https://arxiv.org/pdf/1706.03762 | |
| """ | |
| def __init__(self, embed_dim:int, key_dim:int, num_heads: int, dropout: float = 0.0): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" | |
| self.embed_dim = embed_dim | |
| self.key_dim = key_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.scale = math.sqrt(self.head_dim) # square root of dk for scaling | |
| # Separate projections for query, key, and value 3 diff transformations | |
| # HINT: Linear projections for Q, K, V | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(key_dim, embed_dim) #To Do | |
| self.v_proj = nn.Linear(key_dim, embed_dim) #To Do | |
| # Output projection after concatenating heads (embed_dim -> embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) #To Do | |
| # Dropouts | |
| self.attn_dropout = nn.Dropout(dropout) | |
| self.proj_dropout = nn.Dropout(dropout) | |
| def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| query: [batch, seq_q, embed_dim] | |
| key: [batch, seq_k, embed_dim] | |
| value: [batch, seq_k, embed_dim] | |
| Returns: | |
| out: [batch, seq_q, embed_dim] | |
| """ | |
| B, seq_q, _ = query.size() | |
| _, seq_k, _ = key.size() | |
| # 1) Project inputs and split into heads | |
| q = self.q_proj(query).view(B, seq_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, seq_q, head_dim] | |
| k = self.k_proj(key).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim] | |
| v = self.v_proj(value).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim] | |
| # 2) Compute scaled dot-product attention | |
| scores = (q @ k.transpose(-1, -2)) / self.scale # TO DO multiply q and k, then scale # [B, heads, seq_q, seq_k] swaps last twp dims? | |
| weights = torch.softmax(scores, dim=-1) # TO DO apply softmax to scores | |
| weights = self.attn_dropout(weights) | |
| attn = weights @ v # TO DO multiply weights and v # [B, heads, seq_q, head_dim] | |
| # 3) Concatenate heads | |
| attn = attn.transpose(1, 2).contiguous().view(B, seq_q, self.embed_dim) # [B, seq_q, embed_dim] | |
| # 4) Final projection | |
| out = self.out_proj(attn) # TO DO apply output projection | |
| out = self.proj_dropout(out) | |
| return out | |
| class TransformerEncoderLayer(nn.Module): | |
| """ | |
| Transformer Encoder Layer: | |
| 1) Multi-head self-attention | |
| 2) Feed-forward network | |
| 3) Residual connections + LayerNorm | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| mlp_dim: int, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| # 1) Self-attention | |
| self.self_attn = MultiHeadAttention( | |
| embed_dim=embed_dim, | |
| key_dim=embed_dim, # self-attention uses same dimension for Q, K, V | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| ) | |
| # 2) Feed-forward network using nn.Sequential | |
| self.ffn = nn.Sequential( | |
| nn.Linear(embed_dim, mlp_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(mlp_dim, embed_dim), | |
| ) | |
| # 3) LayerNorm and Dropouts for residuals | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| # self.norm3 = nn.LayerNorm(embed_dim) | |
| self.attn_dropout = nn.Dropout(dropout) | |
| self.ff_dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Tensor of shape [batch, seq_len, embed_dim] | |
| Returns: | |
| Tensor of same shape | |
| """ | |
| # 1) Self-attention block | |
| x_norm = self.norm1(x) # Normalize input | |
| attn_out = self.self_attn(x_norm, x_norm, x_norm) | |
| x = x + self.attn_dropout(attn_out) # Residual connection | |
| # 2) Feed-forward block | |
| x_norm_ff = self.norm2(x) | |
| ff = self.ffn(x_norm_ff) | |
| x = x + self.ff_dropout(ff) # Residual connection | |
| # x = self.norm3(x) | |
| return x | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, encoder_layer: TransformerEncoderLayer, num_layers: int, embed_dim: int): | |
| super().__init__() | |
| # Clone the provided encoder_layer num_layers times | |
| self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # TO DO pass through each layer | |
| # HINT: Pass input through each encoder layer to update x | |
| for layer in self.layers: | |
| x = layer(x) | |
| # Apply final normalization | |
| x = self.norm(x) | |
| return x | |
| class TransformerDecoderLayer(nn.Module): | |
| """ | |
| Transformer Decoder Layer: | |
| 1) Self-attention | |
| 2) Cross-attention (over encoder features) | |
| 3) Feed-forward network | |
| 4) Residual connections + LayerNorm | |
| """ | |
| def __init__( | |
| self, | |
| encoder_embed_dim: int, | |
| decoder_embed_dim: int, | |
| num_heads: int, | |
| mlp_dim: int, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| # 1. Self-attention on decoder input | |
| self.self_attn = MultiHeadAttention(decoder_embed_dim, decoder_embed_dim, num_heads, dropout) | |
| # 2. Cross-attention over encoder features | |
| self.cross_attn = MultiHeadAttention(decoder_embed_dim, encoder_embed_dim, num_heads, dropout) | |
| # 3. Feed-forward network | |
| self.ffn = nn.Sequential( | |
| nn.Linear(decoder_embed_dim, mlp_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(mlp_dim, decoder_embed_dim) | |
| ) | |
| # 4. LayerNorms and Dropouts | |
| self.norm1 = nn.LayerNorm(decoder_embed_dim) | |
| self.norm2 = nn.LayerNorm(decoder_embed_dim) | |
| self.norm3 = nn.LayerNorm(decoder_embed_dim) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor: | |
| # 1) Self-attention block | |
| x_norm = self.norm1(x) # Normalize input | |
| sa = self.self_attn(x_norm, x_norm, x_norm) | |
| # TO DO | |
| x = x + self.dropout1(sa) | |
| # 2) Cross-attention block | |
| x_norm_ca = self.norm2(x) | |
| ca = self.cross_attn(x_norm_ca, encoder_features, encoder_features) | |
| # TO DO | |
| x = x + self.dropout2(ca) | |
| # 3) Feed-forward block | |
| x_norm_ff = self.norm3(x) | |
| ff = self.ffn(x_norm_ff) | |
| # TO DO | |
| x = x + self.dropout3(ff) | |
| return x | |
| class TransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| decoder_layer: TransformerDecoderLayer, | |
| num_layers: int, | |
| embed_dim: int, | |
| ): | |
| super().__init__() | |
| # Clone the provided decoder_layer num_layers times | |
| self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor: | |
| # TODO pass through each layer | |
| # HINT: Pass input through each encoder layer to update x | |
| for layer in self.layers: | |
| x = layer(x, encoder_features) | |
| x = self.norm(x) | |
| return x | |