File size: 10,125 Bytes
32b6996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from preprocess_dataset import preprocess_text
from torch import Tensor
from torch.nn import Transformer
import math
import bitsandbytes as bnb
from invariants import get_data_pairs
from french_dataset import get_full_dataset
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_SEQUENCE_LENGTH = 512
training_pairs = get_data_pairs(get_full_dataset())
PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 256, 699, 257

class CipherDataset(Dataset):
    def __init__(self, pairs):
        self.encodings = []
        self.symbol_indices = []
        bulk_dataset = pairs[:3000]
        for entry in bulk_dataset:

            if len(entry[0]) < MAX_SEQUENCE_LENGTH:
                self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
            elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
                self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
            else:
                self.encodings.append([BOS_IDX_SRC] + entry[0][: - 1])

            if len(entry[1]) < MAX_SEQUENCE_LENGTH:
                self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
            elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
                self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
            else:
                self.symbol_indices.append([BOS_IDX_TGT] + entry[1][: - 1])
    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float, maxlen: int = MAX_SEQUENCE_LENGTH):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(0).transpose(0, 1)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.pos_encoder = PositionalEncoding(emb_size, 0.2)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.pos_encoder(self.embedding(tokens.long()) * math.sqrt(self.emb_size))

class Seq2SeqTransformer(nn.Module):
    def __init__(self,

                 num_encoder_layers: int,

                 num_decoder_layers: int,

                 emb_size: int,

                 nhead: int,

                 src_vocab_size: int,

                 tgt_vocab_size: int,

                 dim_feedforward: int = 512,

                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)

    def forward(self,

                src: Tensor,

                trg: Tensor,

                src_mask: Tensor,

                tgt_mask: Tensor,

                src_padding_mask: Tensor,

                tgt_padding_mask: Tensor,

                memory_key_padding_mask: Tensor):
        src_emb = self.src_tok_emb(src)
        tgt_emb = self.tgt_tok_emb(trg)
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.src_tok_emb(src), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.tgt_tok_emb(tgt), memory, tgt_mask)

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX_TGT).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

torch.manual_seed(22330)

GRAD_CLIP = 1.0 

SRC_VOCAB_SIZE = 700
TGT_VOCAB_SIZE = 258
EMB_SIZE = 512
NHEAD = 16
FFN_HID_DIM = 768
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4
GRAD_CLIP = 1.0
DROPOUT = 0.3  # Increased Dropout
TOTAL_STEPS = 400

model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM,
                                 dropout=DROPOUT)
model.load_state_dict(torch.load("cipher.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
    return ys

from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
from preprocess_dataset import get_frequency_ranks, get_proximity_array
# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    symbols = load_or_save_symbols([])
    rule = substitution_cipher(symbols, 1337)
    encodings, _ = encode_text_with_indices(rule, symbols, src_sentence)
    encodings = torch.tensor(encodings[:20]).view(-1, 1)
    print(encodings.shape)
    num_tokens = encodings.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  encodings, src_mask, max_len=MAX_SEQUENCE_LENGTH, start_symbol=256).flatten()
    tokens = tgt_tokens.cpu().numpy()[1:]
    for i, x in enumerate(tokens):
        if x > 255:
            tokens[i] = 0
    print(tokens)
    return (''.join([symbols[x] for x in tokens]))


print(translate(model, """Bien sûr ! Voici un texte plus long en français: La beauté de la nature réside dans sa diversité et sa capacité à émerveiller à chaque saison. Chaque paysage, qu’il s’agisse de montagnes imposantes, de forêts mystérieuses ou de rivières sinueuses, possède une âme et une histoire à raconter. Lorsqu’on s’aventure au cœur d’une forêt, l’air frais, imprégné des parfums de bois et de terre humide, nous invite à ralentir et à savourer l’instant. Les feuilles qui bruissent sous nos pas, les oiseaux qui chantent et la lumière qui filtre à travers les arbres créent une atmosphère presque magique, propice à la contemplation. Au printemps, la nature se réveille lentement, offrant un spectacle de couleurs éclatantes : des fleurs qui éclorent, des bourgeons qui germent, des champs qui se parent de verts éclatants. L'été, quant à lui, emporte tout dans un tourbillon de chaleur, de lumière et de vie. Les journées longues et ensoleillées sont idéales pour profiter des bienfaits du plein air, que ce soit à la mer, à la montagne ou simplement dans le jardin. L’automne, avec ses nuances orangées et dorées, est une invitation à la réflexion et à la tranquillité. Les feuilles tombent en tourbillonnant, créant des tapis colorés qui habillent la terre. Puis vient l’hiver, avec son froid piquant et la neige qui transforme le monde en un paysage féerique, silencieux et apaisant. Au-delà de la beauté visuelle, la nature nous enseigne aussi l’humilité et la résilience. Elle nous rappelle que tout est en perpétuel mouvement et que chaque cycle a sa raison d’être. Nous, humains, sommes un petit maillon dans cet écosystème complexe, et il est de notre responsabilité de préserver ce précieux équilibre. En prenant soin de notre environnement, nous garantissons non seulement la survie des espèces, mais aussi notre propre bien-être. La nature, avec sa sagesse silencieuse, continue de nous offrir des leçons de vie, chaque jour, sous nos yeux émerveillés."""))