| | import os, json |
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | |
| | |
| |
|
| | def mean_pool(last_hidden_state, attention_mask): |
| | mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) |
| | summed = (last_hidden_state * mask).sum(dim=1) |
| | counts = mask.sum(dim=1).clamp(min=1e-9) |
| | return summed / counts |
| |
|
| | class FullDocEvaluator(nn.Module): |
| | """ |
| | Load from a local HF snapshot directory containing: |
| | - config.json + model.safetensors (backbone) |
| | - trunk.pt, head_faith.pt, head_coh.pt, head_rel.pt |
| | - optional seg_attn.pt (if agg_type == "attn") |
| | - arch_config.json |
| | Forward expects: |
| | input_ids: [B,K,T] |
| | attention_mask: [B,K,T] (0/1) |
| | seg_mask: [B,K] (1 real seg, 0 dummy seg) |
| | Returns sigmoid scores in [0,1] in order: [faith, coh, rel] |
| | """ |
| | def __init__(self, base_dir: str): |
| | super().__init__() |
| | self.base_dir = base_dir |
| |
|
| | with open(os.path.join(base_dir, "arch_config.json"), "r", encoding="utf-8") as f: |
| | cfg = json.load(f) |
| | self.agg_type = cfg.get("agg_type", "mean") |
| |
|
| | |
| | from transformers import AutoModel |
| | self.backbone = AutoModel.from_pretrained(base_dir) |
| |
|
| | if hasattr(self.backbone.config, "use_cache"): |
| | self.backbone.config.use_cache = False |
| |
|
| | hidden = int(getattr(self.backbone.config, "hidden_size")) |
| |
|
| | self.trunk = nn.Sequential( |
| | nn.Linear(hidden, 256), |
| | nn.GELU(), |
| | nn.Dropout(0.1), |
| | ) |
| | self.head_faith = nn.Linear(256, 1) |
| | self.head_coh = nn.Linear(256, 1) |
| | self.head_rel = nn.Linear(256, 1) |
| |
|
| | self.trunk.load_state_dict(torch.load(os.path.join(base_dir, "trunk.pt"), map_location="cpu")) |
| | self.head_faith.load_state_dict(torch.load(os.path.join(base_dir, "head_faith.pt"), map_location="cpu")) |
| | self.head_coh.load_state_dict(torch.load(os.path.join(base_dir, "head_coh.pt"), map_location="cpu")) |
| | self.head_rel.load_state_dict(torch.load(os.path.join(base_dir, "head_rel.pt"), map_location="cpu")) |
| |
|
| | if self.agg_type == "attn": |
| | self.seg_attn = nn.Sequential( |
| | nn.Linear(hidden, hidden // 2), |
| | nn.Tanh(), |
| | nn.Linear(hidden // 2, 1) |
| | ) |
| | self.seg_attn.load_state_dict(torch.load(os.path.join(base_dir, "seg_attn.pt"), map_location="cpu")) |
| | else: |
| | self.seg_attn = None |
| |
|
| | self.eval() |
| |
|
| | @torch.no_grad() |
| | def forward(self, input_ids, attention_mask, seg_mask): |
| | B, K, T = input_ids.shape |
| | x = input_ids.view(B*K, T) |
| | a = attention_mask.view(B*K, T) |
| |
|
| | out = self.backbone(input_ids=x, attention_mask=a).last_hidden_state |
| | pooled = mean_pool(out, a).view(B, K, -1) |
| |
|
| | mask = seg_mask.unsqueeze(-1).float() |
| | pooled = pooled * mask |
| |
|
| | if self.agg_type == "mean": |
| | denom = mask.sum(dim=1).clamp_min(1e-6) |
| | doc = pooled.sum(dim=1) / denom |
| | elif self.agg_type == "max": |
| | neg_inf = torch.finfo(pooled.dtype).min |
| | tmp = pooled + (1.0 - mask) * neg_inf |
| | doc = tmp.max(dim=1).values |
| | else: |
| | seg_fp32 = pooled.float() |
| | score = self.seg_attn(seg_fp32).squeeze(-1) |
| | score = score.masked_fill(seg_mask == 0, torch.finfo(score.dtype).min) |
| | w = torch.softmax(score, dim=1).unsqueeze(-1) |
| | doc_fp32 = (w * seg_fp32).sum(dim=1) |
| | doc = doc_fp32.type_as(pooled) |
| |
|
| | z = self.trunk(doc) |
| | y = torch.cat([self.head_faith(z), self.head_coh(z), self.head_rel(z)], dim=1) |
| | return torch.sigmoid(y) |
| |
|