""" Key and Value Projection Layers Maps hidden states from the base language model into memory-compatible representations for the MIRAS memory module. """ import torch.nn as nn class KeyProjection(nn.Module): """ Projects hidden states to memory keys. Args: hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2) memory_dim: Dimension of memory keys (e.g., 256) """ def __init__(self, hidden_dim, memory_dim): super().__init__() self.projection = nn.Linear(hidden_dim, memory_dim, bias=False) def forward(self, hidden_state): """ Args: hidden_state: (batch_size, hidden_dim) Returns: key: (batch_size, memory_dim) """ return self.projection(hidden_state) class ValueProjection(nn.Module): """ Projects hidden states to memory values. Args: hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2) memory_dim: Dimension of memory values (e.g., 256) """ def __init__(self, hidden_dim, memory_dim): super().__init__() self.projection = nn.Linear(hidden_dim, memory_dim, bias=False) def forward(self, hidden_state): """ Args: hidden_state: (batch_size, hidden_dim) Returns: value: (batch_size, memory_dim) """ return self.projection(hidden_state) class OutputProjection(nn.Module): """ Projects memory output back to hidden dimension for augmentation. This enables memory to influence generation: h' = h + alpha * output_proj(memory(k)) Args: memory_dim: Dimension of memory output (e.g., 256) hidden_dim: Dimension of LM hidden states (e.g., 768) """ def __init__(self, memory_dim, hidden_dim): super().__init__() self.projection = nn.Linear(memory_dim, hidden_dim, bias=False) def forward(self, memory_output): """ Args: memory_output: (batch_size, memory_dim) Returns: hidden_augmentation: (batch_size, hidden_dim) """ return self.projection(memory_output)