|
|
import json |
|
|
import re |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderGRU(nn.Module): |
|
|
def __init__(self, input_dim, emb_dim, hid_dim, num_layers=1, dropout=0.1, pad_idx=0): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx) |
|
|
self.gru = nn.GRU( |
|
|
emb_dim, |
|
|
hid_dim, |
|
|
num_layers=num_layers, |
|
|
batch_first=True, |
|
|
bidirectional=False, |
|
|
) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, src): |
|
|
|
|
|
embedded = self.dropout(self.embedding(src)) |
|
|
outputs, hidden = self.gru(embedded) |
|
|
return outputs, hidden |
|
|
|
|
|
|
|
|
class DecoderGRU(nn.Module): |
|
|
def __init__(self, output_dim, emb_dim, hid_dim, num_layers=1, dropout=0.1, pad_idx=0): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=pad_idx) |
|
|
self.gru = nn.GRU( |
|
|
emb_dim, |
|
|
hid_dim, |
|
|
num_layers=num_layers, |
|
|
batch_first=True, |
|
|
) |
|
|
self.fc_out = nn.Linear(hid_dim, output_dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, input, hidden): |
|
|
|
|
|
input = input.unsqueeze(1) |
|
|
embedded = self.dropout(self.embedding(input)) |
|
|
output, hidden = self.gru(embedded, hidden) |
|
|
output = output.squeeze(1) |
|
|
logits = self.fc_out(output) |
|
|
return logits, hidden |
|
|
|
|
|
|
|
|
class Seq2Seq(nn.Module): |
|
|
def __init__(self, encoder, decoder, pad_idx, sos_idx, eos_idx, device): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
self.pad_idx = pad_idx |
|
|
self.sos_idx = sos_idx |
|
|
self.eos_idx = eos_idx |
|
|
self.device = device |
|
|
|
|
|
def forward(self, src, src_lens, tgt=None, teacher_forcing_ratio=0.5): |
|
|
|
|
|
raise NotImplementedError("Use transliterate_word for inference.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
MODEL_REPO = "harishwar017/hindi-roman-gru" |
|
|
|
|
|
|
|
|
src_json_path = hf_hub_download(repo_id=MODEL_REPO, filename="src_stoi.json") |
|
|
tgt_json_path = hf_hub_download(repo_id=MODEL_REPO, filename="tgt_stoi.json") |
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_hindi_roman_gru.pt") |
|
|
|
|
|
|
|
|
with open(src_json_path, "r", encoding="utf-8") as f: |
|
|
src_stoi = json.load(f) |
|
|
|
|
|
with open(tgt_json_path, "r", encoding="utf-8") as f: |
|
|
tgt_stoi = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tgt_itos = {int(v): k for k, v in tgt_stoi.items()} |
|
|
|
|
|
PAD_TOKEN = "<pad>" |
|
|
SOS_TOKEN = "<sos>" |
|
|
EOS_TOKEN = "<eos>" |
|
|
|
|
|
PAD_IDX = tgt_stoi[PAD_TOKEN] |
|
|
SOS_IDX = tgt_stoi[SOS_TOKEN] |
|
|
EOS_IDX = tgt_stoi[EOS_TOKEN] |
|
|
|
|
|
INPUT_DIM = len(src_stoi) |
|
|
OUTPUT_DIM = len(tgt_stoi) |
|
|
|
|
|
ENC_EMB_DIM = 128 |
|
|
DEC_EMB_DIM = 128 |
|
|
HID_DIM = 256 |
|
|
NUM_LAYERS = 1 |
|
|
ENC_DROPOUT = 0.2 |
|
|
DEC_DROPOUT = 0.2 |
|
|
|
|
|
encoder = EncoderGRU( |
|
|
input_dim=INPUT_DIM, |
|
|
emb_dim=ENC_EMB_DIM, |
|
|
hid_dim=HID_DIM, |
|
|
num_layers=NUM_LAYERS, |
|
|
dropout=ENC_DROPOUT, |
|
|
pad_idx=PAD_IDX, |
|
|
) |
|
|
|
|
|
decoder = DecoderGRU( |
|
|
output_dim=OUTPUT_DIM, |
|
|
emb_dim=DEC_EMB_DIM, |
|
|
hid_dim=HID_DIM, |
|
|
num_layers=NUM_LAYERS, |
|
|
dropout=DEC_DROPOUT, |
|
|
pad_idx=PAD_IDX, |
|
|
) |
|
|
|
|
|
model = Seq2Seq( |
|
|
encoder=encoder, |
|
|
decoder=decoder, |
|
|
pad_idx=PAD_IDX, |
|
|
sos_idx=SOS_IDX, |
|
|
eos_idx=EOS_IDX, |
|
|
device=device, |
|
|
).to(device) |
|
|
|
|
|
|
|
|
state_dict = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def indices_to_string(indices): |
|
|
chars = [] |
|
|
for idx in indices: |
|
|
if idx == EOS_IDX or idx == PAD_IDX: |
|
|
break |
|
|
chars.append(tgt_itos[idx]) |
|
|
return "".join(chars) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def transliterate_word(word: str, max_len: int = 30) -> str: |
|
|
|
|
|
src_ids = [src_stoi[ch] for ch in word if ch in src_stoi] |
|
|
if not src_ids: |
|
|
return "" |
|
|
|
|
|
src_tensor = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0) |
|
|
src_lens = torch.tensor([len(src_ids)], dtype=torch.long) |
|
|
|
|
|
|
|
|
_, hidden = model.encoder(src_tensor) |
|
|
|
|
|
|
|
|
input_token = torch.tensor([SOS_IDX], dtype=torch.long, device=device) |
|
|
decoded_indices = [] |
|
|
|
|
|
for _ in range(max_len): |
|
|
output, hidden = model.decoder(input_token, hidden) |
|
|
top1 = output.argmax(1) |
|
|
idx = top1.item() |
|
|
if idx == EOS_IDX: |
|
|
break |
|
|
decoded_indices.append(idx) |
|
|
input_token = top1 |
|
|
|
|
|
return indices_to_string(decoded_indices) |
|
|
|
|
|
|
|
|
def tokenize_with_punct(text: str): |
|
|
|
|
|
return re.findall(r'\w+|\S', text, flags=re.UNICODE) |
|
|
|
|
|
|
|
|
def is_punctuation_token(tok: str) -> bool: |
|
|
return all(not ch.isalnum() for ch in tok) |
|
|
|
|
|
|
|
|
def map_punctuation(tok: str) -> str: |
|
|
if tok == "।": |
|
|
return "." |
|
|
return tok |
|
|
|
|
|
import regex as re |
|
|
|
|
|
def tokenize_with_correct_unicode(text: str): |
|
|
""" |
|
|
Splits text by matching contiguous word tokens (including Devanagari matras) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return re.findall(r'[\w\p{L}\p{M}\p{N}]+|\S', text, flags=re.UNICODE) |
|
|
|
|
|
|
|
|
def transliterate_sentence(sentence: str, max_word_len: int = 30) -> str: |
|
|
if not sentence.strip(): |
|
|
return "" |
|
|
|
|
|
tokens = tokenize_with_correct_unicode(sentence) |
|
|
out_tokens = [] |
|
|
|
|
|
for tok in tokens: |
|
|
if is_punctuation_token(tok): |
|
|
out_tokens.append(map_punctuation(tok)) |
|
|
else: |
|
|
out_tokens.append(transliterate_word(tok, max_len=max_word_len)) |
|
|
|
|
|
|
|
|
result = "" |
|
|
for i, tok in enumerate(out_tokens): |
|
|
if i == 0: |
|
|
result += tok |
|
|
else: |
|
|
if tok in [".", ",", "!", "?", ";", ":", ")", "”"]: |
|
|
result += tok |
|
|
elif result and result[-1] in ["(", "“"]: |
|
|
result += tok |
|
|
else: |
|
|
result += " " + tok |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_fn(text): |
|
|
return transliterate_sentence(text) |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_fn, |
|
|
inputs=gr.Textbox(lines=3, label="Hindi sentence"), |
|
|
outputs=gr.Textbox(lines=3, label="Romanized (Latin script)"), |
|
|
title="Hindi → Roman Transliteration (Char-level GRU)", |
|
|
description="Paste a Hindi sentence; the model splits it into words, transliterates each with a GRU, and rejoins the output.", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|