Update README.md
Browse files
README.md
CHANGED
|
@@ -24,12 +24,15 @@ from transformers import (
|
|
| 24 |
|
| 25 |
|
| 26 |
def summarize(text, tokenizer, model, num_beams=4, temperature=1, max_new_tokens=512):
|
|
|
|
|
|
|
| 27 |
if len(text) < 20:
|
| 28 |
raise ValueError('Text must be at least 20 characters long.')
|
| 29 |
# This text template is important.
|
| 30 |
inputs = tokenizer(f'{text}\n### 住讬讻讜诐:', return_tensors="pt")
|
| 31 |
in_data = inputs.input_ids.to('cuda')
|
| 32 |
-
|
|
|
|
| 33 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
| 34 |
|
| 35 |
return generated_text
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def summarize(text, tokenizer, model, num_beams=4, temperature=1, max_new_tokens=512):
|
| 27 |
+
if tokenizer.pad_token is None:
|
| 28 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 29 |
if len(text) < 20:
|
| 30 |
raise ValueError('Text must be at least 20 characters long.')
|
| 31 |
# This text template is important.
|
| 32 |
inputs = tokenizer(f'{text}\n### 住讬讻讜诐:', return_tensors="pt")
|
| 33 |
in_data = inputs.input_ids.to('cuda')
|
| 34 |
+
attention_mask = inputs.attention_mask.to('cuda')
|
| 35 |
+
output_ids = model.generate(input_ids=in_data, attention_mask=attention_mask, num_beams=num_beams, max_new_tokens=max_new_tokens, do_sample=True, early_stopping=True, use_cache=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
|
| 36 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
| 37 |
|
| 38 |
return generated_text
|