Spaces:
Sleeping
Sleeping
File size: 1,863 Bytes
c4ffc4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class TransformerBlock(nn.Module):
def __init__(self, sizeVector = 128, numHeads = 4):
super().__init__()
self.sizeVector = sizeVector
self.ln1 = nn.LayerNorm(sizeVector)
self.attn = nn.MultiheadAttention(sizeVector, numHeads, batch_first=True)
self.ln2 = nn.LayerNorm(sizeVector)
self.ff = nn.Sequential(
nn.Linear(sizeVector, sizeVector*4),
nn.GELU(),
nn.Linear(sizeVector*4, sizeVector),
)
def forward(self, x, attMask = None):
h = self.ln1(x)
z, _ = self.attn(h, h, h, attn_mask=attMask)
x = x + z
h = self.ln2(x)
z1 = self.ff(h)
x = x + z1
return x
class TransformerRun(nn.Module):
def __init__(self, vocabSize = 120000, maxLong = 256, sizeVector = 128 ,block = 4):
super().__init__()
self.maxLong = maxLong
self.tokenEmbed = nn.Embedding(vocabSize, sizeVector)
self.posEmbed = nn.Embedding(maxLong, sizeVector)
self.ln_f = nn.LayerNorm(sizeVector)
self.layers = nn.ModuleList([
TransformerBlock(sizeVector=sizeVector, numHeads=4)
for _ in range(block)
])
self.lmHead = nn.Linear(sizeVector,vocabSize)
def forward(self, x):
B,T = x.shape
tok = self.tokenEmbed(x)
pos = self.posEmbed(torch.arange(T, device=x.device)).unsqueeze(0)
h = tok + pos
attMask = torch.triu(
torch.full((T, T), float('-inf'), device=x.device),
diagonal=1
)
for layer in self.layers:
h = layer(h, attMask=attMask)
h = self.ln_f(h)
return self.lmHead(h) |