File size: 3,559 Bytes
b383602
 
 
 
 
 
 
48b3884
 
b383602
48b3884
b383602
 
 
 
 
 
 
 
 
48b3884
b383602
48b3884
b383602
48b3884
b383602
48b3884
 
b383602
 
48b3884
 
 
 
b383602
48b3884
b383602
48b3884
b383602
 
48b3884
 
 
 
 
 
 
 
 
 
 
b383602
 
48b3884
b383602
 
 
48b3884
b383602
 
 
 
48b3884
 
 
 
cbf6689
 
 
 
 
 
 
 
 
48b3884
 
 
 
 
b383602
 
48b3884
 
 
 
cbf6689
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# app/prediction.py

import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark

ImageType = Union[str, Path, bytes, np.ndarray]

class PredictionPipeline:
    def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = ViTImageProcessor.from_pretrained(model_path)
        self.model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
        self.model.eval()
        self.id2label = self.model.config.id2label

    def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
        if not image_sources:
            return {"error": "No images provided."}

        individual_results = []
        all_logits = []
        valid_images_as_np = []

        for source in image_sources:
            try:
                if isinstance(source, np.ndarray):
                    image = Image.fromarray(source).convert("RGB")
                else:
                    image = Image.open(source).convert("RGB")
                
                valid_images_as_np.append(np.array(image))
                
                inputs = self.processor(images=image, return_tensors="pt").to(self.device)
                with torch.no_grad():
                    outputs = self.model(**inputs)
                    logits = outputs.logits
                    all_logits.append(logits)
                    
                    # --- NEW: Calculate individual prediction ---
                    ind_probs = torch.nn.functional.softmax(logits, dim=-1)
                    ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
                    individual_results.append({
                        "prediction": self.id2label[ind_idx.item()],
                        "confidence": ind_conf.item()
                    })

            except Exception as e:
                print(f"Skipping a corrupted or invalid image file. Error: {e}")
                individual_results.append({"prediction": "Error", "confidence": 0})
                continue
        
        if not all_logits:
             return {"error": "All images were invalid."}

        avg_logits = torch.mean(torch.stack(all_logits), dim=0)
        probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
        confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
        
        final_prediction = self.id2label[predicted_class_idx.item()]
        final_confidence = confidence_score.item()

        # --- NEW: Add confidence check ---
        if final_confidence < 0.60:
            return {
                "error": "Low Confidence Prediction",
                "details": f"The model's confidence of {final_confidence:.1%} is too low. "
                           "Please ensure the uploaded image is a clear, frontal chest X-ray."
            }

        # --- Watermarking (same as before) ---
        watermarked_images = [
            add_watermark(img_np, res["prediction"], res["confidence"])
            for img_np, res in zip(valid_images_as_np, individual_results)
            if res["prediction"] != "Error"
        ]
        
        return {
            "final_prediction": final_prediction,
            "final_confidence": final_confidence,
            "individual_results": individual_results,
            "watermarked_images": watermarked_images
        }