Rahul-Samedavar's picture
s
a11bc14
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
from tensorflow import keras
import numpy as np
from PIL import Image
import io
from huggingface_hub import hf_hub_download
app = FastAPI(title="Medical Image Classification API", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure this appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
weights_path = hf_hub_download(
repo_id="Rahul-Samedavar/OralCancer_Predictor",
filename="model.weights.h5",
)
# Load model
base_model = tf.keras.applications.DenseNet121(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet'
)
base_model.trainable = False
model = keras.models.Sequential()
model.add(base_model)
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(1024, activation=tf.nn.relu, kernel_regularizer=keras.regularizers.l2(0.01)))
model.add(keras.layers.Dropout(.3))
model.add(keras.layers.Dense(2, activation=tf.nn.softmax))
model.load_weights(weights_path)
class_names = ['Normal', 'OSCC']
def preprocess_image(image_bytes: bytes) -> np.ndarray:
"""Preprocess image for model prediction"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = image.resize((224, 224))
img_array = np.array(image) / 255.0
return np.expand_dims(img_array, axis=0)
@app.post("/predict")
async def predict(image: UploadFile = File(...)):
"""Predict medical condition from uploaded image"""
# Validate file type
if not image.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
try:
# Read image bytes
img_bytes = await image.read()
# Preprocess image
processed_image = preprocess_image(img_bytes)
# Make prediction
predictions = model.predict(processed_image)[0]
predicted_class = class_names[np.argmax(predictions)]
# Format confidence scores
confidence = {
class_names[i]: float(f"{predictions[i]:.4f}")
for i in range(len(class_names))
}
return {
"predicted_class": predicted_class,
"confidence_scores": confidence
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
app.mount("/", StaticFiles(directory="static/main", html=True), name="main")
app.mount("/cnn/", StaticFiles(directory="static/cnn", html=True), name="cnn")
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)