File size: 5,756 Bytes
426ba72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os, json
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer

def mean_pool(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = (last_hidden_state * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp(min=1e-9)
    return summed / counts

def build_all_segments(ids, seg_len, overlap):
    n = len(ids)
    if n == 0: return [[]]
    if n <= seg_len: return [ids]
    step = max(1, seg_len - overlap)
    segs, start = [], 0
    while start < n:
        segs.append(ids[start:start + seg_len])
        if start + seg_len >= n: break
        start += step
    return segs

def sample_segments_cover_whole_doc(all_segs, cap):
    if len(all_segs) <= cap: return all_segs
    idx = np.linspace(0, len(all_segs)-1, cap)
    idx = np.unique(np.round(idx).astype(int)).tolist()
    idx = sorted(idx)[:cap]
    return [all_segs[i] for i in idx]

class MultiEvalSumVietN(nn.Module):
    def __init__(self, base_dir):
        super().__init__()
        with open(os.path.join(base_dir, "arch_config.json"), "r", encoding="utf-8") as f:
            cfg = json.load(f)
        self.cfg = cfg
        self.backbone = AutoModel.from_pretrained(base_dir)
        if hasattr(self.backbone.config, "use_cache"):
            self.backbone.config.use_cache = False

        hidden_in = cfg["trunk"]["hidden_in"]
        hidden_mid = cfg["trunk"]["hidden_mid"]
        dropout = cfg["trunk"]["dropout"]

        self.trunk = nn.Sequential(nn.Linear(hidden_in, hidden_mid), nn.GELU(), nn.Dropout(dropout))
        self.head_faith = nn.Linear(hidden_mid, 1)
        self.head_coh   = nn.Linear(hidden_mid, 1)
        self.head_rel   = nn.Linear(hidden_mid, 1)

        self.trunk.load_state_dict(torch.load(os.path.join(base_dir, "trunk.pt"), map_location="cpu"))
        self.head_faith.load_state_dict(torch.load(os.path.join(base_dir, "head_faith.pt"), map_location="cpu"))
        self.head_coh.load_state_dict(torch.load(os.path.join(base_dir, "head_coh.pt"), map_location="cpu"))
        self.head_rel.load_state_dict(torch.load(os.path.join(base_dir, "head_rel.pt"), map_location="cpu"))

        self.agg_type = cfg.get("agg_type", "mean")
        if self.agg_type == "attn":
            raise FileNotFoundError("agg_type='attn' requires seg_attn.pt export+upload. Set agg_type='mean' for now.")

        self.eval()

    @torch.no_grad()
    def forward(self, input_ids_3d, attention_mask_3d, seg_mask_2d):
        B, K, T = input_ids_3d.shape
        x = input_ids_3d.view(B*K, T)
        a = attention_mask_3d.view(B*K, T)

        out = self.backbone(input_ids=x, attention_mask=a).last_hidden_state
        pooled = mean_pool(out, a).view(B, K, -1)

        mask = seg_mask_2d.unsqueeze(-1).float()
        pooled = pooled * mask

        denom = mask.sum(dim=1).clamp_min(1e-6)
        doc_repr = pooled.sum(dim=1) / denom  # mean aggregation

        z = self.trunk(doc_repr)
        y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1)
        return y

def load_for_inference(base_dir, device=None):
    tok = AutoTokenizer.from_pretrained(base_dir, use_fast=True)
    mdl = MultiEvalSumVietN(base_dir)
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    mdl.to(device).eval()
    return mdl, tok, device

def encode_full_doc(tok, docs, sums):
    with open(os.path.join(tok.name_or_path, "arch_config.json"), "r", encoding="utf-8") as f:
        cfg = json.load(f)

    max_len = int(cfg["max_len"])
    sum_max_len = int(cfg["sum_max_len"])
    seg_len = int(cfg["seg_len"])
    seg_overlap = int(cfg["seg_overlap"])
    cap = int(cfg["max_segs_cap"])

    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token if tok.eos_token is not None else "[PAD]"
    pad_id = tok.pad_token_id

    tok_s = tok(sums, truncation=True, max_length=sum_max_len, add_special_tokens=False, return_attention_mask=False)
    sum_ids_list = tok_s["input_ids"]
    tok_d = tok(docs, truncation=False, add_special_tokens=False, return_attention_mask=False)
    doc_ids_list = tok_d["input_ids"]

    segs_per_sample = []
    for ids in doc_ids_list:
        all_segs = build_all_segments(ids, seg_len, seg_overlap)
        segs = sample_segments_cover_whole_doc(all_segs, cap)
        segs_per_sample.append(segs)

    B = len(docs)
    K = max(len(s) for s in segs_per_sample)

    flat_ids, flat_attn, flat_segmask = [], [], []
    for i in range(B):
        segs = segs_per_sample[i]
        for seg in segs:
            pair = tok.prepare_for_model(seg, sum_ids_list[i], truncation="only_first",
                                         max_length=max_len, add_special_tokens=True,
                                         return_attention_mask=True)
            flat_ids.append(torch.tensor(pair["input_ids"], dtype=torch.long))
            flat_attn.append(torch.tensor(pair["attention_mask"], dtype=torch.long))
            flat_segmask.append(1)
        for _ in range(len(segs), K):
            flat_ids.append(torch.tensor([pad_id], dtype=torch.long))
            flat_attn.append(torch.tensor([1], dtype=torch.long))
            flat_segmask.append(0)

    T = max(x.numel() for x in flat_ids)
    T = ((T + 7)//8)*8

    def pad_2d(xs):
        out = []
        for x in xs:
            if x.numel() < T:
                out.append(torch.cat([x, torch.full((T-x.numel(),), pad_id, dtype=torch.long)]))
            else:
                out.append(x)
        return torch.stack(out, dim=0)

    ids = pad_2d(flat_ids).view(B, K, T)
    attn = pad_2d(flat_attn).view(B, K, T)
    segm = torch.tensor(flat_segmask, dtype=torch.long).view(B, K)
    return ids, attn, segm