testy / infer.py
Koyd111's picture
Upload 8 files
32b6996 verified
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from preprocess_dataset import preprocess_text
from torch import Tensor
from torch.nn import Transformer
import math
import bitsandbytes as bnb
from invariants import get_data_pairs
from french_dataset import get_full_dataset
import numpy as np
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_SEQUENCE_LENGTH = 512
training_pairs = get_data_pairs(get_full_dataset())
PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 256, 699, 257
class CipherDataset(Dataset):
def __init__(self, pairs):
self.encodings = []
self.symbol_indices = []
bulk_dataset = pairs[:3000]
for entry in bulk_dataset:
if len(entry[0]) < MAX_SEQUENCE_LENGTH:
self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
else:
self.encodings.append([BOS_IDX_SRC] + entry[0][: - 1])
if len(entry[1]) < MAX_SEQUENCE_LENGTH:
self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
else:
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][: - 1])
def __len__(self):
return len(self.encodings)
def __getitem__(self, idx):
return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
class PositionalEncoding(nn.Module):
def __init__(self, emb_size: int, dropout: float, maxlen: int = MAX_SEQUENCE_LENGTH):
super(PositionalEncoding, self).__init__()
den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(0).transpose(0, 1)
self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)
def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.pos_encoder = PositionalEncoding(emb_size, 0.2)
self.emb_size = emb_size
def forward(self, tokens: Tensor):
return self.pos_encoder(self.embedding(tokens.long()) * math.sqrt(self.emb_size))
class Seq2SeqTransformer(nn.Module):
def __init__(self,
num_encoder_layers: int,
num_decoder_layers: int,
emb_size: int,
nhead: int,
src_vocab_size: int,
tgt_vocab_size: int,
dim_feedforward: int = 512,
dropout: float = 0.1):
super(Seq2SeqTransformer, self).__init__()
self.transformer = Transformer(d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout)
self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
def forward(self,
src: Tensor,
trg: Tensor,
src_mask: Tensor,
tgt_mask: Tensor,
src_padding_mask: Tensor,
tgt_padding_mask: Tensor,
memory_key_padding_mask: Tensor):
src_emb = self.src_tok_emb(src)
tgt_emb = self.tgt_tok_emb(trg)
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
return self.generator(outs)
def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer.encoder(self.src_tok_emb(src), src_mask)
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer.decoder(self.tgt_tok_emb(tgt), memory, tgt_mask)
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def create_mask(src, tgt):
src_seq_len = src.shape[0]
tgt_seq_len = tgt.shape[0]
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
tgt_padding_mask = (tgt == PAD_IDX_TGT).transpose(0, 1)
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
torch.manual_seed(22330)
GRAD_CLIP = 1.0
SRC_VOCAB_SIZE = 700
TGT_VOCAB_SIZE = 258
EMB_SIZE = 512
NHEAD = 16
FFN_HID_DIM = 768
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4
GRAD_CLIP = 1.0
DROPOUT = 0.3 # Increased Dropout
TOTAL_STEPS = 400
model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM,
dropout=DROPOUT)
model.load_state_dict(torch.load("cipher.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
src = src.to(DEVICE)
src_mask = src_mask.to(DEVICE)
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
for i in range(max_len-1):
memory = memory.to(DEVICE)
tgt_mask = (generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(DEVICE)
out = model.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
ys = torch.cat([ys,
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
return ys
from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
from preprocess_dataset import get_frequency_ranks, get_proximity_array
# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
model.eval()
symbols = load_or_save_symbols([])
rule = substitution_cipher(symbols, 1337)
encodings, _ = encode_text_with_indices(rule, symbols, src_sentence)
encodings = torch.tensor(encodings[:20]).view(-1, 1)
print(encodings.shape)
num_tokens = encodings.shape[0]
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
tgt_tokens = greedy_decode(
model, encodings, src_mask, max_len=MAX_SEQUENCE_LENGTH, start_symbol=256).flatten()
tokens = tgt_tokens.cpu().numpy()[1:]
for i, x in enumerate(tokens):
if x > 255:
tokens[i] = 0
print(tokens)
return (''.join([symbols[x] for x in tokens]))
print(translate(model, """Bien sûr ! Voici un texte plus long en français: La beauté de la nature réside dans sa diversité et sa capacité à émerveiller à chaque saison. Chaque paysage, qu’il s’agisse de montagnes imposantes, de forêts mystérieuses ou de rivières sinueuses, possède une âme et une histoire à raconter. Lorsqu’on s’aventure au cœur d’une forêt, l’air frais, imprégné des parfums de bois et de terre humide, nous invite à ralentir et à savourer l’instant. Les feuilles qui bruissent sous nos pas, les oiseaux qui chantent et la lumière qui filtre à travers les arbres créent une atmosphère presque magique, propice à la contemplation. Au printemps, la nature se réveille lentement, offrant un spectacle de couleurs éclatantes : des fleurs qui éclorent, des bourgeons qui germent, des champs qui se parent de verts éclatants. L'été, quant à lui, emporte tout dans un tourbillon de chaleur, de lumière et de vie. Les journées longues et ensoleillées sont idéales pour profiter des bienfaits du plein air, que ce soit à la mer, à la montagne ou simplement dans le jardin. L’automne, avec ses nuances orangées et dorées, est une invitation à la réflexion et à la tranquillité. Les feuilles tombent en tourbillonnant, créant des tapis colorés qui habillent la terre. Puis vient l’hiver, avec son froid piquant et la neige qui transforme le monde en un paysage féerique, silencieux et apaisant. Au-delà de la beauté visuelle, la nature nous enseigne aussi l’humilité et la résilience. Elle nous rappelle que tout est en perpétuel mouvement et que chaque cycle a sa raison d’être. Nous, humains, sommes un petit maillon dans cet écosystème complexe, et il est de notre responsabilité de préserver ce précieux équilibre. En prenant soin de notre environnement, nous garantissons non seulement la survie des espèces, mais aussi notre propre bien-être. La nature, avec sa sagesse silencieuse, continue de nous offrir des leçons de vie, chaque jour, sous nos yeux émerveillés."""))