|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
class AttentionLinkedMemory(nn.Module): |
|
|
""" |
|
|
Implements an Attention-Linked Memory Layer. |
|
|
|
|
|
This layer retrieves context from a structured memory based on a query. |
|
|
It uses a two-level attention mechanism: |
|
|
1. Within-Bucket Attention: Summarizes each memory bucket based on relevance |
|
|
to the query. |
|
|
2. Between-Bucket Attention: Weights the summarized buckets to find the most |
|
|
relevant buckets for the query, producing a final aggregated context. |
|
|
|
|
|
Args: |
|
|
query_dim (int): Dimension of the input query vector. |
|
|
memory_dim (int): Dimension of the memory item vectors. |
|
|
embed_dim (int): Internal embedding dimension for attention. Keys, Queries, |
|
|
and Values will be projected to this dimension. |
|
|
num_heads (int): Number of attention heads for both levels. Must divide embed_dim. |
|
|
output_dim (int, optional): Dimension of the final output context vector. |
|
|
If None, defaults to embed_dim. |
|
|
dropout_rate (float, optional): Dropout probability. Defaults to 0.1. |
|
|
""" |
|
|
def __init__(self, |
|
|
query_dim: int, |
|
|
memory_dim: int, |
|
|
embed_dim: int, |
|
|
num_heads: int, |
|
|
output_dim: int = None, |
|
|
dropout_rate: float = 0.1): |
|
|
super().__init__() |
|
|
|
|
|
if embed_dim % num_heads != 0: |
|
|
raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})") |
|
|
|
|
|
self.query_dim = query_dim |
|
|
self.memory_dim = memory_dim |
|
|
self.embed_dim = embed_dim |
|
|
self.output_dim = output_dim if output_dim is not None else embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
self.scale = math.sqrt(self.head_dim) |
|
|
|
|
|
|
|
|
self.q_proj_l1 = nn.Linear(query_dim, embed_dim) |
|
|
self.k_proj_l1 = nn.Linear(memory_dim, embed_dim) |
|
|
self.v_proj_l1 = nn.Linear(memory_dim, embed_dim) |
|
|
self.dropout_l1 = nn.Dropout(dropout_rate) |
|
|
self.norm_l1_out = nn.LayerNorm(embed_dim) |
|
|
|
|
|
|
|
|
self.q_proj_l2 = nn.Linear(query_dim, embed_dim) |
|
|
self.k_proj_l2 = nn.Linear(embed_dim, embed_dim) |
|
|
self.v_proj_l2 = nn.Linear(embed_dim, embed_dim) |
|
|
self.dropout_l2 = nn.Dropout(dropout_rate) |
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, self.output_dim) |
|
|
self.norm_final = nn.LayerNorm(self.output_dim) |
|
|
|
|
|
def forward(self, |
|
|
query: torch.Tensor, |
|
|
memory_buckets: torch.Tensor, |
|
|
memory_mask: torch.Tensor = None): |
|
|
""" |
|
|
Forward pass through the Attention-Linked Memory layer. |
|
|
|
|
|
Args: |
|
|
query (torch.Tensor): Input query tensor. |
|
|
Shape: (batch_size, query_dim) |
|
|
memory_buckets (torch.Tensor): Tensor containing memory items organized |
|
|
into buckets. Padded for consistent size. |
|
|
Shape: (batch_size, num_buckets, max_items_per_bucket, memory_dim) |
|
|
memory_mask (torch.Tensor, optional): Boolean mask indicating valid |
|
|
memory items (True) vs padding (False). |
|
|
Shape: (batch_size, num_buckets, max_items_per_bucket) |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
- aggregated_context (torch.Tensor): The final context vector. |
|
|
Shape: (batch_size, output_dim) |
|
|
- bucket_attention_weights_l2 (torch.Tensor): Attention weights for buckets (L2). |
|
|
Shape: (batch_size, num_buckets) |
|
|
- item_attention_weights_l1 (torch.Tensor): Attention weights for items within buckets (L1, averaged over heads). |
|
|
Shape: (batch_size, num_buckets, max_items_per_bucket) |
|
|
""" |
|
|
batch_size, num_buckets, max_items, _ = memory_buckets.shape |
|
|
|
|
|
|
|
|
q1 = self.q_proj_l1(query).view(batch_size, self.num_heads, 1, 1, self.head_dim) |
|
|
k1 = self.k_proj_l1(memory_buckets) |
|
|
v1 = self.v_proj_l1(memory_buckets) |
|
|
|
|
|
k1 = k1.view(batch_size, num_buckets, max_items, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4) |
|
|
v1 = v1.view(batch_size, num_buckets, max_items, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4) |
|
|
|
|
|
attn_scores_l1 = torch.einsum('bhqid,bhnkd->bhnqk', q1, k1) / self.scale |
|
|
attn_scores_l1 = attn_scores_l1.squeeze(3) |
|
|
|
|
|
if memory_mask is not None: |
|
|
expanded_mask_l1 = memory_mask.unsqueeze(1) |
|
|
attn_scores_l1 = attn_scores_l1.masked_fill(expanded_mask_l1 == 0, -1e9) |
|
|
|
|
|
attn_weights_l1_raw = F.softmax(attn_scores_l1, dim=-1) |
|
|
|
|
|
item_attention_weights_l1_for_output = attn_weights_l1_raw.mean(dim=1) |
|
|
|
|
|
attn_weights_l1_dropout = self.dropout_l1(attn_weights_l1_raw) |
|
|
|
|
|
bucket_summaries = torch.einsum('bhnk,bhnkd->bhnd', attn_weights_l1_dropout, v1) |
|
|
bucket_summaries = bucket_summaries.permute(0, 2, 1, 3).contiguous().view(batch_size, num_buckets, self.embed_dim) |
|
|
bucket_summaries = self.norm_l1_out(bucket_summaries) |
|
|
|
|
|
|
|
|
q2 = self.q_proj_l2(query).view(batch_size, self.num_heads, 1, self.head_dim) |
|
|
k2 = self.k_proj_l2(bucket_summaries) |
|
|
v2 = self.v_proj_l2(bucket_summaries) |
|
|
|
|
|
|
|
|
k2 = k2.view(batch_size, num_buckets, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
|
|
v2 = v2.view(batch_size, num_buckets, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
|
|
|
|
|
attn_scores_l2 = torch.einsum('bhqd,bhnd->bhqn', q2, k2) / self.scale |
|
|
attn_scores_l2 = attn_scores_l2.squeeze(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights_l2_raw = F.softmax(attn_scores_l2, dim=-1) |
|
|
bucket_attention_weights_l2_for_output = attn_weights_l2_raw.mean(dim=1) |
|
|
|
|
|
attn_weights_l2_dropout = self.dropout_l2(attn_weights_l2_raw) |
|
|
|
|
|
aggregated_context_heads = torch.einsum('bhn,bhnd->bhd', attn_weights_l2_dropout, v2) |
|
|
aggregated_context = aggregated_context_heads.contiguous().view(batch_size, self.embed_dim) |
|
|
|
|
|
|
|
|
final_output = self.out_proj(aggregated_context) |
|
|
final_output = self.norm_final(final_output) |
|
|
|
|
|
return final_output, bucket_attention_weights_l2_for_output, item_attention_weights_l1_for_output |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
_batch_size = 2 |
|
|
_query_dim = 32 |
|
|
_memory_dim = 24 |
|
|
_embed_dim = 64 |
|
|
_num_heads = 4 |
|
|
_output_dim = 32 |
|
|
_num_buckets = 3 |
|
|
_max_items_per_bucket = 5 |
|
|
|
|
|
alm_layer_test = AttentionLinkedMemory( |
|
|
query_dim=_query_dim, memory_dim=_memory_dim, embed_dim=_embed_dim, |
|
|
num_heads=_num_heads, output_dim=_output_dim |
|
|
) |
|
|
|
|
|
query_test = torch.randn(_batch_size, _query_dim) |
|
|
memory_buckets_test = torch.randn(_batch_size, _num_buckets, _max_items_per_bucket, _memory_dim) |
|
|
memory_mask_test = torch.ones(_batch_size, _num_buckets, _max_items_per_bucket, dtype=torch.bool) |
|
|
memory_mask_test[:, :, -1] = 0 |
|
|
|
|
|
agg_ctx, buck_att, item_att = alm_layer_test(query_test, memory_buckets_test, memory_mask_test) |
|
|
|
|
|
print("--- ALM Standalone Test ---") |
|
|
print(f"Aggregated Context Shape: {agg_ctx.shape}") |
|
|
print(f"Bucket Attention (L2) Shape: {buck_att.shape}") |
|
|
print(f"Item Attention (L1) Shape: {item_att.shape}") |
|
|
print("Item Attention for first batch, first bucket:\n", item_att[0, 0, :]) |
|
|
print("Sum of item attentions (should be ~1.0 for unmasked items):", item_att[0,0,:-1].sum()) |
|
|
print("Attention on masked item (should be ~0.0):", item_att[0,0,-1]) |
|
|
|
|
|
|
|
|
|
|
|
|