# --- START OF FILE ALM.py --- import torch import torch.nn as nn import torch.nn.functional as F import math class AttentionLinkedMemory(nn.Module): """ Implements an Attention-Linked Memory Layer. This layer retrieves context from a structured memory based on a query. It uses a two-level attention mechanism: 1. Within-Bucket Attention: Summarizes each memory bucket based on relevance to the query. 2. Between-Bucket Attention: Weights the summarized buckets to find the most relevant buckets for the query, producing a final aggregated context. Args: query_dim (int): Dimension of the input query vector. memory_dim (int): Dimension of the memory item vectors. embed_dim (int): Internal embedding dimension for attention. Keys, Queries, and Values will be projected to this dimension. num_heads (int): Number of attention heads for both levels. Must divide embed_dim. output_dim (int, optional): Dimension of the final output context vector. If None, defaults to embed_dim. dropout_rate (float, optional): Dropout probability. Defaults to 0.1. """ def __init__(self, query_dim: int, memory_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dropout_rate: float = 0.1): super().__init__() if embed_dim % num_heads != 0: raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})") self.query_dim = query_dim self.memory_dim = memory_dim self.embed_dim = embed_dim self.output_dim = output_dim if output_dim is not None else embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = math.sqrt(self.head_dim) # --- Level 1 (Within-Bucket) Projections --- self.q_proj_l1 = nn.Linear(query_dim, embed_dim) self.k_proj_l1 = nn.Linear(memory_dim, embed_dim) self.v_proj_l1 = nn.Linear(memory_dim, embed_dim) self.dropout_l1 = nn.Dropout(dropout_rate) self.norm_l1_out = nn.LayerNorm(embed_dim) # --- Level 2 (Between-Bucket) Projections --- self.q_proj_l2 = nn.Linear(query_dim, embed_dim) self.k_proj_l2 = nn.Linear(embed_dim, embed_dim) self.v_proj_l2 = nn.Linear(embed_dim, embed_dim) self.dropout_l2 = nn.Dropout(dropout_rate) # --- Final Output Projection --- self.out_proj = nn.Linear(embed_dim, self.output_dim) self.norm_final = nn.LayerNorm(self.output_dim) def forward(self, query: torch.Tensor, memory_buckets: torch.Tensor, memory_mask: torch.Tensor = None): """ Forward pass through the Attention-Linked Memory layer. Args: query (torch.Tensor): Input query tensor. Shape: (batch_size, query_dim) memory_buckets (torch.Tensor): Tensor containing memory items organized into buckets. Padded for consistent size. Shape: (batch_size, num_buckets, max_items_per_bucket, memory_dim) memory_mask (torch.Tensor, optional): Boolean mask indicating valid memory items (True) vs padding (False). Shape: (batch_size, num_buckets, max_items_per_bucket) Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - aggregated_context (torch.Tensor): The final context vector. Shape: (batch_size, output_dim) - bucket_attention_weights_l2 (torch.Tensor): Attention weights for buckets (L2). Shape: (batch_size, num_buckets) - item_attention_weights_l1 (torch.Tensor): Attention weights for items within buckets (L1, averaged over heads). Shape: (batch_size, num_buckets, max_items_per_bucket) """ batch_size, num_buckets, max_items, _ = memory_buckets.shape # ========================= Level 1: Within-Bucket Attention ========================= q1 = self.q_proj_l1(query).view(batch_size, self.num_heads, 1, 1, self.head_dim) k1 = self.k_proj_l1(memory_buckets) v1 = self.v_proj_l1(memory_buckets) k1 = k1.view(batch_size, num_buckets, max_items, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4) v1 = v1.view(batch_size, num_buckets, max_items, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4) attn_scores_l1 = torch.einsum('bhqid,bhnkd->bhnqk', q1, k1) / self.scale attn_scores_l1 = attn_scores_l1.squeeze(3) # (B, H, N_b, N_i) if memory_mask is not None: expanded_mask_l1 = memory_mask.unsqueeze(1) # (B, 1, N_b, N_i) attn_scores_l1 = attn_scores_l1.masked_fill(expanded_mask_l1 == 0, -1e9) attn_weights_l1_raw = F.softmax(attn_scores_l1, dim=-1) # (B, H, N_b, N_i) # For returning, average over heads item_attention_weights_l1_for_output = attn_weights_l1_raw.mean(dim=1) # (B, N_b, N_i) attn_weights_l1_dropout = self.dropout_l1(attn_weights_l1_raw) bucket_summaries = torch.einsum('bhnk,bhnkd->bhnd', attn_weights_l1_dropout, v1) bucket_summaries = bucket_summaries.permute(0, 2, 1, 3).contiguous().view(batch_size, num_buckets, self.embed_dim) bucket_summaries = self.norm_l1_out(bucket_summaries) # ========================= Level 2: Between-Bucket Attention ======================== q2 = self.q_proj_l2(query).view(batch_size, self.num_heads, 1, self.head_dim) k2 = self.k_proj_l2(bucket_summaries) v2 = self.v_proj_l2(bucket_summaries) # Corrected typo: self.head__dim -> self.head_dim k2 = k2.view(batch_size, num_buckets, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v2 = v2.view(batch_size, num_buckets, self.num_heads, self.head_dim).permute(0, 2, 1, 3) attn_scores_l2 = torch.einsum('bhqd,bhnd->bhqn', q2, k2) / self.scale attn_scores_l2 = attn_scores_l2.squeeze(2) # (B, H, N_b) # Optional: Mask entire buckets if they were all padding in L1 (advanced) # For now, assume valid buckets have non-zero summaries # Or if memory_mask indicates all items in a bucket are false, that bucket's summary might be ~0 # and L2 attention should naturally give it low weight. attn_weights_l2_raw = F.softmax(attn_scores_l2, dim=-1) # (B, H, N_b) bucket_attention_weights_l2_for_output = attn_weights_l2_raw.mean(dim=1) # (B, N_b) attn_weights_l2_dropout = self.dropout_l2(attn_weights_l2_raw) aggregated_context_heads = torch.einsum('bhn,bhnd->bhd', attn_weights_l2_dropout, v2) aggregated_context = aggregated_context_heads.contiguous().view(batch_size, self.embed_dim) # ========================= Final Output Projection ======================== final_output = self.out_proj(aggregated_context) final_output = self.norm_final(final_output) return final_output, bucket_attention_weights_l2_for_output, item_attention_weights_l1_for_output # ========================= Example Usage (ALM Standalone) ========================= if __name__ == '__main__': # Test the modified ALM layer _batch_size = 2 _query_dim = 32 _memory_dim = 24 _embed_dim = 64 _num_heads = 4 _output_dim = 32 _num_buckets = 3 _max_items_per_bucket = 5 alm_layer_test = AttentionLinkedMemory( query_dim=_query_dim, memory_dim=_memory_dim, embed_dim=_embed_dim, num_heads=_num_heads, output_dim=_output_dim ) query_test = torch.randn(_batch_size, _query_dim) memory_buckets_test = torch.randn(_batch_size, _num_buckets, _max_items_per_bucket, _memory_dim) memory_mask_test = torch.ones(_batch_size, _num_buckets, _max_items_per_bucket, dtype=torch.bool) memory_mask_test[:, :, -1] = 0 # Mask last item in each bucket agg_ctx, buck_att, item_att = alm_layer_test(query_test, memory_buckets_test, memory_mask_test) print("--- ALM Standalone Test ---") print(f"Aggregated Context Shape: {agg_ctx.shape}") # Expected: (B, output_dim) print(f"Bucket Attention (L2) Shape: {buck_att.shape}") # Expected: (B, num_buckets) print(f"Item Attention (L1) Shape: {item_att.shape}") # Expected: (B, num_buckets, max_items_per_bucket) print("Item Attention for first batch, first bucket:\n", item_att[0, 0, :]) print("Sum of item attentions (should be ~1.0 for unmasked items):", item_att[0,0,:-1].sum()) print("Attention on masked item (should be ~0.0):", item_att[0,0,-1]) # --- END OF FILE ALM.py ---