hazhu commited on
Commit
aebb6ba
·
verified ·
1 Parent(s): 5228ac6

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. infer.py +33 -21
  2. infer_multiprocessing.py +93 -0
infer.py CHANGED
@@ -1,13 +1,21 @@
 
1
  import time
 
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
 
5
  import mlx.core as mx
6
  from mlxDeepDanBooru.mlx_deep_danbooru_model import mlxDeepDanBooruModel
7
 
 
 
8
 
9
- model_path = "models/model-resnet_custom_v3_mlx.npz"
10
- tags_path = 'models/tags-resnet_custom_v3_mlx.npy'
 
 
 
 
11
 
12
  mlx_dan = mlxDeepDanBooruModel()
13
  mlx_dan.load_weights(model_path)
@@ -15,50 +23,54 @@ mx.eval(mlx_dan.parameters())
15
 
16
 
17
  model_tags = np.load(tags_path)
18
- print(f'total tags: {len(model_tags)}')
19
 
20
  def danbooru_tags(fpath):
 
21
  tags = []
 
22
  pic = Image.open(fpath).convert("RGB").resize((512, 512))
23
  a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
24
 
25
  x = mx.array(a)
26
  y = mlx_dan(x)[0]
27
 
28
- for n in range(10):
29
- mlx_dan(x)
30
- for i, p in enumerate(y):
31
- if p >= 0.5:
32
- #print(model_tags[i].item(), p)
33
- tags.append(model_tags[i].item())
 
 
 
 
 
 
34
 
35
- return tags
36
 
37
- image_count = 0
38
  def image_infer(fpath):
39
- global image_count
40
  tags = danbooru_tags(fpath)
41
- image_count += 1
42
  return tags
43
 
44
-
45
  t1 = time.time()
46
- tags_1 = image_infer("example/1.png")
47
- tags_2 = image_infer("example/2.png")
 
48
 
49
  t2 = time.time()
50
 
51
  print(tags_1)
52
  print(tags_2)
53
- # print(tags_3)
54
- # print(tags_4)
55
- # print(tags_5)
56
 
57
- print("-----------")
58
- print(f'infer speed(with mlx): {(t2 - t1)/image_count} seconds per image')
 
 
59
 
60
 
61
 
 
62
 
63
 
64
 
 
1
+ import os
2
  import time
3
+ import glob
4
  import numpy as np
5
  from PIL import Image, ImageDraw
6
 
7
  import mlx.core as mx
8
  from mlxDeepDanBooru.mlx_deep_danbooru_model import mlxDeepDanBooruModel
9
 
10
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
11
+ from copy import deepcopy
12
 
13
+ ROOTDIR = os.path.dirname(os.path.abspath(__file__))
14
+ IMAGEDIR = f'{ROOTDIR}/example'
15
+
16
+
17
+ model_path = f"{ROOTDIR}/models/model-resnet_custom_v3_mlx.npz"
18
+ tags_path = f'{ROOTDIR}/models/tags-resnet_custom_v3_mlx.npy'
19
 
20
  mlx_dan = mlxDeepDanBooruModel()
21
  mlx_dan.load_weights(model_path)
 
23
 
24
 
25
  model_tags = np.load(tags_path)
26
+ #print(f'total tags: {len(model_tags)}')
27
 
28
  def danbooru_tags(fpath):
29
+ results = {}
30
  tags = []
31
+
32
  pic = Image.open(fpath).convert("RGB").resize((512, 512))
33
  a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
34
 
35
  x = mx.array(a)
36
  y = mlx_dan(x)[0]
37
 
38
+ try:
39
+ for n in range(10):
40
+ mlx_dan(x)
41
+ for i, p in enumerate(y):
42
+ if p >= 0.55:
43
+ #print(model_tags[i].item(), p)
44
+ tags.append(model_tags[i].item())
45
+ except Exception as err:
46
+ print(err)
47
+
48
+ results[fpath] = tags
49
+ return results
50
 
 
51
 
 
52
  def image_infer(fpath):
 
53
  tags = danbooru_tags(fpath)
 
54
  return tags
55
 
 
56
  t1 = time.time()
57
+
58
+ tags_1 = image_infer(f'{IMAGEDIR}/1.png')
59
+ tags_2 = image_infer(f'{IMAGEDIR}/2.png')
60
 
61
  t2 = time.time()
62
 
63
  print(tags_1)
64
  print(tags_2)
 
 
 
65
 
66
+ print(f'2 images: infer speed(with mlx): {(t2 - t1)/2} seconds per image')
67
+
68
+
69
+
70
 
71
 
72
 
73
+
74
 
75
 
76
 
infer_multiprocessing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import glob
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+
7
+ import mlx.core as mx
8
+ from mlxDeepDanBooru.mlx_deep_danbooru_model import mlxDeepDanBooruModel
9
+
10
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
11
+ from copy import deepcopy
12
+
13
+ ROOTDIR = os.path.dirname(os.path.abspath(__file__))
14
+ IMAGEDIR = f'{ROOTDIR}/example'
15
+
16
+ worker_count = os.cpu_count()
17
+ # worker_count depends on your unified-memory size
18
+ # if oom, decrease the number
19
+
20
+ model_path = f"{ROOTDIR}/models/model-resnet_custom_v3_mlx.npz"
21
+ tags_path = f'{ROOTDIR}/models/tags-resnet_custom_v3_mlx.npy'
22
+
23
+ mlx_dan = mlxDeepDanBooruModel()
24
+ mlx_dan.load_weights(model_path)
25
+ mx.eval(mlx_dan.parameters())
26
+
27
+
28
+ model_tags = np.load(tags_path)
29
+ #print(f'total tags: {len(model_tags)}')
30
+
31
+ def danbooru_tags(fpath):
32
+ results = {}
33
+ tags = []
34
+
35
+ pic = Image.open(fpath).convert("RGB").resize((512, 512))
36
+ a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
37
+
38
+ x = mx.array(a)
39
+ y = mlx_dan(x)[0]
40
+
41
+ try:
42
+ for n in range(10):
43
+ mlx_dan(x)
44
+ for i, p in enumerate(y):
45
+ if p >= 0.55:
46
+
47
+ #print(model_tags[i].item(), p)
48
+ tags.append(model_tags[i].item())
49
+ except Exception as err:
50
+ print(err)
51
+
52
+ results[fpath] = tags
53
+ return results
54
+
55
+
56
+ def image_infer(fpath):
57
+ tags = danbooru_tags(fpath)
58
+ return tags
59
+
60
+
61
+ def batch_infer(image_list):
62
+ workers = min(len(image_list), worker_count)
63
+ print(f'workers: {workers}: {os.cpu_count()}')
64
+ with ProcessPoolExecutor(max_workers=workers) as executor:
65
+ process_results = list(executor.map(image_infer, image_list))
66
+ return process_results
67
+
68
+
69
+
70
+ if __name__ == '__main__':
71
+ image_list = []
72
+ for root, dirs, files in os.walk(IMAGEDIR, True):
73
+ for file in files:
74
+ if not file[-4:].lower() in [".png", ".jpg", "jpeg"]:
75
+ continue
76
+ fpath = os.path.join(root, file).replace("\\","/")
77
+ image_list.append(fpath)
78
+
79
+ #print(image_list)
80
+
81
+
82
+ t1 = time.time()
83
+ lines = batch_infer(image_list)
84
+ t2 = time.time()
85
+
86
+ for line in lines:
87
+ print(line)
88
+ print("-----------")
89
+
90
+ print(f'{len(image_list)} images: infer speed(with mlx): {(t2 - t1)/len(image_list)} seconds per image')
91
+
92
+
93
+