Spaces:
Running
Running
File size: 2,703 Bytes
cb92718 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import json
from pathlib import Path
import torch
from app.models.mlp_head import build_mlp_head_from_checkpoint
from app.services.derm_backbone import DermFoundationBackbone
def load_class_names() -> dict[int, str]:
project_root = Path(__file__).resolve().parents[2]
class_names_path = project_root / "class_names.json"
with open(class_names_path, "r", encoding="utf-8") as f:
raw_class_names = json.load(f)
return {int(index): name for index, name in raw_class_names.items()}
class TwoStageDermPredictor:
"""
Stage 1: Derm Foundation image -> embedding.
Stage 2: PyTorch MLP head embedding -> class probabilities.
"""
def __init__(
self,
derm_model_id: str,
head_checkpoint_path: str,
hf_token: str | None = None,
local_files_only: bool = False,
image_size: int = 448,
device_name: str = "auto",
) -> None:
if device_name == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device_name)
self.class_names = load_class_names()
self.backbone = DermFoundationBackbone(
repo_id=derm_model_id,
token=hf_token,
local_files_only=local_files_only,
image_size=image_size,
)
self.head, _ = build_mlp_head_from_checkpoint(
checkpoint_path=head_checkpoint_path,
device=self.device,
)
output_dim = self.head[-1].out_features
if output_dim != len(self.class_names):
raise ValueError(
f"MLP output dimension is {output_dim}, "
f"but class_names.json contains {len(self.class_names)} classes."
)
def predict(self, image_bytes: bytes) -> dict:
embedding_np = self.backbone.image_to_embedding(image_bytes)
embedding = torch.from_numpy(embedding_np).float().to(self.device)
with torch.no_grad():
logits = self.head(embedding)
probs = torch.softmax(logits, dim=1)[0].cpu()
pred_idx = int(torch.argmax(probs).item())
confidence = float(probs[pred_idx].item())
print(self.class_names)
probabilities = [
{
"index": i,
"class_name": self.class_names[i],
"probability": float(prob),
}
for i, prob in enumerate(probs.tolist())
]
return {
"predicted_index": pred_idx,
"predicted_class": self.class_names[pred_idx],
"confidence": confidence,
"probabilities": probabilities,
} |