| """ | |
| Key and Value Projection Layers | |
| Maps hidden states from the base language model into memory-compatible | |
| representations for the MIRAS memory module. | |
| """ | |
| import torch.nn as nn | |
| class KeyProjection(nn.Module): | |
| """ | |
| Projects hidden states to memory keys. | |
| Args: | |
| hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2) | |
| memory_dim: Dimension of memory keys (e.g., 256) | |
| """ | |
| def __init__(self, hidden_dim, memory_dim): | |
| super().__init__() | |
| self.projection = nn.Linear(hidden_dim, memory_dim, bias=False) | |
| def forward(self, hidden_state): | |
| """ | |
| Args: | |
| hidden_state: (batch_size, hidden_dim) | |
| Returns: | |
| key: (batch_size, memory_dim) | |
| """ | |
| return self.projection(hidden_state) | |
| class ValueProjection(nn.Module): | |
| """ | |
| Projects hidden states to memory values. | |
| Args: | |
| hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2) | |
| memory_dim: Dimension of memory values (e.g., 256) | |
| """ | |
| def __init__(self, hidden_dim, memory_dim): | |
| super().__init__() | |
| self.projection = nn.Linear(hidden_dim, memory_dim, bias=False) | |
| def forward(self, hidden_state): | |
| """ | |
| Args: | |
| hidden_state: (batch_size, hidden_dim) | |
| Returns: | |
| value: (batch_size, memory_dim) | |
| """ | |
| return self.projection(hidden_state) | |
| class OutputProjection(nn.Module): | |
| """ | |
| Projects memory output back to hidden dimension for augmentation. | |
| This enables memory to influence generation: | |
| h' = h + alpha * output_proj(memory(k)) | |
| Args: | |
| memory_dim: Dimension of memory output (e.g., 256) | |
| hidden_dim: Dimension of LM hidden states (e.g., 768) | |
| """ | |
| def __init__(self, memory_dim, hidden_dim): | |
| super().__init__() | |
| self.projection = nn.Linear(memory_dim, hidden_dim, bias=False) | |
| def forward(self, memory_output): | |
| """ | |
| Args: | |
| memory_output: (batch_size, memory_dim) | |
| Returns: | |
| hidden_augmentation: (batch_size, hidden_dim) | |
| """ | |
| return self.projection(memory_output) | |