File size: 4,234 Bytes
65cb21e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import uuid
import tempfile

app = FastAPI(
    title="Medical Image Classification API",
    description="AI-powered medical image classification service",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

model = None
class_names = {
    0: "AKIEC",  # Actinic keratoses and intraepithelial carcinoma
    1: "BCC",    # Basal cell carcinoma
    2: "BKL",    # Benign keratosis-like lesions
    3: "DF",     # Dermatofibroma
    4: "MEL",    # Melanoma
    5: "NV",     # Melanocytic nevi
    6: "VASC"    # Vascular lesions
}
full_names = {
    "AKIEC": "Actinic keratoses and intraepithelial carcinoma",
    "BCC": "Basal cell carcinoma",
    "BKL": "Benign keratosis-like lesions",
    "DF": "Dermatofibroma",
    "MEL": "Melanoma",
    "NV": "Melanocytic nevi",
    "VASC": "Vascular lesions"
}
UPLOAD_DIR = tempfile.mkdtemp()

def load_model():
    global model
    model_path = "efficientnetv2s.h5"
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")
    model = tf.keras.models.load_model(model_path)
    return model

def predict_image(image_path):
    global model
    if model is None:
        model = load_model()
    img = Image.open(image_path).convert('RGB')
    img = img.resize((224, 224))
    img_array = np.array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array.astype(np.float32) / 255.0
    predictions = model.predict(img_array)
    predicted_class = np.argmax(predictions[0])
    confidence = float(predictions[0][predicted_class])
    class_code = class_names[predicted_class]
    class_full_name = full_names[class_code]
    all_predictions = []
    for i, prob in enumerate(predictions[0]):
        all_predictions.append({
            "label": class_names[i],
            "confidence": float(prob)
        })
    all_predictions.sort(key=lambda x: x["confidence"], reverse=True)
    return class_full_name, confidence, all_predictions

@app.get("/health")
async def health_check():
    return {"status": "healthy", "service": "medical-image-classifier"}

@app.post("/api/classify")
async def classify_image_api(file: UploadFile = File(...)):
    try:
        if not file.content_type or not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image")
        unique_filename = f"{uuid.uuid4().hex}_{file.filename}"
        file_path = os.path.join(UPLOAD_DIR, unique_filename)
        with open(file_path, "wb") as buffer:
            content = await file.read()
            buffer.write(content)
        label, confidence, all_predictions = predict_image(file_path)
        os.remove(file_path)
        formatted_predictions = []
        for pred in all_predictions:
            formatted_predictions.append({
                "label": pred["label"],
                "confidence": float(pred["confidence"]),
                "confidence_percent": f"{pred['confidence'] * 100:.2f}%"
            })
        return JSONResponse(
            status_code=200,
            content={
                "success": True,
                "prediction": {
                    "top_prediction": {
                        "label": label,
                        "confidence": float(confidence),
                        "confidence_percent": f"{confidence * 100:.2f}%"
                    },
                    "all_predictions": formatted_predictions
                }
            }
        )
    except Exception as e:
        if 'file_path' in locals() and os.path.exists(file_path):
            os.remove(file_path)
        raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8003)