Daniel Huynh
Deploy FastAPI derm backend to Hugging Face Spaces
cb92718
import json
from pathlib import Path
import torch
from app.models.mlp_head import build_mlp_head_from_checkpoint
from app.services.derm_backbone import DermFoundationBackbone
def load_class_names() -> dict[int, str]:
project_root = Path(__file__).resolve().parents[2]
class_names_path = project_root / "class_names.json"
with open(class_names_path, "r", encoding="utf-8") as f:
raw_class_names = json.load(f)
return {int(index): name for index, name in raw_class_names.items()}
class TwoStageDermPredictor:
"""
Stage 1: Derm Foundation image -> embedding.
Stage 2: PyTorch MLP head embedding -> class probabilities.
"""
def __init__(
self,
derm_model_id: str,
head_checkpoint_path: str,
hf_token: str | None = None,
local_files_only: bool = False,
image_size: int = 448,
device_name: str = "auto",
) -> None:
if device_name == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device_name)
self.class_names = load_class_names()
self.backbone = DermFoundationBackbone(
repo_id=derm_model_id,
token=hf_token,
local_files_only=local_files_only,
image_size=image_size,
)
self.head, _ = build_mlp_head_from_checkpoint(
checkpoint_path=head_checkpoint_path,
device=self.device,
)
output_dim = self.head[-1].out_features
if output_dim != len(self.class_names):
raise ValueError(
f"MLP output dimension is {output_dim}, "
f"but class_names.json contains {len(self.class_names)} classes."
)
def predict(self, image_bytes: bytes) -> dict:
embedding_np = self.backbone.image_to_embedding(image_bytes)
embedding = torch.from_numpy(embedding_np).float().to(self.device)
with torch.no_grad():
logits = self.head(embedding)
probs = torch.softmax(logits, dim=1)[0].cpu()
pred_idx = int(torch.argmax(probs).item())
confidence = float(probs[pred_idx].item())
print(self.class_names)
probabilities = [
{
"index": i,
"class_name": self.class_names[i],
"probability": float(prob),
}
for i, prob in enumerate(probs.tolist())
]
return {
"predicted_index": pred_idx,
"predicted_class": self.class_names[pred_idx],
"confidence": confidence,
"probabilities": probabilities,
}