Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Script to download and cache models for offline use. | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Add the project root to the path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| # Try multiple ways to import the config | |
| try: | |
| from config.model_config import DEFAULT_MODELS | |
| except ImportError as e: | |
| logger.error(f"Failed to import config: {e}") | |
| logger.info(f"Current working directory: {os.getcwd()}") | |
| logger.info(f"Project root: {project_root}") | |
| logger.info(f"Python path: {sys.path}") | |
| # Try alternative import approach | |
| try: | |
| from src.config.model_config import DEFAULT_MODELS | |
| logger.info("Successfully imported config from src.config.model_config") | |
| except ImportError as e2: | |
| logger.error(f"Also failed to import from src.config.model_config: {e2}") | |
| raise ImportError("Could not import model configuration. Please check your Python path and module structure.") | |
| def download_model(model_name: str, cache_dir: str = None): | |
| """Download a model from Hugging Face Hub if it doesn't exist locally.""" | |
| if model_name not in DEFAULT_MODELS: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| config = DEFAULT_MODELS[model_name] | |
| model_path = config.model_path | |
| # Use cache_dir if provided, otherwise use the model's path | |
| if cache_dir: | |
| model_path = os.path.join(cache_dir, os.path.basename(model_path)) | |
| print(f"Downloading {model_name} to {model_path}...") | |
| # Create model directory if it doesn't exist | |
| os.makedirs(model_path, exist_ok=True) | |
| # Download the model | |
| snapshot_download( | |
| repo_id=config.model_id, | |
| local_dir=model_path, | |
| local_dir_use_symlinks=True, | |
| ignore_patterns=["*.h5", "*.ot", "*.msgpack"], | |
| ) | |
| print(f"Successfully downloaded {model_name} to {model_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Download and cache models for offline use") | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default="all", | |
| help="Model to download (default: all)" | |
| ) | |
| parser.add_argument( | |
| "--cache-dir", | |
| type=str, | |
| default=None, | |
| help="Directory to cache models (default: model's default path)" | |
| ) | |
| args = parser.parse_args() | |
| if args.model.lower() == "all": | |
| for model_name in DEFAULT_MODELS.keys(): | |
| try: | |
| download_model(model_name, args.cache_dir) | |
| except Exception as e: | |
| print(f"Error downloading {model_name}: {e}") | |
| else: | |
| if args.model not in DEFAULT_MODELS: | |
| print(f"Error: Unknown model {args.model}") | |
| print(f"Available models: {', '.join(DEFAULT_MODELS.keys())}") | |
| return | |
| download_model(args.model, args.cache_dir) | |
| if __name__ == "__main__": | |
| main() | |