Omnia-cy commited on
Commit
37129bc
·
verified ·
1 Parent(s): 285b6af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import json
5
+ import sentencepiece as spm
6
+ import gradio as gr
7
+
8
+ # =========================
9
+ # Load config
10
+ # =========================
11
+ with open("config.json") as f:
12
+ config = json.load(f)
13
+
14
+ pad_id = config["pad_id"]
15
+ bos_id = config["bos_id"]
16
+ eos_id = config["eos_id"]
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+
21
+ # =========================
22
+ # SentencePiece
23
+ # =========================
24
+ sp_en = spm.SentencePieceProcessor()
25
+ sp_en.load("sp_en.model")
26
+
27
+ sp_ar = spm.SentencePieceProcessor()
28
+ sp_ar.load("sp_ar.model")
29
+
30
+
31
+ # =========================
32
+ # Model Classes (same as training)
33
+ # =========================
34
+ class MultiHeadAttention(nn.Module):
35
+ def __init__(self, d_model, num_heads):
36
+ super().__init__()
37
+ assert d_model % num_heads == 0
38
+ self.d_model = d_model
39
+ self.num_heads = num_heads
40
+ self.d_k = d_model // num_heads
41
+
42
+ self.W_q = nn.Linear(d_model, d_model)
43
+ self.W_k = nn.Linear(d_model, d_model)
44
+ self.W_v = nn.Linear(d_model, d_model)
45
+ self.W_o = nn.Linear(d_model, d_model)
46
+
47
+ def split(self, x):
48
+ B, T, D = x.size()
49
+ return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
50
+
51
+ def combine(self, x):
52
+ B, H, T, D = x.size()
53
+ return x.transpose(1, 2).contiguous().view(B, T, self.d_model)
54
+
55
+ def attention(self, q, k, v, mask=None):
56
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
57
+ if mask is not None:
58
+ scores = scores.masked_fill(mask == 0, -1e9)
59
+ return torch.softmax(scores, dim=-1) @ v
60
+
61
+ def forward(self, q, k, v, mask=None):
62
+ q = self.split(self.W_q(q))
63
+ k = self.split(self.W_k(k))
64
+ v = self.split(self.W_v(v))
65
+
66
+ out = self.attention(q, k, v, mask)
67
+ return self.W_o(self.combine(out))
68
+
69
+
70
+ class FFN(nn.Module):
71
+ def __init__(self, d_model, d_ff):
72
+ super().__init__()
73
+ self.net = nn.Sequential(
74
+ nn.Linear(d_model, d_ff),
75
+ nn.ReLU(),
76
+ nn.Linear(d_ff, d_model)
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.net(x)
81
+
82
+
83
+ class PosEnc(nn.Module):
84
+ def __init__(self, d_model, max_len):
85
+ super().__init__()
86
+ pe = torch.zeros(max_len, d_model)
87
+ pos = torch.arange(0, max_len).unsqueeze(1)
88
+ div = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
89
+
90
+ pe[:, 0::2] = torch.sin(pos * div)
91
+ pe[:, 1::2] = torch.cos(pos * div)
92
+
93
+ self.pe = pe.unsqueeze(0)
94
+
95
+ def forward(self, x):
96
+ return x + self.pe[:, :x.size(1)].to(x.device)
97
+
98
+
99
+ class EncoderLayer(nn.Module):
100
+ def __init__(self, d_model, heads, d_ff):
101
+ super().__init__()
102
+ self.attn = MultiHeadAttention(d_model, heads)
103
+ self.ffn = FFN(d_model, d_ff)
104
+ self.norm1 = nn.LayerNorm(d_model)
105
+ self.norm2 = nn.LayerNorm(d_model)
106
+
107
+ def forward(self, x, mask):
108
+ x = self.norm1(x + self.attn(x, x, x, mask))
109
+ x = self.norm2(x + self.ffn(x))
110
+ return x
111
+
112
+
113
+ class DecoderLayer(nn.Module):
114
+ def __init__(self, d_model, heads, d_ff):
115
+ super().__init__()
116
+ self.self_attn = MultiHeadAttention(d_model, heads)
117
+ self.cross_attn = MultiHeadAttention(d_model, heads)
118
+ self.ffn = FFN(d_model, d_ff)
119
+
120
+ self.n1 = nn.LayerNorm(d_model)
121
+ self.n2 = nn.LayerNorm(d_model)
122
+ self.n3 = nn.LayerNorm(d_model)
123
+
124
+ def forward(self, x, enc, src_mask, tgt_mask):
125
+ x = self.n1(x + self.self_attn(x, x, x, tgt_mask))
126
+ x = self.n2(x + self.cross_attn(x, enc, enc, src_mask))
127
+ x = self.n3(x + self.ffn(x))
128
+ return x
129
+
130
+
131
+ class Transformer(nn.Module):
132
+ def __init__(self):
133
+ super().__init__()
134
+
135
+ self.d_model = config["d_model"]
136
+
137
+ self.enc_emb = nn.Embedding(config["src_vocab_size"], self.d_model, padding_idx=0)
138
+ self.dec_emb = nn.Embedding(config["tgt_vocab_size"], self.d_model, padding_idx=0)
139
+
140
+ self.pos = PosEnc(self.d_model, config["max_src_len"])
141
+
142
+ self.enc_layers = nn.ModuleList([
143
+ EncoderLayer(self.d_model, config["num_heads"], config["d_ff"])
144
+ for _ in range(config["num_layers"])
145
+ ])
146
+
147
+ self.dec_layers = nn.ModuleList([
148
+ DecoderLayer(self.d_model, config["num_heads"], config["d_ff"])
149
+ for _ in range(config["num_layers"])
150
+ ])
151
+
152
+ self.fc = nn.Linear(self.d_model, config["tgt_vocab_size"])
153
+
154
+ def masks(self, src, tgt):
155
+ src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
156
+ tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
157
+ size = tgt.size(1)
158
+ causal = torch.tril(torch.ones(size, size)).bool().to(tgt.device)
159
+ return src_mask, tgt_mask & causal
160
+
161
+ def forward(self, src, tgt):
162
+ src_mask, tgt_mask = self.masks(src, tgt)
163
+
164
+ src = self.pos(self.enc_emb(src))
165
+ tgt = self.pos(self.dec_emb(tgt))
166
+
167
+ enc = src
168
+ for layer in self.enc_layers:
169
+ enc = layer(enc, src_mask)
170
+
171
+ dec = tgt
172
+ for layer in self.dec_layers:
173
+ dec = layer(dec, enc, src_mask, tgt_mask)
174
+
175
+ return self.fc(dec)
176
+
177
+
178
+ # =========================
179
+ # Load model
180
+ # =========================
181
+ model = Transformer().to(device)
182
+ model.load_state_dict(torch.load("best_model.pt", map_location=device))
183
+ model.eval()
184
+
185
+
186
+ # =========================
187
+ # Inference
188
+ # =========================
189
+ def translate(sentence):
190
+
191
+ tokens = sp_en.encode(sentence)
192
+ tokens = [bos_id] + tokens + [eos_id]
193
+
194
+ src = torch.tensor(tokens).unsqueeze(0).to(device)
195
+
196
+ out = [bos_id]
197
+
198
+ for _ in range(50):
199
+
200
+ tgt = torch.tensor(out).unsqueeze(0).to(device)
201
+
202
+ with torch.no_grad():
203
+ pred = model(src, tgt)
204
+
205
+ next_token = pred[0, -1].argmax().item()
206
+
207
+ out.append(next_token)
208
+
209
+ if next_token == eos_id:
210
+ break
211
+
212
+ result = sp_ar.decode([t for t in out if t not in [bos_id, eos_id, pad_id]])
213
+ return result
214
+
215
+
216
+ # =========================
217
+ # UI
218
+ # =========================
219
+ demo = gr.Interface(
220
+ fn=translate,
221
+ inputs="text",
222
+ outputs="text",
223
+ title="Arabic ↔ English Translator (Transformer)",
224
+ description="Enter English sentence and get Arabic translation"
225
+ )
226
+
227
+ demo.launch()