Commit ·
74d656b
1
Parent(s): d91835f
change infer function
Browse files- infer_concat.py +6 -6
infer_concat.py
CHANGED
|
@@ -63,7 +63,7 @@ def processing_data_infer(input_file):
|
|
| 63 |
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
| 64 |
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
| 65 |
|
| 66 |
-
device = torch.device('
|
| 67 |
model.to(device)
|
| 68 |
|
| 69 |
model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
|
|
@@ -90,12 +90,12 @@ def infer_2_hier(model, data_loader, device, tokenizer):
|
|
| 90 |
summary = model.generate(inputs[i].to(device),
|
| 91 |
attention_mask=att_mask[i].to(device),
|
| 92 |
max_length=128,
|
| 93 |
-
num_beams=
|
| 94 |
-
num_return_sequences=1)
|
| 95 |
summaries.append(summary)
|
| 96 |
summaries = torch.cat(summaries, dim = 1)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
|
| 100 |
|
| 101 |
end = time.time()
|
|
@@ -104,6 +104,6 @@ def infer_2_hier(model, data_loader, device, tokenizer):
|
|
| 104 |
|
| 105 |
def vit5_infer(data):
|
| 106 |
dataset = Dataset4Summarization(data, tokenizer)
|
| 107 |
-
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1
|
| 108 |
result = infer_2_hier(model, data_loader, device, tokenizer)
|
| 109 |
return result
|
|
|
|
| 63 |
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
| 64 |
model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-base-vietnews-summarization")
|
| 65 |
|
| 66 |
+
device = torch.device('cpu')
|
| 67 |
model.to(device)
|
| 68 |
|
| 69 |
model.load_state_dict(torch.load("./weight_cp19_model.pth", map_location=torch.device('cpu')))
|
|
|
|
| 90 |
summary = model.generate(inputs[i].to(device),
|
| 91 |
attention_mask=att_mask[i].to(device),
|
| 92 |
max_length=128,
|
| 93 |
+
num_beams=4,
|
| 94 |
+
num_return_sequences=1, no_repeat_ngram_size=3)
|
| 95 |
summaries.append(summary)
|
| 96 |
summaries = torch.cat(summaries, dim = 1)
|
| 97 |
+
|
| 98 |
+
all_summaries.append(tokenizer.decode(summaries, skip_special_tokens=True))
|
| 99 |
|
| 100 |
|
| 101 |
end = time.time()
|
|
|
|
| 104 |
|
| 105 |
def vit5_infer(data):
|
| 106 |
dataset = Dataset4Summarization(data, tokenizer)
|
| 107 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
|
| 108 |
result = infer_2_hier(model, data_loader, device, tokenizer)
|
| 109 |
return result
|