jessehostetler commited on
Commit
1d3b709
·
2 Parent(s): ebbcd26 7fb83c4

Merge branch 'dyff-824-refactor-codebase' into 'main'

Browse files

DYFF-824: Finalize the example submission repository

See merge request ul-dsri/sandbox/sachin-sharma-in/ml-inference-service!3

.dockerignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .Python
6
+ *.so
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+
12
+ .venv
13
+ venv
14
+ ENV
15
+ env
16
+
17
+ .git
18
+ .gitignore
19
+ .idea
20
+ .vscode
21
+ .claude
22
+
23
+ *.md
24
+ README.md
25
+ Dockerfile
26
+ .dockerignore
27
+
28
+ test_*.http
29
+ test_results
30
+ scripts/test_datasets
31
+
32
+ .pytest_cache
33
+ .coverage
34
+ htmlcov
35
+
36
+ *.log
37
+ .DS_Store
38
+ .python-version
.env.example ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # App Configuration
2
+ APP_NAME="ML Inference Service"
3
+ APP_VERSION="0.1.0"
4
+ DEBUG=false
5
+
6
+ # Server
7
+ HOST="0.0.0.0"
8
+ PORT=8000
9
+
10
+ # Model
11
+ MODEL_NAME="microsoft/resnet-18"
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim as builder
2
+
3
+ WORKDIR /build
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir --user -r requirements.txt
7
+
8
+ FROM python:3.12-slim
9
+
10
+ WORKDIR /app
11
+
12
+ RUN useradd -m -u 1000 appuser
13
+
14
+ COPY --from=builder --chown=appuser:appuser /root/.local /home/appuser/.local
15
+ COPY --chown=appuser:appuser app ./app
16
+ COPY --chown=appuser:appuser models ./models
17
+ COPY --chown=appuser:appuser main.py .
18
+
19
+ USER appuser
20
+
21
+ ENV PATH=/home/appuser/.local/bin:$PATH \
22
+ PYTHONUNBUFFERED=1
23
+
24
+ EXPOSE 8000
25
+
26
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,419 +1,323 @@
1
- # ML Inference Service (FastAPI)
2
 
3
- A production-ready **FastAPI** web service that serves **image classification** models.
4
- This repo ships with a working example using **ResNet-18** (downloaded from Hugging Face) under `models/resnet-18/` and exposes a simple **REST** endpoint.
5
 
6
- ---
7
 
8
- ## ✨ What you get
9
-
10
- - FastAPI application with clean layering (routes → controller → service)
11
- - Hot-loaded model on startup (single instance reused per request)
12
- - Hugging Face–compatible local model folder (`config.json`, weights, preprocessor, etc.)
13
- - Example endpoint: `POST /predict/resnet` that accepts a base64 image and returns:
14
- - `prediction` (class label)
15
- - `confidence` (softmax probability)
16
- - `predicted_label` (class index)
17
- - `model` (model id)
18
- - `mediaType` (echoed)
19
-
20
- ---
21
-
22
- ## 🧭 Project Layout
23
-
24
- ```
25
- ml-inference-service/
26
- ├─ main.py
27
- ├─ app/
28
- │ ├─ __init__.py
29
- │ ├─ core/
30
- │ │ ├─ app.py # App factory & router wiring
31
- │ │ ├─ config.py # Settings (app name/version/debug)
32
- │ │ ├─ dependencies.py # DI for model services
33
- │ │ ├─ lifespan.py # Startup: load model & register service
34
- │ │ └─ logging.py # Logger setup
35
- │ ├─ api/
36
- │ │ ├─ models.py # Pydantic request/response
37
- │ │ ├─ controllers.py # HTTP → service orchestration
38
- │ │ └─ routes/
39
- │ │ ├─ prediction.py # `POST /predict/resnet`
40
- │ │ └─ resnet_service_manager.py (legacy, unused)
41
- │ └─ services/
42
- │ └─ inference.py # ResNetInferenceService (load/predict)
43
- ├─ models/
44
- │ └─ resnet-18/ # Sample HF-style model folder
45
- ├─ scripts/
46
- │ ├─ model_download.bash # One-liner to snapshot HF weights locally
47
- │ ├─ generate_test_datasets.py # Generate PyArrow datasets for testing
48
- │ ├─ test_datasets.py # Test generated datasets against API
49
- │ └─ test_datasets/ # Generated PyArrow test datasets (100 files)
50
- ├─ requirements.in / requirements.txt
51
- └─ test_main.http # Example request you can run from IDEs
52
- ```
53
-
54
- ---
55
-
56
- ## 🚀 Quickstart
57
-
58
- ### 1) Install dependencies (Python 3.9+)
59
  ```bash
 
60
  python -m venv .venv
61
- source .venv/bin/activate # Windows: .venv\Scripts\activate
62
  pip install -r requirements.txt
63
- ```
64
 
65
- ### 2) Download the sample model (ResNet‑18) locally
66
- ```bash
67
  bash scripts/model_download.bash
 
 
 
68
  ```
69
- This populates `models/resnet-18/` with Hugging Face artifacts (`config.json`, weights, `preprocessor_config.json`, etc.).
70
 
71
- ### 3) Run the server
 
 
72
  ```bash
73
- uvicorn main:app --reload
 
 
 
 
 
 
 
 
 
 
74
  ```
75
- Server listens on `http://127.0.0.1:8000`.
76
 
77
- ### 4) Call the API
78
- - Use `test_main.http` from your IDE (VSCode/IntelliJ) **or** curl:
79
 
80
  ```bash
81
- curl -X POST http://127.0.0.1:8000/predict/resnet -H "Content-Type: application/json" -d '{
82
- "image": { "mediaType": "image/jpeg", "data": "<base64-encoded-bytes>" }
 
 
 
 
 
 
83
  }'
84
  ```
85
 
86
- **Response (example):**
87
  ```json
88
  {
89
  "prediction": "tiger cat",
90
- "confidence": 0.9971,
91
  "predicted_label": 282,
92
  "model": "microsoft/resnet-18",
93
  "mediaType": "image/jpeg"
94
  }
95
  ```
96
 
