| | """ |
| | Cache Management and SAM2 Loading Utilities |
| | Comprehensive cache cleaning system to resolve model loading issues on HF Spaces |
| | """ |
| |
|
| | import os |
| | import gc |
| | import sys |
| | import shutil |
| | import tempfile |
| | import logging |
| | import traceback |
| | from pathlib import Path |
| | from typing import Optional, Dict, Any, Tuple |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class HardCacheCleaner: |
| | """ |
| | Comprehensive cache cleaning system to resolve SAM2 loading issues |
| | Clears Python module cache, HuggingFace cache, and temp files |
| | """ |
| | |
| | @staticmethod |
| | def clean_all_caches(verbose: bool = True): |
| | """Clean all caches that might interfere with SAM2 loading""" |
| | |
| | if verbose: |
| | logger.info("Starting comprehensive cache cleanup...") |
| | |
| | |
| | HardCacheCleaner._clean_python_cache(verbose) |
| | |
| | |
| | HardCacheCleaner._clean_huggingface_cache(verbose) |
| | |
| | |
| | HardCacheCleaner._clean_pytorch_cache(verbose) |
| | |
| | |
| | HardCacheCleaner._clean_temp_directories(verbose) |
| | |
| | |
| | HardCacheCleaner._clear_import_cache(verbose) |
| | |
| | |
| | HardCacheCleaner._force_gc_cleanup(verbose) |
| | |
| | if verbose: |
| | logger.info("Cache cleanup completed") |
| | |
| | @staticmethod |
| | def _clean_python_cache(verbose: bool = True): |
| | """Clean Python bytecode cache""" |
| | try: |
| | |
| | sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()] |
| | for module in sam2_modules: |
| | if verbose: |
| | logger.info(f"Removing cached module: {module}") |
| | del sys.modules[module] |
| | |
| | |
| | for root, dirs, files in os.walk("."): |
| | for dir_name in dirs[:]: |
| | if dir_name == "__pycache__": |
| | cache_path = os.path.join(root, dir_name) |
| | if verbose: |
| | logger.info(f"Removing __pycache__: {cache_path}") |
| | shutil.rmtree(cache_path, ignore_errors=True) |
| | dirs.remove(dir_name) |
| | |
| | except Exception as e: |
| | logger.warning(f"Python cache cleanup failed: {e}") |
| | |
| | @staticmethod |
| | def _clean_huggingface_cache(verbose: bool = True): |
| | """Clean HuggingFace model cache""" |
| | try: |
| | |
| | from config.app_config import get_config |
| | config = get_config() |
| | |
| | cache_paths = [ |
| | os.path.expanduser("~/.cache/huggingface/"), |
| | os.path.expanduser("~/.cache/torch/"), |
| | config.model_cache_dir, |
| | "./checkpoints/", |
| | "./.cache/", |
| | ] |
| | |
| | for cache_path in cache_paths: |
| | if os.path.exists(cache_path): |
| | if verbose: |
| | logger.info(f"Cleaning cache directory: {cache_path}") |
| | |
| | |
| | for root, dirs, files in os.walk(cache_path): |
| | for file in files: |
| | if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']): |
| | file_path = os.path.join(root, file) |
| | try: |
| | os.remove(file_path) |
| | if verbose: |
| | logger.info(f"Removed cached file: {file_path}") |
| | except: |
| | pass |
| | |
| | for dir_name in dirs[:]: |
| | if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']): |
| | dir_path = os.path.join(root, dir_name) |
| | try: |
| | shutil.rmtree(dir_path, ignore_errors=True) |
| | if verbose: |
| | logger.info(f"Removed cached directory: {dir_path}") |
| | dirs.remove(dir_name) |
| | except: |
| | pass |
| | |
| | except Exception as e: |
| | logger.warning(f"HuggingFace cache cleanup failed: {e}") |
| | |
| | @staticmethod |
| | def _clean_pytorch_cache(verbose: bool = True): |
| | """Clean PyTorch cache""" |
| | try: |
| | import torch |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | if verbose: |
| | logger.info("Cleared PyTorch CUDA cache") |
| | except Exception as e: |
| | logger.warning(f"PyTorch cache cleanup failed: {e}") |
| | |
| | @staticmethod |
| | def _clean_temp_directories(verbose: bool = True): |
| | """Clean temporary directories""" |
| | try: |
| | from config.app_config import get_config |
| | config = get_config() |
| | |
| | temp_dirs = [ |
| | config.temp_dir, |
| | tempfile.gettempdir(), |
| | "/tmp", |
| | "./tmp", |
| | "./temp" |
| | ] |
| | |
| | for temp_dir in temp_dirs: |
| | if os.path.exists(temp_dir): |
| | for item in os.listdir(temp_dir): |
| | if 'sam2' in item.lower() or 'segment' in item.lower(): |
| | item_path = os.path.join(temp_dir, item) |
| | try: |
| | if os.path.isfile(item_path): |
| | os.remove(item_path) |
| | elif os.path.isdir(item_path): |
| | shutil.rmtree(item_path, ignore_errors=True) |
| | if verbose: |
| | logger.info(f"Removed temp item: {item_path}") |
| | except: |
| | pass |
| | |
| | except Exception as e: |
| | logger.warning(f"Temp directory cleanup failed: {e}") |
| | |
| | @staticmethod |
| | def _clear_import_cache(verbose: bool = True): |
| | """Clear Python import cache""" |
| | try: |
| | import importlib |
| | |
| | |
| | importlib.invalidate_caches() |
| | |
| | if verbose: |
| | logger.info("Cleared Python import cache") |
| | |
| | except Exception as e: |
| | logger.warning(f"Import cache cleanup failed: {e}") |
| | |
| | @staticmethod |
| | def _force_gc_cleanup(verbose: bool = True): |
| | """Force garbage collection""" |
| | try: |
| | collected = gc.collect() |
| | if verbose: |
| | logger.info(f"Garbage collection freed {collected} objects") |
| | except Exception as e: |
| | logger.warning(f"Garbage collection failed: {e}") |
| |
|
| |
|
| | class WorkingSAM2Loader: |
| | """ |
| | SAM2 loader using HuggingFace Transformers integration - proven to work on HF Spaces |
| | This avoids all the config file and CUDA compilation issues |
| | """ |
| | |
| | @staticmethod |
| | def load_sam2_transformers_approach(device: str = "cuda", model_size: str = "large") -> Optional[Any]: |
| | """ |
| | Load SAM2 using HuggingFace Transformers integration |
| | This method works reliably on HuggingFace Spaces |
| | """ |
| | try: |
| | logger.info("Loading SAM2 via HuggingFace Transformers...") |
| | |
| | |
| | model_map = { |
| | "tiny": "facebook/sam2.1-hiera-tiny", |
| | "small": "facebook/sam2.1-hiera-small", |
| | "base": "facebook/sam2.1-hiera-base-plus", |
| | "large": "facebook/sam2.1-hiera-large" |
| | } |
| | |
| | model_id = model_map.get(model_size, model_map["large"]) |
| | logger.info(f"Using model: {model_id}") |
| | |
| | |
| | try: |
| | from transformers import pipeline |
| | |
| | sam2_pipeline = pipeline( |
| | "mask-generation", |
| | model=model_id, |
| | device=0 if device == "cuda" else -1 |
| | ) |
| | |
| | logger.info("SAM2 loaded successfully via Transformers pipeline") |
| | return sam2_pipeline |
| | |
| | except Exception as e: |
| | logger.warning(f"Pipeline approach failed: {e}") |
| | |
| | |
| | try: |
| | from transformers import Sam2Processor, Sam2Model |
| | |
| | processor = Sam2Processor.from_pretrained(model_id) |
| | model = Sam2Model.from_pretrained(model_id).to(device) |
| | |
| | logger.info("SAM2 loaded successfully via Transformers classes") |
| | return {"model": model, "processor": processor} |
| | |
| | except Exception as e: |
| | logger.warning(f"Direct class approach failed: {e}") |
| | |
| | |
| | try: |
| | from sam2.sam2_image_predictor import SAM2ImagePredictor |
| | |
| | predictor = SAM2ImagePredictor.from_pretrained(model_id) |
| | |
| | logger.info("SAM2 loaded successfully via official from_pretrained") |
| | return predictor |
| | |
| | except Exception as e: |
| | logger.warning(f"Official from_pretrained approach failed: {e}") |
| | |
| | return None |
| | |
| | except Exception as e: |
| | logger.error(f"All SAM2 loading methods failed: {e}") |
| | return None |
| | |
| | @staticmethod |
| | def load_sam2_fallback_approach(device: str = "cuda") -> Optional[Any]: |
| | """ |
| | Fallback approach using direct model loading |
| | """ |
| | try: |
| | logger.info("Trying fallback SAM2 loading approach...") |
| | |
| | |
| | from huggingface_hub import hf_hub_download |
| | import torch |
| | |
| | |
| | checkpoint_path = hf_hub_download( |
| | repo_id="facebook/sam2.1-hiera-large", |
| | filename="sam2_hiera_large.pt" |
| | ) |
| | |
| | logger.info(f"Downloaded checkpoint to: {checkpoint_path}") |
| | |
| | |
| | try: |
| | |
| | from transformers import Sam2Model |
| | model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large") |
| | return model.to(device) |
| | |
| | except Exception as e: |
| | logger.warning(f"Transformers fallback failed: {e}") |
| | |
| | return None |
| | |
| | except Exception as e: |
| | logger.error(f"Fallback loading failed: {e}") |
| | return None |
| |
|
| |
|
| | def load_sam2_with_cache_cleanup( |
| | device: str = "cuda", |
| | model_size: str = "large", |
| | force_cache_clean: bool = True, |
| | verbose: bool = True |
| | ) -> Tuple[Optional[Any], str]: |
| | """ |
| | Load SAM2 with comprehensive cache cleanup |
| | |
| | Returns: |
| | Tuple of (model, status_message) |
| | """ |
| | |
| | status_messages = [] |
| | |
| | try: |
| | |
| | if force_cache_clean: |
| | status_messages.append("Cleaning caches...") |
| | HardCacheCleaner.clean_all_caches(verbose=verbose) |
| | status_messages.append("Cache cleanup completed") |
| | |
| | |
| | status_messages.append("Loading SAM2 (primary method)...") |
| | model = WorkingSAM2Loader.load_sam2_transformers_approach(device, model_size) |
| | |
| | if model is not None: |
| | status_messages.append("SAM2 loaded successfully!") |
| | return model, "\n".join(status_messages) |
| | |
| | |
| | status_messages.append("Trying fallback loading method...") |
| | model = WorkingSAM2Loader.load_sam2_fallback_approach(device) |
| | |
| | if model is not None: |
| | status_messages.append("SAM2 loaded successfully (fallback)!") |
| | return model, "\n".join(status_messages) |
| | |
| | |
| | status_messages.append("All SAM2 loading methods failed") |
| | return None, "\n".join(status_messages) |
| | |
| | except Exception as e: |
| | error_msg = f"Critical error in SAM2 loading: {e}" |
| | logger.error(f"{error_msg}\n{traceback.format_exc()}") |
| | status_messages.append(error_msg) |
| | return None, "\n".join(status_messages) |