|
|
import math
|
|
|
import numpy as np
|
|
|
import pytorch_lightning as pl
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from preprocess_dataset import preprocess_text
|
|
|
import bitsandbytes as bnb
|
|
|
from invariants import get_data_pairs
|
|
|
from french_dataset import get_full_dataset
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
MAX_SEQUENCE_LENGTH = 512
|
|
|
PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 257, 699, 258
|
|
|
|
|
|
training_pairs = get_data_pairs(get_full_dataset())
|
|
|
def create_mask(src):
|
|
|
src_seq_len = src.shape[0]
|
|
|
|
|
|
src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
|
|
|
|
|
|
src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
|
|
|
return src_padding_mask
|
|
|
class CipherDataset(Dataset):
|
|
|
def __init__(self, pairs):
|
|
|
self.encodings = []
|
|
|
self.symbol_indices = []
|
|
|
bulk_dataset = pairs
|
|
|
for entry in bulk_dataset:
|
|
|
|
|
|
if len(entry[0]) < MAX_SEQUENCE_LENGTH:
|
|
|
self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
|
|
|
elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
|
|
|
self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
|
|
|
else:
|
|
|
self.encodings.append([BOS_IDX_SRC] + entry[0][:-1])
|
|
|
|
|
|
if len(entry[1]) < MAX_SEQUENCE_LENGTH:
|
|
|
self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
|
|
|
elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
|
|
|
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
|
|
|
else:
|
|
|
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:-1])
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.encodings)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 64
|
|
|
TRAIN_FRAC = 0.995
|
|
|
|
|
|
dataset = CipherDataset(training_pairs)
|
|
|
N = len(training_pairs)
|
|
|
print(N)
|
|
|
S = 512
|
|
|
C = 700
|
|
|
|
|
|
num_train = int(N * TRAIN_FRAC)
|
|
|
num_val = N - num_train
|
|
|
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))
|
|
|
|
|
|
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
|
|
|
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)
|
|
|
|
|
|
|
|
|
x, y = next(iter(dataloader_train))
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
"""
|
|
|
Classic Attention-is-all-you-need positional encoding.
|
|
|
From PyTorch docs.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, d_model, dropout=0.1, max_len=512):
|
|
|
super(PositionalEncoding, self).__init__()
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
|
|
self.register_buffer('pe', pe)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x + self.pe[:x.size(0), :]
|
|
|
return self.dropout(x)
|
|
|
|
|
|
|
|
|
def generate_square_subsequent_mask(size: int):
|
|
|
"""Generate a triangular (size, size) mask. From PyTorch docs."""
|
|
|
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
|
|
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
|
return mask
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
"""
|
|
|
Classic Transformer that both encodes and decodes.
|
|
|
|
|
|
Prediction-time inference is done greedily.
|
|
|
|
|
|
NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes: int, max_output_length: int, dim: int = 192):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.dim = dim
|
|
|
self.max_output_length = max_output_length
|
|
|
nhead = 16
|
|
|
num_layers = 8
|
|
|
dim_feedforward = dim
|
|
|
|
|
|
|
|
|
self.x_embedding = nn.Embedding(700, dim)
|
|
|
self.y_embedding = nn.Embedding(259, dim)
|
|
|
self.pos_encoder = PositionalEncoding(d_model=self.dim)
|
|
|
self.transformer_encoder = nn.TransformerEncoder(
|
|
|
encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
|
|
|
num_layers=num_layers
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
|
|
|
self.transformer_decoder = nn.TransformerDecoder(
|
|
|
decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
|
|
|
num_layers=num_layers
|
|
|
)
|
|
|
self.fc = nn.Linear(self.dim, 259)
|
|
|
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def init_weights(self):
|
|
|
initrange = 0.1
|
|
|
self.x_embedding.weight.data.uniform_(-initrange, initrange)
|
|
|
self.y_embedding.weight.data.uniform_(-initrange, initrange)
|
|
|
self.fc.bias.data.zero_()
|
|
|
self.fc.weight.data.uniform_(-initrange, initrange)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Input
|
|
|
x: (B, Sx) with elements in (0, C) where C is num_classes
|
|
|
y: (B, Sy) with elements in (0, C) where C is num_classes
|
|
|
Output
|
|
|
(B, C, Sy) logits
|
|
|
"""
|
|
|
x_pad_mask= create_mask(x)
|
|
|
|
|
|
encoded_x = self.encode(x, x_pad_mask)
|
|
|
output = self.decode(y, encoded_x)
|
|
|
return output.permute(1, 2, 0)
|
|
|
|
|
|
def encode(self, x: torch.Tensor, x_pad_mask: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Input
|
|
|
x: (B, Sx) with elements in (0, C) where C is num_classes
|
|
|
Output
|
|
|
(Sx, B, E) embedding
|
|
|
"""
|
|
|
x = x.permute(1, 0)
|
|
|
x = self.x_embedding(x) * math.sqrt(self.dim)
|
|
|
x = self.pos_encoder(x)
|
|
|
x = self.transformer_encoder(x, None, x_pad_mask.transpose(0,1))
|
|
|
return x
|
|
|
|
|
|
def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Input
|
|
|
encoded_x: (Sx, B, E)
|
|
|
y: (B, Sy) with elements in (0, C) where C is num_classes
|
|
|
Output
|
|
|
(Sy, B, C) logits
|
|
|
"""
|
|
|
y = y.permute(1, 0)
|
|
|
y = self.y_embedding(y) * math.sqrt(self.dim)
|
|
|
y = self.pos_encoder(y)
|
|
|
Sy = y.shape[0]
|
|
|
y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)
|
|
|
output = self.transformer_decoder(y, encoded_x, y_mask)
|
|
|
output = self.fc(output)
|
|
|
return output
|
|
|
|
|
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Method to use at inference time. Predict y from x one token at a time. This method is greedy
|
|
|
decoding. Beam search can be used instead for a potential accuracy boost.
|
|
|
|
|
|
Input
|
|
|
x: (B, Sx) with elements in (0, C) where C is num_classes
|
|
|
Output
|
|
|
(B, C, Sy) logits
|
|
|
"""
|
|
|
x_pad_mask = create_mask(x)
|
|
|
|
|
|
encoded_x = self.encode(x, x_pad_mask)
|
|
|
|
|
|
output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long()
|
|
|
output_tokens[:, 0] = BOS_IDX_TGT
|
|
|
for Sy in range(1, self.max_output_length):
|
|
|
y = output_tokens[:, :Sy]
|
|
|
output = self.decode(y, encoded_x)
|
|
|
output = torch.argmax(output, dim=-1)
|
|
|
output_tokens[:, Sy] = output[-1:]
|
|
|
return output_tokens
|
|
|
|
|
|
|
|
|
class LitModel(pl.LightningModule):
|
|
|
"""Simple PyTorch-Lightning model to train our Transformer."""
|
|
|
|
|
|
def __init__(self, model):
|
|
|
super().__init__()
|
|
|
self.model = model
|
|
|
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=0.2, reduction='mean')
|
|
|
|
|
|
def training_step(self, batch, batch_ind):
|
|
|
x, y = batch
|
|
|
|
|
|
|
|
|
logits = self.model(x, y[:, :-1])
|
|
|
loss = self.loss(logits, y[:, 1:])
|
|
|
self.log("train_loss", loss)
|
|
|
return loss
|
|
|
|
|
|
def validation_step(self, batch, batch_ind):
|
|
|
x, y = batch
|
|
|
logits = self.model(x, y[:, :-1])
|
|
|
loss = self.loss(logits, y[:, 1:])
|
|
|
self.log("val_loss", loss, prog_bar=True)
|
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
return bnb.optim.AdamW8bit(self.parameters(), lr=0.0005, betas=(0.9, 0.99), eps=1e-8, weight_decay=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, y = next(iter(dataloader_val))
|
|
|
|
|
|
model = Transformer(num_classes=700, max_output_length=y.shape[1])
|
|
|
lit_model = LitModel(model)
|
|
|
trainer = pl.Trainer(max_epochs=1)
|
|
|
trainer.fit(lit_model, dataloader_train, dataloader_val)
|
|
|
torch.save(model.state_dict(), "cipher.pth")
|
|
|
|
|
|
print('Input:', x[:1])
|
|
|
pred = lit_model.model.predict(x[:1])
|
|
|
print('Truth/Pred:')
|
|
|
tokens = torch.cat((y[:1], pred)).cpu().numpy()
|
|
|
symbols = load_or_save_symbols([])
|
|
|
print(tokens)
|
|
|
print(''.join([symbols[x] if x < 256 else "#" for x in tokens[0][1:]]))
|
|
|
print(''.join([symbols[x] if x < 256 else "#" for x in tokens[1][1:]]))
|
|
|
|
|
|
|