rellow / prediction.py
rafacamargo's picture
fix: relative file import paths
738d2af
import torch
import tiktoken
from src.constants.tokens import special_tokens
from src.services.model import load_model, get_device
# Initialize tokenizer
_tokenizer = tiktoken.get_encoding("cl100k_base")
# Generate an imaginary word and its definition from three input words.
def generate_word(words, model, vocab, inv_vocab, max_length=64):
device = get_device()
# Tokenize input words
input_text = ",".join(words)
input_tokens = _tokenizer.encode(input_text)
input_tensor = torch.tensor([vocab.get(str(tok), vocab["<pad>"]) for tok in input_tokens]).unsqueeze(0).to(device)
# Initialize target with SOS token
target = torch.tensor([[vocab["<sos>"]]]).to(device)
# Generate output
with torch.no_grad():
for _ in range(max_length):
output = model(input_tensor, target)
next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
# Stop if we predict EOS token
if next_token.item() == vocab["<eos>"]:
break
target = torch.cat([target, next_token], dim=1)
# Convert output tokens to text
output_tokens = target[0].cpu().numpy()
output_text = _tokenizer.decode([int(inv_vocab[tok]) for tok in output_tokens if tok not in special_tokens.values()])
return output_text
def main(words=None):
# Load model and vocabulary
model, vocab, inv_vocab = load_model()
# Use provided words or default example
if words is None:
words = ["muito", "grande", "imenso"]
result = generate_word(words, model, vocab, inv_vocab)
print(f"Input words: {', '.join(words)}")
print(f"Generated: {result}")
return result
if __name__ == "__main__":
main()