Spaces:
Sleeping
Sleeping
| """ | |
| English β Malay Translator β Gradio Interface | |
| ================================================ | |
| A live inference demo for the 6+2 Tied Transformer (16K shared BPE). | |
| Downloads model weights and tokenizer from the HuggingFace model repo | |
| on startup, then serves translations via a Gradio UI. | |
| """ | |
| import math | |
| import re | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from tokenizers import Tokenizer | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_REPO = "AstralPotato/en-ms-transformer" | |
| DEVICE = torch.device("cpu") # Spaces free tier = CPU only | |
| MODEL_KWARGS = dict( | |
| vocab_size=16000, d_model=512, n_head=8, | |
| num_encoder_layers=6, num_decoder_layers=2, | |
| d_ff=2048, dropout=0.0, max_len=144, pad_idx=0, | |
| ) | |
| BOS_ID, EOS_ID, PAD_ID = 5, 6, 0 | |
| MAX_DECODE_LEN = 128 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model architecture (inlined to avoid import issues on Spaces) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1).float() | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pe[:, : x.size(1)] | |
| return self.dropout(x) | |
| class TransformerTranslator(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int = 512, | |
| n_head: int = 8, | |
| num_encoder_layers: int = 6, | |
| num_decoder_layers: int = 2, | |
| d_ff: int = 2048, | |
| dropout: float = 0.1, | |
| max_len: int = 512, | |
| pad_idx: int = 0, | |
| ): | |
| super().__init__() | |
| self.pad_idx = pad_idx | |
| self.d_model = d_model | |
| self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx) | |
| self.pos_encoding = PositionalEncoding(d_model, max_len, dropout) | |
| self.embed_scale = math.sqrt(d_model) | |
| self.transformer = nn.Transformer( | |
| d_model=d_model, nhead=n_head, | |
| num_encoder_layers=num_encoder_layers, | |
| num_decoder_layers=num_decoder_layers, | |
| dim_feedforward=d_ff, dropout=dropout, | |
| batch_first=True, norm_first=True, | |
| ) | |
| self.output_bias = nn.Parameter(torch.zeros(vocab_size)) | |
| nn.init.normal_(self.shared_embedding.weight, mean=0, std=d_model ** -0.5) | |
| with torch.no_grad(): | |
| self.shared_embedding.weight[pad_idx].zero_() | |
| def _embed(self, tokens: torch.Tensor) -> torch.Tensor: | |
| return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale) | |
| def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor: | |
| return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1) | |
| def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor: | |
| return x == self.pad_idx | |
| def encode(self, src, src_key_padding_mask=None): | |
| if src_key_padding_mask is None: | |
| src_key_padding_mask = self._make_pad_mask(src) | |
| return self.transformer.encoder(self._embed(src), src_key_padding_mask=src_key_padding_mask) | |
| def decode(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
| if tgt_key_padding_mask is None: | |
| tgt_key_padding_mask = self._make_pad_mask(tgt) | |
| tgt_mask = self.generate_square_subsequent_mask(tgt.size(1), tgt.device) | |
| out = self.transformer.decoder( | |
| self._embed(tgt), memory, | |
| tgt_mask=tgt_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask, | |
| ) | |
| return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Decoding & post-processing | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def postprocess(text: str) -> str: | |
| text = re.sub(r'\s+([.,?!;:)\]}"\'β¦])', r'\1', text) | |
| text = re.sub(r'([(\[{"\'])\s+', r'\1', text) | |
| text = re.sub(r'\s*-\s*', '-', text) | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| text = text.strip() | |
| if text: | |
| text = text[0].upper() + text[1:] | |
| return text | |
| def greedy_decode(model, src, max_len=MAX_DECODE_LEN): | |
| src_pad_mask = (src == PAD_ID) | |
| memory = model.encode(src, src_key_padding_mask=src_pad_mask) | |
| ys = torch.tensor([[BOS_ID]], dtype=torch.long, device=DEVICE) | |
| for _ in range(max_len - 1): | |
| logits = model.decode(ys, memory, memory_key_padding_mask=src_pad_mask) | |
| next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) | |
| ys = torch.cat([ys, next_token], dim=1) | |
| if next_token.item() == EOS_ID: | |
| break | |
| return ys | |
| def beam_search_decode(model, src, beam_width=5, length_penalty=0.6, max_len=MAX_DECODE_LEN): | |
| src_pad_mask = (src == PAD_ID) | |
| memory = model.encode(src, src_key_padding_mask=src_pad_mask) | |
| beams = [(0.0, [BOS_ID])] | |
| completed = [] | |
| for _ in range(max_len - 1): | |
| candidates = [] | |
| for score, tokens in beams: | |
| if tokens[-1] == EOS_ID: | |
| completed.append((score, tokens)) | |
| continue | |
| ys = torch.tensor([tokens], dtype=torch.long, device=DEVICE) | |
| logits = model.decode(ys, memory, memory_key_padding_mask=src_pad_mask) | |
| log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0) | |
| topk_lp, topk_ids = log_probs.topk(beam_width) | |
| for k in range(beam_width): | |
| candidates.append((score + topk_lp[k].item(), tokens + [topk_ids[k].item()])) | |
| if not candidates: | |
| break | |
| candidates.sort(key=lambda x: x[0] / (len(x[1]) ** length_penalty), reverse=True) | |
| beams = candidates[:beam_width] | |
| if all(b[1][-1] == EOS_ID for b in beams): | |
| completed.extend(beams) | |
| break | |
| completed.extend(beams) | |
| best = max(completed, key=lambda x: x[0] / (len(x[1]) ** length_penalty)) | |
| return torch.tensor([best[1]], dtype=torch.long, device=DEVICE) | |
| def translate(text: str, method: str = "Greedy", beam_width: int = 5) -> str: | |
| if not text or not text.strip(): | |
| return "" | |
| src_ids = tokenizer.encode(text).ids | |
| src = torch.tensor([src_ids], dtype=torch.long, device=DEVICE) | |
| if method == "Beam Search": | |
| out_ids = beam_search_decode(model, src, beam_width=beam_width) | |
| else: | |
| out_ids = greedy_decode(model, src) | |
| raw = tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True) | |
| return postprocess(raw) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Load model & tokenizer on startup | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("Downloading model and tokenizer...") | |
| weights_path = hf_hub_download(MODEL_REPO, "best_model.pt") | |
| tok_path = hf_hub_download(MODEL_REPO, "tokenizer_shared_16k.json") | |
| tokenizer = Tokenizer.from_file(tok_path) | |
| model = TransformerTranslator(**MODEL_KWARGS).to(DEVICE) | |
| state = torch.load(weights_path, map_location=DEVICE, weights_only=True) | |
| model.load_state_dict(state) | |
| model.eval() | |
| print("Model loaded!") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| EXAMPLES = [ | |
| ["Hello, how are you?"], | |
| ["What time is the meeting tomorrow?"], | |
| ["I love learning new languages."], | |
| ["The weather is beautiful today."], | |
| ["Can you help me find the nearest hospital?"], | |
| ["Thank you very much for your help."], | |
| ["She may be dying and it's all my fault."], | |
| ["We'll be ready for the shipment."], | |
| ] | |
| with gr.Blocks( | |
| title="English β Malay Translator", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # π¬π§ β π²πΎ English to Malay Translator | |
| A custom **6+2 Tied Transformer** trained from scratch on 2M OpenSubtitles sentence pairs. | |
| **Model:** [AstralPotato/en-ms-transformer](https://huggingface.co/AstralPotato/en-ms-transformer) | |
| | **chrF:** 52.14 (case-normalized, cleaned, beam=5) | |
| | **Params:** ~27M | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="English", | |
| placeholder="Type an English sentence...", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| method = gr.Radio( | |
| choices=["Greedy", "Beam Search"], | |
| value="Greedy", | |
| label="Decoding Method", | |
| ) | |
| beam_width = gr.Slider( | |
| minimum=2, maximum=10, value=5, step=1, | |
| label="Beam Width", | |
| visible=True, | |
| ) | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Malay", | |
| lines=3, | |
| interactive=False, | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[input_text], | |
| outputs=[output_text], | |
| fn=lambda text: translate(text, "Greedy"), | |
| cache_examples=True, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| <details> | |
| <summary>βΉοΈ About this model</summary> | |
| - **Architecture:** 6-layer encoder + 2-layer decoder, pre-norm Transformer | |
| - **Tokenizer:** 16K shared BPE | |
| - **Training data:** 2M filtered OpenSubtitles en-ms pairs | |
| - **Trained for:** IT3103 Advanced Topics in AI β Assignment 2, 2025 Semester 2 | |
| The model outputs **lowercase** text (tokenizer normalizes to lowercase), with post-processing to capitalize the first letter. | |
| Beam search is slower but may produce higher-quality translations. | |
| </details> | |
| """ | |
| ) | |
| # Wire up events | |
| translate_btn.click( | |
| fn=translate, | |
| inputs=[input_text, method, beam_width], | |
| outputs=[output_text], | |
| ) | |
| input_text.submit( | |
| fn=translate, | |
| inputs=[input_text, method, beam_width], | |
| outputs=[output_text], | |
| ) | |
| demo.launch() | |