Multi_EvalSumViet2 / modeling_summary_evaluator.py
phuongntc's picture
Initial release: backbone (videberta), trunk+3 heads, configs, loader, README
e1c2520 verified
import os, json, torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
def mean_pool(last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
class SummaryEvaluatorModule(nn.Module):
def __init__(self, base_model_dir):
super().__init__()
self.backbone = AutoModel.from_pretrained(base_model_dir)
with open(os.path.join(base_model_dir, "arch_config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
hidden = cfg["trunk"]["hidden_in"]
self.trunk = nn.Sequential(
nn.Linear(hidden, cfg["trunk"]["hidden_mid"]),
nn.GELU(),
nn.Dropout(cfg["trunk"]["dropout"])
)
H = cfg["trunk"]["hidden_mid"]
self.head_faith = nn.Linear(H, 1)
self.head_coh = nn.Linear(H, 1)
self.head_rel = nn.Linear(H, 1)
# load weights
self.trunk.load_state_dict(torch.load(os.path.join(base_model_dir, "trunk.pt"), map_location="cpu"))
self.head_faith.load_state_dict(torch.load(os.path.join(base_model_dir, "head_faith.pt"), map_location="cpu"))
self.head_coh.load_state_dict(torch.load(os.path.join(base_model_dir, "head_coh.pt"), map_location="cpu"))
self.head_rel.load_state_dict(torch.load(os.path.join(base_model_dir, "head_rel.pt"), map_location="cpu"))
self.eval()
@torch.no_grad()
def forward(self, input_ids, attention_mask):
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
pooled = mean_pool(out.last_hidden_state, attention_mask)
z = self.trunk(pooled)
y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1)
return y # [B, 3] in [0,1]
def load_for_inference(base_model_dir, device=None):
tok = AutoTokenizer.from_pretrained(base_model_dir, use_fast=True)
mdl = SummaryEvaluatorModule(base_model_dir)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
mdl.to(device).eval()
return mdl, tok, device
def encode_pair(tokenizer, docs, sums, max_len=512, sum_max_len=256):
# 1) pre-trim summary
tok_sum = tokenizer(
sums, truncation=True, max_length=sum_max_len,
add_special_tokens=False, return_attention_mask=False
)
trimmed = tokenizer.batch_decode(
tok_sum["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=True
)
# 2) pair-encode: cắt doc, giữ summary
enc = tokenizer(
docs, trimmed, truncation="only_first", max_length=max_len,
add_special_tokens=True, return_attention_mask=True
)
return enc