testy / test2.py
Koyd111's picture
Upload 8 files
32b6996 verified
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from preprocess_dataset import preprocess_text
import bitsandbytes as bnb
from invariants import get_data_pairs
from french_dataset import get_full_dataset
from torch.utils.data import Dataset, DataLoader
from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_SEQUENCE_LENGTH = 512
PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 257, 699, 258
training_pairs = get_data_pairs(get_full_dataset())
def create_mask(src):
src_seq_len = src.shape[0]
src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
return src_padding_mask
class CipherDataset(Dataset):
def __init__(self, pairs):
self.encodings = []
self.symbol_indices = []
bulk_dataset = pairs
for entry in bulk_dataset:
if len(entry[0]) < MAX_SEQUENCE_LENGTH:
self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
else:
self.encodings.append([BOS_IDX_SRC] + entry[0][:-1])
if len(entry[1]) < MAX_SEQUENCE_LENGTH:
self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
else:
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:-1])
def __len__(self):
return len(self.encodings)
def __getitem__(self, idx):
return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
# Wrap data in the simplest possible way to enable PyTorch data fetching
# https://pytorch.org/docs/stable/data.html
BATCH_SIZE = 64
TRAIN_FRAC = 0.995
dataset = CipherDataset(training_pairs)
N = len(training_pairs)
print(N)
S = 512
C = 700
# Split into train and val
num_train = int(N * TRAIN_FRAC)
num_val = N - num_train
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)
# Sample batch
x, y = next(iter(dataloader_train))
class PositionalEncoding(nn.Module):
"""
Classic Attention-is-all-you-need positional encoding.
From PyTorch docs.
"""
def __init__(self, d_model, dropout=0.1, max_len=512):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
def generate_square_subsequent_mask(size: int):
"""Generate a triangular (size, size) mask. From PyTorch docs."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
class Transformer(nn.Module):
"""
Classic Transformer that both encodes and decodes.
Prediction-time inference is done greedily.
NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
"""
def __init__(self, num_classes: int, max_output_length: int, dim: int = 192):
super().__init__()
# Parameters
self.dim = dim
self.max_output_length = max_output_length
nhead = 16
num_layers = 8
dim_feedforward = dim
# Encoder part
self.x_embedding = nn.Embedding(700, dim)
self.y_embedding = nn.Embedding(259, dim)
self.pos_encoder = PositionalEncoding(d_model=self.dim)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
num_layers=num_layers
)
# Decoder part
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
num_layers=num_layers
)
self.fc = nn.Linear(self.dim, 259)
# It is empirically important to initialize weights properly
self.init_weights()
def init_weights(self):
initrange = 0.1
self.x_embedding.weight.data.uniform_(-initrange, initrange)
self.y_embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Input
x: (B, Sx) with elements in (0, C) where C is num_classes
y: (B, Sy) with elements in (0, C) where C is num_classes
Output
(B, C, Sy) logits
"""
x_pad_mask= create_mask(x)
encoded_x = self.encode(x, x_pad_mask) # (Sx, B, E)
output = self.decode(y, encoded_x) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def encode(self, x: torch.Tensor, x_pad_mask: torch.Tensor) -> torch.Tensor:
"""
Input
x: (B, Sx) with elements in (0, C) where C is num_classes
Output
(Sx, B, E) embedding
"""
x = x.permute(1, 0) # (Sx, B, E)
x = self.x_embedding(x) * math.sqrt(self.dim) # (Sx, B, E)
x = self.pos_encoder(x) # (Sx, B, E)
x = self.transformer_encoder(x, None, x_pad_mask.transpose(0,1)) # (Sx, B, E)
return x
def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
"""
Input
encoded_x: (Sx, B, E)
y: (B, Sy) with elements in (0, C) where C is num_classes
Output
(Sy, B, C) logits
"""
y = y.permute(1, 0) # (Sy, B)
y = self.y_embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x) # (Sy, Sy)
output = self.transformer_decoder(y, encoded_x, y_mask) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Method to use at inference time. Predict y from x one token at a time. This method is greedy
decoding. Beam search can be used instead for a potential accuracy boost.
Input
x: (B, Sx) with elements in (0, C) where C is num_classes
Output
(B, C, Sy) logits
"""
x_pad_mask = create_mask(x)
encoded_x = self.encode(x, x_pad_mask)
output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
output_tokens[:, 0] = BOS_IDX_TGT # Set start token
for Sy in range(1, self.max_output_length):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(y, encoded_x) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
return output_tokens
class LitModel(pl.LightningModule):
"""Simple PyTorch-Lightning model to train our Transformer."""
def __init__(self, model):
super().__init__()
self.model = model
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=0.2, reduction='mean')
def training_step(self, batch, batch_ind):
x, y = batch
# Teacher forcing: model gets input up to the last character,
# while ground truth is from the second character onward.
logits = self.model(x, y[:, :-1])
loss = self.loss(logits, y[:, 1:])
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_ind):
x, y = batch
logits = self.model(x, y[:, :-1])
loss = self.loss(logits, y[:, 1:])
self.log("val_loss", loss, prog_bar=True)
def configure_optimizers(self):
return bnb.optim.AdamW8bit(self.parameters(), lr=0.0005, betas=(0.9, 0.99), eps=1e-8, weight_decay=0.01)
# We can see that the decoding works correctly
x, y = next(iter(dataloader_val))
model = Transformer(num_classes=700, max_output_length=y.shape[1])
lit_model = LitModel(model)
trainer = pl.Trainer(max_epochs=1)
trainer.fit(lit_model, dataloader_train, dataloader_val)
torch.save(model.state_dict(), "cipher.pth")
print('Input:', x[:1])
pred = lit_model.model.predict(x[:1])
print('Truth/Pred:')
tokens = torch.cat((y[:1], pred)).cpu().numpy()
symbols = load_or_save_symbols([])
print(tokens)
print(''.join([symbols[x] if x < 256 else "#" for x in tokens[0][1:]]))
print(''.join([symbols[x] if x < 256 else "#" for x in tokens[1][1:]]))