|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
class EOUConfig(PretrainedConfig): |
|
|
model_type = "eou_model" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_model_name="HuggingFaceTB/SmolLM2-135M", |
|
|
hidden_size=576, |
|
|
dropout_rate=0.1, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model_name = base_model_name |
|
|
self.hidden_size = hidden_size |
|
|
self.dropout_rate = dropout_rate |
|
|
|
|
|
class EOUModelForHF(PreTrainedModel): |
|
|
config_class = EOUConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name) |
|
|
|
|
|
|
|
|
self.completion_head = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.hidden_size // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(config.dropout_rate), |
|
|
nn.Linear(config.hidden_size // 2, 1) |
|
|
) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None): |
|
|
|
|
|
outputs = self.base_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True |
|
|
) |
|
|
|
|
|
|
|
|
last_hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
|
batch_size = last_hidden_states.shape[0] |
|
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
|
|
|
|
|
|
|
last_token_hidden = last_hidden_states[ |
|
|
torch.arange(batch_size), sequence_lengths |
|
|
] |
|
|
|
|
|
|
|
|
completion_prob = self.completion_head(last_token_hidden) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.BCEWithLogitsLoss() |
|
|
loss = loss_fct(completion_prob.squeeze(), labels.float()) |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'logits': completion_prob, |
|
|
'hidden_states': outputs.hidden_states |
|
|
} |
|
|
|
|
|
AutoConfig.register("eou_model", EOUConfig) |
|
|
|