File size: 4,341 Bytes
2c76547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import glob
import argparse
import hashlib
import json

from typing import Optional
from multiprocessing import Pool
from tqdm import tqdm


DEFAULT_SHA256S_FILE = os.path.join(__file__.rsplit(os.sep, 2)[0], "dr_sha256.json")
BLOCKSIZE = 65536


def main(
    download_folder: str,
    sha256s_file: str,
    dump: bool = False,
    n_sha256_workers: int = 4
):
    if not os.path.isfile(sha256s_file):
        raise ValueError(f"The SHA256 file does not exist ({sha256s_file}).")

    expected_sha256s = get_expected_sha256s(
        sha256s_file=sha256s_file
    )

    zipfiles = sorted(glob.glob(os.path.join(download_folder, "*.zip")))
    print(f"Extracting SHA256 hashes for {len(zipfiles)} files in {download_folder}.")
    extracted_sha256s_list = []
    with Pool(processes=n_sha256_workers) as sha_pool:
        for extracted_hash in tqdm(
            sha_pool.imap(_sha256_file_and_print, zipfiles),
            total=len(zipfiles),
        ):
            extracted_sha256s_list.append(extracted_hash)
            pass

    extracted_sha256s = dict(
        zip([os.path.split(z)[-1] for z in zipfiles], extracted_sha256s_list)
    )

    if dump:
        print(extracted_sha256s)
        with open(sha256s_file, "w") as f:
            json.dump(extracted_sha256s, f, indent=2)

    
    missing_keys, invalid_keys = [], []
    for k in expected_sha256s.keys():
        if k not in extracted_sha256s:
            print(f"{k} missing!")
            missing_keys.append(k)
        elif expected_sha256s[k] != extracted_sha256s[k]:
            print(
                f"'{k}' does not match!"
                + f" ({expected_sha256s[k]} != {extracted_sha256s[k]})"
            )
            invalid_keys.append(k)
    if len(invalid_keys) + len(missing_keys) > 0:
        raise ValueError(
            f"Checksum checker failed!"
            + f" Non-matching checksums: {str(invalid_keys)};"
            + f" missing files: {str(missing_keys)}."
        )


def get_expected_sha256s(
    sha256s_file: str
):
    with open(sha256s_file, "r") as f:
        expected_sha256s = json.load(f)
    return expected_sha256s


def check_dr_sha256(
    path: str,
    sha256s_file: str,
    expected_sha256s: Optional[dict] = None,
    do_assertion: bool = True,
):
    zipname = os.path.split(path)[-1]
    if expected_sha256s is None:
        expected_sha256s = get_expected_sha256s(
            sha256s_file=sha256s_file,
        )
    extracted_hash = sha256_file(path)
    if do_assertion:
        assert (
            extracted_hash == expected_sha256s[zipname]
        ), f"{zipname}: ({extracted_hash} != {expected_sha256s[zipname]})"
    else:
        return extracted_hash == expected_sha256s[zipname]


def sha256_file(path: str):
    sha256_hash = hashlib.sha256()
    with open(path, "rb") as f:
        file_buffer = f.read(BLOCKSIZE)
        while len(file_buffer) > 0:
            sha256_hash.update(file_buffer)
            file_buffer = f.read(BLOCKSIZE)
    digest_ = sha256_hash.hexdigest()
    return digest_


def _sha256_file_and_print(path: str):
    digest_ = sha256_file(path)
    print(f"{path}: {digest_}")
    return digest_



if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Check SHA256 hashes of the Dynamic Replica dataset."
    )
    parser.add_argument(
        "--download_folder",
        type=str,
        help="A local target folder for downloading the the dataset files.",
    )
    parser.add_argument(
        "--sha256s_file",
        type=str,
        help="A local target folder for downloading the the dataset files.",
        default=DEFAULT_SHA256S_FILE,
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="The number of sha256 extraction workers.",
    )
    parser.add_argument(
        "--dump_sha256s",
        action="store_true",
        help="Store sha256s hashes.",
    )

    args = parser.parse_args()
    main(
        str(args.download_folder),
        dump=bool(args.dump_sha256s),
        n_sha256_workers=int(args.num_workers),
        sha256s_file=str(args.sha256s_file),
    )