Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from sklearn.preprocessing import MultiLabelBinarizer | |
| import joblib | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Any | |
| from app.src.logger import setup_logger | |
| logger = setup_logger("vit_load") | |
| class VITDocumentClassifier: | |
| """ | |
| A class for classifying documents using a Vision Transformer (ViT) model. | |
| This class encapsulates the loading of the ViT model, its associated processor, | |
| and a MultiLabelBinarizer for converting model outputs to meaningful labels. | |
| It provides a method to preprocess input images and perform multi-label | |
| classification predictions with a specified confidence cutoff threshold. | |
| """ | |
| def __init__(self, model_path: Path, mlb_path: Path, model_id: str = "google/vit-base-patch16-224-in21k") -> None: | |
| """ | |
| Initializes the VITDocumentClassifier by loading the model, processor, and MLB. | |
| Args: | |
| model_path: Path to the ViT model file (.pth). This is expected to be | |
| a pre-trained or fine-tuned PyTorch model file. | |
| mlb_path: Path to the MultiLabelBinarizer file (.joblib). This file | |
| should contain the fitted binarizer object corresponding | |
| to the model's output classes. | |
| model_id: The Hugging Face model ID for the processor. This is used | |
| to load the appropriate image processor for the ViT model. | |
| Defaults to "google/vit-base-patch16-224-in21k". | |
| Raises: | |
| FileNotFoundError: If either the model file or the MLB file is not found | |
| at the specified paths during artifact loading. | |
| Exception: If any other unexpected error occurs during the loading | |
| of the model, processor, or MultiLabelBinarizer. | |
| RuntimeError: If artifact loading fails for critical components | |
| (model or MLB). | |
| """ | |
| logger.info("Initializing VITDocumentClassifier.") | |
| self.model: Optional[torch.nn.Module] = None | |
| self.processor: Optional[AutoImageProcessor] = None | |
| self.mlb: Optional[MultiLabelBinarizer] = None | |
| self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| self.model_id: str = model_id | |
| try: | |
| self._load_artifacts(model_path, mlb_path) | |
| if self.model and self.processor and self.mlb: | |
| logger.info("VITDocumentClassifier initialized successfully.") | |
| else: | |
| # This case should ideally be caught and re-raised in _load_artifacts | |
| # but adding a check here for robustness. | |
| logger.critical("VITDocumentClassifier failed to fully initialize due to artifact loading errors.") | |
| raise RuntimeError("Failed to load all required artifacts for VITDocumentClassifier.") | |
| except Exception as e: | |
| logger.critical(f"Failed to initialize VITDocumentClassifier: {e}", exc_info=True) | |
| # Re-raise the exception after logging | |
| raise | |
| def _load_artifacts(self, model_path: Path, mlb_path: Path) -> None: | |
| """ | |
| Loads the ViT model, processor, and MultiLabelBinarizer with enhanced error handling and logging. | |
| This is an internal helper method called during initialization. | |
| Args: | |
| model_path: Path to the ViT model file (.pth). | |
| mlb_path: Path to the MultiLabelBinarizer file (.joblib). | |
| Raises: | |
| FileNotFoundError: If either the model file or the MLB file is not found. | |
| Exception: If any other unexpected error occurs during loading. | |
| """ | |
| logger.info("Starting artifact loading.") | |
| processor_loaded: bool = False | |
| model_loaded: bool = False | |
| mlb_loaded: bool = False | |
| # Load Processor | |
| try: | |
| logger.info(f"Attempting to load ViT processor for model ID: {self.model_id}") | |
| self.processor = AutoImageProcessor.from_pretrained(self.model_id, use_fast=True) | |
| logger.info("ViT processor loaded successfully.") | |
| processor_loaded = True | |
| except Exception as e: | |
| # Log at error level as processor is important but not strictly critical if we raise later | |
| logger.error(f"An error occurred while loading the ViT processor for model ID {self.model_id}: {e}", exc_info=True) | |
| # Do not re-raise here, continue loading other artifacts | |
| # Load Model | |
| try: | |
| logger.info(f"Attempting to load ViT model from {model_path}") | |
| # Note: Adjust map_location as needed based on where the model was saved | |
| self.model = torch.load(model_path, map_location=self.device, weights_only=False) | |
| self.model.to(self.device) # Ensure model is on the correct device | |
| logger.info(f"ViT model loaded successfully and moved to {self.device}.") | |
| model_loaded = True | |
| except FileNotFoundError: | |
| logger.critical(f"Critical Error: ViT model file not found at {model_path}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| except Exception as e: | |
| logger.critical(f"Critical Error: An unexpected error occurred while loading the ViT model from {model_path}: {e}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| # Load MLB | |
| try: | |
| logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}") | |
| self.mlb = joblib.load(mlb_path) | |
| logger.info("MultiLabelBinarizer loaded successfully.") | |
| mlb_loaded = True | |
| except FileNotFoundError: | |
| logger.critical(f"Critical Error: MultiLabelBinarizer file not found at {mlb_path}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| except Exception as e: | |
| logger.critical(f"Critical Error: An unexpected error occurred while loading the MultiLabelBinarizer from {mlb_path}: {e}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| if processor_loaded and model_loaded and mlb_loaded: | |
| logger.info("All required ViT artifacts loaded successfully.") | |
| else: | |
| logger.error("One or more required ViT artifacts failed to load during _load_artifacts.") | |
| def predict(self, image_path: Path, cut_off: float = 0.5) -> Optional[List[str]]: | |
| """ | |
| Predicts the class labels for a given image using the loaded ViT model. | |
| The process involves loading and preprocessing the image, performing | |
| inference with the model, applying a sigmoid activation, thresholding | |
| the probabilities to obtain binary predictions, and finally converting | |
| the binary predictions back to class labels using the MultiLabelBinarizer. | |
| Args: | |
| image_path: Path to the image file to classify. The image is expected | |
| to be in a format compatible with PIL (Pillow). | |
| cut_off: The threshold for converting predicted probabilities into | |
| binary labels. Probabilities greater than or equal to this | |
| value are considered positive predictions (1), otherwise 0. | |
| Defaults to 0.5. | |
| Returns: | |
| A list of predicted class labels (strings) if the prediction process | |
| is successful. Returns None if any critical step (image loading, | |
| preprocessing, model inference, or inverse transform) fails. | |
| Returns an empty list if the prediction process is successful but | |
| no labels meet the cutoff threshold. | |
| """ | |
| logger.info(f"Starting prediction process for image: {image_path} with cutoff {cut_off}.") | |
| if self.model is None or self.processor is None or self.mlb is None: | |
| logger.error("Model, processor, or MultiLabelBinarizer not loaded. Cannot perform prediction.") | |
| return None | |
| # Load and preprocess image | |
| image: Optional[Image.Image] = None | |
| try: | |
| logger.info(f"Attempting to load image from {image_path}") | |
| image = Image.open(image_path) | |
| logger.info(f"Image loaded successfully from {image_path}.") | |
| except FileNotFoundError: | |
| logger.error(f"Error: Image file not found at {image_path}", exc_info=True) | |
| return None | |
| except Exception as e: | |
| logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True) | |
| return None | |
| try: | |
| logger.info(f"Attempting to convert image to RGB for {image_path}.") | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| logger.info(f"Image converted to RGB successfully for {image_path}.") | |
| else: | |
| logger.info(f"Image is already in RGB format for {image_path}.") | |
| except Exception as e: | |
| logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True) | |
| return None | |
| # Preprocess image using the loaded processor | |
| try: | |
| logger.info(f"Attempting to preprocess image using processor for {image_path}.") | |
| # Check if image is valid after loading/conversion | |
| if image is None: | |
| logger.error(f"Image is None after loading/conversion for {image_path}. Cannot preprocess.") | |
| return None | |
| # The processor expects a PIL Image or a list of PIL Images | |
| pixel_values: torch.Tensor = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) | |
| logger.info(f"Image preprocessed and moved to device ({self.device}).") | |
| except Exception as e: | |
| logger.error(f"An error occurred during image preprocessing for {image_path}: {e}", exc_info=True) | |
| return None | |
| # Forward pass | |
| try: | |
| logger.info(f"Starting model forward pass for {image_path}.") | |
| self.model.eval() # Set model to evaluation mode | |
| with torch.no_grad(): | |
| outputs: Any = self.model(pixel_values) # Use Any because the output type can vary | |
| logits: torch.Tensor = outputs.logits | |
| logger.info(f"Model forward pass completed for {image_path}.") | |
| except Exception as e: | |
| logger.error(f"An error occurred during model forward pass for {image_path}: {e}", exc_info=True) | |
| return None | |
| # Apply sigmoid and thresholding | |
| try: | |
| logger.info(f"Applying sigmoid and thresholding for {image_path}.") | |
| sigmoid: torch.nn.Sigmoid = torch.nn.Sigmoid() | |
| probs: torch.Tensor = sigmoid(logits.squeeze().cpu()) | |
| predictions: np.ndarray = np.zeros(probs.shape, dtype=int) # Explicitly set dtype to int | |
| print(predictions) | |
| predictions[np.where(probs >= cut_off)] = 1 | |
| logger.info(f"Applied sigmoid and thresholding with cutoff {cut_off} for {image_path}. Binary predictions shape: {predictions.shape}") | |
| except Exception as e: | |
| logger.error(f"An error occurred during probability processing for {image_path}: {e}", exc_info=True) | |
| return None | |
| # Get label names using the loaded MultiLabelBinarizer | |
| try: | |
| logger.info(f"Performing inverse transform using MultiLabelBinarizer for {image_path}.") | |
| # The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes) | |
| # Use the self.mlb loaded during initialization | |
| # Ensure self.mlb is not None (checked at the start of predict, but good practice) | |
| if self.mlb is None: | |
| logger.error(f"MultiLabelBinarizer is None. Cannot perform inverse transform for {image_path}.") | |
| return None | |
| binary_prediction: np.ndarray | |
| # Ensure predictions shape is compatible (must be 2D: (n_samples, n_classes)) | |
| # Since we process one image at a time, expected shape is (1, n_classes) | |
| expected_shape: Tuple[int, int] = (1, len(self.mlb.classes_)) | |
| if predictions.ndim == 1 and predictions.shape[0] == len(self.mlb.classes_): | |
| binary_prediction = predictions.reshape(expected_shape) | |
| logger.info(f"Reshaped 1D prediction to 2D ({expected_shape}) for inverse transform.") | |
| elif predictions.ndim == 2 and predictions.shape == expected_shape: | |
| binary_prediction = predictions | |
| logger.info(f"Prediction already in correct 2D shape ({expected_shape}) for inverse transform.") | |
| else: | |
| logger.error(f"Cannot inverse transform prediction shape {predictions.shape} with MLB classes {len(self.mlb.classes_)} for {image_path}. Expected shape: {expected_shape}") | |
| return None | |
| predicted_labels_tuple_list: List[Tuple[str, ...]] = self.mlb.inverse_transform(binary_prediction) | |
| logger.info(f"Prediction processed for {image_path}. Predicted labels (raw tuple list): {predicted_labels_tuple_list}") | |
| # inverse_transform returns a list of tuples, even for a single sample. | |
| # We expect a single prediction here, so we take the first tuple. | |
| if predicted_labels_tuple_list and len(predicted_labels_tuple_list) > 0: | |
| final_labels: List[str] = list(predicted_labels_tuple_list[0]) | |
| logger.info(f"Final predicted labels for {image_path}: {final_labels}") | |
| return final_labels | |
| else: | |
| logger.warning(f"MLB inverse_transform returned an empty list for {image_path}. No labels predicted.") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"An error occurred during inverse transform for {image_path}: {e}", exc_info=True) | |
| return None | |