FLARE / flare /models /encoders.py
yzhouchen001's picture
cleaned up
2c0063e
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, dropout=0.1, final_activation=None):
super(MLP, self).__init__()
self.dropout = nn.Dropout(dropout)
self.has_final_activation = False
layers = [nn.Linear(in_dim, hidden_dims[0])]
for d1, d2 in zip(hidden_dims[:-1], hidden_dims[1:]):
layers.append(nn.Linear(d1, d2))
self.layers = nn.ModuleList(layers)
if final_activation is not None:
self.has_final_activation = True
self.final_activation = {'relu': F.relu,
'sigmoid': F.sigmoid,
'softmax': F.softmax,}[final_activation]
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) -1:
x = F.relu(x)
x = self.dropout(x)
elif self.has_final_activation:
x = self.final_activation(x)
return x
class CrossAttention(nn.Module):
def __init__(self, embed_dim_q, embed_dim_kv, num_heads, dim_out, dropout=0.0):
"""
Args:
embed_dim_q (int): Dimension of query embeddings.
embed_dim_kv (int): Dimension of key/value embeddings.
num_heads (int): Number of attention heads.
dropout (float): Dropout probability for attention weights.
"""
super(CrossAttention, self).__init__()
# Ensure the embedding dimensions are divisible by the number of heads
assert embed_dim_q % num_heads == 0, "embed_dim_q must be divisible by num_heads"
assert embed_dim_kv % num_heads == 0, "embed_dim_kv must be divisible by num_heads"
self.query_proj = nn.Linear(embed_dim_q, embed_dim_q)
self.key_proj = nn.Linear(embed_dim_kv, embed_dim_q) # Match dimensions with queries
self.value_proj = nn.Linear(embed_dim_kv, embed_dim_q)
self.attention = nn.MultiheadAttention(embed_dim=embed_dim_q, num_heads=num_heads, dropout=dropout, batch_first=True)
self.out_proj = nn.Linear(embed_dim_q, dim_out)
def forward(self, queries, keys, values, mask=None):
"""
Args:
queries (Tensor): Shape (batch_size, len_q, embed_dim_q)
keys (Tensor): Shape (batch_size, len_k, embed_dim_kv)
values (Tensor): Shape (batch_size, len_v, embed_dim_kv)
mask (Tensor, optional): Shape (batch_size, len_q, len_k), 1 for valid positions and 0 for masked.
Returns:
Tensor: Shape (batch_size, len_q, embed_dim_q)
"""
# Project inputs to the required dimensions
queries = self.query_proj(queries) # (batch_size, len_q, embed_dim_q)
keys = self.key_proj(keys) # (batch_size, len_k, embed_dim_q)
values = self.value_proj(values) # (batch_size, len_v, embed_dim_q)
# Compute attention
attn_output, _ = self.attention(queries, keys, values, key_padding_mask=mask)
# Apply output projection
output = self.out_proj(attn_output)
return output