KaushiGihan commited on
Commit
07fc447
·
verified ·
1 Parent(s): f9d20c4

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ artifacts/model/vgg_model/model.keras filter=lfs diff=lfs merge=lfs -text
37
+ artifacts/model/VIT_model/confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from PIL import Image
4
+
5
+ # Import your model classes (adjust import paths as needed)
6
+ from app.src.vit_load import VITDocumentClassifier
7
+ from app.src.vgg16_load import VGGDocumentClassifier
8
+ from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path
9
+
10
+ # Load models once at startup
11
+ vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
12
+ vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
13
+
14
+ def predict_vit(image, cut_off):
15
+ if image is None:
16
+ return "Please upload an image."
17
+ temp_path = "temp_vit_image.png"
18
+ image.save(temp_path)
19
+ result = vit_model.predict(Path(temp_path), cut_off)
20
+ return f"ViT Prediction: {result}"
21
+
22
+ def predict_vgg(image):
23
+ if image is None:
24
+ return "Please upload an image."
25
+ temp_path = "temp_vgg_image.png"
26
+ image.save(temp_path)
27
+ result = vgg_model.predict(Path(temp_path))
28
+ return f"VGG16 Prediction: {result}"
29
+
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
32
+ with gr.Row():
33
+ with gr.Column():
34
+ image_input = gr.Image(type="pil", label="Upload Image")
35
+ cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
36
+ with gr.Column():
37
+ result_output = gr.Textbox(label="Prediction Result", interactive=False)
38
+ with gr.Row():
39
+ vit_btn = gr.Button("Predict with ViT Model")
40
+ vgg_btn = gr.Button("Predict with VGG16 Model")
41
+
42
+ vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
43
+ vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)
44
+
45
+ if __name__ == "__main__":
46
+ demo.launch()
app/__init__.py ADDED
File without changes
app/app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse,FileResponse , JSONResponse,HTMLResponse
4
+ from pydantic import BaseModel
5
+
6
+
7
+ import uvicorn
8
+ import cv2
9
+ import tempfile
10
+ import shutil
11
+ import os
12
+ import warnings
13
+ import base64
14
+ import numpy as np
15
+ from pathlib import Path
16
+
17
+ from app.src.model_loader import vit_loader,vgg_loader
18
+ from app.src.logger import setup_logger
19
+
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+
24
+ app=FastAPI(title="Document_Classifire",
25
+ description="FastAPI",
26
+ version="0.115.4")
27
+
28
+ # Allow all origins (replace * with specific origins if needed)
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ @app.get("/")
38
+ async def root():
39
+ return {"Fast API":"API is woorking"}
40
+
41
+
42
+ # Suppress warnings
43
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' # 0 = all logs, 1 = filter out info, 2 = filter out warnings, 3 = filter out errors
44
+ warnings.filterwarnings("ignore")
45
+
46
+ logger = setup_logger()
47
+
48
+ @app.post("/vit_model")
49
+ async def vit_clf(cut_off:float=0.5,image_file: UploadFile = File(...)):
50
+
51
+ try:
52
+ # Create a temporary directory
53
+ temp_dir = tempfile.mkdtemp()
54
+ # Create a temporary file path
55
+ temp_file_path = os.path.join(temp_dir,image_file.filename)
56
+ # Write the uploaded file content to the temporary file
57
+ with open(temp_file_path, "wb") as temp_file:
58
+ shutil.copyfileobj(image_file.file, temp_file)
59
+ result=vit_loader().predict(image_path=Path(temp_file_path), cut_off=cut_off)
60
+ logger.info(result)
61
+
62
+ if result is not None:
63
+ return JSONResponse(content={"status":1,"document_classe":result})
64
+ else:
65
+ return JSONResponse(content={"status":0,"document_classe":None})
66
+
67
+ except Exception as e:
68
+ logger.error(str(e))
69
+ return JSONResponse(content={"status":0,"error_message":str(e)})
70
+
71
+
72
+
73
+
74
+ @app.post("/vgg_model")
75
+ async def vgg_clf(image_file: UploadFile = File(...)):
76
+
77
+ try:
78
+ # Create a temporary directory
79
+ temp_dir = tempfile.mkdtemp()
80
+ # Create a temporary file path
81
+ temp_file_path = os.path.join(temp_dir,image_file.filename)
82
+ # Write the uploaded file content to the temporary file
83
+ with open(temp_file_path, "wb") as temp_file:
84
+ shutil.copyfileobj(image_file.file, temp_file)
85
+ result=vgg_loader().predict(image_path=Path(temp_file_path))
86
+ logger.info(result)
87
+
88
+ if result is not None:
89
+ return JSONResponse(content={"status":1,"document_classe":result})
90
+ else:
91
+ return JSONResponse(content={"status":0,"document_classe":None})
92
+
93
+ except Exception as e:
94
+ logger.error(str(e))
95
+ return JSONResponse(content={"status":0,"document_classe":str(e)})
96
+
97
+
98
+
app/src/__init__.py ADDED
File without changes
app/src/constant.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ vit_model_path=Path(r"artifacts\model\VIT_model\model.pth")
4
+ vit_mlb_path=Path(r"artifacts\model\VIT_model\mlb.joblib")
5
+
6
+ vgg_model_path=Path(r"artifacts\model\vgg_model\model.keras")
7
+ vgg_mlb_path=Path(r"artifacts\model\vgg_model\mlb.joblib")
8
+
9
+
10
+
11
+ layout_model_path=Path(r"artifacts\model\layout_model\model.pth")
app/src/layout_loader.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, List, Dict, Any
5
+ from pathlib import Path
6
+ from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor, LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer
7
+ import os
8
+ from dotenv import load_dotenv
9
+ from app.src.logger import setup_logger
10
+
11
+
12
+
13
+ logger = setup_logger("layout_loader")
14
+
15
+ class LayoutLMDocumentClassifier:
16
+ """
17
+ A class for classifying documents using a LayoutLMv2 model.
18
+
19
+ This class encapsulates the loading of the LayoutLMv2 model and its associated
20
+ processor, handles image preprocessing, and performs document classification
21
+ predictions. The model path is loaded from environment variables, promoting
22
+ flexible configuration. It includes robust error handling, logging, and
23
+ type hinting for production readiness.
24
+ """
25
+
26
+ def __init__(self,model_path_str) -> None:
27
+ """
28
+ Initializes the LayoutLMDocumentClassifier by loading the model and processor.
29
+
30
+ The model and processor are loaded from the path specified in the
31
+ environment variable 'LAYOUTLM_MODEL_PATH'. This method also sets up
32
+ the device for inference (GPU if available, otherwise CPU) and defines
33
+ the mapping from model output indices to class labels.
34
+
35
+ Includes robust error handling and logging for initialization and artifact loading.
36
+
37
+ Raises:
38
+ ValueError: If the 'LAYOUTLM_MODEL_PATH' environment variable is not set.
39
+ FileNotFoundError: If the model path specified in the environment variable
40
+ does not exist or a required artifact file is not found
41
+ during the artifact loading process.
42
+ Exception: If any other unexpected error occurs during the loading
43
+ of the model or processor.
44
+ """
45
+ logger.info("Initializing LayoutLMDocumentClassifier.")
46
+ self.model_path_str: Optional[str]=model_path_str
47
+ self.model: Optional[LayoutLMv2ForSequenceClassification] = None
48
+ self.processor: Optional[LayoutLMv2Processor] = None
49
+ self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ logger.info(f"Using device: {self.device}")
51
+ # Define id2label mapping as a class attribute
52
+ # This mapping should align with the model's output classes.
53
+ self.id2label: Dict[int, str] = {0:'invoice', 1: 'form', 2:'note', 3:'advertisement', 4: 'email'}
54
+ logger.info(f"Defined id2label mapping: {self.id2label}")
55
+
56
+ # Load model path from environment variable
57
+ model_path_str: Optional[str] = self.model_path_str
58
+ logger.info(f"Attempting to retrieve LAYOUTLM_MODEL_PATH from environment variables.")
59
+ if not model_path_str:
60
+ logger.critical("Critical Error: 'LAYOUTLM_MODEL_PATH' environment variable is not set.")
61
+ raise ValueError("LAYOUTLM_MODEL_PATH environment variable is not set.")
62
+
63
+ model_path: Path = Path(model_path_str)
64
+ logger.info(f"Retrieved model path: {model_path}")
65
+ if not model_path.exists():
66
+ logger.critical(f"Critical Error: Model path from environment variable does not exist: {model_path}")
67
+ raise FileNotFoundError(f"Model path not found: {model_path}")
68
+ logger.info(f"Model path {model_path} exists.")
69
+
70
+
71
+ try:
72
+ logger.info("Calling _load_artifacts to load model and processor.")
73
+ self._load_artifacts(model_path)
74
+ if self.model is not None and self.processor is not None:
75
+ logger.info("LayoutLMDocumentClassifier initialized successfully.")
76
+ else:
77
+ # This case should ideally be caught and re-raised in _load_artifacts
78
+ logger.critical("LayoutLMDocumentClassifier failed to fully initialize due to artifact loading errors in _load_artifacts.")
79
+ # _load_artifacts already raises on critical failure, no need to raise again
80
+ except Exception as e:
81
+ # Catch and log any exception that wasn't handled and re-raised in _load_artifacts
82
+ logger.critical(f"An unhandled exception occurred during LayoutLMDocumentClassifier initialization: {e}", exc_info=True)
83
+ raise # Re-raise the exception after logging
84
+ logger.info("Initialization process completed.")
85
+
86
+
87
+ def _load_artifacts(self, model_path: Path) -> None:
88
+ """
89
+ Loads the LayoutLMv2 model and processor from the specified path.
90
+
91
+ This is an internal helper method called during initialization. It handles
92
+ the loading of both the `LayoutLMv2ForSequenceClassification` model and
93
+ its corresponding `LayoutLMv2Processor` with error handling and logging.
94
+
95
+ Args:
96
+ model_path: Path to the LayoutLMv2 model directory or file. This path
97
+ is expected to contain both the model weights and the
98
+ processor configuration/files.
99
+
100
+ Raises:
101
+ FileNotFoundError: If the `model_path` or any required processor/model
102
+ file within that path is not found.
103
+ Exception: If any other unexpected error occurs during loading
104
+ from the specified path (e.g., corrupt files, compatibility issues).
105
+ """
106
+ logger.info(f"Starting artifact loading from {model_path} for LayoutLMv2.")
107
+ processor_loaded: bool = False
108
+ model_loaded: bool = False
109
+
110
+ # Load Processor
111
+ try:
112
+ logger.info(f"Attempting to load LayoutLMv2 processor from {model_path}")
113
+ # Load feature extractor and tokenizer separately to create the processor
114
+ feature_extractor = LayoutLMv2FeatureExtractor()
115
+ tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")
116
+ self.processor = LayoutLMv2Processor(feature_extractor, tokenizer)
117
+ logger.info("LayoutLMv2 processor loaded successfully.")
118
+ processor_loaded = True
119
+ except Exception as e:
120
+ logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 processor from {model_path}: {e}", exc_info=True)
121
+ raise # Re-raise to indicate a critical initialization failure
122
+
123
+ # Load Model
124
+ try:
125
+ logger.info(f"Attempting to load LayoutLMv2 model from {model_path}")
126
+ self.model = LayoutLMv2ForSequenceClassification.from_pretrained(model_path)
127
+ self.model.to(self.device) # Ensure model is on the correct device
128
+ logger.info(f"LayoutLMv2 model loaded successfully and moved to {self.device}.")
129
+ model_loaded = True
130
+ except FileNotFoundError:
131
+ logger.critical(f"Critical Error: LayoutLMv2 model file not found at {model_path}", exc_info=True)
132
+ raise # Re-raise to indicate a critical initialization failure
133
+ except Exception as e:
134
+ logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 model from {model_path}: {e}", exc_info=True)
135
+ raise # Re-raise to indicate a critical initialization failure
136
+
137
+ # Conditional logging based on loading success
138
+ if model_loaded and processor_loaded:
139
+ logger.info("All required LayoutLMv2 artifacts loaded successfully from _load_artifacts.")
140
+ elif model_loaded and not processor_loaded:
141
+ logger.error("LayoutLMv2 model loaded successfully, but processor loading failed in _load_artifacts.")
142
+ elif not model_loaded and processor_loaded:
143
+ logger.error("LayoutLMv2 processor loaded successfully, but model loading failed in _load_artifacts.")
144
+ else:
145
+ logger.error("Both LayoutLMv2 model and processor failed to load during _load_artifacts.")
146
+ logger.info("Artifact loading process completed.")
147
+
148
+
149
+ def _prepare_inputs(self, image_path: Path) -> Optional[Dict[str, torch.Tensor]]:
150
+ """
151
+ Loads and preprocesses an image to prepare inputs for the LayoutLMv2 model.
152
+
153
+ This method handles loading the image from a file path, converting it to RGB,
154
+ and using the loaded LayoutLMv2Processor to create the necessary input tensors
155
+ (pixel values, input IDs, attention masks, bounding boxes). The tensors are
156
+ then moved to the appropriate device for inference.
157
+
158
+ Includes robust error handling and logging for each step.
159
+
160
+ Args:
161
+ image_path: Path to the image file (e.g., PNG, JPG) to be processed.
162
+
163
+ Returns:
164
+ A dictionary containing the prepared input tensors (e.g., 'pixel_values',
165
+ 'input_ids', 'attention_mask', 'bbox') as PyTorch tensors, if image
166
+ loading and preprocessing are successful. Returns `None` if any
167
+ step fails (e.g., file not found, image corruption, processor error).
168
+ """
169
+ logger.info(f"Starting image loading and preprocessing for {image_path}.")
170
+ image: Optional[Image.Image] = None
171
+
172
+ # Load image
173
+ try:
174
+ logger.info(f"Attempting to load image from {image_path}")
175
+ image = Image.open(image_path)
176
+ logger.info(f"Image loaded successfully from {image_path}.")
177
+ except FileNotFoundError:
178
+ logger.error(f"Error: Image file not found at {image_path}", exc_info=True)
179
+ return None
180
+ except Exception as e:
181
+ logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True)
182
+ return None
183
+
184
+ # Convert image to RGB
185
+ try:
186
+ logger.info(f"Attempting to convert image to RGB for {image_path}.")
187
+ if image is None:
188
+ logger.error(f"Image is None after loading for {image_path}. Cannot convert to RGB.")
189
+ return None
190
+ if image.mode != "RGB":
191
+ image = image.convert("RGB")
192
+ logger.info(f"Image converted to RGB successfully for {image_path}.")
193
+ else:
194
+ logger.info(f"Image is already in RGB format for {image_path}.")
195
+
196
+ except Exception as e:
197
+ logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True)
198
+ return None
199
+
200
+
201
+ # Prepare inputs using the processor
202
+ if self.processor is None:
203
+ logger.error("LayoutLMv2 processor is not loaded. Cannot prepare inputs.")
204
+ return None
205
+
206
+ encoded_inputs: Optional[Dict[str, torch.Tensor]] = None
207
+ try:
208
+ logger.info(f"Attempting to prepare inputs using processor for {image_path}.")
209
+ # The processor expects a PIL Image or a list of PIL Images
210
+ if image is None:
211
+ logger.error(f"Image is None before preprocessing for {image_path}. Cannot prepare inputs.")
212
+ return None
213
+
214
+ encoded_inputs = self.processor(
215
+ images=image,
216
+ return_tensors="pt",
217
+ truncation=True,
218
+ padding="max_length",
219
+ max_length=512
220
+ )
221
+ logger.info(f"Inputs prepared successfully for {image_path}.")
222
+ except Exception as e:
223
+ logger.error(f"An error occurred during input preparation for {image_path}: {e}", exc_info=True)
224
+ return None
225
+
226
+ # Move inputs to the device
227
+ if encoded_inputs is not None:
228
+ try:
229
+ logger.info(f"Attempting to move inputs to device ({self.device}) for {image_path}.")
230
+ for k, v in encoded_inputs.items():
231
+ if isinstance(v, torch.Tensor):
232
+ encoded_inputs[k] = v.to(self.device)
233
+ logger.info(f"Inputs moved to device ({self.device}) successfully for {image_path}.")
234
+ except Exception as e:
235
+ logger.error(f"An error occurred while moving inputs to device for {image_path}: {e}", exc_info=True)
236
+ return None
237
+ else:
238
+ logger.error(f"Encoded inputs are None after processing for {image_path}. Cannot move to device.")
239
+ return None
240
+
241
+
242
+ logger.info(f"Image loading and preprocessing completed successfully for {image_path}.")
243
+ return encoded_inputs
244
+
245
+
246
+ def predict(self, image_path: Path) -> Optional[str]:
247
+ """
248
+ Predicts the class label for a given image using the loaded LayoutLMv2 model.
249
+
250
+ This is the main prediction method. It orchestrates the process by first
251
+ preparing the image inputs using `_prepare_inputs`, performing inference
252
+ with the LayoutLMv2 model, determining the predicted class index from the
253
+ model's output logits, and finally mapping this index to a human-readable
254
+ class label using the `id2label` mapping.
255
+
256
+ Includes robust error handling and logging throughout the prediction pipeline.
257
+
258
+ Args:
259
+ image_path: Path to the image file to classify.
260
+
261
+ Returns:
262
+ The predicted class label as a string if the entire prediction process
263
+ is successful. Returns `None` if any critical step fails (e.g.,
264
+ image loading/preprocessing, model inference, or if the predicted
265
+ index is not found in the `id2label` mapping).
266
+ """
267
+ logger.info(f"Starting prediction process for image: {image_path}.")
268
+
269
+ # Prepare inputs
270
+ logger.info(f"Calling _prepare_inputs for {image_path}.")
271
+ encoded_inputs: Optional[Dict[str, torch.Tensor]] = self._prepare_inputs(image_path)
272
+ if encoded_inputs is None:
273
+ logger.error(f"Input preparation failed for {image_path}. Cannot perform prediction.")
274
+ logger.info(f"Prediction process failed for {image_path}.")
275
+ return None
276
+ logger.info(f"Input preparation successful for {image_path}.")
277
+
278
+
279
+ # Check if model is loaded
280
+ if self.model is None:
281
+ logger.error("LayoutLMv2 model is not loaded. Cannot perform prediction.")
282
+ logger.info(f"Prediction process failed for {image_path}.")
283
+ return None
284
+ logger.info("LayoutLMv2 model is loaded. Proceeding with inference.")
285
+
286
+ predicted_label: Optional[str] = None
287
+
288
+ try:
289
+ logger.info(f"Performing model inference for {image_path}.")
290
+ self.model.eval() # Set model to evaluation mode
291
+ with torch.no_grad():
292
+ outputs: Any = self.model(**encoded_inputs)
293
+ logits: torch.Tensor = outputs.logits
294
+
295
+ # Determine predicted class index
296
+ # Ensure logits is a tensor before calling argmax
297
+ if not isinstance(logits, torch.Tensor):
298
+ logger.error(f"Model output 'logits' is not a torch.Tensor for {image_path}. Cannot determine predicted index.")
299
+ logger.info(f"Prediction process failed for {image_path} due to invalid model output.")
300
+ return None
301
+
302
+ predicted_class_idx: int = logits.argmax(-1).item()
303
+ logger.info(f"Model inference completed for {image_path}. Predicted index: {predicted_class_idx}.")
304
+
305
+ # Map index to label
306
+ logger.info(f"Attempting to map predicted index {predicted_class_idx} to label.")
307
+ if predicted_class_idx in self.id2label:
308
+ predicted_label = self.id2label[predicted_class_idx]
309
+ logger.info(f"Mapped predicted index {predicted_class_idx} to label: {predicted_label}.")
310
+ else:
311
+ logger.error(f"Predicted index {predicted_class_idx} not found in id2label mapping for {image_path}.")
312
+ logger.info(f"Prediction process failed for {image_path} due to unknown predicted index.")
313
+ return None # Return None if index is not in mapping
314
+
315
+ except Exception as e:
316
+ logger.error(f"An error occurred during model inference or label mapping for {image_path}: {e}", exc_info=True)
317
+ logger.info(f"Prediction process failed for {image_path} due to inference/mapping error.")
318
+ return None
319
+
320
+ logger.info(f"Prediction process completed successfully for {image_path}. Predicted label: {predicted_label}.")
321
+ return predicted_label
322
+
323
+
app/src/logger.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from logging.handlers import RotatingFileHandler
4
+ from datetime import datetime
5
+
6
+ # Get the current working directory
7
+ #current_direction = os.path.dirname(os.path.abspath(__file__))
8
+ LOG_FILE=f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}"
9
+ logs_path=os.path.join(os.getcwd(),"logs",LOG_FILE)
10
+ os.makedirs(logs_path,exist_ok=True)
11
+
12
+ # Define the logging configuration
13
+ def setup_logger(file_name:str=None,api_app=None):
14
+
15
+ if file_name is not None :
16
+ LOG_FILE_PATH=os.path.join(logs_path,f"{file_name}.log")
17
+ #log_formatter = logging.Formatter("%(asctime)s- %(name)s - %(levelname)s - %(message)s")
18
+
19
+ # Modified log formatter to include filename, function name, and line number
20
+ log_formatter = logging.Formatter("%(asctime)s - %(filename)s - %(funcName)s - Line %(lineno)d - %(levelname)s - %(message)s")
21
+
22
+ # File handler for logging to a file
23
+ file_handler = RotatingFileHandler(filename=LOG_FILE_PATH,maxBytes=5 * 1024 * 1024, backupCount=3) # Log file size is 5MB with 3 backups
24
+ file_handler.setFormatter(log_formatter)
25
+ file_handler.setLevel(logging.INFO)
26
+
27
+
28
+ file_handler2 = RotatingFileHandler(filename=os.path.join(logs_path,"global.log"),maxBytes=5 * 1024 * 1024, backupCount=3) # Log file size is 5MB with 3 backups
29
+ file_handler2.setFormatter(log_formatter)
30
+ file_handler2.setLevel(logging.INFO)
31
+
32
+ # Stream handler for console output (optional)
33
+ console_handler = logging.StreamHandler()
34
+ console_handler.setFormatter(log_formatter)
35
+ console_handler.setLevel(logging.DEBUG)
36
+
37
+
38
+ # Add handlers to the root logger for custom logging
39
+ root_logger = logging.getLogger(file_name)
40
+ root_logger.setLevel(logging.DEBUG)
41
+ root_logger.addHandler(file_handler)
42
+ root_logger.addHandler(file_handler2)
43
+ #root_logger.addHandler(console_handler)
44
+
45
+ if api_app is not None:
46
+ # Get the FastAPI logger and attach handlers
47
+ uvicorn_access_logger = logging.getLogger("uvicorn.access") # For request logging
48
+ uvicorn_access_logger.setLevel(logging.INFO)
49
+ uvicorn_access_logger.addHandler(file_handler)
50
+ uvicorn_access_logger.addHandler(file_handler2)
51
+ #api_logger.addHandler(console_handler)
52
+
53
+ return uvicorn_access_logger
54
+
55
+ else:
56
+ return root_logger
57
+
58
+
59
+ else:
60
+
61
+ # Modified log formatter to include filename, function name, and line number
62
+ log_formatter = logging.Formatter("%(asctime)s - %(filename)s - %(funcName)s - Line %(lineno)d - %(levelname)s - %(message)s")
63
+
64
+
65
+ file_handler2 = RotatingFileHandler(filename=os.path.join(logs_path,"global.log"),maxBytes=5 * 1024 * 1024, backupCount=3) # Log file size is 5MB with 3 backups
66
+ file_handler2.setFormatter(log_formatter)
67
+ file_handler2.setLevel(logging.INFO)
68
+
69
+ # Stream handler for console output (optional)
70
+ console_handler = logging.StreamHandler()
71
+ console_handler.setFormatter(log_formatter)
72
+ console_handler.setLevel(logging.DEBUG)
73
+
74
+
75
+ # Add handlers to the root logger for custom logging
76
+ root_logger = logging.getLogger(file_name)
77
+ root_logger.setLevel(logging.DEBUG)
78
+ root_logger.addHandler(file_handler2)
79
+ #root_logger.addHandler(console_handler)
80
+
81
+ if api_app is not None:
82
+ # Get the FastAPI logger and attach handlers
83
+ uvicorn_access_logger = logging.getLogger("uvicorn.access") # For request logging
84
+ uvicorn_access_logger.setLevel(logging.INFO)
85
+ uvicorn_access_logger.addHandler(file_handler2)
86
+ #api_logger.addHandler(console_handler)
87
+
88
+ return uvicorn_access_logger
89
+
90
+ else:
91
+ return root_logger
app/src/model_loader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from app.src.vgg16_load import VGGDocumentClassifier
3
+ from app.src.vit_load import VITDocumentClassifier
4
+ from app.src.constant import *
5
+ from app.src.logger import setup_logger
6
+
7
+ logger = setup_logger("model_loader")
8
+
9
+
10
+ def vit_loader()->VITDocumentClassifier:
11
+ try:
12
+ vit=VITDocumentClassifier(vit_model_path, vit_mlb_path)
13
+ return vit
14
+ except Exception as e:
15
+ logger.error(str(e))
16
+ raise e
17
+
18
+
19
+ def vgg_loader():
20
+ try:
21
+ vgg=VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
22
+ return vgg
23
+ except Exception as e:
24
+ logger.error(str(e))
25
+ raise e
app/src/test_vit.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from sklearn.preprocessing import MultiLabelBinarizer
3
+ from pathlib import Path
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+ from app.src.logger import setup_logger
9
+
10
+ logger = setup_logger("test_vit")
11
+
12
+ try:
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ mlb_file_path=Path("artifacts\model\VIT_model\mlb.joblib")
16
+ model_file_path=Path("artifacts\model\VIT_model\model.pth")
17
+ # Select model
18
+ model_id = "google/vit-base-patch16-224-in21k"
19
+ # Load processor
20
+ processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
21
+
22
+ # TODO: You need to load your fine-tuned model here
23
+ # For example:
24
+ # model = AutoModelForImageClassification.from_pretrained("path/to/your/fine-tuned-model")
25
+ # For now, we will use the base model for demonstration, but it will not give correct predictions.
26
+ #model = AutoModelForImageClassification.from_pretrained(model_id)
27
+ # Load the entire model
28
+ model= torch.load(model_file_path, map_location=device,weights_only=False )
29
+ # Set device
30
+ model.to(device)
31
+
32
+ except Exception as e:
33
+ logger.error(str(e))
34
+ raise e
35
+
36
+
37
+
38
+
39
+ def mlb_load(file_path:Path)->MultiLabelBinarizer:
40
+ try:
41
+ # Assuming you run this notebook from the root of your project directory
42
+ mlb = joblib.load(file_path)
43
+
44
+ except FileNotFoundError:
45
+ logger.error("Error: 'artifacts/model/VIT_model/mlb.joblib' not found.")
46
+ logger.error("Please make sure the path is correct. Using a placeholder binarizer.")
47
+ # As a placeholder, let's create a dummy mlb if the file is not found.
48
+ mlb = MultiLabelBinarizer()
49
+ # This should be the set of your actual labels.
50
+ mlb.fit([['advertisement', 'email', 'form', 'invoice', 'note']])
51
+ return mlb
52
+
53
+
54
+
55
+
56
+
57
+
58
+ def VIT_model_prediction(image_path:Path,cut_off:float):
59
+ try:
60
+ # Load and convert image
61
+ # --- IMPORTANT: Please update this path to your image ---
62
+ try:
63
+ image = Image.open(image_path)
64
+ if image.mode != "RGB":
65
+ image = image.convert("RGB")
66
+ except FileNotFoundError:
67
+ logger.error(f"Error: Image not found at {image_path}")
68
+ logger.error("Using a dummy image for demonstration.")
69
+ # Create a dummy image for demonstration if image not found
70
+ image = Image.new('RGB', (224, 224), color = 'red')
71
+
72
+
73
+ # Preprocess image
74
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
75
+
76
+ # Forward pass
77
+ with torch.no_grad():
78
+ outputs = model(pixel_values)
79
+ logits = outputs.logits
80
+
81
+ # Apply sigmoid for multi-label classification
82
+ sigmoid = torch.nn.Sigmoid()
83
+ probs = sigmoid(logits.squeeze().cpu())
84
+
85
+ # Thresholding (using 0.5 as an example)
86
+ predictions = np.zeros(probs.shape)
87
+ predictions[np.where(probs >= cut_off)] = 1
88
+
89
+ # Get label names using the loaded MultiLabelBinarizer
90
+ mlb=mlb_load(mlb_file_path)
91
+ # The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
92
+ predicted_labels = mlb.inverse_transform(predictions.reshape(1, -1))
93
+ logger.info(f"Predicted labels: {predicted_labels}")
94
+ return {"status":1,"classe":predicted_labels}
95
+
96
+ except Exception as e:
97
+ logger.error(str(e))
98
+ raise e
99
+
100
+
101
+
102
+ #VIT_model_prediction(Path(r"dataset\sample_text_ds\test\email\2078379610a.jpg"),0.5)
app/src/vgg16_load.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import joblib
3
+ import tensorflow as tf
4
+ from pathlib import Path
5
+ from sklearn.preprocessing import MultiLabelBinarizer
6
+ import cv2
7
+ import numpy as np
8
+ import logging
9
+ import cv2
10
+ import keras
11
+ from pathlib import Path
12
+ import tensorflow as tf
13
+ from typing import Optional, Tuple, List
14
+ from app.src.logger import setup_logger
15
+
16
+
17
+ # Configure logging
18
+ logger = setup_logger("vgg16_load")
19
+
20
+ def load_vgg_artifacts(model_path: Path, mlb_path: Path) -> tuple[tf.keras.Model, MultiLabelBinarizer]:
21
+ """
22
+ Loads the VGG model and the MultiLabelBinarizer from specified paths.
23
+
24
+ Args:
25
+ model_path: Path to the VGG model file (.keras).
26
+ mlb_path: Path to the MultiLabelBinarizer file (.joblib).
27
+
28
+ Returns:
29
+ A tuple containing the loaded Keras model and MultiLabelBinarizer object.
30
+
31
+ Raises:
32
+ FileNotFoundError: If either the model file or the MLB file is not found.
33
+ Exception: If any other error occurs during loading.
34
+ """
35
+ model = None
36
+ mlb = None
37
+ try:
38
+ logger.info(f"Attempting to load VGG model from {model_path}")
39
+ model = tf.keras.models.load_model(model_path)
40
+ logger.info("VGG model loaded successfully.")
41
+ except FileNotFoundError:
42
+ logger.error(f"Error: VGG model file not found at {model_path}")
43
+ raise
44
+ except Exception as e:
45
+ logger.error(f"An error occurred while loading the VGG model: {e}")
46
+ raise
47
+
48
+ try:
49
+ logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
50
+ mlb = joblib.load(mlb_path)
51
+ logger.info("MultiLabelBinarizer loaded successfully.")
52
+ except FileNotFoundError:
53
+ logger.error(f"Error: MultiLabelBinarizer file not found at {mlb_path}")
54
+ raise
55
+ except Exception as e:
56
+ logger.error(f"An error occurred while loading the MultiLabelBinarizer: {e}")
57
+ raise
58
+
59
+ logger.info("Both VGG model and MultiLabelBinarizer loaded successfully.")
60
+ return model, mlb
61
+
62
+
63
+
64
+
65
+ def preprocess_image(image_path: Path, target_size: tuple[int, int] = (224, 224)) -> np.ndarray | None:
66
+ """
67
+ Preprocesses an image for VGG model prediction.
68
+
69
+ Loads an image from the specified path, converts it to RGB, resizes it,
70
+ and normalizes pixel values. Includes robust error handling and logging
71
+ at each step.
72
+
73
+ Args:
74
+ image_path: Path to the image file.
75
+ target_size: A tuple (width, height) specifying the desired output size.
76
+
77
+ Returns:
78
+ A preprocessed NumPy array representing the image with pixel values
79
+ scaled between 0 and 1, or None if an error occurred during processing.
80
+ """
81
+ try:
82
+ logger.info(f"Attempting to load image from {image_path}")
83
+ img = cv2.imread(str(image_path)) # cv2.imread expects a string or numpy array
84
+
85
+ if img is None:
86
+ logger.error(f"Error: Could not load image from {image_path}. cv2.imread returned None.")
87
+ return None
88
+ logger.info("Image loaded successfully.")
89
+
90
+ logger.info("Attempting to convert image to RGB.")
91
+ # Check if the image is already in a format that doesn't need BGR to RGB conversion
92
+ # cv2.imread loads in BGR format by default for color images.
93
+ # If the image is grayscale, it might be loaded as such.
94
+ # We want RGB for consistency with models trained on RGB data.
95
+ if len(img.shape) == 3 and img.shape[2] == 3: # Check if it's a color image (likely BGR)
96
+ try:
97
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
98
+ logger.info("Image converted to RGB successfully.")
99
+ except cv2.error as e:
100
+ logger.error(f"Error during BGR to RGB conversion for image {image_path}: {e}")
101
+ return None
102
+ elif len(img.shape) == 2: # Grayscale image
103
+ try:
104
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
105
+ logger.info("Grayscale image converted to RGB successfully.")
106
+ except cv2.error as e:
107
+ logger.error(f"Error during Grayscale to RGB conversion for image {image_path}: {e}")
108
+ return None
109
+ else:
110
+ logger.warning(f"Unexpected image format for {image_path}. Attempting to proceed.")
111
+ # If it's not a standard color or grayscale, we might proceed but log a warning.
112
+ # Depending on requirements, you might want to return None here.
113
+
114
+
115
+ logger.info(f"Attempting to resize image to {target_size}.")
116
+ try:
117
+ img = cv2.resize(img, target_size)
118
+ if img is None or img.size == 0:
119
+ logger.error(f"Error: cv2.resize returned None or empty array for image {image_path}.")
120
+ return None
121
+ logger.info("Image resized successfully.")
122
+ except cv2.error as e:
123
+ logger.error(f"Error during image resizing for image {image_path} to size {target_size}: {e}")
124
+ return None
125
+
126
+
127
+ logger.info("Attempting to normalize pixel values.")
128
+ try:
129
+ # Ensure the image is the correct dtype before division
130
+ img = img.astype("float32") / 255.0
131
+ if img is None or img.size == 0 or np.max(img) > 1.0 or np.min(img) < 0.0:
132
+ logger.error(f"Error: Image normalization failed or resulted in unexpected values for image {image_path}.")
133
+ return None
134
+ logger.info("Pixel values normalized successfully.")
135
+ except Exception as e:
136
+ logger.error(f"Error during pixel normalization for image {image_path}: {e}")
137
+ return None
138
+
139
+ logger.info(f"Image preprocessing completed successfully for {image_path}.")
140
+ return img
141
+
142
+ except Exception as e:
143
+ logger.error(f"An unexpected error occurred during image preprocessing for {image_path}: {e}")
144
+ return None
145
+
146
+
147
+
148
+
149
+
150
+ class VGGDocumentClassifier:
151
+ """
152
+ A class for classifying documents using a VGG16 model.
153
+
154
+ This class encapsulates the loading of the VGG16 model and its associated
155
+ MultiLabelBinarizer, provides a method to preprocess input images, and
156
+ performs document classification predictions.
157
+ """
158
+
159
+ def __init__(self, model_path: Path, mlb_path: Path, target_size: Tuple[int, int] = (224, 224)) -> None:
160
+ """
161
+ Initializes the VGGDocumentClassifier by loading the model and MLB.
162
+
163
+ Args:
164
+ model_path: Path to the VGG model file (.keras).
165
+ mlb_path: Path to the MultiLabelBinarizer file (.joblib).
166
+ target_size: The target size (width, height) for image preprocessing.
167
+ Defaults to (224, 224).
168
+
169
+ Raises:
170
+ FileNotFoundError: If either the model file or the MLB file is not found.
171
+ Exception: If any other error occurs during loading.
172
+ """
173
+ logger.info("Initializing VGGDocumentClassifier.")
174
+ self.model: Optional[tf.keras.Model] = None
175
+ self.mlb: Optional[MultiLabelBinarizer] = None
176
+ self.target_size: Tuple[int, int] = target_size
177
+
178
+ try:
179
+ self._load_artifacts(model_path, mlb_path)
180
+ if self.model and self.mlb:
181
+ logger.info("VGGDocumentClassifier initialized successfully.")
182
+ else:
183
+ logger.critical("VGGDocumentClassifier failed to fully initialize due to artifact loading errors.")
184
+ raise RuntimeError("Failed to load all required artifacts for VGGDocumentClassifier.")
185
+ except Exception as e:
186
+ logger.critical(f"Failed to initialize VGGDocumentClassifier: {e}", exc_info=True)
187
+ raise # Re-raise the exception after logging
188
+
189
+
190
+ def _load_artifacts(self, model_path: Path, mlb_path: Path) -> None:
191
+ """
192
+ Loads the VGG model and MultiLabelBinarizer with error handling and logging.
193
+
194
+ Args:
195
+ model_path: Path to the VGG model file (.keras).
196
+ mlb_path: Path to the MultiLabelBinarizer file (.joblib).
197
+
198
+ Raises:
199
+ FileNotFoundError: If either the model file or the MLB file is not found.
200
+ Exception: If any other unexpected error occurs during loading.
201
+ """
202
+ logger.info("Starting artifact loading.")
203
+ model_loaded: bool = False
204
+ mlb_loaded: bool = False
205
+
206
+ # Load Model
207
+ try:
208
+ logger.info(f"Attempting to load VGG model from {model_path}")
209
+ self.model = tf.keras.models.load_model(model_path)
210
+ logger.info("VGG model loaded successfully.")
211
+ model_loaded = True
212
+ except FileNotFoundError:
213
+ logger.critical(f"Critical Error: VGG model file not found at {model_path}", exc_info=True)
214
+ raise # Re-raise to indicate a critical initialization failure
215
+ except Exception as e:
216
+ logger.critical(f"Critical Error: An unexpected error occurred while loading the VGG model from {model_path}: {e}", exc_info=True)
217
+ raise # Re-raise to indicate a critical initialization failure
218
+
219
+ # Load MLB
220
+ try:
221
+ logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
222
+ self.mlb = joblib.load(mlb_path)
223
+ logger.info("MultiLabelBinarizer loaded successfully.")
224
+ mlb_loaded = True
225
+ except FileNotFoundError:
226
+ logger.critical(f"Critical Error: MultiLabelBinarizer file not found at {mlb_path}", exc_info=True)
227
+ raise # Re-raise to indicate a critical initialization failure
228
+ except Exception as e:
229
+ logger.critical(f"Critical Error: An unexpected error occurred while loading the MultiLabelBinarizer from {mlb_path}: {e}", exc_info=True)
230
+ raise # Re-raise to indicate a critical initialization failure
231
+
232
+ if model_loaded and mlb_loaded:
233
+ logger.info("All required VGG artifacts loaded successfully.")
234
+ else:
235
+ logger.error("One or more required VGG artifacts failed to load during _load_artifacts.")
236
+
237
+
238
+ def preprocess_image(self, image_path: Path) -> Optional[np.ndarray]:
239
+ """
240
+ Preprocesses an image for VGG model prediction.
241
+
242
+ Loads an image from the specified path, converts it to RGB, resizes it,
243
+ and normalizes pixel values. Includes robust error handling and logging
244
+ at each step.
245
+
246
+ Args:
247
+ image_path: Path to the image file.
248
+
249
+ Returns:
250
+ A preprocessed NumPy array representing the image with pixel values
251
+ scaled between 0 and 1, or None if an error occurred during processing.
252
+ """
253
+ try:
254
+ logger.info(f"Attempting to load image from {image_path}")
255
+ img = cv2.imread(str(image_path)) # cv2.imread expects a string or numpy array
256
+
257
+ if img is None:
258
+ logger.error(f"Error: Could not load image from {image_path}. cv2.imread returned None.")
259
+ return None
260
+ logger.info("Image loaded successfully.")
261
+
262
+ logger.info("Attempting to convert image to RGB.")
263
+ if len(img.shape) == 3 and img.shape[2] == 3: # Check if it's a color image (likely BGR)
264
+ try:
265
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
266
+ logger.info("Image converted to RGB successfully.")
267
+ except cv2.error as e:
268
+ logger.error(f"Error during BGR to RGB conversion for image {image_path}: {e}")
269
+ return None
270
+ elif len(img.shape) == 2: # Grayscale image
271
+ try:
272
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
273
+ logger.info("Grayscale image converted to RGB successfully.")
274
+ except cv2.error as e:
275
+ logger.error(f"Error during Grayscale to RGB conversion for image {image_path}: {e}")
276
+ return None
277
+ else:
278
+ logger.warning(f"Unexpected image format for {image_path}. Attempting to proceed.")
279
+
280
+
281
+ logger.info(f"Attempting to resize image to {self.target_size}.")
282
+ try:
283
+ img = cv2.resize(img, self.target_size)
284
+ if img is None or img.size == 0:
285
+ logger.error(f"Error: cv2.resize returned None or empty array for image {image_path}.")
286
+ return None
287
+ logger.info("Image resized successfully.")
288
+ except cv2.error as e:
289
+ logger.error(f"Error during image resizing for image {image_path} to size {self.target_size}: {e}")
290
+ return None
291
+
292
+
293
+ logger.info("Attempting to normalize pixel values.")
294
+ try:
295
+ img = img.astype("float32") / 255.0
296
+ if img is None or img.size == 0 or np.max(img) > 1.0 or np.min(img) < 0.0:
297
+ logger.error(f"Error: Image normalization failed or resulted in unexpected values for image {image_path}.")
298
+ return None
299
+ logger.info("Pixel values normalized successfully.")
300
+ except Exception as e:
301
+ logger.error(f"Error during pixel normalization for image {image_path}: {e}")
302
+ return None
303
+
304
+ logger.info(f"Image preprocessing completed successfully for {image_path}.")
305
+ return img
306
+
307
+ except Exception as e:
308
+ logger.error(f"An unexpected error occurred during image preprocessing for {image_path}: {e}")
309
+ return None
310
+
311
+
312
+ def predict(self, image_path: Path) -> Optional[List[str]]:
313
+ """
314
+ Predicts the class labels for a given image using the loaded VGG model.
315
+
316
+ The process involves loading and preprocessing the image, performing
317
+ inference with the model, and converting the prediction to class labels
318
+ using the MultiLabelBinarizer.
319
+
320
+ Args:
321
+ image_path: Path to the image file to classify.
322
+
323
+ Returns:
324
+ A list of predicted class labels (strings) if the prediction process
325
+ is successful. Returns None if any critical step (image loading,
326
+ preprocessing, model inference, or inverse transform) fails.
327
+ Returns an empty list if the prediction process is successful but
328
+ no labels are predicted.
329
+ """
330
+ logger.info(f"Starting prediction process for image: {image_path}.")
331
+
332
+ if self.model is None or self.mlb is None:
333
+ logger.error("Model or MultiLabelBinarizer not loaded. Cannot perform prediction.")
334
+ return None
335
+
336
+ # Preprocess image
337
+ image = self.preprocess_image(image_path)
338
+ if image is None:
339
+ logger.error(f"Image preprocessing failed for {image_path}. Cannot perform prediction.")
340
+ return None
341
+
342
+ try:
343
+ logger.info(f"Performing model inference for {image_path}.")
344
+ # Add batch dimension to the image
345
+ image = np.expand_dims(image, axis=0)
346
+ prd = self.model.predict(image)
347
+ logger.info(f"Model inference completed for {image_path}. Prediction shape: {prd.shape}")
348
+ except Exception as e:
349
+ logger.error(f"An error occurred during model inference for {image_path}: {e}", exc_info=True)
350
+ return None
351
+
352
+
353
+ # Convert the prediction to a binary indicator format and get labels
354
+ try:
355
+ logger.info(f"Converting prediction to labels for {image_path}.")
356
+ # Assuming multi-class classification for now, taking the argmax
357
+ # If it's multi-label, you'd apply a sigmoid and thresholding here
358
+ pred_id = np.argmax(prd, axis=1)
359
+
360
+ # Create a zero array with the shape (1, number of classes)
361
+ binary_prediction = np.zeros((1, len(self.mlb.classes_)))
362
+ # Set the index of the predicted class to 1
363
+ binary_prediction[0, pred_id] = 1
364
+
365
+
366
+ predicted_labels_tuple_list: List[Tuple[str, ...]] = self.mlb.inverse_transform(binary_prediction)
367
+ logger.info(f"Prediction processed for {image_path}. Predicted labels (raw tuple list): {predicted_labels_tuple_list}")
368
+
369
+ if predicted_labels_tuple_list and len(predicted_labels_tuple_list) > 0:
370
+ final_labels: List[str] = list(predicted_labels_tuple_list[0])
371
+ logger.info(f"Final predicted labels for {image_path}: {final_labels}")
372
+ return final_labels
373
+ else:
374
+ logger.warning(f"MLB inverse_transform returned an empty list for {image_path}. No labels predicted.")
375
+ return []
376
+
377
+ except Exception as e:
378
+ logger.error(f"An error occurred during inverse transform or label processing for {image_path}: {e}", exc_info=True)
379
+ return None
380
+
381
+
app/src/vit_load.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
+ from sklearn.preprocessing import MultiLabelBinarizer
6
+ import joblib
7
+ from pathlib import Path
8
+ from typing import List, Optional, Tuple, Any
9
+ from app.src.logger import setup_logger
10
+
11
+
12
+
13
+ logger = setup_logger("vit_load")
14
+
15
+ class VITDocumentClassifier:
16
+ """
17
+ A class for classifying documents using a Vision Transformer (ViT) model.
18
+
19
+ This class encapsulates the loading of the ViT model, its associated processor,
20
+ and a MultiLabelBinarizer for converting model outputs to meaningful labels.
21
+ It provides a method to preprocess input images and perform multi-label
22
+ classification predictions with a specified confidence cutoff threshold.
23
+ """
24
+
25
+ def __init__(self, model_path: Path, mlb_path: Path, model_id: str = "google/vit-base-patch16-224-in21k") -> None:
26
+ """
27
+ Initializes the VITDocumentClassifier by loading the model, processor, and MLB.
28
+
29
+ Args:
30
+ model_path: Path to the ViT model file (.pth). This is expected to be
31
+ a pre-trained or fine-tuned PyTorch model file.
32
+ mlb_path: Path to the MultiLabelBinarizer file (.joblib). This file
33
+ should contain the fitted binarizer object corresponding
34
+ to the model's output classes.
35
+ model_id: The Hugging Face model ID for the processor. This is used
36
+ to load the appropriate image processor for the ViT model.
37
+ Defaults to "google/vit-base-patch16-224-in21k".
38
+
39
+ Raises:
40
+ FileNotFoundError: If either the model file or the MLB file is not found
41
+ at the specified paths during artifact loading.
42
+ Exception: If any other unexpected error occurs during the loading
43
+ of the model, processor, or MultiLabelBinarizer.
44
+ RuntimeError: If artifact loading fails for critical components
45
+ (model or MLB).
46
+ """
47
+ logger.info("Initializing VITDocumentClassifier.")
48
+ self.model: Optional[torch.nn.Module] = None
49
+ self.processor: Optional[AutoImageProcessor] = None
50
+ self.mlb: Optional[MultiLabelBinarizer] = None
51
+ self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ logger.info(f"Using device: {self.device}")
53
+ self.model_id: str = model_id
54
+
55
+ try:
56
+ self._load_artifacts(model_path, mlb_path)
57
+ if self.model and self.processor and self.mlb:
58
+ logger.info("VITDocumentClassifier initialized successfully.")
59
+ else:
60
+ # This case should ideally be caught and re-raised in _load_artifacts
61
+ # but adding a check here for robustness.
62
+ logger.critical("VITDocumentClassifier failed to fully initialize due to artifact loading errors.")
63
+ raise RuntimeError("Failed to load all required artifacts for VITDocumentClassifier.")
64
+
65
+ except Exception as e:
66
+ logger.critical(f"Failed to initialize VITDocumentClassifier: {e}", exc_info=True)
67
+ # Re-raise the exception after logging
68
+ raise
69
+
70
+
71
+ def _load_artifacts(self, model_path: Path, mlb_path: Path) -> None:
72
+ """
73
+ Loads the ViT model, processor, and MultiLabelBinarizer with enhanced error handling and logging.
74
+
75
+ This is an internal helper method called during initialization.
76
+
77
+ Args:
78
+ model_path: Path to the ViT model file (.pth).
79
+ mlb_path: Path to the MultiLabelBinarizer file (.joblib).
80
+
81
+ Raises:
82
+ FileNotFoundError: If either the model file or the MLB file is not found.
83
+ Exception: If any other unexpected error occurs during loading.
84
+ """
85
+ logger.info("Starting artifact loading.")
86
+ processor_loaded: bool = False
87
+ model_loaded: bool = False
88
+ mlb_loaded: bool = False
89
+
90
+ # Load Processor
91
+ try:
92
+ logger.info(f"Attempting to load ViT processor for model ID: {self.model_id}")
93
+ self.processor = AutoImageProcessor.from_pretrained(self.model_id, use_fast=True)
94
+ logger.info("ViT processor loaded successfully.")
95
+ processor_loaded = True
96
+ except Exception as e:
97
+ # Log at error level as processor is important but not strictly critical if we raise later
98
+ logger.error(f"An error occurred while loading the ViT processor for model ID {self.model_id}: {e}", exc_info=True)
99
+ # Do not re-raise here, continue loading other artifacts
100
+
101
+
102
+ # Load Model
103
+ try:
104
+ logger.info(f"Attempting to load ViT model from {model_path}")
105
+ # Note: Adjust map_location as needed based on where the model was saved
106
+ self.model = torch.load(model_path, map_location=self.device, weights_only=False)
107
+ self.model.to(self.device) # Ensure model is on the correct device
108
+ logger.info(f"ViT model loaded successfully and moved to {self.device}.")
109
+ model_loaded = True
110
+ except FileNotFoundError:
111
+ logger.critical(f"Critical Error: ViT model file not found at {model_path}", exc_info=True)
112
+ raise # Re-raise to indicate a critical initialization failure
113
+ except Exception as e:
114
+ logger.critical(f"Critical Error: An unexpected error occurred while loading the ViT model from {model_path}: {e}", exc_info=True)
115
+ raise # Re-raise to indicate a critical initialization failure
116
+
117
+
118
+ # Load MLB
119
+ try:
120
+ logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
121
+ self.mlb = joblib.load(mlb_path)
122
+ logger.info("MultiLabelBinarizer loaded successfully.")
123
+ mlb_loaded = True
124
+ except FileNotFoundError:
125
+ logger.critical(f"Critical Error: MultiLabelBinarizer file not found at {mlb_path}", exc_info=True)
126
+ raise # Re-raise to indicate a critical initialization failure
127
+ except Exception as e:
128
+ logger.critical(f"Critical Error: An unexpected error occurred while loading the MultiLabelBinarizer from {mlb_path}: {e}", exc_info=True)
129
+ raise # Re-raise to indicate a critical initialization failure
130
+
131
+ if processor_loaded and model_loaded and mlb_loaded:
132
+ logger.info("All required ViT artifacts loaded successfully.")
133
+ else:
134
+ logger.error("One or more required ViT artifacts failed to load during _load_artifacts.")
135
+
136
+
137
+ def predict(self, image_path: Path, cut_off: float = 0.5) -> Optional[List[str]]:
138
+ """
139
+ Predicts the class labels for a given image using the loaded ViT model.
140
+
141
+ The process involves loading and preprocessing the image, performing
142
+ inference with the model, applying a sigmoid activation, thresholding
143
+ the probabilities to obtain binary predictions, and finally converting
144
+ the binary predictions back to class labels using the MultiLabelBinarizer.
145
+
146
+ Args:
147
+ image_path: Path to the image file to classify. The image is expected
148
+ to be in a format compatible with PIL (Pillow).
149
+ cut_off: The threshold for converting predicted probabilities into
150
+ binary labels. Probabilities greater than or equal to this
151
+ value are considered positive predictions (1), otherwise 0.
152
+ Defaults to 0.5.
153
+
154
+ Returns:
155
+ A list of predicted class labels (strings) if the prediction process
156
+ is successful. Returns None if any critical step (image loading,
157
+ preprocessing, model inference, or inverse transform) fails.
158
+ Returns an empty list if the prediction process is successful but
159
+ no labels meet the cutoff threshold.
160
+ """
161
+ logger.info(f"Starting prediction process for image: {image_path} with cutoff {cut_off}.")
162
+
163
+ if self.model is None or self.processor is None or self.mlb is None:
164
+ logger.error("Model, processor, or MultiLabelBinarizer not loaded. Cannot perform prediction.")
165
+ return None
166
+
167
+ # Load and preprocess image
168
+ image: Optional[Image.Image] = None
169
+ try:
170
+ logger.info(f"Attempting to load image from {image_path}")
171
+ image = Image.open(image_path)
172
+ logger.info(f"Image loaded successfully from {image_path}.")
173
+ except FileNotFoundError:
174
+ logger.error(f"Error: Image file not found at {image_path}", exc_info=True)
175
+ return None
176
+ except Exception as e:
177
+ logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True)
178
+ return None
179
+
180
+ try:
181
+ logger.info(f"Attempting to convert image to RGB for {image_path}.")
182
+ if image.mode != "RGB":
183
+ image = image.convert("RGB")
184
+ logger.info(f"Image converted to RGB successfully for {image_path}.")
185
+ else:
186
+ logger.info(f"Image is already in RGB format for {image_path}.")
187
+
188
+ except Exception as e:
189
+ logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True)
190
+ return None
191
+
192
+
193
+ # Preprocess image using the loaded processor
194
+ try:
195
+ logger.info(f"Attempting to preprocess image using processor for {image_path}.")
196
+ # Check if image is valid after loading/conversion
197
+ if image is None:
198
+ logger.error(f"Image is None after loading/conversion for {image_path}. Cannot preprocess.")
199
+ return None
200
+ # The processor expects a PIL Image or a list of PIL Images
201
+ pixel_values: torch.Tensor = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
202
+ logger.info(f"Image preprocessed and moved to device ({self.device}).")
203
+ except Exception as e:
204
+ logger.error(f"An error occurred during image preprocessing for {image_path}: {e}", exc_info=True)
205
+ return None
206
+
207
+ # Forward pass
208
+ try:
209
+ logger.info(f"Starting model forward pass for {image_path}.")
210
+ self.model.eval() # Set model to evaluation mode
211
+ with torch.no_grad():
212
+ outputs: Any = self.model(pixel_values) # Use Any because the output type can vary
213
+ logits: torch.Tensor = outputs.logits
214
+ logger.info(f"Model forward pass completed for {image_path}.")
215
+ except Exception as e:
216
+ logger.error(f"An error occurred during model forward pass for {image_path}: {e}", exc_info=True)
217
+ return None
218
+
219
+
220
+ # Apply sigmoid and thresholding
221
+ try:
222
+ logger.info(f"Applying sigmoid and thresholding for {image_path}.")
223
+ sigmoid: torch.nn.Sigmoid = torch.nn.Sigmoid()
224
+ probs: torch.Tensor = sigmoid(logits.squeeze().cpu())
225
+
226
+ predictions: np.ndarray = np.zeros(probs.shape, dtype=int) # Explicitly set dtype to int
227
+ print(predictions)
228
+ predictions[np.where(probs >= cut_off)] = 1
229
+ logger.info(f"Applied sigmoid and thresholding with cutoff {cut_off} for {image_path}. Binary predictions shape: {predictions.shape}")
230
+ except Exception as e:
231
+ logger.error(f"An error occurred during probability processing for {image_path}: {e}", exc_info=True)
232
+ return None
233
+
234
+
235
+ # Get label names using the loaded MultiLabelBinarizer
236
+ try:
237
+ logger.info(f"Performing inverse transform using MultiLabelBinarizer for {image_path}.")
238
+ # The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
239
+ # Use the self.mlb loaded during initialization
240
+
241
+ # Ensure self.mlb is not None (checked at the start of predict, but good practice)
242
+ if self.mlb is None:
243
+ logger.error(f"MultiLabelBinarizer is None. Cannot perform inverse transform for {image_path}.")
244
+ return None
245
+
246
+ binary_prediction: np.ndarray
247
+
248
+ # Ensure predictions shape is compatible (must be 2D: (n_samples, n_classes))
249
+ # Since we process one image at a time, expected shape is (1, n_classes)
250
+ expected_shape: Tuple[int, int] = (1, len(self.mlb.classes_))
251
+
252
+ if predictions.ndim == 1 and predictions.shape[0] == len(self.mlb.classes_):
253
+ binary_prediction = predictions.reshape(expected_shape)
254
+ logger.info(f"Reshaped 1D prediction to 2D ({expected_shape}) for inverse transform.")
255
+ elif predictions.ndim == 2 and predictions.shape == expected_shape:
256
+ binary_prediction = predictions
257
+ logger.info(f"Prediction already in correct 2D shape ({expected_shape}) for inverse transform.")
258
+ else:
259
+ logger.error(f"Cannot inverse transform prediction shape {predictions.shape} with MLB classes {len(self.mlb.classes_)} for {image_path}. Expected shape: {expected_shape}")
260
+ return None
261
+
262
+
263
+ predicted_labels_tuple_list: List[Tuple[str, ...]] = self.mlb.inverse_transform(binary_prediction)
264
+ logger.info(f"Prediction processed for {image_path}. Predicted labels (raw tuple list): {predicted_labels_tuple_list}")
265
+
266
+ # inverse_transform returns a list of tuples, even for a single sample.
267
+ # We expect a single prediction here, so we take the first tuple.
268
+ if predicted_labels_tuple_list and len(predicted_labels_tuple_list) > 0:
269
+ final_labels: List[str] = list(predicted_labels_tuple_list[0])
270
+ logger.info(f"Final predicted labels for {image_path}: {final_labels}")
271
+ return final_labels
272
+ else:
273
+ logger.warning(f"MLB inverse_transform returned an empty list for {image_path}. No labels predicted.")
274
+ return []
275
+
276
+
277
+ except Exception as e:
278
+ logger.error(f"An error occurred during inverse transform for {image_path}: {e}", exc_info=True)
279
+ return None
280
+
281
+
artifacts/model/VIT_model/confusion_matrix.png ADDED

