KaushiGihan's picture
Upload 17 files
07fc447 verified
import joblib
from sklearn.preprocessing import MultiLabelBinarizer
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from app.src.logger import setup_logger
logger = setup_logger("test_vit")
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mlb_file_path=Path("artifacts\model\VIT_model\mlb.joblib")
model_file_path=Path("artifacts\model\VIT_model\model.pth")
# Select model
model_id = "google/vit-base-patch16-224-in21k"
# Load processor
processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
# TODO: You need to load your fine-tuned model here
# For example:
# model = AutoModelForImageClassification.from_pretrained("path/to/your/fine-tuned-model")
# For now, we will use the base model for demonstration, but it will not give correct predictions.
#model = AutoModelForImageClassification.from_pretrained(model_id)
# Load the entire model
model= torch.load(model_file_path, map_location=device,weights_only=False )
# Set device
model.to(device)
except Exception as e:
logger.error(str(e))
raise e
def mlb_load(file_path:Path)->MultiLabelBinarizer:
try:
# Assuming you run this notebook from the root of your project directory
mlb = joblib.load(file_path)
except FileNotFoundError:
logger.error("Error: 'artifacts/model/VIT_model/mlb.joblib' not found.")
logger.error("Please make sure the path is correct. Using a placeholder binarizer.")
# As a placeholder, let's create a dummy mlb if the file is not found.
mlb = MultiLabelBinarizer()
# This should be the set of your actual labels.
mlb.fit([['advertisement', 'email', 'form', 'invoice', 'note']])
return mlb
def VIT_model_prediction(image_path:Path,cut_off:float):
try:
# Load and convert image
# --- IMPORTANT: Please update this path to your image ---
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
except FileNotFoundError:
logger.error(f"Error: Image not found at {image_path}")
logger.error("Using a dummy image for demonstration.")
# Create a dummy image for demonstration if image not found
image = Image.new('RGB', (224, 224), color = 'red')
# Preprocess image
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
# Forward pass
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
# Apply sigmoid for multi-label classification
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
# Thresholding (using 0.5 as an example)
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= cut_off)] = 1
# Get label names using the loaded MultiLabelBinarizer
mlb=mlb_load(mlb_file_path)
# The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
predicted_labels = mlb.inverse_transform(predictions.reshape(1, -1))
logger.info(f"Predicted labels: {predicted_labels}")
return {"status":1,"classe":predicted_labels}
except Exception as e:
logger.error(str(e))
raise e
#VIT_model_prediction(Path(r"dataset\sample_text_ds\test\email\2078379610a.jpg"),0.5)