Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import random | |
| from functools import lru_cache | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import HfFileSystem, HfApi | |
| from imgutils.utils import open_onnx_model | |
| from natsort import natsorted | |
| hf_token = os.environ.get('HF_TOKEN') | |
| hf_fs = HfFileSystem(token=hf_token) | |
| hf_client = HfApi(token=hf_token) | |
| REPOSITORY = 'mf666/shit-checker' | |
| MODELS = natsorted([ | |
| os.path.splitext(os.path.relpath(file, REPOSITORY))[0] | |
| for file in hf_fs.glob(f'{REPOSITORY}/*.onnx') | |
| ]) | |
| DEFAULT_MODEL = 'mobilenet.xs.v2' | |
| def _open_model(model_name): | |
| return open_onnx_model(hf_client.hf_hub_download(REPOSITORY, f'{model_name}.onnx')) | |
| _DEFAULT_ORDER = 'HWC' | |
| def _get_hwc_map(order_): | |
| return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper()) | |
| def _encode_channels(image, channels_order='CHW', is_float=True): | |
| array = np.asarray(image.convert('RGB')) | |
| array = np.transpose(array, _get_hwc_map(channels_order)) | |
| if not is_float: | |
| assert array.dtype == np.uint8 | |
| else: | |
| array = (array / 255.0).astype(np.float32) | |
| assert array.dtype == np.float32 | |
| return array | |
| def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)): | |
| image = image.resize(size, Image.BILINEAR) | |
| data = _encode_channels(image, channels_order='CHW') | |
| if normalize is not None: | |
| mean_, std_ = normalize | |
| mean = np.asarray([mean_]).reshape((-1, 1, 1)) | |
| std = np.asarray([std_]).reshape((-1, 1, 1)) | |
| data = (data - mean) / std | |
| return data.astype(np.float32) | |
| def _raw_predict(images, model_name=DEFAULT_MODEL): | |
| items = [] | |
| for image in images: | |
| items.append(_img_encode(image.convert('RGB'))) | |
| input_ = np.stack(items) | |
| output, = _open_model(model_name).run(['output'], {'input': input_}) | |
| return output.mean(axis=0) | |
| def predict(image, model_name=DEFAULT_MODEL, max_batch_size=8): | |
| area = image.width * image.height | |
| batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1)) | |
| blocks = [] | |
| for _ in range(batch_size): | |
| x0 = random.randint(0, max(0, image.width - 384)) | |
| y0 = random.randint(0, max(0, image.height - 384)) | |
| x1 = min(x0 + 384, image.width) | |
| y1 = min(y0 + 384, image.height) | |
| blocks.append(image.crop((x0, y0, x1, y1))) | |
| scores = _raw_predict(blocks, model_name) | |
| return dict(zip(['shat', 'normal'], map(lambda x: x.item(), scores))) | |