import torch import torch.nn as nn from transformers import AutoModel from transformers.modeling_outputs import SequenceClassifierOutput class SequenceClassificationWithFeatures(nn.Module): def __init__(self, encoder, hidden_size, feature_dim, num_labels): super().__init__() self.encoder = encoder self.feature_dim = feature_dim self.classifier = nn.Sequential( nn.Linear(hidden_size + feature_dim, hidden_size), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_size, num_labels) ) @classmethod def from_pretrained(cls, path, device="cpu"): import json import os from safetensors.torch import load_file # Load Config with open(os.path.join(path, "model_metadata.json"), "r") as f: meta = json.load(f) # Load Encoder encoder = AutoModel.from_pretrained(path, trust_remote_code=True) # Init Model model = cls( encoder=encoder, hidden_size=encoder.config.hidden_size, feature_dim=meta["feature_dim"], num_labels=meta["num_labels"] ) # Load Weights state_dict = load_file(os.path.join(path, "model.safetensors")) model.load_state_dict(state_dict) return model.to(device) def forward(self, input_ids=None, attention_mask=None, features=None, labels=None): out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) if hasattr(out, "pooler_output") and out.pooler_output is not None: pooled = out.pooler_output else: last = out.last_hidden_state mask = attention_mask.unsqueeze(-1) pooled = (last * mask).sum(1) / mask.sum(1).clamp(min=1e-9) # Cast features to match model dtype/device features = features.to(pooled.device).to(pooled.dtype) x = torch.cat([pooled, features], dim=1) logits = self.classifier(x) loss = None if labels is not None: loss = nn.CrossEntropyLoss()(logits, labels) return SequenceClassifierOutput(loss=loss, logits=logits)