Sriomdash's picture
Upload 11 files
33a3d61 verified
import io
import numpy as np
from PIL import Image
from fastapi import HTTPException
LABELS = [
"No DR",
"Mild",
"Moderate",
"Severe",
"Proliferative DR"
]
def preprocess_and_split(image_bytes: bytes):
"""
Input: raw image bytes
Output:
imgs1 → for model_1
imgs2 → for model_2
"""
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
image = image.resize((224, 224))
images = [
image, # original
image.transpose(Image.FLIP_LEFT_RIGHT),
]
def to_array(img_list):
return np.array([
np.array(img, dtype=np.float32) # matches training (0–255)
for img in img_list
])
imgs = to_array(images)
return imgs, imgs
def predict_split(model_1, model_2, imgs1, imgs2, debug=False):
"""
Runs inference and returns:
label (str)
confidence (float)
"""
try:
preds1 = model_1.predict(imgs1, verbose=0)
preds2 = model_2.predict(imgs2, verbose=0)
# Average per model
avg1 = np.mean(preds1, axis=0)
avg2 = np.mean(preds2, axis=0)
final_pred = 0.50 * avg1 + 0.50 * avg2
confidence = float(np.max(final_pred))
predicted_index = int(np.argmax(final_pred))
if confidence < 0.35:
return "Uncertain", confidence
if debug:
top2 = np.argsort(final_pred)[-2:][::-1]
print("Prediction vector:", final_pred)
print("Top-1:", LABELS[top2[0]])
print("Top-2:", LABELS[top2[1]])
return LABELS[predicted_index], confidence
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)