File size: 2,188 Bytes
2bcedff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
from pathlib import Path

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer


def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).float()
    masked = last_hidden_state * mask
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return masked.sum(dim=1) / denom


class MultiEvalVietSumModel(nn.Module):
    def __init__(self, backbone_name: str):
        super().__init__()
        self.backbone_name = backbone_name
        self.model = AutoModel.from_pretrained(backbone_name)
        hidden = self.model.config.hidden_size

        self.trunk = nn.Sequential(
            nn.Linear(hidden * 2, 256),
            nn.GELU(),
            nn.Dropout(0.1),
        )
        self.head_faith = nn.Linear(256, 1)
        self.head_coh = nn.Linear(256, 1)
        self.head_rel = nn.Linear(256, 1)

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }
        if token_type_ids is not None:
            kwargs["token_type_ids"] = token_type_ids

        out = self.model(**kwargs)
        cls_vec = out.last_hidden_state[:, 0]
        mean_vec = mean_pool(out.last_hidden_state, attention_mask)
        pooled = torch.cat([cls_vec, mean_vec], dim=-1)
        z = self.trunk(pooled)

        faith = self.head_faith(z)
        coh = self.head_coh(z)
        rel = self.head_rel(z)
        return torch.cat([faith, coh, rel], dim=1)

    @classmethod
    def from_pretrained_local(cls, model_dir: str):
        model_dir = Path(model_dir)
        with open(model_dir / "multievalvietsum_config.json", "r", encoding="utf-8") as f:
            cfg = json.load(f)

        model = cls(backbone_name=cfg["backbone_name"])
        state_dict = torch.load(model_dir / "pytorch_model.bin", map_location="cpu")
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        return model, cfg

    @staticmethod
    def load_tokenizer_local(model_dir: str):
        return AutoTokenizer.from_pretrained(model_dir, use_fast=True)