Multi_EvalSumVietN_FullDoc / modeling_summary_evaluator.py
phuongntc's picture
Upload MultiEvalSumVietN Full-Doc evaluator (no-transformers packaging)
1705a46 verified
import os, json
import torch
import torch.nn as nn
# NOTE:
# This loader intentionally avoids importing transformers at module import time.
# It imports AutoModel lazily inside __init__.
def mean_pool(last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
class FullDocEvaluator(nn.Module):
"""
Load from a local HF snapshot directory containing:
- config.json + model.safetensors (backbone)
- trunk.pt, head_faith.pt, head_coh.pt, head_rel.pt
- optional seg_attn.pt (if agg_type == "attn")
- arch_config.json
Forward expects:
input_ids: [B,K,T]
attention_mask: [B,K,T] (0/1)
seg_mask: [B,K] (1 real seg, 0 dummy seg)
Returns sigmoid scores in [0,1] in order: [faith, coh, rel]
"""
def __init__(self, base_dir: str):
super().__init__()
self.base_dir = base_dir
with open(os.path.join(base_dir, "arch_config.json"), "r", encoding="utf-8") as f:
cfg = json.load(f)
self.agg_type = cfg.get("agg_type", "mean")
# Lazy import to reduce environment-triggered import issues
from transformers import AutoModel
self.backbone = AutoModel.from_pretrained(base_dir)
if hasattr(self.backbone.config, "use_cache"):
self.backbone.config.use_cache = False
hidden = int(getattr(self.backbone.config, "hidden_size"))
self.trunk = nn.Sequential(
nn.Linear(hidden, 256),
nn.GELU(),
nn.Dropout(0.1),
)
self.head_faith = nn.Linear(256, 1)
self.head_coh = nn.Linear(256, 1)
self.head_rel = nn.Linear(256, 1)
self.trunk.load_state_dict(torch.load(os.path.join(base_dir, "trunk.pt"), map_location="cpu"))
self.head_faith.load_state_dict(torch.load(os.path.join(base_dir, "head_faith.pt"), map_location="cpu"))
self.head_coh.load_state_dict(torch.load(os.path.join(base_dir, "head_coh.pt"), map_location="cpu"))
self.head_rel.load_state_dict(torch.load(os.path.join(base_dir, "head_rel.pt"), map_location="cpu"))
if self.agg_type == "attn":
self.seg_attn = nn.Sequential(
nn.Linear(hidden, hidden // 2),
nn.Tanh(),
nn.Linear(hidden // 2, 1)
)
self.seg_attn.load_state_dict(torch.load(os.path.join(base_dir, "seg_attn.pt"), map_location="cpu"))
else:
self.seg_attn = None
self.eval()
@torch.no_grad()
def forward(self, input_ids, attention_mask, seg_mask):
B, K, T = input_ids.shape
x = input_ids.view(B*K, T)
a = attention_mask.view(B*K, T)
out = self.backbone(input_ids=x, attention_mask=a).last_hidden_state
pooled = mean_pool(out, a).view(B, K, -1)
mask = seg_mask.unsqueeze(-1).float()
pooled = pooled * mask
if self.agg_type == "mean":
denom = mask.sum(dim=1).clamp_min(1e-6)
doc = pooled.sum(dim=1) / denom
elif self.agg_type == "max":
neg_inf = torch.finfo(pooled.dtype).min
tmp = pooled + (1.0 - mask) * neg_inf
doc = tmp.max(dim=1).values
else:
seg_fp32 = pooled.float()
score = self.seg_attn(seg_fp32).squeeze(-1)
score = score.masked_fill(seg_mask == 0, torch.finfo(score.dtype).min)
w = torch.softmax(score, dim=1).unsqueeze(-1)
doc_fp32 = (w * seg_fp32).sum(dim=1)
doc = doc_fp32.type_as(pooled)
z = self.trunk(doc)
y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1)
return torch.sigmoid(y)