Plana-Archive commited on
Commit
9a50d38
·
verified ·
1 Parent(s): 9a91581

Upload monochrome_detection/monochrome.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. monochrome_detection/monochrome.py +60 -0
monochrome_detection/monochrome.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ from PIL.Image import Resampling
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from encode import rgb_encode
10
+ from image import ImageTyping, load_image
11
+ from onnxruntime_ import open_onnx_model
12
+
13
+ __all__ = [
14
+ 'get_monochrome_score',
15
+ 'is_monochrome',
16
+ ]
17
+
18
+ # _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
19
+ _MONOCHROME_CKPTS = [
20
+ 'mobilenetv3_large_100_safe2',
21
+ 'mobilenetv3_large_100',
22
+ 'caformer_s36',
23
+ ]
24
+ _DEFAULT_MONOCHROME_CKPT = _MONOCHROME_CKPTS[0]
25
+
26
+
27
+ @lru_cache()
28
+ def _monochrome_validate_model(model):
29
+ return open_onnx_model(hf_hub_download(
30
+ f'deepghs/monochrome_detect',
31
+ f'{model}/model.onnx'
32
+ ))
33
+
34
+
35
+ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
36
+ normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
37
+ if image.mode != 'RGB':
38
+ image = image.convert('RGB')
39
+ image = image.resize(size, Resampling.BILINEAR)
40
+ data = rgb_encode(image, order_='CHW')
41
+
42
+ if normalize is not None:
43
+ mean_, std_ = normalize
44
+ mean = np.asarray([mean_]).reshape((-1, 1, 1))
45
+ std = np.asarray([std_]).reshape((-1, 1, 1))
46
+ data = (data - mean) / std
47
+
48
+ return data
49
+
50
+
51
+ def get_monochrome_score(image: ImageTyping, model: str = _DEFAULT_MONOCHROME_CKPT):
52
+ image = load_image(image, mode='RGB')
53
+ input_data = _2d_encode(image).astype(np.float32)
54
+ input_data = np.stack([input_data])
55
+ output_data, = _monochrome_validate_model(model).run(['output'], {'input': input_data})
56
+ return {name: v.item() for name, v in zip(['monochrome', 'normal'], output_data[0])}
57
+
58
+
59
+ def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
60
+ return get_monochrome_score(image, ckpt) >= threshold