| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
|
|
| @dataclass |
| class ObservationMemoryConfig: |
| hidden_dim: int = 512 |
| action_dim: int = 14 |
| history_steps: int = 2 |
| num_layers: int = 1 |
| dropout: float = 0.1 |
| memory_bank_size: int = 4 |
| num_heads: int = 4 |
| max_history_steps: int = 8 |
| scene_bank_size: int = 2 |
| belief_bank_size: int = 2 |
| scene_history_steps: int = 3 |
| belief_history_steps: int = 8 |
| memory_write_threshold: float = 0.45 |
| memory_suppression_margin: float = 0.05 |
|
|
|
|
| class ObservationMemory(nn.Module): |
| def __init__(self, config: ObservationMemoryConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.gru = nn.GRU( |
| input_size=config.hidden_dim, |
| hidden_size=config.hidden_dim, |
| num_layers=config.num_layers, |
| batch_first=True, |
| dropout=config.dropout if config.num_layers > 1 else 0.0, |
| ) |
| self.token_proj = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.action_proj = nn.Sequential( |
| nn.LayerNorm(config.action_dim), |
| nn.Linear(config.action_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.uncertainty_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, 1), |
| ) |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| history_scene_tokens: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| pooled_current = scene_tokens.mean(dim=1, keepdim=True) |
| if history_scene_tokens is not None and history_scene_tokens.numel() > 0: |
| history_pooled = history_scene_tokens.mean(dim=2) |
| if history_actions is not None and history_actions.numel() > 0: |
| history_action_tokens = self.action_proj(history_actions[:, -history_pooled.shape[1] :]) |
| history_pooled = history_pooled + history_action_tokens |
| sequence = torch.cat([history_pooled, pooled_current], dim=1) |
| else: |
| sequence = pooled_current |
| memory_sequence, hidden = self.gru(sequence) |
| final_state = hidden[-1] |
| return { |
| "memory_sequence": memory_sequence, |
| "memory_state": final_state, |
| "memory_token": self.token_proj(final_state).unsqueeze(1), |
| "memory_tokens": self.token_proj(final_state).unsqueeze(1), |
| "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(final_state)).squeeze(-1), |
| } |
|
|
|
|
| class InteractionObservationMemory(nn.Module): |
| def __init__(self, config: ObservationMemoryConfig) -> None: |
| super().__init__() |
| self.config = config |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.hidden_dim, |
| nhead=config.num_heads, |
| dim_feedforward=config.hidden_dim * 4, |
| dropout=config.dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=max(1, config.num_layers)) |
| self.position_embedding = nn.Parameter( |
| torch.randn(1, config.max_history_steps + 1, config.hidden_dim) * 0.02 |
| ) |
| self.bank_queries = nn.Parameter(torch.randn(config.memory_bank_size, config.hidden_dim) * 0.02) |
| self.bank_attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_dim, |
| num_heads=config.num_heads, |
| dropout=config.dropout, |
| batch_first=True, |
| ) |
| self.bank_mlp = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.token_proj = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.action_proj = nn.Sequential( |
| nn.LayerNorm(config.action_dim), |
| nn.Linear(config.action_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.uncertainty_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, 1), |
| ) |
|
|
| def _recency_weights(self, length: int, device: torch.device, dtype: torch.dtype) -> Tensor: |
| if length <= 0: |
| return torch.zeros((0,), device=device, dtype=dtype) |
| positions = torch.arange(length, device=device, dtype=dtype) |
| distances = (length - 1) - positions |
| weights = torch.exp(-0.5 * distances) |
| return weights / weights.sum().clamp_min(1e-6) |
|
|
| def _truncate_history(self, history_scene_tokens: Tensor | None) -> Tensor | None: |
| if history_scene_tokens is None or history_scene_tokens.numel() == 0: |
| return history_scene_tokens |
| if history_scene_tokens.shape[1] <= self.config.history_steps: |
| return history_scene_tokens |
| return history_scene_tokens[:, -self.config.history_steps :] |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| history_scene_tokens: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| pooled_current = scene_tokens.mean(dim=1, keepdim=True) |
| history_scene_tokens = self._truncate_history(history_scene_tokens) |
| if history_scene_tokens is not None and history_scene_tokens.numel() > 0: |
| history_pooled = history_scene_tokens.mean(dim=2) |
| if history_actions is not None and history_actions.numel() > 0: |
| truncated_actions = history_actions[:, -history_pooled.shape[1] :] |
| history_pooled = history_pooled + self.action_proj(truncated_actions) |
| recency_weights = self._recency_weights( |
| history_pooled.shape[1], |
| device=history_pooled.device, |
| dtype=history_pooled.dtype, |
| ).view(1, -1, 1) |
| history_pooled = history_pooled * recency_weights * float(history_pooled.shape[1]) |
| sequence = torch.cat([history_pooled, pooled_current], dim=1) |
| else: |
| sequence = pooled_current |
|
|
| seq_len = sequence.shape[1] |
| if seq_len > self.position_embedding.shape[1]: |
| raise ValueError( |
| f"Sequence length {seq_len} exceeds configured maximum {self.position_embedding.shape[1]}" |
| ) |
| encoded = self.sequence_encoder(sequence + self.position_embedding[:, :seq_len]) |
| batch_size = encoded.shape[0] |
| recent_window = min(max(1, self.config.memory_bank_size // 2), encoded.shape[1]) |
| recent_summary = encoded[:, -recent_window:].mean(dim=1, keepdim=True) |
| queries = self.bank_queries.unsqueeze(0).expand(batch_size, -1, -1) + recent_summary |
| bank_tokens, _ = self.bank_attention(queries, encoded, encoded) |
| bank_tokens = bank_tokens + self.bank_mlp(bank_tokens) |
| projected_bank = self.token_proj(bank_tokens + recent_summary) |
| pooled_bank = projected_bank.mean(dim=1) + 0.25 * recent_summary.squeeze(1) |
| return { |
| "memory_sequence": encoded, |
| "memory_state": encoded[:, -1], |
| "memory_token": pooled_bank.unsqueeze(1), |
| "memory_tokens": projected_bank, |
| "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1), |
| } |
|
|
|
|
| class _SelectiveMemoryBank(nn.Module): |
| def __init__( |
| self, |
| hidden_dim: int, |
| action_dim: int, |
| num_heads: int, |
| dropout: float, |
| bank_size: int, |
| history_steps: int, |
| max_history_steps: int, |
| write_threshold: float, |
| suppression_margin: float, |
| ) -> None: |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.history_steps = history_steps |
| self.max_history_steps = max_history_steps |
| self.write_threshold = write_threshold |
| self.suppression_margin = suppression_margin |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=num_heads, |
| dim_feedforward=hidden_dim * 4, |
| dropout=dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) |
| self.position_embedding = nn.Parameter(torch.randn(1, max_history_steps + 1, hidden_dim) * 0.02) |
| self.bank_queries = nn.Parameter(torch.randn(bank_size, hidden_dim) * 0.02) |
| self.bank_attention = nn.MultiheadAttention( |
| embed_dim=hidden_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ) |
| self.action_proj = nn.Sequential( |
| nn.LayerNorm(action_dim), |
| nn.Linear(action_dim, hidden_dim), |
| nn.GELU(), |
| ) |
| self.write_gate = nn.Sequential( |
| nn.LayerNorm(hidden_dim * 3), |
| nn.Linear(hidden_dim * 3, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, 1), |
| ) |
| self.token_proj = nn.Sequential( |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| def _recency_weights(self, length: int, device: torch.device, dtype: torch.dtype) -> Tensor: |
| if length <= 0: |
| return torch.zeros((0,), device=device, dtype=dtype) |
| positions = torch.arange(length, device=device, dtype=dtype) |
| distances = (length - 1) - positions |
| weights = torch.exp(-0.5 * distances) |
| return weights / weights.sum().clamp_min(1e-6) |
|
|
| def _truncate(self, history: Tensor | None) -> Tensor | None: |
| if history is None or history.numel() == 0: |
| return history |
| if history.shape[1] <= self.history_steps: |
| return history |
| return history[:, -self.history_steps :] |
|
|
| def _chunk_pool(self, tokens: Tensor) -> Tensor: |
| batch_size, seq_len, hidden_dim = tokens.shape |
| chunk_size = max(1, (seq_len + self.bank_queries.shape[0] - 1) // self.bank_queries.shape[0]) |
| slots = [] |
| for slot_idx in range(self.bank_queries.shape[0]): |
| start = slot_idx * chunk_size |
| end = min(seq_len, start + chunk_size) |
| if start >= seq_len: |
| pooled = tokens[:, -1] |
| else: |
| pooled = tokens[:, start:end].mean(dim=1) |
| slots.append(pooled) |
| return torch.stack(slots, dim=1) |
|
|
| def _compress_tokens(self, tokens: Tensor) -> Tensor: |
| base_slots = self._chunk_pool(tokens) |
| queries = self.bank_queries.unsqueeze(0).expand(tokens.shape[0], -1, -1) + base_slots |
| attended, _ = self.bank_attention(queries, tokens, tokens) |
| return base_slots + 0.1 * attended |
|
|
| def forward( |
| self, |
| current_tokens: Tensor, |
| history_scene_tokens: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| history_scene_tokens = self._truncate(history_scene_tokens) |
| current_bank = self._compress_tokens(current_tokens) |
| pooled_current = current_bank.mean(dim=1, keepdim=True) |
| if history_scene_tokens is not None and history_scene_tokens.numel() > 0: |
| batch_size, history_steps = history_scene_tokens.shape[:2] |
| flat_history = history_scene_tokens.reshape(batch_size * history_steps, history_scene_tokens.shape[2], history_scene_tokens.shape[3]) |
| history_bank = self._compress_tokens(flat_history).view(batch_size, history_steps, self.bank_queries.shape[0], self.hidden_dim) |
| history_pooled = history_bank.mean(dim=2) |
| if history_actions is not None and history_actions.numel() > 0: |
| history_actions = history_actions[:, -history_pooled.shape[1] :] |
| history_action_tokens = self.action_proj(history_actions).unsqueeze(2) |
| history_bank = history_bank + history_action_tokens |
| history_pooled = history_bank.mean(dim=2) |
| sequence = torch.cat([history_pooled, pooled_current], dim=1) |
| else: |
| history_bank = current_bank.unsqueeze(1)[:, :0] |
| history_pooled = pooled_current[:, :0] |
| sequence = pooled_current |
| if sequence.shape[1] > self.position_embedding.shape[1]: |
| raise ValueError( |
| f"Sequence length {sequence.shape[1]} exceeds configured maximum {self.position_embedding.shape[1]}" |
| ) |
| encoded = self.sequence_encoder(sequence + self.position_embedding[:, : sequence.shape[1]]) |
| current_token = encoded[:, -1] |
| if history_bank.shape[1] > 0: |
| recency = self._recency_weights( |
| history_bank.shape[1], |
| device=history_bank.device, |
| dtype=history_bank.dtype, |
| ).view(1, -1, 1, 1) |
| prior_bank = (history_bank * recency).sum(dim=1) |
| else: |
| prior_bank = torch.zeros_like(current_bank) |
| novelty = torch.abs(current_bank - prior_bank) |
| gate_logit = self.write_gate(torch.cat([current_bank, prior_bank, novelty], dim=-1)) |
| novelty_score = novelty.mean(dim=-1, keepdim=True) |
| novelty_gate = torch.sigmoid(12.0 * (novelty_score - self.write_threshold)) |
| gate = (0.25 + 0.75 * torch.sigmoid(gate_logit)) * novelty_gate |
| bank_tokens = prior_bank * (1.0 - gate) + current_bank * gate |
| bank_tokens = self.token_proj(bank_tokens) |
| return { |
| "memory_tokens": bank_tokens, |
| "memory_token": bank_tokens.mean(dim=1, keepdim=True), |
| "memory_sequence": encoded, |
| "memory_state": current_token, |
| "write_gate": gate.squeeze(-1), |
| "saturation": bank_tokens.abs().mean(dim=(1, 2)), |
| } |
|
|
|
|
| class SceneMemory(_SelectiveMemoryBank): |
| def __init__(self, config: ObservationMemoryConfig) -> None: |
| super().__init__( |
| hidden_dim=config.hidden_dim, |
| action_dim=config.action_dim, |
| num_heads=config.num_heads, |
| dropout=config.dropout, |
| bank_size=max(1, config.scene_bank_size), |
| history_steps=max(1, config.scene_history_steps), |
| max_history_steps=config.max_history_steps, |
| write_threshold=config.memory_write_threshold, |
| suppression_margin=config.memory_suppression_margin, |
| ) |
|
|
|
|
| class BeliefMemory(_SelectiveMemoryBank): |
| def __init__(self, config: ObservationMemoryConfig) -> None: |
| super().__init__( |
| hidden_dim=config.hidden_dim, |
| action_dim=config.action_dim, |
| num_heads=config.num_heads, |
| dropout=config.dropout, |
| bank_size=max(1, config.belief_bank_size), |
| history_steps=max(1, config.belief_history_steps), |
| max_history_steps=config.max_history_steps, |
| write_threshold=config.memory_write_threshold + 0.05, |
| suppression_margin=config.memory_suppression_margin, |
| ) |
|
|
|
|
| class DualObservationMemory(nn.Module): |
| def __init__(self, config: ObservationMemoryConfig) -> None: |
| super().__init__() |
| self.scene_memory = SceneMemory(config) |
| self.belief_memory = BeliefMemory(config) |
| self.uncertainty_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, 1), |
| ) |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| history_scene_tokens: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| scene_output = self.scene_memory( |
| current_tokens=scene_tokens, |
| history_scene_tokens=history_scene_tokens, |
| history_actions=history_actions, |
| ) |
| belief_output = self.belief_memory( |
| current_tokens=scene_tokens, |
| history_scene_tokens=history_scene_tokens, |
| history_actions=history_actions, |
| ) |
| memory_tokens = torch.cat([scene_output["memory_tokens"], belief_output["memory_tokens"]], dim=1) |
| memory_token = memory_tokens.mean(dim=1, keepdim=True) |
| memory_state = torch.cat([scene_output["memory_state"], belief_output["memory_state"]], dim=-1) |
| pooled_memory = memory_tokens.mean(dim=1) |
| return { |
| "scene_memory_tokens": scene_output["memory_tokens"], |
| "belief_memory_tokens": belief_output["memory_tokens"], |
| "memory_tokens": memory_tokens, |
| "memory_token": memory_token, |
| "memory_sequence": torch.cat( |
| [scene_output["memory_sequence"], belief_output["memory_sequence"]], |
| dim=1, |
| ), |
| "memory_state": memory_state, |
| "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_memory)).squeeze(-1), |
| "memory_write_rate": 0.5 * (scene_output["write_gate"] + belief_output["write_gate"]), |
| "memory_saturation": 0.5 * (scene_output["saturation"] + belief_output["saturation"]), |
| "scene_write_gate": scene_output["write_gate"], |
| "belief_write_gate": belief_output["write_gate"], |
| "memory_scene_state": scene_output["memory_state"], |
| "memory_belief_state": belief_output["memory_state"], |
| } |
|
|