File size: 1,984 Bytes
7b7527a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import glob
import random
import fnmatch
import re
import sys

class_id = {"nofight": 0, "fight": 1}


def get_list(path, key_func=lambda x: x[-11:], rgb_prefix='img_', level=1):
    if level == 1:
        frame_folders = glob.glob(os.path.join(path, '*'))
    elif level == 2:
        frame_folders = glob.glob(os.path.join(path, '*', '*'))
    else:
        raise ValueError('level can be only 1 or 2')

    def count_files(directory):
        lst = os.listdir(directory)
        cnt = len(fnmatch.filter(lst, rgb_prefix + '*'))
        return cnt

    # check RGB
    video_dict = {}
    for f in frame_folders:
        cnt = count_files(f)
        k = key_func(f)
        if level == 2:
            k = k.split("/")[0]

        video_dict[f] = str(cnt) + " " + str(class_id[k])

    return video_dict


def fight_splits(video_dict, train_percent=0.8):
    videos = list(video_dict.keys())

    train_num = int(len(videos) * train_percent)

    train_list = []
    val_list = []

    random.shuffle(videos)

    for i in range(train_num):
        train_list.append(videos[i] + " " + str(video_dict[videos[i]]))
    for i in range(train_num, len(videos)):
        val_list.append(videos[i] + " " + str(video_dict[videos[i]]))

    print("train:", len(train_list), ",val:", len(val_list))

    with open("fight_train_list.txt", "w") as f:
        for item in train_list:
            f.write(item + "\n")

    with open("fight_val_list.txt", "w") as f:
        for item in val_list:
            f.write(item + "\n")


if __name__ == "__main__":
    frame_dir = sys.argv[1]  # "rawframes"
    level = sys.argv[2]  # 2
    train_percent = sys.argv[3]  # 0.8

    if level == 2:

        def key_func(x):
            return '/'.join(x.split('/')[-2:])
    else:

        def key_func(x):
            return x.split('/')[-1]

    video_dict = get_list(frame_dir, key_func=key_func, level=level)
    print("number:", len(video_dict))

    fight_splits(video_dict, train_percent)