arabic-EOU / modeling_eou.py
Harras111's picture
Upload folder using huggingface_hub
2249e61 verified
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
class EOUConfig(PretrainedConfig):
model_type = "eou_model"
def __init__(
self,
base_model_name="HuggingFaceTB/SmolLM2-135M",
hidden_size=576,
dropout_rate=0.1,
**kwargs
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.hidden_size = hidden_size
self.dropout_rate = dropout_rate
class EOUModelForHF(PreTrainedModel):
config_class = EOUConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Load base model
self.base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name)
# Classification head
self.completion_head = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout_rate),
nn.Linear(config.hidden_size // 2, 1)
)
def forward(self, input_ids, attention_mask=None, labels=None):
# Get outputs from base model
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# Use the last hidden state at the last token position
last_hidden_states = outputs.hidden_states[-1]
# Get the representation of the last token for each sequence
batch_size = last_hidden_states.shape[0]
sequence_lengths = attention_mask.sum(dim=1) - 1
# Extract last token representation for each sequence in batch
last_token_hidden = last_hidden_states[
torch.arange(batch_size), sequence_lengths
]
# Get completion probability
completion_prob = self.completion_head(last_token_hidden)
loss = None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(completion_prob.squeeze(), labels.float())
return {
'loss': loss,
'logits': completion_prob,
'hidden_states': outputs.hidden_states
}
AutoConfig.register("eou_model", EOUConfig)