Spaces:
Sleeping
Sleeping
File size: 1,653 Bytes
e663138 738d2af e663138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import tiktoken
from src.constants.tokens import special_tokens
from src.services.model import load_model, get_device
# Initialize tokenizer
_tokenizer = tiktoken.get_encoding("cl100k_base")
# Generate an imaginary word and its definition from three input words.
def generate_word(words, model, vocab, inv_vocab, max_length=64):
device = get_device()
# Tokenize input words
input_text = ",".join(words)
input_tokens = _tokenizer.encode(input_text)
input_tensor = torch.tensor([vocab.get(str(tok), vocab["<pad>"]) for tok in input_tokens]).unsqueeze(0).to(device)
# Initialize target with SOS token
target = torch.tensor([[vocab["<sos>"]]]).to(device)
# Generate output
with torch.no_grad():
for _ in range(max_length):
output = model(input_tensor, target)
next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
# Stop if we predict EOS token
if next_token.item() == vocab["<eos>"]:
break
target = torch.cat([target, next_token], dim=1)
# Convert output tokens to text
output_tokens = target[0].cpu().numpy()
output_text = _tokenizer.decode([int(inv_vocab[tok]) for tok in output_tokens if tok not in special_tokens.values()])
return output_text
def main(words=None):
# Load model and vocabulary
model, vocab, inv_vocab = load_model()
# Use provided words or default example
if words is None:
words = ["muito", "grande", "imenso"]
result = generate_word(words, model, vocab, inv_vocab)
print(f"Input words: {', '.join(words)}")
print(f"Generated: {result}")
return result
if __name__ == "__main__":
main()
|