moelanoby's picture
Initial upload of ALM-Qwen model package
42bcacc verified
# --- START OF FILE ALM.py ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class AttentionLinkedMemory(nn.Module):
"""
Implements an Attention-Linked Memory Layer.
This layer retrieves context from a structured memory based on a query.
It uses a two-level attention mechanism:
1. Within-Bucket Attention: Summarizes each memory bucket based on relevance
to the query.
2. Between-Bucket Attention: Weights the summarized buckets to find the most
relevant buckets for the query, producing a final aggregated context.
Args:
query_dim (int): Dimension of the input query vector.
memory_dim (int): Dimension of the memory item vectors.
embed_dim (int): Internal embedding dimension for attention. Keys, Queries,
and Values will be projected to this dimension.
num_heads (int): Number of attention heads for both levels. Must divide embed_dim.
output_dim (int, optional): Dimension of the final output context vector.
If None, defaults to embed_dim.
dropout_rate (float, optional): Dropout probability. Defaults to 0.1.
"""
def __init__(self,
query_dim: int,
memory_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None,
dropout_rate: float = 0.1):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")
self.query_dim = query_dim
self.memory_dim = memory_dim
self.embed_dim = embed_dim
self.output_dim = output_dim if output_dim is not None else embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = math.sqrt(self.head_dim)
# --- Level 1 (Within-Bucket) Projections ---
self.q_proj_l1 = nn.Linear(query_dim, embed_dim)
self.k_proj_l1 = nn.Linear(memory_dim, embed_dim)
self.v_proj_l1 = nn.Linear(memory_dim, embed_dim)
self.dropout_l1 = nn.Dropout(dropout_rate)
self.norm_l1_out = nn.LayerNorm(embed_dim)
# --- Level 2 (Between-Bucket) Projections ---
self.q_proj_l2 = nn.Linear(query_dim, embed_dim)
self.k_proj_l2 = nn.Linear(embed_dim, embed_dim)
self.v_proj_l2 = nn.Linear(embed_dim, embed_dim)
self.dropout_l2 = nn.Dropout(dropout_rate)
# --- Final Output Projection ---
self.out_proj = nn.Linear(embed_dim, self.output_dim)
self.norm_final = nn.LayerNorm(self.output_dim)
def forward(self,
query: torch.Tensor,
memory_buckets: torch.Tensor,
memory_mask: torch.Tensor = None):
"""
Forward pass through the Attention-Linked Memory layer.
Args:
query (torch.Tensor): Input query tensor.
Shape: (batch_size, query_dim)
memory_buckets (torch.Tensor): Tensor containing memory items organized
into buckets. Padded for consistent size.
Shape: (batch_size, num_buckets, max_items_per_bucket, memory_dim)
memory_mask (torch.Tensor, optional): Boolean mask indicating valid
memory items (True) vs padding (False).
Shape: (batch_size, num_buckets, max_items_per_bucket)
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- aggregated_context (torch.Tensor): The final context vector.
Shape: (batch_size, output_dim)
- bucket_attention_weights_l2 (torch.Tensor): Attention weights for buckets (L2).
Shape: (batch_size, num_buckets)
- item_attention_weights_l1 (torch.Tensor): Attention weights for items within buckets (L1, averaged over heads).
Shape: (batch_size, num_buckets, max_items_per_bucket)
"""
batch_size, num_buckets, max_items, _ = memory_buckets.shape
# ========================= Level 1: Within-Bucket Attention =========================
q1 = self.q_proj_l1(query).view(batch_size, self.num_heads, 1, 1, self.head_dim)
k1 = self.k_proj_l1(memory_buckets)
v1 = self.v_proj_l1(memory_buckets)
k1 = k1.view(batch_size, num_buckets, max_items, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4)
v1 = v1.view(batch_size, num_buckets, max_items, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4)
attn_scores_l1 = torch.einsum('bhqid,bhnkd->bhnqk', q1, k1) / self.scale
attn_scores_l1 = attn_scores_l1.squeeze(3) # (B, H, N_b, N_i)
if memory_mask is not None:
expanded_mask_l1 = memory_mask.unsqueeze(1) # (B, 1, N_b, N_i)
attn_scores_l1 = attn_scores_l1.masked_fill(expanded_mask_l1 == 0, -1e9)
attn_weights_l1_raw = F.softmax(attn_scores_l1, dim=-1) # (B, H, N_b, N_i)
# For returning, average over heads
item_attention_weights_l1_for_output = attn_weights_l1_raw.mean(dim=1) # (B, N_b, N_i)
attn_weights_l1_dropout = self.dropout_l1(attn_weights_l1_raw)
bucket_summaries = torch.einsum('bhnk,bhnkd->bhnd', attn_weights_l1_dropout, v1)
bucket_summaries = bucket_summaries.permute(0, 2, 1, 3).contiguous().view(batch_size, num_buckets, self.embed_dim)
bucket_summaries = self.norm_l1_out(bucket_summaries)
# ========================= Level 2: Between-Bucket Attention ========================
q2 = self.q_proj_l2(query).view(batch_size, self.num_heads, 1, self.head_dim)
k2 = self.k_proj_l2(bucket_summaries)
v2 = self.v_proj_l2(bucket_summaries)
# Corrected typo: self.head__dim -> self.head_dim
k2 = k2.view(batch_size, num_buckets, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v2 = v2.view(batch_size, num_buckets, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
attn_scores_l2 = torch.einsum('bhqd,bhnd->bhqn', q2, k2) / self.scale
attn_scores_l2 = attn_scores_l2.squeeze(2) # (B, H, N_b)
# Optional: Mask entire buckets if they were all padding in L1 (advanced)
# For now, assume valid buckets have non-zero summaries
# Or if memory_mask indicates all items in a bucket are false, that bucket's summary might be ~0
# and L2 attention should naturally give it low weight.
attn_weights_l2_raw = F.softmax(attn_scores_l2, dim=-1) # (B, H, N_b)
bucket_attention_weights_l2_for_output = attn_weights_l2_raw.mean(dim=1) # (B, N_b)
attn_weights_l2_dropout = self.dropout_l2(attn_weights_l2_raw)
aggregated_context_heads = torch.einsum('bhn,bhnd->bhd', attn_weights_l2_dropout, v2)
aggregated_context = aggregated_context_heads.contiguous().view(batch_size, self.embed_dim)
# ========================= Final Output Projection ========================
final_output = self.out_proj(aggregated_context)
final_output = self.norm_final(final_output)
return final_output, bucket_attention_weights_l2_for_output, item_attention_weights_l1_for_output
# ========================= Example Usage (ALM Standalone) =========================
if __name__ == '__main__':
# Test the modified ALM layer
_batch_size = 2
_query_dim = 32
_memory_dim = 24
_embed_dim = 64
_num_heads = 4
_output_dim = 32
_num_buckets = 3
_max_items_per_bucket = 5
alm_layer_test = AttentionLinkedMemory(
query_dim=_query_dim, memory_dim=_memory_dim, embed_dim=_embed_dim,
num_heads=_num_heads, output_dim=_output_dim
)
query_test = torch.randn(_batch_size, _query_dim)
memory_buckets_test = torch.randn(_batch_size, _num_buckets, _max_items_per_bucket, _memory_dim)
memory_mask_test = torch.ones(_batch_size, _num_buckets, _max_items_per_bucket, dtype=torch.bool)
memory_mask_test[:, :, -1] = 0 # Mask last item in each bucket
agg_ctx, buck_att, item_att = alm_layer_test(query_test, memory_buckets_test, memory_mask_test)
print("--- ALM Standalone Test ---")
print(f"Aggregated Context Shape: {agg_ctx.shape}") # Expected: (B, output_dim)
print(f"Bucket Attention (L2) Shape: {buck_att.shape}") # Expected: (B, num_buckets)
print(f"Item Attention (L1) Shape: {item_att.shape}") # Expected: (B, num_buckets, max_items_per_bucket)
print("Item Attention for first batch, first bucket:\n", item_att[0, 0, :])
print("Sum of item attentions (should be ~1.0 for unmasked items):", item_att[0,0,:-1].sum())
print("Attention on masked item (should be ~0.0):", item_att[0,0,-1])
# --- END OF FILE ALM.py ---