from transformers import PreTrainedModel from .configuration_cetacean_classifier import TemplateClassifierConfig from .model import TemplateClassifier class TemplateClassifierModelForImageClassification(PreTrainedModel): config_class = TemplateClassifierConfig def __init__(self, config): super().__init__(config) self.model = TemplateClassifier(config=config.to_dict()) self.model.eval() def forward(self, model_input): predictions = self.model(model_input) return {"predictions": predictions}