resolve a error
Browse files- generate.py +2 -2
generate.py
CHANGED
|
@@ -11,7 +11,7 @@ def generate_text(model_data, input_text, max_new_token):
|
|
| 11 |
model_pipeline = model_data["pipeline"]
|
| 12 |
generated_text = model_pipeline(
|
| 13 |
input_text,
|
| 14 |
-
max_length=max_new_token,
|
| 15 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
| 16 |
truncation=True # فعال کردن truncation
|
| 17 |
)[0]["generated_text"]
|
|
@@ -40,7 +40,7 @@ def generate_text(model_data, input_text, max_new_token):
|
|
| 40 |
outputs = model.generate(
|
| 41 |
input_ids=input_ids,
|
| 42 |
attention_mask=attention_mask,
|
| 43 |
-
max_new_tokens=max_new_token,
|
| 44 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
| 45 |
pad_token_id=tokenizer.eos_token_id,
|
| 46 |
repetition_penalty=1.2,
|
|
|
|
| 11 |
model_pipeline = model_data["pipeline"]
|
| 12 |
generated_text = model_pipeline(
|
| 13 |
input_text,
|
| 14 |
+
max_length=max_new_token + len(input_text.split()), # افزایش max_length
|
| 15 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
| 16 |
truncation=True # فعال کردن truncation
|
| 17 |
)[0]["generated_text"]
|
|
|
|
| 40 |
outputs = model.generate(
|
| 41 |
input_ids=input_ids,
|
| 42 |
attention_mask=attention_mask,
|
| 43 |
+
max_new_tokens=max_new_token, # استفاده از max_new_tokens
|
| 44 |
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
| 45 |
pad_token_id=tokenizer.eos_token_id,
|
| 46 |
repetition_penalty=1.2,
|