ScanServer / app /services.py
AbelGAlem
feat(server): implement FastAPI application with model loading(HF HUB), CORS support, prediction endpoint and Docker
a65c9ed
import os
import torch
import numpy as np
from transformers import AutoImageProcessor, AutoConfig
from sklearn.preprocessing import OneHotEncoder
from huggingface_hub import hf_hub_download, list_repo_files
from fastapi import Request
from .state import app_state
from .models import SkinCancerConfig, SkinCancerViT
from .utils import load_json
def load_model():
"""Load and initialize the model and related components from Hugging Face."""
print(f"Loading model from Hugging Face: {app_state.HF_REPO_ID}")
try:
# Download and load label maps from HF
print("Loading label maps...")
label2id_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="label2id.json")
id2label_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="id2label.json")
app_state.label2id = load_json(label2id_path)
id2label_raw = load_json(id2label_path)
app_state.id2label.update({int(k): v for k, v in id2label_raw.items()})
print(f"Loaded {len(app_state.id2label)} classes")
# Download and load encoder categories
print("Loading encoder categories...")
cats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="loc_encoder_categories.npy")
cats = np.load(cats_path, allow_pickle=True)
app_state.loc_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
app_state.loc_encoder.fit(np.array(cats).reshape(-1, 1))
app_state.valid_localizations[:] = list(cats.tolist())
print(f"Loaded {len(app_state.valid_localizations)} localizations")
# Tabular dim = one-hot length + 1 (age)
app_state.tab_dim = app_state.loc_encoder.transform(np.array(["unknown"]).reshape(-1, 1)).shape[1] + 1
print(f"Tabular dimension: {app_state.tab_dim}")
# Download and load age stats
age_stats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="age_stats.json")
app_state.age_stats.update(load_json(age_stats_path))
print(f"Age stats: {app_state.age_stats}")
# Download and read the HF config to get the vision backbone name
print("Loading model config...")
config_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/config.json")
cfg_json = load_json(config_path)
app_state.vision_ckpt = cfg_json.get("vision_model_checkpoint", app_state.vision_ckpt)
print(f"Vision checkpoint: {app_state.vision_ckpt}")
app_state.image_processor = AutoImageProcessor.from_pretrained(app_state.vision_ckpt)
print("Image processor loaded")
# Build model config
print("Building model config...")
sc_cfg = SkinCancerConfig(
vision_model_checkpoint=app_state.vision_ckpt,
tabular_dim=app_state.tab_dim,
num_labels=len(app_state.id2label),
id2label=app_state.id2label,
label2id=app_state.label2id,
age_min=app_state.age_stats["age_min"],
age_max=app_state.age_stats["age_max"],
age_mean=app_state.age_stats["age_mean"]
)
# Initialize empty model with our config
print("Initializing model...")
model_init = SkinCancerViT(sc_cfg)
# Load weights from HF
print("Loading model weights from Hugging Face...")
try:
# Try to load from safetensors first
model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/model.safetensors")
from safetensors.torch import load_file as safe_load
print(f"Loading from safetensors: {model_path}")
state = safe_load(model_path)
except Exception as e:
print(f"Safetensors not found, trying pytorch_model.bin: {e}")
model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/pytorch_model.bin")
state = torch.load(model_path, map_location="cpu")
# Remove training-only keys like loss_fct.weight
to_drop = [k for k in list(state.keys()) if k.startswith("loss_fct.")]
for k in to_drop:
state.pop(k, None)
# Load with strict=False to ignore harmless mismatches
missing, unexpected = model_init.load_state_dict(state, strict=False)
if unexpected:
print("Ignored unexpected keys:", unexpected)
if missing:
print("Missing keys:", missing)
print(f"Using device: {app_state.device}")
model_init.to(app_state.device)
model_init.eval()
app_state.model = model_init
print("Model loaded successfully from Hugging Face!")
# Patch size / grid (if available from vision config)
try:
app_state.vit_patch_size = getattr(model_init.vision.config, "patch_size", app_state.vit_patch_size)
# For square inputs (224×224) with non-overlapping patches
size = app_state.image_processor.size
if isinstance(size, dict):
h = size.get("height", 224)
w = size.get("width", 224)
else:
h = w = size
app_state.vit_grid = (h // app_state.vit_patch_size, w // app_state.vit_patch_size)
print(f"ViT grid: {app_state.vit_grid}")
except Exception as e:
print(f"Error setting ViT grid: {e}")
app_state.vit_patch_size, app_state.vit_grid = app_state.DEFAULT_VIT_PATCH_SIZE, app_state.DEFAULT_VIT_GRID
except Exception as e:
print(f"Error loading model from Hugging Face: {e}")
raise
async def get_client_ip(request: Request) -> str:
# First hop of X-Forwarded-For is original client. Fall back to direct socket IP.
xff = request.headers.get("x-forwarded-for")
if xff:
return xff.split(",")[0].strip()
return request.client.host