File size: 3,800 Bytes
eabf707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c09b32b
eabf707
c09b32b
eabf707
 
 
 
c09b32b
eabf707
c09b32b
eabf707
c09b32b
eabf707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from huggingface_hub import hf_hub_download, snapshot_download

def download_from_huggingface(repo_name, token):
    """
    Download model checkpoints, embeddings, and all intermediary files from Hugging Face Hub.
    
    Args:
        repo_name (str): Name of the repository on Hugging Face
        token (str): Hugging Face API token
    """
    # Create necessary directories
    os.makedirs('cbow/checkpoints', exist_ok=True)
    os.makedirs('checkpoints', exist_ok=True)
    os.makedirs('data', exist_ok=True)
    os.makedirs('config', exist_ok=True)
    os.makedirs('src', exist_ok=True)

    # Download CBOW checkpoints
    try:
        cbow_files = snapshot_download(
            repo_id=repo_name,
            repo_type="model",
            token=token,
            local_dir="cbow/checkpoints",
            allow_patterns="cbow/checkpoints/*.pth"
        )
        print("Downloaded CBOW checkpoints")
    except Exception as e:
        print(f"Error downloading CBOW checkpoints: {e}")

    # Download main checkpoints
    try:
        main_files = snapshot_download(
            repo_id=repo_name,
            repo_type="model",
            token=token,
            local_dir="checkpoints",
            allow_patterns="checkpoints/*.pth"
        )
        print("Downloaded main checkpoints")
    except Exception as e:
        print(f"Error downloading main checkpoints: {e}")

    # Download raw and intermediary data files
    data_files = [
        'tokenized_triples.json',
        'triples_small.json',
        'extracted_data.json',
        'corpus.pkl',
        'text8'
    ]
    
    for data_file in data_files:
        try:
            hf_hub_download(
                repo_id=repo_name,
                repo_type="model",
                token=token,
                filename=f"data/{data_file}",
                local_dir="."
            )
            print(f"Downloaded {data_file}")
        except Exception as e:
            print(f"Error downloading {data_file}: {e}")

    # Download all tokenizer files from cbow directory
    try:
        cbow_files = snapshot_download(
            repo_id=repo_name,
            repo_type="model",
            token=token,
            local_dir="cbow",
            allow_patterns="cbow/*.pkl"
        )
        print("Downloaded CBOW tokenizer files")
    except Exception as e:
        print(f"Error downloading CBOW tokenizer files: {e}")

    # Download configuration files
    config_files = ['sweep.yaml', 'requirements.txt']
    for config_file in config_files:
        try:
            hf_hub_download(
                repo_id=repo_name,
                repo_type="model",
                token=token,
                filename=f"config/{config_file}",
                local_dir="."
            )
            print(f"Downloaded {config_file}")
        except Exception as e:
            print(f"Error downloading {config_file}: {e}")

    # Download source code files
    try:
        code_files = snapshot_download(
            repo_id=repo_name,
            repo_type="model",
            token=token,
            local_dir=".",
            allow_patterns="src/*.py"
        )
        print("Downloaded source code files")
    except Exception as e:
        print(f"Error downloading source code files: {e}")

    print("\nDownload complete! Files are ready for training.")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Download model files from Hugging Face Hub')
    parser.add_argument('--repo_name', type=str, required=True, help='Name of the repository on Hugging Face')
    parser.add_argument('--token', type=str, required=True, help='Hugging Face API token')
    args = parser.parse_args()
    
    download_from_huggingface(args.repo_name, args.token)