|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, PretrainedConfig, AutoConfig, AutoModel, PreTrainedModel |
|
|
from torch.optim import AdamW |
|
|
import os |
|
|
import time |
|
|
import numpy as np |
|
|
import json |
|
|
|
|
|
class BucketMemoryConfig(PretrainedConfig): |
|
|
model_type = "bucket-memory-model3" |
|
|
|
|
|
def __init__( |
|
|
self, vocab_size=30000, d_model=512, num_layers=6, num_buckets=8, |
|
|
min_bucket_size=1, max_bucket_size=32, max_seq_length=1024, dropout=0.1, |
|
|
use_flash_attention=True, num_attention_heads=8, **kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.num_layers = num_layers |
|
|
self.num_buckets = num_buckets |
|
|
self.min_bucket_size = min_bucket_size |
|
|
self.max_bucket_size = max_bucket_size |
|
|
self.max_seq_length = max_seq_length |
|
|
self.dropout = dropout |
|
|
self.use_flash_attention = use_flash_attention |
|
|
self.num_attention_heads = num_attention_heads |
|
|
|
|
|
class DynamicBucketMemory(nn.Module): |
|
|
def __init__(self, embedding_dim=512, num_buckets=8, min_bucket_size=1, max_bucket_size=32, |
|
|
compression_factor=0.8, decay_rate=0.05): |
|
|
super().__init__() |
|
|
self.embedding_dim = embedding_dim |
|
|
self.num_buckets = num_buckets |
|
|
self.min_bucket_size = min_bucket_size |
|
|
self.max_bucket_size = max_bucket_size |
|
|
self.decay_rate = decay_rate |
|
|
|
|
|
|
|
|
sizes = np.logspace(np.log10(min_bucket_size), np.log10(max_bucket_size), num_buckets).astype(int) |
|
|
self.bucket_sizes = np.maximum(sizes, min_bucket_size).tolist() |
|
|
|
|
|
|
|
|
self.memory_buckets = None |
|
|
self.memory_age = None |
|
|
self.bucket_importance = nn.Parameter(torch.ones(num_buckets)) |
|
|
|
|
|
|
|
|
self.query_proj = nn.Linear(embedding_dim, embedding_dim) |
|
|
self.key_proj = nn.Linear(embedding_dim, embedding_dim) |
|
|
self.value_proj = nn.Linear(embedding_dim, embedding_dim) |
|
|
self.output_proj = nn.Linear(embedding_dim, embedding_dim) |
|
|
self.input_norm = nn.LayerNorm(embedding_dim) |
|
|
self.output_norm = nn.LayerNorm(embedding_dim) |
|
|
|
|
|
self.bucket_selector = nn.Sequential( |
|
|
nn.Linear(embedding_dim, num_buckets * 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(num_buckets * 2, num_buckets), |
|
|
nn.Softmax(dim=-1) |
|
|
) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def _initialize_memory(self, batch_size, device): |
|
|
if self.memory_buckets is None: |
|
|
self.memory_buckets = [torch.zeros(batch_size, size, self.embedding_dim, device=device) |
|
|
for size in self.bucket_sizes] |
|
|
self.memory_age = [torch.zeros(batch_size, size, device=device) for size in self.bucket_sizes] |
|
|
|
|
|
def forward(self, input_data, memory_update=True): |
|
|
|
|
|
while input_data.dim() > 3: |
|
|
input_data = input_data.squeeze(0) |
|
|
if input_data.dim() == 4: |
|
|
input_data = input_data.squeeze(-1) |
|
|
if input_data.dim() == 2: |
|
|
input_data = input_data.unsqueeze(-1) |
|
|
if self.embedding_dim > 1: |
|
|
input_data = input_data.expand(-1, -1, self.embedding_dim) |
|
|
|
|
|
batch_size, seq_len, _ = input_data.size() |
|
|
device = input_data.device |
|
|
|
|
|
normalized_input = self.input_norm(input_data) |
|
|
|
|
|
|
|
|
if self.memory_buckets is None or len(self.memory_buckets[0]) != batch_size: |
|
|
self._initialize_memory(batch_size, device) |
|
|
|
|
|
|
|
|
avg_input_features = normalized_input.mean(dim=1) |
|
|
bucket_weights = self.bucket_selector(avg_input_features) |
|
|
|
|
|
|
|
|
projected_query = self.query_proj(normalized_input) |
|
|
outputs = torch.zeros(batch_size, seq_len, self.embedding_dim, device=device) |
|
|
|
|
|
for b in range(self.num_buckets): |
|
|
if bucket_weights[:, b].max() < 0.05: |
|
|
continue |
|
|
|
|
|
relevance = torch.bmm( |
|
|
projected_query, |
|
|
self.memory_buckets[b].transpose(1, 2) |
|
|
) / (self.embedding_dim ** 0.5) |
|
|
|
|
|
age_penalty = torch.exp(-self.memory_age[b] * 0.7).unsqueeze(1) |
|
|
relevance *= age_penalty |
|
|
|
|
|
retrieval_weights = F.softmax(relevance, dim=-1) |
|
|
retrieved_values = torch.bmm(retrieval_weights, self.memory_buckets[b]) |
|
|
|
|
|
importance_scale = torch.sigmoid(self.bucket_importance[b]) |
|
|
outputs += retrieved_values * importance_scale * bucket_weights[:, b].view(batch_size, 1, 1) |
|
|
|
|
|
memory_output = self.output_proj(outputs) |
|
|
|
|
|
|
|
|
if memory_update and self.training: |
|
|
with torch.no_grad(): |
|
|
keys = self.key_proj(normalized_input) |
|
|
values = self.value_proj(normalized_input) |
|
|
|
|
|
for b in range(self.num_buckets): |
|
|
bucket_size = self.bucket_sizes[b] |
|
|
bucket_mask = (bucket_weights[:, b] > 0.1).float().view(-1, 1, 1) |
|
|
|
|
|
if seq_len > bucket_size: |
|
|
stride = max(1, seq_len // bucket_size) |
|
|
indices = torch.arange(0, seq_len, stride, device=device)[:bucket_size] |
|
|
selected_values = values[:, indices] |
|
|
else: |
|
|
padding = bucket_size - seq_len |
|
|
selected_values = F.pad(values, (0, 0, 0, padding)) |
|
|
|
|
|
alpha = torch.sigmoid(self.bucket_importance[b]) * (0.8 if b > self.num_buckets // 2 else 0.2) |
|
|
|
|
|
update = alpha * self.memory_buckets[b] + (1 - alpha) * selected_values |
|
|
self.memory_buckets[b] = self.memory_buckets[b] * (1 - bucket_mask) + update * bucket_mask |
|
|
|
|
|
age_mask = (1 - bucket_mask.squeeze(-1)) |
|
|
self.memory_age[b] = self.memory_age[b] * age_mask + self.decay_rate |
|
|
|
|
|
return self.output_norm(input_data + memory_output) |
|
|
|
|
|
|
|
|
class BucketMemoryTransformerLayer(nn.Module): |
|
|
def __init__(self, d_model=512, d_ff=2048, dropout=0.4, num_buckets=8, |
|
|
min_bucket_size=1, max_bucket_size=32, use_flash_attention=True, |
|
|
num_heads=8): |
|
|
super().__init__() |
|
|
self.use_flash_attention = use_flash_attention |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = d_model // num_heads |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model) |
|
|
self.k_proj = nn.Linear(d_model, d_model) |
|
|
self.v_proj = nn.Linear(d_model, d_model) |
|
|
self.out_proj = nn.Linear(d_model, d_model) |
|
|
|
|
|
|
|
|
self.bucket_memory = DynamicBucketMemory( |
|
|
embedding_dim=d_model, num_buckets=num_buckets, |
|
|
min_bucket_size=min_bucket_size, max_bucket_size=max_bucket_size |
|
|
) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.ff = nn.Sequential( |
|
|
nn.Linear(d_model, d_ff), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(d_ff, d_model) |
|
|
) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
|
|
|
residual = x |
|
|
x = self.norm1(x) |
|
|
|
|
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
|
|
|
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): |
|
|
|
|
|
attn_mask = None |
|
|
if attention_mask is not None: |
|
|
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
attn_mask = (1.0 - attn_mask) * -10000.0 |
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
|
is_causal=False |
|
|
) |
|
|
else: |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
|
|
|
if attention_mask is not None: |
|
|
scores = scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9) |
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1) |
|
|
attn_weights = self.dropout(attn_weights) |
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) |
|
|
attn_output = self.out_proj(attn_output) |
|
|
x = residual + self.dropout(attn_output) |
|
|
|
|
|
|
|
|
memory_out = self.bucket_memory(self.norm2(x)) |
|
|
x = x + self.dropout(memory_out) |
|
|
|
|
|
|
|
|
x = x + self.dropout(self.ff(self.norm3(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BucketMemoryModel(PreTrainedModel): |
|
|
config_class = BucketMemoryConfig |
|
|
base_model_prefix = "bucket-memory-model2" |
|
|
def __init__(self, config, adapter_kwargs=None): |
|
|
super().__init__(config) |
|
|
self.d_model = config.d_model |
|
|
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) |
|
|
self.pos_encoding = nn.Parameter(torch.zeros(1, config.max_seq_length, config.d_model)) |
|
|
self._init_positional_encoding(config.max_seq_length, config.d_model) |
|
|
|
|
|
|
|
|
num_heads = getattr(config, 'num_attention_heads', config.d_model // 64) |
|
|
num_heads = max(1, num_heads) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
BucketMemoryTransformerLayer( |
|
|
d_model=config.d_model, |
|
|
d_ff=4*config.d_model, |
|
|
dropout=config.dropout, |
|
|
num_buckets=config.num_buckets, |
|
|
min_bucket_size=config.min_bucket_size, |
|
|
max_bucket_size=config.max_bucket_size, |
|
|
use_flash_attention=getattr(config, 'use_flash_attention', True), |
|
|
num_heads=num_heads |
|
|
) for _ in range(config.num_layers) |
|
|
]) |
|
|
|
|
|
self.norm = nn.LayerNorm(config.d_model) |
|
|
self.output_proj = nn.Linear(config.d_model, config.vocab_size) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def _init_positional_encoding(self, max_len, d_model): |
|
|
position = torch.arange(0, max_len).unsqueeze(1).float() |
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)) |
|
|
pos_enc = torch.zeros(1, max_len, d_model) |
|
|
pos_enc[0, :, 0::2] = torch.sin(position * div_term) |
|
|
pos_enc[0, :, 1::2] = torch.cos(position * div_term) |
|
|
self.pos_encoding.data.copy_(pos_enc) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None): |
|
|
batch_size, seq_len = input_ids.size() |
|
|
x = self.token_embedding(input_ids) * np.sqrt(self.d_model) |
|
|
x = x + self.pos_encoding[:, :seq_len] |
|
|
x = self.dropout(x) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, attention_mask) |
|
|
|
|
|
x = self.norm(x) |
|
|
logits = self.output_proj(x) |
|
|
|
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
|
|
return type('ModelOutput', (), {'loss': loss, 'logits': logits}) |
|
|
return logits |
|
|
|
|
|
AutoConfig.register("bucket-memory-model3", BucketMemoryConfig) |
|
|
AutoModel.register(BucketMemoryConfig, BucketMemoryModel) |
|
|
BucketMemoryConfig.register_for_auto_class() |
|
|
BucketMemoryModel.register_for_auto_class("AutoModel") |
|
|
|