File size: 3,674 Bytes
07fc447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import joblib
from sklearn.preprocessing import MultiLabelBinarizer
from pathlib import Path
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from app.src.logger import setup_logger

logger = setup_logger("test_vit")

try:
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mlb_file_path=Path("artifacts\model\VIT_model\mlb.joblib")
    model_file_path=Path("artifacts\model\VIT_model\model.pth")
    # Select model
    model_id = "google/vit-base-patch16-224-in21k"
    # Load processor
    processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)

    # TODO: You need to load your fine-tuned model here
    # For example:
    # model = AutoModelForImageClassification.from_pretrained("path/to/your/fine-tuned-model")
    # For now, we will use the base model for demonstration, but it will not give correct predictions.
    #model = AutoModelForImageClassification.from_pretrained(model_id)
    # Load the entire model
    model= torch.load(model_file_path, map_location=device,weights_only=False )
    # Set device
    model.to(device)

except Exception as e:
    logger.error(str(e))
    raise e




def mlb_load(file_path:Path)->MultiLabelBinarizer:
    try:
        # Assuming you run this notebook from the root of your project directory
        mlb = joblib.load(file_path)

    except FileNotFoundError:
        logger.error("Error: 'artifacts/model/VIT_model/mlb.joblib' not found.")
        logger.error("Please make sure the path is correct. Using a placeholder binarizer.")
        # As a placeholder, let's create a dummy mlb if the file is not found.
        mlb = MultiLabelBinarizer()
        # This should be the set of your actual labels.
        mlb.fit([['advertisement', 'email', 'form', 'invoice', 'note']])
    return mlb






def VIT_model_prediction(image_path:Path,cut_off:float):
    try:
        # Load and convert image
        # --- IMPORTANT: Please update this path to your image ---
        try:
            image = Image.open(image_path)
            if image.mode != "RGB":
                image = image.convert("RGB")
        except FileNotFoundError:
            logger.error(f"Error: Image not found at {image_path}")
            logger.error("Using a dummy image for demonstration.")
            # Create a dummy image for demonstration if image not found
            image = Image.new('RGB', (224, 224), color = 'red')


        # Preprocess image
        pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

        # Forward pass
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits

        # Apply sigmoid for multi-label classification
        sigmoid = torch.nn.Sigmoid()
        probs = sigmoid(logits.squeeze().cpu())

        # Thresholding (using 0.5 as an example)
        predictions = np.zeros(probs.shape)
        predictions[np.where(probs >= cut_off)] = 1

        # Get label names using the loaded MultiLabelBinarizer
        mlb=mlb_load(mlb_file_path)
        # The predictions need to be in a 2D array for inverse_transform, e.g., (1, num_classes)
        predicted_labels = mlb.inverse_transform(predictions.reshape(1, -1))
        logger.info(f"Predicted labels: {predicted_labels}")
        return {"status":1,"classe":predicted_labels}

    except Exception as e:
        logger.error(str(e))
        raise e



#VIT_model_prediction(Path(r"dataset\sample_text_ds\test\email\2078379610a.jpg"),0.5)