File size: 4,686 Bytes
84d6927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
from torchvision import transforms
from model import ModifiedMobileNetV2
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from io import BytesIO
import logging
import os

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Gallbladder Classification API", description="API for gallbladder condition classification using ModifiedMobileNetV2")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins for testing; specify domains for production (e.g., ["https://prasanta4.github.io"])
    allow_credentials=True,
    allow_methods=["GET", "POST"],  # Allow GET for /, /health; POST for /predict
    allow_headers=["*"],  # Allow all headers
)

# Class names provided by user
class_names = ['Gallstones', 'Cholecystitis', 'Gangrenous_Cholecystitis', 'Perforation', 'Polyps&Cholesterol_Crystal', 'WallThickening', 'Adenomyomatosis', 'Carcinoma', 'Intra-abdominal&Retroperitoneum', 'Normal']

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

# Model initialization
model = None

def load_model():
    global model
    try:
        model_path = 'GB_stu_mob.pth'
        if not os.path.exists(model_path):
            logger.error(f"Model file {model_path} not found!")
            raise FileNotFoundError(f"Model file {model_path} not found!")
        
        model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device)
        
        # Load with map_location for CPU compatibility
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint)
        model.eval()
        logger.info("Model loaded successfully")
        return True
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        return False

# Load model at startup
model_loaded = load_model()

# Preprocessing
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Inference function
def predict(image):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    try:
        with torch.no_grad():
            if not torch.is_tensor(image):
                image = preprocess(image).unsqueeze(0)
            image = image.to(device)
            output = model(image)
            probabilities = torch.softmax(output, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1)
            confidence_score = probabilities[0, predicted_class.item()].item()
        return class_names[predicted_class.item()], confidence_score
    except Exception as e:
        logger.error(f"Error during prediction: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")

@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
    if not model_loaded:
        raise HTTPException(status_code=500, detail="Model not properly loaded")
    
    try:
        # Validate file type
        if not file.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="File must be an image")
        
        # Read image file
        contents = await file.read()
        if len(contents) == 0:
            raise HTTPException(status_code=400, detail="Empty file")
        
        try:
            image = Image.open(BytesIO(contents)).convert('RGB')
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
        
        # Run prediction
        class_name, confidence_score = predict(image)
        
        return {
            "filename": file.filename,
            "predicted_class": class_name,
            "confidence_score": round(confidence_score, 4)
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error processing image: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.get("/")
async def root():
    return {
        "message": "Welcome to the Gallbladder Classification API"
    }

@app.get("/health")
async def health_check():
    return {
        "status": "healthy" if model_loaded else "unhealthy",
        "model_loaded": model_loaded,
        "device": str(device)
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)