|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| class SequenceClassificationWithFeatures(nn.Module): |
| def __init__(self, encoder, hidden_size, feature_dim, num_labels): |
| super().__init__() |
| self.encoder = encoder |
| self.feature_dim = feature_dim |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_size + feature_dim, hidden_size), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_size, num_labels) |
| ) |
|
|
| @classmethod |
| def from_pretrained(cls, path, device="cpu"): |
| import json |
| import os |
| from safetensors.torch import load_file |
| |
| |
| with open(os.path.join(path, "model_metadata.json"), "r") as f: |
| meta = json.load(f) |
| |
| |
| encoder = AutoModel.from_pretrained(path, trust_remote_code=True) |
| |
| |
| model = cls( |
| encoder=encoder, |
| hidden_size=encoder.config.hidden_size, |
| feature_dim=meta["feature_dim"], |
| num_labels=meta["num_labels"] |
| ) |
| |
| |
| state_dict = load_file(os.path.join(path, "model.safetensors")) |
| model.load_state_dict(state_dict) |
| |
| return model.to(device) |
|
|
| def forward(self, input_ids=None, attention_mask=None, features=None, labels=None): |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
|
| if hasattr(out, "pooler_output") and out.pooler_output is not None: |
| pooled = out.pooler_output |
| else: |
| last = out.last_hidden_state |
| mask = attention_mask.unsqueeze(-1) |
| pooled = (last * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
|
|
| |
| features = features.to(pooled.device).to(pooled.dtype) |
| x = torch.cat([pooled, features], dim=1) |
| logits = self.classifier(x) |
|
|
| loss = None |
| if labels is not None: |
| loss = nn.CrossEntropyLoss()(logits, labels) |
|
|
| return SequenceClassifierOutput(loss=loss, logits=logits) |
|
|