Trainin_Transformers / inference.py
Urfavghost's picture
Added models and code
39a7504
from pathlib import Path
from model import build_transformer
from util import create_resources
import torch
import sys
import yaml
import sacrebleu
def translate(sentence: str):
with open("config.yaml", "r") as file:
config = yaml.safe_load(file)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
src_vocab_size = tokenizer_src.get_vocab_size()
tgt_vocab_size = tokenizer_tgt.get_vocab_size()
model = build_transformer(
src_vocab_size,
tgt_vocab_size,
config["seq_len"],
config["seq_len"],
config["num_enc_dec_blocks"],
config["num_of_heads"],
config["d_model"]
)
model = model.to(device)
#model_filename = "models/big_models_res/model_epoch_2.pth"
model_filename = "models/model_epoch_15.pth"
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
model.eval()
with torch.no_grad():
source = tokenizer_src.encode(sentence)
print(source,source.ids)
source = torch.cat([
torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
torch.tensor(source.ids, dtype=torch.int64),
torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config["seq_len"] - len(source.ids) - 2), dtype=torch.int64)
], dim=0)
source = source.to(device)
source = source.unsqueeze(0)
print(source.shape)
source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
encoder_output = model.encode(source, source_mask)
decoder_input = torch.full((1, 1), tokenizer_tgt.token_to_id('[SOS]'),
dtype=torch.long, device=device)
while decoder_input.size(1) < config["seq_len"]:
decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
diagonal=1).to(device, dtype=torch.int)
out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
next_token = torch.full((1, 1), next_word.item(), dtype=torch.long, device=device)
decoder_input = torch.cat([decoder_input, next_token], dim=1)
print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
if next_word.item() == tokenizer_tgt.token_to_id('[EOS]'):
break
return tokenizer_tgt.decode(decoder_input[0].tolist())
#a = translate("Why does Earth have only one moon, while other planets have many?")
a = translate("Defending champion Kolkata Knight Riders (KKR) hosts Royal Challengers Bengaluru (RCB) at the Eden Gardens in the Indian Premier League 2025 opener on Saturday.")
a = translate("Which South American country is home to the Amazon Rainforest and the Christ the Redeemer statue?")
# a = translate("Imagine you are an astronaut stepping onto Mars for the first time. Write a monologue expressing your emotions and observations.")
# a = translate("The theory of evolution, proposed by Charles Darwin, explains the process by which species of organisms change over time through natural selection. While the theory has been widely accepted in the scientific community, it continues to spark debates in various social and religious contexts. Discuss how the theory of evolution has shaped our understanding of human origins and the controversies that surround it.")
with open("output.txt", "a") as w:
w.write(f"\n{a}")