kok-baseV2 / bucket_memory_model.py
moelanoby's picture
Update bucket_memory_model.py
9706dbe verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, PretrainedConfig, AutoConfig, AutoModel, PreTrainedModel
from torch.optim import AdamW
import os
import time
import numpy as np
import json
# Enhanced configuration class with HuggingFace compatibility
class BucketMemoryConfig(PretrainedConfig):
model_type = "bucket-memory-model3"
def __init__(
self, vocab_size=30000, d_model=512, num_layers=6, num_buckets=8,
min_bucket_size=1, max_bucket_size=32, max_seq_length=1024, dropout=0.1,
use_flash_attention=True, num_attention_heads=8, **kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.num_layers = num_layers
self.num_buckets = num_buckets
self.min_bucket_size = min_bucket_size
self.max_bucket_size = max_bucket_size
self.max_seq_length = max_seq_length
self.dropout = dropout
self.use_flash_attention = use_flash_attention
self.num_attention_heads = num_attention_heads
class DynamicBucketMemory(nn.Module):
def __init__(self, embedding_dim=512, num_buckets=8, min_bucket_size=1, max_bucket_size=32,
compression_factor=0.8, decay_rate=0.05):
super().__init__()
self.embedding_dim = embedding_dim
self.num_buckets = num_buckets
self.min_bucket_size = min_bucket_size
self.max_bucket_size = max_bucket_size
self.decay_rate = decay_rate
# Initialize bucket sizes logarithmically
sizes = np.logspace(np.log10(min_bucket_size), np.log10(max_bucket_size), num_buckets).astype(int)
self.bucket_sizes = np.maximum(sizes, min_bucket_size).tolist()
# Memory structures
self.memory_buckets = None
self.memory_age = None
self.bucket_importance = nn.Parameter(torch.ones(num_buckets))
# Neural components
self.query_proj = nn.Linear(embedding_dim, embedding_dim)
self.key_proj = nn.Linear(embedding_dim, embedding_dim)
self.value_proj = nn.Linear(embedding_dim, embedding_dim)
self.output_proj = nn.Linear(embedding_dim, embedding_dim)
self.input_norm = nn.LayerNorm(embedding_dim)
self.output_norm = nn.LayerNorm(embedding_dim)
self.bucket_selector = nn.Sequential(
nn.Linear(embedding_dim, num_buckets * 2),
nn.GELU(),
nn.Linear(num_buckets * 2, num_buckets),
nn.Softmax(dim=-1)
)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def _initialize_memory(self, batch_size, device):
if self.memory_buckets is None:
self.memory_buckets = [torch.zeros(batch_size, size, self.embedding_dim, device=device)
for size in self.bucket_sizes]
self.memory_age = [torch.zeros(batch_size, size, device=device) for size in self.bucket_sizes]
def forward(self, input_data, memory_update=True):
# Handle dimension issues
while input_data.dim() > 3:
input_data = input_data.squeeze(0)
if input_data.dim() == 4:
input_data = input_data.squeeze(-1)
if input_data.dim() == 2:
input_data = input_data.unsqueeze(-1)
if self.embedding_dim > 1:
input_data = input_data.expand(-1, -1, self.embedding_dim)
batch_size, seq_len, _ = input_data.size()
device = input_data.device
normalized_input = self.input_norm(input_data)
# Initialize memory if needed
if self.memory_buckets is None or len(self.memory_buckets[0]) != batch_size:
self._initialize_memory(batch_size, device)
# Determine which buckets to use
avg_input_features = normalized_input.mean(dim=1)
bucket_weights = self.bucket_selector(avg_input_features)
# Retrieve from memory (simplified)
projected_query = self.query_proj(normalized_input)
outputs = torch.zeros(batch_size, seq_len, self.embedding_dim, device=device)
for b in range(self.num_buckets):
if bucket_weights[:, b].max() < 0.05:
continue
relevance = torch.bmm(
projected_query,
self.memory_buckets[b].transpose(1, 2)
) / (self.embedding_dim ** 0.5)
age_penalty = torch.exp(-self.memory_age[b] * 0.7).unsqueeze(1)
relevance *= age_penalty
retrieval_weights = F.softmax(relevance, dim=-1)
retrieved_values = torch.bmm(retrieval_weights, self.memory_buckets[b])
importance_scale = torch.sigmoid(self.bucket_importance[b])
outputs += retrieved_values * importance_scale * bucket_weights[:, b].view(batch_size, 1, 1)
memory_output = self.output_proj(outputs)
# Update memory if training
if memory_update and self.training:
with torch.no_grad():
keys = self.key_proj(normalized_input)
values = self.value_proj(normalized_input)
for b in range(self.num_buckets):
bucket_size = self.bucket_sizes[b]
bucket_mask = (bucket_weights[:, b] > 0.1).float().view(-1, 1, 1)
if seq_len > bucket_size:
stride = max(1, seq_len // bucket_size)
indices = torch.arange(0, seq_len, stride, device=device)[:bucket_size]
selected_values = values[:, indices]
else:
padding = bucket_size - seq_len
selected_values = F.pad(values, (0, 0, 0, padding))
alpha = torch.sigmoid(self.bucket_importance[b]) * (0.8 if b > self.num_buckets // 2 else 0.2)
update = alpha * self.memory_buckets[b] + (1 - alpha) * selected_values
self.memory_buckets[b] = self.memory_buckets[b] * (1 - bucket_mask) + update * bucket_mask
age_mask = (1 - bucket_mask.squeeze(-1))
self.memory_age[b] = self.memory_age[b] * age_mask + self.decay_rate
return self.output_norm(input_data + memory_output)
# Modified transformer layer with Flash Attention
class BucketMemoryTransformerLayer(nn.Module):
def __init__(self, d_model=512, d_ff=2048, dropout=0.4, num_buckets=8,
min_bucket_size=1, max_bucket_size=32, use_flash_attention=True,
num_heads=8):
super().__init__()
self.use_flash_attention = use_flash_attention
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Self-attention components with Flash Attention support
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
# Keep the bucket memory as is
self.bucket_memory = DynamicBucketMemory(
embedding_dim=d_model, num_buckets=num_buckets,
min_bucket_size=min_bucket_size, max_bucket_size=max_bucket_size
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attention_mask=None):
# Self-attention with Flash Attention
residual = x
x = self.norm1(x)
batch_size, seq_len, _ = x.shape
# Project to queries, keys, values
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Use Flash Attention if available and enabled
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
# Convert attention mask if provided
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attn_mask = (1.0 - attn_mask) * -10000.0
# Use PyTorch's native flash attention
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=False
)
else:
# Fallback to standard attention
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
if attention_mask is not None:
scores = scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)
# Reshape and project back
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
attn_output = self.out_proj(attn_output)
x = residual + self.dropout(attn_output)
# Bucket memory (unchanged)
memory_out = self.bucket_memory(self.norm2(x))
x = x + self.dropout(memory_out)
# Feed-forward
x = x + self.dropout(self.ff(self.norm3(x)))
return x
# Updated model with HuggingFace compatibility
class BucketMemoryModel(PreTrainedModel):
config_class = BucketMemoryConfig # Add this line
base_model_prefix = "bucket-memory-model2"
def __init__(self, config, adapter_kwargs=None):
super().__init__(config)
self.d_model = config.d_model
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_encoding = nn.Parameter(torch.zeros(1, config.max_seq_length, config.d_model))
self._init_positional_encoding(config.max_seq_length, config.d_model)
# Use config.num_attention_heads if available, otherwise calculate
num_heads = getattr(config, 'num_attention_heads', config.d_model // 64)
num_heads = max(1, num_heads) # Ensure at least 1 head
self.layers = nn.ModuleList([
BucketMemoryTransformerLayer(
d_model=config.d_model,
d_ff=4*config.d_model,
dropout=config.dropout,
num_buckets=config.num_buckets,
min_bucket_size=config.min_bucket_size,
max_bucket_size=config.max_bucket_size,
use_flash_attention=getattr(config, 'use_flash_attention', True),
num_heads=num_heads
) for _ in range(config.num_layers)
])
self.norm = nn.LayerNorm(config.d_model)
self.output_proj = nn.Linear(config.d_model, config.vocab_size)
self.dropout = nn.Dropout(config.dropout)
def _init_positional_encoding(self, max_len, d_model):
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
pos_enc = torch.zeros(1, max_len, d_model)
pos_enc[0, :, 0::2] = torch.sin(position * div_term)
pos_enc[0, :, 1::2] = torch.cos(position * div_term)
self.pos_encoding.data.copy_(pos_enc)
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_len = input_ids.size()
x = self.token_embedding(input_ids) * np.sqrt(self.d_model)
x = x + self.pos_encoding[:, :seq_len]
x = self.dropout(x)
# Process through transformer layers
for layer in self.layers:
x = layer(x, attention_mask)
x = self.norm(x)
logits = self.output_proj(x)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return type('ModelOutput', (), {'loss': loss, 'logits': logits})
return logits
AutoConfig.register("bucket-memory-model3", BucketMemoryConfig)
AutoModel.register(BucketMemoryConfig, BucketMemoryModel)
BucketMemoryConfig.register_for_auto_class()
BucketMemoryModel.register_for_auto_class("AutoModel")