Spaces:
Sleeping
Sleeping
| import torch | |
| import transformers | |
| DOWNLOAD_URL = "https://github.com/unitaryai/detoxify/releases/download/" | |
| MODEL_URLS = { | |
| "original": DOWNLOAD_URL + "v0.1-alpha/toxic_original-c1212f89.ckpt", | |
| "unbiased": DOWNLOAD_URL + "v0.3-alpha/toxic_debiased-c7548aa0.ckpt", | |
| "multilingual": DOWNLOAD_URL + "v0.4-alpha/multilingual_debiased-0b549669.ckpt", | |
| "original-small": DOWNLOAD_URL + "v0.1.2/original-albert-0e1d6498.ckpt", | |
| "unbiased-small": DOWNLOAD_URL + "v0.1.2/unbiased-albert-c8519128.ckpt", | |
| } | |
| PRETRAINED_MODEL = None | |
| def get_model_and_tokenizer( | |
| model_type, model_name, tokenizer_name, num_classes, state_dict, huggingface_config_path=None | |
| ): | |
| model_class = getattr(transformers, model_name) | |
| model = model_class.from_pretrained( | |
| pretrained_model_name_or_path=None, | |
| config=huggingface_config_path or model_type, | |
| num_labels=num_classes, | |
| state_dict=state_dict, | |
| local_files_only=huggingface_config_path is not None, | |
| ) | |
| tokenizer = getattr(transformers, tokenizer_name).from_pretrained( | |
| huggingface_config_path or model_type, | |
| local_files_only=huggingface_config_path is not None, | |
| # TODO: may be needed to let it work with Kaggle competition | |
| # model_max_length=512, | |
| ) | |
| return model, tokenizer | |
| def load_checkpoint(model_type="original", checkpoint=None, device="cpu", huggingface_config_path=None): | |
| if checkpoint is None: | |
| checkpoint_path = MODEL_URLS[model_type] | |
| loaded = torch.hub.load_state_dict_from_url(checkpoint_path, map_location=device) | |
| else: | |
| loaded = torch.load(checkpoint, map_location=device) | |
| if "config" not in loaded or "state_dict" not in loaded: | |
| raise ValueError( | |
| "Checkpoint needs to contain the config it was trained \ | |
| with as well as the state dict" | |
| ) | |
| class_names = loaded["config"]["dataset"]["args"]["classes"] | |
| # standardise class names between models | |
| change_names = { | |
| "toxic": "toxicity", | |
| "identity_hate": "identity_attack", | |
| "severe_toxic": "severe_toxicity", | |
| } | |
| class_names = [change_names.get(cl, cl) for cl in class_names] | |
| model, tokenizer = get_model_and_tokenizer( | |
| **loaded["config"]["arch"]["args"], | |
| state_dict=loaded["state_dict"], | |
| huggingface_config_path=huggingface_config_path, | |
| ) | |
| return model, tokenizer, class_names | |
| def load_model(model_type, checkpoint=None): | |
| if checkpoint is None: | |
| model, _, _ = load_checkpoint(model_type=model_type) | |
| else: | |
| model, _, _ = load_checkpoint(checkpoint=checkpoint) | |
| return model | |
| class Detoxify: | |
| """Detoxify | |
| Easily predict if a comment or list of comments is toxic. | |
| Can initialize 5 different model types from model type or checkpoint path: | |
| - original: | |
| model trained on data from the Jigsaw Toxic Comment | |
| Classification Challenge | |
| - unbiased: | |
| model trained on data from the Jigsaw Unintended Bias in | |
| Toxicity Classification Challenge | |
| - multilingual: | |
| model trained on data from the Jigsaw Multilingual | |
| Toxic Comment Classification Challenge | |
| - original-small: | |
| lightweight version of the original model | |
| - unbiased-small: | |
| lightweight version of the unbiased model | |
| Args: | |
| model_type(str): model type to be loaded, can be either original, | |
| unbiased or multilingual | |
| checkpoint(str): checkpoint path, defaults to None | |
| device(str or torch.device): accepts any torch.device input or | |
| torch.device object, defaults to cpu | |
| huggingface_config_path: path to HF config and tokenizer files needed for offline model loading | |
| Returns: | |
| results(dict): dictionary of output scores for each class | |
| """ | |
| def __init__(self, model_type="original", checkpoint=PRETRAINED_MODEL, device="cpu", huggingface_config_path=None): | |
| super().__init__() | |
| self.model, self.tokenizer, self.class_names = load_checkpoint( | |
| model_type=model_type, | |
| checkpoint=checkpoint, | |
| device=device, | |
| huggingface_config_path=huggingface_config_path, | |
| ) | |
| self.device = device | |
| self.model.to(self.device) | |
| def predict(self, text): | |
| self.model.eval() | |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.model.device) | |
| out = self.model(**inputs)[0] | |
| scores = torch.sigmoid(out).cpu().detach().numpy() | |
| results = {} | |
| for i, cla in enumerate(self.class_names): | |
| results[cla] = ( | |
| scores[0][i] if isinstance(text, str) else [scores[ex_i][i].tolist() for ex_i in range(len(scores))] | |
| ) | |
| return results | |
| def toxic_bert(): | |
| return load_model("original") | |
| def toxic_albert(): | |
| return load_model("original-small") | |
| def unbiased_toxic_roberta(): | |
| return load_model("unbiased") | |
| def unbiased_albert(): | |
| return load_model("unbiased-small") | |
| def multilingual_toxic_xlm_r(): | |
| return load_model("multilingual") | |