memory-augmented-generation / projections.py
Pavantej's picture
Upload folder using huggingface_hub
dd5bee6 verified
"""
Key and Value Projection Layers
Maps hidden states from the base language model into memory-compatible
representations for the MIRAS memory module.
"""
import torch.nn as nn
class KeyProjection(nn.Module):
"""
Projects hidden states to memory keys.
Args:
hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
memory_dim: Dimension of memory keys (e.g., 256)
"""
def __init__(self, hidden_dim, memory_dim):
super().__init__()
self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
def forward(self, hidden_state):
"""
Args:
hidden_state: (batch_size, hidden_dim)
Returns:
key: (batch_size, memory_dim)
"""
return self.projection(hidden_state)
class ValueProjection(nn.Module):
"""
Projects hidden states to memory values.
Args:
hidden_dim: Dimension of LM hidden states (e.g., 768 for distilgpt2)
memory_dim: Dimension of memory values (e.g., 256)
"""
def __init__(self, hidden_dim, memory_dim):
super().__init__()
self.projection = nn.Linear(hidden_dim, memory_dim, bias=False)
def forward(self, hidden_state):
"""
Args:
hidden_state: (batch_size, hidden_dim)
Returns:
value: (batch_size, memory_dim)
"""
return self.projection(hidden_state)
class OutputProjection(nn.Module):
"""
Projects memory output back to hidden dimension for augmentation.
This enables memory to influence generation:
h' = h + alpha * output_proj(memory(k))
Args:
memory_dim: Dimension of memory output (e.g., 256)
hidden_dim: Dimension of LM hidden states (e.g., 768)
"""
def __init__(self, memory_dim, hidden_dim):
super().__init__()
self.projection = nn.Linear(memory_dim, hidden_dim, bias=False)
def forward(self, memory_output):
"""
Args:
memory_output: (batch_size, memory_dim)
Returns:
hidden_augmentation: (batch_size, hidden_dim)
"""
return self.projection(memory_output)