import io import numpy as np from PIL import Image from fastapi import HTTPException LABELS = [ "No DR", "Mild", "Moderate", "Severe", "Proliferative DR" ] def preprocess_and_split(image_bytes: bytes): """ Input: raw image bytes Output: imgs1 → for model_1 imgs2 → for model_2 """ try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image file") image = image.resize((224, 224)) images = [ image, # original image.transpose(Image.FLIP_LEFT_RIGHT), ] def to_array(img_list): return np.array([ np.array(img, dtype=np.float32) # matches training (0–255) for img in img_list ]) imgs = to_array(images) return imgs, imgs def predict_split(model_1, model_2, imgs1, imgs2, debug=False): """ Runs inference and returns: label (str) confidence (float) """ try: preds1 = model_1.predict(imgs1, verbose=0) preds2 = model_2.predict(imgs2, verbose=0) # Average per model avg1 = np.mean(preds1, axis=0) avg2 = np.mean(preds2, axis=0) final_pred = 0.50 * avg1 + 0.50 * avg2 confidence = float(np.max(final_pred)) predicted_index = int(np.argmax(final_pred)) if confidence < 0.35: return "Uncertain", confidence if debug: top2 = np.argsort(final_pred)[-2:][::-1] print("Prediction vector:", final_pred) print("Top-1:", LABELS[top2[0]]) print("Top-2:", LABELS[top2[1]]) return LABELS[predicted_index], confidence except Exception as e: raise HTTPException( status_code=500, detail=f"Prediction failed: {str(e)}" )