hazhu commited on
Commit
370ff13
·
verified ·
1 Parent(s): 786764f

Upload folder using huggingface_hub

Browse files
infer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+
5
+ import mlx.core as mx
6
+ from mlxDeepDanBooru.mlx_deep_danbooru_model import mlxDeepDanBooruModel
7
+
8
+
9
+ model_path = "models/model-resnet_custom_v3_mlx.npz"
10
+ tags_path = 'models/tags-resnet_custom_v3_mlx.npy'
11
+
12
+ mlx_dan = mlxDeepDanBooruModel()
13
+ mlx_dan.load_weights(model_path)
14
+ mx.eval(mlx_dan.parameters())
15
+
16
+
17
+ model_tags = np.load(tags_path)
18
+ print(f'total tags: {len(model_tags)}')
19
+
20
+ def danbooru_tags(fpath):
21
+ tags = []
22
+ pic = Image.open(fpath).convert("RGB").resize((512, 512))
23
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
24
+
25
+ x = mx.array(a)
26
+ y = mlx_dan(x)[0]
27
+
28
+ for n in range(10):
29
+ mlx_dan(x)
30
+ for i, p in enumerate(y):
31
+ if p >= 0.5:
32
+ #print(model_tags[i].item(), p)
33
+ tags.append(model_tags[i].item())
34
+
35
+ return tags
36
+
37
+ image_count = 0
38
+ def image_infer(fpath):
39
+ global image_count
40
+ tags = danbooru_tags(fpath)
41
+ image_count += 1
42
+ return tags
43
+
44
+
45
+ t1 = time.time()
46
+ tags_1 = image_infer("example/1.png")
47
+ tags_2 = image_infer("example/2.png")
48
+
49
+ t2 = time.time()
50
+
51
+ print(tags_1)
52
+ print(tags_2)
53
+ # print(tags_3)
54
+ # print(tags_4)
55
+ # print(tags_5)
56
+
57
+ print("-----------")
58
+ print(f'infer speed(with mlx): {(t2 - t1)/image_count} seconds per image')
59
+
60
+
61
+
62
+
63
+
64
+
models/model-resnet_custom_v3_mlx.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d07dbe2ec2e219c93545467c6109d8818b402a59adf50c8dc441cb8ca366d11
3
+ size 643902466
models/tags-resnet_custom_v3_mlx.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cc7848ad89cbfeef8aa935f2ce370f84c40ba61cfc928852fa7b6e034e04430
3
+ size 1761920