File size: 3,154 Bytes
24a7f55
 
 
 
 
f3e3305
24a7f55
f3e3305
 
24a7f55
f3e3305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24a7f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
"""
Script to download and cache models for offline use.
"""
import os
import sys
import argparse
import logging
from pathlib import Path
from huggingface_hub import snapshot_download

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Add the project root to the path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# Try multiple ways to import the config
try:
    from config.model_config import DEFAULT_MODELS
except ImportError as e:
    logger.error(f"Failed to import config: {e}")
    logger.info(f"Current working directory: {os.getcwd()}")
    logger.info(f"Project root: {project_root}")
    logger.info(f"Python path: {sys.path}")
    
    # Try alternative import approach
    try:
        from src.config.model_config import DEFAULT_MODELS
        logger.info("Successfully imported config from src.config.model_config")
    except ImportError as e2:
        logger.error(f"Also failed to import from src.config.model_config: {e2}")
        raise ImportError("Could not import model configuration. Please check your Python path and module structure.")

def download_model(model_name: str, cache_dir: str = None):
    """Download a model from Hugging Face Hub if it doesn't exist locally."""
    if model_name not in DEFAULT_MODELS:
        raise ValueError(f"Unknown model: {model_name}")
    
    config = DEFAULT_MODELS[model_name]
    model_path = config.model_path
    
    # Use cache_dir if provided, otherwise use the model's path
    if cache_dir:
        model_path = os.path.join(cache_dir, os.path.basename(model_path))
    
    print(f"Downloading {model_name} to {model_path}...")
    
    # Create model directory if it doesn't exist
    os.makedirs(model_path, exist_ok=True)
    
    # Download the model
    snapshot_download(
        repo_id=config.model_id,
        local_dir=model_path,
        local_dir_use_symlinks=True,
        ignore_patterns=["*.h5", "*.ot", "*.msgpack"],
    )
    
    print(f"Successfully downloaded {model_name} to {model_path}")

def main():
    parser = argparse.ArgumentParser(description="Download and cache models for offline use")
    parser.add_argument(
        "--model",
        type=str,
        default="all",
        help="Model to download (default: all)"
    )
    parser.add_argument(
        "--cache-dir",
        type=str,
        default=None,
        help="Directory to cache models (default: model's default path)"
    )
    
    args = parser.parse_args()
    
    if args.model.lower() == "all":
        for model_name in DEFAULT_MODELS.keys():
            try:
                download_model(model_name, args.cache_dir)
            except Exception as e:
                print(f"Error downloading {model_name}: {e}")
    else:
        if args.model not in DEFAULT_MODELS:
            print(f"Error: Unknown model {args.model}")
            print(f"Available models: {', '.join(DEFAULT_MODELS.keys())}")
            return
        download_model(args.model, args.cache_dir)

if __name__ == "__main__":
    main()