File size: 2,004 Bytes
33a3d61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import io
import numpy as np
from PIL import Image
from fastapi import HTTPException


LABELS = [
    "No DR",
    "Mild",
    "Moderate",
    "Severe",
    "Proliferative DR"
]


def preprocess_and_split(image_bytes: bytes):
    """

    Input: raw image bytes

    Output:

        imgs1 → for model_1

        imgs2 → for model_2

    """

    try:
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid image file")

    image = image.resize((224, 224))

    
    images = [
        image,  # original
        image.transpose(Image.FLIP_LEFT_RIGHT),  
    ]

    def to_array(img_list):
        return np.array([
            np.array(img, dtype=np.float32)  # matches training (0–255)
            for img in img_list
        ])

    imgs = to_array(images)

    return imgs, imgs


def predict_split(model_1, model_2, imgs1, imgs2, debug=False):
    """

    Runs inference and returns:

        label (str)

        confidence (float)

    """

    try:
        preds1 = model_1.predict(imgs1, verbose=0)
        preds2 = model_2.predict(imgs2, verbose=0)

        # Average per model
        avg1 = np.mean(preds1, axis=0)
        avg2 = np.mean(preds2, axis=0)

       
        final_pred = 0.50 * avg1 + 0.50 * avg2

        confidence = float(np.max(final_pred))
        predicted_index = int(np.argmax(final_pred))

        
        if confidence < 0.35:
            return "Uncertain", confidence

       
        if debug:
            top2 = np.argsort(final_pred)[-2:][::-1]
            print("Prediction vector:", final_pred)
            print("Top-1:", LABELS[top2[0]])
            print("Top-2:", LABELS[top2[1]])

        return LABELS[predicted_index], confidence

    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Prediction failed: {str(e)}"
        )