Qalb-Pro / modeling_qalb.py
ReySajju742's picture
Upload folder using huggingface_hub
c0e807d verified
import torch
from torch import nn
from transformers import OPTPreTrainedModel, OPTModel, OPTConfig
class QalbConfig(OPTConfig):
model_type = "qalb"
def __init__(self, table_size=500000, **kwargs):
super().__init__(**kwargs)
self.table_size = table_size
class DeepSeekEngramModule(nn.Module):
def __init__(self, config):
super().__init__()
self.table_size = getattr(config, "table_size", 500000)
self.dim = config.word_embed_proj_dim
self.memory_table = nn.Embedding(self.table_size, self.dim)
self.gate = nn.Linear(self.dim, 1)
self.polynomial_base = 31
def forward(self, input_ids, hidden_states):
batch_size, seq_len = input_ids.shape
hashes = torch.zeros_like(input_ids)
for t in range(seq_len):
hashes[:, t] = (input_ids[:, :t+1].sum(dim=1) * self.polynomial_base) % self.table_size
memory_features = self.memory_table(hashes.abs())
g = torch.sigmoid(self.gate(hidden_states))
return g * hidden_states + (1 - g) * memory_features.to(hidden_states.dtype)
class FinalPerfectQalb(OPTPreTrainedModel):
config_class = QalbConfig
def __init__(self, config):
super().__init__(config)
self.backbone = OPTModel(config)
self.engram = DeepSeekEngramModule(config)
self.post_init()
def forward(self, input_ids, attention_mask=None, **kwargs):
outputs = self.backbone(input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
enhanced_states = self.engram(input_ids, hidden_states)
# Project to vocab using the backbone's embeddings
logits = torch.matmul(enhanced_states, self.backbone.decoder.embed_tokens.weight.T)
return torch.nn.utils.rnn.PackedSequence(logits) if isinstance(logits, tuple) else type('obj', (object,), {'logits': logits})