Tiny-Translator / src /sample.py
caixiaoshun's picture
Upload 6 files
5153277 verified
import argparse
from tokenizers import Tokenizer
import torch
from src.config import Config
from src.model import TranslateModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", default="checkpoints/translate-step=290000.ckpt")
parser.add_argument("--zh", default="早上好")
return parser.parse_args()
class Inference:
def __init__(self,config:Config, ckpt_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer:Tokenizer = Tokenizer.from_file(config.tokenizer_file)
self.model:TranslateModel = TranslateModel(config)
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
state_dict = {}
for k, v in ckpt.items():
new_k = k[len("net._orig_mod."):]
state_dict[new_k] = v
self.model.load_state_dict(state_dict, strict=True)
self.model.eval()
self.model = self.model.to(self.device)
self.config = config
@torch.no_grad()
def sampler(self, src:str)->str:
src = self.tokenizer.encode(src).ids
tgt = [self.tokenizer.token_to_id("[SOS]")]
max_len = self.config.max_len
EOS = self.tokenizer.token_to_id("[EOS]")
src = torch.tensor(src, dtype=torch.long).to(self.device).unsqueeze(0)
tgt = torch.tensor(tgt, dtype=torch.long).to(self.device).unsqueeze(0)
for _ in range(1, max_len):
logits = self.model.forward(src, tgt) # [1, len, vocab]
logits = logits[:,-1,:]
logits = torch.softmax(logits, dim=-1)
index = torch.argmax(logits, dim=-1)
tgt = torch.cat((tgt, index.unsqueeze(0)), dim=-1)
if index.detach().cpu().item() == EOS:
break
tgt = tgt.detach().cpu().squeeze(0).tolist()
tgt_str = self.tokenizer.decode(tgt)
return tgt_str
def main():
args = get_args()
config = Config()
inference = Inference(config, args.ckpt_path)
zh = args.zh
result = inference.sampler(zh)
print(f"中文:{zh}")
print(f"English:{result}")
if __name__ == "__main__":
main()