File size: 4,140 Bytes
3d56b4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
# =========================
# Config
# =========================
class TinyWayConfig(PretrainedConfig):
model_type = "tinyway"
def __init__(
self,
vocab_size=50257,
n_positions=256,
n_embd=384,
n_layer=8,
n_head=8,
dropout=0.1,
**kwargs
):
super().__init__(**kwargs)
# --- original fields ---
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dropout = dropout
# --- HF standard aliases (CRITICAL) ---
self.hidden_size = n_embd
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.max_position_embeddings = n_positions
# =========================
# Attention
# =========================
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.register_buffer(
"mask",
torch.tril(torch.ones(config.n_positions, config.n_positions))
)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
out = att @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(out)
# =========================
# Transformer Block
# =========================
class DecoderBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config)
self.ffn = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd)
)
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = x + self.dropout(self.attn(self.ln1(x)))
x = x + self.dropout(self.ffn(self.ln2(x)))
return x
# =========================
# Model
# =========================
class TinyWayForCausalLM(PreTrainedModel, GenerationMixin):
config_class = TinyWayConfig
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.n_positions, config.n_embd)
self.blocks = nn.ModuleList(
[DecoderBlock(config) for _ in range(config.n_layer)]
)
self.ln = nn.LayerNorm(config.n_embd)
# MUST match training
self.head = nn.Linear(config.n_embd, config.vocab_size)
self.post_init()
# ---- HF REQUIRED METHODS ----
def get_input_embeddings(self):
return self.token_emb
def set_input_embeddings(self, value):
self.token_emb = value
# ---- Forward ----
def forward(self, input_ids, **kwargs):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.token_emb(input_ids) + self.pos_emb(pos)
for block in self.blocks:
x = block(x)
x = self.ln(x)
logits = self.head(x)
return CausalLMOutput(logits=logits)
|