| | import os, json |
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | def mean_pool(last_hidden_state, attention_mask): |
| | mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) |
| | summed = (last_hidden_state * mask).sum(dim=1) |
| | counts = mask.sum(dim=1).clamp(min=1e-9) |
| | return summed / counts |
| |
|
| | def build_all_segments(ids, seg_len, overlap): |
| | n = len(ids) |
| | if n == 0: return [[]] |
| | if n <= seg_len: return [ids] |
| | step = max(1, seg_len - overlap) |
| | segs, start = [], 0 |
| | while start < n: |
| | segs.append(ids[start:start + seg_len]) |
| | if start + seg_len >= n: break |
| | start += step |
| | return segs |
| |
|
| | def sample_segments_cover_whole_doc(all_segs, cap): |
| | if len(all_segs) <= cap: return all_segs |
| | idx = np.linspace(0, len(all_segs)-1, cap) |
| | idx = np.unique(np.round(idx).astype(int)).tolist() |
| | idx = sorted(idx)[:cap] |
| | return [all_segs[i] for i in idx] |
| |
|
| | class MultiEvalSumVietN(nn.Module): |
| | def __init__(self, base_dir): |
| | super().__init__() |
| | with open(os.path.join(base_dir, "arch_config.json"), "r", encoding="utf-8") as f: |
| | cfg = json.load(f) |
| | self.cfg = cfg |
| | self.backbone = AutoModel.from_pretrained(base_dir) |
| | if hasattr(self.backbone.config, "use_cache"): |
| | self.backbone.config.use_cache = False |
| |
|
| | hidden_in = cfg["trunk"]["hidden_in"] |
| | hidden_mid = cfg["trunk"]["hidden_mid"] |
| | dropout = cfg["trunk"]["dropout"] |
| |
|
| | self.trunk = nn.Sequential(nn.Linear(hidden_in, hidden_mid), nn.GELU(), nn.Dropout(dropout)) |
| | self.head_faith = nn.Linear(hidden_mid, 1) |
| | self.head_coh = nn.Linear(hidden_mid, 1) |
| | self.head_rel = nn.Linear(hidden_mid, 1) |
| |
|
| | self.trunk.load_state_dict(torch.load(os.path.join(base_dir, "trunk.pt"), map_location="cpu")) |
| | self.head_faith.load_state_dict(torch.load(os.path.join(base_dir, "head_faith.pt"), map_location="cpu")) |
| | self.head_coh.load_state_dict(torch.load(os.path.join(base_dir, "head_coh.pt"), map_location="cpu")) |
| | self.head_rel.load_state_dict(torch.load(os.path.join(base_dir, "head_rel.pt"), map_location="cpu")) |
| |
|
| | self.agg_type = cfg.get("agg_type", "mean") |
| | if self.agg_type == "attn": |
| | raise FileNotFoundError("agg_type='attn' requires seg_attn.pt export+upload. Set agg_type='mean' for now.") |
| |
|
| | self.eval() |
| |
|
| | @torch.no_grad() |
| | def forward(self, input_ids_3d, attention_mask_3d, seg_mask_2d): |
| | B, K, T = input_ids_3d.shape |
| | x = input_ids_3d.view(B*K, T) |
| | a = attention_mask_3d.view(B*K, T) |
| |
|
| | out = self.backbone(input_ids=x, attention_mask=a).last_hidden_state |
| | pooled = mean_pool(out, a).view(B, K, -1) |
| |
|
| | mask = seg_mask_2d.unsqueeze(-1).float() |
| | pooled = pooled * mask |
| |
|
| | denom = mask.sum(dim=1).clamp_min(1e-6) |
| | doc_repr = pooled.sum(dim=1) / denom |
| |
|
| | z = self.trunk(doc_repr) |
| | y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1) |
| | return y |
| |
|
| | def load_for_inference(base_dir, device=None): |
| | tok = AutoTokenizer.from_pretrained(base_dir, use_fast=True) |
| | mdl = MultiEvalSumVietN(base_dir) |
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | mdl.to(device).eval() |
| | return mdl, tok, device |
| |
|
| | def encode_full_doc(tok, docs, sums): |
| | with open(os.path.join(tok.name_or_path, "arch_config.json"), "r", encoding="utf-8") as f: |
| | cfg = json.load(f) |
| |
|
| | max_len = int(cfg["max_len"]) |
| | sum_max_len = int(cfg["sum_max_len"]) |
| | seg_len = int(cfg["seg_len"]) |
| | seg_overlap = int(cfg["seg_overlap"]) |
| | cap = int(cfg["max_segs_cap"]) |
| |
|
| | if tok.pad_token_id is None: |
| | tok.pad_token = tok.eos_token if tok.eos_token is not None else "[PAD]" |
| | pad_id = tok.pad_token_id |
| |
|
| | tok_s = tok(sums, truncation=True, max_length=sum_max_len, add_special_tokens=False, return_attention_mask=False) |
| | sum_ids_list = tok_s["input_ids"] |
| | tok_d = tok(docs, truncation=False, add_special_tokens=False, return_attention_mask=False) |
| | doc_ids_list = tok_d["input_ids"] |
| |
|
| | segs_per_sample = [] |
| | for ids in doc_ids_list: |
| | all_segs = build_all_segments(ids, seg_len, seg_overlap) |
| | segs = sample_segments_cover_whole_doc(all_segs, cap) |
| | segs_per_sample.append(segs) |
| |
|
| | B = len(docs) |
| | K = max(len(s) for s in segs_per_sample) |
| |
|
| | flat_ids, flat_attn, flat_segmask = [], [], [] |
| | for i in range(B): |
| | segs = segs_per_sample[i] |
| | for seg in segs: |
| | pair = tok.prepare_for_model(seg, sum_ids_list[i], truncation="only_first", |
| | max_length=max_len, add_special_tokens=True, |
| | return_attention_mask=True) |
| | flat_ids.append(torch.tensor(pair["input_ids"], dtype=torch.long)) |
| | flat_attn.append(torch.tensor(pair["attention_mask"], dtype=torch.long)) |
| | flat_segmask.append(1) |
| | for _ in range(len(segs), K): |
| | flat_ids.append(torch.tensor([pad_id], dtype=torch.long)) |
| | flat_attn.append(torch.tensor([1], dtype=torch.long)) |
| | flat_segmask.append(0) |
| |
|
| | T = max(x.numel() for x in flat_ids) |
| | T = ((T + 7)//8)*8 |
| |
|
| | def pad_2d(xs): |
| | out = [] |
| | for x in xs: |
| | if x.numel() < T: |
| | out.append(torch.cat([x, torch.full((T-x.numel(),), pad_id, dtype=torch.long)])) |
| | else: |
| | out.append(x) |
| | return torch.stack(out, dim=0) |
| |
|
| | ids = pad_2d(flat_ids).view(B, K, T) |
| | attn = pad_2d(flat_attn).view(B, K, T) |
| | segm = torch.tensor(flat_segmask, dtype=torch.long).view(B, K) |
| | return ids, attn, segm |
| |
|