| """ | |
| Script file used for performing inference with an existing model. | |
| """ | |
| import torch | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| from scripts.config import ( | |
| BIN_REPO, | |
| ML_REPO | |
| ) | |
| class InferenceHandler: | |
| """A class that handles performing inference using the trained binary classification and multilabel regression models.""" | |
| def __init__(self, api_token: str): | |
| """Constructor for instantiating an InferenceHandler object. | |
| Parameters | |
| ---------- | |
| api_token : str | |
| A Hugging Face token with read/write access privileges to allow exporting the trained models (default is None). | |
| """ | |
| self.api_token = api_token | |
| self.bin_tokenizer, self.bin_model = self._init_model_and_tokenizer(BIN_REPO) | |
| self.ml_regr_tokenizer, self.ml_regr_model = self._init_model_and_tokenizer(ML_REPO) | |
| nltk.download('punkt_tab') | |
| def _init_model_and_tokenizer(self, repo_id: str): | |
| """Initializes a model and tokenizer for use in inference using the models path. | |
| Parameters | |
| ---------- | |
| repo_id : str | |
| The repository id (i.e., <owner username>/<repository name>). | |
| Returns | |
| ------- | |
| tuple[PreTrainedTokenizer | PreTrainedTokenizerFast, PreTrainedModel] | |
| A tuple containing the tokenizer and model objects. | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained(repo_id, token=self.api_token) | |
| model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=self.api_token) | |
| model.eval() | |
| return tokenizer, model | |
| def _encode_binary(self, text: str): | |
| """Preprocesses and tokenizes the input text for binary classification. | |
| Parameters | |
| ---------- | |
| text : str | |
| The input text to be preprocessed and tokenized. | |
| Returns | |
| ------- | |
| BatchEncoding | |
| The preprocessed and tokenized input text. | |
| """ | |
| bin_tokenized_input = self.bin_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| return bin_tokenized_input | |
| def _encode_multilabel(self, text: str): | |
| """Preprocesses and tokenizes the input text for multilabel regression. | |
| Parameters | |
| ---------- | |
| text : str | |
| The input text to be preprocessed and tokenized. | |
| Returns | |
| ------- | |
| BatchEncoding | |
| The preprocessed and tokenized input text. | |
| """ | |
| ml_tokenized_input = self.ml_regr_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| return ml_tokenized_input | |
| def _encode_input(self, text: str): | |
| """Preprocesses and tokenizes the input text sentiment classification (both models). | |
| Parameters | |
| ---------- | |
| text : str | |
| The input text to be preprocessed and tokenized. | |
| Returns | |
| ------- | |
| tuple[BatchEncoding, BatchEncoding] | |
| A tuple containing preprocessed and tokenized input text for both the binary and multilabel regression models. | |
| """ | |
| bin_inputs = self._encode_binary(text) | |
| ml_inputs = self._encode_multilabel(text) | |
| return bin_inputs, ml_inputs | |
| def classify_text(self, input: str): | |
| """Performs inference on the input text to determine the binary classification and the multilabel regression for the categories. | |
| Determines whether the text is discriminatory. If it is discriminatory, it will then perform regression on the input text to determine the | |
| assesed percentage that each category applies. | |
| Parameters | |
| ---------- | |
| input : str | |
| The input text to be classified. | |
| Returns | |
| ------- | |
| dict[str, Any] | |
| The resulting classification and regression values for each category. | |
| """ | |
| result = { | |
| 'text_input': input, | |
| 'results': [] | |
| } | |
| sent_res_arr = [] | |
| sentences = sent_tokenize(input) | |
| for sent in sentences: | |
| text_prediction, pred_class = self.discriminatory_inference(sent) | |
| sent_result = { | |
| 'sentence': sent, | |
| 'binary_classification': { | |
| 'classification': text_prediction, | |
| 'prediction_class': pred_class | |
| }, | |
| 'multilabel_regression': None | |
| } | |
| if pred_class == 1: | |
| ml_results = { | |
| "Gender": None, | |
| "Race": None, | |
| "Sexuality": None, | |
| "Disability": None, | |
| "Religion": None, | |
| "Unspecified": None | |
| } | |
| ml_infer_results = self.category_inference(sent) | |
| for idx, key in enumerate(ml_results.keys()): | |
| ml_results[key] = min(max(ml_infer_results[idx], 0.0), 1.0) | |
| sent_result['multilabel_regression'] = ml_results | |
| sent_res_arr.append(sent_result) | |
| result['results'] = sent_res_arr | |
| return result | |
| def discriminatory_inference(self, text: str): | |
| """Performs inference on the input text to determine the binary classification. | |
| Parameters | |
| ---------- | |
| text : str | |
| The input text to be classified. | |
| Returns | |
| ------- | |
| tuple[str, Number] | |
| A tuple consisting of the string classification (Discriminatory or Non-Discriminatory) and the numeric prediction class (1 or 0). | |
| """ | |
| bin_inputs = self._encode_binary(text) | |
| with torch.no_grad(): | |
| bin_logits = self.bin_model(**bin_inputs).logits | |
| probs = torch.nn.functional.softmax(bin_logits, dim=-1) | |
| pred_class = torch.argmax(probs).item() | |
| bin_label_map = {0: "Non-Discriminatory", 1: "Discriminatory"} | |
| bin_text_pred = bin_label_map[pred_class] | |
| return bin_text_pred, pred_class | |
| def category_inference(self, text: str): | |
| """Performs inference on the input text to determine the regression values for the categories of discrimination. | |
| Parameters | |
| ---------- | |
| text : str | |
| The input text to be classified. | |
| Returns | |
| ------- | |
| list[float] | |
| A tuple consisting of the string classification (Discriminatory or Non-Discriminatory) and the numeric prediction class (1 or 0). | |
| """ | |
| ml_inputs = self._encode_multilabel(text) | |
| with torch.no_grad(): | |
| ml_outputs = self.ml_regr_model(**ml_inputs).logits | |
| ml_op_list = ml_outputs.squeeze().tolist() | |
| results = [] | |
| for item in ml_op_list: | |
| results.append(max(0.0, item)) | |
| return results |