File size: 6,826 Bytes
3dc2846 b49fb5f 3dc2846 b49fb5f 3dc2846 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
class SykoSLMConfig(PretrainedConfig):
model_type = "sykollm"
def __init__(self, vocab_size=32000, d_model=768, n_layers=24, n_heads=6,
num_memory_tokens=16, chunk_size=128, context_size=1024,
overlap_size=16, code_overlap_size=64, abstract_head_hidden=256,
abstract_head_layers=2, intermediate_size=3072, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.num_memory_tokens = num_memory_tokens
self.chunk_size = chunk_size
self.context_size = context_size
self.overlap_size = overlap_size
self.code_overlap_size = code_overlap_size
self.abstract_head_hidden = abstract_head_hidden
self.abstract_head_layers = abstract_head_layers
self.intermediate_size = intermediate_size
def apply_rotary_emb(x, cos, sin):
cos, sin = cos.to(x.dtype), sin.to(x.dtype)
d = x.shape[-1]
x1, x2 = x[..., :d//2], x[..., d//2:]
return (x * cos) + (torch.cat([-x2, x1], dim=-1) * sin)
class SykoRoPE(nn.Module):
def __init__(self, dim, base=10000.0):
super().__init__()
self.dim, self.base = dim, base
def forward(self, positions):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim))
freqs = torch.outer(positions.float(), inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
class SykoAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads, self.head_dim = n_heads, d_model // n_heads
self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, cos, sin):
B, L, D = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out(out.transpose(1, 2).reshape(B, L, D))
class SykoTransformerLayer(nn.Module):
def __init__(self, d_model, n_heads, intermediate_size):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = SykoAttention(d_model, n_heads)
self.norm2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, intermediate_size), nn.GELU(),
nn.Dropout(0.0),
nn.Linear(intermediate_size, d_model)
)
def forward(self, x, cos, sin):
x = x + self.attn(self.norm1(x), cos, sin)
return x + self.mlp(self.norm2(x))
class SykoMemoryGate(nn.Module):
def __init__(self, d_model):
super().__init__()
self.forget_linear = nn.Linear(d_model * 2, d_model)
self.update_linear = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, current_context, prev_memory):
combined = torch.cat([current_context, prev_memory], dim=-1)
forget_ratio = torch.sigmoid(self.forget_linear(combined))
new_candidate = torch.tanh(self.update_linear(current_context))
return self.norm((forget_ratio * prev_memory) + ((1 - forget_ratio) * new_candidate))
class SykoSLM(PreTrainedModel):
config_class = SykoSLMConfig
def __init__(self, config):
super().__init__(config)
self.mem_tokens = config.num_memory_tokens
self.d_model = config.d_model
pad_idx = getattr(config, "pad_token_id", 0) or 0
self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=pad_idx)
self.mem_pos_emb = nn.Embedding(config.num_memory_tokens, config.d_model)
self.rope = SykoRoPE(config.d_model // config.n_heads)
self.layers = nn.ModuleList([
SykoTransformerLayer(config.d_model, config.n_heads, config.intermediate_size)
for _ in range(config.n_layers)
])
self.final_norm = nn.LayerNorm(config.d_model)
self.memory_gate = SykoMemoryGate(config.d_model)
self.fc_out = nn.Linear(config.d_model, config.vocab_size)
def forward(self, input_ids, prev_memory=None, chunk_start_idx=0, **kwargs):
B = input_ids.size(0)
if prev_memory is None:
prev_memory = torch.zeros(B, self.mem_tokens, self.d_model, device=input_ids.device)
x = self.embedding(input_ids)
mem_idx = torch.arange(self.mem_tokens, device=input_ids.device)
memory_with_pos = prev_memory + self.mem_pos_emb(mem_idx).unsqueeze(0)
x_with_memory = torch.cat([memory_with_pos, x], dim=1)
mem_pos = torch.zeros(self.mem_tokens, dtype=torch.long, device=input_ids.device)
word_pos = torch.arange(chunk_start_idx, chunk_start_idx + x.size(1), device=input_ids.device)
cos, sin = self.rope(torch.cat([mem_pos, word_pos]))
for layer in self.layers:
x_with_memory = layer(x_with_memory, cos, sin)
x_with_memory = self.final_norm(x_with_memory)
memory_output = x_with_memory[:, :self.mem_tokens, :]
token_outputs = x_with_memory[:, self.mem_tokens:, :]
return self.fc_out(token_outputs), self.memory_gate(memory_output, prev_memory)
def generate_text(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50):
self.eval()
device = input_ids.device
prev_memory = torch.zeros(1, self.mem_tokens, self.d_model, device=device)
generated = input_ids.clone()
with torch.no_grad():
for _ in range(max_new_tokens):
chunk = generated[:, -self.config.chunk_size:]
logits, prev_memory = self.forward(chunk, prev_memory)
next_logits = logits[:, -1, :] / temperature
top_k_vals, top_k_idx = torch.topk(next_logits, k=min(top_k, next_logits.size(-1)))
filtered = torch.full_like(next_logits, float("-inf"))
filtered.scatter_(1, top_k_idx, top_k_vals)
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
generated = torch.cat([generated, next_token], dim=1)
eos = getattr(self.config, "eos_token_id", None)
if eos and next_token.item() == eos:
break
return generated
|