from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoTokenizer, AutoModelForCausalLM import torch import torch.nn as nn class EOUConfig(PretrainedConfig): model_type = "eou_model" def __init__( self, base_model_name="HuggingFaceTB/SmolLM2-135M", hidden_size=576, dropout_rate=0.1, **kwargs ): super().__init__(**kwargs) self.base_model_name = base_model_name self.hidden_size = hidden_size self.dropout_rate = dropout_rate class EOUModelForHF(PreTrainedModel): config_class = EOUConfig def __init__(self, config): super().__init__(config) self.config = config # Load base model self.base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name) # Classification head self.completion_head = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size // 2), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(config.hidden_size // 2, 1) ) def forward(self, input_ids, attention_mask=None, labels=None): # Get outputs from base model outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ) # Use the last hidden state at the last token position last_hidden_states = outputs.hidden_states[-1] # Get the representation of the last token for each sequence batch_size = last_hidden_states.shape[0] sequence_lengths = attention_mask.sum(dim=1) - 1 # Extract last token representation for each sequence in batch last_token_hidden = last_hidden_states[ torch.arange(batch_size), sequence_lengths ] # Get completion probability completion_prob = self.completion_head(last_token_hidden) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(completion_prob.squeeze(), labels.float()) return { 'loss': loss, 'logits': completion_prob, 'hidden_states': outputs.hidden_states } AutoConfig.register("eou_model", EOUConfig)