Multi_EvalSumVietN / modeling_summary_evaluator.py
phuongntc's picture
Add modeling_summary_evaluator.py (full-doc loader)
426ba72 verified
import os, json
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
def mean_pool(last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
def build_all_segments(ids, seg_len, overlap):
n = len(ids)
if n == 0: return [[]]
if n <= seg_len: return [ids]
step = max(1, seg_len - overlap)
segs, start = [], 0
while start < n:
segs.append(ids[start:start + seg_len])
if start + seg_len >= n: break
start += step
return segs
def sample_segments_cover_whole_doc(all_segs, cap):
if len(all_segs) <= cap: return all_segs
idx = np.linspace(0, len(all_segs)-1, cap)
idx = np.unique(np.round(idx).astype(int)).tolist()
idx = sorted(idx)[:cap]
return [all_segs[i] for i in idx]
class MultiEvalSumVietN(nn.Module):
def __init__(self, base_dir):
super().__init__()
with open(os.path.join(base_dir, "arch_config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
self.cfg = cfg
self.backbone = AutoModel.from_pretrained(base_dir)
if hasattr(self.backbone.config, "use_cache"):
self.backbone.config.use_cache = False
hidden_in = cfg["trunk"]["hidden_in"]
hidden_mid = cfg["trunk"]["hidden_mid"]
dropout = cfg["trunk"]["dropout"]
self.trunk = nn.Sequential(nn.Linear(hidden_in, hidden_mid), nn.GELU(), nn.Dropout(dropout))
self.head_faith = nn.Linear(hidden_mid, 1)
self.head_coh = nn.Linear(hidden_mid, 1)
self.head_rel = nn.Linear(hidden_mid, 1)
self.trunk.load_state_dict(torch.load(os.path.join(base_dir, "trunk.pt"), map_location="cpu"))
self.head_faith.load_state_dict(torch.load(os.path.join(base_dir, "head_faith.pt"), map_location="cpu"))
self.head_coh.load_state_dict(torch.load(os.path.join(base_dir, "head_coh.pt"), map_location="cpu"))
self.head_rel.load_state_dict(torch.load(os.path.join(base_dir, "head_rel.pt"), map_location="cpu"))
self.agg_type = cfg.get("agg_type", "mean")
if self.agg_type == "attn":
raise FileNotFoundError("agg_type='attn' requires seg_attn.pt export+upload. Set agg_type='mean' for now.")
self.eval()
@torch.no_grad()
def forward(self, input_ids_3d, attention_mask_3d, seg_mask_2d):
B, K, T = input_ids_3d.shape
x = input_ids_3d.view(B*K, T)
a = attention_mask_3d.view(B*K, T)
out = self.backbone(input_ids=x, attention_mask=a).last_hidden_state
pooled = mean_pool(out, a).view(B, K, -1)
mask = seg_mask_2d.unsqueeze(-1).float()
pooled = pooled * mask
denom = mask.sum(dim=1).clamp_min(1e-6)
doc_repr = pooled.sum(dim=1) / denom # mean aggregation
z = self.trunk(doc_repr)
y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1)
return y
def load_for_inference(base_dir, device=None):
tok = AutoTokenizer.from_pretrained(base_dir, use_fast=True)
mdl = MultiEvalSumVietN(base_dir)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
mdl.to(device).eval()
return mdl, tok, device
def encode_full_doc(tok, docs, sums):
with open(os.path.join(tok.name_or_path, "arch_config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
max_len = int(cfg["max_len"])
sum_max_len = int(cfg["sum_max_len"])
seg_len = int(cfg["seg_len"])
seg_overlap = int(cfg["seg_overlap"])
cap = int(cfg["max_segs_cap"])
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token if tok.eos_token is not None else "[PAD]"
pad_id = tok.pad_token_id
tok_s = tok(sums, truncation=True, max_length=sum_max_len, add_special_tokens=False, return_attention_mask=False)
sum_ids_list = tok_s["input_ids"]
tok_d = tok(docs, truncation=False, add_special_tokens=False, return_attention_mask=False)
doc_ids_list = tok_d["input_ids"]
segs_per_sample = []
for ids in doc_ids_list:
all_segs = build_all_segments(ids, seg_len, seg_overlap)
segs = sample_segments_cover_whole_doc(all_segs, cap)
segs_per_sample.append(segs)
B = len(docs)
K = max(len(s) for s in segs_per_sample)
flat_ids, flat_attn, flat_segmask = [], [], []
for i in range(B):
segs = segs_per_sample[i]
for seg in segs:
pair = tok.prepare_for_model(seg, sum_ids_list[i], truncation="only_first",
max_length=max_len, add_special_tokens=True,
return_attention_mask=True)
flat_ids.append(torch.tensor(pair["input_ids"], dtype=torch.long))
flat_attn.append(torch.tensor(pair["attention_mask"], dtype=torch.long))
flat_segmask.append(1)
for _ in range(len(segs), K):
flat_ids.append(torch.tensor([pad_id], dtype=torch.long))
flat_attn.append(torch.tensor([1], dtype=torch.long))
flat_segmask.append(0)
T = max(x.numel() for x in flat_ids)
T = ((T + 7)//8)*8
def pad_2d(xs):
out = []
for x in xs:
if x.numel() < T:
out.append(torch.cat([x, torch.full((T-x.numel(),), pad_id, dtype=torch.long)]))
else:
out.append(x)
return torch.stack(out, dim=0)
ids = pad_2d(flat_ids).view(B, K, T)
attn = pad_2d(flat_attn).view(B, K, T)
segm = torch.tensor(flat_segmask, dtype=torch.long).view(B, K)
return ids, attn, segm