Spaces:
Running
Running
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MLP(nn.Module): | |
| def __init__(self, in_dim, hidden_dims, dropout=0.1, final_activation=None): | |
| super(MLP, self).__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| self.has_final_activation = False | |
| layers = [nn.Linear(in_dim, hidden_dims[0])] | |
| for d1, d2 in zip(hidden_dims[:-1], hidden_dims[1:]): | |
| layers.append(nn.Linear(d1, d2)) | |
| self.layers = nn.ModuleList(layers) | |
| if final_activation is not None: | |
| self.has_final_activation = True | |
| self.final_activation = {'relu': F.relu, | |
| 'sigmoid': F.sigmoid, | |
| 'softmax': F.softmax,}[final_activation] | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x) | |
| if i < len(self.layers) -1: | |
| x = F.relu(x) | |
| x = self.dropout(x) | |
| elif self.has_final_activation: | |
| x = self.final_activation(x) | |
| return x | |
| class CrossAttention(nn.Module): | |
| def __init__(self, embed_dim_q, embed_dim_kv, num_heads, dim_out, dropout=0.0): | |
| """ | |
| Args: | |
| embed_dim_q (int): Dimension of query embeddings. | |
| embed_dim_kv (int): Dimension of key/value embeddings. | |
| num_heads (int): Number of attention heads. | |
| dropout (float): Dropout probability for attention weights. | |
| """ | |
| super(CrossAttention, self).__init__() | |
| # Ensure the embedding dimensions are divisible by the number of heads | |
| assert embed_dim_q % num_heads == 0, "embed_dim_q must be divisible by num_heads" | |
| assert embed_dim_kv % num_heads == 0, "embed_dim_kv must be divisible by num_heads" | |
| self.query_proj = nn.Linear(embed_dim_q, embed_dim_q) | |
| self.key_proj = nn.Linear(embed_dim_kv, embed_dim_q) # Match dimensions with queries | |
| self.value_proj = nn.Linear(embed_dim_kv, embed_dim_q) | |
| self.attention = nn.MultiheadAttention(embed_dim=embed_dim_q, num_heads=num_heads, dropout=dropout, batch_first=True) | |
| self.out_proj = nn.Linear(embed_dim_q, dim_out) | |
| def forward(self, queries, keys, values, mask=None): | |
| """ | |
| Args: | |
| queries (Tensor): Shape (batch_size, len_q, embed_dim_q) | |
| keys (Tensor): Shape (batch_size, len_k, embed_dim_kv) | |
| values (Tensor): Shape (batch_size, len_v, embed_dim_kv) | |
| mask (Tensor, optional): Shape (batch_size, len_q, len_k), 1 for valid positions and 0 for masked. | |
| Returns: | |
| Tensor: Shape (batch_size, len_q, embed_dim_q) | |
| """ | |
| # Project inputs to the required dimensions | |
| queries = self.query_proj(queries) # (batch_size, len_q, embed_dim_q) | |
| keys = self.key_proj(keys) # (batch_size, len_k, embed_dim_q) | |
| values = self.value_proj(values) # (batch_size, len_v, embed_dim_q) | |
| # Compute attention | |
| attn_output, _ = self.attention(queries, keys, values, key_padding_mask=mask) | |
| # Apply output projection | |
| output = self.out_proj(attn_output) | |
| return output | |