AhmedRabie01's picture
Rename Sentiment/ml/model/multitask_bert.pyimport torch import torch.nn as nn from transformers import AutoConfig, AutoModel class MultiTaskBert(nn.Module): def __init__( self, model_name: str, n_sentiment: int, n_intent: int, n_topic: int, dropout: float = 0.2, init_from_pretrained: bool = False, # IMPORTANT ): super().__init__() self.model_name = model_name if init_from_pretrained: # trust_remote_code lets us load architectures (e.g., ModernBERT/mmBERT) # that may not be bundled with the installed transformers version. self.encoder = AutoModel.from_pretrained( model_name, trust_remote_code=True, ) else: config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, ) self.encoder = AutoModel.from_config( config, trust_remote_code=True, ) hidden_size = self.encoder.config.hidden_size self.dropout = nn.Dropout(dropout) self.sentiment_head = nn.Linear(hidden_size, n_sentiment) self.intent_head = nn.Linear(hidden_size, n_intent) self.topic_head = nn.Linear(hidden_size, n_topic) def forward(self, input_ids, attention_mask): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, ) pooled = outputs.last_hidden_state[:, 0] # CLS pooled = self.dropout(pooled) return ( self.sentiment_head(pooled), self.intent_head(pooled), self.topic_head(pooled), ) to Sentiment/ml/model/multitask_bert.py
da59a51 verified
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
class MultiTaskBert(nn.Module):
def __init__(
self,
model_name: str,
n_sentiment: int,
n_intent: int,
n_topic: int,
dropout: float = 0.2,
init_from_pretrained: bool = False, # IMPORTANT
):
super().__init__()
self.model_name = model_name
if init_from_pretrained:
# trust_remote_code lets us load architectures (e.g., ModernBERT/mmBERT)
# that may not be bundled with the installed transformers version.
self.encoder = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
)
else:
config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=True,
)
self.encoder = AutoModel.from_config(
config,
trust_remote_code=True,
)
hidden_size = self.encoder.config.hidden_size
self.dropout = nn.Dropout(dropout)
self.sentiment_head = nn.Linear(hidden_size, n_sentiment)
self.intent_head = nn.Linear(hidden_size, n_intent)
self.topic_head = nn.Linear(hidden_size, n_topic)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
pooled = outputs.last_hidden_state[:, 0] # CLS
pooled = self.dropout(pooled)
return (
self.sentiment_head(pooled),
self.intent_head(pooled),
self.topic_head(pooled),
)