sukhmani1303's picture
Upload tuberculosis ViT model with complete configuration
248c98e verified
import torch
import cv2
import numpy as np
import os
class TBClassifier:
"""
Tuberculosis classifier using Vision Transformer
"""
def __init__(self, model_path="model.pt", config_path="config.json"):
"""
Initialize the classifier
Args:
model_path: Path to the TorchScript model file
config_path: Path to the configuration file
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load configuration if available
self.config = None
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
self.config = json.load(f)
# Load model
try:
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Model configuration
self.class_names = self.config.get('class_names', ['Normal', 'Tuberculosis']) if self.config else ['Normal', 'Tuberculosis']
self.img_size = self.config.get('input_size', 224) if self.config else 224
print(f"Classifier initialized with classes: {self.class_names}")
def preprocess(self, image):
"""
Preprocess input image for model inference
Args:
image: Input image as numpy array (BGR or RGB)
Returns:
Preprocessed tensor ready for model inference
"""
try:
# Handle different input formats
if isinstance(image, str):
image = cv2.imread(image)
if image is None:
raise ValueError("Invalid image input")
# Convert to grayscale if needed
if len(image.shape) == 3 and image.shape[2] == 3:
# Assume BGR format from cv2
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply CLAHE for contrast enhancement
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
image = clahe.apply(image)
# Apply Gaussian blur for noise reduction
image = cv2.GaussianBlur(image, (5, 5), 0)
# Resize to model input size
image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
# Convert back to RGB format
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Convert to tensor format (C, H, W)
image = np.moveaxis(image, -1, 0).astype(np.float32)
# Normalize
image = (image - image.mean()) / (image.std() + 1e-8)
# Add batch dimension and move to device
tensor = torch.tensor(image).unsqueeze(0).to(self.device)
return tensor
except Exception as e:
raise RuntimeError(f"Preprocessing failed: {str(e)}")
def predict(self, image, return_probs=False):
"""
Predict tuberculosis from chest X-ray image
Args:
image: Input image (file path, numpy array, or PIL image)
return_probs: Whether to return raw probabilities
Returns:
Dictionary with prediction results
"""
try:
# Preprocess image
processed_image = self.preprocess(image)
# Model inference
with torch.no_grad():
output = self.model(processed_image)
# Handle different output formats
if len(output.shape) > 1:
output = output.squeeze(-1)
prob = torch.sigmoid(output).item() # Ensure probability is in [0,1]
# Determine class
class_id = 1 if prob > 0.5 else 0
confidence = prob if class_id == 1 else 1 - prob
prediction = self.class_names[class_id]
result = {
"prediction": prediction,
"confidence": float(confidence),
"class_id": class_id
}
if return_probs:
result["raw_probability"] = float(prob)
result["probabilities"] = {
self.class_names[0]: float(1 - prob),
self.class_names[1]: float(prob)
}
return result
except Exception as e:
return {
"error": str(e),
"prediction": None,
"confidence": None
}
def batch_predict(self, images, return_probs=False):
"""
Predict on multiple images
Args:
images: List of images
return_probs: Whether to return raw probabilities
Returns:
List of prediction results
"""
results = []
for img in images:
result = self.predict(img, return_probs=return_probs)
results.append(result)
return results