harishwar017's picture
minor
3c5c50b
import json
import re
import torch
import torch.nn as nn
import gradio as gr
from huggingface_hub import hf_hub_download
########################################
# Model definitions (same as in notebook)
########################################
class EncoderGRU(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, num_layers=1, dropout=0.1, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
self.gru = nn.GRU(
emb_dim,
hid_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=False,
)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
# src: (B, src_len)
embedded = self.dropout(self.embedding(src))
outputs, hidden = self.gru(embedded)
return outputs, hidden # outputs not really used, but returned for completeness
class DecoderGRU(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, num_layers=1, dropout=0.1, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=pad_idx)
self.gru = nn.GRU(
emb_dim,
hid_dim,
num_layers=num_layers,
batch_first=True,
)
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden):
# input: (B,)
input = input.unsqueeze(1) # (B, 1)
embedded = self.dropout(self.embedding(input)) # (B, 1, emb_dim)
output, hidden = self.gru(embedded, hidden) # output: (B, 1, H)
output = output.squeeze(1) # (B, H)
logits = self.fc_out(output) # (B, vocab_size)
return logits, hidden
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, pad_idx, sos_idx, eos_idx, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.pad_idx = pad_idx
self.sos_idx = sos_idx
self.eos_idx = eos_idx
self.device = device
def forward(self, src, src_lens, tgt=None, teacher_forcing_ratio=0.5):
# Not used in inference in app; training logic is in notebook.
raise NotImplementedError("Use transliterate_word for inference.")
########################################
# Load vocab + model
########################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 🔴 CHANGE THIS: your actual model repo id
MODEL_REPO = "harishwar017/hindi-roman-gru"
# Download files from HF Hub into the Space’s local cache
src_json_path = hf_hub_download(repo_id=MODEL_REPO, filename="src_stoi.json")
tgt_json_path = hf_hub_download(repo_id=MODEL_REPO, filename="tgt_stoi.json")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_hindi_roman_gru.pt")
# Load vocabularies
with open(src_json_path, "r", encoding="utf-8") as f:
src_stoi = json.load(f)
with open(tgt_json_path, "r", encoding="utf-8") as f:
tgt_stoi = json.load(f)
# # Load vocabularies
# with open("src_stoi.json", "r", encoding="utf-8") as f:
# src_stoi = json.load(f)
# with open("tgt_stoi.json", "r", encoding="utf-8") as f:
# tgt_stoi = json.load(f)
# Build inverse mapping for target
tgt_itos = {int(v): k for k, v in tgt_stoi.items()} # keys might be strings in JSON
PAD_TOKEN = "<pad>"
SOS_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
PAD_IDX = tgt_stoi[PAD_TOKEN]
SOS_IDX = tgt_stoi[SOS_TOKEN]
EOS_IDX = tgt_stoi[EOS_TOKEN]
INPUT_DIM = len(src_stoi)
OUTPUT_DIM = len(tgt_stoi)
ENC_EMB_DIM = 128 # must match training
DEC_EMB_DIM = 128 # must match training
HID_DIM = 256 # must match training
NUM_LAYERS = 1
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2
encoder = EncoderGRU(
input_dim=INPUT_DIM,
emb_dim=ENC_EMB_DIM,
hid_dim=HID_DIM,
num_layers=NUM_LAYERS,
dropout=ENC_DROPOUT,
pad_idx=PAD_IDX,
)
decoder = DecoderGRU(
output_dim=OUTPUT_DIM,
emb_dim=DEC_EMB_DIM,
hid_dim=HID_DIM,
num_layers=NUM_LAYERS,
dropout=DEC_DROPOUT,
pad_idx=PAD_IDX,
)
model = Seq2Seq(
encoder=encoder,
decoder=decoder,
pad_idx=PAD_IDX,
sos_idx=SOS_IDX,
eos_idx=EOS_IDX,
device=device,
).to(device)
# Load weights that you saved from training: torch.save(model.state_dict(), "best_hindi_roman_gru.pt")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
########################################
# Inference helpers
########################################
def indices_to_string(indices):
chars = []
for idx in indices:
if idx == EOS_IDX or idx == PAD_IDX:
break
chars.append(tgt_itos[idx])
return "".join(chars)
@torch.no_grad()
def transliterate_word(word: str, max_len: int = 30) -> str:
# Map Hindi characters to indices, skip unknown chars
src_ids = [src_stoi[ch] for ch in word if ch in src_stoi]
if not src_ids:
return ""
src_tensor = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0) # (1, L)
src_lens = torch.tensor([len(src_ids)], dtype=torch.long)
# Encode
_, hidden = model.encoder(src_tensor)
# Decode
input_token = torch.tensor([SOS_IDX], dtype=torch.long, device=device)
decoded_indices = []
for _ in range(max_len):
output, hidden = model.decoder(input_token, hidden) # (1, vocab)
top1 = output.argmax(1) # (1,)
idx = top1.item()
if idx == EOS_IDX:
break
decoded_indices.append(idx)
input_token = top1
return indices_to_string(decoded_indices)
def tokenize_with_punct(text: str):
# Words + punctuation as separate tokens
return re.findall(r'\w+|\S', text, flags=re.UNICODE)
def is_punctuation_token(tok: str) -> bool:
return all(not ch.isalnum() for ch in tok)
def map_punctuation(tok: str) -> str:
if tok == "।":
return "."
return tok
import regex as re
def tokenize_with_correct_unicode(text: str):
"""
Splits text by matching contiguous word tokens (including Devanagari matras)
"""
# We use a pattern that groups Letters, Marks, and Numbers as a single token.
# The [a-zA-Z0-9] is redundant if using \p{L}\p{N}, but we keep \w for simplicity
# and explicitly add \p{M} to capture matras.
return re.findall(r'[\w\p{L}\p{M}\p{N}]+|\S', text, flags=re.UNICODE)
def transliterate_sentence(sentence: str, max_word_len: int = 30) -> str:
if not sentence.strip():
return ""
tokens = tokenize_with_correct_unicode(sentence)
out_tokens = []
for tok in tokens:
if is_punctuation_token(tok):
out_tokens.append(map_punctuation(tok))
else:
out_tokens.append(transliterate_word(tok, max_len=max_word_len))
# Simple detokenizer: space before words, no space before . , ! ? etc.
result = ""
for i, tok in enumerate(out_tokens):
if i == 0:
result += tok
else:
if tok in [".", ",", "!", "?", ";", ":", ")", "”"]:
result += tok
elif result and result[-1] in ["(", "“"]:
result += tok
else:
result += " " + tok
return result
########################################
# Gradio Interface
########################################
def gradio_fn(text):
return transliterate_sentence(text)
demo = gr.Interface(
fn=gradio_fn,
inputs=gr.Textbox(lines=3, label="Hindi sentence"),
outputs=gr.Textbox(lines=3, label="Romanized (Latin script)"),
title="Hindi → Roman Transliteration (Char-level GRU)",
description="Paste a Hindi sentence; the model splits it into words, transliterates each with a GRU, and rejoins the output.",
)
if __name__ == "__main__":
demo.launch()