AnikS22 commited on
Commit
d1fb167
·
verified ·
1 Parent(s): 28564a3

Upload scripts/download_cem500k.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/download_cem500k.py +79 -0
scripts/download_cem500k.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download CEM500K MoCoV2 ResNet-50 pretrained weights from Zenodo.
3
+
4
+ Usage:
5
+ python scripts/download_cem500k.py
6
+ python scripts/download_cem500k.py --output weights/cem500k_mocov2_resnet50.pth.tar
7
+ """
8
+
9
+ import argparse
10
+ import hashlib
11
+ import os
12
+ import sys
13
+ import urllib.request
14
+ from pathlib import Path
15
+
16
+
17
+ CEM500K_URL = "https://zenodo.org/records/6453140/files/cem500k_mocov2_resnet50_200ep.pth.tar?download=1"
18
+ DEFAULT_OUTPUT = "weights/cem500k_mocov2_resnet50.pth.tar"
19
+
20
+
21
+ def download_with_progress(url: str, output_path: str):
22
+ """Download file with progress indicator."""
23
+ print(f"Downloading from {url}")
24
+ print(f"Saving to {output_path}")
25
+
26
+ def _progress_hook(block_num, block_size, total_size):
27
+ downloaded = block_num * block_size
28
+ if total_size > 0:
29
+ pct = min(100, downloaded * 100 / total_size)
30
+ mb = downloaded / (1024 * 1024)
31
+ total_mb = total_size / (1024 * 1024)
32
+ sys.stdout.write(f"\r {mb:.1f}/{total_mb:.1f} MB ({pct:.0f}%)")
33
+ sys.stdout.flush()
34
+
35
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
36
+ urllib.request.urlretrieve(url, output_path, reporthook=_progress_hook)
37
+ print() # newline after progress
38
+
39
+
40
+ def verify_file(path: str):
41
+ """Verify the downloaded file is a valid PyTorch checkpoint."""
42
+ import torch
43
+ try:
44
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
45
+ keys = list(ckpt.keys()) if isinstance(ckpt, dict) else []
46
+ print(f"Checkpoint keys: {keys}")
47
+ if "state_dict" in ckpt:
48
+ n_params = len(ckpt["state_dict"])
49
+ print(f"State dict entries: {n_params}")
50
+ print("Verification PASSED")
51
+ return True
52
+ except Exception as e:
53
+ print(f"Verification FAILED: {e}")
54
+ return False
55
+
56
+
57
+ def main():
58
+ parser = argparse.ArgumentParser(description="Download CEM500K weights")
59
+ parser.add_argument("--output", default=DEFAULT_OUTPUT)
60
+ parser.add_argument("--force", action="store_true", help="Re-download even if exists")
61
+ args = parser.parse_args()
62
+
63
+ if os.path.exists(args.output) and not args.force:
64
+ size_mb = os.path.getsize(args.output) / (1024 * 1024)
65
+ print(f"File already exists: {args.output} ({size_mb:.1f} MB)")
66
+ print("Use --force to re-download")
67
+ verify_file(args.output)
68
+ return
69
+
70
+ download_with_progress(CEM500K_URL, args.output)
71
+
72
+ size_mb = os.path.getsize(args.output) / (1024 * 1024)
73
+ print(f"Downloaded: {size_mb:.1f} MB")
74
+
75
+ verify_file(args.output)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()