|
|
import os |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
|
def download_from_huggingface(repo_name, token): |
|
|
""" |
|
|
Download model checkpoints, embeddings, and all intermediary files from Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
repo_name (str): Name of the repository on Hugging Face |
|
|
token (str): Hugging Face API token |
|
|
""" |
|
|
|
|
|
os.makedirs('cbow/checkpoints', exist_ok=True) |
|
|
os.makedirs('checkpoints', exist_ok=True) |
|
|
os.makedirs('data', exist_ok=True) |
|
|
os.makedirs('config', exist_ok=True) |
|
|
os.makedirs('src', exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
cbow_files = snapshot_download( |
|
|
repo_id=repo_name, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
local_dir="cbow/checkpoints", |
|
|
allow_patterns="cbow/checkpoints/*.pth" |
|
|
) |
|
|
print("Downloaded CBOW checkpoints") |
|
|
except Exception as e: |
|
|
print(f"Error downloading CBOW checkpoints: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
main_files = snapshot_download( |
|
|
repo_id=repo_name, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
local_dir="checkpoints", |
|
|
allow_patterns="checkpoints/*.pth" |
|
|
) |
|
|
print("Downloaded main checkpoints") |
|
|
except Exception as e: |
|
|
print(f"Error downloading main checkpoints: {e}") |
|
|
|
|
|
|
|
|
data_files = [ |
|
|
'tokenized_triples.json', |
|
|
'triples_small.json', |
|
|
'extracted_data.json', |
|
|
'corpus.pkl', |
|
|
'text8' |
|
|
] |
|
|
|
|
|
for data_file in data_files: |
|
|
try: |
|
|
hf_hub_download( |
|
|
repo_id=repo_name, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
filename=f"data/{data_file}", |
|
|
local_dir="." |
|
|
) |
|
|
print(f"Downloaded {data_file}") |
|
|
except Exception as e: |
|
|
print(f"Error downloading {data_file}: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
cbow_files = snapshot_download( |
|
|
repo_id=repo_name, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
local_dir="cbow", |
|
|
allow_patterns="cbow/*.pkl" |
|
|
) |
|
|
print("Downloaded CBOW tokenizer files") |
|
|
except Exception as e: |
|
|
print(f"Error downloading CBOW tokenizer files: {e}") |
|
|
|
|
|
|
|
|
config_files = ['sweep.yaml', 'requirements.txt'] |
|
|
for config_file in config_files: |
|
|
try: |
|
|
hf_hub_download( |
|
|
repo_id=repo_name, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
filename=f"config/{config_file}", |
|
|
local_dir="." |
|
|
) |
|
|
print(f"Downloaded {config_file}") |
|
|
except Exception as e: |
|
|
print(f"Error downloading {config_file}: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
code_files = snapshot_download( |
|
|
repo_id=repo_name, |
|
|
repo_type="model", |
|
|
token=token, |
|
|
local_dir=".", |
|
|
allow_patterns="src/*.py" |
|
|
) |
|
|
print("Downloaded source code files") |
|
|
except Exception as e: |
|
|
print(f"Error downloading source code files: {e}") |
|
|
|
|
|
print("\nDownload complete! Files are ready for training.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser(description='Download model files from Hugging Face Hub') |
|
|
parser.add_argument('--repo_name', type=str, required=True, help='Name of the repository on Hugging Face') |
|
|
parser.add_argument('--token', type=str, required=True, help='Hugging Face API token') |
|
|
args = parser.parse_args() |
|
|
|
|
|
download_from_huggingface(args.repo_name, args.token) |