Subi003's picture
Update modeling.py
4f6c161 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel, AutoConfig
class Encoder(nn.Module):
def __init__(self, base_encoder):
super().__init__()
self.encoder = base_encoder
def forward(self, inputs):
outputs = self.encoder(**inputs, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1]
mask = inputs["attention_mask"].unsqueeze(-1).float()
pooled = (last_hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
return F.normalize(pooled, p=2, dim=1)
class Classifier(nn.Module):
def __init__(self, input_dim=768, num_classes=28):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.25),
nn.Linear(512, num_classes),
)
def forward(self, x):
return self.mlp(x)
class RobertaEmoPillars(PreTrainedModel):
config_class = AutoConfig
def __init__(self, config):
super().__init__(config)
base_encoder = AutoModel.from_config(config) # IMPORTANT: use from_config
self.encoder = Encoder(base_encoder)
self.classifier = Classifier(input_dim=base_encoder.config.hidden_size,
num_classes=config.num_labels)
self.post_init() # ensures HF weights init
def forward(self, input_ids=None, attention_mask=None, **kwargs):
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
**kwargs
}
# --- Registration for AutoModel ---
try:
AutoModel.register(AutoConfig, RobertaEmoPillars)
except:
pass