Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Download checkpoints from HuggingFace Hub. | |
| Downloads pre-trained model checkpoints for all categories and models. | |
| Checkpoints are stored on HuggingFace Hub to keep the Git repository lightweight. | |
| Usage: | |
| python scripts/download_checkpoints.py # Download all | |
| python scripts/download_checkpoints.py --model patchcore # Specific model | |
| python scripts/download_checkpoints.py --category bottle # Specific category | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| from pathlib import Path | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from huggingface_hub import hf_hub_download, list_repo_files, HfApi | |
| from tqdm import tqdm | |
| from core.config import MVTEC_CATEGORIES, DIR_RESULTS, get_available_models | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| # TODO: Replace with your HuggingFace Hub repository | |
| HF_REPO_ID = "micguida1/mvtec-anomaly-checkpoints" | |
| # Mapping from model name to result directory name | |
| MODEL_TO_DIRNAME = { | |
| "patchcore": "Patchcore", | |
| "efficientad": "EfficientAd", | |
| "fastflow": "Fastflow", | |
| "stfpm": "Stfpm", | |
| "padim": "Padim", | |
| } | |
| # ============================================================================= | |
| # DOWNLOAD FUNCTIONS | |
| # ============================================================================= | |
| def get_checkpoint_hf_path(model_name: str, category: str) -> str: | |
| """ | |
| Returns the path of the checkpoint file in the HF repository. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| Returns: | |
| Path string relative to HF repo root | |
| """ | |
| dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize()) | |
| return f"{dirname}/MVTecAD/{category}/latest/weights/lightning/model.ckpt" | |
| def get_metrics_hf_path(model_name: str, category: str) -> str: | |
| """ | |
| Returns the path of the metrics.json file in the HF repository. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| Returns: | |
| Path string relative to HF repo root | |
| """ | |
| dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize()) | |
| return f"{dirname}/MVTecAD/{category}/latest/metrics.json" | |
| def get_local_checkpoint_path(model_name: str, category: str) -> Path: | |
| """ | |
| Returns the local path where the checkpoint should be stored. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| Returns: | |
| Path object for local checkpoint | |
| """ | |
| dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize()) | |
| return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "weights" / "lightning" / "model.ckpt" | |
| def get_local_metrics_path(model_name: str, category: str) -> Path: | |
| """ | |
| Returns the local path where the metrics.json should be stored. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| Returns: | |
| Path object for local metrics file | |
| """ | |
| dirname = MODEL_TO_DIRNAME.get(model_name, model_name.capitalize()) | |
| return DIR_RESULTS / dirname / "MVTecAD" / category / "latest" / "metrics.json" | |
| def download_checkpoint(model_name: str, category: str, force: bool = False) -> bool: | |
| """ | |
| Downloads a single checkpoint from HuggingFace Hub. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| force: If True, re-download even if exists | |
| Returns: | |
| True if downloaded/exists, False if failed | |
| """ | |
| local_path = get_local_checkpoint_path(model_name, category) | |
| # Skip if already exists | |
| if local_path.exists() and not force: | |
| return True | |
| hf_path = get_checkpoint_hf_path(model_name, category) | |
| try: | |
| # Create parent directories | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Download from HF Hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=hf_path, | |
| local_dir=DIR_RESULTS, | |
| local_dir_use_symlinks=False, | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f" β Failed to download {model_name}/{category}: {e}") | |
| return False | |
| def download_metrics(model_name: str, category: str, force: bool = False) -> bool: | |
| """ | |
| Downloads metrics.json for a model/category from HuggingFace Hub. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| force: If True, re-download even if exists | |
| Returns: | |
| True if downloaded/exists, False if failed | |
| """ | |
| local_path = get_local_metrics_path(model_name, category) | |
| # Skip if already exists | |
| if local_path.exists() and not force: | |
| return True | |
| hf_path = get_metrics_hf_path(model_name, category) | |
| try: | |
| # Create parent directories | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Download from HF Hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=hf_path, | |
| local_dir=DIR_RESULTS, | |
| local_dir_use_symlinks=False, | |
| ) | |
| return True | |
| except Exception as e: | |
| # Metrics file is optional, don't print error | |
| return False | |
| def download_all_checkpoints( | |
| models: list[str] = None, | |
| categories: list[str] = None, | |
| force: bool = False | |
| ) -> dict: | |
| """ | |
| Downloads checkpoints and metrics for specified models and categories. | |
| Args: | |
| models: List of model names (None = all available) | |
| categories: List of categories (None = all MVTec categories) | |
| force: If True, re-download even if exists | |
| Returns: | |
| Dict with download statistics | |
| """ | |
| if models is None: | |
| models = get_available_models() | |
| if categories is None: | |
| categories = MVTEC_CATEGORIES | |
| stats = {"downloaded": 0, "existed": 0, "failed": 0, "metrics_downloaded": 0} | |
| total = len(models) * len(categories) | |
| print(f"π¦ Downloading checkpoints and metrics from: {HF_REPO_ID}") | |
| print(f" Models: {', '.join(models)}") | |
| print(f" Categories: {len(categories)} total") | |
| print() | |
| with tqdm(total=total, desc="Downloading") as pbar: | |
| for model in models: | |
| for category in categories: | |
| local_path = get_local_checkpoint_path(model, category) | |
| if local_path.exists() and not force: | |
| stats["existed"] += 1 | |
| elif download_checkpoint(model, category, force): | |
| stats["downloaded"] += 1 | |
| else: | |
| stats["failed"] += 1 | |
| # Also download metrics.json if available | |
| if download_metrics(model, category, force): | |
| stats["metrics_downloaded"] += 1 | |
| pbar.update(1) | |
| return stats | |
| def check_checkpoint_exists(model_name: str, category: str) -> bool: | |
| """ | |
| Checks if a checkpoint exists locally. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| Returns: | |
| True if checkpoint exists locally | |
| """ | |
| return get_local_checkpoint_path(model_name, category).exists() | |
| def ensure_checkpoint(model_name: str, category: str) -> Path: | |
| """ | |
| Ensures a checkpoint exists, downloading if necessary. | |
| Also downloads metrics.json if available. | |
| This is the main function to call from inference/app code. | |
| Args: | |
| model_name: Name of the model | |
| category: MVTec category | |
| Returns: | |
| Path to the checkpoint | |
| Raises: | |
| FileNotFoundError: If checkpoint cannot be found or downloaded | |
| """ | |
| local_path = get_local_checkpoint_path(model_name, category) | |
| if local_path.exists(): | |
| # Also try to download metrics if not present | |
| download_metrics(model_name, category) | |
| return local_path | |
| print(f"β¬ Checkpoint not found locally. Downloading {model_name}/{category}...") | |
| if download_checkpoint(model_name, category): | |
| # Also download metrics | |
| download_metrics(model_name, category) | |
| if local_path.exists(): | |
| print(f"β Downloaded successfully") | |
| return local_path | |
| raise FileNotFoundError( | |
| f"Checkpoint not found: {local_path}\n" | |
| f"Please train the model first with: python train.py --model {model_name} --category {category}\n" | |
| f"Or download from HuggingFace Hub: {HF_REPO_ID}" | |
| ) | |
| # ============================================================================= | |
| # CLI | |
| # ============================================================================= | |
| def parse_args(): | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="Download checkpoints from HuggingFace Hub" | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default="all", | |
| help="Model to download (default: all)" | |
| ) | |
| parser.add_argument( | |
| "--category", | |
| type=str, | |
| default="all", | |
| help="Category to download (default: all)" | |
| ) | |
| parser.add_argument( | |
| "--force", | |
| action="store_true", | |
| help="Force re-download even if exists" | |
| ) | |
| parser.add_argument( | |
| "--list", | |
| action="store_true", | |
| help="List available checkpoints on HF Hub" | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Determine models and categories | |
| models = get_available_models() if args.model == "all" else [args.model] | |
| categories = MVTEC_CATEGORIES if args.category == "all" else [args.category] | |
| # Download | |
| stats = download_all_checkpoints(models, categories, args.force) | |
| # Report | |
| print() | |
| print("=" * 50) | |
| print(f"β Checkpoints downloaded: {stats['downloaded']}") | |
| print(f"β Already existed: {stats['existed']}") | |
| print(f"π Metrics downloaded: {stats['metrics_downloaded']}") | |
| if stats['failed'] > 0: | |
| print(f"β Failed: {stats['failed']}") | |
| print("=" * 50) | |
| if __name__ == "__main__": | |
| main() | |