import torch import tiktoken from src.constants.tokens import special_tokens from src.services.model import load_model, get_device # Initialize tokenizer _tokenizer = tiktoken.get_encoding("cl100k_base") # Generate an imaginary word and its definition from three input words. def generate_word(words, model, vocab, inv_vocab, max_length=64): device = get_device() # Tokenize input words input_text = ",".join(words) input_tokens = _tokenizer.encode(input_text) input_tensor = torch.tensor([vocab.get(str(tok), vocab[""]) for tok in input_tokens]).unsqueeze(0).to(device) # Initialize target with SOS token target = torch.tensor([[vocab[""]]]).to(device) # Generate output with torch.no_grad(): for _ in range(max_length): output = model(input_tensor, target) next_token = output[:, -1, :].argmax(dim=-1, keepdim=True) # Stop if we predict EOS token if next_token.item() == vocab[""]: break target = torch.cat([target, next_token], dim=1) # Convert output tokens to text output_tokens = target[0].cpu().numpy() output_text = _tokenizer.decode([int(inv_vocab[tok]) for tok in output_tokens if tok not in special_tokens.values()]) return output_text def main(words=None): # Load model and vocabulary model, vocab, inv_vocab = load_model() # Use provided words or default example if words is None: words = ["muito", "grande", "imenso"] result = generate_word(words, model, vocab, inv_vocab) print(f"Input words: {', '.join(words)}") print(f"Generated: {result}") return result if __name__ == "__main__": main()