File size: 1,380 Bytes
d4f7659
 
 
 
95cca76
d4f7659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f199b95
d4f7659
 
 
 
 
 
 
95cca76
 
d4f7659
95cca76
f199b95
95cca76
 
 
 
d4f7659
95cca76
 
d4f7659
95cca76
d4f7659
 
95cca76
d4f7659
 
95cca76
d4f7659
95cca76
d4f7659
95cca76
 
d4f7659
 
95cca76
d4f7659
95cca76
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from app import model_loader


preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

CARDIOMEGALY_INDEX = 2
PNEUMOTHORAX_INDEX = 9
THRESHOLD = 0.96


def predict_disease(image: Image.Image):

    img = image.convert("RGB")
    x = preprocess(img).unsqueeze(0)

    with torch.no_grad():
        logits = model_loader.cnn_model(x)
        probs = torch.sigmoid(logits).squeeze()

    c = probs[CARDIOMEGALY_INDEX].item()
    p = probs[PNEUMOTHORAX_INDEX].item()

    if c < THRESHOLD and p < THRESHOLD:
        label = "No Finding"
        conf = max(1 - c, 1 - p)
    elif c > p:
        label = "Cardiomegaly"
        conf = c
    else:
        label = "Pneumothorax (Pleural Effusion)"
        conf = p

    return label, conf, x, img


def compute_ai_risk(label, confidence, cam):

    cam_score = float(np.mean(cam))
    flag = 0 if label == "No Finding" else 1

    vec = np.array([[confidence, cam_score, flag]])

    vec = model_loader.scaler.transform(vec)
    cluster = model_loader.kmeans_model.predict(vec)[0]

    if cluster == 0:
        return "LOW 🟢"
    elif cluster == 1:
        return "MEDIUM 🟡"
    return "HIGH 🔴"