File size: 16,633 Bytes
07fc447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from PIL import Image
import numpy as np
import torch
from typing import Optional, List, Dict, Any
from pathlib import Path
from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor, LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer
import os
from dotenv import load_dotenv
from app.src.logger import setup_logger



logger = setup_logger("layout_loader")

class LayoutLMDocumentClassifier:
    """

    A class for classifying documents using a LayoutLMv2 model.



    This class encapsulates the loading of the LayoutLMv2 model and its associated

    processor, handles image preprocessing, and performs document classification

    predictions. The model path is loaded from environment variables, promoting

    flexible configuration. It includes robust error handling, logging, and

    type hinting for production readiness.

    """

    def __init__(self,model_path_str) -> None:
        """

        Initializes the LayoutLMDocumentClassifier by loading the model and processor.



        The model and processor are loaded from the path specified in the

        environment variable 'LAYOUTLM_MODEL_PATH'. This method also sets up

        the device for inference (GPU if available, otherwise CPU) and defines

         the mapping from model output indices to class labels.



        Includes robust error handling and logging for initialization and artifact loading.



        Raises:

            ValueError: If the 'LAYOUTLM_MODEL_PATH' environment variable is not set.

            FileNotFoundError: If the model path specified in the environment variable

                               does not exist or a required artifact file is not found

                               during the artifact loading process.

            Exception: If any other unexpected error occurs during the loading

                       of the model or processor.

        """
        logger.info("Initializing LayoutLMDocumentClassifier.")
        self.model_path_str: Optional[str]=model_path_str
        self.model: Optional[LayoutLMv2ForSequenceClassification] = None
        self.processor: Optional[LayoutLMv2Processor] = None
        self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        # Define id2label mapping as a class attribute
        # This mapping should align with the model's output classes.
        self.id2label: Dict[int, str] = {0:'invoice', 1: 'form', 2:'note', 3:'advertisement', 4: 'email'}
        logger.info(f"Defined id2label mapping: {self.id2label}")

        # Load model path from environment variable
        model_path_str: Optional[str] = self.model_path_str
        logger.info(f"Attempting to retrieve LAYOUTLM_MODEL_PATH from environment variables.")
        if not model_path_str:
            logger.critical("Critical Error: 'LAYOUTLM_MODEL_PATH' environment variable is not set.")
            raise ValueError("LAYOUTLM_MODEL_PATH environment variable is not set.")

        model_path: Path = Path(model_path_str)
        logger.info(f"Retrieved model path: {model_path}")
        if not model_path.exists():
             logger.critical(f"Critical Error: Model path from environment variable does not exist: {model_path}")
             raise FileNotFoundError(f"Model path not found: {model_path}")
        logger.info(f"Model path {model_path} exists.")


        try:
            logger.info("Calling _load_artifacts to load model and processor.")
            self._load_artifacts(model_path)
            if self.model is not None and self.processor is not None:
                logger.info("LayoutLMDocumentClassifier initialized successfully.")
            else:
                # This case should ideally be caught and re-raised in _load_artifacts
                logger.critical("LayoutLMDocumentClassifier failed to fully initialize due to artifact loading errors in _load_artifacts.")
                # _load_artifacts already raises on critical failure, no need to raise again
        except Exception as e:
            # Catch and log any exception that wasn't handled and re-raised in _load_artifacts
            logger.critical(f"An unhandled exception occurred during LayoutLMDocumentClassifier initialization: {e}", exc_info=True)
            raise # Re-raise the exception after logging
        logger.info("Initialization process completed.")


    def _load_artifacts(self, model_path: Path) -> None:
        """

        Loads the LayoutLMv2 model and processor from the specified path.



        This is an internal helper method called during initialization. It handles

        the loading of both the `LayoutLMv2ForSequenceClassification` model and

        its corresponding `LayoutLMv2Processor` with error handling and logging.



        Args:

            model_path: Path to the LayoutLMv2 model directory or file. This path

                        is expected to contain both the model weights and the

                        processor configuration/files.



        Raises:

            FileNotFoundError: If the `model_path` or any required processor/model

                               file within that path is not found.

            Exception: If any other unexpected error occurs during loading

                       from the specified path (e.g., corrupt files, compatibility issues).

        """
        logger.info(f"Starting artifact loading from {model_path} for LayoutLMv2.")
        processor_loaded: bool = False
        model_loaded: bool = False

        # Load Processor
        try:
            logger.info(f"Attempting to load LayoutLMv2 processor from {model_path}")
            # Load feature extractor and tokenizer separately to create the processor
            feature_extractor = LayoutLMv2FeatureExtractor()
            tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")
            self.processor = LayoutLMv2Processor(feature_extractor, tokenizer)
            logger.info("LayoutLMv2 processor loaded successfully.")
            processor_loaded = True
        except Exception as e:
            logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 processor from {model_path}: {e}", exc_info=True)
            raise # Re-raise to indicate a critical initialization failure

        # Load Model
        try:
            logger.info(f"Attempting to load LayoutLMv2 model from {model_path}")
            self.model = LayoutLMv2ForSequenceClassification.from_pretrained(model_path)
            self.model.to(self.device) # Ensure model is on the correct device
            logger.info(f"LayoutLMv2 model loaded successfully and moved to {self.device}.")
            model_loaded = True
        except FileNotFoundError:
            logger.critical(f"Critical Error: LayoutLMv2 model file not found at {model_path}", exc_info=True)
            raise # Re-raise to indicate a critical initialization failure
        except Exception as e:
            logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 model from {model_path}: {e}", exc_info=True)
            raise # Re-raise to indicate a critical initialization failure

        # Conditional logging based on loading success
        if model_loaded and processor_loaded:
             logger.info("All required LayoutLMv2 artifacts loaded successfully from _load_artifacts.")
        elif model_loaded and not processor_loaded:
             logger.error("LayoutLMv2 model loaded successfully, but processor loading failed in _load_artifacts.")
        elif not model_loaded and processor_loaded:
             logger.error("LayoutLMv2 processor loaded successfully, but model loading failed in _load_artifacts.")
        else:
            logger.error("Both LayoutLMv2 model and processor failed to load during _load_artifacts.")
        logger.info("Artifact loading process completed.")


    def _prepare_inputs(self, image_path: Path) -> Optional[Dict[str, torch.Tensor]]:
        """

        Loads and preprocesses an image to prepare inputs for the LayoutLMv2 model.



        This method handles loading the image from a file path, converting it to RGB,

        and using the loaded LayoutLMv2Processor to create the necessary input tensors

        (pixel values, input IDs, attention masks, bounding boxes). The tensors are

        then moved to the appropriate device for inference.



        Includes robust error handling and logging for each step.



        Args:

            image_path: Path to the image file (e.g., PNG, JPG) to be processed.



        Returns:

            A dictionary containing the prepared input tensors (e.g., 'pixel_values',

            'input_ids', 'attention_mask', 'bbox') as PyTorch tensors, if image

            loading and preprocessing are successful. Returns `None` if any

            step fails (e.g., file not found, image corruption, processor error).

        """
        logger.info(f"Starting image loading and preprocessing for {image_path}.")
        image: Optional[Image.Image] = None

        # Load image
        try:
            logger.info(f"Attempting to load image from {image_path}")
            image = Image.open(image_path)
            logger.info(f"Image loaded successfully from {image_path}.")
        except FileNotFoundError:
            logger.error(f"Error: Image file not found at {image_path}", exc_info=True)
            return None
        except Exception as e:
            logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True)
            return None

        # Convert image to RGB
        try:
            logger.info(f"Attempting to convert image to RGB for {image_path}.")
            if image is None:
                 logger.error(f"Image is None after loading for {image_path}. Cannot convert to RGB.")
                 return None
            if image.mode != "RGB":
                image = image.convert("RGB")
                logger.info(f"Image converted to RGB successfully for {image_path}.")
            else:
                 logger.info(f"Image is already in RGB format for {image_path}.")

        except Exception as e:
            logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True)
            return None


        # Prepare inputs using the processor
        if self.processor is None:
            logger.error("LayoutLMv2 processor is not loaded. Cannot prepare inputs.")
            return None

        encoded_inputs: Optional[Dict[str, torch.Tensor]] = None
        try:
            logger.info(f"Attempting to prepare inputs using processor for {image_path}.")
            # The processor expects a PIL Image or a list of PIL Images
            if image is None:
                 logger.error(f"Image is None before preprocessing for {image_path}. Cannot prepare inputs.")
                 return None

            encoded_inputs = self.processor(
                images=image,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=512
            )
            logger.info(f"Inputs prepared successfully for {image_path}.")
        except Exception as e:
            logger.error(f"An error occurred during input preparation for {image_path}: {e}", exc_info=True)
            return None

        # Move inputs to the device
        if encoded_inputs is not None:
            try:
                logger.info(f"Attempting to move inputs to device ({self.device}) for {image_path}.")
                for k, v in encoded_inputs.items():
                    if isinstance(v, torch.Tensor):
                        encoded_inputs[k] = v.to(self.device)
                logger.info(f"Inputs moved to device ({self.device}) successfully for {image_path}.")
            except Exception as e:
                logger.error(f"An error occurred while moving inputs to device for {image_path}: {e}", exc_info=True)
                return None
        else:
             logger.error(f"Encoded inputs are None after processing for {image_path}. Cannot move to device.")
             return None


        logger.info(f"Image loading and preprocessing completed successfully for {image_path}.")
        return encoded_inputs


    def predict(self, image_path: Path) -> Optional[str]:
        """

        Predicts the class label for a given image using the loaded LayoutLMv2 model.



        This is the main prediction method. It orchestrates the process by first

        preparing the image inputs using `_prepare_inputs`, performing inference

        with the LayoutLMv2 model, determining the predicted class index from the

        model's output logits, and finally mapping this index to a human-readable

        class label using the `id2label` mapping.



        Includes robust error handling and logging throughout the prediction pipeline.



        Args:

            image_path: Path to the image file to classify.



        Returns:

            The predicted class label as a string if the entire prediction process

            is successful. Returns `None` if any critical step fails (e.g.,

            image loading/preprocessing, model inference, or if the predicted

            index is not found in the `id2label` mapping).

        """
        logger.info(f"Starting prediction process for image: {image_path}.")

        # Prepare inputs
        logger.info(f"Calling _prepare_inputs for {image_path}.")
        encoded_inputs: Optional[Dict[str, torch.Tensor]] = self._prepare_inputs(image_path)
        if encoded_inputs is None:
            logger.error(f"Input preparation failed for {image_path}. Cannot perform prediction.")
            logger.info(f"Prediction process failed for {image_path}.")
            return None
        logger.info(f"Input preparation successful for {image_path}.")


        # Check if model is loaded
        if self.model is None:
            logger.error("LayoutLMv2 model is not loaded. Cannot perform prediction.")
            logger.info(f"Prediction process failed for {image_path}.")
            return None
        logger.info("LayoutLMv2 model is loaded. Proceeding with inference.")

        predicted_label: Optional[str] = None

        try:
            logger.info(f"Performing model inference for {image_path}.")
            self.model.eval() # Set model to evaluation mode
            with torch.no_grad():
                outputs: Any = self.model(**encoded_inputs)
                logits: torch.Tensor = outputs.logits

            # Determine predicted class index
            # Ensure logits is a tensor before calling argmax
            if not isinstance(logits, torch.Tensor):
                 logger.error(f"Model output 'logits' is not a torch.Tensor for {image_path}. Cannot determine predicted index.")
                 logger.info(f"Prediction process failed for {image_path} due to invalid model output.")
                 return None

            predicted_class_idx: int = logits.argmax(-1).item()
            logger.info(f"Model inference completed for {image_path}. Predicted index: {predicted_class_idx}.")

            # Map index to label
            logger.info(f"Attempting to map predicted index {predicted_class_idx} to label.")
            if predicted_class_idx in self.id2label:
                predicted_label = self.id2label[predicted_class_idx]
                logger.info(f"Mapped predicted index {predicted_class_idx} to label: {predicted_label}.")
            else:
                logger.error(f"Predicted index {predicted_class_idx} not found in id2label mapping for {image_path}.")
                logger.info(f"Prediction process failed for {image_path} due to unknown predicted index.")
                return None # Return None if index is not in mapping

        except Exception as e:
            logger.error(f"An error occurred during model inference or label mapping for {image_path}: {e}", exc_info=True)
            logger.info(f"Prediction process failed for {image_path} due to inference/mapping error.")
            return None

        logger.info(f"Prediction process completed successfully for {image_path}. Predicted label: {predicted_label}.")
        return predicted_label