Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,10 +29,15 @@ model = load_model_from_huggingface()
|
|
| 29 |
# Force model to stay in eval mode
|
| 30 |
model.train(False)
|
| 31 |
|
| 32 |
-
def generate_text(prompt, max_length=
|
| 33 |
enc = tiktoken.get_encoding('gpt2')
|
| 34 |
tokens = enc.encode(prompt)
|
| 35 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
| 37 |
tokens = tokens.to(device)
|
| 38 |
|
|
|
|
| 29 |
# Force model to stay in eval mode
|
| 30 |
model.train(False)
|
| 31 |
|
| 32 |
+
def generate_text(prompt, max_length="25", num_samples="1"):
|
| 33 |
enc = tiktoken.get_encoding('gpt2')
|
| 34 |
tokens = enc.encode(prompt)
|
| 35 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 36 |
+
|
| 37 |
+
# Convert inputs to integers
|
| 38 |
+
max_length = int(max_length)
|
| 39 |
+
num_samples = int(num_samples)
|
| 40 |
+
|
| 41 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
| 42 |
tokens = tokens.to(device)
|
| 43 |
|