| classes = { '0': 'nsfw_gore', '1': 'nsfw_suggestive', '2': 'safe' } | |
| model_path = "safesearch_mini_v2.bin" | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| class SafeSearchConfig(PretrainedConfig): | |
| model_type = "safesearch_mini_v2" | |
| def __init__(self, | |
| model_name: str = "safesearch_mini_v2", | |
| input_channels: int = 3, | |
| num_classes: int = 3, | |
| input_size: list = [3, 299, 299], | |
| pool_size: list = [8, 8], | |
| crop_pct: float = 0.875, | |
| interpolation: str = "bicubic", | |
| mean: list = [0.5, 0.5, 0.5], | |
| std: list = [0.5, 0.5, 0.5], | |
| first_conv: str = "conv2d_1a.conv", | |
| classifier: str = "default", | |
| has_aux: bool = False, | |
| label_offset: int = 0, | |
| classes: object = classes, | |
| output_channels: int = 1536, | |
| device: str = "cpu", | |
| **kwargs): | |
| self.model_name = model_name | |
| self.input_channels = input_channels | |
| self.num_classes = num_classes | |
| self.input_size = input_size | |
| self.pool_size = pool_size | |
| self.crop_pct = crop_pct | |
| self.interpolation = interpolation | |
| self.mean = mean | |
| self.std = std | |
| self.first_conv = first_conv | |
| self.classifier = classifier | |
| self.has_aux = has_aux | |
| self.label_offset = label_offset | |
| self.classes = classes | |
| self.output_channels = output_channels | |
| self.device = device | |
| super().__init__(**kwargs) | |
| """ | |
| safesearch_config = SafeSearchConfig() | |
| safesearch_config.save_pretrained("safesearch_config") | |
| """ | |
| import torch, os, timm | |
| class SafeSearchModel(PreTrainedModel): | |
| config_class = SafeSearchConfig | |
| def __init__(self, config: SafeSearchConfig): | |
| super().__init__(config) | |
| if not os.path.exists(model_path): | |
| from urllib.request import urlretrieve | |
| urlretrieve(f"https://huggingface.co/FredZhang7/google-safesearch-mini-v2/resolve/main/pytorch_model.bin", model_path) | |
| self.model = timm.create_model("inception_resnet_v2", pretrained=False, num_classes=3) | |
| self.model.load_state_dict(torch.load(model_path, map_location=torch.device(config.device))) | |
| def forward(self, input_ids: torch.Tensor): | |
| return self.model(input_ids) |