97
- ---
98
-
99
- ## 🧩 Bring Your Own Model (BYOM)
100
-
101
- There are **two** ways to integrate your own model.
102
-
103
- ### Option A — *Drop-in replacement (zero code changes)*
104
-
105
- If your model is a **Hugging Face image classification** model that works with
106
- `AutoImageProcessor` and `ResNetForImageClassification` **or** a compatible
107
- `*ForImageClassification` class from `transformers`, you can simply place the
108
- model folder alongside `resnet-18` and point the service at it.
109
-
110
- 1. Put your HF-style folder under `models/<your-model-name>/` containing at least:
111
- - `config.json`
112
- - weights (e.g., `pytorch_model.bin` or `model.safetensors`)
113
- - `preprocessor_config.json` / `image_processor` files
114
-
115
- 2. **Choose one** of these approaches:
116
- - **Simplest**: Replace the contents of `models/resnet-18/` with your model files *but keep the folder name*. The existing `/predict/resnet` endpoint will now serve your model.
117
- - **Preferred**: Change the model id used at startup:
118
- - Open `app/core/lifespan.py` and modify the service initialization:
119
- ```python
120
- resnet_service = ResNetInferenceService(
121
- model_name="your-org/your-model", # used for local folder name
122
- use_local_model=True # loads from models/your-model/
123
- )
124
- ```
125
- - Ensure your local folder is `models/your-model/`.
126
-
127
- > How folder naming works: when `use_local_model=True`, the service derives the
128
- > local directory as `models/<last-segment-of-model_name>`. For
129
- > `"microsoft/resnet-18"` that becomes `models/resnet-18`. For
130
- > `"your-org/awesome-vit-base"`, it becomes `models/awesome-vit-base`.
131
-
132
- That’s it. No code changes elsewhere if your model is a standard image classifier.
133
-
134
- ---
135
-
136
- ### Option B — *New task/model type (minimal code: new service + route)*
137
-
138
- If you are **not** serving a Hugging Face image classifier (e.g., object detection,
139
- segmentation, text models), implement a small service class and a route mirroring
140
- the `ResNetInferenceService` flow.
141
-
142
- 1. **Create your service** (copy and adapt `ResNetInferenceService`):
143
- - File: `app/services/<your_model>_service.py`
144
- - Responsibilities you must implement:
145
- - `__init__(model_name: str, use_local_model: bool)` → set `self.model_path`
146
- - `load_model()` → load weights & preprocessor
147
- - `predict(image: PIL.Image.Image) -> Dict[str, Any]` → run inference and return a dict with:
148
- ```python
149
- {
150
- "prediction": "<your label or structured result>",
151
- "confidence": <float 0..1>,
152
- "predicted_label": <int or meaningful code>,
153
- "model": "<model id>"
154
- }
155
- ```
156
- *Feel free to extend the payload; just update the API schema accordingly.*
157
-
158
- 2. **Wire the dependency**:
159
- - Register your service at startup in `app/core/lifespan.py` similar to ResNet:
160
- ```python
161
- from app.core.dependencies import set_resnet_service # or create your own set/get
162
- from app.services.your_model_service import YourModelService
163
-
164
- svc = YourModelService(model_name="your-org/your-model", use_local_model=True)
165
- svc.load_model()
166
- set_resnet_service(svc) # or create set_your_model_service(...)
167
- ```
168
- - Optionally create **new getters/setters** in `app/core/dependencies.py` if you serve multiple models in parallel (one getter per model).
169
-
170
- 3. **Add a route**:
171
- - Create `app/api/routes/your_model.py` analogous to `prediction.py`:
172
- ```python
173
- from fastapi import APIRouter, Depends
174
- from app.api.controllers import PredictionController
175
- from app.api.models import ImageRequest, PredictionResponse
176
- from app.core.dependencies import get_resnet_service # or your getter
177
- from app.services.your_model_service import YourModelService
178
-
179
- router = APIRouter()
180
-
181
- @router.post("/predict/your-model", response_model=PredictionResponse)
182
- async def predict_image(request: ImageRequest, service: YourModelService = Depends(get_resnet_service)):
183
- controller = PredictionController(service) # reuse the controller
184
- return await controller.predict(request)
185
- ```
186
- - Register the router in `app/core/app.py`:
187
- ```python
188
- from app.api.routes import your_model as your_model_routes
189
- app.include_router(your_model_routes.router)
190
- ```
191
-
192
- 4. **Adjust schemas if needed**:
193
- - The default `PredictionResponse` in `app/api/models.py` is for single-label classification. For other tasks, either extend it or define a new response model and use it in your route’s `response_model=`.
194
-
195
- > **Tip**: Keep your controller thin and push all model-specific logic into your service class. The server glue (DI + routes) stays identical across models.
196
-
197
- ---
198
-
199
- ## 🧪 Validating your setup
200
-
201
- - **Startup logs** should include: `Initializing ResNet service with local model: models/<folder>` and `Model and processor loaded successfully`.
202
- - Hitting your endpoint should return a **200** with a JSON body like the example above.
203
- - If you see `Local model directory not found`, check your `models/<name>/` path and filenames.
204
-
205
- ---
206
-
207
- ## 🔌 Request & Response Shapes
208
-
209
- ### Request
210
- ```json
211
- {
212
- "image": {
213
- "mediaType": "image/jpeg",
214
- "data": "<base64-encoded image bytes>"
215
- }
216
- }
217
- ```
218
 
219
- ### Response
220
- ```json
221
- {
222
- "prediction": "string label",
223
- "confidence": 0.0,
224
- "predicted_label": 0,
225
- "model": "your-org/your-model",
226
- "mediaType": "image/jpeg"
227
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ```
229
 
230
- ---
231
 
232
- ## ⚙️ Configuration
233
 
234
- Basic settings live in `app/core/config.py`. Out of the box we keep it simple:
235
- - `app_name`, `app_version`, `debug`
236
 
237
- If you want to make the **model** configurable without touching code, extend `Settings` with a `model_name` env var and consume it in `lifespan.py` when creating your service instance.
238
 
239
- Example:
240
  ```python
241
- # app/core/config.py
242
- from pydantic_settings import BaseSettings
243
- from pydantic import Field
244
-
245
- class Settings(BaseSettings):
246
- app_name: str = Field("ML Inference Service")
247
- app_version: str = Field("0.1.0")
248
- debug: bool = Field(False)
249
- model_name: str = Field("microsoft/resnet-18", description="HF model id used at startup")
250
-
251
- settings = Settings()
252
-
253
- # app/core/lifespan.py
254
- from app.core.config import settings
255
- svc = ResNetInferenceService(model_name=settings.model_name, use_local_model=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  ```
257
 
258
- Then set `MODEL_NAME=your-org/your-model` in your environment (Pydantic will map `model_name` from `MODEL_NAME`).
259
 
260
- ---
261
 
262
- ## 📦 Packaging & Deployment
263
 
264
- - **Dev**: `uvicorn main:app --reload`
265
- - **Prod**: Use a process manager (e.g., `gunicorn -k uvicorn.workers.UvicornWorker`) and add health checks.
266
- - **Containerize**: Copy only `requirements.txt` and source, install wheels, and bake the `models/` folder into the image or mount it as a volume.
267
- - **CPU vs GPU**: This example uses CPU by default. If you have CUDA, install a CUDA-enabled PyTorch build and set device placement in your service.
268
 
269
- ---
 
 
270
 
271
- ## 🧪 PyArrow Test Datasets
272
 
273
- This project includes a comprehensive **PyArrow-based dataset generation system** designed specifically for academic challenges and ML model validation. The system generates **100 standardized test datasets** that allow participants to validate their models against consistent, reproducible test cases.
 
 
274
 
275
- ### 🏗️ Why Both? `.parquet` + `_metadata.json`
276
  ```
277
- standard_test_001.parquet # Actual test data (images, requests, responses)
278
- standard_test_001_metadata.json # Human-readable description and stats
 
 
 
 
279
  ```
280
 
281
- ### 📊 Dataset Categories (25 each = 100 total)
 
 
282
 
283
- #### 1. **Standard Test Cases** (`standard_test_*.parquet`)
284
- **Purpose**: Baseline functionality validation
285
 
