AbelGAlem commited on
Commit
a65c9ed
·
1 Parent(s): 53ec08f

feat(server): implement FastAPI application with model loading(HF HUB), CORS support, prediction endpoint and Docker

Browse files
Files changed (10) hide show
  1. .dockerignore +46 -0
  2. Dockerfile +35 -0
  3. app/api/routes.py +72 -0
  4. app/config.py +14 -0
  5. app/models.py +57 -0
  6. app/services.py +132 -0
  7. app/state.py +50 -0
  8. app/utils.py +22 -0
  9. main.py +40 -4
  10. requirements.txt +0 -0
.dockerignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ .venv/
8
+ venv/
9
+ ENV/
10
+ env/
11
+
12
+ # Development
13
+ .env
14
+ .env.*
15
+ .git/
16
+ .gitignore
17
+
18
+ # IDE
19
+ .vscode/
20
+ .idea/
21
+ *.swp
22
+ *.swo
23
+
24
+ # OS
25
+ .DS_Store
26
+ Thumbs.db
27
+
28
+ # Logs
29
+ *.log
30
+
31
+ # Temporary files
32
+ *.tmp
33
+ *.temp
34
+
35
+ # Test files
36
+ tests/
37
+ test_*
38
+ *_test.py
39
+
40
+ # Coverage
41
+ .coverage
42
+ htmlcov/
43
+
44
+ # Documentation
45
+ README.md
46
+ *.md
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # apps/server/Dockerfile
2
+
3
+ FROM python:3.11-slim AS base
4
+ ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+ WORKDIR /app
9
+
10
+ FROM base AS builder
11
+ COPY requirements.txt /app/requirements.txt
12
+ RUN pip install --upgrade pip && \
13
+ pip wheel --no-cache-dir --wheel-dir /app/wheels -r /app/requirements.txt
14
+
15
+ FROM python:3.11-slim AS runtime
16
+ ENV PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
17
+ # add runtime libs you actually need; xgboost often needs libgomp1
18
+ RUN apt-get update && apt-get install -y --no-install-recommends \
19
+ libgomp1 \
20
+ && rm -rf /var/lib/apt/lists/*
21
+ WORKDIR /app
22
+
23
+ # deps
24
+ COPY --from=builder /app/wheels /wheels
25
+ RUN pip install --no-cache-dir /wheels/* && rm -rf /wheels
26
+
27
+ # app source
28
+ COPY . /app
29
+
30
+ # non-root
31
+ RUN useradd -m appuser
32
+ USER appuser
33
+
34
+ EXPOSE 8000
35
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
app/api/routes.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from fastapi import APIRouter, Depends, File, UploadFile, Form, HTTPException
6
+ from fastapi.responses import JSONResponse
7
+ from typing import Optional
8
+
9
+ from ..state import app_state
10
+ from ..utils import normalize_age
11
+
12
+ from fastapi_limiter.depends import RateLimiter
13
+
14
+ from app.config import RATE_TIMES, RATE_SECONDS
15
+
16
+ router = APIRouter()
17
+
18
+ @router.get("/health")
19
+ def health():
20
+ return {
21
+ "status": "ok",
22
+ "device": str(app_state.device),
23
+ "classes": app_state.id2label,
24
+ "model_loaded": app_state.is_model_loaded()
25
+ }
26
+
27
+
28
+ @router.post("/predict", dependencies=[Depends(RateLimiter(times=RATE_TIMES, seconds=RATE_SECONDS))],)
29
+ async def predict(
30
+ file: UploadFile = File(..., description="RGB lesion image"),
31
+ age: Optional[float] = Form(None),
32
+ localization: Optional[str] = Form("unknown"),
33
+ top_k: Optional[int] = Form(3),
34
+ ):
35
+ if not app_state.is_model_loaded():
36
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
37
+
38
+ # Read image
39
+ try:
40
+ img_bytes = await file.read()
41
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
42
+ except Exception as e:
43
+ raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
44
+
45
+ # Preprocess image
46
+ px = app_state.image_processor(img, return_tensors="pt")["pixel_values"].to(app_state.device)
47
+
48
+ # Tabular vector
49
+ loc = (localization or "unknown").strip().lower()
50
+ loc_oh = app_state.loc_encoder.transform(np.array([loc]).reshape(-1, 1)) # (1, L)
51
+ norm_age = normalize_age(age, app_state.age_stats["age_min"], app_state.age_stats["age_max"], app_state.age_stats["age_mean"])
52
+ tab = np.concatenate([loc_oh, np.array([[norm_age]])], axis=1).astype("float32")
53
+ tab_t = torch.tensor(tab, dtype=torch.float32, device=app_state.device)
54
+
55
+ # Forward
56
+ with torch.no_grad():
57
+ logits = app_state.model(pixel_values=px, tabular_features=tab_t)
58
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
59
+
60
+ # Top-k
61
+ k = max(1, min(int(top_k or 3), len(probs)))
62
+ idxs = np.argsort(-probs)[:k]
63
+ top = [{"label": app_state.id2label[int(i)], "probability": float(probs[i])} for i in idxs]
64
+ dist = {app_state.id2label[int(i)]: float(p) for i, p in enumerate(probs)}
65
+
66
+ payload = {
67
+ "top": top
68
+ # "distribution": dist,
69
+ # "accepted_localizations_example": app_state.valid_localizations[:10]
70
+ }
71
+
72
+ return JSONResponse(content=payload)
app/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ # Load environment variables from .env file
5
+ load_dotenv()
6
+
7
+ # Redis config
8
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
9
+
10
+ # Rate limiting: requests per time window per IP
11
+ RATE_TIMES = int(os.getenv("RATE_TIMES", "60"))
12
+ RATE_SECONDS = int(os.getenv("RATE_SECONDS", "60"))
13
+
14
+ TRUSTED_HOSTS = os.getenv("TRUSTED_HOSTS", "*").split(",")
app/models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig, AutoModel
4
+
5
+
6
+ class SkinCancerConfig(PretrainedConfig):
7
+ model_type = "vit_tabular_skin_cancer"
8
+
9
+ def __init__(self,
10
+ vision_model_checkpoint="google/vit-base-patch16-224-in21k",
11
+ tabular_dim=0,
12
+ num_labels=7,
13
+ id2label=None,
14
+ label2id=None,
15
+ age_min=0.0,
16
+ age_max=100.0,
17
+ age_mean=50.0,
18
+ **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.vision_model_checkpoint = vision_model_checkpoint
21
+ self.tabular_dim = tabular_dim
22
+ self.num_labels = num_labels
23
+ self.id2label = id2label
24
+ self.label2id = label2id
25
+ self.age_min = age_min
26
+ self.age_max = age_max
27
+ self.age_mean = age_mean
28
+
29
+
30
+ class SkinCancerViT(PreTrainedModel):
31
+ config_class = SkinCancerConfig
32
+
33
+ def __init__(self, config):
34
+ super().__init__(config)
35
+ self.vision = AutoModel.from_pretrained(config.vision_model_checkpoint)
36
+ hdim = self.vision.config.hidden_size
37
+
38
+ self.tabular = nn.Sequential(
39
+ nn.Linear(config.tabular_dim, 128),
40
+ nn.ReLU(),
41
+ nn.Dropout(0.1),
42
+ nn.Linear(128, 64),
43
+ nn.ReLU()
44
+ )
45
+ self.classifier = nn.Linear(hdim + 64, config.num_labels)
46
+ self.post_init()
47
+
48
+ def forward(self, pixel_values, tabular_features):
49
+ vout = self.vision(pixel_values=pixel_values, output_hidden_states=False, return_dict=True)
50
+ if getattr(vout, "pooler_output", None) is not None:
51
+ vfeat = vout.pooler_output
52
+ else:
53
+ vfeat = vout.last_hidden_state[:, 0, :] # CLS
54
+ tfeat = self.tabular(tabular_features.float())
55
+ feats = torch.cat([vfeat, tfeat], dim=-1)
56
+ logits = self.classifier(feats)
57
+ return logits
app/services.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoImageProcessor, AutoConfig
5
+ from sklearn.preprocessing import OneHotEncoder
6
+ from huggingface_hub import hf_hub_download, list_repo_files
7
+
8
+ from fastapi import Request
9
+
10
+ from .state import app_state
11
+ from .models import SkinCancerConfig, SkinCancerViT
12
+ from .utils import load_json
13
+
14
+
15
+ def load_model():
16
+ """Load and initialize the model and related components from Hugging Face."""
17
+ print(f"Loading model from Hugging Face: {app_state.HF_REPO_ID}")
18
+
19
+ try:
20
+ # Download and load label maps from HF
21
+ print("Loading label maps...")
22
+ label2id_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="label2id.json")
23
+ id2label_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="id2label.json")
24
+
25
+ app_state.label2id = load_json(label2id_path)
26
+ id2label_raw = load_json(id2label_path)
27
+ app_state.id2label.update({int(k): v for k, v in id2label_raw.items()})
28
+ print(f"Loaded {len(app_state.id2label)} classes")
29
+
30
+ # Download and load encoder categories
31
+ print("Loading encoder categories...")
32
+ cats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="loc_encoder_categories.npy")
33
+ cats = np.load(cats_path, allow_pickle=True)
34
+ app_state.loc_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
35
+ app_state.loc_encoder.fit(np.array(cats).reshape(-1, 1))
36
+ app_state.valid_localizations[:] = list(cats.tolist())
37
+ print(f"Loaded {len(app_state.valid_localizations)} localizations")
38
+
39
+ # Tabular dim = one-hot length + 1 (age)
40
+ app_state.tab_dim = app_state.loc_encoder.transform(np.array(["unknown"]).reshape(-1, 1)).shape[1] + 1
41
+ print(f"Tabular dimension: {app_state.tab_dim}")
42
+
43
+ # Download and load age stats
44
+ age_stats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="age_stats.json")
45
+ app_state.age_stats.update(load_json(age_stats_path))
46
+ print(f"Age stats: {app_state.age_stats}")
47
+
48
+ # Download and read the HF config to get the vision backbone name
49
+ print("Loading model config...")
50
+ config_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/config.json")
51
+ cfg_json = load_json(config_path)
52
+ app_state.vision_ckpt = cfg_json.get("vision_model_checkpoint", app_state.vision_ckpt)
53
+ print(f"Vision checkpoint: {app_state.vision_ckpt}")
54
+
55
+ app_state.image_processor = AutoImageProcessor.from_pretrained(app_state.vision_ckpt)
56
+ print("Image processor loaded")
57
+
58
+ # Build model config
59
+ print("Building model config...")
60
+ sc_cfg = SkinCancerConfig(
61
+ vision_model_checkpoint=app_state.vision_ckpt,
62
+ tabular_dim=app_state.tab_dim,
63
+ num_labels=len(app_state.id2label),
64
+ id2label=app_state.id2label,
65
+ label2id=app_state.label2id,
66
+ age_min=app_state.age_stats["age_min"],
67
+ age_max=app_state.age_stats["age_max"],
68
+ age_mean=app_state.age_stats["age_mean"]
69
+ )
70
+
71
+ # Initialize empty model with our config
72
+ print("Initializing model...")
73
+ model_init = SkinCancerViT(sc_cfg)
74
+
75
+ # Load weights from HF
76
+ print("Loading model weights from Hugging Face...")
77
+ try:
78
+ # Try to load from safetensors first
79
+ model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/model.safetensors")
80
+ from safetensors.torch import load_file as safe_load
81
+ print(f"Loading from safetensors: {model_path}")
82
+ state = safe_load(model_path)
83
+ except Exception as e:
84
+ print(f"Safetensors not found, trying pytorch_model.bin: {e}")
85
+ model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/pytorch_model.bin")
86
+ state = torch.load(model_path, map_location="cpu")
87
+
88
+ # Remove training-only keys like loss_fct.weight
89
+ to_drop = [k for k in list(state.keys()) if k.startswith("loss_fct.")]
90
+ for k in to_drop:
91
+ state.pop(k, None)
92
+
93
+ # Load with strict=False to ignore harmless mismatches
94
+ missing, unexpected = model_init.load_state_dict(state, strict=False)
95
+ if unexpected:
96
+ print("Ignored unexpected keys:", unexpected)
97
+ if missing:
98
+ print("Missing keys:", missing)
99
+
100
+ print(f"Using device: {app_state.device}")
101
+ model_init.to(app_state.device)
102
+ model_init.eval()
103
+ app_state.model = model_init
104
+ print("Model loaded successfully from Hugging Face!")
105
+
106
+ # Patch size / grid (if available from vision config)
107
+ try:
108
+ app_state.vit_patch_size = getattr(model_init.vision.config, "patch_size", app_state.vit_patch_size)
109
+ # For square inputs (224×224) with non-overlapping patches
110
+ size = app_state.image_processor.size
111
+ if isinstance(size, dict):
112
+ h = size.get("height", 224)
113
+ w = size.get("width", 224)
114
+ else:
115
+ h = w = size
116
+ app_state.vit_grid = (h // app_state.vit_patch_size, w // app_state.vit_patch_size)
117
+ print(f"ViT grid: {app_state.vit_grid}")
118
+ except Exception as e:
119
+ print(f"Error setting ViT grid: {e}")
120
+ app_state.vit_patch_size, app_state.vit_grid = app_state.DEFAULT_VIT_PATCH_SIZE, app_state.DEFAULT_VIT_GRID
121
+
122
+ except Exception as e:
123
+ print(f"Error loading model from Hugging Face: {e}")
124
+ raise
125
+
126
+
127
+ async def get_client_ip(request: Request) -> str:
128
+ # First hop of X-Forwarded-For is original client. Fall back to direct socket IP.
129
+ xff = request.headers.get("x-forwarded-for")
130
+ if xff:
131
+ return xff.split(",")[0].strip()
132
+ return request.client.host
app/state.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional
3
+ import torch
4
+ from transformers import AutoImageProcessor
5
+ from sklearn.preprocessing import OneHotEncoder
6
+
7
+ from .models import SkinCancerViT
8
+
9
+
10
+ class AppState:
11
+ """Centralized state management for the application."""
12
+
13
+ def __init__(self):
14
+ # Hugging Face configuration
15
+ # Hi there human. What you looking at.
16
+ self.HF_REPO_ID = os.environ.get("HF_REPO_ID", "HelloWorld47474747/skin_vit_tabular")
17
+
18
+ # Default settings
19
+ self.DEFAULT_AGE_STATS = {"age_min": 0.0, "age_max": 100.0, "age_mean": 50.0}
20
+ self.DEFAULT_VIT_PATCH_SIZE = 16
21
+ self.DEFAULT_VIT_GRID = (14, 14)
22
+ self.DEFAULT_VISION_CKPT = "google/vit-base-patch16-224-in21k"
23
+
24
+ # Model state
25
+ self.image_processor: Optional[AutoImageProcessor] = None
26
+ self.model: Optional[SkinCancerViT] = None
27
+ self.label2id: Dict[str, int] = {}
28
+ self.id2label: Dict[int, str] = {}
29
+ self.loc_encoder: Optional[OneHotEncoder] = None
30
+ self.age_stats = self.DEFAULT_AGE_STATS.copy()
31
+ self.tab_dim = 0
32
+ self.valid_localizations: List[str] = []
33
+ self.vit_patch_size = self.DEFAULT_VIT_PATCH_SIZE
34
+ self.vit_grid = self.DEFAULT_VIT_GRID
35
+ self.vision_ckpt = self.DEFAULT_VISION_CKPT
36
+
37
+ # Device
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ def is_model_loaded(self) -> bool:
41
+ """Check if the model is loaded."""
42
+ return self.model is not None and self.image_processor is not None
43
+
44
+ def get_device(self) -> torch.device:
45
+ """Get the current device."""
46
+ return self.device
47
+
48
+
49
+ # Global state instance
50
+ app_state = AppState()
app/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ from typing import Optional
4
+
5
+
6
+ def load_json(path: str) -> dict:
7
+ """Load JSON file."""
8
+ with open(path, "r") as f:
9
+ return json.load(f)
10
+
11
+
12
+ def normalize_age(age: Optional[float], amin: float, amax: float, amean: float) -> float:
13
+ """Normalize age to [0, 1] range."""
14
+ if age is None:
15
+ age = amean
16
+ try:
17
+ age = float(age)
18
+ except Exception:
19
+ age = amean
20
+ if amax == amin:
21
+ return 0.0
22
+ return (age - amin) / (amax - amin)
main.py CHANGED
@@ -1,7 +1,43 @@
 
 
1
  from fastapi import FastAPI
 
2
 
3
- app = FastAPI()
 
4
 
5
- @app.get("/health")
6
- def read_root():
7
- return {"message": "Hello?"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
 
6
+ from app.services import load_model, get_client_ip
7
+ from app.api.routes import router
8
 
9
+ from redis import asyncio as redis
10
+ from fastapi_limiter import FastAPILimiter
11
+ from app.config import REDIS_URL, TRUSTED_HOSTS
12
+
13
+
14
+ @asynccontextmanager
15
+ async def lifespan(app: FastAPI):
16
+ # Startup
17
+ print("Loading model...")
18
+ load_model()
19
+ print("Model loaded successfully!")
20
+ redis_client = redis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
21
+ await FastAPILimiter.init(redis_client, identifier=get_client_ip)
22
+ yield
23
+ # Shutdown (if needed)
24
+
25
+
26
+ app = FastAPI(title="Skin Cancer ViT+Tabular API", lifespan=lifespan)
27
+
28
+ # CORS middleware
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=TRUSTED_HOSTS if "*" not in TRUSTED_HOSTS else ["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ # Include API routes
38
+ app.include_router(router, prefix="/api")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ import uvicorn
43
+ uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