import torch import numpy as np from PIL import Image from torchvision import transforms from app import model_loader preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) CARDIOMEGALY_INDEX = 2 PNEUMOTHORAX_INDEX = 9 THRESHOLD = 0.96 def predict_disease(image: Image.Image): img = image.convert("RGB") x = preprocess(img).unsqueeze(0) with torch.no_grad(): logits = model_loader.cnn_model(x) probs = torch.sigmoid(logits).squeeze() c = probs[CARDIOMEGALY_INDEX].item() p = probs[PNEUMOTHORAX_INDEX].item() if c < THRESHOLD and p < THRESHOLD: label = "No Finding" conf = max(1 - c, 1 - p) elif c > p: label = "Cardiomegaly" conf = c else: label = "Pneumothorax (Pleural Effusion)" conf = p return label, conf, x, img def compute_ai_risk(label, confidence, cam): cam_score = float(np.mean(cam)) flag = 0 if label == "No Finding" else 1 vec = np.array([[confidence, cam_score, flag]]) vec = model_loader.scaler.transform(vec) cluster = model_loader.kmeans_model.predict(vec)[0] if cluster == 0: return "LOW 🟢" elif cluster == 1: return "MEDIUM 🟡" return "HIGH 🔴"