| from transformers import PreTrainedModel | |
| from .configuration_cetacean_classifier import TemplateClassifierConfig | |
| from .model import TemplateClassifier | |
| class TemplateClassifierModelForImageClassification(PreTrainedModel): | |
| config_class = TemplateClassifierConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = TemplateClassifier(config=config.to_dict()) | |
| self.model.eval() | |
| def forward(self, model_input): | |
| predictions = self.model(model_input) | |
| return {"predictions": predictions} | |