|
|
|
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
class TBClassifier: |
|
|
""" |
|
|
Tuberculosis classifier using Vision Transformer |
|
|
""" |
|
|
def __init__(self, model_path="model.pt", config_path="config.json"): |
|
|
""" |
|
|
Initialize the classifier |
|
|
|
|
|
Args: |
|
|
model_path: Path to the TorchScript model file |
|
|
config_path: Path to the configuration file |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.config = None |
|
|
if os.path.exists(config_path): |
|
|
import json |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = torch.jit.load(model_path, map_location=self.device) |
|
|
self.model.eval() |
|
|
print(f"Model loaded successfully on {self.device}") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load model: {str(e)}") |
|
|
|
|
|
|
|
|
self.class_names = self.config.get('class_names', ['Normal', 'Tuberculosis']) if self.config else ['Normal', 'Tuberculosis'] |
|
|
self.img_size = self.config.get('input_size', 224) if self.config else 224 |
|
|
|
|
|
print(f"Classifier initialized with classes: {self.class_names}") |
|
|
|
|
|
def preprocess(self, image): |
|
|
""" |
|
|
Preprocess input image for model inference |
|
|
|
|
|
Args: |
|
|
image: Input image as numpy array (BGR or RGB) |
|
|
|
|
|
Returns: |
|
|
Preprocessed tensor ready for model inference |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(image, str): |
|
|
image = cv2.imread(image) |
|
|
|
|
|
if image is None: |
|
|
raise ValueError("Invalid image input") |
|
|
|
|
|
|
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
|
|
image = clahe.apply(image) |
|
|
|
|
|
|
|
|
image = cv2.GaussianBlur(image, (5, 5), 0) |
|
|
|
|
|
|
|
|
image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
|
|
|
image = np.moveaxis(image, -1, 0).astype(np.float32) |
|
|
|
|
|
|
|
|
image = (image - image.mean()) / (image.std() + 1e-8) |
|
|
|
|
|
|
|
|
tensor = torch.tensor(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
return tensor |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Preprocessing failed: {str(e)}") |
|
|
|
|
|
def predict(self, image, return_probs=False): |
|
|
""" |
|
|
Predict tuberculosis from chest X-ray image |
|
|
|
|
|
Args: |
|
|
image: Input image (file path, numpy array, or PIL image) |
|
|
return_probs: Whether to return raw probabilities |
|
|
|
|
|
Returns: |
|
|
Dictionary with prediction results |
|
|
""" |
|
|
try: |
|
|
|
|
|
processed_image = self.preprocess(image) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(processed_image) |
|
|
|
|
|
|
|
|
if len(output.shape) > 1: |
|
|
output = output.squeeze(-1) |
|
|
|
|
|
prob = torch.sigmoid(output).item() |
|
|
|
|
|
|
|
|
class_id = 1 if prob > 0.5 else 0 |
|
|
confidence = prob if class_id == 1 else 1 - prob |
|
|
prediction = self.class_names[class_id] |
|
|
|
|
|
result = { |
|
|
"prediction": prediction, |
|
|
"confidence": float(confidence), |
|
|
"class_id": class_id |
|
|
} |
|
|
|
|
|
if return_probs: |
|
|
result["raw_probability"] = float(prob) |
|
|
result["probabilities"] = { |
|
|
self.class_names[0]: float(1 - prob), |
|
|
self.class_names[1]: float(prob) |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
"error": str(e), |
|
|
"prediction": None, |
|
|
"confidence": None |
|
|
} |
|
|
|
|
|
def batch_predict(self, images, return_probs=False): |
|
|
""" |
|
|
Predict on multiple images |
|
|
|
|
|
Args: |
|
|
images: List of images |
|
|
return_probs: Whether to return raw probabilities |
|
|
|
|
|
Returns: |
|
|
List of prediction results |
|
|
""" |
|
|
results = [] |
|
|
for img in images: |
|
|
result = self.predict(img, return_probs=return_probs) |
|
|
results.append(result) |
|
|
return results |
|
|
|