Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| import json | |
| import sentencepiece as spm | |
| import gradio as gr | |
| # ========================= | |
| # Load config | |
| # ========================= | |
| with open("config.json") as f: | |
| config = json.load(f) | |
| padIndex = config["pad_id"] | |
| BOSIndex = config["bos_id"] | |
| EOSIndex = config["eos_id"] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ========================= | |
| # SentencePiece | |
| # ========================= | |
| sp_en = spm.SentencePieceProcessor() | |
| sp_en.load("sp_en.model") | |
| sp_ar = spm.SentencePieceProcessor() | |
| sp_ar.load("sp_ar.model") | |
| # ========================= | |
| # MODEL (EXACT TRAINING VERSION) | |
| # ========================= | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, num_heads): | |
| super().__init__() | |
| assert d_model % num_heads == 0 | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| self.W_q = nn.Linear(d_model, d_model) | |
| self.W_k = nn.Linear(d_model, d_model) | |
| self.W_v = nn.Linear(d_model, d_model) | |
| self.W_o = nn.Linear(d_model, d_model) | |
| def scaled_dot_product_attention(self, Q, K, V, mask=None): | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, -1e9) | |
| attn = torch.softmax(scores, dim=-1) | |
| return torch.matmul(attn, V) | |
| def split_heads(self, x): | |
| B, T, D = x.size() | |
| return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2) | |
| def combine_heads(self, x): | |
| B, H, T, D = x.size() | |
| return x.transpose(1, 2).contiguous().view(B, T, self.d_model) | |
| def forward(self, Q, K, V, mask=None): | |
| Q = self.split_heads(self.W_q(Q)) | |
| K = self.split_heads(self.W_k(K)) | |
| V = self.split_heads(self.W_v(V)) | |
| out = self.scaled_dot_product_attention(Q, K, V, mask) | |
| return self.W_o(self.combine_heads(out)) | |
| class PositionWiseFeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(d_model, d_ff), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len, dropout=0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * | |
| -(math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe.unsqueeze(0)) | |
| def forward(self, x): | |
| x = x + self.pe[:, :x.size(1)] | |
| return self.dropout(x) | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, num_heads) | |
| self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask): | |
| x = self.norm1(x + self.dropout(self.self_attn(x, x, x, mask))) | |
| x = self.norm2(x + self.dropout(self.feed_forward(x))) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, d_model, num_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, num_heads) | |
| self.cross_attn = MultiHeadAttention(d_model, num_heads) | |
| self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, enc_out, src_mask, tgt_mask): | |
| x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask))) | |
| x = self.norm2(x + self.dropout(self.cross_attn(x, enc_out, enc_out, src_mask))) | |
| x = self.norm3(x + self.dropout(self.feed_forward(x))) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__(self, src_vocab, tgt_vocab, | |
| d_model=256, num_heads=4, num_layers=3, | |
| d_ff=512, max_len=100): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.encoder_embedding = nn.Embedding(src_vocab, d_model, padding_idx=0) | |
| self.decoder_embedding = nn.Embedding(tgt_vocab, d_model, padding_idx=0) | |
| self.positional_encoding = PositionalEncoding(d_model, max_len) | |
| self.encoder_layers = nn.ModuleList([ | |
| EncoderLayer(d_model, num_heads, d_ff) | |
| for _ in range(num_layers) | |
| ]) | |
| self.decoder_layers = nn.ModuleList([ | |
| DecoderLayer(d_model, num_heads, d_ff) | |
| for _ in range(num_layers) | |
| ]) | |
| self.fc = nn.Linear(d_model, tgt_vocab) | |
| def generate_mask(self, src, tgt): | |
| src_mask = (src != 0).unsqueeze(1).unsqueeze(2) | |
| tgt_pad = (tgt != 0).unsqueeze(1).unsqueeze(3) | |
| T = tgt.size(1) | |
| causal = torch.tril(torch.ones(T, T)).bool().to(tgt.device) | |
| tgt_mask = tgt_pad & causal | |
| return src_mask, tgt_mask | |
| def forward(self, src, tgt): | |
| src_mask, tgt_mask = self.generate_mask(src, tgt) | |
| src = self.positional_encoding(self.encoder_embedding(src) * math.sqrt(self.d_model)) | |
| tgt = self.positional_encoding(self.decoder_embedding(tgt) * math.sqrt(self.d_model)) | |
| enc = src | |
| for layer in self.encoder_layers: | |
| enc = layer(enc, src_mask) | |
| dec = tgt | |
| for layer in self.decoder_layers: | |
| dec = layer(dec, enc, src_mask, tgt_mask) | |
| return self.fc(dec) | |
| # ========================= | |
| # Load model | |
| # ========================= | |
| model = Transformer( | |
| config["src_vocab_size"], | |
| config["tgt_vocab_size"], | |
| config["d_model"], | |
| config["num_heads"], | |
| config["num_layers"], | |
| config["d_ff"], | |
| max_len=max(config["max_src_len"], config["max_tgt_len"]) | |
| ).to(device) | |
| model.load_state_dict(torch.load("best_model.pt", map_location=device)) | |
| model.eval() | |
| # ========================= | |
| # Translation | |
| # ========================= | |
| def translate(text): | |
| src = sp_en.encode(text) | |
| src = [BOSIndex] + src + [EOSIndex] | |
| src = torch.tensor(src).unsqueeze(0).to(device) | |
| out = [BOSIndex] | |
| for _ in range(50): | |
| tgt = torch.tensor(out).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| pred = model(src, tgt) | |
| next_token = pred[0, -1].argmax().item() | |
| out.append(next_token) | |
| if next_token == EOSIndex: | |
| break | |
| result = sp_ar.decode([t for t in out if t not in [BOSIndex, EOSIndex, padIndex]]) | |
| return result | |
| # ========================= | |
| # UI | |
| # ========================= | |
| gr.Interface( | |
| fn=translate, | |
| inputs="text", | |
| outputs="text", | |
| title="English ↔ Arabic Transformer", | |
| ).launch() |