Git LFS Details

  • SHA256: 6fa92b894adacd89239ccb67dcaa37c8e84a4d6d8987924e2e5d6a913f70c415
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
artifacts/model/VIT_model/mlb.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4754cb9555905cbeb8a008ac90b2bb81ab076fbc272510a17c40abea32aa5d16
3
+ size 571
artifacts/model/VIT_model/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:223b9f3ccbe55b37f66ed7dd4c832116c17bec3229693a679da41351e9361a82
3
+ size 343310666
artifacts/model/vgg_model/mlb.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4754cb9555905cbeb8a008ac90b2bb81ab076fbc272510a17c40abea32aa5d16
3
+ size 571
artifacts/model/vgg_model/model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad1f9fbf700dfac83efd97f5cc4f944ea5a628de9c0ba26d440abdd4b4426ef2
3
+ size 183090331
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.53.0
2
+ efficientnet==1.1.1
3
+ seaborn==0.13.2
4
+ libfinder==0.1.7
5
+ pathlib==1.0.1
6
+ requests==2.32.3
7
+ tensorflow==2.18.0
8
+ dagshub==0.5.10
9
+ google==2.0.3
10
+ torch==2.7.1
11
+ numpy==2.0.2
12
+ pandas==2.2.2
13
+ opencv-python
14
+ mlflow==3.1.1
15
+ keras==3.8.0
16
+ scikit-learn==1.6.1
17
+ ensure==1.0.4
18
+ joblib==1.5.1
19
+ matplotlib==3.10.0
20
+ ensure==1.0.4
21
+ python-box
22
+ pydot
23
+ graphviz
24
+ #'git+https://github.com/facebookresearch/detectron2.git'
25
+ gradio
26
+ fastapi==0.115.4
27
+ uvicorn==0.34.0
28
+ python-multipart== 0.0.19
29
+ -e .