| from transformers import TextClassificationPipeline, AutoTokenizer |
|
|
| class CustomTextClassificationPipeline(TextClassificationPipeline): |
| def __init__(self, model, tokenizer=None, **kwargs): |
| |
| if tokenizer is None: |
| tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) |
| |
| self.tokenizer = tokenizer |
| super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
| |
| def _sanitize_parameters(self, **kwargs): |
| preprocess_kwargs = {} |
| return preprocess_kwargs, {}, {} |
|
|
| def preprocess(self, inputs): |
| return self.tokenizer(inputs, return_tensors='pt', truncation=False) |
|
|
| def _forward(self, model_inputs): |
| input_ids = model_inputs['input_ids'] |
| attention_mask = (input_ids != 0).long() |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
| return outputs |
|
|
| def postprocess(self, model_outputs): |
| predictions = model_outputs.logits.argmax(dim=-1).squeeze().tolist() |
| categories = ["Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence", "Other"] |
| return dict(zip(categories, predictions)) |
|
|
|
|