File size: 5,487 Bytes
248c98e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

import torch
import cv2
import numpy as np
import os

class TBClassifier:
    """
    Tuberculosis classifier using Vision Transformer
    """
    def __init__(self, model_path="model.pt", config_path="config.json"):
        """
        Initialize the classifier
        
        Args:
            model_path: Path to the TorchScript model file
            config_path: Path to the configuration file
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load configuration if available
        self.config = None
        if os.path.exists(config_path):
            import json
            with open(config_path, 'r') as f:
                self.config = json.load(f)
        
        # Load model
        try:
            self.model = torch.jit.load(model_path, map_location=self.device)
            self.model.eval()
            print(f"Model loaded successfully on {self.device}")
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {str(e)}")
        
        # Model configuration
        self.class_names = self.config.get('class_names', ['Normal', 'Tuberculosis']) if self.config else ['Normal', 'Tuberculosis']
        self.img_size = self.config.get('input_size', 224) if self.config else 224
        
        print(f"Classifier initialized with classes: {self.class_names}")

    def preprocess(self, image):
        """
        Preprocess input image for model inference
        
        Args:
            image: Input image as numpy array (BGR or RGB)
            
        Returns:
            Preprocessed tensor ready for model inference
        """
        try:
            # Handle different input formats
            if isinstance(image, str):
                image = cv2.imread(image)
            
            if image is None:
                raise ValueError("Invalid image input")
            
            # Convert to grayscale if needed
            if len(image.shape) == 3 and image.shape[2] == 3:
                # Assume BGR format from cv2
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            
            # Apply CLAHE for contrast enhancement
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            image = clahe.apply(image)
            
            # Apply Gaussian blur for noise reduction
            image = cv2.GaussianBlur(image, (5, 5), 0)
            
            # Resize to model input size
            image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
            
            # Convert back to RGB format
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            
            # Convert to tensor format (C, H, W)
            image = np.moveaxis(image, -1, 0).astype(np.float32)
            
            # Normalize
            image = (image - image.mean()) / (image.std() + 1e-8)
            
            # Add batch dimension and move to device
            tensor = torch.tensor(image).unsqueeze(0).to(self.device)
            
            return tensor
            
        except Exception as e:
            raise RuntimeError(f"Preprocessing failed: {str(e)}")

    def predict(self, image, return_probs=False):
        """
        Predict tuberculosis from chest X-ray image
        
        Args:
            image: Input image (file path, numpy array, or PIL image)
            return_probs: Whether to return raw probabilities
            
        Returns:
            Dictionary with prediction results
        """
        try:
            # Preprocess image
            processed_image = self.preprocess(image)
            
            # Model inference
            with torch.no_grad():
                output = self.model(processed_image)
                
                # Handle different output formats
                if len(output.shape) > 1:
                    output = output.squeeze(-1)
                
                prob = torch.sigmoid(output).item()  # Ensure probability is in [0,1]
                
                # Determine class
                class_id = 1 if prob > 0.5 else 0
                confidence = prob if class_id == 1 else 1 - prob
                prediction = self.class_names[class_id]
                
                result = {
                    "prediction": prediction,
                    "confidence": float(confidence),
                    "class_id": class_id
                }
                
                if return_probs:
                    result["raw_probability"] = float(prob)
                    result["probabilities"] = {
                        self.class_names[0]: float(1 - prob),
                        self.class_names[1]: float(prob)
                    }
                
                return result
                
        except Exception as e:
            return {
                "error": str(e),
                "prediction": None,
                "confidence": None
            }

    def batch_predict(self, images, return_probs=False):
        """
        Predict on multiple images
        
        Args:
            images: List of images
            return_probs: Whether to return raw probabilities
            
        Returns:
            List of prediction results
        """
        results = []
        for img in images:
            result = self.predict(img, return_probs=return_probs)
            results.append(result)
        return results