from pathlib import Path from model import build_transformer from util import create_resources import torch import sys import yaml import sacrebleu def translate(sentence: str): with open("config.yaml", "r") as file: config = yaml.safe_load(file) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources() src_vocab_size = tokenizer_src.get_vocab_size() tgt_vocab_size = tokenizer_tgt.get_vocab_size() model = build_transformer( src_vocab_size, tgt_vocab_size, config["seq_len"], config["seq_len"], config["num_enc_dec_blocks"], config["num_of_heads"], config["d_model"] ) model = model.to(device) #model_filename = "models/big_models_res/model_epoch_2.pth" model_filename = "models/model_epoch_15.pth" state = torch.load(model_filename) model.load_state_dict(state['model_state_dict']) model.eval() with torch.no_grad(): source = tokenizer_src.encode(sentence) print(source,source.ids) source = torch.cat([ torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64), torch.tensor(source.ids, dtype=torch.int64), torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64), torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config["seq_len"] - len(source.ids) - 2), dtype=torch.int64) ], dim=0) source = source.to(device) source = source.unsqueeze(0) print(source.shape) source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device) encoder_output = model.encode(source, source_mask) decoder_input = torch.full((1, 1), tokenizer_tgt.token_to_id('[SOS]'), dtype=torch.long, device=device) while decoder_input.size(1) < config["seq_len"]: decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).to(device, dtype=torch.int) out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask) prob = model.project(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_token = torch.full((1, 1), next_word.item(), dtype=torch.long, device=device) decoder_input = torch.cat([decoder_input, next_token], dim=1) print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ') if next_word.item() == tokenizer_tgt.token_to_id('[EOS]'): break return tokenizer_tgt.decode(decoder_input[0].tolist()) #a = translate("Why does Earth have only one moon, while other planets have many?") a = translate("Defending champion Kolkata Knight Riders (KKR) hosts Royal Challengers Bengaluru (RCB) at the Eden Gardens in the Indian Premier League 2025 opener on Saturday.") a = translate("Which South American country is home to the Amazon Rainforest and the Christ the Redeemer statue?") # a = translate("Imagine you are an astronaut stepping onto Mars for the first time. Write a monologue expressing your emotions and observations.") # a = translate("The theory of evolution, proposed by Charles Darwin, explains the process by which species of organisms change over time through natural selection. While the theory has been widely accepted in the scientific community, it continues to spark debates in various social and religious contexts. Discuss how the theory of evolution has shaped our understanding of human origins and the controversies that surround it.") with open("output.txt", "a") as w: w.write(f"\n{a}")