my_model / model.py
shaqaqio's picture
Upload 3 files
91a459e verified
import torch
import torch.nn as nn
from transformers import AutoModel
MODEL_ID = "Omartificial-Intelligence-Space/SA-BERT-V1"
class EOUClassifier(nn.Module):
def __init__(self, model_id=MODEL_ID, num_labels=2, use_class_weights=True, pooling="cls"):
super().__init__()
self.num_labels = num_labels
self.pooling = pooling # "cls" or "mean"
# Load encoder
self.bert = AutoModel.from_pretrained(model_id)
self.dropout = nn.Dropout(0.15)
self.layer_1 = nn.Linear(768, 384)
self.act = nn.GELU()
self.layer_2 = nn.Linear(384, num_labels)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
if self.pooling == "cls":
pooled = outputs.last_hidden_state[:, 0] # [CLS]
else:
# Mean pooling
hidden = outputs.last_hidden_state
mask = attention_mask.unsqueeze(-1)
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1)
x = self.dropout(pooled)
x = self.layer_1(x)
x = self.act(x)
x = self.dropout(x)
logits = self.layer_2(x)
if labels is not None:
loss = self.loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}