from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware import tensorflow as tf from tensorflow import keras import numpy as np from PIL import Image import io from huggingface_hub import hf_hub_download app = FastAPI(title="Medical Image Classification API", version="1.0.0") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure this appropriately for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) weights_path = hf_hub_download( repo_id="Rahul-Samedavar/OralCancer_Predictor", filename="model.weights.h5", ) # Load model base_model = tf.keras.applications.DenseNet121( input_shape=(224, 224, 3), include_top=False, weights='imagenet' ) base_model.trainable = False model = keras.models.Sequential() model.add(base_model) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(1024, activation=tf.nn.relu, kernel_regularizer=keras.regularizers.l2(0.01))) model.add(keras.layers.Dropout(.3)) model.add(keras.layers.Dense(2, activation=tf.nn.softmax)) model.load_weights(weights_path) class_names = ['Normal', 'OSCC'] def preprocess_image(image_bytes: bytes) -> np.ndarray: """Preprocess image for model prediction""" image = Image.open(io.BytesIO(image_bytes)).convert('RGB') image = image.resize((224, 224)) img_array = np.array(image) / 255.0 return np.expand_dims(img_array, axis=0) @app.post("/predict") async def predict(image: UploadFile = File(...)): """Predict medical condition from uploaded image""" # Validate file type if not image.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") try: # Read image bytes img_bytes = await image.read() # Preprocess image processed_image = preprocess_image(img_bytes) # Make prediction predictions = model.predict(processed_image)[0] predicted_class = class_names[np.argmax(predictions)] # Format confidence scores confidence = { class_names[i]: float(f"{predictions[i]:.4f}") for i in range(len(class_names)) } return { "predicted_class": predicted_class, "confidence_scores": confidence } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") app.mount("/", StaticFiles(directory="static/main", html=True), name="main") app.mount("/cnn/", StaticFiles(directory="static/cnn", html=True), name="cnn") @app.get("/health") async def health(): """Health check endpoint""" return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)