import torch import torch.nn.functional as F import albumentations as A from albumentations.pytorch import ToTensorV2 import cv2 import numpy as np from models.classification_models.ResNet import * from models.segmentation_models.ResnetUnet import * class Pipeline: def __init__(self, img_size=256): self.transform = self._get_transforms(img_size) self.classification_model, self.segmentation_model = self._load_models() self.class_names = ['COVID', 'Non-COVID', 'Healthy'] def _get_transforms(self, img_size): return A.Compose([ A.LongestMaxSize(max_size=img_size), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ]) def _load_models(self): classification_model = resnet_model classification_model.load_state_dict(torch.load('weights/classification_models/resnet50.pt', map_location=torch.device('cpu'))) classification_model.eval() segmentation_model = ResNetUnet() checkpoint = torch.load('weights/segmentation_models/ResNetUnet_best.pt', map_location=torch.device('cpu')) segmentation_model.load_state_dict(checkpoint['model_state_dict']) segmentation_model.eval() return classification_model, segmentation_model def process_image(self, image, overlay_opacity=0.4): if image is None: return None, None, None, None image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) transformed = self.transform(image=image) input_tensor = transformed['image'].unsqueeze(0) with torch.inference_mode(): outputs = self.classification_model(input_tensor) probs = F.softmax(outputs, dim=1) pred_class = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_class].item() * 100 prediction = self.class_names[pred_class] if prediction == 'COVID': with torch.inference_mode(): output = self.segmentation_model(input_tensor) output = torch.sigmoid(output) output = output.squeeze().cpu().numpy() binary_mask = (output > 0.5).astype(np.uint8) * 255 mask_resized = cv2.resize(binary_mask, (image.shape[1], image.shape[0])) overlay = np.zeros_like(image) overlay[mask_resized > 0] = [255, 0, 0] blended = cv2.addWeighted(image, 1, overlay, overlay_opacity, 0) analysis_text = ( f"COVID-19 Detection Results:\n" f"• Infection detected with {confidence:.1f}% confidence\n" f"• Red overlay indicates areas of potential COVID-19 infection\n" f"• Recommended: Seek immediate medical attention" ) return prediction, confidence, blended, analysis_text elif prediction == 'Non-COVID': analysis_text = ( f"Non-COVID Lung Condition Detected:\n" f"• Confidence: {confidence:.1f}%\n" f"• Other lung abnormalities as pneumonia or lungs enlargement should be considered for further treatment\n" f"• Recommended: Consult with healthcare provider for proper diagnosis" ) return prediction, confidence, None, analysis_text else: analysis_text = ( f"Healthy Lung Scan Results:\n" f"• Confidence: {confidence:.1f}%\n" f"• No significant abnormalities detected :)\n" f"• Regular check-ups and an apple a day is recommended" ) return prediction, confidence, None, analysis_text