image_class / main.py
Tantawi's picture
Upload 3 files
65cb21e verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import uuid
import tempfile
app = FastAPI(
title="Medical Image Classification API",
description="AI-powered medical image classification service",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = None
class_names = {
0: "AKIEC", # Actinic keratoses and intraepithelial carcinoma
1: "BCC", # Basal cell carcinoma
2: "BKL", # Benign keratosis-like lesions
3: "DF", # Dermatofibroma
4: "MEL", # Melanoma
5: "NV", # Melanocytic nevi
6: "VASC" # Vascular lesions
}
full_names = {
"AKIEC": "Actinic keratoses and intraepithelial carcinoma",
"BCC": "Basal cell carcinoma",
"BKL": "Benign keratosis-like lesions",
"DF": "Dermatofibroma",
"MEL": "Melanoma",
"NV": "Melanocytic nevi",
"VASC": "Vascular lesions"
}
UPLOAD_DIR = tempfile.mkdtemp()
def load_model():
global model
model_path = "efficientnetv2s.h5"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
model = tf.keras.models.load_model(model_path)
return model
def predict_image(image_path):
global model
if model is None:
model = load_model()
img = Image.open(image_path).convert('RGB')
img = img.resize((224, 224))
img_array = np.array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array.astype(np.float32) / 255.0
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions[0])
confidence = float(predictions[0][predicted_class])
class_code = class_names[predicted_class]
class_full_name = full_names[class_code]
all_predictions = []
for i, prob in enumerate(predictions[0]):
all_predictions.append({
"label": class_names[i],
"confidence": float(prob)
})
all_predictions.sort(key=lambda x: x["confidence"], reverse=True)
return class_full_name, confidence, all_predictions
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "medical-image-classifier"}
@app.post("/api/classify")
async def classify_image_api(file: UploadFile = File(...)):
try:
if not file.content_type or not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
unique_filename = f"{uuid.uuid4().hex}_{file.filename}"
file_path = os.path.join(UPLOAD_DIR, unique_filename)
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
label, confidence, all_predictions = predict_image(file_path)
os.remove(file_path)
formatted_predictions = []
for pred in all_predictions:
formatted_predictions.append({
"label": pred["label"],
"confidence": float(pred["confidence"]),
"confidence_percent": f"{pred['confidence'] * 100:.2f}%"
})
return JSONResponse(
status_code=200,
content={
"success": True,
"prediction": {
"top_prediction": {
"label": label,
"confidence": float(confidence),
"confidence_percent": f"{confidence * 100:.2f}%"
},
"all_predictions": formatted_predictions
}
}
)
except Exception as e:
if 'file_path' in locals() and os.path.exists(file_path):
os.remove(file_path)
raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8003)