Spaces:
Build error
Build error
Commit
·
e243b3e
1
Parent(s):
51ffa64
|endoftext| token handled
Browse files
app.py
CHANGED
|
@@ -31,7 +31,8 @@ model = load_model_from_hf()
|
|
| 31 |
model.train(False)
|
| 32 |
|
| 33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
| 34 |
-
|
|
|
|
| 35 |
tokens = enc.encode(prompt)
|
| 36 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 37 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
|
@@ -54,7 +55,7 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
|
| 54 |
tokens = torch.cat((tokens, next_token), dim=1)
|
| 55 |
|
| 56 |
# Check for end of text token
|
| 57 |
-
if next_token.item() == enc.encode('<|endoftext|>')[0]:
|
| 58 |
break
|
| 59 |
|
| 60 |
generated_texts = []
|
|
|
|
| 31 |
model.train(False)
|
| 32 |
|
| 33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
| 34 |
+
# Initialize encoder with allowed special tokens
|
| 35 |
+
enc = tiktoken.get_encoding('gpt2', allowed_special={'<|endoftext|>'})
|
| 36 |
tokens = enc.encode(prompt)
|
| 37 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 38 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
|
|
|
| 55 |
tokens = torch.cat((tokens, next_token), dim=1)
|
| 56 |
|
| 57 |
# Check for end of text token
|
| 58 |
+
if next_token.item() == enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
|
| 59 |
break
|
| 60 |
|
| 61 |
generated_texts = []
|