Spaces:
Runtime error
Runtime error
Upload 17 files
Browse files- .gitattributes +2 -0
- app.py +46 -0
- app/__init__.py +0 -0
- app/app.py +98 -0
- app/src/__init__.py +0 -0
- app/src/constant.py +11 -0
- app/src/layout_loader.py +323 -0
- app/src/logger.py +91 -0
- app/src/model_loader.py +25 -0
- app/src/test_vit.py +102 -0
- app/src/vgg16_load.py +381 -0
- app/src/vit_load.py +281 -0
- artifacts/model/VIT_model/confusion_matrix.png +3 -0
- artifacts/model/VIT_model/mlb.joblib +3 -0
- artifacts/model/VIT_model/model.pth +3 -0
- artifacts/model/vgg_model/mlb.joblib +3 -0
- artifacts/model/vgg_model/model.keras +3 -0
- requirements.txt +29 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
artifacts/model/vgg_model/model.keras filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
artifacts/model/VIT_model/confusion_matrix.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
# Import your model classes (adjust import paths as needed)
|
| 6 |
+
from app.src.vit_load import VITDocumentClassifier
|
| 7 |
+
from app.src.vgg16_load import VGGDocumentClassifier
|
| 8 |
+
from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path
|
| 9 |
+
|
| 10 |
+
# Load models once at startup
|
| 11 |
+
vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
|
| 12 |
+
vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
|
| 13 |
+
|
| 14 |
+
def predict_vit(image, cut_off):
|
| 15 |
+
if image is None:
|
| 16 |
+
return "Please upload an image."
|
| 17 |
+
temp_path = "temp_vit_image.png"
|
| 18 |
+
image.save(temp_path)
|
| 19 |
+
result = vit_model.predict(Path(temp_path), cut_off)
|
| 20 |
+
return f"ViT Prediction: {result}"
|
| 21 |
+
|
| 22 |
+
def predict_vgg(image):
|
| 23 |
+
if image is None:
|
| 24 |
+
return "Please upload an image."
|
| 25 |
+
temp_path = "temp_vgg_image.png"
|
| 26 |
+
image.save(temp_path)
|
| 27 |
+
result = vgg_model.predict(Path(temp_path))
|
| 28 |
+
return f"VGG16 Prediction: {result}"
|
| 29 |
+
|
| 30 |
+
with gr.Blocks() as demo:
|
| 31 |
+
gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
|
| 32 |
+
with gr.Row():
|
| 33 |
+
with gr.Column():
|
| 34 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
| 35 |
+
cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
|
| 36 |
+
with gr.Column():
|
| 37 |
+
result_output = gr.Textbox(label="Prediction Result", interactive=False)
|
| 38 |
+
with gr.Row():
|
| 39 |
+
vit_btn = gr.Button("Predict with ViT Model")
|
| 40 |
+
vgg_btn = gr.Button("Predict with VGG16 Model")
|
| 41 |
+
|
| 42 |
+
vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
|
| 43 |
+
vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
demo.launch()
|
app/__init__.py
ADDED
|
File without changes
|
app/app.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import StreamingResponse,FileResponse , JSONResponse,HTMLResponse
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import uvicorn
|
| 8 |
+
import cv2
|
| 9 |
+
import tempfile
|
| 10 |
+
import shutil
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
import base64
|
| 14 |
+
import numpy as np
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
from app.src.model_loader import vit_loader,vgg_loader
|
| 18 |
+
from app.src.logger import setup_logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings("ignore")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
app=FastAPI(title="Document_Classifire",
|
| 25 |
+
description="FastAPI",
|
| 26 |
+
version="0.115.4")
|
| 27 |
+
|
| 28 |
+
# Allow all origins (replace * with specific origins if needed)
|
| 29 |
+
app.add_middleware(
|
| 30 |
+
CORSMiddleware,
|
| 31 |
+
allow_origins=["*"],
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
@app.get("/")
|
| 38 |
+
async def root():
|
| 39 |
+
return {"Fast API":"API is woorking"}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Suppress warnings
|
| 43 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' # 0 = all logs, 1 = filter out info, 2 = filter out warnings, 3 = filter out errors
|
| 44 |
+
warnings.filterwarnings("ignore")
|
| 45 |
+
|
| 46 |
+
logger = setup_logger()
|
| 47 |
+
|
| 48 |
+
@app.post("/vit_model")
|
| 49 |
+
async def vit_clf(cut_off:float=0.5,image_file: UploadFile = File(...)):
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
# Create a temporary directory
|
| 53 |
+
temp_dir = tempfile.mkdtemp()
|
| 54 |
+
# Create a temporary file path
|
| 55 |
+
temp_file_path = os.path.join(temp_dir,image_file.filename)
|
| 56 |
+
# Write the uploaded file content to the temporary file
|
| 57 |
+
with open(temp_file_path, "wb") as temp_file:
|
| 58 |
+
shutil.copyfileobj(image_file.file, temp_file)
|
| 59 |
+
result=vit_loader().predict(image_path=Path(temp_file_path), cut_off=cut_off)
|
| 60 |
+
logger.info(result)
|
| 61 |
+
|
| 62 |
+
if result is not None:
|
| 63 |
+
return JSONResponse(content={"status":1,"document_classe":result})
|
| 64 |
+
else:
|
| 65 |
+
return JSONResponse(content={"status":0,"document_classe":None})
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(str(e))
|
| 69 |
+
return JSONResponse(content={"status":0,"error_message":str(e)})
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@app.post("/vgg_model")
|
| 75 |
+
async def vgg_clf(image_file: UploadFile = File(...)):
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
# Create a temporary directory
|
| 79 |
+
temp_dir = tempfile.mkdtemp()
|
| 80 |
+
# Create a temporary file path
|
| 81 |
+
temp_file_path = os.path.join(temp_dir,image_file.filename)
|
| 82 |
+
# Write the uploaded file content to the temporary file
|
| 83 |
+
with open(temp_file_path, "wb") as temp_file:
|
| 84 |
+
shutil.copyfileobj(image_file.file, temp_file)
|
| 85 |
+
result=vgg_loader().predict(image_path=Path(temp_file_path))
|
| 86 |
+
logger.info(result)
|
| 87 |
+
|
| 88 |
+
if result is not None:
|
| 89 |
+
return JSONResponse(content={"status":1,"document_classe":result})
|
| 90 |
+
else:
|
| 91 |
+
return JSONResponse(content={"status":0,"document_classe":None})
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(str(e))
|
| 95 |
+
return JSONResponse(content={"status":0,"document_classe":str(e)})
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
app/src/__init__.py
ADDED
|
File without changes
|
app/src/constant.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
vit_model_path=Path(r"artifacts\model\VIT_model\model.pth")
|
| 4 |
+
vit_mlb_path=Path(r"artifacts\model\VIT_model\mlb.joblib")
|
| 5 |
+
|
| 6 |
+
vgg_model_path=Path(r"artifacts\model\vgg_model\model.keras")
|
| 7 |
+
vgg_mlb_path=Path(r"artifacts\model\vgg_model\mlb.joblib")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
layout_model_path=Path(r"artifacts\model\layout_model\model.pth")
|
app/src/layout_loader.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional, List, Dict, Any
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from transformers import LayoutLMv2ForSequenceClassification, LayoutLMv2Processor, LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer
|
| 7 |
+
import os
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from app.src.logger import setup_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = setup_logger("layout_loader")
|
| 14 |
+
|
| 15 |
+
class LayoutLMDocumentClassifier:
|
| 16 |
+
"""
|
| 17 |
+
A class for classifying documents using a LayoutLMv2 model.
|
| 18 |
+
|
| 19 |
+
This class encapsulates the loading of the LayoutLMv2 model and its associated
|
| 20 |
+
processor, handles image preprocessing, and performs document classification
|
| 21 |
+
predictions. The model path is loaded from environment variables, promoting
|
| 22 |
+
flexible configuration. It includes robust error handling, logging, and
|
| 23 |
+
type hinting for production readiness.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self,model_path_str) -> None:
|
| 27 |
+
"""
|
| 28 |
+
Initializes the LayoutLMDocumentClassifier by loading the model and processor.
|
| 29 |
+
|
| 30 |
+
The model and processor are loaded from the path specified in the
|
| 31 |
+
environment variable 'LAYOUTLM_MODEL_PATH'. This method also sets up
|
| 32 |
+
the device for inference (GPU if available, otherwise CPU) and defines
|
| 33 |
+
the mapping from model output indices to class labels.
|
| 34 |
+
|
| 35 |
+
Includes robust error handling and logging for initialization and artifact loading.
|
| 36 |
+
|
| 37 |
+
Raises:
|
| 38 |
+
ValueError: If the 'LAYOUTLM_MODEL_PATH' environment variable is not set.
|
| 39 |
+
FileNotFoundError: If the model path specified in the environment variable
|
| 40 |
+
does not exist or a required artifact file is not found
|
| 41 |
+
during the artifact loading process.
|
| 42 |
+
Exception: If any other unexpected error occurs during the loading
|
| 43 |
+
of the model or processor.
|
| 44 |
+
"""
|
| 45 |
+
logger.info("Initializing LayoutLMDocumentClassifier.")
|
| 46 |
+
self.model_path_str: Optional[str]=model_path_str
|
| 47 |
+
self.model: Optional[LayoutLMv2ForSequenceClassification] = None
|
| 48 |
+
self.processor: Optional[LayoutLMv2Processor] = None
|
| 49 |
+
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
logger.info(f"Using device: {self.device}")
|
| 51 |
+
# Define id2label mapping as a class attribute
|
| 52 |
+
# This mapping should align with the model's output classes.
|
| 53 |
+
self.id2label: Dict[int, str] = {0:'invoice', 1: 'form', 2:'note', 3:'advertisement', 4: 'email'}
|
| 54 |
+
logger.info(f"Defined id2label mapping: {self.id2label}")
|
| 55 |
+
|
| 56 |
+
# Load model path from environment variable
|
| 57 |
+
model_path_str: Optional[str] = self.model_path_str
|
| 58 |
+
logger.info(f"Attempting to retrieve LAYOUTLM_MODEL_PATH from environment variables.")
|
| 59 |
+
if not model_path_str:
|
| 60 |
+
logger.critical("Critical Error: 'LAYOUTLM_MODEL_PATH' environment variable is not set.")
|
| 61 |
+
raise ValueError("LAYOUTLM_MODEL_PATH environment variable is not set.")
|
| 62 |
+
|
| 63 |
+
model_path: Path = Path(model_path_str)
|
| 64 |
+
logger.info(f"Retrieved model path: {model_path}")
|
| 65 |
+
if not model_path.exists():
|
| 66 |
+
logger.critical(f"Critical Error: Model path from environment variable does not exist: {model_path}")
|
| 67 |
+
raise FileNotFoundError(f"Model path not found: {model_path}")
|
| 68 |
+
logger.info(f"Model path {model_path} exists.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
logger.info("Calling _load_artifacts to load model and processor.")
|
| 73 |
+
self._load_artifacts(model_path)
|
| 74 |
+
if self.model is not None and self.processor is not None:
|
| 75 |
+
logger.info("LayoutLMDocumentClassifier initialized successfully.")
|
| 76 |
+
else:
|
| 77 |
+
# This case should ideally be caught and re-raised in _load_artifacts
|
| 78 |
+
logger.critical("LayoutLMDocumentClassifier failed to fully initialize due to artifact loading errors in _load_artifacts.")
|
| 79 |
+
# _load_artifacts already raises on critical failure, no need to raise again
|
| 80 |
+
except Exception as e:
|
| 81 |
+
# Catch and log any exception that wasn't handled and re-raised in _load_artifacts
|
| 82 |
+
logger.critical(f"An unhandled exception occurred during LayoutLMDocumentClassifier initialization: {e}", exc_info=True)
|
| 83 |
+
raise # Re-raise the exception after logging
|
| 84 |
+
logger.info("Initialization process completed.")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _load_artifacts(self, model_path: Path) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Loads the LayoutLMv2 model and processor from the specified path.
|
| 90 |
+
|
| 91 |
+
This is an internal helper method called during initialization. It handles
|
| 92 |
+
the loading of both the `LayoutLMv2ForSequenceClassification` model and
|
| 93 |
+
its corresponding `LayoutLMv2Processor` with error handling and logging.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
model_path: Path to the LayoutLMv2 model directory or file. This path
|
| 97 |
+
is expected to contain both the model weights and the
|
| 98 |
+
processor configuration/files.
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
FileNotFoundError: If the `model_path` or any required processor/model
|
| 102 |
+
file within that path is not found.
|
| 103 |
+
Exception: If any other unexpected error occurs during loading
|
| 104 |
+
from the specified path (e.g., corrupt files, compatibility issues).
|
| 105 |
+
"""
|
| 106 |
+
logger.info(f"Starting artifact loading from {model_path} for LayoutLMv2.")
|
| 107 |
+
processor_loaded: bool = False
|
| 108 |
+
model_loaded: bool = False
|
| 109 |
+
|
| 110 |
+
# Load Processor
|
| 111 |
+
try:
|
| 112 |
+
logger.info(f"Attempting to load LayoutLMv2 processor from {model_path}")
|
| 113 |
+
# Load feature extractor and tokenizer separately to create the processor
|
| 114 |
+
feature_extractor = LayoutLMv2FeatureExtractor()
|
| 115 |
+
tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
| 116 |
+
self.processor = LayoutLMv2Processor(feature_extractor, tokenizer)
|
| 117 |
+
logger.info("LayoutLMv2 processor loaded successfully.")
|
| 118 |
+
processor_loaded = True
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 processor from {model_path}: {e}", exc_info=True)
|
| 121 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 122 |
+
|
| 123 |
+
# Load Model
|
| 124 |
+
try:
|
| 125 |
+
logger.info(f"Attempting to load LayoutLMv2 model from {model_path}")
|
| 126 |
+
self.model = LayoutLMv2ForSequenceClassification.from_pretrained(model_path)
|
| 127 |
+
self.model.to(self.device) # Ensure model is on the correct device
|
| 128 |
+
logger.info(f"LayoutLMv2 model loaded successfully and moved to {self.device}.")
|
| 129 |
+
model_loaded = True
|
| 130 |
+
except FileNotFoundError:
|
| 131 |
+
logger.critical(f"Critical Error: LayoutLMv2 model file not found at {model_path}", exc_info=True)
|
| 132 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.critical(f"Critical Error: An unexpected error occurred while loading the LayoutLMv2 model from {model_path}: {e}", exc_info=True)
|
| 135 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 136 |
+
|
| 137 |
+
# Conditional logging based on loading success
|
| 138 |
+
if model_loaded and processor_loaded:
|
| 139 |
+
logger.info("All required LayoutLMv2 artifacts loaded successfully from _load_artifacts.")
|
| 140 |
+
elif model_loaded and not processor_loaded:
|
| 141 |
+
logger.error("LayoutLMv2 model loaded successfully, but processor loading failed in _load_artifacts.")
|
| 142 |
+
elif not model_loaded and processor_loaded:
|
| 143 |
+
logger.error("LayoutLMv2 processor loaded successfully, but model loading failed in _load_artifacts.")
|
| 144 |
+
else:
|
| 145 |
+
logger.error("Both LayoutLMv2 model and processor failed to load during _load_artifacts.")
|
| 146 |
+
logger.info("Artifact loading process completed.")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _prepare_inputs(self, image_path: Path) -> Optional[Dict[str, torch.Tensor]]:
|
| 150 |
+
"""
|
| 151 |
+
Loads and preprocesses an image to prepare inputs for the LayoutLMv2 model.
|
| 152 |
+
|
| 153 |
+
This method handles loading the image from a file path, converting it to RGB,
|
| 154 |
+
and using the loaded LayoutLMv2Processor to create the necessary input tensors
|
| 155 |
+
(pixel values, input IDs, attention masks, bounding boxes). The tensors are
|
| 156 |
+
then moved to the appropriate device for inference.
|
| 157 |
+
|
| 158 |
+
Includes robust error handling and logging for each step.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
image_path: Path to the image file (e.g., PNG, JPG) to be processed.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
A dictionary containing the prepared input tensors (e.g., 'pixel_values',
|
| 165 |
+
'input_ids', 'attention_mask', 'bbox') as PyTorch tensors, if image
|
| 166 |
+
loading and preprocessing are successful. Returns `None` if any
|
| 167 |
+
step fails (e.g., file not found, image corruption, processor error).
|
| 168 |
+
"""
|
| 169 |
+
logger.info(f"Starting image loading and preprocessing for {image_path}.")
|
| 170 |
+
image: Optional[Image.Image] = None
|
| 171 |
+
|
| 172 |
+
# Load image
|
| 173 |
+
try:
|
| 174 |
+
logger.info(f"Attempting to load image from {image_path}")
|
| 175 |
+
image = Image.open(image_path)
|
| 176 |
+
logger.info(f"Image loaded successfully from {image_path}.")
|
| 177 |
+
except FileNotFoundError:
|
| 178 |
+
logger.error(f"Error: Image file not found at {image_path}", exc_info=True)
|
| 179 |
+
return None
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True)
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
# Convert image to RGB
|
| 185 |
+
try:
|
| 186 |
+
logger.info(f"Attempting to convert image to RGB for {image_path}.")
|
| 187 |
+
if image is None:
|
| 188 |
+
logger.error(f"Image is None after loading for {image_path}. Cannot convert to RGB.")
|
| 189 |
+
return None
|
| 190 |
+
if image.mode != "RGB":
|
| 191 |
+
image = image.convert("RGB")
|
| 192 |
+
logger.info(f"Image converted to RGB successfully for {image_path}.")
|
| 193 |
+
else:
|
| 194 |
+
logger.info(f"Image is already in RGB format for {image_path}.")
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True)
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Prepare inputs using the processor
|
| 202 |
+
if self.processor is None:
|
| 203 |
+
logger.error("LayoutLMv2 processor is not loaded. Cannot prepare inputs.")
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
encoded_inputs: Optional[Dict[str, torch.Tensor]] = None
|
| 207 |
+
try:
|
| 208 |
+
logger.info(f"Attempting to prepare inputs using processor for {image_path}.")
|
| 209 |
+
# The processor expects a PIL Image or a list of PIL Images
|
| 210 |
+
if image is None:
|
| 211 |
+
logger.error(f"Image is None before preprocessing for {image_path}. Cannot prepare inputs.")
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
encoded_inputs = self.processor(
|
| 215 |
+
images=image,
|
| 216 |
+
return_tensors="pt",
|
| 217 |
+
truncation=True,
|
| 218 |
+
padding="max_length",
|
| 219 |
+
max_length=512
|
| 220 |
+
)
|
| 221 |
+
logger.info(f"Inputs prepared successfully for {image_path}.")
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"An error occurred during input preparation for {image_path}: {e}", exc_info=True)
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
# Move inputs to the device
|
| 227 |
+
if encoded_inputs is not None:
|
| 228 |
+
try:
|
| 229 |
+
logger.info(f"Attempting to move inputs to device ({self.device}) for {image_path}.")
|
| 230 |
+
for k, v in encoded_inputs.items():
|
| 231 |
+
if isinstance(v, torch.Tensor):
|
| 232 |
+
encoded_inputs[k] = v.to(self.device)
|
| 233 |
+
logger.info(f"Inputs moved to device ({self.device}) successfully for {image_path}.")
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.error(f"An error occurred while moving inputs to device for {image_path}: {e}", exc_info=True)
|
| 236 |
+
return None
|
| 237 |
+
else:
|
| 238 |
+
logger.error(f"Encoded inputs are None after processing for {image_path}. Cannot move to device.")
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
logger.info(f"Image loading and preprocessing completed successfully for {image_path}.")
|
| 243 |
+
return encoded_inputs
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def predict(self, image_path: Path) -> Optional[str]:
|
| 247 |
+
"""
|
| 248 |
+
Predicts the class label for a given image using the loaded LayoutLMv2 model.
|
| 249 |
+
|
| 250 |
+
This is the main prediction method. It orchestrates the process by first
|
| 251 |
+
preparing the image inputs using `_prepare_inputs`, performing inference
|
| 252 |
+
with the LayoutLMv2 model, determining the predicted class index from the
|
| 253 |
+
model's output logits, and finally mapping this index to a human-readable
|
| 254 |
+
class label using the `id2label` mapping.
|
| 255 |
+
|
| 256 |
+
Includes robust error handling and logging throughout the prediction pipeline.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
image_path: Path to the image file to classify.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
The predicted class label as a string if the entire prediction process
|
| 263 |
+
is successful. Returns `None` if any critical step fails (e.g.,
|
| 264 |
+
image loading/preprocessing, model inference, or if the predicted
|
| 265 |
+
index is not found in the `id2label` mapping).
|
| 266 |
+
"""
|
| 267 |
+
logger.info(f"Starting prediction process for image: {image_path}.")
|
| 268 |
+
|
| 269 |
+
# Prepare inputs
|
| 270 |
+
logger.info(f"Calling _prepare_inputs for {image_path}.")
|
| 271 |
+
encoded_inputs: Optional[Dict[str, torch.Tensor]] = self._prepare_inputs(image_path)
|
| 272 |
+
if encoded_inputs is None:
|
| 273 |
+
logger.error(f"Input preparation failed for {image_path}. Cannot perform prediction.")
|
| 274 |
+
logger.info(f"Prediction process failed for {image_path}.")
|
| 275 |
+
return None
|
| 276 |
+
logger.info(f"Input preparation successful for {image_path}.")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# Check if model is loaded
|
| 280 |
+
if self.model is None:
|
| 281 |
+
logger.error("LayoutLMv2 model is not loaded. Cannot perform prediction.")
|
| 282 |
+
logger.info(f"Prediction process failed for {image_path}.")
|
| 283 |
+
return None
|
| 284 |
+
logger.info("LayoutLMv2 model is loaded. Proceeding with inference.")
|
| 285 |
+
|
| 286 |
+
predicted_label: Optional[str] = None
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
logger.info(f"Performing model inference for {image_path}.")
|
| 290 |
+
self.model.eval() # Set model to evaluation mode
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
outputs: Any = self.model(**encoded_inputs)
|
| 293 |
+
logits: torch.Tensor = outputs.logits
|
| 294 |
+
|
| 295 |
+
# Determine predicted class index
|
| 296 |
+
# Ensure logits is a tensor before calling argmax
|
| 297 |
+
if not isinstance(logits, torch.Tensor):
|
| 298 |
+
logger.error(f"Model output 'logits' is not a torch.Tensor for {image_path}. Cannot determine predicted index.")
|
| 299 |
+
logger.info(f"Prediction process failed for {image_path} due to invalid model output.")
|
| 300 |
+
return None
|
| 301 |
+
|
| 302 |
+
predicted_class_idx: int = logits.argmax(-1).item()
|
| 303 |
+
logger.info(f"Model inference completed for {image_path}. Predicted index: {predicted_class_idx}.")
|
| 304 |
+
|
| 305 |
+
# Map index to label
|
| 306 |
+
logger.info(f"Attempting to map predicted index {predicted_class_idx} to label.")
|
| 307 |
+
if predicted_class_idx in self.id2label:
|
| 308 |
+
predicted_label = self.id2label[predicted_class_idx]
|
| 309 |
+
logger.info(f"Mapped predicted index {predicted_class_idx} to label: {predicted_label}.")
|
| 310 |
+
else:
|
| 311 |
+
logger.error(f"Predicted index {predicted_class_idx} not found in id2label mapping for {image_path}.")
|
| 312 |
+
logger.info(f"Prediction process failed for {image_path} due to unknown predicted index.")
|
| 313 |
+
return None # Return None if index is not in mapping
|
| 314 |
+
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.error(f"An error occurred during model inference or label mapping for {image_path}: {e}", exc_info=True)
|
| 317 |
+
logger.info(f"Prediction process failed for {image_path} due to inference/mapping error.")
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
logger.info(f"Prediction process completed successfully for {image_path}. Predicted label: {predicted_label}.")
|
| 321 |
+
return predicted_label
|
| 322 |
+
|
| 323 |
+
|
app/src/logger.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from logging.handlers import RotatingFileHandler
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
# Get the current working directory
|
| 7 |
+
#current_direction = os.path.dirname(os.path.abspath(__file__))
|
| 8 |
+
LOG_FILE=f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}"
|
| 9 |
+
logs_path=os.path.join(os.getcwd(),"logs",LOG_FILE)
|
| 10 |
+
os.makedirs(logs_path,exist_ok=True)
|
| 11 |
+
|
| 12 |
+
# Define the logging configuration
|
| 13 |
+
def setup_logger(file_name:str=None,api_app=None):
|
| 14 |
+
|
| 15 |
+
if file_name is not None :
|
| 16 |
+
LOG_FILE_PATH=os.path.join(logs_path,f"{file_name}.log")
|
| 17 |
+
#log_formatter = logging.Formatter("%(asctime)s- %(name)s - %(levelname)s - %(message)s")
|
| 18 |
+
|
| 19 |
+
# Modified log formatter to include filename, function name, and line number
|
| 20 |
+
log_formatter = logging.Formatter("%(asctime)s - %(filename)s - %(funcName)s - Line %(lineno)d - %(levelname)s - %(message)s")
|
| 21 |
+
|
| 22 |
+
# File handler for logging to a file
|
| 23 |
+
file_handler = RotatingFileHandler(filename=LOG_FILE_PATH,maxBytes=5 * 1024 * 1024, backupCount=3) # Log file size is 5MB with 3 backups
|
| 24 |
+
file_handler.setFormatter(log_formatter)
|
| 25 |
+
file_handler.setLevel(logging.INFO)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
file_handler2 = RotatingFileHandler(filename=os.path.join(logs_path,"global.log"),maxBytes=5 * 1024 * 1024, backupCount=3) # Log file size is 5MB with 3 backups
|
| 29 |
+
file_handler2.setFormatter(log_formatter)
|
| 30 |
+
file_handler2.setLevel(logging.INFO)
|
| 31 |
+
|
| 32 |
+
# Stream handler for console output (optional)
|
| 33 |
+
console_handler = logging.StreamHandler()
|
| 34 |
+
console_handler.setFormatter(log_formatter)
|
| 35 |
+
console_handler.setLevel(logging.DEBUG)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Add handlers to the root logger for custom logging
|
| 39 |
+
root_logger = logging.getLogger(file_name)
|
| 40 |
+
root_logger.setLevel(logging.DEBUG)
|
| 41 |
+
root_logger.addHandler(file_handler)
|
| 42 |
+
root_logger.addHandler(file_handler2)
|
| 43 |
+
#root_logger.addHandler(console_handler)
|
| 44 |
+
|
| 45 |
+
if api_app is not None:
|
| 46 |
+
# Get the FastAPI logger and attach handlers
|
| 47 |
+
uvicorn_access_logger = logging.getLogger("uvicorn.access") # For request logging
|
| 48 |
+
uvicorn_access_logger.setLevel(logging.INFO)
|
| 49 |
+
uvicorn_access_logger.addHandler(file_handler)
|
| 50 |
+
uvicorn_access_logger.addHandler(file_handler2)
|
| 51 |
+
#api_logger.addHandler(console_handler)
|
| 52 |
+
|
| 53 |
+
return uvicorn_access_logger
|
| 54 |
+
|
| 55 |
+
else:
|
| 56 |
+
return root_logger
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
|
| 61 |
+
# Modified log formatter to include filename, function name, and line number
|
| 62 |
+
log_formatter = logging.Formatter("%(asctime)s - %(filename)s - %(funcName)s - Line %(lineno)d - %(levelname)s - %(message)s")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
file_handler2 = RotatingFileHandler(filename=os.path.join(logs_path,"global.log"),maxBytes=5 * 1024 * 1024, backupCount=3) # Log file size is 5MB with 3 backups
|
| 66 |
+
file_handler2.setFormatter(log_formatter)
|
| 67 |
+
file_handler2.setLevel(logging.INFO)
|
| 68 |
+
|
| 69 |
+
# Stream handler for console output (optional)
|
| 70 |
+
console_handler = logging.StreamHandler()
|
| 71 |
+
console_handler.setFormatter(log_formatter)
|
| 72 |
+
console_handler.setLevel(logging.DEBUG)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Add handlers to the root logger for custom logging
|
| 76 |
+
root_logger = logging.getLogger(file_name)
|
| 77 |
+
root_logger.setLevel(logging.DEBUG)
|
| 78 |
+
root_logger.addHandler(file_handler2)
|
| 79 |
+
#root_logger.addHandler(console_handler)
|
| 80 |
+
|
| 81 |
+
if api_app is not None:
|
| 82 |
+
# Get the FastAPI logger and attach handlers
|
| 83 |
+
uvicorn_access_logger = logging.getLogger("uvicorn.access") # For request logging
|
| 84 |
+
uvicorn_access_logger.setLevel(logging.INFO)
|
| 85 |
+
uvicorn_access_logger.addHandler(file_handler2)
|
| 86 |
+
#api_logger.addHandler(console_handler)
|
| 87 |
+
|
| 88 |
+
return uvicorn_access_logger
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
return root_logger
|
app/src/model_loader.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from app.src.vgg16_load import VGGDocumentClassifier
|
| 3 |
+
from app.src.vit_load import VITDocumentClassifier
|
| 4 |
+
from app.src.constant import *
|
| 5 |
+
from app.src.logger import setup_logger
|
| 6 |
+
|
| 7 |
+
logger = setup_logger("model_loader")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def vit_loader()->VITDocumentClassifier:
|
| 11 |
+
try:
|
| 12 |
+
vit=VITDocumentClassifier(vit_model_path, vit_mlb_path)
|
| 13 |
+
return vit
|
| 14 |
+
except Exception as e:
|
| 15 |
+
logger.error(str(e))
|
| 16 |
+
raise e
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def vgg_loader():
|
| 20 |
+
try:
|
| 21 |
+
vgg=VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
|
| 22 |
+
return vgg
|
| 23 |
+
except Exception as e:
|
| 24 |
+
logger.error(str(e))
|
| 25 |
+
raise e
|
app/src/test_vit.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 8 |
+
from app.src.logger import setup_logger
|
| 9 |
+
|
| 10 |
+
logger = setup_logger("test_vit")
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
mlb_file_path=Path("artifacts\model\VIT_model\mlb.joblib")
|
| 16 |
+
model_file_path=Path("artifacts\model\VIT_model\model.pth")
|
| 17 |
+
# Select model
|
| 18 |
+
model_id = "google/vit-base-patch16-224-in21k"
|
| 19 |
+
# Load processor
|
| 20 |
+
processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 21 |
+
|
| 22 |
+
# TODO: You need to load your fine-tuned model here
|
| 23 |
+
# For example:
|
| 24 |
+
# model = AutoModelForImageClassification.from_pretrained("path/to/your/fine-tuned-model")
|
| 25 |
+
# For now, we will use the base model for demonstration, but it will not give correct predictions.
|
| 26 |
+
#model = AutoModelForImageClassification.from_pretrained(model_id)
|
| 27 |
+
# Load the entire model
|
| 28 |
+
model= torch.load(model_file_path, map_location=device,weights_only=False )
|
| 29 |
+
# Set device
|
| 30 |
+
model.to(device)
|
| 31 |
+
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(str(e))
|
| 34 |
+
raise e
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def mlb_load(file_path:Path)->MultiLabelBinarizer:
|
| 40 |
+
try:
|
| 41 |
+
# Assuming you run this notebook from the root of your project directory
|
| 42 |
+
mlb = joblib.load(file_path)
|
| 43 |
+
|
| 44 |
+
except FileNotFoundError:
|
| 45 |
+
logger.error("Error: 'artifacts/model/VIT_model/mlb.joblib' not found.")
|
| 46 |
+
logger.error("Please make sure the path is correct. Using a placeholder binarizer.")
|
| 47 |
+
# As a placeholder, let's create a dummy mlb if the file is not found.
|
| 48 |
+
mlb = MultiLabelBinarizer()
|
| 49 |
+
# This should be the set of your actual labels.
|
| 50 |
+
mlb.fit([['advertisement', 'email', 'form', 'invoice', 'note']])
|
| 51 |
+
return mlb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def VIT_model_prediction(image_path:Path,cut_off:float):
|
| 59 |
+
try:
|
| 60 |
+
# Load and convert image
|
| 61 |
+
# --- IMPORTANT: Please update this path to your image ---
|
| 62 |
+
try:
|
| 63 |
+
image = Image.open(image_path)
|
| 64 |
+
if image.mode != "RGB":
|
| 65 |
+
image = image.convert("RGB")
|
| 66 |
+
except FileNotFoundError:
|
| 67 |
+
logger.error(f"Error: Image not found at {image_path}")
|
| 68 |
+
logger.error("Using a dummy image for demonstration.")
|
| 69 |
+
# Create a dummy image for demonstration if image not found
|
| 70 |
+
image = Image.new('RGB', (224, 224), color = 'red')
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Preprocess image
|
| 74 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
|
| 75 |
+
|
| 76 |
+
# Forward pass
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
outputs = model(pixel_values)
|
| 79 |
+
logits = outputs.logits
|
| 80 |
+
|
| 81 |
+
# Apply sigmoid for multi-label classification
|
| 82 |
+
sigmoid = torch.nn.Sigmoid()
|
| 83 |
+
probs = sigmoid(logits.squeeze().cpu())
|
| 84 |
+
|
| 85 |
+
# Thresholding (using 0.5 as an example)
|
| 86 |
+
predictions = np.zeros(probs.shape)
|
| 87 |
+
predictions[np.where(probs >= cut_off)] = 1
|
| 88 |
+
|
| 89 |
+
# Get label names using the loaded MultiLabelBinarizer
|
| 90 |
+
mlb=mlb_load(mlb_file_path)
|
| 91 |
+
# The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
|
| 92 |
+
predicted_labels = mlb.inverse_transform(predictions.reshape(1, -1))
|
| 93 |
+
logger.info(f"Predicted labels: {predicted_labels}")
|
| 94 |
+
return {"status":1,"classe":predicted_labels}
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(str(e))
|
| 98 |
+
raise e
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
#VIT_model_prediction(Path(r"dataset\sample_text_ds\test\email\2078379610a.jpg"),0.5)
|
app/src/vgg16_load.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import joblib
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import logging
|
| 9 |
+
import cv2
|
| 10 |
+
import keras
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import tensorflow as tf
|
| 13 |
+
from typing import Optional, Tuple, List
|
| 14 |
+
from app.src.logger import setup_logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Configure logging
|
| 18 |
+
logger = setup_logger("vgg16_load")
|
| 19 |
+
|
| 20 |
+
def load_vgg_artifacts(model_path: Path, mlb_path: Path) -> tuple[tf.keras.Model, MultiLabelBinarizer]:
|
| 21 |
+
"""
|
| 22 |
+
Loads the VGG model and the MultiLabelBinarizer from specified paths.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path: Path to the VGG model file (.keras).
|
| 26 |
+
mlb_path: Path to the MultiLabelBinarizer file (.joblib).
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
A tuple containing the loaded Keras model and MultiLabelBinarizer object.
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
FileNotFoundError: If either the model file or the MLB file is not found.
|
| 33 |
+
Exception: If any other error occurs during loading.
|
| 34 |
+
"""
|
| 35 |
+
model = None
|
| 36 |
+
mlb = None
|
| 37 |
+
try:
|
| 38 |
+
logger.info(f"Attempting to load VGG model from {model_path}")
|
| 39 |
+
model = tf.keras.models.load_model(model_path)
|
| 40 |
+
logger.info("VGG model loaded successfully.")
|
| 41 |
+
except FileNotFoundError:
|
| 42 |
+
logger.error(f"Error: VGG model file not found at {model_path}")
|
| 43 |
+
raise
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"An error occurred while loading the VGG model: {e}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
|
| 50 |
+
mlb = joblib.load(mlb_path)
|
| 51 |
+
logger.info("MultiLabelBinarizer loaded successfully.")
|
| 52 |
+
except FileNotFoundError:
|
| 53 |
+
logger.error(f"Error: MultiLabelBinarizer file not found at {mlb_path}")
|
| 54 |
+
raise
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"An error occurred while loading the MultiLabelBinarizer: {e}")
|
| 57 |
+
raise
|
| 58 |
+
|
| 59 |
+
logger.info("Both VGG model and MultiLabelBinarizer loaded successfully.")
|
| 60 |
+
return model, mlb
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def preprocess_image(image_path: Path, target_size: tuple[int, int] = (224, 224)) -> np.ndarray | None:
|
| 66 |
+
"""
|
| 67 |
+
Preprocesses an image for VGG model prediction.
|
| 68 |
+
|
| 69 |
+
Loads an image from the specified path, converts it to RGB, resizes it,
|
| 70 |
+
and normalizes pixel values. Includes robust error handling and logging
|
| 71 |
+
at each step.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
image_path: Path to the image file.
|
| 75 |
+
target_size: A tuple (width, height) specifying the desired output size.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
A preprocessed NumPy array representing the image with pixel values
|
| 79 |
+
scaled between 0 and 1, or None if an error occurred during processing.
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
logger.info(f"Attempting to load image from {image_path}")
|
| 83 |
+
img = cv2.imread(str(image_path)) # cv2.imread expects a string or numpy array
|
| 84 |
+
|
| 85 |
+
if img is None:
|
| 86 |
+
logger.error(f"Error: Could not load image from {image_path}. cv2.imread returned None.")
|
| 87 |
+
return None
|
| 88 |
+
logger.info("Image loaded successfully.")
|
| 89 |
+
|
| 90 |
+
logger.info("Attempting to convert image to RGB.")
|
| 91 |
+
# Check if the image is already in a format that doesn't need BGR to RGB conversion
|
| 92 |
+
# cv2.imread loads in BGR format by default for color images.
|
| 93 |
+
# If the image is grayscale, it might be loaded as such.
|
| 94 |
+
# We want RGB for consistency with models trained on RGB data.
|
| 95 |
+
if len(img.shape) == 3 and img.shape[2] == 3: # Check if it's a color image (likely BGR)
|
| 96 |
+
try:
|
| 97 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 98 |
+
logger.info("Image converted to RGB successfully.")
|
| 99 |
+
except cv2.error as e:
|
| 100 |
+
logger.error(f"Error during BGR to RGB conversion for image {image_path}: {e}")
|
| 101 |
+
return None
|
| 102 |
+
elif len(img.shape) == 2: # Grayscale image
|
| 103 |
+
try:
|
| 104 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 105 |
+
logger.info("Grayscale image converted to RGB successfully.")
|
| 106 |
+
except cv2.error as e:
|
| 107 |
+
logger.error(f"Error during Grayscale to RGB conversion for image {image_path}: {e}")
|
| 108 |
+
return None
|
| 109 |
+
else:
|
| 110 |
+
logger.warning(f"Unexpected image format for {image_path}. Attempting to proceed.")
|
| 111 |
+
# If it's not a standard color or grayscale, we might proceed but log a warning.
|
| 112 |
+
# Depending on requirements, you might want to return None here.
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
logger.info(f"Attempting to resize image to {target_size}.")
|
| 116 |
+
try:
|
| 117 |
+
img = cv2.resize(img, target_size)
|
| 118 |
+
if img is None or img.size == 0:
|
| 119 |
+
logger.error(f"Error: cv2.resize returned None or empty array for image {image_path}.")
|
| 120 |
+
return None
|
| 121 |
+
logger.info("Image resized successfully.")
|
| 122 |
+
except cv2.error as e:
|
| 123 |
+
logger.error(f"Error during image resizing for image {image_path} to size {target_size}: {e}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
logger.info("Attempting to normalize pixel values.")
|
| 128 |
+
try:
|
| 129 |
+
# Ensure the image is the correct dtype before division
|
| 130 |
+
img = img.astype("float32") / 255.0
|
| 131 |
+
if img is None or img.size == 0 or np.max(img) > 1.0 or np.min(img) < 0.0:
|
| 132 |
+
logger.error(f"Error: Image normalization failed or resulted in unexpected values for image {image_path}.")
|
| 133 |
+
return None
|
| 134 |
+
logger.info("Pixel values normalized successfully.")
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error during pixel normalization for image {image_path}: {e}")
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
logger.info(f"Image preprocessing completed successfully for {image_path}.")
|
| 140 |
+
return img
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logger.error(f"An unexpected error occurred during image preprocessing for {image_path}: {e}")
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class VGGDocumentClassifier:
|
| 151 |
+
"""
|
| 152 |
+
A class for classifying documents using a VGG16 model.
|
| 153 |
+
|
| 154 |
+
This class encapsulates the loading of the VGG16 model and its associated
|
| 155 |
+
MultiLabelBinarizer, provides a method to preprocess input images, and
|
| 156 |
+
performs document classification predictions.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, model_path: Path, mlb_path: Path, target_size: Tuple[int, int] = (224, 224)) -> None:
|
| 160 |
+
"""
|
| 161 |
+
Initializes the VGGDocumentClassifier by loading the model and MLB.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
model_path: Path to the VGG model file (.keras).
|
| 165 |
+
mlb_path: Path to the MultiLabelBinarizer file (.joblib).
|
| 166 |
+
target_size: The target size (width, height) for image preprocessing.
|
| 167 |
+
Defaults to (224, 224).
|
| 168 |
+
|
| 169 |
+
Raises:
|
| 170 |
+
FileNotFoundError: If either the model file or the MLB file is not found.
|
| 171 |
+
Exception: If any other error occurs during loading.
|
| 172 |
+
"""
|
| 173 |
+
logger.info("Initializing VGGDocumentClassifier.")
|
| 174 |
+
self.model: Optional[tf.keras.Model] = None
|
| 175 |
+
self.mlb: Optional[MultiLabelBinarizer] = None
|
| 176 |
+
self.target_size: Tuple[int, int] = target_size
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
self._load_artifacts(model_path, mlb_path)
|
| 180 |
+
if self.model and self.mlb:
|
| 181 |
+
logger.info("VGGDocumentClassifier initialized successfully.")
|
| 182 |
+
else:
|
| 183 |
+
logger.critical("VGGDocumentClassifier failed to fully initialize due to artifact loading errors.")
|
| 184 |
+
raise RuntimeError("Failed to load all required artifacts for VGGDocumentClassifier.")
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.critical(f"Failed to initialize VGGDocumentClassifier: {e}", exc_info=True)
|
| 187 |
+
raise # Re-raise the exception after logging
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _load_artifacts(self, model_path: Path, mlb_path: Path) -> None:
|
| 191 |
+
"""
|
| 192 |
+
Loads the VGG model and MultiLabelBinarizer with error handling and logging.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
model_path: Path to the VGG model file (.keras).
|
| 196 |
+
mlb_path: Path to the MultiLabelBinarizer file (.joblib).
|
| 197 |
+
|
| 198 |
+
Raises:
|
| 199 |
+
FileNotFoundError: If either the model file or the MLB file is not found.
|
| 200 |
+
Exception: If any other unexpected error occurs during loading.
|
| 201 |
+
"""
|
| 202 |
+
logger.info("Starting artifact loading.")
|
| 203 |
+
model_loaded: bool = False
|
| 204 |
+
mlb_loaded: bool = False
|
| 205 |
+
|
| 206 |
+
# Load Model
|
| 207 |
+
try:
|
| 208 |
+
logger.info(f"Attempting to load VGG model from {model_path}")
|
| 209 |
+
self.model = tf.keras.models.load_model(model_path)
|
| 210 |
+
logger.info("VGG model loaded successfully.")
|
| 211 |
+
model_loaded = True
|
| 212 |
+
except FileNotFoundError:
|
| 213 |
+
logger.critical(f"Critical Error: VGG model file not found at {model_path}", exc_info=True)
|
| 214 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.critical(f"Critical Error: An unexpected error occurred while loading the VGG model from {model_path}: {e}", exc_info=True)
|
| 217 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 218 |
+
|
| 219 |
+
# Load MLB
|
| 220 |
+
try:
|
| 221 |
+
logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
|
| 222 |
+
self.mlb = joblib.load(mlb_path)
|
| 223 |
+
logger.info("MultiLabelBinarizer loaded successfully.")
|
| 224 |
+
mlb_loaded = True
|
| 225 |
+
except FileNotFoundError:
|
| 226 |
+
logger.critical(f"Critical Error: MultiLabelBinarizer file not found at {mlb_path}", exc_info=True)
|
| 227 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.critical(f"Critical Error: An unexpected error occurred while loading the MultiLabelBinarizer from {mlb_path}: {e}", exc_info=True)
|
| 230 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 231 |
+
|
| 232 |
+
if model_loaded and mlb_loaded:
|
| 233 |
+
logger.info("All required VGG artifacts loaded successfully.")
|
| 234 |
+
else:
|
| 235 |
+
logger.error("One or more required VGG artifacts failed to load during _load_artifacts.")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def preprocess_image(self, image_path: Path) -> Optional[np.ndarray]:
|
| 239 |
+
"""
|
| 240 |
+
Preprocesses an image for VGG model prediction.
|
| 241 |
+
|
| 242 |
+
Loads an image from the specified path, converts it to RGB, resizes it,
|
| 243 |
+
and normalizes pixel values. Includes robust error handling and logging
|
| 244 |
+
at each step.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
image_path: Path to the image file.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
A preprocessed NumPy array representing the image with pixel values
|
| 251 |
+
scaled between 0 and 1, or None if an error occurred during processing.
|
| 252 |
+
"""
|
| 253 |
+
try:
|
| 254 |
+
logger.info(f"Attempting to load image from {image_path}")
|
| 255 |
+
img = cv2.imread(str(image_path)) # cv2.imread expects a string or numpy array
|
| 256 |
+
|
| 257 |
+
if img is None:
|
| 258 |
+
logger.error(f"Error: Could not load image from {image_path}. cv2.imread returned None.")
|
| 259 |
+
return None
|
| 260 |
+
logger.info("Image loaded successfully.")
|
| 261 |
+
|
| 262 |
+
logger.info("Attempting to convert image to RGB.")
|
| 263 |
+
if len(img.shape) == 3 and img.shape[2] == 3: # Check if it's a color image (likely BGR)
|
| 264 |
+
try:
|
| 265 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 266 |
+
logger.info("Image converted to RGB successfully.")
|
| 267 |
+
except cv2.error as e:
|
| 268 |
+
logger.error(f"Error during BGR to RGB conversion for image {image_path}: {e}")
|
| 269 |
+
return None
|
| 270 |
+
elif len(img.shape) == 2: # Grayscale image
|
| 271 |
+
try:
|
| 272 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 273 |
+
logger.info("Grayscale image converted to RGB successfully.")
|
| 274 |
+
except cv2.error as e:
|
| 275 |
+
logger.error(f"Error during Grayscale to RGB conversion for image {image_path}: {e}")
|
| 276 |
+
return None
|
| 277 |
+
else:
|
| 278 |
+
logger.warning(f"Unexpected image format for {image_path}. Attempting to proceed.")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
logger.info(f"Attempting to resize image to {self.target_size}.")
|
| 282 |
+
try:
|
| 283 |
+
img = cv2.resize(img, self.target_size)
|
| 284 |
+
if img is None or img.size == 0:
|
| 285 |
+
logger.error(f"Error: cv2.resize returned None or empty array for image {image_path}.")
|
| 286 |
+
return None
|
| 287 |
+
logger.info("Image resized successfully.")
|
| 288 |
+
except cv2.error as e:
|
| 289 |
+
logger.error(f"Error during image resizing for image {image_path} to size {self.target_size}: {e}")
|
| 290 |
+
return None
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
logger.info("Attempting to normalize pixel values.")
|
| 294 |
+
try:
|
| 295 |
+
img = img.astype("float32") / 255.0
|
| 296 |
+
if img is None or img.size == 0 or np.max(img) > 1.0 or np.min(img) < 0.0:
|
| 297 |
+
logger.error(f"Error: Image normalization failed or resulted in unexpected values for image {image_path}.")
|
| 298 |
+
return None
|
| 299 |
+
logger.info("Pixel values normalized successfully.")
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"Error during pixel normalization for image {image_path}: {e}")
|
| 302 |
+
return None
|
| 303 |
+
|
| 304 |
+
logger.info(f"Image preprocessing completed successfully for {image_path}.")
|
| 305 |
+
return img
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"An unexpected error occurred during image preprocessing for {image_path}: {e}")
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def predict(self, image_path: Path) -> Optional[List[str]]:
|
| 313 |
+
"""
|
| 314 |
+
Predicts the class labels for a given image using the loaded VGG model.
|
| 315 |
+
|
| 316 |
+
The process involves loading and preprocessing the image, performing
|
| 317 |
+
inference with the model, and converting the prediction to class labels
|
| 318 |
+
using the MultiLabelBinarizer.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
image_path: Path to the image file to classify.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
A list of predicted class labels (strings) if the prediction process
|
| 325 |
+
is successful. Returns None if any critical step (image loading,
|
| 326 |
+
preprocessing, model inference, or inverse transform) fails.
|
| 327 |
+
Returns an empty list if the prediction process is successful but
|
| 328 |
+
no labels are predicted.
|
| 329 |
+
"""
|
| 330 |
+
logger.info(f"Starting prediction process for image: {image_path}.")
|
| 331 |
+
|
| 332 |
+
if self.model is None or self.mlb is None:
|
| 333 |
+
logger.error("Model or MultiLabelBinarizer not loaded. Cannot perform prediction.")
|
| 334 |
+
return None
|
| 335 |
+
|
| 336 |
+
# Preprocess image
|
| 337 |
+
image = self.preprocess_image(image_path)
|
| 338 |
+
if image is None:
|
| 339 |
+
logger.error(f"Image preprocessing failed for {image_path}. Cannot perform prediction.")
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
logger.info(f"Performing model inference for {image_path}.")
|
| 344 |
+
# Add batch dimension to the image
|
| 345 |
+
image = np.expand_dims(image, axis=0)
|
| 346 |
+
prd = self.model.predict(image)
|
| 347 |
+
logger.info(f"Model inference completed for {image_path}. Prediction shape: {prd.shape}")
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"An error occurred during model inference for {image_path}: {e}", exc_info=True)
|
| 350 |
+
return None
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# Convert the prediction to a binary indicator format and get labels
|
| 354 |
+
try:
|
| 355 |
+
logger.info(f"Converting prediction to labels for {image_path}.")
|
| 356 |
+
# Assuming multi-class classification for now, taking the argmax
|
| 357 |
+
# If it's multi-label, you'd apply a sigmoid and thresholding here
|
| 358 |
+
pred_id = np.argmax(prd, axis=1)
|
| 359 |
+
|
| 360 |
+
# Create a zero array with the shape (1, number of classes)
|
| 361 |
+
binary_prediction = np.zeros((1, len(self.mlb.classes_)))
|
| 362 |
+
# Set the index of the predicted class to 1
|
| 363 |
+
binary_prediction[0, pred_id] = 1
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
predicted_labels_tuple_list: List[Tuple[str, ...]] = self.mlb.inverse_transform(binary_prediction)
|
| 367 |
+
logger.info(f"Prediction processed for {image_path}. Predicted labels (raw tuple list): {predicted_labels_tuple_list}")
|
| 368 |
+
|
| 369 |
+
if predicted_labels_tuple_list and len(predicted_labels_tuple_list) > 0:
|
| 370 |
+
final_labels: List[str] = list(predicted_labels_tuple_list[0])
|
| 371 |
+
logger.info(f"Final predicted labels for {image_path}: {final_labels}")
|
| 372 |
+
return final_labels
|
| 373 |
+
else:
|
| 374 |
+
logger.warning(f"MLB inverse_transform returned an empty list for {image_path}. No labels predicted.")
|
| 375 |
+
return []
|
| 376 |
+
|
| 377 |
+
except Exception as e:
|
| 378 |
+
logger.error(f"An error occurred during inverse transform or label processing for {image_path}: {e}", exc_info=True)
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
|
app/src/vit_load.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 5 |
+
from sklearn.preprocessing import MultiLabelBinarizer
|
| 6 |
+
import joblib
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional, Tuple, Any
|
| 9 |
+
from app.src.logger import setup_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = setup_logger("vit_load")
|
| 14 |
+
|
| 15 |
+
class VITDocumentClassifier:
|
| 16 |
+
"""
|
| 17 |
+
A class for classifying documents using a Vision Transformer (ViT) model.
|
| 18 |
+
|
| 19 |
+
This class encapsulates the loading of the ViT model, its associated processor,
|
| 20 |
+
and a MultiLabelBinarizer for converting model outputs to meaningful labels.
|
| 21 |
+
It provides a method to preprocess input images and perform multi-label
|
| 22 |
+
classification predictions with a specified confidence cutoff threshold.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, model_path: Path, mlb_path: Path, model_id: str = "google/vit-base-patch16-224-in21k") -> None:
|
| 26 |
+
"""
|
| 27 |
+
Initializes the VITDocumentClassifier by loading the model, processor, and MLB.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
model_path: Path to the ViT model file (.pth). This is expected to be
|
| 31 |
+
a pre-trained or fine-tuned PyTorch model file.
|
| 32 |
+
mlb_path: Path to the MultiLabelBinarizer file (.joblib). This file
|
| 33 |
+
should contain the fitted binarizer object corresponding
|
| 34 |
+
to the model's output classes.
|
| 35 |
+
model_id: The Hugging Face model ID for the processor. This is used
|
| 36 |
+
to load the appropriate image processor for the ViT model.
|
| 37 |
+
Defaults to "google/vit-base-patch16-224-in21k".
|
| 38 |
+
|
| 39 |
+
Raises:
|
| 40 |
+
FileNotFoundError: If either the model file or the MLB file is not found
|
| 41 |
+
at the specified paths during artifact loading.
|
| 42 |
+
Exception: If any other unexpected error occurs during the loading
|
| 43 |
+
of the model, processor, or MultiLabelBinarizer.
|
| 44 |
+
RuntimeError: If artifact loading fails for critical components
|
| 45 |
+
(model or MLB).
|
| 46 |
+
"""
|
| 47 |
+
logger.info("Initializing VITDocumentClassifier.")
|
| 48 |
+
self.model: Optional[torch.nn.Module] = None
|
| 49 |
+
self.processor: Optional[AutoImageProcessor] = None
|
| 50 |
+
self.mlb: Optional[MultiLabelBinarizer] = None
|
| 51 |
+
self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 52 |
+
logger.info(f"Using device: {self.device}")
|
| 53 |
+
self.model_id: str = model_id
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
self._load_artifacts(model_path, mlb_path)
|
| 57 |
+
if self.model and self.processor and self.mlb:
|
| 58 |
+
logger.info("VITDocumentClassifier initialized successfully.")
|
| 59 |
+
else:
|
| 60 |
+
# This case should ideally be caught and re-raised in _load_artifacts
|
| 61 |
+
# but adding a check here for robustness.
|
| 62 |
+
logger.critical("VITDocumentClassifier failed to fully initialize due to artifact loading errors.")
|
| 63 |
+
raise RuntimeError("Failed to load all required artifacts for VITDocumentClassifier.")
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.critical(f"Failed to initialize VITDocumentClassifier: {e}", exc_info=True)
|
| 67 |
+
# Re-raise the exception after logging
|
| 68 |
+
raise
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _load_artifacts(self, model_path: Path, mlb_path: Path) -> None:
|
| 72 |
+
"""
|
| 73 |
+
Loads the ViT model, processor, and MultiLabelBinarizer with enhanced error handling and logging.
|
| 74 |
+
|
| 75 |
+
This is an internal helper method called during initialization.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model_path: Path to the ViT model file (.pth).
|
| 79 |
+
mlb_path: Path to the MultiLabelBinarizer file (.joblib).
|
| 80 |
+
|
| 81 |
+
Raises:
|
| 82 |
+
FileNotFoundError: If either the model file or the MLB file is not found.
|
| 83 |
+
Exception: If any other unexpected error occurs during loading.
|
| 84 |
+
"""
|
| 85 |
+
logger.info("Starting artifact loading.")
|
| 86 |
+
processor_loaded: bool = False
|
| 87 |
+
model_loaded: bool = False
|
| 88 |
+
mlb_loaded: bool = False
|
| 89 |
+
|
| 90 |
+
# Load Processor
|
| 91 |
+
try:
|
| 92 |
+
logger.info(f"Attempting to load ViT processor for model ID: {self.model_id}")
|
| 93 |
+
self.processor = AutoImageProcessor.from_pretrained(self.model_id, use_fast=True)
|
| 94 |
+
logger.info("ViT processor loaded successfully.")
|
| 95 |
+
processor_loaded = True
|
| 96 |
+
except Exception as e:
|
| 97 |
+
# Log at error level as processor is important but not strictly critical if we raise later
|
| 98 |
+
logger.error(f"An error occurred while loading the ViT processor for model ID {self.model_id}: {e}", exc_info=True)
|
| 99 |
+
# Do not re-raise here, continue loading other artifacts
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Load Model
|
| 103 |
+
try:
|
| 104 |
+
logger.info(f"Attempting to load ViT model from {model_path}")
|
| 105 |
+
# Note: Adjust map_location as needed based on where the model was saved
|
| 106 |
+
self.model = torch.load(model_path, map_location=self.device, weights_only=False)
|
| 107 |
+
self.model.to(self.device) # Ensure model is on the correct device
|
| 108 |
+
logger.info(f"ViT model loaded successfully and moved to {self.device}.")
|
| 109 |
+
model_loaded = True
|
| 110 |
+
except FileNotFoundError:
|
| 111 |
+
logger.critical(f"Critical Error: ViT model file not found at {model_path}", exc_info=True)
|
| 112 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.critical(f"Critical Error: An unexpected error occurred while loading the ViT model from {model_path}: {e}", exc_info=True)
|
| 115 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Load MLB
|
| 119 |
+
try:
|
| 120 |
+
logger.info(f"Attempting to load MultiLabelBinarizer from {mlb_path}")
|
| 121 |
+
self.mlb = joblib.load(mlb_path)
|
| 122 |
+
logger.info("MultiLabelBinarizer loaded successfully.")
|
| 123 |
+
mlb_loaded = True
|
| 124 |
+
except FileNotFoundError:
|
| 125 |
+
logger.critical(f"Critical Error: MultiLabelBinarizer file not found at {mlb_path}", exc_info=True)
|
| 126 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.critical(f"Critical Error: An unexpected error occurred while loading the MultiLabelBinarizer from {mlb_path}: {e}", exc_info=True)
|
| 129 |
+
raise # Re-raise to indicate a critical initialization failure
|
| 130 |
+
|
| 131 |
+
if processor_loaded and model_loaded and mlb_loaded:
|
| 132 |
+
logger.info("All required ViT artifacts loaded successfully.")
|
| 133 |
+
else:
|
| 134 |
+
logger.error("One or more required ViT artifacts failed to load during _load_artifacts.")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def predict(self, image_path: Path, cut_off: float = 0.5) -> Optional[List[str]]:
|
| 138 |
+
"""
|
| 139 |
+
Predicts the class labels for a given image using the loaded ViT model.
|
| 140 |
+
|
| 141 |
+
The process involves loading and preprocessing the image, performing
|
| 142 |
+
inference with the model, applying a sigmoid activation, thresholding
|
| 143 |
+
the probabilities to obtain binary predictions, and finally converting
|
| 144 |
+
the binary predictions back to class labels using the MultiLabelBinarizer.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
image_path: Path to the image file to classify. The image is expected
|
| 148 |
+
to be in a format compatible with PIL (Pillow).
|
| 149 |
+
cut_off: The threshold for converting predicted probabilities into
|
| 150 |
+
binary labels. Probabilities greater than or equal to this
|
| 151 |
+
value are considered positive predictions (1), otherwise 0.
|
| 152 |
+
Defaults to 0.5.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
A list of predicted class labels (strings) if the prediction process
|
| 156 |
+
is successful. Returns None if any critical step (image loading,
|
| 157 |
+
preprocessing, model inference, or inverse transform) fails.
|
| 158 |
+
Returns an empty list if the prediction process is successful but
|
| 159 |
+
no labels meet the cutoff threshold.
|
| 160 |
+
"""
|
| 161 |
+
logger.info(f"Starting prediction process for image: {image_path} with cutoff {cut_off}.")
|
| 162 |
+
|
| 163 |
+
if self.model is None or self.processor is None or self.mlb is None:
|
| 164 |
+
logger.error("Model, processor, or MultiLabelBinarizer not loaded. Cannot perform prediction.")
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
# Load and preprocess image
|
| 168 |
+
image: Optional[Image.Image] = None
|
| 169 |
+
try:
|
| 170 |
+
logger.info(f"Attempting to load image from {image_path}")
|
| 171 |
+
image = Image.open(image_path)
|
| 172 |
+
logger.info(f"Image loaded successfully from {image_path}.")
|
| 173 |
+
except FileNotFoundError:
|
| 174 |
+
logger.error(f"Error: Image file not found at {image_path}", exc_info=True)
|
| 175 |
+
return None
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"An unexpected error occurred while loading image {image_path}: {e}", exc_info=True)
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
logger.info(f"Attempting to convert image to RGB for {image_path}.")
|
| 182 |
+
if image.mode != "RGB":
|
| 183 |
+
image = image.convert("RGB")
|
| 184 |
+
logger.info(f"Image converted to RGB successfully for {image_path}.")
|
| 185 |
+
else:
|
| 186 |
+
logger.info(f"Image is already in RGB format for {image_path}.")
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"An error occurred while converting image {image_path} to RGB: {e}", exc_info=True)
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# Preprocess image using the loaded processor
|
| 194 |
+
try:
|
| 195 |
+
logger.info(f"Attempting to preprocess image using processor for {image_path}.")
|
| 196 |
+
# Check if image is valid after loading/conversion
|
| 197 |
+
if image is None:
|
| 198 |
+
logger.error(f"Image is None after loading/conversion for {image_path}. Cannot preprocess.")
|
| 199 |
+
return None
|
| 200 |
+
# The processor expects a PIL Image or a list of PIL Images
|
| 201 |
+
pixel_values: torch.Tensor = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
|
| 202 |
+
logger.info(f"Image preprocessed and moved to device ({self.device}).")
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(f"An error occurred during image preprocessing for {image_path}: {e}", exc_info=True)
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
# Forward pass
|
| 208 |
+
try:
|
| 209 |
+
logger.info(f"Starting model forward pass for {image_path}.")
|
| 210 |
+
self.model.eval() # Set model to evaluation mode
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
outputs: Any = self.model(pixel_values) # Use Any because the output type can vary
|
| 213 |
+
logits: torch.Tensor = outputs.logits
|
| 214 |
+
logger.info(f"Model forward pass completed for {image_path}.")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"An error occurred during model forward pass for {image_path}: {e}", exc_info=True)
|
| 217 |
+
return None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Apply sigmoid and thresholding
|
| 221 |
+
try:
|
| 222 |
+
logger.info(f"Applying sigmoid and thresholding for {image_path}.")
|
| 223 |
+
sigmoid: torch.nn.Sigmoid = torch.nn.Sigmoid()
|
| 224 |
+
probs: torch.Tensor = sigmoid(logits.squeeze().cpu())
|
| 225 |
+
|
| 226 |
+
predictions: np.ndarray = np.zeros(probs.shape, dtype=int) # Explicitly set dtype to int
|
| 227 |
+
print(predictions)
|
| 228 |
+
predictions[np.where(probs >= cut_off)] = 1
|
| 229 |
+
logger.info(f"Applied sigmoid and thresholding with cutoff {cut_off} for {image_path}. Binary predictions shape: {predictions.shape}")
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"An error occurred during probability processing for {image_path}: {e}", exc_info=True)
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# Get label names using the loaded MultiLabelBinarizer
|
| 236 |
+
try:
|
| 237 |
+
logger.info(f"Performing inverse transform using MultiLabelBinarizer for {image_path}.")
|
| 238 |
+
# The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
|
| 239 |
+
# Use the self.mlb loaded during initialization
|
| 240 |
+
|
| 241 |
+
# Ensure self.mlb is not None (checked at the start of predict, but good practice)
|
| 242 |
+
if self.mlb is None:
|
| 243 |
+
logger.error(f"MultiLabelBinarizer is None. Cannot perform inverse transform for {image_path}.")
|
| 244 |
+
return None
|
| 245 |
+
|
| 246 |
+
binary_prediction: np.ndarray
|
| 247 |
+
|
| 248 |
+
# Ensure predictions shape is compatible (must be 2D: (n_samples, n_classes))
|
| 249 |
+
# Since we process one image at a time, expected shape is (1, n_classes)
|
| 250 |
+
expected_shape: Tuple[int, int] = (1, len(self.mlb.classes_))
|
| 251 |
+
|
| 252 |
+
if predictions.ndim == 1 and predictions.shape[0] == len(self.mlb.classes_):
|
| 253 |
+
binary_prediction = predictions.reshape(expected_shape)
|
| 254 |
+
logger.info(f"Reshaped 1D prediction to 2D ({expected_shape}) for inverse transform.")
|
| 255 |
+
elif predictions.ndim == 2 and predictions.shape == expected_shape:
|
| 256 |
+
binary_prediction = predictions
|
| 257 |
+
logger.info(f"Prediction already in correct 2D shape ({expected_shape}) for inverse transform.")
|
| 258 |
+
else:
|
| 259 |
+
logger.error(f"Cannot inverse transform prediction shape {predictions.shape} with MLB classes {len(self.mlb.classes_)} for {image_path}. Expected shape: {expected_shape}")
|
| 260 |
+
return None
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
predicted_labels_tuple_list: List[Tuple[str, ...]] = self.mlb.inverse_transform(binary_prediction)
|
| 264 |
+
logger.info(f"Prediction processed for {image_path}. Predicted labels (raw tuple list): {predicted_labels_tuple_list}")
|
| 265 |
+
|
| 266 |
+
# inverse_transform returns a list of tuples, even for a single sample.
|
| 267 |
+
# We expect a single prediction here, so we take the first tuple.
|
| 268 |
+
if predicted_labels_tuple_list and len(predicted_labels_tuple_list) > 0:
|
| 269 |
+
final_labels: List[str] = list(predicted_labels_tuple_list[0])
|
| 270 |
+
logger.info(f"Final predicted labels for {image_path}: {final_labels}")
|
| 271 |
+
return final_labels
|
| 272 |
+
else:
|
| 273 |
+
logger.warning(f"MLB inverse_transform returned an empty list for {image_path}. No labels predicted.")
|
| 274 |
+
return []
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error(f"An error occurred during inverse transform for {image_path}: {e}", exc_info=True)
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
|
artifacts/model/VIT_model/confusion_matrix.png
ADDED
|
Git LFS Details
|
artifacts/model/VIT_model/mlb.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4754cb9555905cbeb8a008ac90b2bb81ab076fbc272510a17c40abea32aa5d16
|
| 3 |
+
size 571
|
artifacts/model/VIT_model/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:223b9f3ccbe55b37f66ed7dd4c832116c17bec3229693a679da41351e9361a82
|
| 3 |
+
size 343310666
|
artifacts/model/vgg_model/mlb.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4754cb9555905cbeb8a008ac90b2bb81ab076fbc272510a17c40abea32aa5d16
|
| 3 |
+
size 571
|
artifacts/model/vgg_model/model.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad1f9fbf700dfac83efd97f5cc4f944ea5a628de9c0ba26d440abdd4b4426ef2
|
| 3 |
+
size 183090331
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.53.0
|
| 2 |
+
efficientnet==1.1.1
|
| 3 |
+
seaborn==0.13.2
|
| 4 |
+
libfinder==0.1.7
|
| 5 |
+
pathlib==1.0.1
|
| 6 |
+
requests==2.32.3
|
| 7 |
+
tensorflow==2.18.0
|
| 8 |
+
dagshub==0.5.10
|
| 9 |
+
google==2.0.3
|
| 10 |
+
torch==2.7.1
|
| 11 |
+
numpy==2.0.2
|
| 12 |
+
pandas==2.2.2
|
| 13 |
+
opencv-python
|
| 14 |
+
mlflow==3.1.1
|
| 15 |
+
keras==3.8.0
|
| 16 |
+
scikit-learn==1.6.1
|
| 17 |
+
ensure==1.0.4
|
| 18 |
+
joblib==1.5.1
|
| 19 |
+
matplotlib==3.10.0
|
| 20 |
+
ensure==1.0.4
|
| 21 |
+
python-box
|
| 22 |
+
pydot
|
| 23 |
+
graphviz
|
| 24 |
+
#'git+https://github.com/facebookresearch/detectron2.git'
|
| 25 |
+
gradio
|
| 26 |
+
fastapi==0.115.4
|
| 27 |
+
uvicorn==0.34.0
|
| 28 |
+
python-multipart== 0.0.19
|
| 29 |
+
-e .
|