Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from model import ModifiedMobileNetV2 | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from io import BytesIO | |
| import logging | |
| import os | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Gallbladder Classification API", description="API for gallbladder condition classification using ModifiedMobileNetV2") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins for testing; specify domains for production (e.g., ["https://prasanta4.github.io"]) | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST"], # Allow GET for /, /health; POST for /predict | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| # Class names provided by user | |
| class_names = ['Gallstones', 'Cholecystitis', 'Gangrenous_Cholecystitis', 'Perforation', 'Polyps&Cholesterol_Crystal', 'WallThickening', 'Adenomyomatosis', 'Carcinoma', 'Intra-abdominal&Retroperitoneum', 'Normal'] | |
| # Device setup | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"Using device: {device}") | |
| # Model initialization | |
| model = None | |
| def load_model(): | |
| global model | |
| try: | |
| model_path = 'GB_stu_mob.pth' | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file {model_path} not found!") | |
| raise FileNotFoundError(f"Model file {model_path} not found!") | |
| model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device) | |
| # Load with map_location for CPU compatibility | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| logger.info("Model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return False | |
| # Load model at startup | |
| model_loaded = load_model() | |
| # Preprocessing | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Inference function | |
| def predict(image): | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| with torch.no_grad(): | |
| if not torch.is_tensor(image): | |
| image = preprocess(image).unsqueeze(0) | |
| image = image.to(device) | |
| output = model(image) | |
| probabilities = torch.softmax(output, dim=1) | |
| predicted_class = torch.argmax(probabilities, dim=1) | |
| confidence_score = probabilities[0, predicted_class.item()].item() | |
| return class_names[predicted_class.item()], confidence_score | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| async def predict_image(file: UploadFile = File(...)): | |
| if not model_loaded: | |
| raise HTTPException(status_code=500, detail="Model not properly loaded") | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read image file | |
| contents = await file.read() | |
| if len(contents) == 0: | |
| raise HTTPException(status_code=400, detail="Empty file") | |
| try: | |
| image = Image.open(BytesIO(contents)).convert('RGB') | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}") | |
| # Run prediction | |
| class_name, confidence_score = predict(image) | |
| return { | |
| "filename": file.filename, | |
| "predicted_class": class_name, | |
| "confidence_score": round(confidence_score, 4) | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing image: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def root(): | |
| return { | |
| "message": "Welcome to the Gallbladder Classification API" | |
| } | |
| async def health_check(): | |
| return { | |
| "status": "healthy" if model_loaded else "unhealthy", | |
| "model_loaded": model_loaded, | |
| "device": str(device) | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |