bitmar-attention-multimodal / modeling_bitmar.py
estebancarlin's picture
Fix BitMarModel class and tensor shapes for main
db3ebcb verified
"""
BitMar Model for Hugging Face Transformers
BitNet-quantized Vision-Language Episodic Memory Transformer
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import math
import os
import pickle
import gzip
from typing import Dict, List, Optional, Tuple, Union
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput
import time
logger = logging.getLogger(__name__)
class BitMarConfig(PretrainedConfig):
"""Configuration class for BitMar model"""
model_type = "bitmar"
def __init__(
self,
vocab_size: int = 50257,
text_encoder_dim: int = 128,
text_encoder_layers: int = 4,
text_encoder_heads: int = 4,
text_decoder_dim: int = 128,
text_decoder_layers: int = 4,
text_decoder_heads: int = 4,
vision_encoder_dim: int = 768,
vision_latent_size: int = 128,
vision_hidden_size: int = 64,
vision_compression_method: str = "learned_compression",
vision_spatial_pooling: bool = True,
vision_pool_size: int = 2,
fusion_hidden_size: int = 128,
fusion_num_heads: int = 4,
fusion_num_layers: int = 2,
memory_size: int = 32,
episode_dim: int = 128,
memory_alpha: float = 0.2,
direct_writing: bool = True,
memory_compression: bool = True,
max_seq_len: int = 256,
dropout: float = 0.15,
initializer_range: float = 0.02,
layer_norm_epsilon: float = 1e-5,
use_cache: bool = True,
tie_word_embeddings: bool = True,
pad_token_id: int = 50256,
bos_token_id: int = 50256,
eos_token_id: int = 50256,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
self.vocab_size = vocab_size
self.text_encoder_dim = text_encoder_dim
self.text_encoder_layers = text_encoder_layers
self.text_encoder_heads = text_encoder_heads
self.text_decoder_dim = text_decoder_dim
self.text_decoder_layers = text_decoder_layers
self.text_decoder_heads = text_decoder_heads
self.vision_encoder_dim = vision_encoder_dim
self.vision_latent_size = vision_latent_size
self.vision_hidden_size = vision_hidden_size
self.vision_compression_method = vision_compression_method
self.vision_spatial_pooling = vision_spatial_pooling
self.vision_pool_size = vision_pool_size
self.fusion_hidden_size = fusion_hidden_size
self.fusion_num_heads = fusion_num_heads
self.fusion_num_layers = fusion_num_layers
self.memory_size = memory_size
self.episode_dim = episode_dim
self.memory_alpha = memory_alpha
self.direct_writing = direct_writing
self.memory_compression = memory_compression
self.max_seq_len = max_seq_len
self.dropout = dropout
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
class BitNetLinear(nn.Module):
"""1.58-bit Linear layer following BitNet b1.58 architecture - FIXED VERSION"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Weight parameters (full precision for training)
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
# FIXED
self.register_buffer('weight_scale', torch.tensor(1.0))
self.register_buffer('input_scale', torch.tensor(1.0))
def quantize_weights_1_58_bit(self, weight: torch.Tensor) -> torch.Tensor:
"""BitNet b1.58 weight quantization: {-1, 0, +1}"""
# Handle empty tensors
if weight.numel() == 0:
return weight
# Compute scaling factor with numerical stability
scale = weight.abs().mean()
# Handle case where all weights are zero
if scale < 1e-8:
scale = torch.tensor(1e-5, device=weight.device, dtype=weight.dtype)
self.weight_scale.data = scale.clamp(min=1e-5, max=1e3)
# Normalize weights with gradient clipping
weight_norm = torch.clamp(weight / self.weight_scale, min=-10.0, max=10.0)
# 1.58-bit quantization with threshold
threshold = 2.0 / 3.0 # Optimal threshold for ternary quantization
# Create ternary weights
quantized = torch.zeros_like(weight_norm)
quantized[weight_norm > threshold] = 1.0
quantized[weight_norm < -threshold] = -1.0
# Values between -threshold and threshold remain 0
return quantized
def quantize_activations_8bit(self, x: torch.Tensor) -> torch.Tensor:
"""8-bit activation quantization with numerical stability"""
# Handle empty tensors
if x.numel() == 0:
return x
# Clamp extreme values to prevent overflow
x_clamped = torch.clamp(x, min=-1e6, max=1e6)
# Handle scalar tensors
if x_clamped.numel() == 1:
return x_clamped
# Compute quantization parameters
x_min, x_max = x_clamped.min(), x_clamped.max()
# Prevent division by zero
range_val = x_max - x_min
if range_val < 1e-8:
return x_clamped
scale = range_val / 255.0
self.input_scale.data = scale.clamp(min=1e-8, max=1e3)
# Quantize to 8-bit
zero_point = (-x_min / scale).round().clamp(0, 255)
quantized = ((x_clamped / scale) + zero_point).round().clamp(0, 255)
# Dequantize
dequantized = scale * (quantized - zero_point)
return dequantized
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
# Full precision training with straight-through estimator
# Forward pass with quantized weights but gradients flow through original weights
weight_q = self.quantize_weights_1_58_bit(self.weight)
weight_forward = weight_q * self.weight_scale
# Use original weight for gradient computation
weight_forward = weight_forward + (self.weight - self.weight.detach())
return F.linear(x, weight_forward, self.bias)
else:
# Inference with full quantization
weight_q = self.quantize_weights_1_58_bit(self.weight) * self.weight_scale
x_q = self.quantize_activations_8bit(x)
return F.linear(x_q, weight_q, self.bias)
class BitNetMLP(nn.Module):
"""BitNet MLP block with 1.58-bit quantization"""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.fc1 = BitNetLinear(dim, hidden_dim)
self.fc2 = BitNetLinear(hidden_dim, dim)
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.fc1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return self.norm(x + residual)
class BitNetAttention(nn.Module):
"""Multi-head attention with BitNet quantization"""
def __init__(
self,
dim: int,
num_heads: int,
dropout: float = 0.1,
bias: bool = True
):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# BitNet quantized projections
self.q_proj = BitNetLinear(dim, dim, bias=bias)
self.k_proj = BitNetLinear(dim, dim, bias=bias)
self.v_proj = BitNetLinear(dim, dim, bias=bias)
self.out_proj = BitNetLinear(dim, dim, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len = query.shape[:2]
# Validate input dimensions
if query.size(-1) != self.dim:
raise ValueError(f"Query dimension {query.size(-1)} doesn't match expected {self.dim}")
if key.size(-1) != self.dim:
raise ValueError(f"Key dimension {key.size(-1)} doesn't match expected {self.dim}")
if value.size(-1) != self.dim:
raise ValueError(f"Value dimension {value.size(-1)} doesn't match expected {self.dim}")
# Linear projections
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
# Get key/value sequence length (handle different shapes)
key_seq_len = key.size(1)
# Reshape for multi-head attention with proper dimension checking
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Attention computation
attention_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
# Handle mask shape: expand to match attention scores shape
if mask.dim() == 2: # [batch_size, seq_len]
mask = mask.unsqueeze(1).unsqueeze(1) # [batch_size, 1, 1, seq_len]
elif mask.dim() == 3: # [batch_size, seq_len, seq_len]
mask = mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
# Expand mask to match attention scores shape [batch_size, num_heads, seq_len, key_seq_len]
if mask.size(-1) != key_seq_len:
# Adjust mask if needed
if mask.size(-1) == seq_len:
# Pad or trim mask to match key_seq_len
if key_seq_len > seq_len:
pad_size = key_seq_len - seq_len
mask = torch.cat([mask, torch.zeros(*mask.shape[:-1], pad_size, device=mask.device, dtype=mask.dtype)], dim=-1)
else:
mask = mask[..., :key_seq_len]
mask = mask.expand(batch_size, self.num_heads, seq_len, key_seq_len)
attention_scores.masked_fill_(mask == 0, float('-inf'))
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention to values
attended = torch.matmul(attention_weights, v)
# Reshape and project output
attended = attended.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.dim
)
output = self.out_proj(attended)
return output, attention_weights.mean(dim=1) # Average across heads
class BitNetTransformerBlock(nn.Module):
"""BitNet Transformer block with quantized components"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
dropout: float = 0.1
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = BitNetAttention(dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(dim)
self.mlp = BitNetMLP(dim, int(dim * mlp_ratio), dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self-attention with residual connection
normed_x = self.norm1(x)
attn_out, attn_weights = self.attn(normed_x, normed_x, normed_x, mask)
x = x + attn_out
# MLP with residual connection
x = x + self.mlp(self.norm2(x))
return x, attn_weights
class BitNetTextEncoder(nn.Module):
"""BitNet-based text encoder"""
def __init__(
self,
vocab_size: int,
dim: int,
num_layers: int,
num_heads: int,
max_seq_len: int = 512,
dropout: float = 0.1
):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# Token embeddings (kept full precision)
self.token_embedding = nn.Embedding(vocab_size, dim)
self.position_embedding = nn.Embedding(max_seq_len, dim)
# BitNet transformer layers
self.layers = nn.ModuleList([
BitNetTransformerBlock(dim, num_heads, dropout=dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
# Initialize embeddings
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.position_embedding.weight, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
batch_size, seq_len = input_ids.shape
# Embeddings
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
x = self.token_embedding(input_ids) + \
self.position_embedding(positions)
x = self.dropout(x)
# Transform through BitNet layers
attention_patterns = []
for layer in self.layers:
# Convert attention mask to the right format for the layer
layer_mask = None
if attention_mask is not None:
# Create a mask where 1 means attend, 0 means don't attend
layer_mask = attention_mask.unsqueeze(
1).unsqueeze(2) # [batch_size, 1, 1, seq_len]
x, attn_weights = layer(x, layer_mask)
attention_patterns.append(attn_weights)
x = self.norm(x)
return x, attention_patterns
class BitNetTextDecoder(nn.Module):
"""BitNet-based text decoder with causal masking"""
def __init__(
self,
vocab_size: int,
dim: int,
num_layers: int,
num_heads: int,
max_seq_len: int = 512,
dropout: float = 0.1
):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# Token embeddings
self.token_embedding = nn.Embedding(vocab_size, dim)
self.position_embedding = nn.Embedding(max_seq_len, dim)
# BitNet transformer layers
self.layers = nn.ModuleList([
BitNetTransformerBlock(dim, num_heads, dropout=dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
# Output projection to vocabulary
self.lm_head = BitNetLinear(dim, vocab_size, bias=False)
# Initialize embeddings
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.position_embedding.weight, std=0.02)
# Register causal mask
self.register_buffer(
'causal_mask',
torch.tril(torch.ones(max_seq_len, max_seq_len)
).unsqueeze(0).unsqueeze(0)
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
if input_ids is not None:
batch_size, seq_len = input_ids.shape
positions = torch.arange(
seq_len, device=input_ids.device).unsqueeze(0)
x = self.token_embedding(input_ids) + \
self.position_embedding(positions)
elif inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
positions = torch.arange(
seq_len, device=inputs_embeds.device).unsqueeze(0)
x = inputs_embeds + self.position_embedding(positions)
else:
raise ValueError(
"Either input_ids or inputs_embeds must be provided")
x = self.dropout(x)
# Create causal mask
causal_mask = self.causal_mask[:, :, :seq_len, :seq_len]
if attention_mask is not None:
# Combine causal mask with padding mask
mask = attention_mask.unsqueeze(1).unsqueeze(2) * causal_mask
else:
mask = causal_mask
# Transform through BitNet layers
attention_patterns = []
for layer in self.layers:
x, attn_weights = layer(x, mask)
attention_patterns.append(attn_weights)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
# Shift labels for causal LM
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
return {
'logits': logits,
'loss': loss,
'attention_patterns': attention_patterns
}
class EpisodicMemory(nn.Module):
"""Episodic Memory mechanism inspired by Larimar with performance optimizations and external storage support"""
def __init__(
self,
memory_size: int,
episode_dim: int,
alpha: float = 0.1,
direct_writing: bool = True,
observation_noise_std: float = 1e-6,
external_storage: bool = False,
memory_storage_path: str = None,
compression_enabled: bool = True,
lazy_loading: bool = False
):
super().__init__()
self.memory_size = memory_size
self.episode_dim = episode_dim
self.alpha = alpha
self.direct_writing = direct_writing
self.observation_noise_std = observation_noise_std
# External storage configuration
self.external_storage = external_storage
self.memory_storage_path = memory_storage_path
self.compression_enabled = compression_enabled
self.lazy_loading = lazy_loading
self._memory_loaded = False
self._memory_version = 1
# Memory storage with improved initialization
if external_storage and lazy_loading:
# For lazy loading, we'll initialize empty and load when needed
self._memory_data = None
self._metadata = None
else:
# Standard initialization for compatibility
self.register_buffer('memory', torch.randn(memory_size, episode_dim) * 0.02)
self.register_buffer('memory_age', torch.zeros(memory_size))
self.register_buffer('memory_usage', torch.zeros(memory_size))
# Always initialize these for proper functioning
self.register_buffer('memory_quality', torch.zeros(memory_size))
self.register_buffer('memory_importance', torch.ones(memory_size))
self.register_buffer('memory_mean', torch.zeros(episode_dim))
self.register_buffer('memory_std', torch.ones(episode_dim))
self.register_buffer('update_count', torch.tensor(0))
# Enhanced memory access networks with residual connections
self.query_net = nn.Sequential(
BitNetLinear(episode_dim, episode_dim),
nn.LayerNorm(episode_dim),
nn.GELU(),
BitNetLinear(episode_dim, episode_dim)
)
self.key_net = nn.Sequential(
BitNetLinear(episode_dim, episode_dim),
nn.LayerNorm(episode_dim),
nn.GELU(),
BitNetLinear(episode_dim, episode_dim)
)
self.value_net = nn.Sequential(
BitNetLinear(episode_dim, episode_dim),
nn.LayerNorm(episode_dim),
nn.GELU(),
BitNetLinear(episode_dim, episode_dim)
)
# Add temperature parameter for attention sharpening
self.register_parameter('attention_temperature', nn.Parameter(torch.tensor(1.0)))
# Memory consolidation network for better episode encoding
self.consolidation_net = nn.Sequential(
BitNetLinear(episode_dim, episode_dim * 2),
nn.LayerNorm(episode_dim * 2),
nn.GELU(),
nn.Dropout(0.1),
BitNetLinear(episode_dim * 2, episode_dim),
nn.LayerNorm(episode_dim)
)
def _ensure_memory_loaded(self):
"""Ensure memory is loaded into device memory"""
if self.external_storage and self.lazy_loading and not self._memory_loaded:
self.load_external_memory()
elif not hasattr(self, 'memory'):
# Initialize if not present (compatibility mode)
self.register_buffer('memory', torch.randn(self.memory_size, self.episode_dim) * 0.02)
self.register_buffer('memory_age', torch.zeros(self.memory_size))
self.register_buffer('memory_usage', torch.zeros(self.memory_size))
def save_external_memory(self, path: str = None, compress: bool = None) -> str:
"""Save episodic memory to external storage"""
import os
import json
from pathlib import Path
# Use provided path or default
save_path = path or self.memory_storage_path or "episodic_memory.pt"
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Use provided compression setting or default
use_compression = compress if compress is not None else self.compression_enabled
# Prepare memory data
memory_data = {
'memory': self.memory.cpu() if hasattr(self, 'memory') else torch.randn(self.memory_size, self.episode_dim) * 0.02,
'memory_age': self.memory_age.cpu() if hasattr(self, 'memory_age') else torch.zeros(self.memory_size),
'memory_usage': self.memory_usage.cpu() if hasattr(self, 'memory_usage') else torch.zeros(self.memory_size),
'memory_quality': self.memory_quality.cpu(),
'memory_importance': self.memory_importance.cpu(),
'memory_mean': self.memory_mean.cpu(),
'memory_std': self.memory_std.cpu(),
'update_count': self.update_count.cpu(),
'version': self._memory_version,
'metadata': {
'memory_size': self.memory_size,
'episode_dim': self.episode_dim,
'alpha': self.alpha,
'creation_timestamp': torch.tensor(time.time()),
'compression_enabled': use_compression
}
}
# Apply compression if enabled
if use_compression:
# Quantize memory to reduce storage size
memory_data['memory'] = self._compress_memory_tensor(memory_data['memory'])
memory_data['compressed'] = True
else:
memory_data['compressed'] = False
# Save to file
torch.save(memory_data, save_path)
# Also save metadata separately for quick access
metadata_path = save_path.with_suffix('.json')
with open(metadata_path, 'w') as f:
json.dump({
'memory_size': self.memory_size,
'episode_dim': self.episode_dim,
'version': self._memory_version,
'compressed': use_compression,
'file_size_mb': save_path.stat().st_size / (1024 * 1024),
'creation_timestamp': time.time()
}, f, indent=2)
logger.info(f"💾 Episodic memory saved to: {save_path}")
logger.info(f"📊 Memory size: {save_path.stat().st_size / 1024:.1f} KB")
return str(save_path)
def load_external_memory(self, path: str = None, device: str = None) -> bool:
"""Load episodic memory from external storage"""
import json
from pathlib import Path
# Use provided path or default
load_path = path or self.memory_storage_path or "episodic_memory.pt"
load_path = Path(load_path)
if not load_path.exists():
logger.warning(f"⚠️ External memory file not found: {load_path}")
return False
try:
# Load memory data
memory_data = torch.load(load_path, map_location='cpu')
# Validate compatibility
if memory_data['metadata']['memory_size'] != self.memory_size:
logger.error(f"❌ Memory size mismatch: expected {self.memory_size}, got {memory_data['metadata']['memory_size']}")
return False
if memory_data['metadata']['episode_dim'] != self.episode_dim:
logger.error(f"❌ Episode dimension mismatch: expected {self.episode_dim}, got {memory_data['metadata']['episode_dim']}")
return False
# Set device
device = device or next(self.parameters()).device
# Decompress if needed
if memory_data.get('compressed', False):
memory_tensor = self._decompress_memory_tensor(memory_data['memory'])
else:
memory_tensor = memory_data['memory']
# Load memory tensors
if hasattr(self, 'memory'):
self.memory.copy_(memory_tensor.to(device))
self.memory_age.copy_(memory_data['memory_age'].to(device))
self.memory_usage.copy_(memory_data['memory_usage'].to(device))
else:
# Register buffers if not present (lazy loading case)
self.register_buffer('memory', memory_tensor.to(device))
self.register_buffer('memory_age', memory_data['memory_age'].to(device))
self.register_buffer('memory_usage', memory_data['memory_usage'].to(device))
self.memory_quality.copy_(memory_data['memory_quality'].to(device))
self.memory_importance.copy_(memory_data['memory_importance'].to(device))
self.memory_mean.copy_(memory_data['memory_mean'].to(device))
self.memory_std.copy_(memory_data['memory_std'].to(device))
self.update_count.copy_(memory_data['update_count'].to(device))
self._memory_version = memory_data.get('version', 1)
self._memory_loaded = True
logger.info(f"✅ Episodic memory loaded from: {load_path}")
logger.info(f"📊 Memory version: {self._memory_version}")
return True
except Exception as e:
logger.error(f"❌ Failed to load external memory: {e}")
return False
def _compress_memory_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Compress memory tensor for storage"""
# Quantize to int8 to reduce storage size
tensor_min = tensor.min()
tensor_max = tensor.max()
# Avoid division by zero
tensor_range = tensor_max - tensor_min
if tensor_range < 1e-8:
return tensor
# Quantize to int8 range
quantized = ((tensor - tensor_min) / tensor_range * 255).round().clamp(0, 255).to(torch.uint8)
# Store quantization parameters
return {
'data': quantized,
'min': tensor_min,
'max': tensor_max,
'original_shape': tensor.shape
}
def _decompress_memory_tensor(self, compressed_data) -> torch.Tensor:
"""Decompress memory tensor"""
if isinstance(compressed_data, dict):
quantized = compressed_data['data'].float()
tensor_min = compressed_data['min']
tensor_max = compressed_data['max']
# Dequantize
tensor_range = tensor_max - tensor_min
dequantized = (quantized / 255.0) * tensor_range + tensor_min
return dequantized.view(compressed_data['original_shape'])
else:
# Not compressed, return as-is
return compressed_data
def _update_memory_statistics(self, episodes: torch.Tensor):
"""Update running statistics for memory normalization"""
with torch.no_grad():
batch_mean = episodes.mean(dim=0)
batch_var = episodes.var(dim=0, unbiased=False)
# Exponential moving average
momentum = 0.1
self.memory_mean = (1 - momentum) * self.memory_mean + momentum * batch_mean
self.memory_std = torch.sqrt((1 - momentum) * self.memory_std**2 + momentum * batch_var)
self.update_count += 1
def _normalize_episodes(self, episodes: torch.Tensor) -> torch.Tensor:
"""Normalize episodes using running statistics"""
if self.update_count > 10: # Only normalize after some updates
return (episodes - self.memory_mean) / (self.memory_std + 1e-8)
return episodes
def _compute_episode_quality(self, episode: torch.Tensor, retrieved: torch.Tensor) -> torch.Tensor:
"""Compute quality score for memory episodes"""
# Quality based on diversity and relevance
similarity_to_memory = torch.cosine_similarity(
episode.unsqueeze(1), self.memory.unsqueeze(0), dim=-1
).max(dim=1)[0]
# Encourage diversity - lower similarity = higher quality
diversity_score = 1.0 - similarity_to_memory
# Relevance score based on retrieval quality
retrieval_quality = torch.cosine_similarity(episode, retrieved, dim=-1)
# Combined quality score
return 0.7 * diversity_score + 0.3 * retrieval_quality
def write_memory(self, episode: torch.Tensor) -> torch.Tensor:
"""Optimized memory writing with intelligent slot selection"""
batch_size = episode.size(0)
# Apply consolidation to improve episode representation
consolidated_episode = self.consolidation_net(episode) + episode # Residual connection
# Update statistics
self._update_memory_statistics(consolidated_episode)
# Normalize episodes
normalized_episode = self._normalize_episodes(consolidated_episode)
if self.direct_writing:
# Enhanced slot selection combining age, usage, and quality
if batch_size <= self.memory_size:
# Compute composite scores for slot selection
age_scores = -self.memory_age # Prefer older slots
usage_scores = -self.memory_usage # Prefer less used slots
quality_scores = -self.memory_quality # Prefer lower quality slots
importance_scores = -self.memory_importance # Prefer less important slots
# Weighted combination
composite_scores = (
0.4 * age_scores +
0.3 * usage_scores +
0.2 * quality_scores +
0.1 * importance_scores
)
_, best_indices = composite_scores.topk(batch_size, largest=True)
# Update memory slots with momentum-based updates
momentum = self.alpha
self.memory[best_indices] = (
(1 - momentum) * self.memory[best_indices] +
momentum * normalized_episode.detach()
)
# Update metadata
self.memory_age[best_indices] = self.memory_age.max() + 1
self.memory_usage[best_indices] += 1
# Update quality scores (will be computed during read)
with torch.no_grad():
# Temporary quality estimation based on internal consistency
temp_quality = torch.norm(normalized_episode, dim=-1)
self.memory_quality[best_indices] = temp_quality.detach()
else:
# Handle large batches efficiently
for i in range(0, batch_size, self.memory_size):
end_idx = min(i + self.memory_size, batch_size)
chunk_size = end_idx - i
# Apply same logic for chunks
age_scores = -self.memory_age
usage_scores = -self.memory_usage
quality_scores = -self.memory_quality
importance_scores = -self.memory_importance
composite_scores = (
0.4 * age_scores +
0.3 * usage_scores +
0.2 * quality_scores +
0.1 * importance_scores
)
_, chunk_indices = composite_scores.topk(chunk_size, largest=True)
momentum = self.alpha
self.memory[chunk_indices] = (
(1 - momentum) * self.memory[chunk_indices] +
momentum * normalized_episode[i:end_idx].detach()
)
self.memory_age[chunk_indices] = self.memory_age.max() + 1 + i
self.memory_usage[chunk_indices] += 1
temp_quality = torch.norm(normalized_episode[i:end_idx], dim=-1)
self.memory_quality[chunk_indices] = temp_quality.detach()
return consolidated_episode
def read_memory(self, query: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Optimized memory reading with enhanced attention"""
batch_size = query.size(0)
# Validate query dimensions
if query.size(-1) != self.episode_dim:
raise ValueError(f"Query dimension {query.size(-1)} doesn't match memory episode_dim {self.episode_dim}")
# Normalize query
normalized_query = self._normalize_episodes(query)
# Enhanced query, key, value computation with residual connections
q = self.query_net(normalized_query) + normalized_query # Residual
k = self.key_net(self.memory) + self.memory # Residual
v = self.value_net(self.memory) + self.memory # Residual
# Scaled dot-product attention with learnable temperature
attention_scores = torch.matmul(q, k.transpose(0, 1)) / (
math.sqrt(self.episode_dim) * self.attention_temperature.clamp(min=0.1, max=10.0)
)
# Add importance weighting to attention scores
importance_weights = self.memory_importance.unsqueeze(0).expand(batch_size, -1)
attention_scores = attention_scores + torch.log(importance_weights + 1e-8)
# Apply attention with improved stability
attention_weights = F.softmax(attention_scores, dim=-1)
# Add attention dropout for regularization during training
if self.training:
attention_weights = F.dropout(attention_weights, p=0.1)
# Weighted memory retrieval
retrieved = torch.matmul(attention_weights, v)
# Update memory access statistics and importance
with torch.no_grad():
access_counts = attention_weights.sum(0)
self.memory_usage += access_counts
# Update importance based on usage frequency
self.memory_importance = 0.9 * self.memory_importance + 0.1 * (access_counts + 1e-8)
# Update quality scores based on retrieval effectiveness
if hasattr(self, '_last_query_quality'):
quality_update = self._compute_episode_quality(query, retrieved)
# Update quality for attended slots
attended_indices = attention_weights.max(0)[1] # Most attended slots
self.memory_quality[attended_indices] = (
0.8 * self.memory_quality[attended_indices] +
0.2 * quality_update.mean()
)
return retrieved, attention_weights
def forward(self, episode: torch.Tensor, mode: str = "read_write") -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced forward pass with memory consolidation"""
if mode == "write":
return self.write_memory(episode), None
elif mode == "read":
return self.read_memory(episode)
else: # read_write
# Write episode to memory with consolidation
consolidated_episode = self.write_memory(episode)
# Read from memory using consolidated episode as query
retrieved, attention_weights = self.read_memory(consolidated_episode)
# Memory-augmented output combining input and retrieved memory
output = 0.7 * consolidated_episode + 0.3 * retrieved
return output, attention_weights
def get_memory_statistics(self) -> Dict[str, torch.Tensor]:
"""Get comprehensive memory statistics for monitoring"""
return {
'memory_usage_distribution': self.memory_usage,
'memory_age_distribution': self.memory_age,
'memory_quality_scores': self.memory_quality,
'memory_importance': self.memory_importance,
'attention_temperature': self.attention_temperature,
'memory_utilization': (self.memory_usage > 0).float().mean(),
'memory_diversity': torch.std(self.memory, dim=0).mean(),
'update_count': self.update_count
}
def consolidate_memory(self):
"""Explicit memory consolidation for improved organization"""
with torch.no_grad():
# Sort memory by importance and quality
importance_quality_score = 0.6 * self.memory_importance + 0.4 * self.memory_quality
sorted_indices = torch.argsort(importance_quality_score, descending=True)
# Reorganize memory to group similar episodes
sorted_memory = self.memory[sorted_indices]
self.memory.copy_(sorted_memory)
# Update corresponding metadata
self.memory_age[:] = self.memory_age[sorted_indices]
self.memory_usage[:] = self.memory_usage[sorted_indices]
self.memory_quality[:] = self.memory_quality[sorted_indices]
self.memory_importance[:] = self.memory_importance[sorted_indices]
def get_memory_info(self) -> Dict:
"""Get comprehensive memory information"""
info = {
'memory_size': self.memory_size,
'episode_dim': self.episode_dim,
'external_storage': self.external_storage,
'compression_enabled': self.compression_enabled,
'lazy_loading': self.lazy_loading,
'memory_loaded': self._memory_loaded if self.external_storage else True,
'version': self._memory_version,
'storage_path': self.memory_storage_path
}
if hasattr(self, 'memory'):
info.update({
'memory_utilization': (self.memory_usage > 0).float().mean().item(),
'memory_diversity': torch.std(self.memory, dim=0).mean().item(),
'update_count': self.update_count.item(),
'memory_device': str(self.memory.device)
})
return info
def create_memory_snapshot(self, snapshot_name: str = None) -> str:
"""Create a named snapshot of the current memory state"""
import time
from pathlib import Path
timestamp = int(time.time())
snapshot_name = snapshot_name or f"memory_snapshot_{timestamp}"
# Create snapshots directory
snapshots_dir = Path("memory_snapshots")
snapshots_dir.mkdir(exist_ok=True)
snapshot_path = snapshots_dir / f"{snapshot_name}.pt"
# Save current memory state
saved_path = self.save_external_memory(str(snapshot_path), compress=True)
logger.info(f"📸 Memory snapshot created: {saved_path}")
return saved_path
def load_memory_snapshot(self, snapshot_name: str) -> bool:
"""Load a named memory snapshot"""
from pathlib import Path
snapshots_dir = Path("memory_snapshots")
snapshot_path = snapshots_dir / f"{snapshot_name}.pt"
if not snapshot_path.exists():
logger.warning(f"⚠️ Snapshot not found: {snapshot_path}")
return False
success = self.load_external_memory(str(snapshot_path))
if success:
logger.info(f"📸 Memory snapshot loaded: {snapshot_name}")
return success
def enable_external_storage(self, storage_path: str = None, compress: bool = True, lazy: bool = False):
"""Enable external storage mode for edge deployment"""
self.external_storage = True
self.memory_storage_path = storage_path or "episodic_memory.pt"
self.compression_enabled = compress
self.lazy_loading = lazy
logger.info(f"🔄 External storage enabled: {self.memory_storage_path}")
logger.info(f" Compression: {compress}, Lazy loading: {lazy}")
def disable_external_storage(self):
"""Disable external storage and return to integrated mode"""
# Ensure memory is loaded before disabling external storage
self._ensure_memory_loaded()
self.external_storage = False
self.lazy_loading = False
self._memory_loaded = True
logger.info("🔄 External storage disabled, using integrated mode")
# ...existing code for other methods...
class CrossModalFusion(nn.Module):
"""Cross-modal fusion module for text and vision features"""
def __init__(
self,
text_dim: int,
vision_dim: int,
hidden_dim: int,
num_heads: int = 8,
num_layers: int = 2
):
super().__init__()
self.text_dim = text_dim
self.vision_dim = vision_dim
self.hidden_dim = hidden_dim
# Projection layers
self.text_proj = BitNetLinear(text_dim, hidden_dim)
self.vision_proj = BitNetLinear(vision_dim, hidden_dim)
# Cross-attention layers
self.cross_attention_layers = nn.ModuleList([
BitNetAttention(
dim=hidden_dim,
num_heads=num_heads
) for _ in range(num_layers)
])
# Layer normalization
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_dim) for _ in range(num_layers)
])
# Output projection
self.output_proj = BitNetLinear(hidden_dim, hidden_dim)
def forward(
self,
text_features: torch.Tensor,
vision_features: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
text_features: [batch_size, seq_len, text_dim]
vision_features: [batch_size, vision_dim]
Returns:
fused_features: [batch_size, seq_len, hidden_dim]
attention_weights: Dict of attention patterns
"""
batch_size, seq_len = text_features.shape[:2]
# Validate input dimensions
if text_features.size(-1) != self.text_dim:
raise ValueError(f"Text features dimension {text_features.size(-1)} doesn't match expected {self.text_dim}")
if vision_features.size(-1) != self.vision_dim:
raise ValueError(f"Vision features dimension {vision_features.size(-1)} doesn't match expected {self.vision_dim}")
# Project to common dimension
# [batch_size, seq_len, hidden_dim]
text_proj = self.text_proj(text_features)
vision_proj = self.vision_proj(vision_features).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Cross-attention fusion
fused = text_proj
attention_weights = {}
for i, (attn_layer, norm_layer) in enumerate(zip(self.cross_attention_layers, self.layer_norms)):
# Text-to-vision cross-attention
attn_output, attn_weights = attn_layer(
query=fused,
key=vision_proj,
value=vision_proj
)
# Residual connection and normalization
fused = norm_layer(fused + attn_output)
attention_weights[f'layer_{i}'] = attn_weights
# Output projection
output = self.output_proj(fused)
return output, attention_weights
class VisionEncoder(nn.Module):
"""Quantized Vision Encoder for DiNOv2 features"""
def __init__(
self,
input_dim: int = 768,
hidden_dim: int = 512,
output_dim: int = 768,
num_layers: int = 2
):
super().__init__()
# Quantized layers
self.layers = nn.ModuleList([
BitNetLinear(input_dim if i == 0 else hidden_dim, hidden_dim)
for i in range(num_layers)
])
# Output projection
self.output_proj = BitNetLinear(hidden_dim, output_dim)
# Activation and normalization
self.activation = nn.GELU()
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_dim) for _ in range(num_layers)
])
self.dropout = nn.Dropout(0.1)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
"""
Args:
vision_features: [batch_size, input_dim] - DiNOv2 features
Returns:
encoded_features: [batch_size, output_dim]
"""
# Handle potential extra dimensions
if vision_features.dim() > 2:
# Flatten any extra dimensions except batch
original_shape = vision_features.shape
vision_features = vision_features.view(original_shape[0], -1)
# Ensure we have the expected input dimension
if vision_features.size(-1) != self.layers[0].in_features:
# Take only the first input_dim features if we have more
if vision_features.size(-1) > self.layers[0].in_features:
vision_features = vision_features[:, :self.layers[0].in_features]
else:
raise ValueError(f"Vision features dimension {vision_features.size(-1)} is smaller than expected {self.layers[0].in_features}")
x = vision_features
for layer, norm in zip(self.layers, self.layer_norms):
x = layer(x)
x = norm(x)
x = self.activation(x)
x = self.dropout(x)
# Output projection
output = self.output_proj(x)
return output
class BitMarModel(PreTrainedModel):
"""
BitMar: BitNet-quantized Vision-Language Episodic Memory Transformer
Compatible with Hugging Face Transformers
"""
config_class = BitMarConfig
base_model_prefix = "bitmar"
supports_gradient_checkpointing = True
_no_split_modules = ["BitNetTransformerBlock", "EpisodicMemory"]
def __init__(self, config: BitMarConfig):
super().__init__(config)
self.config = config
# Loss balancing parameters
self.cross_modal_loss_weight = getattr(config, 'cross_modal_loss_weight', 0.1)
self.text_loss_weight = getattr(config, 'text_loss_weight', 1.0)
self.vision_loss_weight = getattr(config, 'vision_loss_weight', 0.1)
self.memory_loss_weight = getattr(config, 'memory_loss_weight', 0.05)
# Dynamic loss scaling
self.adaptive_loss_scaling = getattr(config, 'adaptive_loss_scaling', True)
self.loss_scale_temperature = getattr(config, 'loss_scale_temperature', 0.07)
# Encoder freezing parameters
self.freeze_text_encoder_steps = getattr(config, 'freeze_text_encoder_steps', 0)
self.freeze_vision_encoder_steps = getattr(config, 'freeze_vision_encoder_steps', 0)
self.current_step = 0
# BitNet text encoder/decoder
self.text_encoder = BitNetTextEncoder(
vocab_size=config.vocab_size,
dim=config.text_encoder_dim,
num_layers=config.text_encoder_layers,
num_heads=config.text_encoder_heads,
max_seq_len=config.max_seq_len,
dropout=config.dropout
)
self.text_decoder = BitNetTextDecoder(
vocab_size=config.vocab_size,
dim=config.text_decoder_dim,
num_layers=config.text_decoder_layers,
num_heads=config.text_decoder_heads,
max_seq_len=config.max_seq_len,
dropout=config.dropout
)
# Vision processing with BitNet quantization
self.vision_encoder = VisionEncoder(
input_dim=config.vision_encoder_dim,
hidden_dim=config.vision_hidden_size,
output_dim=config.vision_latent_size
)
# Cross-modal fusion with BitNet
self.fusion = CrossModalFusion(
text_dim=config.text_encoder_dim,
vision_dim=config.vision_latent_size,
hidden_dim=config.fusion_hidden_size,
num_heads=config.fusion_num_heads,
num_layers=config.fusion_num_layers
)
# Episodic memory with BitNet quantization
self.memory = EpisodicMemory(
memory_size=config.memory_size,
episode_dim=config.episode_dim,
alpha=config.memory_alpha,
direct_writing=config.direct_writing
)
# Additional BitNet projection layers
self.text_to_episode = BitNetLinear(
config.text_encoder_dim,
config.episode_dim
)
self.vision_to_episode = BitNetLinear(
config.vision_latent_size,
config.episode_dim
)
self.memory_to_decoder = BitNetLinear(
config.episode_dim,
config.fusion_hidden_size
)
# Projection to decoder dimension
self.decoder_input_proj = BitNetLinear(
config.fusion_hidden_size,
config.text_decoder_dim
)
# Initialize tokenizer (for compatibility)
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained('gpt2')
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
except:
self.tokenizer = None
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, BitNetLinear)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Encode text using BitNet encoder"""
text_features, attention_patterns = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask)
return text_features, attention_patterns
def encode_vision(self, vision_features: torch.Tensor) -> torch.Tensor:
"""Encode vision features using quantized vision encoder"""
vision_latent = self.vision_encoder(vision_features)
return vision_latent
def create_episode(
self,
text_features: torch.Tensor,
vision_latent: torch.Tensor,
attention_weights: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Create multimodal episode for memory storage"""
# Pool text features (mean pooling)
text_pooled = text_features.mean(dim=1)
# Project both text and vision to episode dimension
text_projected = self.text_to_episode(text_pooled)
vision_projected = self.vision_to_episode(vision_latent)
# Combine text and vision features
episode = text_projected + vision_projected
return episode
def create_episode_mixed(
self,
text_features: torch.Tensor,
vision_latent: torch.Tensor,
attention_weights: Dict[str, torch.Tensor],
has_vision: torch.Tensor
) -> torch.Tensor:
"""Create episodes with different handling for vision vs text-only samples"""
batch_size = text_features.size(0)
# Pool text features
text_pooled = text_features.mean(dim=1)
# Project to episode dimension
text_episode = self.text_to_episode(text_pooled)
vision_episode = self.vision_to_episode(vision_latent)
# For text-only samples, use only text features
# For multimodal samples, combine text and vision
episode = torch.zeros_like(text_episode)
# Text-only samples (has_vision == False)
text_only_mask = ~has_vision
if text_only_mask.any():
episode[text_only_mask] = text_episode[text_only_mask]
# Multimodal samples (has_vision == True)
multimodal_mask = has_vision
if multimodal_mask.any():
# Combine text and vision for multimodal samples
combined = text_episode[multimodal_mask] + vision_episode[multimodal_mask]
episode[multimodal_mask] = combined
return episode
def compute_cross_modal_contrastive_loss(
self,
text_features: torch.Tensor,
vision_features: torch.Tensor,
temperature: float = 0.07
) -> torch.Tensor:
"""Compute cross-modal contrastive loss similar to CLIP"""
batch_size = text_features.shape[0]
# Handle dimension mismatch between text and vision features
text_dim = text_features.shape[-1]
vision_dim = vision_features.shape[-1]
if text_dim != vision_dim:
# Project to smaller dimension to maintain compatibility
target_dim = min(text_dim, vision_dim)
if text_dim > vision_dim:
# Project text features to vision dimension
text_features = text_features[:, :target_dim]
else:
# Project vision features to text dimension
vision_features = vision_features[:, :target_dim]
# Normalize features
text_features = F.normalize(text_features, dim=-1)
vision_features = F.normalize(vision_features, dim=-1)
# Compute similarity matrix
logits = torch.matmul(text_features, vision_features.T) / temperature
# Create labels (diagonal should be positive pairs)
labels = torch.arange(batch_size, device=logits.device)
# Compute cross-entropy loss for both directions
text_to_vision_loss = F.cross_entropy(logits, labels)
vision_to_text_loss = F.cross_entropy(logits.T, labels)
return (text_to_vision_loss + vision_to_text_loss) / 2
def compute_vision_reconstruction_loss(
self,
original_vision: torch.Tensor,
reconstructed_vision: torch.Tensor
) -> torch.Tensor:
"""Compute vision reconstruction loss to prevent vision encoder collapse"""
return F.mse_loss(reconstructed_vision, original_vision)
def compute_memory_consistency_loss(
self,
episode: torch.Tensor,
retrieved_memory: torch.Tensor
) -> torch.Tensor:
"""Compute memory consistency loss to encourage meaningful memory usage"""
# L2 regularization on memory difference
memory_diff = episode - retrieved_memory
return torch.mean(torch.norm(memory_diff, dim=-1))
def compute_balanced_loss(
self,
decoder_loss: torch.Tensor,
cross_modal_loss: torch.Tensor,
vision_loss: Optional[torch.Tensor] = None,
memory_loss: Optional[torch.Tensor] = None,
step: int = 0,
adaptive_controller=None
) -> Dict[str, torch.Tensor]:
"""Compute balanced multi-objective loss with adaptive scaling"""
losses = {'decoder_loss': decoder_loss, 'cross_modal_loss': cross_modal_loss}
if vision_loss is not None:
losses['vision_loss'] = vision_loss
if memory_loss is not None:
losses['memory_loss'] = memory_loss
if self.adaptive_loss_scaling:
# Adaptive scaling based on loss magnitudes
with torch.no_grad():
# Compute relative loss scales
decoder_scale = decoder_loss.detach()
cross_modal_scale = cross_modal_loss.detach()
# Prevent division by zero
if decoder_scale > 1e-8:
adaptive_cross_modal_weight = (decoder_scale / cross_modal_scale.clamp(min=1e-8)) * self.cross_modal_loss_weight
else:
adaptive_cross_modal_weight = self.cross_modal_loss_weight
# Clamp adaptive weights
adaptive_cross_modal_weight = torch.clamp(adaptive_cross_modal_weight, 0.01, 1.0)
else:
adaptive_cross_modal_weight = self.cross_modal_loss_weight
# Apply loss scheduling (increase cross-modal importance over time)
cross_modal_schedule = min(1.0, step / 50000) # Ramp up over 50k steps
scheduled_cross_modal_weight = adaptive_cross_modal_weight * cross_modal_schedule
# Compute weighted total loss
total_loss = (
self.text_loss_weight * decoder_loss +
scheduled_cross_modal_weight * cross_modal_loss
)
if vision_loss is not None:
total_loss += self.vision_loss_weight * vision_loss
if memory_loss is not None:
total_loss += self.memory_loss_weight * memory_loss
losses.update({
'total_loss': total_loss,
'cross_modal_weight': scheduled_cross_modal_weight,
'adaptive_weight': adaptive_cross_modal_weight if self.adaptive_loss_scaling else torch.tensor(0.0)
})
return losses
def apply_encoder_freezing(self, step: int):
"""Apply temporary encoder freezing based on training step"""
self.current_step = step
# Freeze text encoder if within freezing window
freeze_text = step < self.freeze_text_encoder_steps
for param in self.text_encoder.parameters():
param.requires_grad = not freeze_text
# Freeze vision encoder if within freezing window
freeze_vision = step < self.freeze_vision_encoder_steps
for param in self.vision_encoder.parameters():
param.requires_grad = not freeze_vision
return {
'text_encoder_frozen': freeze_text,
'vision_encoder_frozen': freeze_vision
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
vision_features: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
mode: str = "train",
step: int = 0,
has_vision: Optional[torch.Tensor] = None,
**kwargs
) -> Union[Tuple, CausalLMOutput]:
"""
Forward pass through BitMar model with mixed vision/text batch support
Args:
has_vision: Boolean tensor [batch_size] indicating which samples have real vision features
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# CRITICAL FIX: Ensure input_ids are integers
if input_ids.dtype != torch.long:
input_ids = input_ids.long()
# CRITICAL FIX: Ensure labels are integers if provided
if labels is not None and labels.dtype != torch.long:
labels = labels.long()
if input_ids is None:
raise ValueError("input_ids must be provided")
batch_size, seq_len = input_ids.shape
# Handle missing attention mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.float)
# Ensure attention_mask is float
if attention_mask.dtype != torch.float:
attention_mask = attention_mask.float()
# Handle missing vision features
if vision_features is None:
vision_features = torch.zeros(batch_size, self.config.vision_encoder_dim,
device=input_ids.device, dtype=torch.float32)
# Validate input tensor dimensions
expected_vision_dim = self.config.vision_encoder_dim
if vision_features.dim() != 2 or vision_features.size(-1) != expected_vision_dim:
if vision_features.dim() > 2:
vision_features = vision_features.view(batch_size, -1)
if vision_features.size(-1) != expected_vision_dim:
# Pad or trim to expected dimension
if vision_features.size(-1) > expected_vision_dim:
vision_features = vision_features[:, :expected_vision_dim]
else:
padding = expected_vision_dim - vision_features.size(-1)
vision_features = F.pad(vision_features, (0, padding))
# Default has_vision to all True if not provided (backward compatibility)
if has_vision is None:
has_vision = torch.ones(batch_size, dtype=torch.bool, device=input_ids.device)
# Apply encoder freezing
freezing_status = {}
if mode == "train":
freezing_status = self.apply_encoder_freezing(step)
# Encode text (always available)
text_features, text_attention = self.encode_text(input_ids, attention_mask)
# Encode vision (with masking for text-only samples)
vision_latent = self.encode_vision(vision_features)
# Mask vision features for text-only samples
vision_mask = has_vision.float().unsqueeze(-1)
vision_latent_masked = vision_latent * vision_mask
# Cross-modal fusion
fused_features, cross_attention = self.fusion(text_features, vision_latent_masked)
# Create episodes
if has_vision.any() and (~has_vision).any():
# Mixed batch - use mixed episode creation
episode = self.create_episode_mixed(
text_features, vision_latent_masked, cross_attention, has_vision
)
else:
# Uniform batch - use standard episode creation
episode = self.create_episode(
text_features, vision_latent_masked, cross_attention
)
# Episodic memory interaction
if mode == "train":
retrieved_memory, memory_attention = self.memory(episode, mode="read_write")
else:
retrieved_memory, memory_attention = self.memory(episode, mode="read")
# Prepare decoder input
memory_context = self.memory_to_decoder(retrieved_memory)
memory_context_expanded = memory_context.unsqueeze(1).expand(-1, seq_len, -1)
fused_with_memory = fused_features + memory_context_expanded
decoder_input = self.decoder_input_proj(fused_with_memory)
# Generate text using BitNet decoder
decoder_outputs = self.text_decoder(
inputs_embeds=decoder_input,
attention_mask=attention_mask,
labels=labels
)
# Compute losses if in training mode
final_loss = None
loss_dict = {}
if mode == "train" and labels is not None:
# Primary decoder loss
decoder_loss = decoder_outputs['loss']
# Cross-modal contrastive loss (only for samples with vision)
cross_modal_loss = torch.tensor(0.0, device=input_ids.device)
if has_vision.any():
vision_indices = has_vision.nonzero(as_tuple=True)[0]
if len(vision_indices) > 0:
text_pooled = text_features[vision_indices].mean(dim=1)
vision_for_loss = vision_latent[vision_indices]
cross_modal_loss = self.compute_cross_modal_contrastive_loss(
text_pooled, vision_for_loss, temperature=self.loss_scale_temperature
)
# Optional additional losses
vision_loss = None
memory_loss = self.compute_memory_consistency_loss(episode, retrieved_memory)
# Compute balanced loss
loss_dict = self.compute_balanced_loss(
decoder_loss, cross_modal_loss, vision_loss, memory_loss, step
)
final_loss = loss_dict['total_loss']
elif decoder_outputs.get('loss') is not None:
final_loss = decoder_outputs['loss']
# Prepare outputs
if return_dict:
output = CausalLMOutput(
loss=final_loss,
logits=decoder_outputs['logits'],
hidden_states=fused_features if output_hidden_states else None,
attentions=text_attention if output_attentions else None,
)
# Add additional outputs for analysis
if mode == "train":
for key, value in loss_dict.items():
setattr(output, key, value)
for key, value in freezing_status.items():
setattr(output, key, value)
return output
else:
outputs = (decoder_outputs['logits'],)
if final_loss is not None:
outputs = (final_loss,) + outputs
if output_hidden_states:
outputs = outputs + (fused_features,)
if output_attentions:
outputs = outputs + (text_attention,)
return outputs
def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
vision_features: Optional[torch.Tensor] = None,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True,
**kwargs
) -> torch.LongTensor:
"""Generate text given input text and vision features"""
self.eval()
batch_size = input_ids.size(0)
device = input_ids.device
# Handle missing vision features
if vision_features is None:
vision_features = torch.zeros(batch_size, self.config.vision_encoder_dim,
device=device, dtype=torch.float32)
# Handle attention mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
generated_ids = input_ids.clone()
current_attention_mask = attention_mask.clone()
with torch.no_grad():
for _ in range(max_length - input_ids.size(1)):
# Get model outputs
outputs = self.forward(
input_ids=generated_ids,
attention_mask=current_attention_mask,
vision_features=vision_features,
mode="inference",
return_dict=True
)
# Get next token logits
next_token_logits = outputs.logits[:, -1, :] / temperature
if do_sample:
# Apply top-p sampling
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample from the filtered distribution
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
# Append to generated sequence
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
# Update attention mask
current_attention_mask = torch.cat([
current_attention_mask,
torch.ones(batch_size, 1, device=device)
], dim=-1)
# Stop if EOS token is generated
if (next_token == self.config.eos_token_id).all():
break
return generated_ids
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
vision_features=None,
**kwargs
):
"""Prepare inputs for generation"""
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"vision_features": vision_features,
"use_cache": kwargs.get("use_cache", True),
}
# Register the model with transformers
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("bitmar", BitMarConfig)
AutoModel.register(BitMarConfig, BitMarModel)
AutoModelForCausalLM.register(BitMarConfig, BitMarModel)