| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| class BathSalt1DaedalusPhi3Model(AutoModelForSeq2SeqLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| def forward(self, input_ids, attention_mask, labels): | |
| outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels) | |
| return outputs |