File size: 2,135 Bytes
24a7f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
"""
Script to download and cache models for offline use.
"""
import os
import argparse
from huggingface_hub import snapshot_download
from config.model_config import DEFAULT_MODELS

def download_model(model_name: str, cache_dir: str = None):
    """Download a model from Hugging Face Hub if it doesn't exist locally."""
    if model_name not in DEFAULT_MODELS:
        raise ValueError(f"Unknown model: {model_name}")
    
    config = DEFAULT_MODELS[model_name]
    model_path = config.model_path
    
    # Use cache_dir if provided, otherwise use the model's path
    if cache_dir:
        model_path = os.path.join(cache_dir, os.path.basename(model_path))
    
    print(f"Downloading {model_name} to {model_path}...")
    
    # Create model directory if it doesn't exist
    os.makedirs(model_path, exist_ok=True)
    
    # Download the model
    snapshot_download(
        repo_id=config.model_id,
        local_dir=model_path,
        local_dir_use_symlinks=True,
        ignore_patterns=["*.h5", "*.ot", "*.msgpack"],
    )
    
    print(f"Successfully downloaded {model_name} to {model_path}")

def main():
    parser = argparse.ArgumentParser(description="Download and cache models for offline use")
    parser.add_argument(
        "--model",
        type=str,
        default="all",
        help="Model to download (default: all)"
    )
    parser.add_argument(
        "--cache-dir",
        type=str,
        default=None,
        help="Directory to cache models (default: model's default path)"
    )
    
    args = parser.parse_args()
    
    if args.model.lower() == "all":
        for model_name in DEFAULT_MODELS.keys():
            try:
                download_model(model_name, args.cache_dir)
            except Exception as e:
                print(f"Error downloading {model_name}: {e}")
    else:
        if args.model not in DEFAULT_MODELS:
            print(f"Error: Unknown model {args.model}")
            print(f"Available models: {', '.join(DEFAULT_MODELS.keys())}")
            return
        download_model(args.model, args.cache_dir)

if __name__ == "__main__":
    main()