286
- **Content**: Normal images with expected successful predictions
 
 
 
 
 
 
287
 
288
- - **Image Types**: Random patterns, geometric shapes, gradients, text overlays, solid colors
289
- - **Formats**: JPEG, PNG with proper MIME types
290
- - **Sizes**: 224x224, 256x256, 299x299, 384x384 (common ML input sizes)
291
- - **Expected Behavior**: HTTP 200 responses with valid prediction structure
292
 
293
- #### 2. **Edge Case Tests** (`edge_case_*.parquet`)
294
- **Purpose**: Robustness and error handling validation
 
295
 
296
- **Content**: Challenging scenarios that test model resilience
 
 
 
 
297
 
298
- - **Tiny Images**: 32x32, 1x1 pixels (tests preprocessing robustness)
299
- - **Huge Images**: 2048x2048 (tests memory management and resizing)
300
- - **Extreme Aspect Ratios**: 1000x50 (tests preprocessing assumptions)
301
- - **Corrupted Data**: Invalid base64, malformed requests (tests error handling)
302
- - **Expected Behavior**: Graceful degradation, proper error responses
303
 
304
- #### 3. **Performance Benchmarks** (`performance_test_*.parquet`)
305
- **Purpose**: Latency and throughput measurement
 
 
306
 
307
- **Content**: Varying batch sizes for performance profiling
 
 
 
308
 
309
- - **Batch Sizes**: 1, 5, 10, 25, 50, 100 images per test
310
- - **Latency Tracking**: Expected max response times based on batch size
311
- - **Throughput Metrics**: Requests per second under different loads
312
- - **Expected Behavior**: Consistent performance within acceptable bounds
313
 
314
- #### 4. **Model Comparison** (`model_comparison_*.parquet`)
315
- **Purpose**: Cross-model validation and benchmarking
 
 
 
316
 
317
- **Content**: Identical inputs tested across different model architectures
318
 
319
- - **Model Types**: ResNet-18/50, ViT, ConvNext, Swin Transformer
320
- - **Consistent Inputs**: Same 10 base images per dataset
321
- - **Comparative Analysis**: Enables direct performance comparison between models
322
- - **Expected Behavior**: Architecture-specific but structurally consistent responses
 
 
 
 
323
 
324
- ### 🛠️ Generation Process
325
 
326
- The dataset generation follows a **deterministic, reproducible approach**:
327
 
328
- #### Step 1: Synthetic Image Creation
329
- ```python
330
- # Why synthetic images instead of real photos?
331
- # 1. Copyright-free for academic distribution
332
- # 3. Programmatically generated edge cases
333
-
334
- def create_synthetic_image(width, height, image_type):
335
- if image_type == "random":
336
- # RGB noise - tests model noise robustness
337
- array = np.random.randint(0, 256, (height, width, 3))
338
- elif image_type == "geometric":
339
- # Shapes and patterns - tests feature detection
340
- # ... geometric pattern generation
341
- # ... other synthetic types
342
- ```
343
 
344
- #### Step 2: API Request Structure Generation
345
- ```python
346
- # Matches exact API format for drop-in testing
347
  {
348
- "image": {
349
- "mediaType": "image/jpeg", # Proper MIME types
350
- "data": "<base64-encoded-image>" # Standard encoding
351
- }
352
  }
353
  ```
354
 
355
- #### Step 3: Expected Response Generation
356
- ```python
357
- # Realistic prediction responses with proper structure
358
  {
359
- "prediction": "tiger_cat", # ImageNet-style labels
360
- "confidence": 0.8742, # Realistic confidence scores
361
- "predicted_label": 282, # Numeric label indices
362
- "model": "microsoft/resnet-18", # Model identification
363
- "mediaType": "image/jpeg" # Echo input format
364
  }
365
  ```
366
 
367
- #### Step 4: PyArrow Table Creation
368
- ```python
369
- # Columnar storage for efficient querying
370
- table = pa.table({
371
- "dataset_id": [...], # Unique dataset identifier
372
- "image_id": [...], # Individual image identifier
373
- "api_request": [...], # JSON-serialized requests
374
- "expected_response": [...], # JSON-serialized expected responses
375
- "test_category": [...], # Category classification
376
- "difficulty": [...], # Complexity indicator
377
- # ... additional metadata columns
378
- })
379
- ```
380
 
381
- ### 🚀 Usage Guide
382
 
 
383
 
384
- **1. Generate Test Datasets**
385
  ```bash
386
- # Create all 100 datasets (~2-5 minutes depending on hardware)
387
  python scripts/generate_test_datasets.py
388
-
389
- # What this creates:
390
- # - scripts/test_datasets/*.parquet (actual test data)
391
- # - scripts/test_datasets/*_metadata.json (human-readable info)
392
- # - scripts/test_datasets/datasets_summary.json (overview)
393
  ```
394
 
395
- **2. Validate API**
 
 
 
 
 
 
396
  ```bash
397
- # Start your ML service
398
  uvicorn main:app --reload
399
 
400
  # Quick test (5 samples per dataset)
401
  python scripts/test_datasets.py --quick
402
 
403
- # Full validation (all samples)
404
  python scripts/test_datasets.py
405
 
406
- # Category-specific testing
407
  python scripts/test_datasets.py --category edge_case
408
- python scripts/test_datasets.py --category performance
409
  ```
410
 
411
- ### 📈 Testing Output and Metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
- The test runner provides comprehensive validation metrics:
 
 
 
 
 
 
 
 
 
 
414
 
415
  ```
416
- 🏁 DATASET TESTING SUMMARY
417
  ============================================================
418
  Datasets tested: 100
419
  Successful datasets: 95
@@ -425,7 +329,7 @@ Test duration: 45.2s
425
  Performance:
426
  Avg latency: 123.4ms
427
  Median latency: 98.7ms
428
- Min latency: 45.2ms
429
  Max latency: 2,341.0ms
430
  Requests/sec: 27.6
431
 
@@ -434,6 +338,29 @@ Category breakdown:
434
  edge_case: 25 datasets, 76.8% avg success
435
  performance: 25 datasets, 91.1% avg success
436
  model_comparison: 25 datasets, 89.3% avg success
 
 
 
 
 
 
 
 
437
 
438
- Failed datasets: edge_case_023, edge_case_019, performance_012
 
439
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ML Inference Service
2
 
3
+ FastAPI service for serving ML models over HTTP. Comes with ResNet-18 for image classification out of the box, but you can swap in any model you want.
 
4
 
5
+ ## Quick Start
6
 
7
+ **Local development:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ```bash
9
+ # Install dependencies
10
  python -m venv .venv
11
+ source .venv/bin/activate
12
  pip install -r requirements.txt
 
13
 
14
+ # Download the example model
 
15
  bash scripts/model_download.bash
