roshbeed commited on
Commit
eabf707
·
verified ·
1 Parent(s): 7e73a05

Upload src/download_from_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/download_from_hf.py +118 -0
src/download_from_hf.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download, snapshot_download
3
+
4
+ def download_from_huggingface(repo_name, token):
5
+ """
6
+ Download model checkpoints, embeddings, and all intermediary files from Hugging Face Hub.
7
+
8
+ Args:
9
+ repo_name (str): Name of the repository on Hugging Face
10
+ token (str): Hugging Face API token
11
+ """
12
+ # Create necessary directories
13
+ os.makedirs('cbow/checkpoints', exist_ok=True)
14
+ os.makedirs('checkpoints', exist_ok=True)
15
+ os.makedirs('vocabulary', exist_ok=True)
16
+ os.makedirs('data', exist_ok=True)
17
+ os.makedirs('config', exist_ok=True)
18
+ os.makedirs('src', exist_ok=True)
19
+
20
+ # Download CBOW checkpoints
21
+ try:
22
+ cbow_files = snapshot_download(
23
+ repo_id=repo_name,
24
+ repo_type="model",
25
+ token=token,
26
+ local_dir="cbow/checkpoints",
27
+ allow_patterns="cbow/checkpoints/*.pth"
28
+ )
29
+ print("Downloaded CBOW checkpoints")
30
+ except Exception as e:
31
+ print(f"Error downloading CBOW checkpoints: {e}")
32
+
33
+ # Download main checkpoints
34
+ try:
35
+ main_files = snapshot_download(
36
+ repo_id=repo_name,
37
+ repo_type="model",
38
+ token=token,
39
+ local_dir="checkpoints",
40
+ allow_patterns="checkpoints/*.pth"
41
+ )
42
+ print("Downloaded main checkpoints")
43
+ except Exception as e:
44
+ print(f"Error downloading main checkpoints: {e}")
45
+
46
+ # Download raw and intermediary data files
47
+ data_files = [
48
+ 'tokenized_triples.json',
49
+ 'triples_small.json',
50
+ 'extracted_data.json',
51
+ 'corpus.pkl',
52
+ 'text8'
53
+ ]
54
+
55
+ for data_file in data_files:
56
+ try:
57
+ hf_hub_download(
58
+ repo_id=repo_name,
59
+ repo_type="model",
60
+ token=token,
61
+ filename=f"data/{data_file}",
62
+ local_dir="."
63
+ )
64
+ print(f"Downloaded {data_file}")
65
+ except Exception as e:
66
+ print(f"Error downloading {data_file}: {e}")
67
+
68
+ # Download vocabulary files
69
+ try:
70
+ vocab_files = snapshot_download(
71
+ repo_id=repo_name,
72
+ repo_type="model",
73
+ token=token,
74
+ local_dir="cbow",
75
+ allow_patterns="vocabulary/*.pkl"
76
+ )
77
+ print("Downloaded vocabulary files")
78
+ except Exception as e:
79
+ print(f"Error downloading vocabulary files: {e}")
80
+
81
+ # Download configuration files
82
+ config_files = ['sweep.yaml', 'requirements.txt']
83
+ for config_file in config_files:
84
+ try:
85
+ hf_hub_download(
86
+ repo_id=repo_name,
87
+ repo_type="model",
88
+ token=token,
89
+ filename=f"config/{config_file}",
90
+ local_dir="."
91
+ )
92
+ print(f"Downloaded {config_file}")
93
+ except Exception as e:
94
+ print(f"Error downloading {config_file}: {e}")
95
+
96
+ # Download source code files
97
+ try:
98
+ code_files = snapshot_download(
99
+ repo_id=repo_name,
100
+ repo_type="model",
101
+ token=token,
102
+ local_dir=".",
103
+ allow_patterns="src/*.py"
104
+ )
105
+ print("Downloaded source code files")
106
+ except Exception as e:
107
+ print(f"Error downloading source code files: {e}")
108
+
109
+ print("\nDownload complete! Files are ready for training.")
110
+
111
+ if __name__ == "__main__":
112
+ import argparse
113
+ parser = argparse.ArgumentParser(description='Download model files from Hugging Face Hub')
114
+ parser.add_argument('--repo_name', type=str, required=True, help='Name of the repository on Hugging Face')
115
+ parser.add_argument('--token', type=str, required=True, help='Hugging Face API token')
116
+ args = parser.parse_args()
117
+
118
+ download_from_huggingface(args.repo_name, args.token)