Merge branch 'dyff-824-refactor-codebase' into 'main'
Browse filesDYFF-824: Finalize the example submission repository
See merge request ul-dsri/sandbox/sachin-sharma-in/ml-inference-service!3
- .dockerignore +38 -0
- .env.example +11 -0
- Dockerfile +26 -0
- README.md +253 -326
- app/api/controllers.py +16 -55
- app/api/routes/prediction.py +13 -10
- app/api/routes/resnet_service_manager.py +0 -19
- app/core/app.py +55 -9
- app/core/config.py +0 -29
- app/core/dependencies.py +11 -13
- app/core/lifespan.py +0 -43
- app/core/logging.py +16 -39
- app/services/base.py +30 -0
- app/services/inference.py +56 -129
- test_main.http +3 -2
.dockerignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.so
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info
|
| 9 |
+
dist
|
| 10 |
+
build
|
| 11 |
+
|
| 12 |
+
.venv
|
| 13 |
+
venv
|
| 14 |
+
ENV
|
| 15 |
+
env
|
| 16 |
+
|
| 17 |
+
.git
|
| 18 |
+
.gitignore
|
| 19 |
+
.idea
|
| 20 |
+
.vscode
|
| 21 |
+
.claude
|
| 22 |
+
|
| 23 |
+
*.md
|
| 24 |
+
README.md
|
| 25 |
+
Dockerfile
|
| 26 |
+
.dockerignore
|
| 27 |
+
|
| 28 |
+
test_*.http
|
| 29 |
+
test_results
|
| 30 |
+
scripts/test_datasets
|
| 31 |
+
|
| 32 |
+
.pytest_cache
|
| 33 |
+
.coverage
|
| 34 |
+
htmlcov
|
| 35 |
+
|
| 36 |
+
*.log
|
| 37 |
+
.DS_Store
|
| 38 |
+
.python-version
|
.env.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# App Configuration
|
| 2 |
+
APP_NAME="ML Inference Service"
|
| 3 |
+
APP_VERSION="0.1.0"
|
| 4 |
+
DEBUG=false
|
| 5 |
+
|
| 6 |
+
# Server
|
| 7 |
+
HOST="0.0.0.0"
|
| 8 |
+
PORT=8000
|
| 9 |
+
|
| 10 |
+
# Model
|
| 11 |
+
MODEL_NAME="microsoft/resnet-18"
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim as builder
|
| 2 |
+
|
| 3 |
+
WORKDIR /build
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 7 |
+
|
| 8 |
+
FROM python:3.12-slim
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
RUN useradd -m -u 1000 appuser
|
| 13 |
+
|
| 14 |
+
COPY --from=builder --chown=appuser:appuser /root/.local /home/appuser/.local
|
| 15 |
+
COPY --chown=appuser:appuser app ./app
|
| 16 |
+
COPY --chown=appuser:appuser models ./models
|
| 17 |
+
COPY --chown=appuser:appuser main.py .
|
| 18 |
+
|
| 19 |
+
USER appuser
|
| 20 |
+
|
| 21 |
+
ENV PATH=/home/appuser/.local/bin:$PATH \
|
| 22 |
+
PYTHONUNBUFFERED=1
|
| 23 |
+
|
| 24 |
+
EXPOSE 8000
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,419 +1,323 @@
|
|
| 1 |
-
# ML Inference Service
|
| 2 |
|
| 3 |
-
|
| 4 |
-
This repo ships with a working example using **ResNet-18** (downloaded from Hugging Face) under `models/resnet-18/` and exposes a simple **REST** endpoint.
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
- FastAPI application with clean layering (routes → controller → service)
|
| 11 |
-
- Hot-loaded model on startup (single instance reused per request)
|
| 12 |
-
- Hugging Face–compatible local model folder (`config.json`, weights, preprocessor, etc.)
|
| 13 |
-
- Example endpoint: `POST /predict/resnet` that accepts a base64 image and returns:
|
| 14 |
-
- `prediction` (class label)
|
| 15 |
-
- `confidence` (softmax probability)
|
| 16 |
-
- `predicted_label` (class index)
|
| 17 |
-
- `model` (model id)
|
| 18 |
-
- `mediaType` (echoed)
|
| 19 |
-
|
| 20 |
-
---
|
| 21 |
-
|
| 22 |
-
## 🧭 Project Layout
|
| 23 |
-
|
| 24 |
-
```
|
| 25 |
-
ml-inference-service/
|
| 26 |
-
├─ main.py
|
| 27 |
-
├─ app/
|
| 28 |
-
│ ├─ __init__.py
|
| 29 |
-
│ ├─ core/
|
| 30 |
-
│ │ ├─ app.py # App factory & router wiring
|
| 31 |
-
│ │ ├─ config.py # Settings (app name/version/debug)
|
| 32 |
-
│ │ ├─ dependencies.py # DI for model services
|
| 33 |
-
│ │ ├─ lifespan.py # Startup: load model & register service
|
| 34 |
-
│ │ └─ logging.py # Logger setup
|
| 35 |
-
│ ├─ api/
|
| 36 |
-
│ │ ├─ models.py # Pydantic request/response
|
| 37 |
-
│ │ ├─ controllers.py # HTTP → service orchestration
|
| 38 |
-
│ │ └─ routes/
|
| 39 |
-
│ │ ├─ prediction.py # `POST /predict/resnet`
|
| 40 |
-
│ │ └─ resnet_service_manager.py (legacy, unused)
|
| 41 |
-
│ └─ services/
|
| 42 |
-
│ └─ inference.py # ResNetInferenceService (load/predict)
|
| 43 |
-
├─ models/
|
| 44 |
-
│ └─ resnet-18/ # Sample HF-style model folder
|
| 45 |
-
├─ scripts/
|
| 46 |
-
│ ├─ model_download.bash # One-liner to snapshot HF weights locally
|
| 47 |
-
│ ├─ generate_test_datasets.py # Generate PyArrow datasets for testing
|
| 48 |
-
│ ├─ test_datasets.py # Test generated datasets against API
|
| 49 |
-
│ └─ test_datasets/ # Generated PyArrow test datasets (100 files)
|
| 50 |
-
├─ requirements.in / requirements.txt
|
| 51 |
-
└─ test_main.http # Example request you can run from IDEs
|
| 52 |
-
```
|
| 53 |
-
|
| 54 |
-
---
|
| 55 |
-
|
| 56 |
-
## 🚀 Quickstart
|
| 57 |
-
|
| 58 |
-
### 1) Install dependencies (Python 3.9+)
|
| 59 |
```bash
|
|
|
|
| 60 |
python -m venv .venv
|
| 61 |
-
source .venv/bin/activate
|
| 62 |
pip install -r requirements.txt
|
| 63 |
-
```
|
| 64 |
|
| 65 |
-
|
| 66 |
-
```bash
|
| 67 |
bash scripts/model_download.bash
|
|
|
|
|
|
|
|
|
|
| 68 |
```
|
| 69 |
-
This populates `models/resnet-18/` with Hugging Face artifacts (`config.json`, weights, `preprocessor_config.json`, etc.).
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
```bash
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
```
|
| 75 |
-
Server listens on `http://127.0.0.1:8000`.
|
| 76 |
|
| 77 |
-
|
| 78 |
-
- Use `test_main.http` from your IDE (VSCode/IntelliJ) **or** curl:
|
| 79 |
|
| 80 |
```bash
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
}'
|
| 84 |
```
|
| 85 |
|
| 86 |
-
|
| 87 |
```json
|
| 88 |
{
|
| 89 |
"prediction": "tiger cat",
|
| 90 |
-
"confidence": 0.
|
| 91 |
"predicted_label": 282,
|
| 92 |
"model": "microsoft/resnet-18",
|
| 93 |
"mediaType": "image/jpeg"
|
| 94 |
}
|
| 95 |
```
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
## 🧩 Bring Your Own Model (BYOM)
|
| 100 |
-
|
| 101 |
-
There are **two** ways to integrate your own model.
|
| 102 |
-
|
| 103 |
-
### Option A — *Drop-in replacement (zero code changes)*
|
| 104 |
-
|
| 105 |
-
If your model is a **Hugging Face image classification** model that works with
|
| 106 |
-
`AutoImageProcessor` and `ResNetForImageClassification` **or** a compatible
|
| 107 |
-
`*ForImageClassification` class from `transformers`, you can simply place the
|
| 108 |
-
model folder alongside `resnet-18` and point the service at it.
|
| 109 |
-
|
| 110 |
-
1. Put your HF-style folder under `models/<your-model-name>/` containing at least:
|
| 111 |
-
- `config.json`
|
| 112 |
-
- weights (e.g., `pytorch_model.bin` or `model.safetensors`)
|
| 113 |
-
- `preprocessor_config.json` / `image_processor` files
|
| 114 |
-
|
| 115 |
-
2. **Choose one** of these approaches:
|
| 116 |
-
- **Simplest**: Replace the contents of `models/resnet-18/` with your model files *but keep the folder name*. The existing `/predict/resnet` endpoint will now serve your model.
|
| 117 |
-
- **Preferred**: Change the model id used at startup:
|
| 118 |
-
- Open `app/core/lifespan.py` and modify the service initialization:
|
| 119 |
-
```python
|
| 120 |
-
resnet_service = ResNetInferenceService(
|
| 121 |
-
model_name="your-org/your-model", # used for local folder name
|
| 122 |
-
use_local_model=True # loads from models/your-model/
|
| 123 |
-
)
|
| 124 |
-
```
|
| 125 |
-
- Ensure your local folder is `models/your-model/`.
|
| 126 |
-
|
| 127 |
-
> How folder naming works: when `use_local_model=True`, the service derives the
|
| 128 |
-
> local directory as `models/<last-segment-of-model_name>`. For
|
| 129 |
-
> `"microsoft/resnet-18"` that becomes `models/resnet-18`. For
|
| 130 |
-
> `"your-org/awesome-vit-base"`, it becomes `models/awesome-vit-base`.
|
| 131 |
-
|
| 132 |
-
That’s it. No code changes elsewhere if your model is a standard image classifier.
|
| 133 |
-
|
| 134 |
-
---
|
| 135 |
-
|
| 136 |
-
### Option B — *New task/model type (minimal code: new service + route)*
|
| 137 |
-
|
| 138 |
-
If you are **not** serving a Hugging Face image classifier (e.g., object detection,
|
| 139 |
-
segmentation, text models), implement a small service class and a route mirroring
|
| 140 |
-
the `ResNetInferenceService` flow.
|
| 141 |
-
|
| 142 |
-
1. **Create your service** (copy and adapt `ResNetInferenceService`):
|
| 143 |
-
- File: `app/services/<your_model>_service.py`
|
| 144 |
-
- Responsibilities you must implement:
|
| 145 |
-
- `__init__(model_name: str, use_local_model: bool)` → set `self.model_path`
|
| 146 |
-
- `load_model()` → load weights & preprocessor
|
| 147 |
-
- `predict(image: PIL.Image.Image) -> Dict[str, Any]` → run inference and return a dict with:
|
| 148 |
-
```python
|
| 149 |
-
{
|
| 150 |
-
"prediction": "<your label or structured result>",
|
| 151 |
-
"confidence": <float 0..1>,
|
| 152 |
-
"predicted_label": <int or meaningful code>,
|
| 153 |
-
"model": "<model id>"
|
| 154 |
-
}
|
| 155 |
-
```
|
| 156 |
-
*Feel free to extend the payload; just update the API schema accordingly.*
|
| 157 |
-
|
| 158 |
-
2. **Wire the dependency**:
|
| 159 |
-
- Register your service at startup in `app/core/lifespan.py` similar to ResNet:
|
| 160 |
-
```python
|
| 161 |
-
from app.core.dependencies import set_resnet_service # or create your own set/get
|
| 162 |
-
from app.services.your_model_service import YourModelService
|
| 163 |
-
|
| 164 |
-
svc = YourModelService(model_name="your-org/your-model", use_local_model=True)
|
| 165 |
-
svc.load_model()
|
| 166 |
-
set_resnet_service(svc) # or create set_your_model_service(...)
|
| 167 |
-
```
|
| 168 |
-
- Optionally create **new getters/setters** in `app/core/dependencies.py` if you serve multiple models in parallel (one getter per model).
|
| 169 |
-
|
| 170 |
-
3. **Add a route**:
|
| 171 |
-
- Create `app/api/routes/your_model.py` analogous to `prediction.py`:
|
| 172 |
-
```python
|
| 173 |
-
from fastapi import APIRouter, Depends
|
| 174 |
-
from app.api.controllers import PredictionController
|
| 175 |
-
from app.api.models import ImageRequest, PredictionResponse
|
| 176 |
-
from app.core.dependencies import get_resnet_service # or your getter
|
| 177 |
-
from app.services.your_model_service import YourModelService
|
| 178 |
-
|
| 179 |
-
router = APIRouter()
|
| 180 |
-
|
| 181 |
-
@router.post("/predict/your-model", response_model=PredictionResponse)
|
| 182 |
-
async def predict_image(request: ImageRequest, service: YourModelService = Depends(get_resnet_service)):
|
| 183 |
-
controller = PredictionController(service) # reuse the controller
|
| 184 |
-
return await controller.predict(request)
|
| 185 |
-
```
|
| 186 |
-
- Register the router in `app/core/app.py`:
|
| 187 |
-
```python
|
| 188 |
-
from app.api.routes import your_model as your_model_routes
|
| 189 |
-
app.include_router(your_model_routes.router)
|
| 190 |
-
```
|
| 191 |
-
|
| 192 |
-
4. **Adjust schemas if needed**:
|
| 193 |
-
- The default `PredictionResponse` in `app/api/models.py` is for single-label classification. For other tasks, either extend it or define a new response model and use it in your route’s `response_model=`.
|
| 194 |
-
|
| 195 |
-
> **Tip**: Keep your controller thin and push all model-specific logic into your service class. The server glue (DI + routes) stays identical across models.
|
| 196 |
-
|
| 197 |
-
---
|
| 198 |
-
|
| 199 |
-
## 🧪 Validating your setup
|
| 200 |
-
|
| 201 |
-
- **Startup logs** should include: `Initializing ResNet service with local model: models/<folder>` and `Model and processor loaded successfully`.
|
| 202 |
-
- Hitting your endpoint should return a **200** with a JSON body like the example above.
|
| 203 |
-
- If you see `Local model directory not found`, check your `models/<name>/` path and filenames.
|
| 204 |
-
|
| 205 |
-
---
|
| 206 |
-
|
| 207 |
-
## 🔌 Request & Response Shapes
|
| 208 |
-
|
| 209 |
-
### Request
|
| 210 |
-
```json
|
| 211 |
-
{
|
| 212 |
-
"image": {
|
| 213 |
-
"mediaType": "image/jpeg",
|
| 214 |
-
"data": "<base64-encoded image bytes>"
|
| 215 |
-
}
|
| 216 |
-
}
|
| 217 |
-
```
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
```
|
| 229 |
|
| 230 |
-
|
| 231 |
|
| 232 |
-
##
|
| 233 |
|
| 234 |
-
|
| 235 |
-
- `app_name`, `app_version`, `debug`
|
| 236 |
|
| 237 |
-
|
| 238 |
|
| 239 |
-
Example:
|
| 240 |
```python
|
| 241 |
-
# app/
|
| 242 |
-
from
|
| 243 |
-
from
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
```
|
| 257 |
|
| 258 |
-
|
| 259 |
|
| 260 |
-
|
| 261 |
|
| 262 |
-
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
- **CPU vs GPU**: This example uses CPU by default. If you have CUDA, install a CUDA-enabled PyTorch build and set device placement in your service.
|
| 268 |
|
| 269 |
-
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
|
| 272 |
|
| 273 |
-
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
### 🏗️ Why Both? `.parquet` + `_metadata.json`
|
| 276 |
```
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
```
|
| 280 |
|
| 281 |
-
|
|
|
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
-
**Purpose**: Baseline functionality validation
|
| 285 |
|
| 286 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
|
| 293 |
-
|
| 294 |
-
|
|
|
|
| 295 |
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
-
|
| 299 |
-
- **Huge Images**: 2048x2048 (tests memory management and resizing)
|
| 300 |
-
- **Extreme Aspect Ratios**: 1000x50 (tests preprocessing assumptions)
|
| 301 |
-
- **Corrupted Data**: Invalid base64, malformed requests (tests error handling)
|
| 302 |
-
- **Expected Behavior**: Graceful degradation, proper error responses
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
|
| 307 |
-
**
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
-
|
| 310 |
-
- **Latency Tracking**: Expected max response times based on batch size
|
| 311 |
-
- **Throughput Metrics**: Requests per second under different loads
|
| 312 |
-
- **Expected Behavior**: Consistent performance within acceptable bounds
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
-
|
| 318 |
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
|
| 325 |
|
| 326 |
-
|
| 327 |
|
| 328 |
-
|
| 329 |
-
```python
|
| 330 |
-
# Why synthetic images instead of real photos?
|
| 331 |
-
# 1. Copyright-free for academic distribution
|
| 332 |
-
# 3. Programmatically generated edge cases
|
| 333 |
-
|
| 334 |
-
def create_synthetic_image(width, height, image_type):
|
| 335 |
-
if image_type == "random":
|
| 336 |
-
# RGB noise - tests model noise robustness
|
| 337 |
-
array = np.random.randint(0, 256, (height, width, 3))
|
| 338 |
-
elif image_type == "geometric":
|
| 339 |
-
# Shapes and patterns - tests feature detection
|
| 340 |
-
# ... geometric pattern generation
|
| 341 |
-
# ... other synthetic types
|
| 342 |
-
```
|
| 343 |
|
| 344 |
-
|
| 345 |
-
```
|
| 346 |
-
# Matches exact API format for drop-in testing
|
| 347 |
{
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
}
|
| 353 |
```
|
| 354 |
|
| 355 |
-
|
| 356 |
-
```
|
| 357 |
-
# Realistic prediction responses with proper structure
|
| 358 |
{
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
}
|
| 365 |
```
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
"api_request": [...], # JSON-serialized requests
|
| 374 |
-
"expected_response": [...], # JSON-serialized expected responses
|
| 375 |
-
"test_category": [...], # Category classification
|
| 376 |
-
"difficulty": [...], # Complexity indicator
|
| 377 |
-
# ... additional metadata columns
|
| 378 |
-
})
|
| 379 |
-
```
|
| 380 |
|
| 381 |
-
|
| 382 |
|
|
|
|
| 383 |
|
| 384 |
-
**1. Generate Test Datasets**
|
| 385 |
```bash
|
| 386 |
-
# Create all 100 datasets (~2-5 minutes depending on hardware)
|
| 387 |
python scripts/generate_test_datasets.py
|
| 388 |
-
|
| 389 |
-
# What this creates:
|
| 390 |
-
# - scripts/test_datasets/*.parquet (actual test data)
|
| 391 |
-
# - scripts/test_datasets/*_metadata.json (human-readable info)
|
| 392 |
-
# - scripts/test_datasets/datasets_summary.json (overview)
|
| 393 |
```
|
| 394 |
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
```bash
|
| 397 |
-
# Start your
|
| 398 |
uvicorn main:app --reload
|
| 399 |
|
| 400 |
# Quick test (5 samples per dataset)
|
| 401 |
python scripts/test_datasets.py --quick
|
| 402 |
|
| 403 |
-
# Full validation
|
| 404 |
python scripts/test_datasets.py
|
| 405 |
|
| 406 |
-
#
|
| 407 |
python scripts/test_datasets.py --category edge_case
|
| 408 |
-
python scripts/test_datasets.py --category performance
|
| 409 |
```
|
| 410 |
|
| 411 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
```
|
| 416 |
-
|
| 417 |
============================================================
|
| 418 |
Datasets tested: 100
|
| 419 |
Successful datasets: 95
|
|
@@ -425,7 +329,7 @@ Test duration: 45.2s
|
|
| 425 |
Performance:
|
| 426 |
Avg latency: 123.4ms
|
| 427 |
Median latency: 98.7ms
|
| 428 |
-
|
| 429 |
Max latency: 2,341.0ms
|
| 430 |
Requests/sec: 27.6
|
| 431 |
|
|
@@ -434,6 +338,29 @@ Category breakdown:
|
|
| 434 |
edge_case: 25 datasets, 76.8% avg success
|
| 435 |
performance: 25 datasets, 91.1% avg success
|
| 436 |
model_comparison: 25 datasets, 89.3% avg success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
-
|
|
|
|
| 439 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ML Inference Service
|
| 2 |
|
| 3 |
+
FastAPI service for serving ML models over HTTP. Comes with ResNet-18 for image classification out of the box, but you can swap in any model you want.
|
|
|
|
| 4 |
|
| 5 |
+
## Quick Start
|
| 6 |
|
| 7 |
+
**Local development:**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
```bash
|
| 9 |
+
# Install dependencies
|
| 10 |
python -m venv .venv
|
| 11 |
+
source .venv/bin/activate
|
| 12 |
pip install -r requirements.txt
|
|
|
|
| 13 |
|
| 14 |
+
# Download the example model
|
|
|
|
| 15 |
bash scripts/model_download.bash
|
| 16 |
+
|
| 17 |
+
# Run it
|
| 18 |
+
uvicorn main:app --reload
|
| 19 |
```
|
|
|
|
| 20 |
|
| 21 |
+
Server runs on `http://127.0.0.1:8000`. Check `/docs` for the interactive API documentation.
|
| 22 |
+
|
| 23 |
+
**Docker:**
|
| 24 |
```bash
|
| 25 |
+
# Build
|
| 26 |
+
docker build -t ml-inference-service:test .
|
| 27 |
+
|
| 28 |
+
# Run
|
| 29 |
+
docker run -d --name ml-inference-test -p 8000:8000 ml-inference-service:test
|
| 30 |
+
|
| 31 |
+
# Check logs
|
| 32 |
+
docker logs -f ml-inference-test
|
| 33 |
+
|
| 34 |
+
# Stop
|
| 35 |
+
docker stop ml-inference-test && docker rm ml-inference-test
|
| 36 |
```
|
|
|
|
| 37 |
|
| 38 |
+
## Testing the API
|
|
|
|
| 39 |
|
| 40 |
```bash
|
| 41 |
+
# Using curl
|
| 42 |
+
curl -X POST http://localhost:8000/predict \
|
| 43 |
+
-H "Content-Type: application/json" \
|
| 44 |
+
-d '{
|
| 45 |
+
"image": {
|
| 46 |
+
"mediaType": "image/jpeg",
|
| 47 |
+
"data": "<base64-encoded-image>"
|
| 48 |
+
}
|
| 49 |
}'
|
| 50 |
```
|
| 51 |
|
| 52 |
+
Example response:
|
| 53 |
```json
|
| 54 |
{
|
| 55 |
"prediction": "tiger cat",
|
| 56 |
+
"confidence": 0.394,
|
| 57 |
"predicted_label": 282,
|
| 58 |
"model": "microsoft/resnet-18",
|
| 59 |
"mediaType": "image/jpeg"
|
| 60 |
}
|
| 61 |
```
|
| 62 |
|
| 63 |
+
## Project Structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
```
|
| 66 |
+
ml-inference-service/
|
| 67 |
+
├── main.py # Entry point
|
| 68 |
+
├── app/
|
| 69 |
+
│ ├── core/
|
| 70 |
+
│ │ ├── app.py # App factory, config, DI, lifecycle
|
| 71 |
+
│ │ └── logging.py # Logging setup
|
| 72 |
+
│ ├── api/
|
| 73 |
+
│ │ ├── models.py # Request/response schemas
|
| 74 |
+
│ │ ├── controllers.py # Business logic
|
| 75 |
+
│ │ └── routes/
|
| 76 |
+
│ │ └── prediction.py # POST /predict
|
| 77 |
+
│ └── services/
|
| 78 |
+
│ ├── base.py # Abstract InferenceService class
|
| 79 |
+
│ └── inference.py # ResNet implementation
|
| 80 |
+
├── models/
|
| 81 |
+
│ └── microsoft/
|
| 82 |
+
│ └── resnet-18/ # Model weights and config
|
| 83 |
+
├── scripts/
|
| 84 |
+
│ ├── model_download.bash
|
| 85 |
+
│ ├── generate_test_datasets.py
|
| 86 |
+
│ └── test_datasets.py
|
| 87 |
+
├── Dockerfile # Multi-stage build
|
| 88 |
+
├── .env.example # Environment config template
|
| 89 |
+
└── requirements.txt
|
| 90 |
```
|
| 91 |
|
| 92 |
+
The key design decision here is that `app/core/app.py` consolidates everything—config, dependency injection, lifecycle, and the app factory. This avoids the mess of managing global state across multiple files.
|
| 93 |
|
| 94 |
+
## How to Plug In Your Own Model
|
| 95 |
|
| 96 |
+
The whole service is built around one abstract base class: `InferenceService`. Implement it for your model, and everything else just works.
|
|
|
|
| 97 |
|
| 98 |
+
### Step 1: Create Your Service Class
|
| 99 |
|
|
|
|
| 100 |
```python
|
| 101 |
+
# app/services/your_model_service.py
|
| 102 |
+
from app.services.base import InferenceService
|
| 103 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 104 |
+
import asyncio
|
| 105 |
+
|
| 106 |
+
class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
|
| 107 |
+
def __init__(self, model_name: str):
|
| 108 |
+
self.model_name = model_name
|
| 109 |
+
self.model_path = f"models/{model_name}"
|
| 110 |
+
self.model = None
|
| 111 |
+
self._is_loaded = False
|
| 112 |
+
|
| 113 |
+
async def load_model(self) -> None:
|
| 114 |
+
"""Load your model here. Called once at startup."""
|
| 115 |
+
self.model = load_your_model(self.model_path)
|
| 116 |
+
self._is_loaded = True
|
| 117 |
+
|
| 118 |
+
async def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 119 |
+
"""Run inference. Offload heavy work to thread pool."""
|
| 120 |
+
return await asyncio.to_thread(self._predict_sync, request)
|
| 121 |
+
|
| 122 |
+
def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
|
| 123 |
+
"""Actual inference happens here."""
|
| 124 |
+
image = decode_base64_image(request.image.data)
|
| 125 |
+
result = self.model(image)
|
| 126 |
+
|
| 127 |
+
return PredictionResponse(
|
| 128 |
+
prediction=result.label,
|
| 129 |
+
confidence=result.confidence,
|
| 130 |
+
predicted_label=result.class_id,
|
| 131 |
+
model=self.model_name,
|
| 132 |
+
mediaType=request.image.mediaType
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def is_loaded(self) -> bool:
|
| 137 |
+
return self._is_loaded
|
| 138 |
```
|
| 139 |
|
| 140 |
+
**Important:** Use `asyncio.to_thread()` to run CPU-heavy inference in a background thread. This keeps the server responsive while your model is working.
|
| 141 |
|
| 142 |
+
### Step 2: Register Your Service
|
| 143 |
|
| 144 |
+
Open `app/core/app.py` and find the lifespan function:
|
| 145 |
|
| 146 |
+
```python
|
| 147 |
+
# Change this line:
|
| 148 |
+
service = ResNetInferenceService(model_name="microsoft/resnet-18")
|
|
|
|
| 149 |
|
| 150 |
+
# To this:
|
| 151 |
+
service = YourModelService(model_name="your-org/your-model")
|
| 152 |
+
```
|
| 153 |
|
| 154 |
+
That's it. The `/predict` endpoint now serves your model.
|
| 155 |
|
| 156 |
+
### Model Files
|
| 157 |
+
|
| 158 |
+
Put your model files under `models/` with the full org/model structure:
|
| 159 |
|
|
|
|
| 160 |
```
|
| 161 |
+
models/
|
| 162 |
+
└── your-org/
|
| 163 |
+
└── your-model/
|
| 164 |
+
├── config.json
|
| 165 |
+
├── weights.bin
|
| 166 |
+
└── (other files)
|
| 167 |
```
|
| 168 |
|
| 169 |
+
No renaming, no dropping the org prefix—it just mirrors the Hugging Face structure.
|
| 170 |
+
|
| 171 |
+
## Configuration
|
| 172 |
|
| 173 |
+
Settings are managed via environment variables or a `.env` file. See `.env.example` for all available options.
|
|
|
|
| 174 |
|
| 175 |
+
**Default values:**
|
| 176 |
+
- `APP_NAME`: "ML Inference Service"
|
| 177 |
+
- `APP_VERSION`: "0.1.0"
|
| 178 |
+
- `DEBUG`: false
|
| 179 |
+
- `HOST`: "0.0.0.0"
|
| 180 |
+
- `PORT`: 8000
|
| 181 |
+
- `MODEL_NAME`: "microsoft/resnet-18"
|
| 182 |
|
| 183 |
+
**To customize:**
|
| 184 |
+
```bash
|
| 185 |
+
# Copy the example
|
| 186 |
+
cp .env.example .env
|
| 187 |
|
| 188 |
+
# Edit values
|
| 189 |
+
vim .env
|
| 190 |
+
```
|
| 191 |
|
| 192 |
+
Or set environment variables directly:
|
| 193 |
+
```bash
|
| 194 |
+
export MODEL_NAME="google/vit-base-patch16-224"
|
| 195 |
+
uvicorn main:app --reload
|
| 196 |
+
```
|
| 197 |
|
| 198 |
+
## Deployment
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
**Development:**
|
| 201 |
+
```bash
|
| 202 |
+
uvicorn main:app --reload
|
| 203 |
+
```
|
| 204 |
|
| 205 |
+
**Production:**
|
| 206 |
+
```bash
|
| 207 |
+
gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
|
| 208 |
+
```
|
| 209 |
|
| 210 |
+
The service runs on CPU by default. For GPU inference, install CUDA-enabled PyTorch and modify your service to move tensors to the GPU device.
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
**Docker:**
|
| 213 |
+
- Multi-stage build keeps the image small
|
| 214 |
+
- Runs as non-root user (`appuser`)
|
| 215 |
+
- Python dependencies installed in user site-packages
|
| 216 |
+
- Model files baked into the image
|
| 217 |
|
| 218 |
+
## What Happens When You Start the Server
|
| 219 |
|
| 220 |
+
```
|
| 221 |
+
INFO: Starting ML Inference Service...
|
| 222 |
+
INFO: Initializing ResNet service: models/microsoft/resnet-18
|
| 223 |
+
INFO: Loading model from models/microsoft/resnet-18
|
| 224 |
+
INFO: Model loaded: 1000 classes
|
| 225 |
+
INFO: Startup completed successfully
|
| 226 |
+
INFO: Uvicorn running on http://0.0.0.0:8000
|
| 227 |
+
```
|
| 228 |
|
| 229 |
+
If you see "Model directory not found", check that your model files exist at the expected path with the full org/model structure.
|
| 230 |
|
| 231 |
+
## API Reference
|
| 232 |
|
| 233 |
+
**Endpoint:** `POST /predict`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
**Request:**
|
| 236 |
+
```json
|
|
|
|
| 237 |
{
|
| 238 |
+
"image": {
|
| 239 |
+
"mediaType": "image/jpeg", // or "image/png"
|
| 240 |
+
"data": "<base64-encoded-image>"
|
| 241 |
+
}
|
| 242 |
}
|
| 243 |
```
|
| 244 |
|
| 245 |
+
**Response:**
|
| 246 |
+
```json
|
|
|
|
| 247 |
{
|
| 248 |
+
"prediction": "string", // Human-readable label
|
| 249 |
+
"confidence": 0.0, // Softmax probability
|
| 250 |
+
"predicted_label": 0, // Numeric class index
|
| 251 |
+
"model": "org/model-name", // Model identifier
|
| 252 |
+
"mediaType": "image/jpeg" // Echoed from request
|
| 253 |
}
|
| 254 |
```
|
| 255 |
|
| 256 |
+
**Docs:**
|
| 257 |
+
- Swagger UI: `http://localhost:8000/docs`
|
| 258 |
+
- ReDoc: `http://localhost:8000/redoc`
|
| 259 |
+
- OpenAPI JSON: `http://localhost:8000/openapi.json`
|
| 260 |
+
|
| 261 |
+
## PyArrow Test Datasets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
+
We've included a test dataset system for validating your model. It generates 100 standardized test cases covering normal inputs, edge cases, performance benchmarks, and model comparisons.
|
| 264 |
|
| 265 |
+
### Generate Datasets
|
| 266 |
|
|
|
|
| 267 |
```bash
|
|
|
|
| 268 |
python scripts/generate_test_datasets.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
```
|
| 270 |
|
| 271 |
+
This creates:
|
| 272 |
+
- `scripts/test_datasets/*.parquet` - Test data (images, requests, expected responses)
|
| 273 |
+
- `scripts/test_datasets/*_metadata.json` - Human-readable descriptions
|
| 274 |
+
- `scripts/test_datasets/datasets_summary.json` - Overview of all datasets
|
| 275 |
+
|
| 276 |
+
### Run Tests
|
| 277 |
+
|
| 278 |
```bash
|
| 279 |
+
# Start your service first
|
| 280 |
uvicorn main:app --reload
|
| 281 |
|
| 282 |
# Quick test (5 samples per dataset)
|
| 283 |
python scripts/test_datasets.py --quick
|
| 284 |
|
| 285 |
+
# Full validation
|
| 286 |
python scripts/test_datasets.py
|
| 287 |
|
| 288 |
+
# Test specific category
|
| 289 |
python scripts/test_datasets.py --category edge_case
|
|
|
|
| 290 |
```
|
| 291 |
|
| 292 |
+
### Dataset Categories (25 datasets each)
|
| 293 |
+
|
| 294 |
+
**1. Standard Tests** (`standard_test_*.parquet`)
|
| 295 |
+
- Normal images: random patterns, shapes, gradients
|
| 296 |
+
- Common sizes: 224x224, 256x256, 299x299, 384x384
|
| 297 |
+
- Formats: JPEG, PNG
|
| 298 |
+
- Purpose: Baseline validation
|
| 299 |
+
|
| 300 |
+
**2. Edge Cases** (`edge_case_*.parquet`)
|
| 301 |
+
- Tiny images (32x32, 1x1)
|
| 302 |
+
- Huge images (2048x2048)
|
| 303 |
+
- Extreme aspect ratios (1000x50)
|
| 304 |
+
- Corrupted data, malformed requests
|
| 305 |
+
- Purpose: Test error handling
|
| 306 |
|
| 307 |
+
**3. Performance Benchmarks** (`performance_test_*.parquet`)
|
| 308 |
+
- Batch sizes: 1, 5, 10, 25, 50, 100 images
|
| 309 |
+
- Latency and throughput tracking
|
| 310 |
+
- Purpose: Performance profiling
|
| 311 |
+
|
| 312 |
+
**4. Model Comparisons** (`model_comparison_*.parquet`)
|
| 313 |
+
- Same inputs across different architectures
|
| 314 |
+
- Models: ResNet-18/50, ViT, ConvNext, Swin
|
| 315 |
+
- Purpose: Cross-model benchmarking
|
| 316 |
+
|
| 317 |
+
### Test Output
|
| 318 |
|
| 319 |
```
|
| 320 |
+
DATASET TESTING SUMMARY
|
| 321 |
============================================================
|
| 322 |
Datasets tested: 100
|
| 323 |
Successful datasets: 95
|
|
|
|
| 329 |
Performance:
|
| 330 |
Avg latency: 123.4ms
|
| 331 |
Median latency: 98.7ms
|
| 332 |
+
p95 latency: 342.1ms
|
| 333 |
Max latency: 2,341.0ms
|
| 334 |
Requests/sec: 27.6
|
| 335 |
|
|
|
|
| 338 |
edge_case: 25 datasets, 76.8% avg success
|
| 339 |
performance: 25 datasets, 91.1% avg success
|
| 340 |
model_comparison: 25 datasets, 89.3% avg success
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
## Common Issues
|
| 344 |
+
|
| 345 |
+
**Port 8000 already in use:**
|
| 346 |
+
```bash
|
| 347 |
+
# Find what's using it
|
| 348 |
+
lsof -i :8000
|
| 349 |
|
| 350 |
+
# Or just use a different port
|
| 351 |
+
uvicorn main:app --port 8080
|
| 352 |
```
|
| 353 |
+
|
| 354 |
+
**Model not loading:**
|
| 355 |
+
- Check the path: models should be in `models/<org>/<model-name>/`
|
| 356 |
+
- Make sure you ran `bash scripts/model_download.bash`
|
| 357 |
+
- Check logs for the exact error
|
| 358 |
+
|
| 359 |
+
**Slow inference:**
|
| 360 |
+
- Inference runs on CPU by default
|
| 361 |
+
- For GPU: install CUDA PyTorch and modify service to use GPU device
|
| 362 |
+
- Consider using smaller models or quantization
|
| 363 |
+
|
| 364 |
+
## License
|
| 365 |
+
|
| 366 |
+
MIT
|
app/api/controllers.py
CHANGED
|
@@ -1,75 +1,36 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Controllers for handling API business logic.
|
| 3 |
-
"""
|
| 4 |
-
import base64
|
| 5 |
-
import io
|
| 6 |
|
|
|
|
| 7 |
from fastapi import HTTPException
|
| 8 |
-
from PIL import Image
|
| 9 |
|
| 10 |
from app.core.logging import logger
|
| 11 |
-
from app.services.
|
| 12 |
from app.api.models import ImageRequest, PredictionResponse
|
| 13 |
|
| 14 |
|
| 15 |
class PredictionController:
|
| 16 |
-
"""Controller for
|
| 17 |
|
| 18 |
@staticmethod
|
| 19 |
-
async def
|
| 20 |
-
|
| 21 |
-
|
| 22 |
) -> PredictionResponse:
|
| 23 |
-
"""
|
| 24 |
-
Classify an image using ResNet-18 from base64 encoded data.
|
| 25 |
-
"""
|
| 26 |
try:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
raise HTTPException(
|
| 30 |
-
status_code=503,
|
| 31 |
-
detail="Service not initialized"
|
| 32 |
-
)
|
| 33 |
|
| 34 |
-
# Validate media type
|
| 35 |
if not request.image.mediaType.startswith('image/'):
|
| 36 |
-
raise HTTPException(
|
| 37 |
-
status_code=400,
|
| 38 |
-
detail=f"Invalid media type: {request.image.mediaType}"
|
| 39 |
-
)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
try:
|
| 43 |
-
image_data = base64.b64decode(request.image.data)
|
| 44 |
-
except Exception as decode_error:
|
| 45 |
-
raise HTTPException(
|
| 46 |
-
status_code=400,
|
| 47 |
-
detail=f"Invalid base64 data: {str(decode_error)}"
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
# Load and validate image
|
| 51 |
-
try:
|
| 52 |
-
image = Image.open(io.BytesIO(image_data))
|
| 53 |
-
except Exception as img_error:
|
| 54 |
-
raise HTTPException(
|
| 55 |
-
status_code=400,
|
| 56 |
-
detail=f"Invalid image file: {str(img_error)}"
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
# Perform prediction
|
| 60 |
-
result = resnet_service.predict(image)
|
| 61 |
-
|
| 62 |
-
# Return structured response
|
| 63 |
-
return PredictionResponse(
|
| 64 |
-
prediction=result["prediction"],
|
| 65 |
-
confidence=result["confidence"],
|
| 66 |
-
model=result["model"],
|
| 67 |
-
predicted_label=result["predicted_label"],
|
| 68 |
-
mediaType=request.image.mediaType
|
| 69 |
-
)
|
| 70 |
|
| 71 |
except HTTPException:
|
| 72 |
raise
|
|
|
|
|
|
|
|
|
|
| 73 |
except Exception as e:
|
| 74 |
logger.error(f"Prediction failed: {e}")
|
| 75 |
-
raise HTTPException(
|
|
|
|
| 1 |
+
"""API controllers for request handling and validation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
from fastapi import HTTPException
|
|
|
|
| 5 |
|
| 6 |
from app.core.logging import logger
|
| 7 |
+
from app.services.base import InferenceService
|
| 8 |
from app.api.models import ImageRequest, PredictionResponse
|
| 9 |
|
| 10 |
|
| 11 |
class PredictionController:
|
| 12 |
+
"""Controller for prediction endpoints."""
|
| 13 |
|
| 14 |
@staticmethod
|
| 15 |
+
async def predict(
|
| 16 |
+
request: ImageRequest,
|
| 17 |
+
service: InferenceService
|
| 18 |
) -> PredictionResponse:
|
| 19 |
+
"""Run inference using the configured service."""
|
|
|
|
|
|
|
| 20 |
try:
|
| 21 |
+
if not service or not service.is_loaded:
|
| 22 |
+
raise HTTPException(503, "Service not available")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
|
|
|
| 24 |
if not request.image.mediaType.startswith('image/'):
|
| 25 |
+
raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
return await asyncio.to_thread(service.predict, request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
except HTTPException:
|
| 30 |
raise
|
| 31 |
+
except ValueError as e:
|
| 32 |
+
logger.error(f"Invalid input: {e}")
|
| 33 |
+
raise HTTPException(400, str(e))
|
| 34 |
except Exception as e:
|
| 35 |
logger.error(f"Prediction failed: {e}")
|
| 36 |
+
raise HTTPException(500, "Internal server error")
|
app/api/routes/prediction.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
from fastapi import APIRouter, Depends
|
| 5 |
|
| 6 |
from app.api.controllers import PredictionController
|
| 7 |
from app.api.models import ImageRequest, PredictionResponse
|
| 8 |
-
from app.core.dependencies import
|
| 9 |
-
from app.services.
|
| 10 |
|
| 11 |
router = APIRouter()
|
| 12 |
|
| 13 |
|
| 14 |
-
@router.post("/predict
|
| 15 |
-
async def
|
| 16 |
request: ImageRequest,
|
| 17 |
-
|
| 18 |
):
|
| 19 |
-
"""
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prediction API routes."""
|
| 2 |
+
|
|
|
|
| 3 |
from fastapi import APIRouter, Depends
|
| 4 |
|
| 5 |
from app.api.controllers import PredictionController
|
| 6 |
from app.api.models import ImageRequest, PredictionResponse
|
| 7 |
+
from app.core.dependencies import get_inference_service
|
| 8 |
+
from app.services.base import InferenceService
|
| 9 |
|
| 10 |
router = APIRouter()
|
| 11 |
|
| 12 |
|
| 13 |
+
@router.post("/predict", response_model=PredictionResponse)
|
| 14 |
+
async def predict(
|
| 15 |
request: ImageRequest,
|
| 16 |
+
service: InferenceService = Depends(get_inference_service)
|
| 17 |
):
|
| 18 |
+
"""
|
| 19 |
+
Run inference on base64-encoded image.
|
| 20 |
+
|
| 21 |
+
Returns prediction, confidence, predicted label, model name, and media type.
|
| 22 |
+
"""
|
| 23 |
+
return await PredictionController.predict(request, service)
|
app/api/routes/resnet_service_manager.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
# """
|
| 2 |
-
# Dependency injection for FastAPI.
|
| 3 |
-
# """
|
| 4 |
-
# from typing import Optional
|
| 5 |
-
# from app.services.inference import ResNetInferenceService
|
| 6 |
-
#
|
| 7 |
-
# # Global service instance
|
| 8 |
-
# _resnet_service: Optional[ResNetInferenceService] = None
|
| 9 |
-
#
|
| 10 |
-
#
|
| 11 |
-
# def get_resnet_service() -> Optional[ResNetInferenceService]:
|
| 12 |
-
# """Get the ResNet service instance."""
|
| 13 |
-
# return _resnet_service
|
| 14 |
-
#
|
| 15 |
-
#
|
| 16 |
-
# def set_resnet_service(service: ResNetInferenceService) -> None:
|
| 17 |
-
# """Set the global ResNet service instance."""
|
| 18 |
-
# global _resnet_service
|
| 19 |
-
# _resnet_service = service
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/app.py
CHANGED
|
@@ -1,16 +1,63 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from fastapi import FastAPI
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
from app.core.
|
| 7 |
-
from app.core.
|
|
|
|
| 8 |
from app.api.routes import prediction
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
-
"""Application
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
app = FastAPI(
|
| 15 |
title=settings.app_name,
|
| 16 |
description="ML inference service for image classification",
|
|
@@ -19,7 +66,6 @@ def create_app() -> FastAPI:
|
|
| 19 |
lifespan=lifespan
|
| 20 |
)
|
| 21 |
|
| 22 |
-
# Include only prediction router
|
| 23 |
app.include_router(prediction.router)
|
| 24 |
|
| 25 |
-
return app
|
|
|
|
| 1 |
+
"""FastAPI application factory and core infrastructure."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import warnings
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from typing import AsyncGenerator, Optional
|
| 7 |
+
|
| 8 |
from fastapi import FastAPI
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
from pydantic_settings import BaseSettings
|
| 11 |
|
| 12 |
+
from app.core.logging import logger
|
| 13 |
+
from app.core.dependencies import set_inference_service
|
| 14 |
+
from app.services.inference import ResNetInferenceService
|
| 15 |
from app.api.routes import prediction
|
| 16 |
|
| 17 |
|
| 18 |
+
class Settings(BaseSettings):
|
| 19 |
+
"""Application settings. Override via environment variables or .env file."""
|
| 20 |
+
|
| 21 |
+
app_name: str = Field(default="ML Inference Service")
|
| 22 |
+
app_version: str = Field(default="0.1.0")
|
| 23 |
+
debug: bool = Field(default=False)
|
| 24 |
+
host: str = Field(default="0.0.0.0")
|
| 25 |
+
port: int = Field(default=8000)
|
| 26 |
+
|
| 27 |
+
class Config:
|
| 28 |
+
env_file = ".env"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
settings = Settings()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@asynccontextmanager
|
| 35 |
+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 36 |
+
"""Application lifecycle: startup/shutdown."""
|
| 37 |
+
logger.info("Starting ML Inference Service...")
|
| 38 |
|
| 39 |
+
try:
|
| 40 |
+
with warnings.catch_warnings():
|
| 41 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 42 |
+
|
| 43 |
+
# Replace ResNetInferenceService with your own implementation
|
| 44 |
+
service = ResNetInferenceService(model_name="microsoft/resnet-18")
|
| 45 |
+
await asyncio.to_thread(service.load_model)
|
| 46 |
+
set_inference_service(service)
|
| 47 |
+
|
| 48 |
+
logger.info("Startup completed successfully")
|
| 49 |
+
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Startup failed: {e}")
|
| 52 |
+
raise
|
| 53 |
+
|
| 54 |
+
yield
|
| 55 |
+
|
| 56 |
+
logger.info("Shutting down...")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_app() -> FastAPI:
|
| 60 |
+
"""Create and configure FastAPI application."""
|
| 61 |
app = FastAPI(
|
| 62 |
title=settings.app_name,
|
| 63 |
description="ML inference service for image classification",
|
|
|
|
| 66 |
lifespan=lifespan
|
| 67 |
)
|
| 68 |
|
|
|
|
| 69 |
app.include_router(prediction.router)
|
| 70 |
|
| 71 |
+
return app
|
app/core/config.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Basic configuration management.
|
| 3 |
-
|
| 4 |
-
Starting simple - just app settings. We'll expand as needed.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from pydantic import Field
|
| 8 |
-
from pydantic_settings import BaseSettings # Changed import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class Settings(BaseSettings):
|
| 12 |
-
"""Application settings with environment variable support."""
|
| 13 |
-
|
| 14 |
-
# Basic app settings
|
| 15 |
-
app_name: str = Field(default="ML Inference Service", description="Application name")
|
| 16 |
-
app_version: str = Field(default="0.1.0", description="Application version")
|
| 17 |
-
debug: bool = Field(default=False, description="Debug mode")
|
| 18 |
-
|
| 19 |
-
# Server settings
|
| 20 |
-
host: str = Field(default="0.0.0.0", description="Server host")
|
| 21 |
-
port: int = Field(default=8000, description="Server port")
|
| 22 |
-
|
| 23 |
-
class Config:
|
| 24 |
-
"""Load from .env file if it exists."""
|
| 25 |
-
env_file = ".env"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# Global settings instance
|
| 29 |
-
settings = Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/dependencies.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
from typing import Optional
|
| 5 |
-
from app.services.
|
| 6 |
|
| 7 |
-
|
| 8 |
-
_resnet_service: Optional[ResNetInferenceService] = None
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
| 12 |
-
"""Get
|
| 13 |
-
return
|
| 14 |
|
| 15 |
|
| 16 |
-
def
|
| 17 |
-
"""Set
|
| 18 |
-
global
|
| 19 |
-
|
|
|
|
| 1 |
+
"""Dependency injection for services."""
|
| 2 |
+
|
|
|
|
| 3 |
from typing import Optional
|
| 4 |
+
from app.services.base import InferenceService
|
| 5 |
|
| 6 |
+
_inference_service: Optional[InferenceService] = None
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
+
def get_inference_service() -> Optional[InferenceService]:
|
| 10 |
+
"""Get inference service for dependency injection."""
|
| 11 |
+
return _inference_service
|
| 12 |
|
| 13 |
|
| 14 |
+
def set_inference_service(service: InferenceService) -> None:
|
| 15 |
+
"""Set inference service. Called internally during startup."""
|
| 16 |
+
global _inference_service
|
| 17 |
+
_inference_service = service
|
app/core/lifespan.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Application lifespan management.
|
| 3 |
-
"""
|
| 4 |
-
import warnings
|
| 5 |
-
from contextlib import asynccontextmanager
|
| 6 |
-
from typing import AsyncGenerator
|
| 7 |
-
|
| 8 |
-
from fastapi import FastAPI
|
| 9 |
-
|
| 10 |
-
from app.core.logging import logger
|
| 11 |
-
from app.core.dependencies import set_resnet_service
|
| 12 |
-
from app.services.inference import ResNetInferenceService
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@asynccontextmanager
|
| 16 |
-
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
| 17 |
-
"""Application lifespan manager."""
|
| 18 |
-
|
| 19 |
-
# Startup
|
| 20 |
-
logger.info("Starting ML Inference Service...")
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
with warnings.catch_warnings():
|
| 24 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 25 |
-
|
| 26 |
-
# Initialize and load ResNet service
|
| 27 |
-
resnet_service = ResNetInferenceService(
|
| 28 |
-
model_name="microsoft/resnet-18",
|
| 29 |
-
use_local_model=True
|
| 30 |
-
)
|
| 31 |
-
resnet_service.load_model()
|
| 32 |
-
set_resnet_service(resnet_service)
|
| 33 |
-
|
| 34 |
-
logger.info("Startup completed successfully")
|
| 35 |
-
|
| 36 |
-
except Exception as e:
|
| 37 |
-
logger.error(f"Startup failed: {e}")
|
| 38 |
-
raise
|
| 39 |
-
|
| 40 |
-
yield # App runs here
|
| 41 |
-
|
| 42 |
-
# Shutdown
|
| 43 |
-
logger.info("Shutting down...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/logging.py
CHANGED
|
@@ -1,49 +1,26 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Logging configuration for the application.
|
| 3 |
-
"""
|
| 4 |
|
| 5 |
import logging
|
| 6 |
import sys
|
| 7 |
-
from typing import Optional
|
| 8 |
|
| 9 |
-
from app.core.config import settings
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@staticmethod
|
| 16 |
-
def setup_logging(
|
| 17 |
-
logger_name: Optional[str] = None,
|
| 18 |
-
level: Optional[str] = None,
|
| 19 |
-
format_string: Optional[str] = None
|
| 20 |
-
) -> logging.Logger:
|
| 21 |
-
"""Set up and configure a logger."""
|
| 22 |
-
logger = logging.getLogger(logger_name or settings.app_name)
|
| 23 |
-
|
| 24 |
-
# Avoid duplicate handlers
|
| 25 |
-
if logger.handlers:
|
| 26 |
-
return logger
|
| 27 |
-
|
| 28 |
-
# Set level
|
| 29 |
-
log_level = getattr(logging, (level or "INFO").upper())
|
| 30 |
-
logger.setLevel(log_level)
|
| 31 |
-
|
| 32 |
-
# Create console handler
|
| 33 |
-
handler = logging.StreamHandler(sys.stdout)
|
| 34 |
-
handler.setLevel(log_level)
|
| 35 |
-
|
| 36 |
-
# Create formatter
|
| 37 |
-
formatter = logging.Formatter(
|
| 38 |
-
format_string or "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 39 |
-
)
|
| 40 |
-
handler.setFormatter(formatter)
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
|
| 48 |
-
|
| 49 |
-
logger = LoggerSetup.setup_logging()
|
|
|
|
| 1 |
+
"""Logging configuration."""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import sys
|
|
|
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
+
def setup_logging(logger_name: str = "ML Inference Service") -> logging.Logger:
|
| 8 |
+
"""Setup and configure logger."""
|
| 9 |
+
logger = logging.getLogger(logger_name)
|
| 10 |
|
| 11 |
+
if logger.handlers:
|
| 12 |
+
return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 16 |
+
handler.setLevel(logging.INFO)
|
| 17 |
+
formatter = logging.Formatter(
|
| 18 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 19 |
+
)
|
| 20 |
+
handler.setFormatter(formatter)
|
| 21 |
+
logger.addHandler(handler)
|
| 22 |
|
| 23 |
+
return logger
|
| 24 |
|
| 25 |
|
| 26 |
+
logger = setup_logging()
|
|
|
app/services/base.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for ML inference services."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Generic, TypeVar
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
TRequest = TypeVar('TRequest', bound=BaseModel)
|
| 8 |
+
TResponse = TypeVar('TResponse', bound=BaseModel)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InferenceService(ABC, Generic[TRequest, TResponse]):
|
| 12 |
+
"""
|
| 13 |
+
Base class for inference services. Subclass this to integrate your model.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def load_model(self) -> None:
|
| 18 |
+
"""Load model weights and processors. Called once at startup."""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def predict(self, request: TRequest) -> TResponse:
|
| 23 |
+
"""Run inference and return typed response."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def is_loaded(self) -> bool:
|
| 29 |
+
"""Check if model is loaded and ready."""
|
| 30 |
+
pass
|
app/services/inference.py
CHANGED
|
@@ -1,152 +1,79 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference service for machine learning models.
|
| 3 |
|
| 4 |
-
This service handles the business logic for ML inference,
|
| 5 |
-
following the Single Responsibility Principle.
|
| 6 |
-
"""
|
| 7 |
import os
|
| 8 |
-
|
|
|
|
| 9 |
import torch
|
| 10 |
from PIL import Image
|
| 11 |
from transformers import AutoImageProcessor, ResNetForImageClassification
|
| 12 |
|
| 13 |
from app.core.logging import logger
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
-
class ResNetInferenceService:
|
| 17 |
-
"""
|
| 18 |
-
ResNet inference service.
|
| 19 |
|
| 20 |
-
|
| 21 |
-
Follows the Singleton pattern - loads model once.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
def __init__(self, model_name: str = "microsoft/resnet-18", use_local_model: bool = True):
|
| 25 |
-
"""
|
| 26 |
-
Initialize the ResNet service.
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
model_name: HuggingFace model identifier
|
| 30 |
-
"""
|
| 31 |
self.model_name = model_name
|
| 32 |
-
self.use_local_model = use_local_model
|
| 33 |
self.model = None
|
| 34 |
self.processor = None
|
| 35 |
self._is_loaded = False
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
self.model_path = os.path.join("models", model_name.split("/")[-1])
|
| 39 |
-
logger.info(f"Initializing ResNet service with local model: {self.model_path}")
|
| 40 |
-
else:
|
| 41 |
-
self.model_path = model_name
|
| 42 |
-
logger.info(f"Initializing ResNet service with remote model: {model_name}")
|
| 43 |
|
| 44 |
def load_model(self) -> None:
|
| 45 |
-
"""
|
| 46 |
-
Load the ResNet model and processor.
|
| 47 |
-
|
| 48 |
-
This method loads the model once and reuses it for all requests.
|
| 49 |
-
"""
|
| 50 |
if self._is_loaded:
|
| 51 |
-
logger.debug("Model already loaded, skipping...")
|
| 52 |
return
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
image: PIL Image to classify
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
Dictionary containing prediction results
|
| 105 |
-
|
| 106 |
-
Raises:
|
| 107 |
-
RuntimeError: If model is not loaded
|
| 108 |
-
ValueError: If image processing fails
|
| 109 |
-
"""
|
| 110 |
-
if not self._is_loaded:
|
| 111 |
-
logger.info("Model not loaded, loading now...")
|
| 112 |
-
self.load_model()
|
| 113 |
-
|
| 114 |
-
try:
|
| 115 |
-
logger.debug("Starting ResNet inference")
|
| 116 |
-
|
| 117 |
-
if image.mode != 'RGB':
|
| 118 |
-
image = image.convert('RGB')
|
| 119 |
-
logger.debug(f"Converted image from {image.mode} to RGB")
|
| 120 |
-
|
| 121 |
-
inputs = self.processor(image, return_tensors="pt")
|
| 122 |
-
|
| 123 |
-
# Perform inference
|
| 124 |
-
with torch.no_grad():
|
| 125 |
-
logits = self.model(**inputs).logits
|
| 126 |
-
|
| 127 |
-
# Get prediction
|
| 128 |
-
predicted_label = logits.argmax(-1).item()
|
| 129 |
-
predicted_class = self.model.config.id2label[predicted_label]
|
| 130 |
-
|
| 131 |
-
# Calculate confidence score
|
| 132 |
-
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 133 |
-
confidence = probabilities[0][predicted_label].item()
|
| 134 |
-
|
| 135 |
-
result = {
|
| 136 |
-
"prediction": predicted_class,
|
| 137 |
-
"confidence": round(confidence, 4),
|
| 138 |
-
"model": self.model_name,
|
| 139 |
-
"predicted_label": predicted_label
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
|
| 143 |
-
return result
|
| 144 |
-
|
| 145 |
-
except Exception as e:
|
| 146 |
-
logger.error(f"Inference failed: {e}")
|
| 147 |
-
raise ValueError(f"Failed to process image: {str(e)}")
|
| 148 |
|
| 149 |
@property
|
| 150 |
def is_loaded(self) -> bool:
|
| 151 |
-
"""Check if model is loaded."""
|
| 152 |
return self._is_loaded
|
|
|
|
| 1 |
+
"""ResNet inference service implementation."""
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
+
import base64
|
| 5 |
+
from io import BytesIO
|
| 6 |
import torch
|
| 7 |
from PIL import Image
|
| 8 |
from transformers import AutoImageProcessor, ResNetForImageClassification
|
| 9 |
|
| 10 |
from app.core.logging import logger
|
| 11 |
+
from app.services.base import InferenceService
|
| 12 |
+
from app.api.models import ImageRequest, PredictionResponse
|
| 13 |
|
| 14 |
|
| 15 |
+
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
| 16 |
+
"""ResNet-18 inference service for image classification."""
|
|
|
|
| 17 |
|
| 18 |
+
def __init__(self, model_name: str = "microsoft/resnet-18"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
self.model_name = model_name
|
|
|
|
| 20 |
self.model = None
|
| 21 |
self.processor = None
|
| 22 |
self._is_loaded = False
|
| 23 |
+
self.model_path = os.path.join("models", model_name)
|
| 24 |
+
logger.info(f"Initializing ResNet service: {self.model_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def load_model(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if self._is_loaded:
|
|
|
|
| 28 |
return
|
| 29 |
|
| 30 |
+
if not os.path.exists(self.model_path):
|
| 31 |
+
raise FileNotFoundError(f"Model not found: {self.model_path}")
|
| 32 |
+
|
| 33 |
+
config_path = os.path.join(self.model_path, "config.json")
|
| 34 |
+
if not os.path.exists(config_path):
|
| 35 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 36 |
+
|
| 37 |
+
logger.info(f"Loading model from {self.model_path}")
|
| 38 |
+
|
| 39 |
+
import warnings
|
| 40 |
+
with warnings.catch_warnings():
|
| 41 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 42 |
+
self.processor = AutoImageProcessor.from_pretrained(
|
| 43 |
+
self.model_path, local_files_only=True
|
| 44 |
+
)
|
| 45 |
+
self.model = ResNetForImageClassification.from_pretrained(
|
| 46 |
+
self.model_path, local_files_only=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self._is_loaded = True
|
| 50 |
+
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
|
| 51 |
+
|
| 52 |
+
def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 53 |
+
image_data = base64.b64decode(request.image.data)
|
| 54 |
+
image = Image.open(BytesIO(image_data))
|
| 55 |
+
|
| 56 |
+
if image.mode != 'RGB':
|
| 57 |
+
image = image.convert('RGB')
|
| 58 |
+
|
| 59 |
+
inputs = self.processor(image, return_tensors="pt")
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
logits = self.model(**inputs).logits
|
| 63 |
+
|
| 64 |
+
predicted_label = logits.argmax(-1).item()
|
| 65 |
+
predicted_class = self.model.config.id2label[predicted_label]
|
| 66 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
| 67 |
+
confidence = probabilities[0][predicted_label].item()
|
| 68 |
+
|
| 69 |
+
return PredictionResponse(
|
| 70 |
+
prediction=predicted_class,
|
| 71 |
+
confidence=round(confidence, 4),
|
| 72 |
+
model=self.model_name,
|
| 73 |
+
predicted_label=predicted_label,
|
| 74 |
+
mediaType=request.image.mediaType
|
| 75 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
@property
|
| 78 |
def is_loaded(self) -> bool:
|
|
|
|
| 79 |
return self._is_loaded
|
test_main.http
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
# Test
|
|
|
|
| 2 |
|
| 3 |
-
POST http://127.0.0.1:8000/predict
|
| 4 |
Content-Type: application/json
|
| 5 |
|
| 6 |
{
|
|
|
|
| 1 |
+
# Test Prediction Endpoint
|
| 2 |
+
# Works with any model configured at startup (default: ResNet-18)
|
| 3 |
|
| 4 |
+
POST http://127.0.0.1:8000/predict
|
| 5 |
Content-Type: application/json
|
| 6 |
|
| 7 |
{
|