Spaces:
Running
Running
Daniel Huynh commited on
Commit ·
cb92718
0
Parent(s):
Deploy FastAPI derm backend to Hugging Face Spaces
Browse files- .dockerignore +5 -0
- .gitattributes +1 -0
- .gitignore +24 -0
- Dockerfile +17 -0
- README.md +125 -0
- app/__init__.py +0 -0
- app/config.py +26 -0
- app/main.py +71 -0
- app/models/__init__.py +0 -0
- app/models/mlp_head.py +68 -0
- app/schemas.py +15 -0
- app/services/__init__.py +0 -0
- app/services/derm_backbone.py +54 -0
- app/services/predictor.py +89 -0
- app/services/preprocessing.py +51 -0
- class_names.json +25 -0
- derm_foundation_mlp_head.pt +3 -0
- requirements.txt +9 -0
- scripts/test_request.py +15 -0
.dockerignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
.git/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
.DS_Store
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Secrets
|
| 2 |
+
.env
|
| 3 |
+
|
| 4 |
+
# Python
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
|
| 10 |
+
# Virtual environments
|
| 11 |
+
venv/
|
| 12 |
+
.venv/
|
| 13 |
+
env/
|
| 14 |
+
|
| 15 |
+
# Jupyter
|
| 16 |
+
.ipynb_checkpoints/
|
| 17 |
+
|
| 18 |
+
# Mac / Windows
|
| 19 |
+
.DS_Store
|
| 20 |
+
Thumbs.db
|
| 21 |
+
|
| 22 |
+
# Hugging Face / model cache
|
| 23 |
+
.cache/
|
| 24 |
+
huggingface/
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13.13-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV PIP_NO_CACHE_DIR=1
|
| 7 |
+
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
|
| 10 |
+
RUN pip install --upgrade pip
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
COPY . .
|
| 14 |
+
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Basic Docker SDK Space
|
| 3 |
+
emoji: 🐳
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Derm Foundation FastAPI Two-Stage Classifier
|
| 11 |
+
|
| 12 |
+
This project deploys a two-stage inference pipeline:
|
| 13 |
+
|
| 14 |
+
```text
|
| 15 |
+
image -> Google Derm Foundation SavedModel -> embedding -> PyTorch MLP head -> class probabilities
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
The preprocessing follows the notebook pipeline:
|
| 19 |
+
|
| 20 |
+
```text
|
| 21 |
+
RGB -> resize 448x448 -> PNG bytes -> tf.train.Example with key image/encoded
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## Project structure
|
| 25 |
+
|
| 26 |
+
```text
|
| 27 |
+
derm_fastapi_project/
|
| 28 |
+
app/
|
| 29 |
+
main.py # FastAPI app and endpoints
|
| 30 |
+
config.py # Environment settings
|
| 31 |
+
schemas.py # API response models
|
| 32 |
+
services/
|
| 33 |
+
preprocessing.py # image -> serialized tf.train.Example
|
| 34 |
+
derm_backbone.py # Derm Foundation wrapper
|
| 35 |
+
predictor.py # two-stage sequential forward pass
|
| 36 |
+
models/
|
| 37 |
+
mlp_head.py # load model_state_dict from .pt checkpoint
|
| 38 |
+
scripts/
|
| 39 |
+
test_request.py # local API test client
|
| 40 |
+
requirements.txt
|
| 41 |
+
class_names.json # replace with your real class order
|
| 42 |
+
.env.example
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Setup
|
| 46 |
+
|
| 47 |
+
Create a virtual environment, then install dependencies:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Put your PyTorch checkpoint in the project root:
|
| 54 |
+
|
| 55 |
+
```text
|
| 56 |
+
derm_foundation_mlp_head.pt
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
The checkpoint should contain:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
{
|
| 63 |
+
"model_state_dict": mlp_head.state_dict(),
|
| 64 |
+
# optional but recommended:
|
| 65 |
+
"class_names": [...]
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
If the checkpoint does not contain `class_names`, edit `class_names.json` so the order exactly matches your training label order.
|
| 70 |
+
|
| 71 |
+
## Hugging Face token
|
| 72 |
+
|
| 73 |
+
Do not put your token in the source code.
|
| 74 |
+
|
| 75 |
+
Use an environment variable:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
export HF_TOKEN="hf_your_token_here"
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
You must already have access to `google/derm-foundation` on Hugging Face.
|
| 82 |
+
|
| 83 |
+
## Run the API
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Test in browser:
|
| 90 |
+
|
| 91 |
+
```text
|
| 92 |
+
http://127.0.0.1:8000/docs
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Test with Python:
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
python scripts/test_request.py path/to/image.jpg
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## API endpoint
|
| 102 |
+
|
| 103 |
+
### POST `/predict`
|
| 104 |
+
|
| 105 |
+
Input: multipart image upload named `file`.
|
| 106 |
+
|
| 107 |
+
Output:
|
| 108 |
+
|
| 109 |
+
```json
|
| 110 |
+
{
|
| 111 |
+
"predicted_index": 0,
|
| 112 |
+
"predicted_class": "class_0",
|
| 113 |
+
"confidence": 0.91,
|
| 114 |
+
"probabilities": [
|
| 115 |
+
{"index": 0, "class_name": "class_0", "probability": 0.91},
|
| 116 |
+
{"index": 1, "class_name": "class_1", "probability": 0.02}
|
| 117 |
+
]
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Important note about the MLP head
|
| 122 |
+
|
| 123 |
+
`app/models/mlp_head.py` reconstructs the MLP from Linear layer tensors in `model_state_dict`.
|
| 124 |
+
It assumes Linear layers with ReLU between hidden layers. This is usually fine for a simple MLP head.
|
| 125 |
+
If your original head used a different activation, BatchNorm, or a more complex custom architecture, replace `InferredMLPHead` with the exact same class used during training.
|
app/__init__.py
ADDED
|
File without changes
|
app/config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class Settings:
|
| 9 |
+
# Hugging Face / Derm Foundation
|
| 10 |
+
derm_model_id: str = os.getenv("DERM_MODEL_ID", "google/derm-foundation")
|
| 11 |
+
hf_token: Optional[str] = os.getenv("HF_TOKEN")
|
| 12 |
+
local_files_only: bool = os.getenv("HF_LOCAL_FILES_ONLY", "false").lower() == "true"
|
| 13 |
+
|
| 14 |
+
# Model artifacts
|
| 15 |
+
head_checkpoint_path: Path = Path(os.getenv("HEAD_CHECKPOINT_PATH", "derm_foundation_mlp_head.pt"))
|
| 16 |
+
class_names_path: Path = Path(os.getenv("CLASS_NAMES_PATH", "class_names.json"))
|
| 17 |
+
|
| 18 |
+
# Inference
|
| 19 |
+
image_size: int = int(os.getenv("DERM_IMAGE_SIZE", "448"))
|
| 20 |
+
device: str = os.getenv("TORCH_DEVICE", "auto")
|
| 21 |
+
|
| 22 |
+
# API
|
| 23 |
+
cors_origins: str = os.getenv("CORS_ORIGINS", "*")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
settings = Settings()
|
app/main.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
|
| 4 |
+
from app.config import settings
|
| 5 |
+
from app.schemas import PredictionResponse
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
app = FastAPI(
|
| 9 |
+
title="Derm Foundation Classifier API",
|
| 10 |
+
description="Derm Foundation embedding backbone + PyTorch MLP head.",
|
| 11 |
+
version="1.0.0",
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
app.add_middleware(
|
| 16 |
+
CORSMiddleware,
|
| 17 |
+
allow_origins=[origin.strip() for origin in settings.cors_origins.split(",")],
|
| 18 |
+
allow_credentials=False,
|
| 19 |
+
allow_methods=["*"],
|
| 20 |
+
allow_headers=["*"],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
app.state.predictor = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_predictor():
|
| 28 |
+
if app.state.predictor is None:
|
| 29 |
+
print("Loading TwoStageDermPredictor...", flush=True)
|
| 30 |
+
|
| 31 |
+
from app.services.predictor import TwoStageDermPredictor
|
| 32 |
+
|
| 33 |
+
app.state.predictor = TwoStageDermPredictor(
|
| 34 |
+
derm_model_id=settings.derm_model_id,
|
| 35 |
+
head_checkpoint_path=str(settings.head_checkpoint_path),
|
| 36 |
+
hf_token=settings.hf_token,
|
| 37 |
+
local_files_only=settings.local_files_only,
|
| 38 |
+
image_size=settings.image_size,
|
| 39 |
+
device_name=settings.device,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
print("TwoStageDermPredictor loaded.", flush=True)
|
| 43 |
+
|
| 44 |
+
return app.state.predictor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@app.get("/")
|
| 48 |
+
def root():
|
| 49 |
+
return {"message": "Derm Foundation API is running"}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@app.get("/health")
|
| 53 |
+
def health():
|
| 54 |
+
return {"status": "ok"}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 58 |
+
async def predict(file: UploadFile = File(...)):
|
| 59 |
+
if file.content_type is not None and not file.content_type.startswith("image/"):
|
| 60 |
+
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
|
| 61 |
+
|
| 62 |
+
image_bytes = await file.read()
|
| 63 |
+
|
| 64 |
+
if not image_bytes:
|
| 65 |
+
raise HTTPException(status_code=400, detail="Uploaded image is empty.")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
predictor = get_predictor()
|
| 69 |
+
return predictor.predict(image_bytes)
|
| 70 |
+
except Exception as exc:
|
| 71 |
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/mlp_head.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
DROPOUT = 0.6
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DermFoundationMLPHead(nn.Sequential):
|
| 9 |
+
"""
|
| 10 |
+
Exact MLP head used after Derm Foundation embeddings.
|
| 11 |
+
|
| 12 |
+
Architecture:
|
| 13 |
+
Linear(input_dim, 512) -> ReLU -> Dropout(0.6)
|
| 14 |
+
Linear(512, 256) -> ReLU -> Dropout(0.6)
|
| 15 |
+
Linear(256, 128) -> ReLU -> Dropout(0.6)
|
| 16 |
+
Linear(128, num_classes)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, input_dim: int, num_classes: int):
|
| 20 |
+
super().__init__(
|
| 21 |
+
nn.Linear(input_dim, 512),
|
| 22 |
+
nn.ReLU(),
|
| 23 |
+
nn.Dropout(DROPOUT),
|
| 24 |
+
|
| 25 |
+
nn.Linear(512, 256),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
nn.Dropout(DROPOUT),
|
| 28 |
+
|
| 29 |
+
nn.Linear(256, 128),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Dropout(DROPOUT),
|
| 32 |
+
|
| 33 |
+
nn.Linear(128, num_classes),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_mlp_head_from_checkpoint(
|
| 38 |
+
checkpoint_path: str,
|
| 39 |
+
device: torch.device,
|
| 40 |
+
) -> tuple[nn.Module, dict]:
|
| 41 |
+
"""
|
| 42 |
+
Load derm_foundation_mlp_head.pt.
|
| 43 |
+
|
| 44 |
+
Expected checkpoint format:
|
| 45 |
+
{
|
| 46 |
+
"model_state_dict": model.state_dict(),
|
| 47 |
+
...
|
| 48 |
+
}
|
| 49 |
+
"""
|
| 50 |
+
checkpoint = torch.load(
|
| 51 |
+
checkpoint_path,
|
| 52 |
+
map_location=device,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
state_dict = checkpoint["model_state_dict"]
|
| 56 |
+
|
| 57 |
+
input_dim = int(state_dict["0.weight"].shape[1])
|
| 58 |
+
num_classes = int(state_dict["9.weight"].shape[0])
|
| 59 |
+
|
| 60 |
+
head = DermFoundationMLPHead(
|
| 61 |
+
input_dim=input_dim,
|
| 62 |
+
num_classes=num_classes,
|
| 63 |
+
).to(device)
|
| 64 |
+
|
| 65 |
+
head.load_state_dict(state_dict, strict=True)
|
| 66 |
+
head.eval()
|
| 67 |
+
|
| 68 |
+
return head, checkpoint
|
app/schemas.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ClassProbability(BaseModel):
|
| 6 |
+
index: int
|
| 7 |
+
class_name: str
|
| 8 |
+
probability: float
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PredictionResponse(BaseModel):
|
| 12 |
+
predicted_index: int
|
| 13 |
+
predicted_class: str
|
| 14 |
+
confidence: float
|
| 15 |
+
probabilities: List[ClassProbability]
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/derm_backbone.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
from huggingface_hub import snapshot_download
|
| 7 |
+
from app.services.preprocessing import image_bytes_to_tf_string_tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DermFoundationBackbone:
|
| 11 |
+
"""
|
| 12 |
+
Thin wrapper around the Google Derm Foundation SavedModel.
|
| 13 |
+
It converts image bytes into the model's serialized tf.Example input
|
| 14 |
+
and returns the 6144-d embedding.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
repo_id: str = "google/derm-foundation",
|
| 20 |
+
token: str | None = None,
|
| 21 |
+
local_files_only: bool = False,
|
| 22 |
+
image_size: int = 448,
|
| 23 |
+
) -> None:
|
| 24 |
+
|
| 25 |
+
self.repo_id = repo_id
|
| 26 |
+
self.image_size: Tuple[int, int] = (image_size, image_size)
|
| 27 |
+
|
| 28 |
+
model_path = snapshot_download(
|
| 29 |
+
repo_id=repo_id,
|
| 30 |
+
token=token,
|
| 31 |
+
local_files_only=local_files_only,
|
| 32 |
+
)
|
| 33 |
+
self.model_path = Path(model_path)
|
| 34 |
+
self.model = tf.saved_model.load(str(self.model_path))
|
| 35 |
+
self.infer = self.model.signatures["serving_default"]
|
| 36 |
+
|
| 37 |
+
def image_to_embedding(self, image_bytes: bytes) -> np.ndarray:
|
| 38 |
+
"""
|
| 39 |
+
Return embedding with shape [1, embedding_dim].
|
| 40 |
+
Derm Foundation normally returns key: "embedding".
|
| 41 |
+
"""
|
| 42 |
+
tf_inputs = image_bytes_to_tf_string_tensor(image_bytes, img_size=self.image_size)
|
| 43 |
+
|
| 44 |
+
# Your notebook used infer(inputs=tf_inputs). Keep that first.
|
| 45 |
+
try:
|
| 46 |
+
output = self.infer(inputs=tf_inputs)
|
| 47 |
+
except TypeError:
|
| 48 |
+
output = self.infer(tf_inputs)
|
| 49 |
+
|
| 50 |
+
if "embedding" not in output:
|
| 51 |
+
available = ", ".join(output.keys())
|
| 52 |
+
raise KeyError(f"Expected output key 'embedding'. Available keys: {available}")
|
| 53 |
+
|
| 54 |
+
return output["embedding"].numpy().astype("float32")
|
app/services/predictor.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from app.models.mlp_head import build_mlp_head_from_checkpoint
|
| 7 |
+
from app.services.derm_backbone import DermFoundationBackbone
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_class_names() -> dict[int, str]:
|
| 11 |
+
project_root = Path(__file__).resolve().parents[2]
|
| 12 |
+
class_names_path = project_root / "class_names.json"
|
| 13 |
+
|
| 14 |
+
with open(class_names_path, "r", encoding="utf-8") as f:
|
| 15 |
+
raw_class_names = json.load(f)
|
| 16 |
+
|
| 17 |
+
return {int(index): name for index, name in raw_class_names.items()}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TwoStageDermPredictor:
|
| 21 |
+
"""
|
| 22 |
+
Stage 1: Derm Foundation image -> embedding.
|
| 23 |
+
Stage 2: PyTorch MLP head embedding -> class probabilities.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
derm_model_id: str,
|
| 29 |
+
head_checkpoint_path: str,
|
| 30 |
+
hf_token: str | None = None,
|
| 31 |
+
local_files_only: bool = False,
|
| 32 |
+
image_size: int = 448,
|
| 33 |
+
device_name: str = "auto",
|
| 34 |
+
) -> None:
|
| 35 |
+
if device_name == "auto":
|
| 36 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
else:
|
| 38 |
+
self.device = torch.device(device_name)
|
| 39 |
+
|
| 40 |
+
self.class_names = load_class_names()
|
| 41 |
+
|
| 42 |
+
self.backbone = DermFoundationBackbone(
|
| 43 |
+
repo_id=derm_model_id,
|
| 44 |
+
token=hf_token,
|
| 45 |
+
local_files_only=local_files_only,
|
| 46 |
+
image_size=image_size,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.head, _ = build_mlp_head_from_checkpoint(
|
| 50 |
+
checkpoint_path=head_checkpoint_path,
|
| 51 |
+
device=self.device,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
output_dim = self.head[-1].out_features
|
| 55 |
+
|
| 56 |
+
if output_dim != len(self.class_names):
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"MLP output dimension is {output_dim}, "
|
| 59 |
+
f"but class_names.json contains {len(self.class_names)} classes."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def predict(self, image_bytes: bytes) -> dict:
|
| 63 |
+
embedding_np = self.backbone.image_to_embedding(image_bytes)
|
| 64 |
+
embedding = torch.from_numpy(embedding_np).float().to(self.device)
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
logits = self.head(embedding)
|
| 68 |
+
probs = torch.softmax(logits, dim=1)[0].cpu()
|
| 69 |
+
|
| 70 |
+
pred_idx = int(torch.argmax(probs).item())
|
| 71 |
+
confidence = float(probs[pred_idx].item())
|
| 72 |
+
|
| 73 |
+
print(self.class_names)
|
| 74 |
+
|
| 75 |
+
probabilities = [
|
| 76 |
+
{
|
| 77 |
+
"index": i,
|
| 78 |
+
"class_name": self.class_names[i],
|
| 79 |
+
"probability": float(prob),
|
| 80 |
+
}
|
| 81 |
+
for i, prob in enumerate(probs.tolist())
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"predicted_index": pred_idx,
|
| 86 |
+
"predicted_class": self.class_names[pred_idx],
|
| 87 |
+
"confidence": confidence,
|
| 88 |
+
"probabilities": probabilities,
|
| 89 |
+
}
|
app/services/preprocessing.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
DERM_FOUNDATION_INPUT_SIZE = (448, 448)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pil_to_serialized_example(
|
| 12 |
+
img: Image.Image,
|
| 13 |
+
img_size: Tuple[int, int] = DERM_FOUNDATION_INPUT_SIZE,
|
| 14 |
+
) -> bytes:
|
| 15 |
+
"""
|
| 16 |
+
Convert one PIL image into the serialized tf.train.Example format
|
| 17 |
+
expected by Google Derm Foundation.
|
| 18 |
+
|
| 19 |
+
Pipeline:
|
| 20 |
+
RGB -> resize -> PNG bytes -> tf.train.Example with key image/encoded
|
| 21 |
+
"""
|
| 22 |
+
img = img.convert("RGB")
|
| 23 |
+
img = img.resize(img_size, resample=Image.BILINEAR)
|
| 24 |
+
|
| 25 |
+
buffer = io.BytesIO()
|
| 26 |
+
img.save(buffer, format="PNG")
|
| 27 |
+
image_bytes = buffer.getvalue()
|
| 28 |
+
|
| 29 |
+
example = tf.train.Example(
|
| 30 |
+
features=tf.train.Features(
|
| 31 |
+
feature={
|
| 32 |
+
"image/encoded": tf.train.Feature(
|
| 33 |
+
bytes_list=tf.train.BytesList(value=[image_bytes])
|
| 34 |
+
)
|
| 35 |
+
}
|
| 36 |
+
)
|
| 37 |
+
)
|
| 38 |
+
return example.SerializeToString()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def image_bytes_to_tf_string_tensor(
|
| 42 |
+
image_bytes: bytes,
|
| 43 |
+
img_size: Tuple[int, int] = DERM_FOUNDATION_INPUT_SIZE,
|
| 44 |
+
) -> tf.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Convert uploaded image bytes into a batch of one tf.string input.
|
| 47 |
+
"""
|
| 48 |
+
with Image.open(io.BytesIO(image_bytes)) as img:
|
| 49 |
+
serialized = pil_to_serialized_example(img, img_size=img_size)
|
| 50 |
+
|
| 51 |
+
return tf.constant([serialized], dtype=tf.string)
|
class_names.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0": "Acne and Rosacea Photos",
|
| 3 |
+
"1": "Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions",
|
| 4 |
+
"2": "Atopic Dermatitis Photos",
|
| 5 |
+
"3": "Bullous Disease Photos",
|
| 6 |
+
"4": "Cellulitis Impetigo and other Bacterial Infections",
|
| 7 |
+
"5": "Eczema Photos",
|
| 8 |
+
"6": "Exanthems and Drug Eruptions",
|
| 9 |
+
"7": "Hair Loss Photos Alopecia and other Hair Diseases",
|
| 10 |
+
"8": "Herpes HPV and other STDs Photos",
|
| 11 |
+
"9": "Light Diseases and Disorders of Pigmentation",
|
| 12 |
+
"10": "Lupus and other Connective Tissue diseases",
|
| 13 |
+
"11": "Melanoma Skin Cancer Nevi and Moles",
|
| 14 |
+
"12": "Nail Fungus and other Nail Disease",
|
| 15 |
+
"13": "Poison Ivy Photos and other Contact Dermatitis",
|
| 16 |
+
"14": "Psoriasis pictures Lichen Planus and related diseases",
|
| 17 |
+
"15": "Scabies Lyme Disease and other Infestations and Bites",
|
| 18 |
+
"16": "Seborrheic Keratoses and other Benign Tumors",
|
| 19 |
+
"17": "Systemic Disease",
|
| 20 |
+
"18": "Tinea Ringworm Candidiasis and other Fungal Infections",
|
| 21 |
+
"19": "Urticaria Hives",
|
| 22 |
+
"20": "Vascular Tumors",
|
| 23 |
+
"21": "Vasculitis Photos",
|
| 24 |
+
"22": "Warts Molluscum and other Viral Infections"
|
| 25 |
+
}
|
derm_foundation_mlp_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3800ab71987278a5d6885e97742be350e4f063e4f19cff20126385be0ab8c25c
|
| 3 |
+
size 13257851
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn[standard]
|
| 3 |
+
python-multipart
|
| 4 |
+
numpy
|
| 5 |
+
tensorflow-cpu
|
| 6 |
+
huggingface_hub==0.36.2
|
| 7 |
+
Pillow
|
| 8 |
+
torch
|
| 9 |
+
requests
|
scripts/test_request.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import requests
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if len(sys.argv) != 2:
|
| 6 |
+
raise SystemExit("Usage: python scripts/test_request.py path/to/image.jpg")
|
| 7 |
+
|
| 8 |
+
image_path = sys.argv[1]
|
| 9 |
+
url = "http://127.0.0.1:8000/predict"
|
| 10 |
+
|
| 11 |
+
with open(image_path, "rb") as f:
|
| 12 |
+
response = requests.post(url, files={"file": f})
|
| 13 |
+
|
| 14 |
+
print(response.status_code)
|
| 15 |
+
print(response.json())
|