test-docker / app.py
Prasanta4's picture
Create app.py
84d6927 verified
import torch
import torch.nn as nn
from torchvision import transforms
from model import ModifiedMobileNetV2
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from io import BytesIO
import logging
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Gallbladder Classification API", description="API for gallbladder condition classification using ModifiedMobileNetV2")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for testing; specify domains for production (e.g., ["https://prasanta4.github.io"])
allow_credentials=True,
allow_methods=["GET", "POST"], # Allow GET for /, /health; POST for /predict
allow_headers=["*"], # Allow all headers
)
# Class names provided by user
class_names = ['Gallstones', 'Cholecystitis', 'Gangrenous_Cholecystitis', 'Perforation', 'Polyps&Cholesterol_Crystal', 'WallThickening', 'Adenomyomatosis', 'Carcinoma', 'Intra-abdominal&Retroperitoneum', 'Normal']
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Model initialization
model = None
def load_model():
global model
try:
model_path = 'GB_stu_mob.pth'
if not os.path.exists(model_path):
logger.error(f"Model file {model_path} not found!")
raise FileNotFoundError(f"Model file {model_path} not found!")
model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device)
# Load with map_location for CPU compatibility
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
# Load model at startup
model_loaded = load_model()
# Preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Inference function
def predict(image):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
with torch.no_grad():
if not torch.is_tensor(image):
image = preprocess(image).unsqueeze(0)
image = image.to(device)
output = model(image)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1)
confidence_score = probabilities[0, predicted_class.item()].item()
return class_names[predicted_class.item()], confidence_score
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
if not model_loaded:
raise HTTPException(status_code=500, detail="Model not properly loaded")
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image file
contents = await file.read()
if len(contents) == 0:
raise HTTPException(status_code=400, detail="Empty file")
try:
image = Image.open(BytesIO(contents)).convert('RGB')
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
# Run prediction
class_name, confidence_score = predict(image)
return {
"filename": file.filename,
"predicted_class": class_name,
"confidence_score": round(confidence_score, 4)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/")
async def root():
return {
"message": "Welcome to the Gallbladder Classification API"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model_loaded else "unhealthy",
"model_loaded": model_loaded,
"device": str(device)
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)