MultiEvalVietSum / inference_example.py
phuongntc's picture
Upload MultiEvalVietSum: weights, tokenizer, config, code, and model card
2bcedff verified
import torch
from modeling_multievalvietsum import MultiEvalVietSumModel
def build_pair_feature(tokenizer, document, summary, max_len=2048, summary_max_len=192):
sum_ids = tokenizer(
summary,
truncation=True,
max_length=summary_max_len,
add_special_tokens=False,
return_attention_mask=False,
)["input_ids"]
doc_ids = tokenizer(
document,
truncation=False,
add_special_tokens=False,
return_attention_mask=False,
)["input_ids"]
special_pair_tokens = tokenizer.num_special_tokens_to_add(pair=True)
doc_budget = max(16, max_len - len(sum_ids) - special_pair_tokens)
doc_ids = doc_ids[:doc_budget]
model_inputs = getattr(tokenizer, "model_input_names", [])
return_token_type_ids = "token_type_ids" in model_inputs
try:
feat = tokenizer.prepare_for_model(
doc_ids,
pair_ids=sum_ids,
add_special_tokens=True,
padding=False,
truncation=False,
return_attention_mask=True,
return_token_type_ids=return_token_type_ids,
)
feat = {k: v for k, v in feat.items() if k in {"input_ids", "attention_mask", "token_type_ids"}}
return feat
except Exception:
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
input_ids = [cls_id] + doc_ids + [sep_id] + sum_ids + [sep_id]
attention_mask = [1] * len(input_ids)
feat = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if return_token_type_ids:
feat["token_type_ids"] = [0] * (len(doc_ids) + 2) + [1] * (len(sum_ids) + 1)
return feat
@torch.no_grad()
def predict_scores(document: str, summary: str, model_dir: str = "."):
model, cfg = MultiEvalVietSumModel.from_pretrained_local(model_dir)
tokenizer = MultiEvalVietSumModel.load_tokenizer_local(model_dir)
feat = build_pair_feature(
tokenizer,
document=document,
summary=summary,
max_len=cfg["max_len"],
summary_max_len=cfg["summary_max_len"],
)
batch = {
"input_ids": torch.tensor([feat["input_ids"]], dtype=torch.long),
"attention_mask": torch.tensor([feat["attention_mask"]], dtype=torch.long),
}
if "token_type_ids" in feat:
batch["token_type_ids"] = torch.tensor([feat["token_type_ids"]], dtype=torch.long)
scores = model(**batch)[0].cpu().tolist()
return {
"faithfulness": float(scores[0]),
"coherence": float(scores[1]),
"relevance": float(scores[2]),
}
if __name__ == "__main__":
doc = "Văn bản gốc mẫu."
summ = "Bản tóm tắt mẫu."
print(predict_scores(doc, summ))