Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| from app import model_loader | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| CARDIOMEGALY_INDEX = 2 | |
| PNEUMOTHORAX_INDEX = 9 | |
| THRESHOLD = 0.96 | |
| def predict_disease(image: Image.Image): | |
| img = image.convert("RGB") | |
| x = preprocess(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model_loader.cnn_model(x) | |
| probs = torch.sigmoid(logits).squeeze() | |
| c = probs[CARDIOMEGALY_INDEX].item() | |
| p = probs[PNEUMOTHORAX_INDEX].item() | |
| if c < THRESHOLD and p < THRESHOLD: | |
| label = "No Finding" | |
| conf = max(1 - c, 1 - p) | |
| elif c > p: | |
| label = "Cardiomegaly" | |
| conf = c | |
| else: | |
| label = "Pneumothorax (Pleural Effusion)" | |
| conf = p | |
| return label, conf, x, img | |
| def compute_ai_risk(label, confidence, cam): | |
| cam_score = float(np.mean(cam)) | |
| flag = 0 if label == "No Finding" else 1 | |
| vec = np.array([[confidence, cam_score, flag]]) | |
| vec = model_loader.scaler.transform(vec) | |
| cluster = model_loader.kmeans_model.predict(vec)[0] | |
| if cluster == 0: | |
| return "LOW 🟢" | |
| elif cluster == 1: | |
| return "MEDIUM 🟡" | |
| return "HIGH 🔴" |