""" Model Loader Module This module handles loading image classification models and their processors from the Hugging Face model hub. It is optimized for ViT-style models but can load a variety of architectures via Auto classes. For ViT models, it configures the model for explainability by enabling attention weights. Author: ViT-XAI-Dashboard Team License: MIT """ import torch from transformers import ( AutoModelForImageClassification, AutoImageProcessor, ) from types import SimpleNamespace import warnings def load_model_and_processor(model_name="google/vit-base-patch16-224"): """ Load an image classification model and its corresponding image processor from Hugging Face. This function uses the Transformers Auto classes to support multiple architectures (ViT, DeiT, Swin, ResNet, etc.). For ViT-like models, it enables attention weight outputs and prefers "eager" attention to make attention matrices accessible for explainability. Args: model_name (str, optional): Hugging Face model identifier. Defaults to "google/vit-base-patch16-224". Returns: tuple: (model, processor) - model (PreTrainedModel): The loaded model in eval mode - processor (ImageProcessor): The corresponding image processor Raises: Exception: If model loading fails due to network issues, invalid model name, or insufficient memory. Note: - Model is automatically set to evaluation mode - Attention outputs are enabled when the model supports them - For ViT-like models, we try to use the "eager" attention implementation - GPU is used automatically if available, otherwise falls back to CPU """ try: print(f"Loading model {model_name}...") # Load the image processor (handles image preprocessing and normalization) processor = AutoImageProcessor.from_pretrained(model_name) # Load the model using Auto classes (supports many architectures) model = AutoModelForImageClassification.from_pretrained(model_name) # Enable attention output in model config when available # This makes attention weights available in forward pass outputs if hasattr(model, "config"): try: model.config.output_attentions = True except Exception: pass # Prefer "eager" attention implementation when the config supports it # This is particularly relevant for ViT models to expose attention weights for attr in ("_attn_implementation", "attn_implementation"): if hasattr(model.config, attr): try: setattr(model.config, attr, "eager") except Exception: pass # Determine device (GPU if available, otherwise CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Set model to evaluation mode # This disables dropout and sets batch normalization to eval mode model.eval() # Print success message with device info print(f"✅ Model and processor loaded successfully on {device}!") # Best-effort informational printout for attention implementation if available attn_impl = None if hasattr(model, "config"): for attr in ("_attn_implementation", "attn_implementation"): if hasattr(model.config, attr): attn_impl = getattr(model.config, attr) break if attn_impl is not None: print(f" Using attention implementation: {attn_impl}") return model, processor except Exception as e: # Handle known EfficientNet issue that requires torch>=2.6 for torch.load err_msg = str(e) print(f"⚠️ Primary load failed for {model_name}: {err_msg}") if "efficientnet" in model_name.lower() or "v2.6" in err_msg: try: print("Attempting fallback to timm for EfficientNet...") model, processor = _load_efficientnet_with_timm(model_name) # Move to device and eval as usual device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() print(f"✅ Fallback loaded via timm on {device}!") return model, processor except Exception as ee: print(f"❌ Fallback via timm failed: {ee}") raise # Re-raise exception with context for debugging if not handled print(f"❌ Error loading model {model_name}: {str(e)}") raise class _SimpleImageProcessor: """Minimal image processor to mimic HF processor for non-HF models. Returns a dict with 'pixel_values' suitable for our predictor pipeline. """ def __init__(self, size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): from torchvision import transforms self.size = size self.transform = transforms.Compose( [ transforms.Resize((size, size)), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ] ) def __call__(self, images, return_tensors="pt"): if return_tensors != "pt": warnings.warn("_SimpleImageProcessor only supports return_tensors='pt'") import torch as _torch # Expect a single PIL Image for our use-cases tensor = self.transform(images).unsqueeze(0) # (1, C, H, W) return {"pixel_values": tensor} class _HFLikeOutput: def __init__(self, logits): self.logits = logits class _HFLikeModelWrapper(torch.nn.Module): """Wrap a timm model to present an HF-like interface with config.id2label. Forward accepts pixel_values and returns an object with .logits """ def __init__(self, model, id2label): super().__init__() self.model = model self.config = SimpleNamespace(id2label=id2label) def forward(self, pixel_values): logits = self.model(pixel_values) return _HFLikeOutput(logits) def _load_efficientnet_with_timm(model_name: str): """Load EfficientNet via timm as a fallback, returning (model, processor).""" try: import timm except Exception as e: raise RuntimeError( "timm is required for EfficientNet fallback. Please install 'timm'." ) from e # Map HF name to a commonly available timm variant variant = "tf_efficientnet_b7_ns" if "b7" in model_name.lower() else "tf_efficientnet_b0" net = timm.create_model(variant, pretrained=True, num_classes=1000) net.eval() # Build ImageNet-1k id2label mapping if needed id2label = {i: f"class_{i}" for i in range(1000)} wrapped = _HFLikeModelWrapper(net, id2label) processor = _SimpleImageProcessor(size=224) return wrapped, processor # Dictionary of supported ViT models with their Hugging Face identifiers # Users can easily add more models by extending this dictionary SUPPORTED_MODELS = { # ViT family "ViT-Base": "google/vit-base-patch16-224", # 86M params, good balance of speed/accuracy "ViT-Large": "google/vit-large-patch16-224", # 304M params, higher accuracy but slower # New additions "ResNet-50": "microsoft/resnet-50", "Swin Transformer": "microsoft/swin-base-patch4-window7-224", "DeiT": "facebook/deit-base-patch16-224", "EfficientNet": "google/efficientnet-b7", # Note: may have limited attention-based XAI }