File size: 5,487 Bytes
248c98e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import cv2
import numpy as np
import os
class TBClassifier:
"""
Tuberculosis classifier using Vision Transformer
"""
def __init__(self, model_path="model.pt", config_path="config.json"):
"""
Initialize the classifier
Args:
model_path: Path to the TorchScript model file
config_path: Path to the configuration file
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load configuration if available
self.config = None
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
self.config = json.load(f)
# Load model
try:
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Model configuration
self.class_names = self.config.get('class_names', ['Normal', 'Tuberculosis']) if self.config else ['Normal', 'Tuberculosis']
self.img_size = self.config.get('input_size', 224) if self.config else 224
print(f"Classifier initialized with classes: {self.class_names}")
def preprocess(self, image):
"""
Preprocess input image for model inference
Args:
image: Input image as numpy array (BGR or RGB)
Returns:
Preprocessed tensor ready for model inference
"""
try:
# Handle different input formats
if isinstance(image, str):
image = cv2.imread(image)
if image is None:
raise ValueError("Invalid image input")
# Convert to grayscale if needed
if len(image.shape) == 3 and image.shape[2] == 3:
# Assume BGR format from cv2
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply CLAHE for contrast enhancement
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
image = clahe.apply(image)
# Apply Gaussian blur for noise reduction
image = cv2.GaussianBlur(image, (5, 5), 0)
# Resize to model input size
image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
# Convert back to RGB format
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Convert to tensor format (C, H, W)
image = np.moveaxis(image, -1, 0).astype(np.float32)
# Normalize
image = (image - image.mean()) / (image.std() + 1e-8)
# Add batch dimension and move to device
tensor = torch.tensor(image).unsqueeze(0).to(self.device)
return tensor
except Exception as e:
raise RuntimeError(f"Preprocessing failed: {str(e)}")
def predict(self, image, return_probs=False):
"""
Predict tuberculosis from chest X-ray image
Args:
image: Input image (file path, numpy array, or PIL image)
return_probs: Whether to return raw probabilities
Returns:
Dictionary with prediction results
"""
try:
# Preprocess image
processed_image = self.preprocess(image)
# Model inference
with torch.no_grad():
output = self.model(processed_image)
# Handle different output formats
if len(output.shape) > 1:
output = output.squeeze(-1)
prob = torch.sigmoid(output).item() # Ensure probability is in [0,1]
# Determine class
class_id = 1 if prob > 0.5 else 0
confidence = prob if class_id == 1 else 1 - prob
prediction = self.class_names[class_id]
result = {
"prediction": prediction,
"confidence": float(confidence),
"class_id": class_id
}
if return_probs:
result["raw_probability"] = float(prob)
result["probabilities"] = {
self.class_names[0]: float(1 - prob),
self.class_names[1]: float(prob)
}
return result
except Exception as e:
return {
"error": str(e),
"prediction": None,
"confidence": None
}
def batch_predict(self, images, return_probs=False):
"""
Predict on multiple images
Args:
images: List of images
return_probs: Whether to return raw probabilities
Returns:
List of prediction results
"""
results = []
for img in images:
result = self.predict(img, return_probs=return_probs)
results.append(result)
return results
|