16
+
17
+ # Run it
18
+ uvicorn main:app --reload
19
  ```
 
20
 
21
+ Server runs on `http://127.0.0.1:8000`. Check `/docs` for the interactive API documentation.
22
+
23
+ **Docker:**
24
  ```bash
25
+ # Build
26
+ docker build -t ml-inference-service:test .
27
+
28
+ # Run
29
+ docker run -d --name ml-inference-test -p 8000:8000 ml-inference-service:test
30
+
31
+ # Check logs
32
+ docker logs -f ml-inference-test
33
+
34
+ # Stop
35
+ docker stop ml-inference-test && docker rm ml-inference-test
36
  ```
 
37
 
38
+ ## Testing the API
 
39
 
40
  ```bash
41
+ # Using curl
42
+ curl -X POST http://localhost:8000/predict \
43
+ -H "Content-Type: application/json" \
44
+ -d '{
45
+ "image": {
46
+ "mediaType": "image/jpeg",
47
+ "data": "<base64-encoded-image>"
48
+ }
49
  }'
50
  ```
51
 
52
+ Example response:
53
  ```json
54
  {
55
  "prediction": "tiger cat",
56
+ "confidence": 0.394,
57
  "predicted_label": 282,
58
  "model": "microsoft/resnet-18",
59
  "mediaType": "image/jpeg"
60
  }
61
  ```
62
 
63
+ ## Project Structure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ ```
66
+ ml-inference-service/
67
+ ├── main.py # Entry point
68
+ ├── app/
69
+ │ ├── core/
70
+ │ │ ├── app.py # App factory, config, DI, lifecycle
71
+ │ │ └── logging.py # Logging setup
72
+ │ ├── api/
73
+ │ │ ├── models.py # Request/response schemas
74
+ │ │ ├── controllers.py # Business logic
75
+ │ │ └── routes/
76
+ │ │ └── prediction.py # POST /predict
77
+ │ └── services/
78
+ │ ├── base.py # Abstract InferenceService class
79
+ │ └── inference.py # ResNet implementation
80
+ ├── models/
81
+ │ └── microsoft/
82
+ │ └── resnet-18/ # Model weights and config
83
+ ├── scripts/
84
+ │ ├── model_download.bash
85
+ │ ├── generate_test_datasets.py
86
+ │ └── test_datasets.py
87
+ ├── Dockerfile # Multi-stage build
88
+ ├── .env.example # Environment config template
89
+ └── requirements.txt
90
  ```
91
 
92
+ The key design decision here is that `app/core/app.py` consolidates everything—config, dependency injection, lifecycle, and the app factory. This avoids the mess of managing global state across multiple files.
93
 
94
+ ## How to Plug In Your Own Model
95
 
96
+ The whole service is built around one abstract base class: `InferenceService`. Implement it for your model, and everything else just works.
 
97
 
98
+ ### Step 1: Create Your Service Class
99
 
 
100
  ```python
101
+ # app/services/your_model_service.py
102
+ from app.services.base import InferenceService
103
+ from app.api.models import ImageRequest, PredictionResponse
104
+ import asyncio
105
+
106
+ class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
107
+ def __init__(self, model_name: str):
108
+ self.model_name = model_name
109
+ self.model_path = f"models/{model_name}"
110
+ self.model = None
111
+ self._is_loaded = False
112
+
113
+ async def load_model(self) -> None:
114
+ """Load your model here. Called once at startup."""
115
+ self.model = load_your_model(self.model_path)
116
+ self._is_loaded = True
117
+
118
+ async def predict(self, request: ImageRequest) -> PredictionResponse:
119
+ """Run inference. Offload heavy work to thread pool."""
120
+ return await asyncio.to_thread(self._predict_sync, request)
121
+
122
+ def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
123
+ """Actual inference happens here."""
124
+ image = decode_base64_image(request.image.data)
125
+ result = self.model(image)
126
+
127
+ return PredictionResponse(
128
+ prediction=result.label,
129
+ confidence=result.confidence,
130
+ predicted_label=result.class_id,
131
+ model=self.model_name,
132
+ mediaType=request.image.mediaType
133
+ )
134
+
135
+ @property
136
+ def is_loaded(self) -> bool:
137
+ return self._is_loaded
138
  ```
139
 
140
+ **Important:** Use `asyncio.to_thread()` to run CPU-heavy inference in a background thread. This keeps the server responsive while your model is working.
141
 
142
+ ### Step 2: Register Your Service
143
 
144
+ Open `app/core/app.py` and find the lifespan function:
145
 
146
+ ```python
147
+ # Change this line:
148
+ service = ResNetInferenceService(model_name="microsoft/resnet-18")
 
149
 
150
+ # To this:
151
+ service = YourModelService(model_name="your-org/your-model")
152
+ ```
153
 
154
+ That's it. The `/predict` endpoint now serves your model.
155
 
156
+ ### Model Files
157
+
158
+ Put your model files under `models/` with the full org/model structure:
159
 
 
160
  ```
161
+ models/
162
+ └── your-org/
163
+ └── your-model/
164
+ ├── config.json
165
+ ├── weights.bin
166
+ └── (other files)
167
  ```
168
 
169
+ No renaming, no dropping the org prefix—it just mirrors the Hugging Face structure.
170
+
171
+ ## Configuration
172
 
173
+ Settings are managed via environment variables or a `.env` file. See `.env.example` for all available options.
 
174
 
175
+ **Default values:**
176
+ - `APP_NAME`: "ML Inference Service"
177
+ - `APP_VERSION`: "0.1.0"
178
+ - `DEBUG`: false
179
+ - `HOST`: "0.0.0.0"
180
+ - `PORT`: 8000
181
+ - `MODEL_NAME`: "microsoft/resnet-18"
182
 
183
+ **To customize:**
184
+ ```bash
185
+ # Copy the example
186
+ cp .env.example .env
187
 
188
+ # Edit values
189
+ vim .env
190
+ ```
191
 
192
+ Or set environment variables directly:
193
+ ```bash
194
+ export MODEL_NAME="google/vit-base-patch16-224"
195
+ uvicorn main:app --reload
196
+ ```
197
 
198
+ ## Deployment
 
 
 
 
199
 
200
+ **Development:**
201
+ ```bash
202
+ uvicorn main:app --reload
203
+ ```
204
 
205
+ **Production:**
206
+ ```bash
207
+ gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
208
+ ```
209
 
210
+ The service runs on CPU by default. For GPU inference, install CUDA-enabled PyTorch and modify your service to move tensors to the GPU device.
 
 
 
211
 
212
+ **Docker:**
213
+ - Multi-stage build keeps the image small
214
+ - Runs as non-root user (`appuser`)
215
+ - Python dependencies installed in user site-packages
216
+ - Model files baked into the image
217
 
218
+ ## What Happens When You Start the Server
219
 
220
+ ```
221
+ INFO: Starting ML Inference Service...
222
+ INFO: Initializing ResNet service: models/microsoft/resnet-18
223
+ INFO: Loading model from models/microsoft/resnet-18
224
+ INFO: Model loaded: 1000 classes
225
+ INFO: Startup completed successfully
226
+ INFO: Uvicorn running on http://0.0.0.0:8000
227
+ ```
228
 
229
+ If you see "Model directory not found", check that your model files exist at the expected path with the full org/model structure.
230
 
231
+ ## API Reference
232
 
233
+ **Endpoint:** `POST /predict`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ **Request:**
236
+ ```json
 
