File size: 5,582 Bytes
131eab2
b383602
 
3726151
b383602
 
 
48b3884
 
b383602
48b3884
b383602
131eab2
 
 
 
 
 
 
 
 
b383602
 
 
3726151
 
 
 
 
 
 
 
 
 
131eab2
3726151
131eab2
 
3726151
 
 
 
 
 
131eab2
3726151
131eab2
3726151
131eab2
 
 
 
 
3726151
131eab2
 
b383602
48b3884
b383602
48b3884
b383602
48b3884
b383602
48b3884
 
b383602
 
48b3884
 
 
 
b383602
131eab2
 
 
3726151
48b3884
b383602
131eab2
3726151
b383602
3726151
48b3884
 
131eab2
 
48b3884
b383602
3726151
48b3884
b383602
 
 
3726151
b383602
131eab2
b383602
 
 
48b3884
 
131eab2
 
48b3884
 
 
 
 
b383602
 
48b3884
 
 
 
cbf6689
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# app/prediction.py (Final Version with Relaxed Sanity Check)

import torch
from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, ResNetForImageClassification
from PIL import Image
from pathlib import Path
import numpy as np
from typing import List, Dict, Union, Any
from .image_utils import add_watermark

ImageType = Union[str, Path, bytes, np.ndarray]

# A list of obviously non-medical terms to check against
FORBIDDEN_LABELS = [
    "car", "truck", "van", "motorcycle", "bicycle", "bus", "train", "boat", "airplane",
    "cat", "dog", "bird", "horse", "sheep", "cow", "bear", "zebra", "giraffe",
    "landscape", "mountain", "beach", "forest", "building", "house", "road", "street",
    "computer", "keyboard", "mouse", "laptop", "cellphone", "television",
    "food", "plate", "bowl", "cup", "fork", "knife", "spoon"
]

class PredictionPipeline:
    def __init__(self, model_path: Path = Path("artifacts/model_training/model")):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.pneumonia_processor = ViTImageProcessor.from_pretrained(model_path)
        self.pneumonia_model = ViTForImageClassification.from_pretrained(model_path).to(self.device)
        self.pneumonia_model.eval()
        self.id2label = self.pneumonia_model.config.id2label

        self.sanity_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
        self.sanity_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50").to(self.device)
        self.sanity_model.eval()

    def sanity_check(self, image: Image.Image) -> bool:
        """
        Uses a general-purpose model to check if the image is something obviously
        not a medical scan. Returns True if the image is plausible, False otherwise.
        """
        with torch.no_grad():
            inputs = self.sanity_processor(images=image, return_tensors="pt").to(self.device)
            outputs = self.sanity_model(**inputs)
            logits = outputs.logits
            
            top5_indices = torch.topk(logits, 5).indices[0]
            
            for idx in top5_indices:
                label = self.sanity_model.config.id2label[idx.item()].lower()
                # Check for partial matches (e.g., 'sports car', 'fire truck')
                for forbidden in FORBIDDEN_LABELS:
                    if forbidden in label:
                        print(f"Sanity check FAILED: Image classified as '{label}', which contains a forbidden term '{forbidden}'.")
                        return False # It's definitely not an X-ray
        
        print("Sanity check PASSED: Image does not appear to be a common non-medical object.")
        return True # It's plausible enough to proceed

    def predict(self, image_sources: List[ImageType]) -> Dict[str, Any]:
        if not image_sources:
            return {"error": "No images provided."}

        individual_results = []
        all_logits = []
        valid_images_as_np = []

        for source in image_sources:
            try:
                if isinstance(source, np.ndarray):
                    image = Image.fromarray(source).convert("RGB")
                else:
                    image = Image.open(source).convert("RGB")
                
                # --- NEW: Perform the relaxed sanity check ---
                if not self.sanity_check(image):
                    raise ValueError("Image appears to be a common object, not a medical scan.")

                valid_images_as_np.append(np.array(image))
                
                # ... (rest of the prediction logic is the same)
                inputs = self.pneumonia_processor(images=image, return_tensors="pt").to(self.device)
                with torch.no_grad():
                    outputs = self.pneumonia_model(**inputs)
                    logits = outputs.logits
                    all_logits.append(logits)
                    ind_probs = torch.nn.functional.softmax(logits, dim=-1); ind_conf, ind_idx = torch.max(ind_probs, dim=-1)
                    individual_results.append({"prediction": self.id2label[ind_idx.item()], "confidence": ind_conf.item()})

            except Exception as e:
                print(f"Skipping an invalid image file. Error: {e}")
                individual_results.append({"prediction": "Error", "confidence": 0})
                continue
        
        if not all_logits:
             return {"error": "Invalid Image", "details": "All uploaded files were invalid or did not appear to be chest X-rays. Please upload a clear, frontal chest X-ray image."}

        # ... (Aggregate prediction and watermarking are the same)
        avg_logits = torch.mean(torch.stack(all_logits), dim=0)
        probabilities = torch.nn.functional.softmax(avg_logits, dim=-1)
        confidence_score, predicted_class_idx = torch.max(probabilities, dim=-1)
        final_prediction = self.id2label[predicted_class_idx.item()]
        final_confidence = confidence_score.item()
        # NOTE: The low-confidence check has been removed as the sanity check is more robust.
        
        watermarked_images = [
            add_watermark(img_np, res["prediction"], res["confidence"])
            for img_np, res in zip(valid_images_as_np, individual_results)
            if res["prediction"] != "Error"
        ]
        
        return {
            "final_prediction": final_prediction,
            "final_confidence": final_confidence,
            "individual_results": individual_results,
            "watermarked_images": watermarked_images
        }