File size: 378 Bytes
8faa42b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import torch
import torch.nn as nn
from model import TransliterationTransformer
class HFTransliterator(nn.Module):
def __init__(self, config):
super().__init__()
self.model = TransliterationTransformer(
config["src_vocab_size"],
config["tgt_vocab_size"]
)
def forward(self, src, tgt):
return self.model(src, tgt) |