Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from typing import Optional, List, Dict, Any | |
| from pathlib import Path | |
| from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor, LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer | |
| import os | |
| from dotenv import load_dotenv | |
| from app.src.logger import setup_logger | |
| logger = setup_logger("layout_loader") | |
| class LayoutLMDocumentClassifier: | |
| """ | |
| A class for classifying documents using a LayoutLMv2 model. | |
| This class encapsulates the loading of the LayoutLMv2 model and its associated | |
| processor, handles image preprocessing, and performs document classification | |
| predictions. The model path is loaded from environment variables, promoting | |
| flexible configuration. It includes robust error handling, logging, and | |
| type hinting for production readiness. | |
| """ | |
| def __init__(self,model_path_str) -> None: | |
| """ | |
| Initializes the LayoutLMDocumentClassifier by loading the model and processor. | |
| The model and processor are loaded from the path specified in the | |
| environment variable 'LAYOUTLM_MODEL_PATH'. This method also sets up | |
| the device for inference (GPU if available, otherwise CPU) and defines | |
| the mapping from model output indices to class labels. | |
| Includes robust error handling and logging for initialization and artifact loading. | |
| Raises: | |
| ValueError: If the 'LAYOUTLM_MODEL_PATH' environment variable is not set. | |
| FileNotFoundError: If the model path specified in the environment variable | |
| does not exist or a required artifact file is not found | |
| during the artifact loading process. | |
| Exception: If any other unexpected error occurs during the loading | |
| of the model or processor. | |
| """ | |
| logger.info("Initializing LayoutLMDocumentClassifier.") | |
| self.model_path_str: Optional[str]=model_path_str | |
| self.model: Optional[LayoutLMv2ForSequenceClassification] = None | |
| self.processor: Optional[LayoutLMv2Processor] = None | |
| self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Define id2label mapping as a class attribute | |
| # This mapping should align with the model's output classes. | |
| self.id2label: Dict[int, str] = {0:'invoice', 1: 'form', 2:'note', 3:'advertisement', 4: 'email'} | |
| logger.info(f"Defined id2label mapping: {self.id2label}") | |
| # Load model path from environment variable | |
| model_path_str: Optional[str] = self.model_path_str | |
| logger.info(f"Attempting to retrieve LAYOUTLM_MODEL_PATH from environment variables.") | |
| if not model_path_str: | |
| logger.critical("Critical Error: 'LAYOUTLM_MODEL_PATH' environment variable is not set.") | |
| raise ValueError("LAYOUTLM_MODEL_PATH environment variable is not set.") | |
| model_path: Path = Path(model_path_str) | |
| logger.info(f"Retrieved model path: {model_path}") | |
| if not model_path.exists(): | |
| logger.critical(f"Critical Error: Model path from environment variable does not exist: {model_path}") | |
| raise FileNotFoundError(f"Model path not found: {model_path}") | |
| logger.info(f"Model path {model_path} exists.") | |
| try: | |
| logger.info("Calling _load_artifacts to load model and processor.") | |
| self._load_artifacts(model_path) | |
| if self.model is not None and self.processor is not None: | |
| logger.info("LayoutLMDocumentClassifier initialized successfully.") | |
| else: | |
| # This case should ideally be caught and re-raised in _load_artifacts | |
| logger.critical("LayoutLMDocumentClassifier failed to fully initialize due to artifact loading errors in _load_artifacts.") | |
| # _load_artifacts already raises on critical failure, no need to raise again | |
| except Exception as e: | |
| # Catch and log any exception that wasn't handled and re-raised in _load_artifacts | |
| logger.critical(f"An unhandled exception occurred during LayoutLMDocumentClassifier initialization: {e}", exc_info=True) | |
| raise # Re-raise the exception after logging | |
| logger.info("Initialization process completed.") | |
| def _load_artifacts(self, model_path: Path) -> None: | |
| """ | |
| Loads the LayoutLMv2 model and processor from the specified path. | |
| This is an internal helper method called during initialization. It handles | |
| the loading of both the `LayoutLMv2ForSequenceClassification` model and | |
| its corresponding `LayoutLMv2Processor` with error handling and logging. | |
| Args: | |
| model_path: Path to the LayoutLMv2 model directory or file. This path | |
| is expected to contain both the model weights and the | |
| processor configuration/files. | |
| Raises: | |
| FileNotFoundError: If the `model_path` or any required processor/model | |
| file within that path is not found. | |
| Exception: If any other unexpected error occurs during loading | |
| from the specified path (e.g., corrupt files, compatibility issues). | |
| """ | |
| logger.info(f"Starting artifact loading from {model_path} for LayoutLMv2.") | |
| processor_loaded: bool = False | |
| model_loaded: bool = False | |
| # Load Processor | |
| try: | |
| logger.info(f"Attempting to load LayoutLMv2 processor from {model_path}") | |
| # Load feature extractor and tokenizer separately to create the processor | |
| feature_extractor = LayoutLMv2FeatureExtractor() | |
| tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased") | |
| self.processor = LayoutLMv2Processor(feature_extractor, tokenizer) | |
| logger.info("LayoutLMv2 processor loaded successfully.") | |
| processor_loaded = True | |
| except Exception as e: | |
| logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 processor from {model_path}: {e}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| # Load Model | |
| try: | |
| logger.info(f"Attempting to load LayoutLMv2 model from {model_path}") | |
| self.model = LayoutLMv2ForSequenceClassification.from_pretrained(model_path) | |
| self.model.to(self.device) # Ensure model is on the correct device | |
| logger.info(f"LayoutLMv2 model loaded successfully and moved to {self.device}.") | |
| model_loaded = True | |
| except FileNotFoundError: | |
| logger.critical(f"Critical Error: LayoutLMv2 model file not found at {model_path}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| except Exception as e: | |
| logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 model from {model_path}: {e}", exc_info=True) | |
| raise # Re-raise to indicate a critical initialization failure | |
| # Conditional logging based on loading success | |
| if model_loaded and processor_loaded: | |
| logger.info("All required LayoutLMv2 artifacts loaded successfully from _load_artifacts.") | |
| elif model_loaded and not processor_loaded: | |
| logger.error("LayoutLMv2 model loaded successfully, but processor loading failed in _load_artifacts.") | |
| elif not model_loaded and processor_loaded: | |
| logger.error("LayoutLMv2 processor loaded successfully, but model loading failed in _load_artifacts.") | |
| else: | |
| logger.error("Both LayoutLMv2 model and processor failed to load during _load_artifacts.") | |
| logger.info("Artifact loading process completed.") | |
| def _prepare_inputs(self, image_path: Path) -> Optional[Dict[str, torch.Tensor]]: | |
| """ | |
| Loads and preprocesses an image to prepare inputs for the LayoutLMv2 model. | |
| This method handles loading the image from a file path, converting it to RGB, | |
| and using the loaded LayoutLMv2Processor to create the necessary input tensors | |
| (pixel values, input IDs, attention masks, bounding boxes). The tensors are | |
| then moved to the appropriate device for inference. | |
| Includes robust error handling and logging for each step. | |
| Args: | |
| image_path: Path to the image file (e.g., PNG, JPG) to be processed. | |
| Returns: | |
| A dictionary containing the prepared input tensors (e.g., 'pixel_values', | |
| 'input_ids', 'attention_mask', 'bbox') as PyTorch tensors, if image | |
| loading and preprocessing are successful. Returns `None` if any | |
| step fails (e.g., file not found, image corruption, processor error). | |
| """ | |
| logger.info(f"Starting image loading and preprocessing for {image_path}.") | |
| image: Optional[Image.Image] = None | |
| # Load image | |
| try: | |
| logger.info(f"Attempting to load image from {image_path}") | |
| image = Image.open(image_path) | |
| logger.info(f"Image loaded successfully from {image_path}.") | |
| except FileNotFoundError: | |
| logger.error(f"Error: Image file not found at {image_path}", exc_info=True) | |
| return None | |
| except Exception as e: | |
| logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True) | |
| return None | |
| # Convert image to RGB | |
| try: | |
| logger.info(f"Attempting to convert image to RGB for {image_path}.") | |
| if image is None: | |
| logger.error(f"Image is None after loading for {image_path}. Cannot convert to RGB.") | |
| return None | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| logger.info(f"Image converted to RGB successfully for {image_path}.") | |
| else: | |
| logger.info(f"Image is already in RGB format for {image_path}.") | |
| except Exception as e: | |
| logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True) | |
| return None | |
| # Prepare inputs using the processor | |
| if self.processor is None: | |
| logger.error("LayoutLMv2 processor is not loaded. Cannot prepare inputs.") | |
| return None | |
| encoded_inputs: Optional[Dict[str, torch.Tensor]] = None | |
| try: | |
| logger.info(f"Attempting to prepare inputs using processor for {image_path}.") | |
| # The processor expects a PIL Image or a list of PIL Images | |
| if image is None: | |
| logger.error(f"Image is None before preprocessing for {image_path}. Cannot prepare inputs.") | |
| return None | |
| encoded_inputs = self.processor( | |
| images=image, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512 | |
| ) | |
| logger.info(f"Inputs prepared successfully for {image_path}.") | |
| except Exception as e: | |
| logger.error(f"An error occurred during input preparation for {image_path}: {e}", exc_info=True) | |
| return None | |
| # Move inputs to the device | |
| if encoded_inputs is not None: | |
| try: | |
| logger.info(f"Attempting to move inputs to device ({self.device}) for {image_path}.") | |
| for k, v in encoded_inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| encoded_inputs[k] = v.to(self.device) | |
| logger.info(f"Inputs moved to device ({self.device}) successfully for {image_path}.") | |
| except Exception as e: | |
| logger.error(f"An error occurred while moving inputs to device for {image_path}: {e}", exc_info=True) | |
| return None | |
| else: | |
| logger.error(f"Encoded inputs are None after processing for {image_path}. Cannot move to device.") | |
| return None | |
| logger.info(f"Image loading and preprocessing completed successfully for {image_path}.") | |
| return encoded_inputs | |
| def predict(self, image_path: Path) -> Optional[str]: | |
| """ | |
| Predicts the class label for a given image using the loaded LayoutLMv2 model. | |
| This is the main prediction method. It orchestrates the process by first | |
| preparing the image inputs using `_prepare_inputs`, performing inference | |
| with the LayoutLMv2 model, determining the predicted class index from the | |
| model's output logits, and finally mapping this index to a human-readable | |
| class label using the `id2label` mapping. | |
| Includes robust error handling and logging throughout the prediction pipeline. | |
| Args: | |
| image_path: Path to the image file to classify. | |
| Returns: | |
| The predicted class label as a string if the entire prediction process | |
| is successful. Returns `None` if any critical step fails (e.g., | |
| image loading/preprocessing, model inference, or if the predicted | |
| index is not found in the `id2label` mapping). | |
| """ | |
| logger.info(f"Starting prediction process for image: {image_path}.") | |
| # Prepare inputs | |
| logger.info(f"Calling _prepare_inputs for {image_path}.") | |
| encoded_inputs: Optional[Dict[str, torch.Tensor]] = self._prepare_inputs(image_path) | |
| if encoded_inputs is None: | |
| logger.error(f"Input preparation failed for {image_path}. Cannot perform prediction.") | |
| logger.info(f"Prediction process failed for {image_path}.") | |
| return None | |
| logger.info(f"Input preparation successful for {image_path}.") | |
| # Check if model is loaded | |
| if self.model is None: | |
| logger.error("LayoutLMv2 model is not loaded. Cannot perform prediction.") | |
| logger.info(f"Prediction process failed for {image_path}.") | |
| return None | |
| logger.info("LayoutLMv2 model is loaded. Proceeding with inference.") | |
| predicted_label: Optional[str] = None | |
| try: | |
| logger.info(f"Performing model inference for {image_path}.") | |
| self.model.eval() # Set model to evaluation mode | |
| with torch.no_grad(): | |
| outputs: Any = self.model(**encoded_inputs) | |
| logits: torch.Tensor = outputs.logits | |
| # Determine predicted class index | |
| # Ensure logits is a tensor before calling argmax | |
| if not isinstance(logits, torch.Tensor): | |
| logger.error(f"Model output 'logits' is not a torch.Tensor for {image_path}. Cannot determine predicted index.") | |
| logger.info(f"Prediction process failed for {image_path} due to invalid model output.") | |
| return None | |
| predicted_class_idx: int = logits.argmax(-1).item() | |
| logger.info(f"Model inference completed for {image_path}. Predicted index: {predicted_class_idx}.") | |
| # Map index to label | |
| logger.info(f"Attempting to map predicted index {predicted_class_idx} to label.") | |
| if predicted_class_idx in self.id2label: | |
| predicted_label = self.id2label[predicted_class_idx] | |
| logger.info(f"Mapped predicted index {predicted_class_idx} to label: {predicted_label}.") | |
| else: | |
| logger.error(f"Predicted index {predicted_class_idx} not found in id2label mapping for {image_path}.") | |
| logger.info(f"Prediction process failed for {image_path} due to unknown predicted index.") | |
| return None # Return None if index is not in mapping | |
| except Exception as e: | |
| logger.error(f"An error occurred during model inference or label mapping for {image_path}: {e}", exc_info=True) | |
| logger.info(f"Prediction process failed for {image_path} due to inference/mapping error.") | |
| return None | |
| logger.info(f"Prediction process completed successfully for {image_path}. Predicted label: {predicted_label}.") | |
| return predicted_label | |