satyanayak commited on
Commit
e243b3e
·
1 Parent(s): 51ffa64

|endoftext| token handled

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -31,7 +31,8 @@ model = load_model_from_hf()
31
  model.train(False)
32
 
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
- enc = tiktoken.get_encoding('gpt2')
 
35
  tokens = enc.encode(prompt)
36
  tokens = torch.tensor(tokens, dtype=torch.long)
37
  tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
@@ -54,7 +55,7 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
54
  tokens = torch.cat((tokens, next_token), dim=1)
55
 
56
  # Check for end of text token
57
- if next_token.item() == enc.encode('<|endoftext|>')[0]:
58
  break
59
 
60
  generated_texts = []
 
31
  model.train(False)
32
 
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
+ # Initialize encoder with allowed special tokens
35
+ enc = tiktoken.get_encoding('gpt2', allowed_special={'<|endoftext|>'})
36
  tokens = enc.encode(prompt)
37
  tokens = torch.tensor(tokens, dtype=torch.long)
38
  tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
 
55
  tokens = torch.cat((tokens, next_token), dim=1)
56
 
57
  # Check for end of text token
58
+ if next_token.item() == enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
59
  break
60
 
61
  generated_texts = []