237
  {
238
+ "image": {
239
+ "mediaType": "image/jpeg", // or "image/png"
240
+ "data": "<base64-encoded-image>"
241
+ }
242
  }
243
  ```
244
 
245
+ **Response:**
246
+ ```json
 
247
  {
248
+ "prediction": "string", // Human-readable label
249
+ "confidence": 0.0, // Softmax probability
250
+ "predicted_label": 0, // Numeric class index
251
+ "model": "org/model-name", // Model identifier
252
+ "mediaType": "image/jpeg" // Echoed from request
253
  }
254
  ```
255
 
256
+ **Docs:**
257
+ - Swagger UI: `http://localhost:8000/docs`
258
+ - ReDoc: `http://localhost:8000/redoc`
259
+ - OpenAPI JSON: `http://localhost:8000/openapi.json`
260
+
261
+ ## PyArrow Test Datasets
 
 
 
 
 
 
 
262
 
263
+ We've included a test dataset system for validating your model. It generates 100 standardized test cases covering normal inputs, edge cases, performance benchmarks, and model comparisons.
264
 
265
+ ### Generate Datasets
266
 
 
267
  ```bash
 
268
  python scripts/generate_test_datasets.py
 
 
 
 
 
269
  ```
270
 
271
+ This creates:
272
+ - `scripts/test_datasets/*.parquet` - Test data (images, requests, expected responses)
273
+ - `scripts/test_datasets/*_metadata.json` - Human-readable descriptions
274
+ - `scripts/test_datasets/datasets_summary.json` - Overview of all datasets
275
+
276
+ ### Run Tests
277
+
278
  ```bash
279
+ # Start your service first
280
  uvicorn main:app --reload
281
 
282
  # Quick test (5 samples per dataset)
283
  python scripts/test_datasets.py --quick
284
 
285
+ # Full validation
286
  python scripts/test_datasets.py
287
 
288
+ # Test specific category
289
  python scripts/test_datasets.py --category edge_case
 
290
  ```
291
 
292
+ ### Dataset Categories (25 datasets each)
293
+
294
+ **1. Standard Tests** (`standard_test_*.parquet`)
295
+ - Normal images: random patterns, shapes, gradients
296
+ - Common sizes: 224x224, 256x256, 299x299, 384x384
297
+ - Formats: JPEG, PNG
298
+ - Purpose: Baseline validation
299
+
300
+ **2. Edge Cases** (`edge_case_*.parquet`)
301
+ - Tiny images (32x32, 1x1)
302
+ - Huge images (2048x2048)
303
+ - Extreme aspect ratios (1000x50)
304
+ - Corrupted data, malformed requests
305
+ - Purpose: Test error handling
306
 
307
+ **3. Performance Benchmarks** (`performance_test_*.parquet`)
308
+ - Batch sizes: 1, 5, 10, 25, 50, 100 images
309
+ - Latency and throughput tracking
310
+ - Purpose: Performance profiling
311
+
312
+ **4. Model Comparisons** (`model_comparison_*.parquet`)
313
+ - Same inputs across different architectures
314
+ - Models: ResNet-18/50, ViT, ConvNext, Swin
315
+ - Purpose: Cross-model benchmarking
316
+
317
+ ### Test Output
318
 
319
  ```
320
+ DATASET TESTING SUMMARY
321
  ============================================================
322
  Datasets tested: 100
323
  Successful datasets: 95
 
329
  Performance:
330
  Avg latency: 123.4ms
331
  Median latency: 98.7ms
332
+ p95 latency: 342.1ms
333
  Max latency: 2,341.0ms
334
  Requests/sec: 27.6
335
 
 
338
  edge_case: 25 datasets, 76.8% avg success
339
  performance: 25 datasets, 91.1% avg success
340
  model_comparison: 25 datasets, 89.3% avg success
341
+ ```
342
+
343
+ ## Common Issues
344
+
345
+ **Port 8000 already in use:**
346
+ ```bash
347
+ # Find what's using it
348
+ lsof -i :8000
349
 
350
+ # Or just use a different port
351
+ uvicorn main:app --port 8080
352
  ```
353
+
354
+ **Model not loading:**
355
+ - Check the path: models should be in `models/<org>/<model-name>/`
356
+ - Make sure you ran `bash scripts/model_download.bash`
357
+ - Check logs for the exact error
358
+
359
+ **Slow inference:**
360
+ - Inference runs on CPU by default
361
+ - For GPU: install CUDA PyTorch and modify service to use GPU device
362
+ - Consider using smaller models or quantization
363
+
364
+ ## License
365
+
366
+ MIT
app/api/controllers.py CHANGED
@@ -1,75 +1,36 @@
1
- """
2
- Controllers for handling API business logic.
3
- """
4
- import base64
5
- import io
6
 
 
7
  from fastapi import HTTPException
8
- from PIL import Image
9
 
10
  from app.core.logging import logger
11
- from app.services.inference import ResNetInferenceService
12
  from app.api.models import ImageRequest, PredictionResponse
13
 
14
 
15
  class PredictionController:
16
- """Controller for ML prediction endpoints."""
17
 
18
  @staticmethod
19
- async def predict_resnet(
20
- request: ImageRequest,
21
- resnet_service: ResNetInferenceService
22
  ) -> PredictionResponse:
23
- """
24
- Classify an image using ResNet-18 from base64 encoded data.
25
- """
26
  try:
27
- # Validate service availability
28
- if not resnet_service:
29
- raise HTTPException(
30
- status_code=503,
31
- detail="Service not initialized"
32
- )
33
 
34
- # Validate media type
35
  if not request.image.mediaType.startswith('image/'):
36
- raise HTTPException(
37
- status_code=400,
38
- detail=f"Invalid media type: {request.image.mediaType}"
39
- )
40
 
41
- # Decode base64 image data
42
- try:
43
- image_data = base64.b64decode(request.image.data)
44
- except Exception as decode_error:
45
- raise HTTPException(
46
- status_code=400,
47
- detail=f"Invalid base64 data: {str(decode_error)}"
48
- )
49
-
50
- # Load and validate image
51
- try:
52
- image = Image.open(io.BytesIO(image_data))
53
- except Exception as img_error:
54
- raise HTTPException(
55
- status_code=400,
56
- detail=f"Invalid image file: {str(img_error)}"
57
- )
58
-
59
- # Perform prediction
60
- result = resnet_service.predict(image)
61
-
62
- # Return structured response
63
- return PredictionResponse(
64
- prediction=result["prediction"],
65
- confidence=result["confidence"],
66
- model=result["model"],
67
- predicted_label=result["predicted_label"],
68
- mediaType=request.image.mediaType
69
- )
70
 
71
  except HTTPException:
72
  raise
 
 
 
73
  except Exception as e:
74
  logger.error(f"Prediction failed: {e}")
75
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ """API controllers for request handling and validation."""
 
 
 
 
2
 
3
+ import asyncio
4
  from fastapi import HTTPException
 
5
 
6
  from app.core.logging import logger
7
+ from app.services.base import InferenceService
8
  from app.api.models import ImageRequest, PredictionResponse
9
 
10
 
