import torch import torch.nn as nn from transformers import AutoModel MODEL_ID = "Omartificial-Intelligence-Space/SA-BERT-V1" class EOUClassifier(nn.Module): def __init__(self, model_id=MODEL_ID, num_labels=2, use_class_weights=True, pooling="cls"): super().__init__() self.num_labels = num_labels self.pooling = pooling # "cls" or "mean" # Load encoder self.bert = AutoModel.from_pretrained(model_id) self.dropout = nn.Dropout(0.15) self.layer_1 = nn.Linear(768, 384) self.act = nn.GELU() self.layer_2 = nn.Linear(384, num_labels) self.loss_fn = nn.CrossEntropyLoss() def forward(self, input_ids, attention_mask, labels=None): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) if self.pooling == "cls": pooled = outputs.last_hidden_state[:, 0] # [CLS] else: # Mean pooling hidden = outputs.last_hidden_state mask = attention_mask.unsqueeze(-1) pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1) x = self.dropout(pooled) x = self.layer_1(x) x = self.act(x) x = self.dropout(x) logits = self.layer_2(x) if labels is not None: loss = self.loss_fn(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits}