AstralPotato's picture
Initial Gradio demo: English β†’ Malay live inference
655f379 verified
"""
English β†’ Malay Translator β€” Gradio Interface
================================================
A live inference demo for the 6+2 Tied Transformer (16K shared BPE).
Downloads model weights and tokenizer from the HuggingFace model repo
on startup, then serves translations via a Gradio UI.
"""
import math
import re
from typing import Optional
import torch
import torch.nn as nn
import gradio as gr
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
# ──────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────
MODEL_REPO = "AstralPotato/en-ms-transformer"
DEVICE = torch.device("cpu") # Spaces free tier = CPU only
MODEL_KWARGS = dict(
vocab_size=16000, d_model=512, n_head=8,
num_encoder_layers=6, num_decoder_layers=2,
d_ff=2048, dropout=0.0, max_len=144, pad_idx=0,
)
BOS_ID, EOS_ID, PAD_ID = 5, 6, 0
MAX_DECODE_LEN = 128
# ──────────────────────────────────────────────────────────────────────
# Model architecture (inlined to avoid import issues on Spaces)
# ──────────────────────────────────────────────────────────────────────
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class TransformerTranslator(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
n_head: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 2,
d_ff: int = 2048,
dropout: float = 0.1,
max_len: int = 512,
pad_idx: int = 0,
):
super().__init__()
self.pad_idx = pad_idx
self.d_model = d_model
self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.embed_scale = math.sqrt(d_model)
self.transformer = nn.Transformer(
d_model=d_model, nhead=n_head,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=d_ff, dropout=dropout,
batch_first=True, norm_first=True,
)
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
nn.init.normal_(self.shared_embedding.weight, mean=0, std=d_model ** -0.5)
with torch.no_grad():
self.shared_embedding.weight[pad_idx].zero_()
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
@staticmethod
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
return x == self.pad_idx
def encode(self, src, src_key_padding_mask=None):
if src_key_padding_mask is None:
src_key_padding_mask = self._make_pad_mask(src)
return self.transformer.encoder(self._embed(src), src_key_padding_mask=src_key_padding_mask)
def decode(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None):
if tgt_key_padding_mask is None:
tgt_key_padding_mask = self._make_pad_mask(tgt)
tgt_mask = self.generate_square_subsequent_mask(tgt.size(1), tgt.device)
out = self.transformer.decoder(
self._embed(tgt), memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
# ──────────────────────────────────────────────────────────────────────
# Decoding & post-processing
# ──────────────────────────────────────────────────────────────────────
def postprocess(text: str) -> str:
text = re.sub(r'\s+([.,?!;:)\]}"\'…])', r'\1', text)
text = re.sub(r'([(\[{"\'])\s+', r'\1', text)
text = re.sub(r'\s*-\s*', '-', text)
text = re.sub(r'\s{2,}', ' ', text)
text = text.strip()
if text:
text = text[0].upper() + text[1:]
return text
@torch.no_grad()
def greedy_decode(model, src, max_len=MAX_DECODE_LEN):
src_pad_mask = (src == PAD_ID)
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
ys = torch.tensor([[BOS_ID]], dtype=torch.long, device=DEVICE)
for _ in range(max_len - 1):
logits = model.decode(ys, memory, memory_key_padding_mask=src_pad_mask)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
ys = torch.cat([ys, next_token], dim=1)
if next_token.item() == EOS_ID:
break
return ys
@torch.no_grad()
def beam_search_decode(model, src, beam_width=5, length_penalty=0.6, max_len=MAX_DECODE_LEN):
src_pad_mask = (src == PAD_ID)
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
beams = [(0.0, [BOS_ID])]
completed = []
for _ in range(max_len - 1):
candidates = []
for score, tokens in beams:
if tokens[-1] == EOS_ID:
completed.append((score, tokens))
continue
ys = torch.tensor([tokens], dtype=torch.long, device=DEVICE)
logits = model.decode(ys, memory, memory_key_padding_mask=src_pad_mask)
log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
topk_lp, topk_ids = log_probs.topk(beam_width)
for k in range(beam_width):
candidates.append((score + topk_lp[k].item(), tokens + [topk_ids[k].item()]))
if not candidates:
break
candidates.sort(key=lambda x: x[0] / (len(x[1]) ** length_penalty), reverse=True)
beams = candidates[:beam_width]
if all(b[1][-1] == EOS_ID for b in beams):
completed.extend(beams)
break
completed.extend(beams)
best = max(completed, key=lambda x: x[0] / (len(x[1]) ** length_penalty))
return torch.tensor([best[1]], dtype=torch.long, device=DEVICE)
def translate(text: str, method: str = "Greedy", beam_width: int = 5) -> str:
if not text or not text.strip():
return ""
src_ids = tokenizer.encode(text).ids
src = torch.tensor([src_ids], dtype=torch.long, device=DEVICE)
if method == "Beam Search":
out_ids = beam_search_decode(model, src, beam_width=beam_width)
else:
out_ids = greedy_decode(model, src)
raw = tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True)
return postprocess(raw)
# ──────────────────────────────────────────────────────────────────────
# Load model & tokenizer on startup
# ──────────────────────────────────────────────────────────────────────
print("Downloading model and tokenizer...")
weights_path = hf_hub_download(MODEL_REPO, "best_model.pt")
tok_path = hf_hub_download(MODEL_REPO, "tokenizer_shared_16k.json")
tokenizer = Tokenizer.from_file(tok_path)
model = TransformerTranslator(**MODEL_KWARGS).to(DEVICE)
state = torch.load(weights_path, map_location=DEVICE, weights_only=True)
model.load_state_dict(state)
model.eval()
print("Model loaded!")
# ──────────────────────────────────────────────────────────────────────
# Gradio UI
# ──────────────────────────────────────────────────────────────────────
EXAMPLES = [
["Hello, how are you?"],
["What time is the meeting tomorrow?"],
["I love learning new languages."],
["The weather is beautiful today."],
["Can you help me find the nearest hospital?"],
["Thank you very much for your help."],
["She may be dying and it's all my fault."],
["We'll be ready for the shipment."],
]
with gr.Blocks(
title="English β†’ Malay Translator",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"""
# πŸ‡¬πŸ‡§ β†’ πŸ‡²πŸ‡Ύ English to Malay Translator
A custom **6+2 Tied Transformer** trained from scratch on 2M OpenSubtitles sentence pairs.
**Model:** [AstralPotato/en-ms-transformer](https://huggingface.co/AstralPotato/en-ms-transformer)
 |  **chrF:** 52.14 (case-normalized, cleaned, beam=5)
 |  **Params:** ~27M
"""
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="English",
placeholder="Type an English sentence...",
lines=3,
)
with gr.Row():
method = gr.Radio(
choices=["Greedy", "Beam Search"],
value="Greedy",
label="Decoding Method",
)
beam_width = gr.Slider(
minimum=2, maximum=10, value=5, step=1,
label="Beam Width",
visible=True,
)
translate_btn = gr.Button("Translate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Malay",
lines=3,
interactive=False,
)
gr.Examples(
examples=EXAMPLES,
inputs=[input_text],
outputs=[output_text],
fn=lambda text: translate(text, "Greedy"),
cache_examples=True,
)
gr.Markdown(
"""
---
<details>
<summary>ℹ️ About this model</summary>
- **Architecture:** 6-layer encoder + 2-layer decoder, pre-norm Transformer
- **Tokenizer:** 16K shared BPE
- **Training data:** 2M filtered OpenSubtitles en-ms pairs
- **Trained for:** IT3103 Advanced Topics in AI β€” Assignment 2, 2025 Semester 2
The model outputs **lowercase** text (tokenizer normalizes to lowercase), with post-processing to capitalize the first letter.
Beam search is slower but may produce higher-quality translations.
</details>
"""
)
# Wire up events
translate_btn.click(
fn=translate,
inputs=[input_text, method, beam_width],
outputs=[output_text],
)
input_text.submit(
fn=translate,
inputs=[input_text, method, beam_width],
outputs=[output_text],
)
demo.launch()