KaushiGihan's picture
Upload 17 files
07fc447 verified
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.preprocessing import MultiLabelBinarizer
import joblib
from pathlib import Path
from typing import List, Optional, Tuple, Any
from app.src.logger import setup_logger
logger = setup_logger("vit_load")
class VITDocumentClassifier:
"""
A class for classifying documents using a Vision Transformer (ViT) model.
This class encapsulates the loading of the ViT model, its associated processor,
and a MultiLabelBinarizer for converting model outputs to meaningful labels.
It provides a method to preprocess input images and perform multi-label
classification predictions with a specified confidence cutoff threshold.
"""
def __init__(self, model_path: Path, mlb_path: Path, model_id: str = "google/vit-base-patch16-224-in21k") -> None:
"""
Initializes the VITDocumentClassifier by loading the model, processor, and MLB.
Args:
model_path: Path to the ViT model file (.pth). This is expected to be
a pre-trained or fine-tuned PyTorch model file.
mlb_path: Path to the MultiLabelBinarizer file (.joblib). This file
should contain the fitted binarizer object corresponding
to the model's output classes.
model_id: The Hugging Face model ID for the processor. This is used
to load the appropriate image processor for the ViT model.
Defaults to "google/vit-base-patch16-224-in21k".
Raises:
FileNotFoundError: If either the model file or the MLB file is not found
at the specified paths during artifact loading.
Exception: If any other unexpected error occurs during the loading
of the model, processor, or MultiLabelBinarizer.
RuntimeError: If artifact loading fails for critical components
(model or MLB).
"""
logger.info("Initializing VITDocumentClassifier.")
self.model: Optional[torch.nn.Module] = None
self.processor: Optional[AutoImageProcessor] = None
self.mlb: Optional[MultiLabelBinarizer] = None
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
self.model_id: str = model_id
try:
self._load_artifacts(model_path, mlb_path)
if self.model and self.processor and self.mlb:
logger.info("VITDocumentClassifier initialized successfully.")
else:
# This case should ideally be caught and re-raised in _load_artifacts
# but adding a check here for robustness.
logger.critical("VITDocumentClassifier failed to fully initialize due to artifact loading errors.")
raise RuntimeError("Failed to load all required artifacts for VITDocumentClassifier.")
except Exception as e:
logger.critical(f"Failed to initialize VITDocumentClassifier: {e}", exc_info=True)
# Re-raise the exception after logging
raise
def _load_artifacts(self, model_path: Path, mlb_path: Path) -> None:
"""
Loads the ViT model, processor, and MultiLabelBinarizer with enhanced error handling and logging.
This is an internal helper method called during initialization.
Args:
model_path: Path to the ViT model file (.pth).
mlb_path: Path to the MultiLabelBinarizer file (.joblib).
Raises:
FileNotFoundError: If either the model file or the MLB file is not found.
Exception: If any other unexpected error occurs during loading.
"""
logger.info("Starting artifact loading.")
processor_loaded: bool = False
model_loaded: bool = False
mlb_loaded: bool = False
# Load Processor
try:
logger.info(f"Attempting to load ViT processor for model ID: {self.model_id}")
self.processor = AutoImageProcessor.from_pretrained(self.model_id, use_fast=True)
logger.info("ViT processor loaded successfully.")
processor_loaded = True
except Exception as e:
# Log at error level as processor is important but not strictly critical if we raise later
logger.error(f"An error occurred while loading the ViT processor for model ID {self.model_id}: {e}", exc_info=True)
# Do not re-raise here, continue loading other artifacts
# Load Model
try:
logger.info(f"Attempting to load ViT model from {model_path}")
# Note: Adjust map_location as needed based on where the model was saved
self.model = torch.load(model_path, map_location=self.device, weights_only=False)
self.model.to(self.device) # Ensure model is on the correct device
logger.info(f"ViT model loaded successfully and moved to {self.device}.")
model_loaded = True
except FileNotFoundError:
logger.critical(f"Critical Error: ViT model file not found at {model_path}", exc_info=True)
raise # Re-raise to indicate a critical initialization failure
except Exception as e:
logger.critical(f"Critical Error: An unexpected error occurred while loading the ViT model from {model_path}: {e}", exc_info=True)
raise # Re-raise to indicate a critical initialization failure
# Load MLB
try:
logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
self.mlb = joblib.load(mlb_path)
logger.info("MultiLabelBinarizer loaded successfully.")
mlb_loaded = True
except FileNotFoundError:
logger.critical(f"Critical Error: MultiLabelBinarizer file not found at {mlb_path}", exc_info=True)
raise # Re-raise to indicate a critical initialization failure
except Exception as e:
logger.critical(f"Critical Error: An unexpected error occurred while loading the MultiLabelBinarizer from {mlb_path}: {e}", exc_info=True)
raise # Re-raise to indicate a critical initialization failure
if processor_loaded and model_loaded and mlb_loaded:
logger.info("All required ViT artifacts loaded successfully.")
else:
logger.error("One or more required ViT artifacts failed to load during _load_artifacts.")
def predict(self, image_path: Path, cut_off: float = 0.5) -> Optional[List[str]]:
"""
Predicts the class labels for a given image using the loaded ViT model.
The process involves loading and preprocessing the image, performing
inference with the model, applying a sigmoid activation, thresholding
the probabilities to obtain binary predictions, and finally converting
the binary predictions back to class labels using the MultiLabelBinarizer.
Args:
image_path: Path to the image file to classify. The image is expected
to be in a format compatible with PIL (Pillow).
cut_off: The threshold for converting predicted probabilities into
binary labels. Probabilities greater than or equal to this
value are considered positive predictions (1), otherwise 0.
Defaults to 0.5.
Returns:
A list of predicted class labels (strings) if the prediction process
is successful. Returns None if any critical step (image loading,
preprocessing, model inference, or inverse transform) fails.
Returns an empty list if the prediction process is successful but
no labels meet the cutoff threshold.
"""
logger.info(f"Starting prediction process for image: {image_path} with cutoff {cut_off}.")
if self.model is None or self.processor is None or self.mlb is None:
logger.error("Model, processor, or MultiLabelBinarizer not loaded. Cannot perform prediction.")
return None
# Load and preprocess image
image: Optional[Image.Image] = None
try:
logger.info(f"Attempting to load image from {image_path}")
image = Image.open(image_path)
logger.info(f"Image loaded successfully from {image_path}.")
except FileNotFoundError:
logger.error(f"Error: Image file not found at {image_path}", exc_info=True)
return None
except Exception as e:
logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True)
return None
try:
logger.info(f"Attempting to convert image to RGB for {image_path}.")
if image.mode != "RGB":
image = image.convert("RGB")
logger.info(f"Image converted to RGB successfully for {image_path}.")
else:
logger.info(f"Image is already in RGB format for {image_path}.")
except Exception as e:
logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True)
return None
# Preprocess image using the loaded processor
try:
logger.info(f"Attempting to preprocess image using processor for {image_path}.")
# Check if image is valid after loading/conversion
if image is None:
logger.error(f"Image is None after loading/conversion for {image_path}. Cannot preprocess.")
return None
# The processor expects a PIL Image or a list of PIL Images
pixel_values: torch.Tensor = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
logger.info(f"Image preprocessed and moved to device ({self.device}).")
except Exception as e:
logger.error(f"An error occurred during image preprocessing for {image_path}: {e}", exc_info=True)
return None
# Forward pass
try:
logger.info(f"Starting model forward pass for {image_path}.")
self.model.eval() # Set model to evaluation mode
with torch.no_grad():
outputs: Any = self.model(pixel_values) # Use Any because the output type can vary
logits: torch.Tensor = outputs.logits
logger.info(f"Model forward pass completed for {image_path}.")
except Exception as e:
logger.error(f"An error occurred during model forward pass for {image_path}: {e}", exc_info=True)
return None
# Apply sigmoid and thresholding
try:
logger.info(f"Applying sigmoid and thresholding for {image_path}.")
sigmoid: torch.nn.Sigmoid = torch.nn.Sigmoid()
probs: torch.Tensor = sigmoid(logits.squeeze().cpu())
predictions: np.ndarray = np.zeros(probs.shape, dtype=int) # Explicitly set dtype to int
print(predictions)
predictions[np.where(probs >= cut_off)] = 1
logger.info(f"Applied sigmoid and thresholding with cutoff {cut_off} for {image_path}. Binary predictions shape: {predictions.shape}")
except Exception as e:
logger.error(f"An error occurred during probability processing for {image_path}: {e}", exc_info=True)
return None
# Get label names using the loaded MultiLabelBinarizer
try:
logger.info(f"Performing inverse transform using MultiLabelBinarizer for {image_path}.")
# The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
# Use the self.mlb loaded during initialization
# Ensure self.mlb is not None (checked at the start of predict, but good practice)
if self.mlb is None:
logger.error(f"MultiLabelBinarizer is None. Cannot perform inverse transform for {image_path}.")
return None
binary_prediction: np.ndarray
# Ensure predictions shape is compatible (must be 2D: (n_samples, n_classes))
# Since we process one image at a time, expected shape is (1, n_classes)
expected_shape: Tuple[int, int] = (1, len(self.mlb.classes_))
if predictions.ndim == 1 and predictions.shape[0] == len(self.mlb.classes_):
binary_prediction = predictions.reshape(expected_shape)
logger.info(f"Reshaped 1D prediction to 2D ({expected_shape}) for inverse transform.")
elif predictions.ndim == 2 and predictions.shape == expected_shape:
binary_prediction = predictions
logger.info(f"Prediction already in correct 2D shape ({expected_shape}) for inverse transform.")
else:
logger.error(f"Cannot inverse transform prediction shape {predictions.shape} with MLB classes {len(self.mlb.classes_)} for {image_path}. Expected shape: {expected_shape}")
return None
predicted_labels_tuple_list: List[Tuple[str, ...]] = self.mlb.inverse_transform(binary_prediction)
logger.info(f"Prediction processed for {image_path}. Predicted labels (raw tuple list): {predicted_labels_tuple_list}")
# inverse_transform returns a list of tuples, even for a single sample.
# We expect a single prediction here, so we take the first tuple.
if predicted_labels_tuple_list and len(predicted_labels_tuple_list) > 0:
final_labels: List[str] = list(predicted_labels_tuple_list[0])
logger.info(f"Final predicted labels for {image_path}: {final_labels}")
return final_labels
else:
logger.warning(f"MLB inverse_transform returned an empty list for {image_path}. No labels predicted.")
return []
except Exception as e:
logger.error(f"An error occurred during inverse transform for {image_path}: {e}", exc_info=True)
return None