import torch import cv2 import numpy as np import os class TBClassifier: """ Tuberculosis classifier using Vision Transformer """ def __init__(self, model_path="model.pt", config_path="config.json"): """ Initialize the classifier Args: model_path: Path to the TorchScript model file config_path: Path to the configuration file """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load configuration if available self.config = None if os.path.exists(config_path): import json with open(config_path, 'r') as f: self.config = json.load(f) # Load model try: self.model = torch.jit.load(model_path, map_location=self.device) self.model.eval() print(f"Model loaded successfully on {self.device}") except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") # Model configuration self.class_names = self.config.get('class_names', ['Normal', 'Tuberculosis']) if self.config else ['Normal', 'Tuberculosis'] self.img_size = self.config.get('input_size', 224) if self.config else 224 print(f"Classifier initialized with classes: {self.class_names}") def preprocess(self, image): """ Preprocess input image for model inference Args: image: Input image as numpy array (BGR or RGB) Returns: Preprocessed tensor ready for model inference """ try: # Handle different input formats if isinstance(image, str): image = cv2.imread(image) if image is None: raise ValueError("Invalid image input") # Convert to grayscale if needed if len(image.shape) == 3 and image.shape[2] == 3: # Assume BGR format from cv2 image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Apply CLAHE for contrast enhancement clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) image = clahe.apply(image) # Apply Gaussian blur for noise reduction image = cv2.GaussianBlur(image, (5, 5), 0) # Resize to model input size image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) # Convert back to RGB format image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # Convert to tensor format (C, H, W) image = np.moveaxis(image, -1, 0).astype(np.float32) # Normalize image = (image - image.mean()) / (image.std() + 1e-8) # Add batch dimension and move to device tensor = torch.tensor(image).unsqueeze(0).to(self.device) return tensor except Exception as e: raise RuntimeError(f"Preprocessing failed: {str(e)}") def predict(self, image, return_probs=False): """ Predict tuberculosis from chest X-ray image Args: image: Input image (file path, numpy array, or PIL image) return_probs: Whether to return raw probabilities Returns: Dictionary with prediction results """ try: # Preprocess image processed_image = self.preprocess(image) # Model inference with torch.no_grad(): output = self.model(processed_image) # Handle different output formats if len(output.shape) > 1: output = output.squeeze(-1) prob = torch.sigmoid(output).item() # Ensure probability is in [0,1] # Determine class class_id = 1 if prob > 0.5 else 0 confidence = prob if class_id == 1 else 1 - prob prediction = self.class_names[class_id] result = { "prediction": prediction, "confidence": float(confidence), "class_id": class_id } if return_probs: result["raw_probability"] = float(prob) result["probabilities"] = { self.class_names[0]: float(1 - prob), self.class_names[1]: float(prob) } return result except Exception as e: return { "error": str(e), "prediction": None, "confidence": None } def batch_predict(self, images, return_probs=False): """ Predict on multiple images Args: images: List of images return_probs: Whether to return raw probabilities Returns: List of prediction results """ results = [] for img in images: result = self.predict(img, return_probs=return_probs) results.append(result) return results