import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from datasets import load_dataset from transformers import AutoTokenizer, PretrainedConfig, AutoConfig, AutoModel, PreTrainedModel from torch.optim import AdamW import os import time import numpy as np import json # Enhanced configuration class with HuggingFace compatibility class BucketMemoryConfig(PretrainedConfig): model_type = "bucket-memory-model3" def __init__( self, vocab_size=30000, d_model=512, num_layers=6, num_buckets=8, min_bucket_size=1, max_bucket_size=32, max_seq_length=1024, dropout=0.1, use_flash_attention=True, num_attention_heads=8, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.num_layers = num_layers self.num_buckets = num_buckets self.min_bucket_size = min_bucket_size self.max_bucket_size = max_bucket_size self.max_seq_length = max_seq_length self.dropout = dropout self.use_flash_attention = use_flash_attention self.num_attention_heads = num_attention_heads class DynamicBucketMemory(nn.Module): def __init__(self, embedding_dim=512, num_buckets=8, min_bucket_size=1, max_bucket_size=32, compression_factor=0.8, decay_rate=0.05): super().__init__() self.embedding_dim = embedding_dim self.num_buckets = num_buckets self.min_bucket_size = min_bucket_size self.max_bucket_size = max_bucket_size self.decay_rate = decay_rate # Initialize bucket sizes logarithmically sizes = np.logspace(np.log10(min_bucket_size), np.log10(max_bucket_size), num_buckets).astype(int) self.bucket_sizes = np.maximum(sizes, min_bucket_size).tolist() # Memory structures self.memory_buckets = None self.memory_age = None self.bucket_importance = nn.Parameter(torch.ones(num_buckets)) # Neural components self.query_proj = nn.Linear(embedding_dim, embedding_dim) self.key_proj = nn.Linear(embedding_dim, embedding_dim) self.value_proj = nn.Linear(embedding_dim, embedding_dim) self.output_proj = nn.Linear(embedding_dim, embedding_dim) self.input_norm = nn.LayerNorm(embedding_dim) self.output_norm = nn.LayerNorm(embedding_dim) self.bucket_selector = nn.Sequential( nn.Linear(embedding_dim, num_buckets * 2), nn.GELU(), nn.Linear(num_buckets * 2, num_buckets), nn.Softmax(dim=-1) ) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def _initialize_memory(self, batch_size, device): if self.memory_buckets is None: self.memory_buckets = [torch.zeros(batch_size, size, self.embedding_dim, device=device) for size in self.bucket_sizes] self.memory_age = [torch.zeros(batch_size, size, device=device) for size in self.bucket_sizes] def forward(self, input_data, memory_update=True): # Handle dimension issues while input_data.dim() > 3: input_data = input_data.squeeze(0) if input_data.dim() == 4: input_data = input_data.squeeze(-1) if input_data.dim() == 2: input_data = input_data.unsqueeze(-1) if self.embedding_dim > 1: input_data = input_data.expand(-1, -1, self.embedding_dim) batch_size, seq_len, _ = input_data.size() device = input_data.device normalized_input = self.input_norm(input_data) # Initialize memory if needed if self.memory_buckets is None or len(self.memory_buckets[0]) != batch_size: self._initialize_memory(batch_size, device) # Determine which buckets to use avg_input_features = normalized_input.mean(dim=1) bucket_weights = self.bucket_selector(avg_input_features) # Retrieve from memory (simplified) projected_query = self.query_proj(normalized_input) outputs = torch.zeros(batch_size, seq_len, self.embedding_dim, device=device) for b in range(self.num_buckets): if bucket_weights[:, b].max() < 0.05: continue relevance = torch.bmm( projected_query, self.memory_buckets[b].transpose(1, 2) ) / (self.embedding_dim ** 0.5) age_penalty = torch.exp(-self.memory_age[b] * 0.7).unsqueeze(1) relevance *= age_penalty retrieval_weights = F.softmax(relevance, dim=-1) retrieved_values = torch.bmm(retrieval_weights, self.memory_buckets[b]) importance_scale = torch.sigmoid(self.bucket_importance[b]) outputs += retrieved_values * importance_scale * bucket_weights[:, b].view(batch_size, 1, 1) memory_output = self.output_proj(outputs) # Update memory if training if memory_update and self.training: with torch.no_grad(): keys = self.key_proj(normalized_input) values = self.value_proj(normalized_input) for b in range(self.num_buckets): bucket_size = self.bucket_sizes[b] bucket_mask = (bucket_weights[:, b] > 0.1).float().view(-1, 1, 1) if seq_len > bucket_size: stride = max(1, seq_len // bucket_size) indices = torch.arange(0, seq_len, stride, device=device)[:bucket_size] selected_values = values[:, indices] else: padding = bucket_size - seq_len selected_values = F.pad(values, (0, 0, 0, padding)) alpha = torch.sigmoid(self.bucket_importance[b]) * (0.8 if b > self.num_buckets // 2 else 0.2) update = alpha * self.memory_buckets[b] + (1 - alpha) * selected_values self.memory_buckets[b] = self.memory_buckets[b] * (1 - bucket_mask) + update * bucket_mask age_mask = (1 - bucket_mask.squeeze(-1)) self.memory_age[b] = self.memory_age[b] * age_mask + self.decay_rate return self.output_norm(input_data + memory_output) # Modified transformer layer with Flash Attention class BucketMemoryTransformerLayer(nn.Module): def __init__(self, d_model=512, d_ff=2048, dropout=0.4, num_buckets=8, min_bucket_size=1, max_bucket_size=32, use_flash_attention=True, num_heads=8): super().__init__() self.use_flash_attention = use_flash_attention self.num_heads = num_heads self.head_dim = d_model // num_heads # Self-attention components with Flash Attention support self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) # Keep the bucket memory as is self.bucket_memory = DynamicBucketMemory( embedding_dim=d_model, num_buckets=num_buckets, min_bucket_size=min_bucket_size, max_bucket_size=max_bucket_size ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model) ) self.dropout = nn.Dropout(dropout) def forward(self, x, attention_mask=None): # Self-attention with Flash Attention residual = x x = self.norm1(x) batch_size, seq_len, _ = x.shape # Project to queries, keys, values q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Use Flash Attention if available and enabled if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): # Convert attention mask if provided attn_mask = None if attention_mask is not None: attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) attn_mask = (1.0 - attn_mask) * -10000.0 # Use PyTorch's native flash attention attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout.p if self.training else 0.0, is_causal=False ) else: # Fallback to standard attention scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) if attention_mask is not None: scores = scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, v) # Reshape and project back attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) attn_output = self.out_proj(attn_output) x = residual + self.dropout(attn_output) # Bucket memory (unchanged) memory_out = self.bucket_memory(self.norm2(x)) x = x + self.dropout(memory_out) # Feed-forward x = x + self.dropout(self.ff(self.norm3(x))) return x # Updated model with HuggingFace compatibility class BucketMemoryModel(PreTrainedModel): config_class = BucketMemoryConfig # Add this line base_model_prefix = "bucket-memory-model2" def __init__(self, config, adapter_kwargs=None): super().__init__(config) self.d_model = config.d_model self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) self.pos_encoding = nn.Parameter(torch.zeros(1, config.max_seq_length, config.d_model)) self._init_positional_encoding(config.max_seq_length, config.d_model) # Use config.num_attention_heads if available, otherwise calculate num_heads = getattr(config, 'num_attention_heads', config.d_model // 64) num_heads = max(1, num_heads) # Ensure at least 1 head self.layers = nn.ModuleList([ BucketMemoryTransformerLayer( d_model=config.d_model, d_ff=4*config.d_model, dropout=config.dropout, num_buckets=config.num_buckets, min_bucket_size=config.min_bucket_size, max_bucket_size=config.max_bucket_size, use_flash_attention=getattr(config, 'use_flash_attention', True), num_heads=num_heads ) for _ in range(config.num_layers) ]) self.norm = nn.LayerNorm(config.d_model) self.output_proj = nn.Linear(config.d_model, config.vocab_size) self.dropout = nn.Dropout(config.dropout) def _init_positional_encoding(self, max_len, d_model): position = torch.arange(0, max_len).unsqueeze(1).float() div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)) pos_enc = torch.zeros(1, max_len, d_model) pos_enc[0, :, 0::2] = torch.sin(position * div_term) pos_enc[0, :, 1::2] = torch.cos(position * div_term) self.pos_encoding.data.copy_(pos_enc) def forward(self, input_ids, attention_mask=None, labels=None): batch_size, seq_len = input_ids.size() x = self.token_embedding(input_ids) * np.sqrt(self.d_model) x = x + self.pos_encoding[:, :seq_len] x = self.dropout(x) # Process through transformer layers for layer in self.layers: x = layer(x, attention_mask) x = self.norm(x) logits = self.output_proj(x) if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) return type('ModelOutput', (), {'loss': loss, 'logits': logits}) return logits AutoConfig.register("bucket-memory-model3", BucketMemoryConfig) AutoModel.register(BucketMemoryConfig, BucketMemoryModel) BucketMemoryConfig.register_for_auto_class() BucketMemoryModel.register_for_auto_class("AutoModel")