from PIL import Image import numpy as np import torch from typing import Optional, List, Dict, Any from pathlib import Path from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor, LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer import os from dotenv import load_dotenv from app.src.logger import setup_logger logger = setup_logger("layout_loader") class LayoutLMDocumentClassifier: """ A class for classifying documents using a LayoutLMv2 model. This class encapsulates the loading of the LayoutLMv2 model and its associated processor, handles image preprocessing, and performs document classification predictions. The model path is loaded from environment variables, promoting flexible configuration. It includes robust error handling, logging, and type hinting for production readiness. """ def __init__(self,model_path_str) -> None: """ Initializes the LayoutLMDocumentClassifier by loading the model and processor. The model and processor are loaded from the path specified in the environment variable 'LAYOUTLM_MODEL_PATH'. This method also sets up the device for inference (GPU if available, otherwise CPU) and defines the mapping from model output indices to class labels. Includes robust error handling and logging for initialization and artifact loading. Raises: ValueError: If the 'LAYOUTLM_MODEL_PATH' environment variable is not set. FileNotFoundError: If the model path specified in the environment variable does not exist or a required artifact file is not found during the artifact loading process. Exception: If any other unexpected error occurs during the loading of the model or processor. """ logger.info("Initializing LayoutLMDocumentClassifier.") self.model_path_str: Optional[str]=model_path_str self.model: Optional[LayoutLMv2ForSequenceClassification] = None self.processor: Optional[LayoutLMv2Processor] = None self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Define id2label mapping as a class attribute # This mapping should align with the model's output classes. self.id2label: Dict[int, str] = {0:'invoice', 1: 'form', 2:'note', 3:'advertisement', 4: 'email'} logger.info(f"Defined id2label mapping: {self.id2label}") # Load model path from environment variable model_path_str: Optional[str] = self.model_path_str logger.info(f"Attempting to retrieve LAYOUTLM_MODEL_PATH from environment variables.") if not model_path_str: logger.critical("Critical Error: 'LAYOUTLM_MODEL_PATH' environment variable is not set.") raise ValueError("LAYOUTLM_MODEL_PATH environment variable is not set.") model_path: Path = Path(model_path_str) logger.info(f"Retrieved model path: {model_path}") if not model_path.exists(): logger.critical(f"Critical Error: Model path from environment variable does not exist: {model_path}") raise FileNotFoundError(f"Model path not found: {model_path}") logger.info(f"Model path {model_path} exists.") try: logger.info("Calling _load_artifacts to load model and processor.") self._load_artifacts(model_path) if self.model is not None and self.processor is not None: logger.info("LayoutLMDocumentClassifier initialized successfully.") else: # This case should ideally be caught and re-raised in _load_artifacts logger.critical("LayoutLMDocumentClassifier failed to fully initialize due to artifact loading errors in _load_artifacts.") # _load_artifacts already raises on critical failure, no need to raise again except Exception as e: # Catch and log any exception that wasn't handled and re-raised in _load_artifacts logger.critical(f"An unhandled exception occurred during LayoutLMDocumentClassifier initialization: {e}", exc_info=True) raise # Re-raise the exception after logging logger.info("Initialization process completed.") def _load_artifacts(self, model_path: Path) -> None: """ Loads the LayoutLMv2 model and processor from the specified path. This is an internal helper method called during initialization. It handles the loading of both the `LayoutLMv2ForSequenceClassification` model and its corresponding `LayoutLMv2Processor` with error handling and logging. Args: model_path: Path to the LayoutLMv2 model directory or file. This path is expected to contain both the model weights and the processor configuration/files. Raises: FileNotFoundError: If the `model_path` or any required processor/model file within that path is not found. Exception: If any other unexpected error occurs during loading from the specified path (e.g., corrupt files, compatibility issues). """ logger.info(f"Starting artifact loading from {model_path} for LayoutLMv2.") processor_loaded: bool = False model_loaded: bool = False # Load Processor try: logger.info(f"Attempting to load LayoutLMv2 processor from {model_path}") # Load feature extractor and tokenizer separately to create the processor feature_extractor = LayoutLMv2FeatureExtractor() tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased") self.processor = LayoutLMv2Processor(feature_extractor, tokenizer) logger.info("LayoutLMv2 processor loaded successfully.") processor_loaded = True except Exception as e: logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 processor from {model_path}: {e}", exc_info=True) raise # Re-raise to indicate a critical initialization failure # Load Model try: logger.info(f"Attempting to load LayoutLMv2 model from {model_path}") self.model = LayoutLMv2ForSequenceClassification.from_pretrained(model_path) self.model.to(self.device) # Ensure model is on the correct device logger.info(f"LayoutLMv2 model loaded successfully and moved to {self.device}.") model_loaded = True except FileNotFoundError: logger.critical(f"Critical Error: LayoutLMv2 model file not found at {model_path}", exc_info=True) raise # Re-raise to indicate a critical initialization failure except Exception as e: logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 model from {model_path}: {e}", exc_info=True) raise # Re-raise to indicate a critical initialization failure # Conditional logging based on loading success if model_loaded and processor_loaded: logger.info("All required LayoutLMv2 artifacts loaded successfully from _load_artifacts.") elif model_loaded and not processor_loaded: logger.error("LayoutLMv2 model loaded successfully, but processor loading failed in _load_artifacts.") elif not model_loaded and processor_loaded: logger.error("LayoutLMv2 processor loaded successfully, but model loading failed in _load_artifacts.") else: logger.error("Both LayoutLMv2 model and processor failed to load during _load_artifacts.") logger.info("Artifact loading process completed.") def _prepare_inputs(self, image_path: Path) -> Optional[Dict[str, torch.Tensor]]: """ Loads and preprocesses an image to prepare inputs for the LayoutLMv2 model. This method handles loading the image from a file path, converting it to RGB, and using the loaded LayoutLMv2Processor to create the necessary input tensors (pixel values, input IDs, attention masks, bounding boxes). The tensors are then moved to the appropriate device for inference. Includes robust error handling and logging for each step. Args: image_path: Path to the image file (e.g., PNG, JPG) to be processed. Returns: A dictionary containing the prepared input tensors (e.g., 'pixel_values', 'input_ids', 'attention_mask', 'bbox') as PyTorch tensors, if image loading and preprocessing are successful. Returns `None` if any step fails (e.g., file not found, image corruption, processor error). """ logger.info(f"Starting image loading and preprocessing for {image_path}.") image: Optional[Image.Image] = None # Load image try: logger.info(f"Attempting to load image from {image_path}") image = Image.open(image_path) logger.info(f"Image loaded successfully from {image_path}.") except FileNotFoundError: logger.error(f"Error: Image file not found at {image_path}", exc_info=True) return None except Exception as e: logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True) return None # Convert image to RGB try: logger.info(f"Attempting to convert image to RGB for {image_path}.") if image is None: logger.error(f"Image is None after loading for {image_path}. Cannot convert to RGB.") return None if image.mode != "RGB": image = image.convert("RGB") logger.info(f"Image converted to RGB successfully for {image_path}.") else: logger.info(f"Image is already in RGB format for {image_path}.") except Exception as e: logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True) return None # Prepare inputs using the processor if self.processor is None: logger.error("LayoutLMv2 processor is not loaded. Cannot prepare inputs.") return None encoded_inputs: Optional[Dict[str, torch.Tensor]] = None try: logger.info(f"Attempting to prepare inputs using processor for {image_path}.") # The processor expects a PIL Image or a list of PIL Images if image is None: logger.error(f"Image is None before preprocessing for {image_path}. Cannot prepare inputs.") return None encoded_inputs = self.processor( images=image, return_tensors="pt", truncation=True, padding="max_length", max_length=512 ) logger.info(f"Inputs prepared successfully for {image_path}.") except Exception as e: logger.error(f"An error occurred during input preparation for {image_path}: {e}", exc_info=True) return None # Move inputs to the device if encoded_inputs is not None: try: logger.info(f"Attempting to move inputs to device ({self.device}) for {image_path}.") for k, v in encoded_inputs.items(): if isinstance(v, torch.Tensor): encoded_inputs[k] = v.to(self.device) logger.info(f"Inputs moved to device ({self.device}) successfully for {image_path}.") except Exception as e: logger.error(f"An error occurred while moving inputs to device for {image_path}: {e}", exc_info=True) return None else: logger.error(f"Encoded inputs are None after processing for {image_path}. Cannot move to device.") return None logger.info(f"Image loading and preprocessing completed successfully for {image_path}.") return encoded_inputs def predict(self, image_path: Path) -> Optional[str]: """ Predicts the class label for a given image using the loaded LayoutLMv2 model. This is the main prediction method. It orchestrates the process by first preparing the image inputs using `_prepare_inputs`, performing inference with the LayoutLMv2 model, determining the predicted class index from the model's output logits, and finally mapping this index to a human-readable class label using the `id2label` mapping. Includes robust error handling and logging throughout the prediction pipeline. Args: image_path: Path to the image file to classify. Returns: The predicted class label as a string if the entire prediction process is successful. Returns `None` if any critical step fails (e.g., image loading/preprocessing, model inference, or if the predicted index is not found in the `id2label` mapping). """ logger.info(f"Starting prediction process for image: {image_path}.") # Prepare inputs logger.info(f"Calling _prepare_inputs for {image_path}.") encoded_inputs: Optional[Dict[str, torch.Tensor]] = self._prepare_inputs(image_path) if encoded_inputs is None: logger.error(f"Input preparation failed for {image_path}. Cannot perform prediction.") logger.info(f"Prediction process failed for {image_path}.") return None logger.info(f"Input preparation successful for {image_path}.") # Check if model is loaded if self.model is None: logger.error("LayoutLMv2 model is not loaded. Cannot perform prediction.") logger.info(f"Prediction process failed for {image_path}.") return None logger.info("LayoutLMv2 model is loaded. Proceeding with inference.") predicted_label: Optional[str] = None try: logger.info(f"Performing model inference for {image_path}.") self.model.eval() # Set model to evaluation mode with torch.no_grad(): outputs: Any = self.model(**encoded_inputs) logits: torch.Tensor = outputs.logits # Determine predicted class index # Ensure logits is a tensor before calling argmax if not isinstance(logits, torch.Tensor): logger.error(f"Model output 'logits' is not a torch.Tensor for {image_path}. Cannot determine predicted index.") logger.info(f"Prediction process failed for {image_path} due to invalid model output.") return None predicted_class_idx: int = logits.argmax(-1).item() logger.info(f"Model inference completed for {image_path}. Predicted index: {predicted_class_idx}.") # Map index to label logger.info(f"Attempting to map predicted index {predicted_class_idx} to label.") if predicted_class_idx in self.id2label: predicted_label = self.id2label[predicted_class_idx] logger.info(f"Mapped predicted index {predicted_class_idx} to label: {predicted_label}.") else: logger.error(f"Predicted index {predicted_class_idx} not found in id2label mapping for {image_path}.") logger.info(f"Prediction process failed for {image_path} due to unknown predicted index.") return None # Return None if index is not in mapping except Exception as e: logger.error(f"An error occurred during model inference or label mapping for {image_path}: {e}", exc_info=True) logger.info(f"Prediction process failed for {image_path} due to inference/mapping error.") return None logger.info(f"Prediction process completed successfully for {image_path}. Predicted label: {predicted_label}.") return predicted_label