Spaces:
Sleeping
Sleeping
Commit
·
29c8c51
1
Parent(s):
1a241d5
fix: generation paramenters
Browse files
app.py
CHANGED
|
@@ -23,12 +23,13 @@ def generate(prompts: list[str]) -> tuple[list[str], list[dict[str, float]]]:
|
|
| 23 |
tokenize=False,
|
| 24 |
add_generation_prompt=True
|
| 25 |
)
|
| 26 |
-
model_inputs = tokenizer(texts, padding=True,
|
| 27 |
generated_ids = model.generate(
|
| 28 |
**model_inputs,
|
| 29 |
do_sample=False,
|
| 30 |
temperature=0,
|
| 31 |
repetition_penalty=1.0,
|
|
|
|
| 32 |
)
|
| 33 |
generated_ids = [
|
| 34 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
|
|
|
| 23 |
tokenize=False,
|
| 24 |
add_generation_prompt=True
|
| 25 |
)
|
| 26 |
+
model_inputs = tokenizer(texts, padding=True, return_tensors="pt").to(model.device)
|
| 27 |
generated_ids = model.generate(
|
| 28 |
**model_inputs,
|
| 29 |
do_sample=False,
|
| 30 |
temperature=0,
|
| 31 |
repetition_penalty=1.0,
|
| 32 |
+
max_new_tokens=512,
|
| 33 |
)
|
| 34 |
generated_ids = [
|
| 35 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|