File size: 2,243 Bytes
9270c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

import torch
import torch.nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput

class SequenceClassificationWithFeatures(nn.Module):
    def __init__(self, encoder, hidden_size, feature_dim, num_labels):
        super().__init__()
        self.encoder = encoder
        self.feature_dim = feature_dim
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size + feature_dim, hidden_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_labels)
        )

    @classmethod
    def from_pretrained(cls, path, device="cpu"):
        import json
        import os
        from safetensors.torch import load_file
        
        # Load Config
        with open(os.path.join(path, "model_metadata.json"), "r") as f:
            meta = json.load(f)
            
        # Load Encoder
        encoder = AutoModel.from_pretrained(path, trust_remote_code=True)
        
        # Init Model
        model = cls(
            encoder=encoder, 
            hidden_size=encoder.config.hidden_size,
            feature_dim=meta["feature_dim"],
            num_labels=meta["num_labels"]
        )
        
        # Load Weights
        state_dict = load_file(os.path.join(path, "model.safetensors"))
        model.load_state_dict(state_dict)
        
        return model.to(device)

    def forward(self, input_ids=None, attention_mask=None, features=None, labels=None):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)

        if hasattr(out, "pooler_output") and out.pooler_output is not None:
            pooled = out.pooler_output
        else:
            last = out.last_hidden_state
            mask = attention_mask.unsqueeze(-1)
            pooled = (last * mask).sum(1) / mask.sum(1).clamp(min=1e-9)

        # Cast features to match model dtype/device
        features = features.to(pooled.device).to(pooled.dtype)
        x = torch.cat([pooled, features], dim=1)
        logits = self.classifier(x)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return SequenceClassifierOutput(loss=loss, logits=logits)