Spaces:
Running
Running
AbelGAlem commited on
Commit ·
a65c9ed
1
Parent(s): 53ec08f
feat(server): implement FastAPI application with model loading(HF HUB), CORS support, prediction endpoint and Docker
Browse files- .dockerignore +46 -0
- Dockerfile +35 -0
- app/api/routes.py +72 -0
- app/config.py +14 -0
- app/models.py +57 -0
- app/services.py +132 -0
- app/state.py +50 -0
- app/utils.py +22 -0
- main.py +40 -4
- requirements.txt +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
.venv/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
env/
|
| 11 |
+
|
| 12 |
+
# Development
|
| 13 |
+
.env
|
| 14 |
+
.env.*
|
| 15 |
+
.git/
|
| 16 |
+
.gitignore
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
*.swp
|
| 22 |
+
*.swo
|
| 23 |
+
|
| 24 |
+
# OS
|
| 25 |
+
.DS_Store
|
| 26 |
+
Thumbs.db
|
| 27 |
+
|
| 28 |
+
# Logs
|
| 29 |
+
*.log
|
| 30 |
+
|
| 31 |
+
# Temporary files
|
| 32 |
+
*.tmp
|
| 33 |
+
*.temp
|
| 34 |
+
|
| 35 |
+
# Test files
|
| 36 |
+
tests/
|
| 37 |
+
test_*
|
| 38 |
+
*_test.py
|
| 39 |
+
|
| 40 |
+
# Coverage
|
| 41 |
+
.coverage
|
| 42 |
+
htmlcov/
|
| 43 |
+
|
| 44 |
+
# Documentation
|
| 45 |
+
README.md
|
| 46 |
+
*.md
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# apps/server/Dockerfile
|
| 2 |
+
|
| 3 |
+
FROM python:3.11-slim AS base
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
FROM base AS builder
|
| 11 |
+
COPY requirements.txt /app/requirements.txt
|
| 12 |
+
RUN pip install --upgrade pip && \
|
| 13 |
+
pip wheel --no-cache-dir --wheel-dir /app/wheels -r /app/requirements.txt
|
| 14 |
+
|
| 15 |
+
FROM python:3.11-slim AS runtime
|
| 16 |
+
ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
|
| 17 |
+
# add runtime libs you actually need; xgboost often needs libgomp1
|
| 18 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 19 |
+
libgomp1 \
|
| 20 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 21 |
+
WORKDIR /app
|
| 22 |
+
|
| 23 |
+
# deps
|
| 24 |
+
COPY --from=builder /app/wheels /wheels
|
| 25 |
+
RUN pip install --no-cache-dir /wheels/* && rm -rf /wheels
|
| 26 |
+
|
| 27 |
+
# app source
|
| 28 |
+
COPY . /app
|
| 29 |
+
|
| 30 |
+
# non-root
|
| 31 |
+
RUN useradd -m appuser
|
| 32 |
+
USER appuser
|
| 33 |
+
|
| 34 |
+
EXPOSE 8000
|
| 35 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
app/api/routes.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from fastapi import APIRouter, Depends, File, UploadFile, Form, HTTPException
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from ..state import app_state
|
| 10 |
+
from ..utils import normalize_age
|
| 11 |
+
|
| 12 |
+
from fastapi_limiter.depends import RateLimiter
|
| 13 |
+
|
| 14 |
+
from app.config import RATE_TIMES, RATE_SECONDS
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
@router.get("/health")
|
| 19 |
+
def health():
|
| 20 |
+
return {
|
| 21 |
+
"status": "ok",
|
| 22 |
+
"device": str(app_state.device),
|
| 23 |
+
"classes": app_state.id2label,
|
| 24 |
+
"model_loaded": app_state.is_model_loaded()
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@router.post("/predict", dependencies=[Depends(RateLimiter(times=RATE_TIMES, seconds=RATE_SECONDS))],)
|
| 29 |
+
async def predict(
|
| 30 |
+
file: UploadFile = File(..., description="RGB lesion image"),
|
| 31 |
+
age: Optional[float] = Form(None),
|
| 32 |
+
localization: Optional[str] = Form("unknown"),
|
| 33 |
+
top_k: Optional[int] = Form(3),
|
| 34 |
+
):
|
| 35 |
+
if not app_state.is_model_loaded():
|
| 36 |
+
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
| 37 |
+
|
| 38 |
+
# Read image
|
| 39 |
+
try:
|
| 40 |
+
img_bytes = await file.read()
|
| 41 |
+
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 42 |
+
except Exception as e:
|
| 43 |
+
raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
|
| 44 |
+
|
| 45 |
+
# Preprocess image
|
| 46 |
+
px = app_state.image_processor(img, return_tensors="pt")["pixel_values"].to(app_state.device)
|
| 47 |
+
|
| 48 |
+
# Tabular vector
|
| 49 |
+
loc = (localization or "unknown").strip().lower()
|
| 50 |
+
loc_oh = app_state.loc_encoder.transform(np.array([loc]).reshape(-1, 1)) # (1, L)
|
| 51 |
+
norm_age = normalize_age(age, app_state.age_stats["age_min"], app_state.age_stats["age_max"], app_state.age_stats["age_mean"])
|
| 52 |
+
tab = np.concatenate([loc_oh, np.array([[norm_age]])], axis=1).astype("float32")
|
| 53 |
+
tab_t = torch.tensor(tab, dtype=torch.float32, device=app_state.device)
|
| 54 |
+
|
| 55 |
+
# Forward
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
logits = app_state.model(pixel_values=px, tabular_features=tab_t)
|
| 58 |
+
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
|
| 59 |
+
|
| 60 |
+
# Top-k
|
| 61 |
+
k = max(1, min(int(top_k or 3), len(probs)))
|
| 62 |
+
idxs = np.argsort(-probs)[:k]
|
| 63 |
+
top = [{"label": app_state.id2label[int(i)], "probability": float(probs[i])} for i in idxs]
|
| 64 |
+
dist = {app_state.id2label[int(i)]: float(p) for i, p in enumerate(probs)}
|
| 65 |
+
|
| 66 |
+
payload = {
|
| 67 |
+
"top": top
|
| 68 |
+
# "distribution": dist,
|
| 69 |
+
# "accepted_localizations_example": app_state.valid_localizations[:10]
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return JSONResponse(content=payload)
|
app/config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
# Load environment variables from .env file
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
# Redis config
|
| 8 |
+
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
|
| 9 |
+
|
| 10 |
+
# Rate limiting: requests per time window per IP
|
| 11 |
+
RATE_TIMES = int(os.getenv("RATE_TIMES", "60"))
|
| 12 |
+
RATE_SECONDS = int(os.getenv("RATE_SECONDS", "60"))
|
| 13 |
+
|
| 14 |
+
TRUSTED_HOSTS = os.getenv("TRUSTED_HOSTS", "*").split(",")
|
app/models.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import PreTrainedModel, PretrainedConfig, AutoModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SkinCancerConfig(PretrainedConfig):
|
| 7 |
+
model_type = "vit_tabular_skin_cancer"
|
| 8 |
+
|
| 9 |
+
def __init__(self,
|
| 10 |
+
vision_model_checkpoint="google/vit-base-patch16-224-in21k",
|
| 11 |
+
tabular_dim=0,
|
| 12 |
+
num_labels=7,
|
| 13 |
+
id2label=None,
|
| 14 |
+
label2id=None,
|
| 15 |
+
age_min=0.0,
|
| 16 |
+
age_max=100.0,
|
| 17 |
+
age_mean=50.0,
|
| 18 |
+
**kwargs):
|
| 19 |
+
super().__init__(**kwargs)
|
| 20 |
+
self.vision_model_checkpoint = vision_model_checkpoint
|
| 21 |
+
self.tabular_dim = tabular_dim
|
| 22 |
+
self.num_labels = num_labels
|
| 23 |
+
self.id2label = id2label
|
| 24 |
+
self.label2id = label2id
|
| 25 |
+
self.age_min = age_min
|
| 26 |
+
self.age_max = age_max
|
| 27 |
+
self.age_mean = age_mean
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SkinCancerViT(PreTrainedModel):
|
| 31 |
+
config_class = SkinCancerConfig
|
| 32 |
+
|
| 33 |
+
def __init__(self, config):
|
| 34 |
+
super().__init__(config)
|
| 35 |
+
self.vision = AutoModel.from_pretrained(config.vision_model_checkpoint)
|
| 36 |
+
hdim = self.vision.config.hidden_size
|
| 37 |
+
|
| 38 |
+
self.tabular = nn.Sequential(
|
| 39 |
+
nn.Linear(config.tabular_dim, 128),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Dropout(0.1),
|
| 42 |
+
nn.Linear(128, 64),
|
| 43 |
+
nn.ReLU()
|
| 44 |
+
)
|
| 45 |
+
self.classifier = nn.Linear(hdim + 64, config.num_labels)
|
| 46 |
+
self.post_init()
|
| 47 |
+
|
| 48 |
+
def forward(self, pixel_values, tabular_features):
|
| 49 |
+
vout = self.vision(pixel_values=pixel_values, output_hidden_states=False, return_dict=True)
|
| 50 |
+
if getattr(vout, "pooler_output", None) is not None:
|
| 51 |
+
vfeat = vout.pooler_output
|
| 52 |
+
else:
|
| 53 |
+
vfeat = vout.last_hidden_state[:, 0, :] # CLS
|
| 54 |
+
tfeat = self.tabular(tabular_features.float())
|
| 55 |
+
feats = torch.cat([vfeat, tfeat], dim=-1)
|
| 56 |
+
logits = self.classifier(feats)
|
| 57 |
+
return logits
|
app/services.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoImageProcessor, AutoConfig
|
| 5 |
+
from sklearn.preprocessing import OneHotEncoder
|
| 6 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
| 7 |
+
|
| 8 |
+
from fastapi import Request
|
| 9 |
+
|
| 10 |
+
from .state import app_state
|
| 11 |
+
from .models import SkinCancerConfig, SkinCancerViT
|
| 12 |
+
from .utils import load_json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_model():
|
| 16 |
+
"""Load and initialize the model and related components from Hugging Face."""
|
| 17 |
+
print(f"Loading model from Hugging Face: {app_state.HF_REPO_ID}")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
# Download and load label maps from HF
|
| 21 |
+
print("Loading label maps...")
|
| 22 |
+
label2id_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="label2id.json")
|
| 23 |
+
id2label_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="id2label.json")
|
| 24 |
+
|
| 25 |
+
app_state.label2id = load_json(label2id_path)
|
| 26 |
+
id2label_raw = load_json(id2label_path)
|
| 27 |
+
app_state.id2label.update({int(k): v for k, v in id2label_raw.items()})
|
| 28 |
+
print(f"Loaded {len(app_state.id2label)} classes")
|
| 29 |
+
|
| 30 |
+
# Download and load encoder categories
|
| 31 |
+
print("Loading encoder categories...")
|
| 32 |
+
cats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="loc_encoder_categories.npy")
|
| 33 |
+
cats = np.load(cats_path, allow_pickle=True)
|
| 34 |
+
app_state.loc_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
|
| 35 |
+
app_state.loc_encoder.fit(np.array(cats).reshape(-1, 1))
|
| 36 |
+
app_state.valid_localizations[:] = list(cats.tolist())
|
| 37 |
+
print(f"Loaded {len(app_state.valid_localizations)} localizations")
|
| 38 |
+
|
| 39 |
+
# Tabular dim = one-hot length + 1 (age)
|
| 40 |
+
app_state.tab_dim = app_state.loc_encoder.transform(np.array(["unknown"]).reshape(-1, 1)).shape[1] + 1
|
| 41 |
+
print(f"Tabular dimension: {app_state.tab_dim}")
|
| 42 |
+
|
| 43 |
+
# Download and load age stats
|
| 44 |
+
age_stats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="age_stats.json")
|
| 45 |
+
app_state.age_stats.update(load_json(age_stats_path))
|
| 46 |
+
print(f"Age stats: {app_state.age_stats}")
|
| 47 |
+
|
| 48 |
+
# Download and read the HF config to get the vision backbone name
|
| 49 |
+
print("Loading model config...")
|
| 50 |
+
config_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/config.json")
|
| 51 |
+
cfg_json = load_json(config_path)
|
| 52 |
+
app_state.vision_ckpt = cfg_json.get("vision_model_checkpoint", app_state.vision_ckpt)
|
| 53 |
+
print(f"Vision checkpoint: {app_state.vision_ckpt}")
|
| 54 |
+
|
| 55 |
+
app_state.image_processor = AutoImageProcessor.from_pretrained(app_state.vision_ckpt)
|
| 56 |
+
print("Image processor loaded")
|
| 57 |
+
|
| 58 |
+
# Build model config
|
| 59 |
+
print("Building model config...")
|
| 60 |
+
sc_cfg = SkinCancerConfig(
|
| 61 |
+
vision_model_checkpoint=app_state.vision_ckpt,
|
| 62 |
+
tabular_dim=app_state.tab_dim,
|
| 63 |
+
num_labels=len(app_state.id2label),
|
| 64 |
+
id2label=app_state.id2label,
|
| 65 |
+
label2id=app_state.label2id,
|
| 66 |
+
age_min=app_state.age_stats["age_min"],
|
| 67 |
+
age_max=app_state.age_stats["age_max"],
|
| 68 |
+
age_mean=app_state.age_stats["age_mean"]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Initialize empty model with our config
|
| 72 |
+
print("Initializing model...")
|
| 73 |
+
model_init = SkinCancerViT(sc_cfg)
|
| 74 |
+
|
| 75 |
+
# Load weights from HF
|
| 76 |
+
print("Loading model weights from Hugging Face...")
|
| 77 |
+
try:
|
| 78 |
+
# Try to load from safetensors first
|
| 79 |
+
model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/model.safetensors")
|
| 80 |
+
from safetensors.torch import load_file as safe_load
|
| 81 |
+
print(f"Loading from safetensors: {model_path}")
|
| 82 |
+
state = safe_load(model_path)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Safetensors not found, trying pytorch_model.bin: {e}")
|
| 85 |
+
model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/pytorch_model.bin")
|
| 86 |
+
state = torch.load(model_path, map_location="cpu")
|
| 87 |
+
|
| 88 |
+
# Remove training-only keys like loss_fct.weight
|
| 89 |
+
to_drop = [k for k in list(state.keys()) if k.startswith("loss_fct.")]
|
| 90 |
+
for k in to_drop:
|
| 91 |
+
state.pop(k, None)
|
| 92 |
+
|
| 93 |
+
# Load with strict=False to ignore harmless mismatches
|
| 94 |
+
missing, unexpected = model_init.load_state_dict(state, strict=False)
|
| 95 |
+
if unexpected:
|
| 96 |
+
print("Ignored unexpected keys:", unexpected)
|
| 97 |
+
if missing:
|
| 98 |
+
print("Missing keys:", missing)
|
| 99 |
+
|
| 100 |
+
print(f"Using device: {app_state.device}")
|
| 101 |
+
model_init.to(app_state.device)
|
| 102 |
+
model_init.eval()
|
| 103 |
+
app_state.model = model_init
|
| 104 |
+
print("Model loaded successfully from Hugging Face!")
|
| 105 |
+
|
| 106 |
+
# Patch size / grid (if available from vision config)
|
| 107 |
+
try:
|
| 108 |
+
app_state.vit_patch_size = getattr(model_init.vision.config, "patch_size", app_state.vit_patch_size)
|
| 109 |
+
# For square inputs (224×224) with non-overlapping patches
|
| 110 |
+
size = app_state.image_processor.size
|
| 111 |
+
if isinstance(size, dict):
|
| 112 |
+
h = size.get("height", 224)
|
| 113 |
+
w = size.get("width", 224)
|
| 114 |
+
else:
|
| 115 |
+
h = w = size
|
| 116 |
+
app_state.vit_grid = (h // app_state.vit_patch_size, w // app_state.vit_patch_size)
|
| 117 |
+
print(f"ViT grid: {app_state.vit_grid}")
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error setting ViT grid: {e}")
|
| 120 |
+
app_state.vit_patch_size, app_state.vit_grid = app_state.DEFAULT_VIT_PATCH_SIZE, app_state.DEFAULT_VIT_GRID
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Error loading model from Hugging Face: {e}")
|
| 124 |
+
raise
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
async def get_client_ip(request: Request) -> str:
|
| 128 |
+
# First hop of X-Forwarded-For is original client. Fall back to direct socket IP.
|
| 129 |
+
xff = request.headers.get("x-forwarded-for")
|
| 130 |
+
if xff:
|
| 131 |
+
return xff.split(",")[0].strip()
|
| 132 |
+
return request.client.host
|
app/state.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoImageProcessor
|
| 5 |
+
from sklearn.preprocessing import OneHotEncoder
|
| 6 |
+
|
| 7 |
+
from .models import SkinCancerViT
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AppState:
|
| 11 |
+
"""Centralized state management for the application."""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
# Hugging Face configuration
|
| 15 |
+
# Hi there human. What you looking at.
|
| 16 |
+
self.HF_REPO_ID = os.environ.get("HF_REPO_ID", "HelloWorld47474747/skin_vit_tabular")
|
| 17 |
+
|
| 18 |
+
# Default settings
|
| 19 |
+
self.DEFAULT_AGE_STATS = {"age_min": 0.0, "age_max": 100.0, "age_mean": 50.0}
|
| 20 |
+
self.DEFAULT_VIT_PATCH_SIZE = 16
|
| 21 |
+
self.DEFAULT_VIT_GRID = (14, 14)
|
| 22 |
+
self.DEFAULT_VISION_CKPT = "google/vit-base-patch16-224-in21k"
|
| 23 |
+
|
| 24 |
+
# Model state
|
| 25 |
+
self.image_processor: Optional[AutoImageProcessor] = None
|
| 26 |
+
self.model: Optional[SkinCancerViT] = None
|
| 27 |
+
self.label2id: Dict[str, int] = {}
|
| 28 |
+
self.id2label: Dict[int, str] = {}
|
| 29 |
+
self.loc_encoder: Optional[OneHotEncoder] = None
|
| 30 |
+
self.age_stats = self.DEFAULT_AGE_STATS.copy()
|
| 31 |
+
self.tab_dim = 0
|
| 32 |
+
self.valid_localizations: List[str] = []
|
| 33 |
+
self.vit_patch_size = self.DEFAULT_VIT_PATCH_SIZE
|
| 34 |
+
self.vit_grid = self.DEFAULT_VIT_GRID
|
| 35 |
+
self.vision_ckpt = self.DEFAULT_VISION_CKPT
|
| 36 |
+
|
| 37 |
+
# Device
|
| 38 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
|
| 40 |
+
def is_model_loaded(self) -> bool:
|
| 41 |
+
"""Check if the model is loaded."""
|
| 42 |
+
return self.model is not None and self.image_processor is not None
|
| 43 |
+
|
| 44 |
+
def get_device(self) -> torch.device:
|
| 45 |
+
"""Get the current device."""
|
| 46 |
+
return self.device
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Global state instance
|
| 50 |
+
app_state = AppState()
|
app/utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_json(path: str) -> dict:
|
| 7 |
+
"""Load JSON file."""
|
| 8 |
+
with open(path, "r") as f:
|
| 9 |
+
return json.load(f)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def normalize_age(age: Optional[float], amin: float, amax: float, amean: float) -> float:
|
| 13 |
+
"""Normalize age to [0, 1] range."""
|
| 14 |
+
if age is None:
|
| 15 |
+
age = amean
|
| 16 |
+
try:
|
| 17 |
+
age = float(age)
|
| 18 |
+
except Exception:
|
| 19 |
+
age = amean
|
| 20 |
+
if amax == amin:
|
| 21 |
+
return 0.0
|
| 22 |
+
return (age - amin) / (amax - amin)
|
main.py
CHANGED
|
@@ -1,7 +1,43 @@
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
|
|
|
| 2 |
|
| 3 |
-
app
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
from fastapi import FastAPI
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
|
| 6 |
+
from app.services import load_model, get_client_ip
|
| 7 |
+
from app.api.routes import router
|
| 8 |
|
| 9 |
+
from redis import asyncio as redis
|
| 10 |
+
from fastapi_limiter import FastAPILimiter
|
| 11 |
+
from app.config import REDIS_URL, TRUSTED_HOSTS
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@asynccontextmanager
|
| 15 |
+
async def lifespan(app: FastAPI):
|
| 16 |
+
# Startup
|
| 17 |
+
print("Loading model...")
|
| 18 |
+
load_model()
|
| 19 |
+
print("Model loaded successfully!")
|
| 20 |
+
redis_client = redis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
| 21 |
+
await FastAPILimiter.init(redis_client, identifier=get_client_ip)
|
| 22 |
+
yield
|
| 23 |
+
# Shutdown (if needed)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
app = FastAPI(title="Skin Cancer ViT+Tabular API", lifespan=lifespan)
|
| 27 |
+
|
| 28 |
+
# CORS middleware
|
| 29 |
+
app.add_middleware(
|
| 30 |
+
CORSMiddleware,
|
| 31 |
+
allow_origins=TRUSTED_HOSTS if "*" not in TRUSTED_HOSTS else ["*"],
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Include API routes
|
| 38 |
+
app.include_router(router, prefix="/api")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
import uvicorn
|
| 43 |
+
uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))
|
requirements.txt
CHANGED
|
Binary files a/requirements.txt and b/requirements.txt differ
|
|
|