import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, AutoModel, AutoConfig class Encoder(nn.Module): def __init__(self, base_encoder): super().__init__() self.encoder = base_encoder def forward(self, inputs): outputs = self.encoder(**inputs, output_hidden_states=True) last_hidden = outputs.hidden_states[-1] mask = inputs["attention_mask"].unsqueeze(-1).float() pooled = (last_hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9) return F.normalize(pooled, p=2, dim=1) class Classifier(nn.Module): def __init__(self, input_dim=768, num_classes=28): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.25), nn.Linear(512, num_classes), ) def forward(self, x): return self.mlp(x) class RobertaEmoPillars(PreTrainedModel): config_class = AutoConfig def __init__(self, config): super().__init__(config) base_encoder = AutoModel.from_config(config) # IMPORTANT: use from_config self.encoder = Encoder(base_encoder) self.classifier = Classifier(input_dim=base_encoder.config.hidden_size, num_classes=config.num_labels) self.post_init() # ensures HF weights init def forward(self, input_ids=None, attention_mask=None, **kwargs): inputs = { "input_ids": input_ids, "attention_mask": attention_mask, **kwargs } # --- Registration for AutoModel --- try: AutoModel.register(AutoConfig, RobertaEmoPillars) except: pass