File size: 1,653 Bytes
e663138
 
738d2af
 
e663138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import tiktoken
from src.constants.tokens import special_tokens
from src.services.model import load_model, get_device

# Initialize tokenizer
_tokenizer = tiktoken.get_encoding("cl100k_base")

# Generate an imaginary word and its definition from three input words.
def generate_word(words, model, vocab, inv_vocab, max_length=64):
  device = get_device()
  
  # Tokenize input words
  input_text = ",".join(words)
  input_tokens = _tokenizer.encode(input_text)
  input_tensor = torch.tensor([vocab.get(str(tok), vocab["<pad>"]) for tok in input_tokens]).unsqueeze(0).to(device)
  
  # Initialize target with SOS token
  target = torch.tensor([[vocab["<sos>"]]]).to(device)
  
  # Generate output
  with torch.no_grad():
    for _ in range(max_length):
      output = model(input_tensor, target)
      next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
      
      # Stop if we predict EOS token
      if next_token.item() == vocab["<eos>"]:
        break
        
      target = torch.cat([target, next_token], dim=1)
  
  # Convert output tokens to text
  output_tokens = target[0].cpu().numpy()
  output_text = _tokenizer.decode([int(inv_vocab[tok]) for tok in output_tokens if tok not in special_tokens.values()])
  
  return output_text

def main(words=None):
  # Load model and vocabulary
  model, vocab, inv_vocab = load_model()
  
  # Use provided words or default example
  if words is None:
    words = ["muito", "grande", "imenso"]
  
  result = generate_word(words, model, vocab, inv_vocab)
  print(f"Input words: {', '.join(words)}")
  print(f"Generated: {result}")
  return result

if __name__ == "__main__":
  main()