| from transformers import Pipeline |
|
|
|
|
| class MyPipeline(Pipeline): |
| def _sanitize_parameters(self, **kwargs): |
| preprocess_kwargs = {} |
| if "maybe_arg" in kwargs: |
| preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] |
| return preprocess_kwargs, {}, {} |
|
|
| def preprocess(self, inputs, maybe_arg=2): |
| model_input = Tensor(inputs["input_ids"]) |
| return {"model_input": model_input} |
|
|
| def _forward(self, model_inputs): |
| |
| outputs = self.model(**model_inputs) |
| |
| return outputs |
|
|
| def postprocess(self, model_outputs): |
| best_class = model_outputs["logits"].softmax(-1) |
| return best_class |