File size: 5,756 Bytes
426ba72 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import os, json
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
def mean_pool(last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
def build_all_segments(ids, seg_len, overlap):
n = len(ids)
if n == 0: return [[]]
if n <= seg_len: return [ids]
step = max(1, seg_len - overlap)
segs, start = [], 0
while start < n:
segs.append(ids[start:start + seg_len])
if start + seg_len >= n: break
start += step
return segs
def sample_segments_cover_whole_doc(all_segs, cap):
if len(all_segs) <= cap: return all_segs
idx = np.linspace(0, len(all_segs)-1, cap)
idx = np.unique(np.round(idx).astype(int)).tolist()
idx = sorted(idx)[:cap]
return [all_segs[i] for i in idx]
class MultiEvalSumVietN(nn.Module):
def __init__(self, base_dir):
super().__init__()
with open(os.path.join(base_dir, "arch_config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
self.cfg = cfg
self.backbone = AutoModel.from_pretrained(base_dir)
if hasattr(self.backbone.config, "use_cache"):
self.backbone.config.use_cache = False
hidden_in = cfg["trunk"]["hidden_in"]
hidden_mid = cfg["trunk"]["hidden_mid"]
dropout = cfg["trunk"]["dropout"]
self.trunk = nn.Sequential(nn.Linear(hidden_in, hidden_mid), nn.GELU(), nn.Dropout(dropout))
self.head_faith = nn.Linear(hidden_mid, 1)
self.head_coh = nn.Linear(hidden_mid, 1)
self.head_rel = nn.Linear(hidden_mid, 1)
self.trunk.load_state_dict(torch.load(os.path.join(base_dir, "trunk.pt"), map_location="cpu"))
self.head_faith.load_state_dict(torch.load(os.path.join(base_dir, "head_faith.pt"), map_location="cpu"))
self.head_coh.load_state_dict(torch.load(os.path.join(base_dir, "head_coh.pt"), map_location="cpu"))
self.head_rel.load_state_dict(torch.load(os.path.join(base_dir, "head_rel.pt"), map_location="cpu"))
self.agg_type = cfg.get("agg_type", "mean")
if self.agg_type == "attn":
raise FileNotFoundError("agg_type='attn' requires seg_attn.pt export+upload. Set agg_type='mean' for now.")
self.eval()
@torch.no_grad()
def forward(self, input_ids_3d, attention_mask_3d, seg_mask_2d):
B, K, T = input_ids_3d.shape
x = input_ids_3d.view(B*K, T)
a = attention_mask_3d.view(B*K, T)
out = self.backbone(input_ids=x, attention_mask=a).last_hidden_state
pooled = mean_pool(out, a).view(B, K, -1)
mask = seg_mask_2d.unsqueeze(-1).float()
pooled = pooled * mask
denom = mask.sum(dim=1).clamp_min(1e-6)
doc_repr = pooled.sum(dim=1) / denom # mean aggregation
z = self.trunk(doc_repr)
y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1)
return y
def load_for_inference(base_dir, device=None):
tok = AutoTokenizer.from_pretrained(base_dir, use_fast=True)
mdl = MultiEvalSumVietN(base_dir)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
mdl.to(device).eval()
return mdl, tok, device
def encode_full_doc(tok, docs, sums):
with open(os.path.join(tok.name_or_path, "arch_config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
max_len = int(cfg["max_len"])
sum_max_len = int(cfg["sum_max_len"])
seg_len = int(cfg["seg_len"])
seg_overlap = int(cfg["seg_overlap"])
cap = int(cfg["max_segs_cap"])
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token if tok.eos_token is not None else "[PAD]"
pad_id = tok.pad_token_id
tok_s = tok(sums, truncation=True, max_length=sum_max_len, add_special_tokens=False, return_attention_mask=False)
sum_ids_list = tok_s["input_ids"]
tok_d = tok(docs, truncation=False, add_special_tokens=False, return_attention_mask=False)
doc_ids_list = tok_d["input_ids"]
segs_per_sample = []
for ids in doc_ids_list:
all_segs = build_all_segments(ids, seg_len, seg_overlap)
segs = sample_segments_cover_whole_doc(all_segs, cap)
segs_per_sample.append(segs)
B = len(docs)
K = max(len(s) for s in segs_per_sample)
flat_ids, flat_attn, flat_segmask = [], [], []
for i in range(B):
segs = segs_per_sample[i]
for seg in segs:
pair = tok.prepare_for_model(seg, sum_ids_list[i], truncation="only_first",
max_length=max_len, add_special_tokens=True,
return_attention_mask=True)
flat_ids.append(torch.tensor(pair["input_ids"], dtype=torch.long))
flat_attn.append(torch.tensor(pair["attention_mask"], dtype=torch.long))
flat_segmask.append(1)
for _ in range(len(segs), K):
flat_ids.append(torch.tensor([pad_id], dtype=torch.long))
flat_attn.append(torch.tensor([1], dtype=torch.long))
flat_segmask.append(0)
T = max(x.numel() for x in flat_ids)
T = ((T + 7)//8)*8
def pad_2d(xs):
out = []
for x in xs:
if x.numel() < T:
out.append(torch.cat([x, torch.full((T-x.numel(),), pad_id, dtype=torch.long)]))
else:
out.append(x)
return torch.stack(out, dim=0)
ids = pad_2d(flat_ids).view(B, K, T)
attn = pad_2d(flat_attn).view(B, K, T)
segm = torch.tensor(flat_segmask, dtype=torch.long).view(B, K)
return ids, attn, segm
|