VLAarchtests / code /reveal_vla_bimanual /models /observation_memory.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor, nn
@dataclass
class ObservationMemoryConfig:
hidden_dim: int = 512
action_dim: int = 14
history_steps: int = 2
num_layers: int = 1
dropout: float = 0.1
memory_bank_size: int = 4
num_heads: int = 4
max_history_steps: int = 8
scene_bank_size: int = 2
belief_bank_size: int = 2
scene_history_steps: int = 3
belief_history_steps: int = 8
memory_write_threshold: float = 0.45
memory_suppression_margin: float = 0.05
class ObservationMemory(nn.Module):
def __init__(self, config: ObservationMemoryConfig) -> None:
super().__init__()
self.config = config
self.gru = nn.GRU(
input_size=config.hidden_dim,
hidden_size=config.hidden_dim,
num_layers=config.num_layers,
batch_first=True,
dropout=config.dropout if config.num_layers > 1 else 0.0,
)
self.token_proj = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
)
self.action_proj = nn.Sequential(
nn.LayerNorm(config.action_dim),
nn.Linear(config.action_dim, config.hidden_dim),
nn.GELU(),
)
self.uncertainty_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, 1),
)
def forward(
self,
scene_tokens: Tensor,
history_scene_tokens: Tensor | None = None,
history_actions: Tensor | None = None,
) -> dict[str, Tensor]:
pooled_current = scene_tokens.mean(dim=1, keepdim=True)
if history_scene_tokens is not None and history_scene_tokens.numel() > 0:
history_pooled = history_scene_tokens.mean(dim=2)
if history_actions is not None and history_actions.numel() > 0:
history_action_tokens = self.action_proj(history_actions[:, -history_pooled.shape[1] :])
history_pooled = history_pooled + history_action_tokens
sequence = torch.cat([history_pooled, pooled_current], dim=1)
else:
sequence = pooled_current
memory_sequence, hidden = self.gru(sequence)
final_state = hidden[-1]
return {
"memory_sequence": memory_sequence,
"memory_state": final_state,
"memory_token": self.token_proj(final_state).unsqueeze(1),
"memory_tokens": self.token_proj(final_state).unsqueeze(1),
"memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(final_state)).squeeze(-1),
}
class InteractionObservationMemory(nn.Module):
def __init__(self, config: ObservationMemoryConfig) -> None:
super().__init__()
self.config = config
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.hidden_dim * 4,
dropout=config.dropout,
batch_first=True,
norm_first=True,
)
self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=max(1, config.num_layers))
self.position_embedding = nn.Parameter(
torch.randn(1, config.max_history_steps + 1, config.hidden_dim) * 0.02
)
self.bank_queries = nn.Parameter(torch.randn(config.memory_bank_size, config.hidden_dim) * 0.02)
self.bank_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
dropout=config.dropout,
batch_first=True,
)
self.bank_mlp = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.token_proj = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
)
self.action_proj = nn.Sequential(
nn.LayerNorm(config.action_dim),
nn.Linear(config.action_dim, config.hidden_dim),
nn.GELU(),
)
self.uncertainty_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, 1),
)
def _recency_weights(self, length: int, device: torch.device, dtype: torch.dtype) -> Tensor:
if length <= 0:
return torch.zeros((0,), device=device, dtype=dtype)
positions = torch.arange(length, device=device, dtype=dtype)
distances = (length - 1) - positions
weights = torch.exp(-0.5 * distances)
return weights / weights.sum().clamp_min(1e-6)
def _truncate_history(self, history_scene_tokens: Tensor | None) -> Tensor | None:
if history_scene_tokens is None or history_scene_tokens.numel() == 0:
return history_scene_tokens
if history_scene_tokens.shape[1] <= self.config.history_steps:
return history_scene_tokens
return history_scene_tokens[:, -self.config.history_steps :]
def forward(
self,
scene_tokens: Tensor,
history_scene_tokens: Tensor | None = None,
history_actions: Tensor | None = None,
) -> dict[str, Tensor]:
pooled_current = scene_tokens.mean(dim=1, keepdim=True)
history_scene_tokens = self._truncate_history(history_scene_tokens)
if history_scene_tokens is not None and history_scene_tokens.numel() > 0:
history_pooled = history_scene_tokens.mean(dim=2)
if history_actions is not None and history_actions.numel() > 0:
truncated_actions = history_actions[:, -history_pooled.shape[1] :]
history_pooled = history_pooled + self.action_proj(truncated_actions)
recency_weights = self._recency_weights(
history_pooled.shape[1],
device=history_pooled.device,
dtype=history_pooled.dtype,
).view(1, -1, 1)
history_pooled = history_pooled * recency_weights * float(history_pooled.shape[1])
sequence = torch.cat([history_pooled, pooled_current], dim=1)
else:
sequence = pooled_current
seq_len = sequence.shape[1]
if seq_len > self.position_embedding.shape[1]:
raise ValueError(
f"Sequence length {seq_len} exceeds configured maximum {self.position_embedding.shape[1]}"
)
encoded = self.sequence_encoder(sequence + self.position_embedding[:, :seq_len])
batch_size = encoded.shape[0]
recent_window = min(max(1, self.config.memory_bank_size // 2), encoded.shape[1])
recent_summary = encoded[:, -recent_window:].mean(dim=1, keepdim=True)
queries = self.bank_queries.unsqueeze(0).expand(batch_size, -1, -1) + recent_summary
bank_tokens, _ = self.bank_attention(queries, encoded, encoded)
bank_tokens = bank_tokens + self.bank_mlp(bank_tokens)
projected_bank = self.token_proj(bank_tokens + recent_summary)
pooled_bank = projected_bank.mean(dim=1) + 0.25 * recent_summary.squeeze(1)
return {
"memory_sequence": encoded,
"memory_state": encoded[:, -1],
"memory_token": pooled_bank.unsqueeze(1),
"memory_tokens": projected_bank,
"memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1),
}
class _SelectiveMemoryBank(nn.Module):
def __init__(
self,
hidden_dim: int,
action_dim: int,
num_heads: int,
dropout: float,
bank_size: int,
history_steps: int,
max_history_steps: int,
write_threshold: float,
suppression_margin: float,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.history_steps = history_steps
self.max_history_steps = max_history_steps
self.write_threshold = write_threshold
self.suppression_margin = suppression_margin
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
self.position_embedding = nn.Parameter(torch.randn(1, max_history_steps + 1, hidden_dim) * 0.02)
self.bank_queries = nn.Parameter(torch.randn(bank_size, hidden_dim) * 0.02)
self.bank_attention = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.action_proj = nn.Sequential(
nn.LayerNorm(action_dim),
nn.Linear(action_dim, hidden_dim),
nn.GELU(),
)
self.write_gate = nn.Sequential(
nn.LayerNorm(hidden_dim * 3),
nn.Linear(hidden_dim * 3, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1),
)
self.token_proj = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
)
def _recency_weights(self, length: int, device: torch.device, dtype: torch.dtype) -> Tensor:
if length <= 0:
return torch.zeros((0,), device=device, dtype=dtype)
positions = torch.arange(length, device=device, dtype=dtype)
distances = (length - 1) - positions
weights = torch.exp(-0.5 * distances)
return weights / weights.sum().clamp_min(1e-6)
def _truncate(self, history: Tensor | None) -> Tensor | None:
if history is None or history.numel() == 0:
return history
if history.shape[1] <= self.history_steps:
return history
return history[:, -self.history_steps :]
def _chunk_pool(self, tokens: Tensor) -> Tensor:
batch_size, seq_len, hidden_dim = tokens.shape
chunk_size = max(1, (seq_len + self.bank_queries.shape[0] - 1) // self.bank_queries.shape[0])
slots = []
for slot_idx in range(self.bank_queries.shape[0]):
start = slot_idx * chunk_size
end = min(seq_len, start + chunk_size)
if start >= seq_len:
pooled = tokens[:, -1]
else:
pooled = tokens[:, start:end].mean(dim=1)
slots.append(pooled)
return torch.stack(slots, dim=1)
def _compress_tokens(self, tokens: Tensor) -> Tensor:
base_slots = self._chunk_pool(tokens)
queries = self.bank_queries.unsqueeze(0).expand(tokens.shape[0], -1, -1) + base_slots
attended, _ = self.bank_attention(queries, tokens, tokens)
return base_slots + 0.1 * attended
def forward(
self,
current_tokens: Tensor,
history_scene_tokens: Tensor | None = None,
history_actions: Tensor | None = None,
) -> dict[str, Tensor]:
history_scene_tokens = self._truncate(history_scene_tokens)
current_bank = self._compress_tokens(current_tokens)
pooled_current = current_bank.mean(dim=1, keepdim=True)
if history_scene_tokens is not None and history_scene_tokens.numel() > 0:
batch_size, history_steps = history_scene_tokens.shape[:2]
flat_history = history_scene_tokens.reshape(batch_size * history_steps, history_scene_tokens.shape[2], history_scene_tokens.shape[3])
history_bank = self._compress_tokens(flat_history).view(batch_size, history_steps, self.bank_queries.shape[0], self.hidden_dim)
history_pooled = history_bank.mean(dim=2)
if history_actions is not None and history_actions.numel() > 0:
history_actions = history_actions[:, -history_pooled.shape[1] :]
history_action_tokens = self.action_proj(history_actions).unsqueeze(2)
history_bank = history_bank + history_action_tokens
history_pooled = history_bank.mean(dim=2)
sequence = torch.cat([history_pooled, pooled_current], dim=1)
else:
history_bank = current_bank.unsqueeze(1)[:, :0]
history_pooled = pooled_current[:, :0]
sequence = pooled_current
if sequence.shape[1] > self.position_embedding.shape[1]:
raise ValueError(
f"Sequence length {sequence.shape[1]} exceeds configured maximum {self.position_embedding.shape[1]}"
)
encoded = self.sequence_encoder(sequence + self.position_embedding[:, : sequence.shape[1]])
current_token = encoded[:, -1]
if history_bank.shape[1] > 0:
recency = self._recency_weights(
history_bank.shape[1],
device=history_bank.device,
dtype=history_bank.dtype,
).view(1, -1, 1, 1)
prior_bank = (history_bank * recency).sum(dim=1)
else:
prior_bank = torch.zeros_like(current_bank)
novelty = torch.abs(current_bank - prior_bank)
gate_logit = self.write_gate(torch.cat([current_bank, prior_bank, novelty], dim=-1))
novelty_score = novelty.mean(dim=-1, keepdim=True)
novelty_gate = torch.sigmoid(12.0 * (novelty_score - self.write_threshold))
gate = (0.25 + 0.75 * torch.sigmoid(gate_logit)) * novelty_gate
bank_tokens = prior_bank * (1.0 - gate) + current_bank * gate
bank_tokens = self.token_proj(bank_tokens)
return {
"memory_tokens": bank_tokens,
"memory_token": bank_tokens.mean(dim=1, keepdim=True),
"memory_sequence": encoded,
"memory_state": current_token,
"write_gate": gate.squeeze(-1),
"saturation": bank_tokens.abs().mean(dim=(1, 2)),
}
class SceneMemory(_SelectiveMemoryBank):
def __init__(self, config: ObservationMemoryConfig) -> None:
super().__init__(
hidden_dim=config.hidden_dim,
action_dim=config.action_dim,
num_heads=config.num_heads,
dropout=config.dropout,
bank_size=max(1, config.scene_bank_size),
history_steps=max(1, config.scene_history_steps),
max_history_steps=config.max_history_steps,
write_threshold=config.memory_write_threshold,
suppression_margin=config.memory_suppression_margin,
)
class BeliefMemory(_SelectiveMemoryBank):
def __init__(self, config: ObservationMemoryConfig) -> None:
super().__init__(
hidden_dim=config.hidden_dim,
action_dim=config.action_dim,
num_heads=config.num_heads,
dropout=config.dropout,
bank_size=max(1, config.belief_bank_size),
history_steps=max(1, config.belief_history_steps),
max_history_steps=config.max_history_steps,
write_threshold=config.memory_write_threshold + 0.05,
suppression_margin=config.memory_suppression_margin,
)
class DualObservationMemory(nn.Module):
def __init__(self, config: ObservationMemoryConfig) -> None:
super().__init__()
self.scene_memory = SceneMemory(config)
self.belief_memory = BeliefMemory(config)
self.uncertainty_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, 1),
)
def forward(
self,
scene_tokens: Tensor,
history_scene_tokens: Tensor | None = None,
history_actions: Tensor | None = None,
) -> dict[str, Tensor]:
scene_output = self.scene_memory(
current_tokens=scene_tokens,
history_scene_tokens=history_scene_tokens,
history_actions=history_actions,
)
belief_output = self.belief_memory(
current_tokens=scene_tokens,
history_scene_tokens=history_scene_tokens,
history_actions=history_actions,
)
memory_tokens = torch.cat([scene_output["memory_tokens"], belief_output["memory_tokens"]], dim=1)
memory_token = memory_tokens.mean(dim=1, keepdim=True)
memory_state = torch.cat([scene_output["memory_state"], belief_output["memory_state"]], dim=-1)
pooled_memory = memory_tokens.mean(dim=1)
return {
"scene_memory_tokens": scene_output["memory_tokens"],
"belief_memory_tokens": belief_output["memory_tokens"],
"memory_tokens": memory_tokens,
"memory_token": memory_token,
"memory_sequence": torch.cat(
[scene_output["memory_sequence"], belief_output["memory_sequence"]],
dim=1,
),
"memory_state": memory_state,
"memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_memory)).squeeze(-1),
"memory_write_rate": 0.5 * (scene_output["write_gate"] + belief_output["write_gate"]),
"memory_saturation": 0.5 * (scene_output["saturation"] + belief_output["saturation"]),
"scene_write_gate": scene_output["write_gate"],
"belief_write_gate": belief_output["write_gate"],
"memory_scene_state": scene_output["memory_state"],
"memory_belief_state": belief_output["memory_state"],
}