Spaces:
Running
Running
AbelGAlem
feat(server): implement FastAPI application with model loading(HF HUB), CORS support, prediction endpoint and Docker
a65c9ed | import os | |
| import torch | |
| import numpy as np | |
| from transformers import AutoImageProcessor, AutoConfig | |
| from sklearn.preprocessing import OneHotEncoder | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from fastapi import Request | |
| from .state import app_state | |
| from .models import SkinCancerConfig, SkinCancerViT | |
| from .utils import load_json | |
| def load_model(): | |
| """Load and initialize the model and related components from Hugging Face.""" | |
| print(f"Loading model from Hugging Face: {app_state.HF_REPO_ID}") | |
| try: | |
| # Download and load label maps from HF | |
| print("Loading label maps...") | |
| label2id_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="label2id.json") | |
| id2label_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="id2label.json") | |
| app_state.label2id = load_json(label2id_path) | |
| id2label_raw = load_json(id2label_path) | |
| app_state.id2label.update({int(k): v for k, v in id2label_raw.items()}) | |
| print(f"Loaded {len(app_state.id2label)} classes") | |
| # Download and load encoder categories | |
| print("Loading encoder categories...") | |
| cats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="loc_encoder_categories.npy") | |
| cats = np.load(cats_path, allow_pickle=True) | |
| app_state.loc_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False) | |
| app_state.loc_encoder.fit(np.array(cats).reshape(-1, 1)) | |
| app_state.valid_localizations[:] = list(cats.tolist()) | |
| print(f"Loaded {len(app_state.valid_localizations)} localizations") | |
| # Tabular dim = one-hot length + 1 (age) | |
| app_state.tab_dim = app_state.loc_encoder.transform(np.array(["unknown"]).reshape(-1, 1)).shape[1] + 1 | |
| print(f"Tabular dimension: {app_state.tab_dim}") | |
| # Download and load age stats | |
| age_stats_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="age_stats.json") | |
| app_state.age_stats.update(load_json(age_stats_path)) | |
| print(f"Age stats: {app_state.age_stats}") | |
| # Download and read the HF config to get the vision backbone name | |
| print("Loading model config...") | |
| config_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/config.json") | |
| cfg_json = load_json(config_path) | |
| app_state.vision_ckpt = cfg_json.get("vision_model_checkpoint", app_state.vision_ckpt) | |
| print(f"Vision checkpoint: {app_state.vision_ckpt}") | |
| app_state.image_processor = AutoImageProcessor.from_pretrained(app_state.vision_ckpt) | |
| print("Image processor loaded") | |
| # Build model config | |
| print("Building model config...") | |
| sc_cfg = SkinCancerConfig( | |
| vision_model_checkpoint=app_state.vision_ckpt, | |
| tabular_dim=app_state.tab_dim, | |
| num_labels=len(app_state.id2label), | |
| id2label=app_state.id2label, | |
| label2id=app_state.label2id, | |
| age_min=app_state.age_stats["age_min"], | |
| age_max=app_state.age_stats["age_max"], | |
| age_mean=app_state.age_stats["age_mean"] | |
| ) | |
| # Initialize empty model with our config | |
| print("Initializing model...") | |
| model_init = SkinCancerViT(sc_cfg) | |
| # Load weights from HF | |
| print("Loading model weights from Hugging Face...") | |
| try: | |
| # Try to load from safetensors first | |
| model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/model.safetensors") | |
| from safetensors.torch import load_file as safe_load | |
| print(f"Loading from safetensors: {model_path}") | |
| state = safe_load(model_path) | |
| except Exception as e: | |
| print(f"Safetensors not found, trying pytorch_model.bin: {e}") | |
| model_path = hf_hub_download(repo_id=app_state.HF_REPO_ID, filename="best_model/pytorch_model.bin") | |
| state = torch.load(model_path, map_location="cpu") | |
| # Remove training-only keys like loss_fct.weight | |
| to_drop = [k for k in list(state.keys()) if k.startswith("loss_fct.")] | |
| for k in to_drop: | |
| state.pop(k, None) | |
| # Load with strict=False to ignore harmless mismatches | |
| missing, unexpected = model_init.load_state_dict(state, strict=False) | |
| if unexpected: | |
| print("Ignored unexpected keys:", unexpected) | |
| if missing: | |
| print("Missing keys:", missing) | |
| print(f"Using device: {app_state.device}") | |
| model_init.to(app_state.device) | |
| model_init.eval() | |
| app_state.model = model_init | |
| print("Model loaded successfully from Hugging Face!") | |
| # Patch size / grid (if available from vision config) | |
| try: | |
| app_state.vit_patch_size = getattr(model_init.vision.config, "patch_size", app_state.vit_patch_size) | |
| # For square inputs (224×224) with non-overlapping patches | |
| size = app_state.image_processor.size | |
| if isinstance(size, dict): | |
| h = size.get("height", 224) | |
| w = size.get("width", 224) | |
| else: | |
| h = w = size | |
| app_state.vit_grid = (h // app_state.vit_patch_size, w // app_state.vit_patch_size) | |
| print(f"ViT grid: {app_state.vit_grid}") | |
| except Exception as e: | |
| print(f"Error setting ViT grid: {e}") | |
| app_state.vit_patch_size, app_state.vit_grid = app_state.DEFAULT_VIT_PATCH_SIZE, app_state.DEFAULT_VIT_GRID | |
| except Exception as e: | |
| print(f"Error loading model from Hugging Face: {e}") | |
| raise | |
| async def get_client_ip(request: Request) -> str: | |
| # First hop of X-Forwarded-For is original client. Fall back to direct socket IP. | |
| xff = request.headers.get("x-forwarded-for") | |
| if xff: | |
| return xff.split(",")[0].strip() | |
| return request.client.host |