11
  class PredictionController:
12
+ """Controller for prediction endpoints."""
13
 
14
  @staticmethod
15
+ async def predict(
16
+ request: ImageRequest,
17
+ service: InferenceService
18
  ) -> PredictionResponse:
19
+ """Run inference using the configured service."""
 
 
20
  try:
21
+ if not service or not service.is_loaded:
22
+ raise HTTPException(503, "Service not available")
 
 
 
 
23
 
 
24
  if not request.image.mediaType.startswith('image/'):
25
+ raise HTTPException(400, f"Invalid media type: {request.image.mediaType}")
 
 
 
26
 
27
+ return await asyncio.to_thread(service.predict, request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  except HTTPException:
30
  raise
31
+ except ValueError as e:
32
+ logger.error(f"Invalid input: {e}")
33
+ raise HTTPException(400, str(e))
34
  except Exception as e:
35
  logger.error(f"Prediction failed: {e}")
36
+ raise HTTPException(500, "Internal server error")
app/api/routes/prediction.py CHANGED
@@ -1,20 +1,23 @@
1
- """
2
- ML Prediction routes.
3
- """
4
  from fastapi import APIRouter, Depends
5
 
6
  from app.api.controllers import PredictionController
7
  from app.api.models import ImageRequest, PredictionResponse
8
- from app.core.dependencies import get_resnet_service
9
- from app.services.inference import ResNetInferenceService
10
 
11
  router = APIRouter()
12
 
13
 
14
- @router.post("/predict/resnet", response_model=PredictionResponse)
15
- async def predict_image(
16
  request: ImageRequest,
17
- resnet_service: ResNetInferenceService = Depends(get_resnet_service)
18
  ):
19
- """Classify an image using ResNet-18 from base64 encoded data."""
20
- return await PredictionController.predict_resnet(request, resnet_service)
 
 
 
 
 
1
+ """Prediction API routes."""
2
+
 
3
  from fastapi import APIRouter, Depends
4
 
5
  from app.api.controllers import PredictionController
6
  from app.api.models import ImageRequest, PredictionResponse
7
+ from app.core.dependencies import get_inference_service
8
+ from app.services.base import InferenceService
9
 
10
  router = APIRouter()
11
 
12
 
13
+ @router.post("/predict", response_model=PredictionResponse)
14
+ async def predict(
15
  request: ImageRequest,
16
+ service: InferenceService = Depends(get_inference_service)
17
  ):
18
+ """
19
+ Run inference on base64-encoded image.
20
+
21
+ Returns prediction, confidence, predicted label, model name, and media type.
22
+ """
23
+ return await PredictionController.predict(request, service)
app/api/routes/resnet_service_manager.py DELETED
@@ -1,19 +0,0 @@
1
- # """
2
- # Dependency injection for FastAPI.
3
- # """
4
- # from typing import Optional
5
- # from app.services.inference import ResNetInferenceService
6
- #
7
- # # Global service instance
8
- # _resnet_service: Optional[ResNetInferenceService] = None
9
- #
10
- #
11
- # def get_resnet_service() -> Optional[ResNetInferenceService]:
12
- # """Get the ResNet service instance."""
13
- # return _resnet_service
14
- #
15
- #
16
- # def set_resnet_service(service: ResNetInferenceService) -> None:
17
- # """Set the global ResNet service instance."""
18
- # global _resnet_service
19
- # _resnet_service = service
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/app.py CHANGED
@@ -1,16 +1,63 @@
1
- """
2
- FastAPI application factory.
3
- """
 
 
 
 
4
  from fastapi import FastAPI
 
 
5
 
6
- from app.core.config import settings
7
- from app.core.lifespan import lifespan
 
8
  from app.api.routes import prediction
9
 
10
 
11
- def create_app() -> FastAPI:
12
- """Application factory."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  app = FastAPI(
15
  title=settings.app_name,
16
  description="ML inference service for image classification",
@@ -19,7 +66,6 @@ def create_app() -> FastAPI:
19
  lifespan=lifespan
20
  )
21
 
22
- # Include only prediction router
23
  app.include_router(prediction.router)
24
 
25
- return app
 
1
+ """FastAPI application factory and core infrastructure."""
2
+
3
+ import asyncio
4
+ import warnings
5
+ from contextlib import asynccontextmanager
6
+ from typing import AsyncGenerator, Optional
7
+
8
  from fastapi import FastAPI
9
+ from pydantic import Field
10
+ from pydantic_settings import BaseSettings
11
 
12
+ from app.core.logging import logger
13
+ from app.core.dependencies import set_inference_service
14
+ from app.services.inference import ResNetInferenceService
15
  from app.api.routes import prediction
16
 
17
 
18
+ class Settings(BaseSettings):
19
+ """Application settings. Override via environment variables or .env file."""
20
+
21
+ app_name: str = Field(default="ML Inference Service")
22
+ app_version: str = Field(default="0.1.0")
23
+ debug: bool = Field(default=False)
24
+ host: str = Field(default="0.0.0.0")
25
+ port: int = Field(default=8000)
26
+
27
+ class Config:
28
+ env_file = ".env"
29
+
30
+
31
+ settings = Settings()
32
+
33
+
34
+ @asynccontextmanager
35
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
36
+ """Application lifecycle: startup/shutdown."""
37
+ logger.info("Starting ML Inference Service...")
38
 
39
+ try:
40
+ with warnings.catch_warnings():
41
+ warnings.filterwarnings("ignore", category=FutureWarning)
42
+
43
+ # Replace ResNetInferenceService with your own implementation
44
+ service = ResNetInferenceService(model_name="microsoft/resnet-18")
45
+ await asyncio.to_thread(service.load_model)
46
+ set_inference_service(service)
47
+
48
+ logger.info("Startup completed successfully")
49
+
50
+ except Exception as e:
51
+ logger.error(f"Startup failed: {e}")
52
+ raise
53
+
54
+ yield
55
+
56
+ logger.info("Shutting down...")
57
+
58
+
59
+ def create_app() -> FastAPI:
60
+ """Create and configure FastAPI application."""
61
  app = FastAPI(
62
  title=settings.app_name,
63
  description="ML inference service for image classification",
 
66
  lifespan=lifespan
67
  )
68
 
 
69
  app.include_router(prediction.router)
70
 
71
+ return app
app/core/config.py DELETED
@@ -1,29 +0,0 @@
1
- """
2
- Basic configuration management.
3
-
4
- Starting simple - just app settings. We'll expand as needed.
5
- """
6
-
7
- from pydantic import Field
8
- from pydantic_settings import BaseSettings # Changed import
9
-
10
-
11
- class Settings(BaseSettings):
12
- """Application settings with environment variable support."""
13
-
14
- # Basic app settings
15
- app_name: str = Field(default="ML Inference Service", description="Application name")
16
- app_version: str = Field(default="0.1.0", description="Application version")
17
- debug: bool = Field(default=False, description="Debug mode")
18
-
19
- # Server settings
20
- host: str = Field(default="0.0.0.0", description="Server host")
21
- port: int = Field(default=8000, description="Server port")
22
-
23
- class Config:
24
- """Load from .env file if it exists."""
25
- env_file = ".env"
26
-
27
-
28
- # Global settings instance
29
- settings = Settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/dependencies.py CHANGED
@@ -1,19 +1,17 @@
1
- """
2
- Dependency injection for FastAPI.
3
- """
4
  from typing import Optional
5
- from app.services.inference import ResNetInferenceService
6
 
7
- # Global service instance
8
- _resnet_service: Optional[ResNetInferenceService] = None
9
 
10
 
11
- def get_resnet_service() -> Optional[ResNetInferenceService]:
12
- """Get the ResNet service instance."""
13
- return _resnet_service
14
 
15
 
16
- def set_resnet_service(service: ResNetInferenceService) -> None:
17
- """Set the global ResNet service instance."""
18
- global _resnet_service
19
- _resnet_service = service
 
1
+ """Dependency injection for services."""
2
+
 
3
  from typing import Optional
4
+ from app.services.base import InferenceService
5
 
6
+ _inference_service: Optional[InferenceService] = None
 
7
 
8
 
9
+ def get_inference_service() -> Optional[InferenceService]:
10
+ """Get inference service for dependency injection."""
11
+ return _inference_service
12
 
13
 
14
+ def set_inference_service(service: InferenceService) -> None:
15
+ """Set inference service. Called internally during startup."""
16
+ global _inference_service
17
+ _inference_service = service
app/core/lifespan.py DELETED
@@ -1,43 +0,0 @@
1
- """
2
- Application lifespan management.
3
- """
4
- import warnings
5
- from contextlib import asynccontextmanager
6
- from typing import AsyncGenerator
7
-
8
- from fastapi import FastAPI
9
-
10
- from app.core.logging import logger
11
- from app.core.dependencies import set_resnet_service
12
- from app.services.inference import ResNetInferenceService
13
-
14
-
15
- @asynccontextmanager
16
- async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
17
- """Application lifespan manager."""
18
-
19
- # Startup
20
- logger.info("Starting ML Inference Service...")
21
-
22
- try:
23
- with warnings.catch_warnings():
24
- warnings.filterwarnings("ignore", category=FutureWarning)
25
-
26
- # Initialize and load ResNet service
27
- resnet_service = ResNetInferenceService(
28
- model_name="microsoft/resnet-18",
29
- use_local_model=True
30
- )
31
- resnet_service.load_model()
32
- set_resnet_service(resnet_service)
33
-
34
- logger.info("Startup completed successfully")
35
-
36
- except Exception as e:
37
- logger.error(f"Startup failed: {e}")
38
- raise
39
-
40
- yield # App runs here
41
-
42
- # Shutdown
43
- logger.info("Shutting down...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/logging.py CHANGED
@@ -1,49 +1,26 @@
1
- """
2
- Logging configuration for the application.
3
- """
4
 
5
  import logging
6
  import sys
7
- from typing import Optional
8
 
9
- from app.core.config import settings
10
 
 
 
 
11
 
12
- class LoggerSetup:
13
- """Logger setup utility class."""
14
-
15
- @staticmethod
16
- def setup_logging(
17
- logger_name: Optional[str] = None,
18
- level: Optional[str] = None,
19
- format_string: Optional[str] = None
20
- ) -> logging.Logger:
21
- """Set up and configure a logger."""
22
- logger = logging.getLogger(logger_name or settings.app_name)
23
-
24
- # Avoid duplicate handlers
25
- if logger.handlers:
26
- return logger
27
-
28
- # Set level
29
- log_level = getattr(logging, (level or "INFO").upper())
30
- logger.setLevel(log_level)
31
-
32
- # Create console handler
33
- handler = logging.StreamHandler(sys.stdout)
34
- handler.setLevel(log_level)
35
-
36
- # Create formatter
37
- formatter = logging.Formatter(
38
- format_string or "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
39
- )
40
- handler.setFormatter(formatter)
41
 
42
- # Add handler to logger
43
- logger.addHandler(handler)
 
 
 
 
 
 
44
 
45
- return logger
46
 
47
 
48
- # Create application logger
49
- logger = LoggerSetup.setup_logging()
 
1
+ """Logging configuration."""
 
 
2
 
3
  import logging
4
  import sys
 
5
 
 
6
 
7
+ def setup_logging(logger_name: str = "ML Inference Service") -> logging.Logger:
8
+ """Setup and configure logger."""
9
+ logger = logging.getLogger(logger_name)
10
 
11
+ if logger.handlers:
12
+ return logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ logger.setLevel(logging.INFO)
15
+ handler = logging.StreamHandler(sys.stdout)
16
+ handler.setLevel(logging.INFO)
17
+ formatter = logging.Formatter(
18
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
+ )
20
+ handler.setFormatter(formatter)
21
+ logger.addHandler(handler)
22
 
23
+ return logger
24
 
25
 
26
+ logger = setup_logging()
 
app/services/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base class for ML inference services."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Generic, TypeVar
5
+ from pydantic import BaseModel
6
+
7
+ TRequest = TypeVar('TRequest', bound=BaseModel)
8
+ TResponse = TypeVar('TResponse', bound=BaseModel)
9
+
10
+
11
+ class InferenceService(ABC, Generic[TRequest, TResponse]):
12
+ """
13
+ Base class for inference services. Subclass this to integrate your model.
14
+ """
15
+
16
+ @abstractmethod
17
+ def load_model(self) -> None:
18
+ """Load model weights and processors. Called once at startup."""
19
+ pass
20
+
21
+ @abstractmethod
22
+ def predict(self, request: TRequest) -> TResponse:
23
+ """Run inference and return typed response."""
24
+ pass
25
+
26
+ @property
27
+ @abstractmethod
28
+ def is_loaded(self) -> bool:
29
+ """Check if model is loaded and ready."""
30
+ pass
app/services/inference.py CHANGED
@@ -1,152 +1,79 @@
1
- """
2
- Inference service for machine learning models.
3
 
4
- This service handles the business logic for ML inference,
5
- following the Single Responsibility Principle.
6
- """
7
  import os
8
- from typing import Dict, Any
 
9
  import torch
10
  from PIL import Image
11
  from transformers import AutoImageProcessor, ResNetForImageClassification
12
 
13
  from app.core.logging import logger
 
 
14
 
15
 
16
- class ResNetInferenceService:
17
- """
18
- ResNet inference service.
19
 
20
- Handles loading and inference for ResNet models.
21
- Follows the Singleton pattern - loads model once.
22
- """
23
-
24
- def __init__(self, model_name: str = "microsoft/resnet-18", use_local_model: bool = True):
25
- """
26
- Initialize the ResNet service.
27
-
28
- Args:
29
- model_name: HuggingFace model identifier
30
- """
31
  self.model_name = model_name
32
- self.use_local_model = use_local_model
33
  self.model = None
34
  self.processor = None
35
  self._is_loaded = False
36
-
37
- if use_local_model:
38
- self.model_path = os.path.join("models", model_name.split("/")[-1])
39
- logger.info(f"Initializing ResNet service with local model: {self.model_path}")
40
- else:
41
- self.model_path = model_name
42
- logger.info(f"Initializing ResNet service with remote model: {model_name}")
43
 
44
  def load_model(self) -> None:
45
- """
46
- Load the ResNet model and processor.
47
-
48
- This method loads the model once and reuses it for all requests.
49
- """
50
  if self._is_loaded:
51
- logger.debug("Model already loaded, skipping...")
52
  return
53
 
54
- try:
55
- if self.use_local_model:
56
- if not os.path.exists(self.model_path):
57
- raise FileNotFoundError(f"Local model directory not found: {self.model_path}")
58
-
59
- config_path = os.path.join(self.model_path, "config.json")
60
- if not os.path.exists(config_path):
61
- raise FileNotFoundError(f"Model config not found: {config_path}")
62
-
63
- logger.info(f"Loading ResNet model from local directory: {self.model_path}")
64
- else:
65
- logger.info(f"Loading ResNet model from HuggingFace Hub: {self.model_name}")
66
-
67
- # Suppress warnings during model loading
68
- import warnings
69
- with warnings.catch_warnings():
70
- warnings.filterwarnings("ignore", category=FutureWarning)
71
- warnings.filterwarnings("ignore", message="Could not find image processor class")
72
-
73
- # Load processor and model from local directory or remote
74
- self.processor = AutoImageProcessor.from_pretrained(
75
- self.model_path,
76
- local_files_only=self.use_local_model
77
- )
78
- self.model = ResNetForImageClassification.from_pretrained(
79
- self.model_path,
80
- local_files_only=self.use_local_model
81
- )
82
-
83
-
84
- self._is_loaded = True
85
- logger.info("ResNet model loaded successfully")
86
- logger.info(f"Model architecture: {self.model.config.architectures}")
87
- logger.info(f"Model has {len(self.model.config.id2label)} classes")
88
-
89
- except Exception as e:
90
- logger.error(f"Failed to load ResNet model: {e}")
91
- if self.use_local_model:
92
- logger.error("Hint: Make sure the model was downloaded correctly with dwl.bash")
93
- raise
94
-
95
-
96
- def predict(self, image: Image.Image) -> Dict[str, Any]:
97
- """
98
- Perform inference on an image.
99
-
100
- Args:
101
- image: PIL Image to classify
102
-
103
- Returns:
104
- Dictionary containing prediction results
105
-
106
- Raises:
107
- RuntimeError: If model is not loaded
108
- ValueError: If image processing fails
109
- """
110
- if not self._is_loaded:
111
- logger.info("Model not loaded, loading now...")
112
- self.load_model()
113
-
114
- try:
115
- logger.debug("Starting ResNet inference")
116
-
117
- if image.mode != 'RGB':
118
- image = image.convert('RGB')
119
- logger.debug(f"Converted image from {image.mode} to RGB")
120
-
121
- inputs = self.processor(image, return_tensors="pt")
122
-
123
- # Perform inference
124
- with torch.no_grad():
125
- logits = self.model(**inputs).logits
126
-
127
- # Get prediction
128
- predicted_label = logits.argmax(-1).item()
129
- predicted_class = self.model.config.id2label[predicted_label]
130
-
131
- # Calculate confidence score
132
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
133
- confidence = probabilities[0][predicted_label].item()
134
-
135
- result = {
136
- "prediction": predicted_class,
137
- "confidence": round(confidence, 4),
138
- "model": self.model_name,
139
- "predicted_label": predicted_label
140
- }
141
-
142
- logger.debug(f"Inference completed: {predicted_class} (confidence: {confidence:.4f})")
143
- return result
144
-
145
- except Exception as e:
146
- logger.error(f"Inference failed: {e}")
147
- raise ValueError(f"Failed to process image: {str(e)}")
148
 
149
  @property
150
  def is_loaded(self) -> bool:
151
- """Check if model is loaded."""
152
  return self._is_loaded
 
1
+ """ResNet inference service implementation."""
 
2
 
 
 
 
3
  import os
4
+ import base64
5
+ from io import BytesIO
6
  import torch
7
  from PIL import Image
8
  from transformers import AutoImageProcessor, ResNetForImageClassification
9
 
10
  from app.core.logging import logger
11
+ from app.services.base import InferenceService
12
+ from app.api.models import ImageRequest, PredictionResponse
13
 
14
 
15
+ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
16
+ """ResNet-18 inference service for image classification."""
 
17
 
18
+ def __init__(self, model_name: str = "microsoft/resnet-18"):
 
 
 
 
 
 
 
 
 
 
19
  self.model_name = model_name
 
20
  self.model = None
21
  self.processor = None
22
  self._is_loaded = False
23
+ self.model_path = os.path.join("models", model_name)
24
+ logger.info(f"Initializing ResNet service: {self.model_path}")
 
 
 
 
 
25
 
26
  def load_model(self) -> None:
 
 
 
 
 
27
  if self._is_loaded:
 
28
  return
29
 
30
+ if not os.path.exists(self.model_path):
31
+ raise FileNotFoundError(f"Model not found: {self.model_path}")
32
+
33
+ config_path = os.path.join(self.model_path, "config.json")
34
+ if not os.path.exists(config_path):
35
+ raise FileNotFoundError(f"Config not found: {config_path}")
36
+
37
+ logger.info(f"Loading model from {self.model_path}")
38
+
39
+ import warnings
40
+ with warnings.catch_warnings():
41
+ warnings.filterwarnings("ignore", category=FutureWarning)
42
+ self.processor = AutoImageProcessor.from_pretrained(
43
+ self.model_path, local_files_only=True
44
+ )
45
+ self.model = ResNetForImageClassification.from_pretrained(
46
+ self.model_path, local_files_only=True
47
+ )
48
+
49
+ self._is_loaded = True
50
+ logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
51
+
52
+ def predict(self, request: ImageRequest) -> PredictionResponse:
53
+ image_data = base64.b64decode(request.image.data)
54
+ image = Image.open(BytesIO(image_data))
55
+
56
+ if image.mode != 'RGB':
57
+ image = image.convert('RGB')
58
+
59
+ inputs = self.processor(image, return_tensors="pt")
60
+
61
+ with torch.no_grad():
62
+ logits = self.model(**inputs).logits
63
+
64
+ predicted_label = logits.argmax(-1).item()
65
+ predicted_class = self.model.config.id2label[predicted_label]
66
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
67
+ confidence = probabilities[0][predicted_label].item()
68
+
69
+ return PredictionResponse(
70
+ prediction=predicted_class,
71
+ confidence=round(confidence, 4),
72
+ model=self.model_name,
73
+ predicted_label=predicted_label,
74
+ mediaType=request.image.mediaType
75
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @property
78
  def is_loaded(self) -> bool:
 
79
  return self._is_loaded
test_main.http CHANGED
@@ -1,6 +1,7 @@
1
- # Test ResNet Prediction Endpoint
 
2
 
3
- POST http://127.0.0.1:8000/predict/resnet
4
  Content-Type: application/json
5
 
6
  {
 
1
+ # Test Prediction Endpoint
2
+ # Works with any model configured at startup (default: ResNet-18)
3
 
4
+ POST http://127.0.0.1:8000/predict
5
  Content-Type: application/json
6
 
7
  {