emo-detector / model.py
23f2001106
Initial commit with model files tracked via Git LFS
2564e6d
import torch.nn as nn
import torch
from transformers import AutoModel
class BERT_FFNN(nn.Module):
"""
BERT_FFNN: BERT + feed-forward network for text classification tasks.
"""
def __init__(
self,
bert_model_name= "microsoft/deberta-v3-base",
hidden_dims=[192, 96],
output_dim=5,
dropout=0.2,
pooling='attention',
freeze_bert=False,
freeze_layers=0,
use_layer_norm=True
):
super().__init__()
# Load pretrained BERT
self.bert = AutoModel.from_pretrained(bert_model_name)
self.use_layer_norm = use_layer_norm
self.pooling = pooling
if pooling == 'attention':
self.attention_pool = AttentionPooling(self.bert.config.hidden_size)
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
elif freeze_layers > 0:
for layer in self.bert.encoder.layer[:freeze_layers]:
for param in layer.parameters():
param.requires_grad = False
# Build FFNN layers
fc_input_dim = self.bert.config.hidden_size
layers = []
in_dim = fc_input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(in_dim, h_dim))
layers.append(nn.ReLU())
if use_layer_norm:
layers.append(nn.LayerNorm(h_dim))
layers.append(nn.Dropout(dropout))
in_dim = h_dim
layers.append(nn.Linear(in_dim, output_dim))
self.classifier = nn.Sequential(*layers)
def forward(self, input_ids, attention_mask):
# BERT forward
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
if self.pooling == 'mean':
mask = attention_mask.unsqueeze(-1).float()
sum_emb = (outputs.last_hidden_state * mask).sum(1)
features = sum_emb / mask.sum(1).clamp(min=1e-9)
elif self.pooling == 'max':
mask = attention_mask.unsqueeze(-1).float()
masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf'))
features, _ = masked_emb.max(dim=1)
elif self.pooling == 'attention':
features = self.attention_pool(outputs.last_hidden_state, attention_mask)
else:
# CLS pooling
features = outputs.pooler_output if getattr(outputs, 'pooler_output', None) is not None else outputs.last_hidden_state[:, 0]
logits = self.classifier(features)
return logits
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attention = nn.Linear(hidden_size, 1)
def forward(self, hidden_states, attention_mask):
# hidden_states: [batch, seq_len, hidden]
# attention_mask: [batch, seq_len]
scores = self.attention(hidden_states).squeeze(-1) # [batch, seq_len]
scores = scores.masked_fill(attention_mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1) # [batch, seq_len]
weighted_sum = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)
return weighted_sum