|
|
| import os, json, torch |
| from torch import nn |
| from transformers import AutoModel, AutoTokenizer |
|
|
| def mean_pool(last_hidden_state, attention_mask): |
| mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) |
| summed = (last_hidden_state * mask).sum(dim=1) |
| counts = mask.sum(dim=1).clamp(min=1e-9) |
| return summed / counts |
|
|
| class SummaryEvaluatorModule(nn.Module): |
| def __init__(self, base_model_dir): |
| super().__init__() |
| self.backbone = AutoModel.from_pretrained(base_model_dir) |
| with open(os.path.join(base_model_dir, "arch_config.json"), "r", encoding="utf-8") as f: |
| cfg = json.load(f) |
| hidden = cfg["trunk"]["hidden_in"] |
| self.trunk = nn.Sequential( |
| nn.Linear(hidden, cfg["trunk"]["hidden_mid"]), |
| nn.GELU(), |
| nn.Dropout(cfg["trunk"]["dropout"]) |
| ) |
| H = cfg["trunk"]["hidden_mid"] |
| self.head_faith = nn.Linear(H, 1) |
| self.head_coh = nn.Linear(H, 1) |
| self.head_rel = nn.Linear(H, 1) |
|
|
| |
| self.trunk.load_state_dict(torch.load(os.path.join(base_model_dir, "trunk.pt"), map_location="cpu")) |
| self.head_faith.load_state_dict(torch.load(os.path.join(base_model_dir, "head_faith.pt"), map_location="cpu")) |
| self.head_coh.load_state_dict(torch.load(os.path.join(base_model_dir, "head_coh.pt"), map_location="cpu")) |
| self.head_rel.load_state_dict(torch.load(os.path.join(base_model_dir, "head_rel.pt"), map_location="cpu")) |
| self.eval() |
|
|
| @torch.no_grad() |
| def forward(self, input_ids, attention_mask): |
| out = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
| pooled = mean_pool(out.last_hidden_state, attention_mask) |
| z = self.trunk(pooled) |
| y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1) |
| return y |
|
|
| def load_for_inference(base_model_dir, device=None): |
| tok = AutoTokenizer.from_pretrained(base_model_dir, use_fast=True) |
| mdl = SummaryEvaluatorModule(base_model_dir) |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| mdl.to(device).eval() |
| return mdl, tok, device |
|
|
| def encode_pair(tokenizer, docs, sums, max_len=512, sum_max_len=256): |
| |
| tok_sum = tokenizer( |
| sums, truncation=True, max_length=sum_max_len, |
| add_special_tokens=False, return_attention_mask=False |
| ) |
| trimmed = tokenizer.batch_decode( |
| tok_sum["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=True |
| ) |
| |
| enc = tokenizer( |
| docs, trimmed, truncation="only_first", max_length=max_len, |
| add_special_tokens=True, return_attention_mask=True |
| ) |
| return enc |
|
|