Spaces:
Build error
Build error
Commit
·
9c1b483
1
Parent(s):
c121a67
special token removed
Browse files
app.py
CHANGED
|
@@ -32,8 +32,6 @@ model.train(False)
|
|
| 32 |
|
| 33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
| 34 |
enc = tiktoken.get_encoding('gpt2')
|
| 35 |
-
# Modify encoding behavior to allow special tokens
|
| 36 |
-
enc._special_tokens.add('<|endoftext|>')
|
| 37 |
tokens = enc.encode(prompt)
|
| 38 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 39 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
|
@@ -55,10 +53,8 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
|
| 55 |
|
| 56 |
tokens = torch.cat((tokens, next_token), dim=1)
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
if next_token.item() == endoftext_token:
|
| 61 |
-
break
|
| 62 |
|
| 63 |
generated_texts = []
|
| 64 |
for i in range(num_samples):
|
|
|
|
| 32 |
|
| 33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
| 34 |
enc = tiktoken.get_encoding('gpt2')
|
|
|
|
|
|
|
| 35 |
tokens = enc.encode(prompt)
|
| 36 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 37 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
|
|
|
| 53 |
|
| 54 |
tokens = torch.cat((tokens, next_token), dim=1)
|
| 55 |
|
| 56 |
+
# Remove special token check entirely
|
| 57 |
+
# Just generate for the specified length or until context limit
|
|
|
|
|
|
|
| 58 |
|
| 59 |
generated_texts = []
|
| 60 |
for i in range(num_samples):
|