Prasanta4 commited on
Commit
84d6927
·
verified ·
1 Parent(s): 4ce1f64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from model import ModifiedMobileNetV2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from io import BytesIO
10
+ import logging
11
+ import os
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ app = FastAPI(title="Gallbladder Classification API", description="API for gallbladder condition classification using ModifiedMobileNetV2")
18
+
19
+ # Add CORS middleware
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"], # Allow all origins for testing; specify domains for production (e.g., ["https://prasanta4.github.io"])
23
+ allow_credentials=True,
24
+ allow_methods=["GET", "POST"], # Allow GET for /, /health; POST for /predict
25
+ allow_headers=["*"], # Allow all headers
26
+ )
27
+
28
+ # Class names provided by user
29
+ class_names = ['Gallstones', 'Cholecystitis', 'Gangrenous_Cholecystitis', 'Perforation', 'Polyps&Cholesterol_Crystal', 'WallThickening', 'Adenomyomatosis', 'Carcinoma', 'Intra-abdominal&Retroperitoneum', 'Normal']
30
+
31
+ # Device setup
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ logger.info(f"Using device: {device}")
34
+
35
+ # Model initialization
36
+ model = None
37
+
38
+ def load_model():
39
+ global model
40
+ try:
41
+ model_path = 'GB_stu_mob.pth'
42
+ if not os.path.exists(model_path):
43
+ logger.error(f"Model file {model_path} not found!")
44
+ raise FileNotFoundError(f"Model file {model_path} not found!")
45
+
46
+ model = ModifiedMobileNetV2(num_classes=len(class_names)).to(device)
47
+
48
+ # Load with map_location for CPU compatibility
49
+ checkpoint = torch.load(model_path, map_location=device)
50
+ model.load_state_dict(checkpoint)
51
+ model.eval()
52
+ logger.info("Model loaded successfully")
53
+ return True
54
+ except Exception as e:
55
+ logger.error(f"Error loading model: {str(e)}")
56
+ return False
57
+
58
+ # Load model at startup
59
+ model_loaded = load_model()
60
+
61
+ # Preprocessing
62
+ preprocess = transforms.Compose([
63
+ transforms.Resize((224, 224)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
66
+ ])
67
+
68
+ # Inference function
69
+ def predict(image):
70
+ if model is None:
71
+ raise HTTPException(status_code=500, detail="Model not loaded")
72
+
73
+ try:
74
+ with torch.no_grad():
75
+ if not torch.is_tensor(image):
76
+ image = preprocess(image).unsqueeze(0)
77
+ image = image.to(device)
78
+ output = model(image)
79
+ probabilities = torch.softmax(output, dim=1)
80
+ predicted_class = torch.argmax(probabilities, dim=1)
81
+ confidence_score = probabilities[0, predicted_class.item()].item()
82
+ return class_names[predicted_class.item()], confidence_score
83
+ except Exception as e:
84
+ logger.error(f"Error during prediction: {str(e)}")
85
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
86
+
87
+ @app.post("/predict")
88
+ async def predict_image(file: UploadFile = File(...)):
89
+ if not model_loaded:
90
+ raise HTTPException(status_code=500, detail="Model not properly loaded")
91
+
92
+ try:
93
+ # Validate file type
94
+ if not file.content_type.startswith('image/'):
95
+ raise HTTPException(status_code=400, detail="File must be an image")
96
+
97
+ # Read image file
98
+ contents = await file.read()
99
+ if len(contents) == 0:
100
+ raise HTTPException(status_code=400, detail="Empty file")
101
+
102
+ try:
103
+ image = Image.open(BytesIO(contents)).convert('RGB')
104
+ except Exception as e:
105
+ raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
106
+
107
+ # Run prediction
108
+ class_name, confidence_score = predict(image)
109
+
110
+ return {
111
+ "filename": file.filename,
112
+ "predicted_class": class_name,
113
+ "confidence_score": round(confidence_score, 4)
114
+ }
115
+ except HTTPException:
116
+ raise
117
+ except Exception as e:
118
+ logger.error(f"Error processing image: {str(e)}")
119
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
120
+
121
+ @app.get("/")
122
+ async def root():
123
+ return {
124
+ "message": "Welcome to the Gallbladder Classification API"
125
+ }
126
+
127
+ @app.get("/health")
128
+ async def health_check():
129
+ return {
130
+ "status": "healthy" if model_loaded else "unhealthy",
131
+ "model_loaded": model_loaded,
132
+ "device": str(device)
133
+ }
134
+
135
+ if __name__ == "__main__":
136
+ import uvicorn
137
+ uvicorn.run(app, host="0.0.0.0", port=7860)