|
|
| import torch |
| from torch import nn |
| from transformers import OPTPreTrainedModel, OPTModel, OPTConfig |
|
|
| class QalbConfig(OPTConfig): |
| model_type = "qalb" |
| def __init__(self, table_size=500000, **kwargs): |
| super().__init__(**kwargs) |
| self.table_size = table_size |
|
|
| class DeepSeekEngramModule(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.table_size = getattr(config, "table_size", 500000) |
| self.dim = config.word_embed_proj_dim |
| self.memory_table = nn.Embedding(self.table_size, self.dim) |
| self.gate = nn.Linear(self.dim, 1) |
| self.polynomial_base = 31 |
|
|
| def forward(self, input_ids, hidden_states): |
| batch_size, seq_len = input_ids.shape |
| hashes = torch.zeros_like(input_ids) |
| for t in range(seq_len): |
| hashes[:, t] = (input_ids[:, :t+1].sum(dim=1) * self.polynomial_base) % self.table_size |
| |
| memory_features = self.memory_table(hashes.abs()) |
| g = torch.sigmoid(self.gate(hidden_states)) |
| return g * hidden_states + (1 - g) * memory_features.to(hidden_states.dtype) |
|
|
| class FinalPerfectQalb(OPTPreTrainedModel): |
| config_class = QalbConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.backbone = OPTModel(config) |
| self.engram = DeepSeekEngramModule(config) |
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| outputs = self.backbone(input_ids, attention_mask=attention_mask) |
| hidden_states = outputs.last_hidden_state |
| enhanced_states = self.engram(input_ids, hidden_states) |
| |
| logits = torch.matmul(enhanced_states, self.backbone.decoder.embed_tokens.weight.T) |
| return torch.nn.utils.rnn.PackedSequence(logits) if isinstance(logits, tuple) else type('obj', (object,), {'logits': logits}) |
|
|