import torch from modeling_multievalvietsum import MultiEvalVietSumModel def build_pair_feature(tokenizer, document, summary, max_len=2048, summary_max_len=192): sum_ids = tokenizer( summary, truncation=True, max_length=summary_max_len, add_special_tokens=False, return_attention_mask=False, )["input_ids"] doc_ids = tokenizer( document, truncation=False, add_special_tokens=False, return_attention_mask=False, )["input_ids"] special_pair_tokens = tokenizer.num_special_tokens_to_add(pair=True) doc_budget = max(16, max_len - len(sum_ids) - special_pair_tokens) doc_ids = doc_ids[:doc_budget] model_inputs = getattr(tokenizer, "model_input_names", []) return_token_type_ids = "token_type_ids" in model_inputs try: feat = tokenizer.prepare_for_model( doc_ids, pair_ids=sum_ids, add_special_tokens=True, padding=False, truncation=False, return_attention_mask=True, return_token_type_ids=return_token_type_ids, ) feat = {k: v for k, v in feat.items() if k in {"input_ids", "attention_mask", "token_type_ids"}} return feat except Exception: cls_id = tokenizer.cls_token_id sep_id = tokenizer.sep_token_id input_ids = [cls_id] + doc_ids + [sep_id] + sum_ids + [sep_id] attention_mask = [1] * len(input_ids) feat = { "input_ids": input_ids, "attention_mask": attention_mask, } if return_token_type_ids: feat["token_type_ids"] = [0] * (len(doc_ids) + 2) + [1] * (len(sum_ids) + 1) return feat @torch.no_grad() def predict_scores(document: str, summary: str, model_dir: str = "."): model, cfg = MultiEvalVietSumModel.from_pretrained_local(model_dir) tokenizer = MultiEvalVietSumModel.load_tokenizer_local(model_dir) feat = build_pair_feature( tokenizer, document=document, summary=summary, max_len=cfg["max_len"], summary_max_len=cfg["summary_max_len"], ) batch = { "input_ids": torch.tensor([feat["input_ids"]], dtype=torch.long), "attention_mask": torch.tensor([feat["attention_mask"]], dtype=torch.long), } if "token_type_ids" in feat: batch["token_type_ids"] = torch.tensor([feat["token_type_ids"]], dtype=torch.long) scores = model(**batch)[0].cpu().tolist() return { "faithfulness": float(scores[0]), "coherence": float(scores[1]), "relevance": float(scores[2]), } if __name__ == "__main__": doc = "Văn bản gốc mẫu." summ = "Bản tóm tắt mẫu." print(predict_scores(doc, summ))