File size: 3,278 Bytes
b383602
 
 
 
 
 
 
48b3884
 
b383602
48b3884
b383602
 
 
 
 
 
 
 
 
48b3884
b383602
48b3884
b383602
48b3884
b383602
48b3884
 
b383602
 
48b3884
 
 
 
b383602
48b3884
b383602
48b3884
b383602
 
48b3884
 
 
 
 
 
 
 
 
 
 
b383602
 
48b3884
b383602
 
 
48b3884
b383602
48b3884
b383602
 
 
48b3884
 
 
 
 
 
 
 
 
 
b383602
 
48b3884
 
 
 
b383602
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# app/prediction.py

import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark

ImageType = Union[str, Path, bytes, np.ndarray]

class PredictionPipeline:
    def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = ViTImageProcessor.from_pretrained(model_path)
        self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
        self.model.eval()
        self.id2label = self.model.config.id2label

    def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
        if not image_sources:
            return {"error": "No images provided."}

        individual_results = []
        all_logits = []
        valid_images_as_np = []

        for source in image_sources:
            try:
                if isinstance(source, np.ndarray):
                    image = Image.fromarray(source).convert("RGB")
                else:
                    image = Image.open(source).convert("RGB")
                
                valid_images_as_np.append(np.array(image))
                
                inputs = self.processor(images=image, return_tensors="pt").to(self.device)
                with torch.no_grad():
                    outputs = self.model(**inputs)
                    logits = outputs.logits
                    all_logits.append(logits)
                    
                    # --- NEW: Calculate individual prediction ---
                    ind_probs = torch.nn.functional.softmax(logits, dim=-1)
                    ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
                    individual_results.append({
                        "prediction": self.id2label[ind_idx.item()],
                        "confidence": ind_conf.item()
                    })

            except Exception as e:
                print(f"Skipping a corrupted or invalid image file. Error: {e}")
                individual_results.append({"prediction": "Error", "confidence": 0})
                continue
        
        if not all_logits:
             return {"error": "All images were invalid."}

        # --- Aggregate Prediction (same as before) ---
        avg_logits = torch.mean(torch.stack(all_logits), dim=0)
        probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
        confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
        
        final_prediction = self.id2label[predicted_class_idx.item()]
        final_confidence = confidence_score.item()

        # --- NEW: Watermark images with their INDIVIDUAL results ---
        watermarked_images = [
            add_watermark(img_np, res["prediction"], res["confidence"])
            for img_np, res in zip(valid_images_as_np, individual_results)
            if res["prediction"] != "Error"
        ]
        
        return {
            "final_prediction": final_prediction,
            "final_confidence": final_confidence,
            "individual_results": individual_results,
            "watermarked_images": watermarked_images
        }