import torch import torch.nn as nn from torchvision import transforms from model import ModifiedMobileNetV2 import numpy as np from PIL import Image from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from io import BytesIO import logging import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Gallbladder Classification API", description="API for gallbladder condition classification using ModifiedMobileNetV2") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins for testing; specify domains for production (e.g., ["https://prasanta4.github.io"]) allow_credentials=True, allow_methods=["GET", "POST"], # Allow GET for /, /health; POST for /predict allow_headers=["*"], # Allow all headers ) # Class names provided by user class_names = ['Gallstones', 'Cholecystitis', 'Gangrenous_Cholecystitis', 'Perforation', 'Polyps&Cholesterol_Crystal', 'WallThickening', 'Adenomyomatosis', 'Carcinoma', 'Intra-abdominal&Retroperitoneum', 'Normal'] # Device setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") # Model initialization model = None def load_model(): global model try: model_path = 'GB_stu_mob.pth' if not os.path.exists(model_path): logger.error(f"Model file {model_path} not found!") raise FileNotFoundError(f"Model file {model_path} not found!") model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device) # Load with map_location for CPU compatibility checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint) model.eval() logger.info("Model loaded successfully") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False # Load model at startup model_loaded = load_model() # Preprocessing preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Inference function def predict(image): if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: with torch.no_grad(): if not torch.is_tensor(image): image = preprocess(image).unsqueeze(0) image = image.to(device) output = model(image) probabilities = torch.softmax(output, dim=1) predicted_class = torch.argmax(probabilities, dim=1) confidence_score = probabilities[0, predicted_class.item()].item() return class_names[predicted_class.item()], confidence_score except Exception as e: logger.error(f"Error during prediction: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") @app.post("/predict") async def predict_image(file: UploadFile = File(...)): if not model_loaded: raise HTTPException(status_code=500, detail="Model not properly loaded") try: # Validate file type if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read image file contents = await file.read() if len(contents) == 0: raise HTTPException(status_code=400, detail="Empty file") try: image = Image.open(BytesIO(contents)).convert('RGB') except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}") # Run prediction class_name, confidence_score = predict(image) return { "filename": file.filename, "predicted_class": class_name, "confidence_score": round(confidence_score, 4) } except HTTPException: raise except Exception as e: logger.error(f"Error processing image: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.get("/") async def root(): return { "message": "Welcome to the Gallbladder Classification API" } @app.get("/health") async def health_check(): return { "status": "healthy" if model_loaded else "unhealthy", "model_loaded": model_loaded, "device": str(device) } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)