RichardLu commited on
Commit
9d50300
·
verified ·
1 Parent(s): 00d54d9

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +23 -0
  2. app.py +162 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1-mesa-glx \
8
+ libglib2.0-0 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for caching
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy app files
16
+ COPY app.py .
17
+ COPY models/ models/
18
+
19
+ # Expose port 7860 (HF Spaces default)
20
+ EXPOSE 7860
21
+
22
+ # Run the API
23
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI for Pneumonia Detection - Hugging Face Spaces Deployment
3
+ """
4
+
5
+ import io
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import models, transforms
12
+ from PIL import Image
13
+ from fastapi import FastAPI, UploadFile, File, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+
17
+ # =============================================================================
18
+ # Configuration
19
+ # =============================================================================
20
+
21
+ IMAGE_SIZE = 224
22
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
23
+ IMAGENET_STD = [0.229, 0.224, 0.225]
24
+ CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
25
+ MODEL_PATH = Path("models/best_model.pt")
26
+
27
+ # =============================================================================
28
+ # Model Definition
29
+ # =============================================================================
30
+
31
+ class PneumoniaClassifier(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+ self.backbone = models.efficientnet_b0(weights=None)
35
+ in_features = self.backbone.classifier[1].in_features
36
+ self.backbone.classifier = nn.Sequential(
37
+ nn.Dropout(p=0.3, inplace=True),
38
+ nn.Linear(in_features, 1)
39
+ )
40
+
41
+ def forward(self, x):
42
+ return self.backbone(x)
43
+
44
+ # =============================================================================
45
+ # Response Models
46
+ # =============================================================================
47
+
48
+ class HealthResponse(BaseModel):
49
+ status: str
50
+ model_loaded: bool
51
+
52
+ class PredictionResponse(BaseModel):
53
+ prediction: str
54
+ confidence: float
55
+ probability: float
56
+ processing_time_ms: float
57
+
58
+ # =============================================================================
59
+ # App Setup
60
+ # =============================================================================
61
+
62
+ app = FastAPI(
63
+ title="Pneumonia Detection API",
64
+ description="Deep learning API for detecting pneumonia from chest X-rays",
65
+ version="1.0.0"
66
+ )
67
+
68
+ app.add_middleware(
69
+ CORSMiddleware,
70
+ allow_origins=["*"],
71
+ allow_credentials=True,
72
+ allow_methods=["*"],
73
+ allow_headers=["*"],
74
+ )
75
+
76
+ # =============================================================================
77
+ # Model Loading
78
+ # =============================================================================
79
+
80
+ model = None
81
+ device = None
82
+
83
+ @app.on_event("startup")
84
+ async def load_model():
85
+ global model, device
86
+
87
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ print(f"Using device: {device}")
89
+
90
+ if not MODEL_PATH.exists():
91
+ print(f"Warning: Model not found at {MODEL_PATH}")
92
+ return
93
+
94
+ model = PneumoniaClassifier()
95
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
96
+ model.load_state_dict(checkpoint['model_state_dict'])
97
+ model.to(device)
98
+ model.eval()
99
+ print("Model loaded successfully")
100
+
101
+ # =============================================================================
102
+ # Helper Functions
103
+ # =============================================================================
104
+
105
+ def get_transforms():
106
+ return transforms.Compose([
107
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
108
+ transforms.ToTensor(),
109
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
110
+ ])
111
+
112
+ async def read_image(file: UploadFile) -> Image.Image:
113
+ contents = await file.read()
114
+ return Image.open(io.BytesIO(contents)).convert("RGB")
115
+
116
+ def predict(image: Image.Image):
117
+ transform = get_transforms()
118
+ img_tensor = transform(image).unsqueeze(0).to(device)
119
+
120
+ with torch.no_grad():
121
+ output = model(img_tensor)
122
+ prob = torch.sigmoid(output).item()
123
+
124
+ pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
125
+ confidence = prob if prob > 0.5 else 1 - prob
126
+ return pred_class, confidence, prob
127
+
128
+ # =============================================================================
129
+ # Endpoints
130
+ # =============================================================================
131
+
132
+ @app.get("/")
133
+ async def root():
134
+ return {"message": "Pneumonia Detection API", "docs": "/docs"}
135
+
136
+ @app.get("/health", response_model=HealthResponse)
137
+ async def health():
138
+ return HealthResponse(
139
+ status="healthy" if model else "model_not_loaded",
140
+ model_loaded=model is not None
141
+ )
142
+
143
+ @app.post("/predict", response_model=PredictionResponse)
144
+ async def predict_endpoint(file: UploadFile = File(...)):
145
+ if model is None:
146
+ raise HTTPException(status_code=503, detail="Model not loaded")
147
+
148
+ if not file.content_type.startswith("image/"):
149
+ raise HTTPException(status_code=400, detail="File must be an image")
150
+
151
+ image = await read_image(file)
152
+
153
+ start_time = time.time()
154
+ pred_class, confidence, prob = predict(image)
155
+ processing_time = (time.time() - start_time) * 1000
156
+
157
+ return PredictionResponse(
158
+ prediction=pred_class,
159
+ confidence=confidence,
160
+ probability=prob,
161
+ processing_time_ms=round(processing_time, 2)
162
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ fastapi>=0.100.0
4
+ uvicorn>=0.23.0
5
+ python-multipart>=0.0.6
6
+ pillow>=10.0.0
7
+ numpy>=1.24.0