Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -78,6 +78,12 @@ def ensure_complete_output(output, context, max_length, temperature, top_k, top_
|
|
| 78 |
|
| 79 |
# Text generation function for Gradio interface
|
| 80 |
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
with torch.no_grad():
|
| 82 |
with ctx:
|
| 83 |
start_ids = encode(prompt)
|
|
|
|
| 78 |
|
| 79 |
# Text generation function for Gradio interface
|
| 80 |
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
|
| 81 |
+
# Add input validation
|
| 82 |
+
if num_samples is None:
|
| 83 |
+
num_samples = 1
|
| 84 |
+
elif not isinstance(num_samples, int):
|
| 85 |
+
raise ValueError("Number of samples must be an integer")
|
| 86 |
+
|
| 87 |
with torch.no_grad():
|
| 88 |
with ctx:
|
| 89 |
start_ids = encode(prompt)
|