# utils/classifier.py from __future__ import annotations import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image from pathlib import Path from utils.analysis import CLASS_NAMES # ------------------------------------------------- # DEVICE # ------------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------------------------- # TRANSFORMS # ------------------------------------------------- clf_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ]) # ------------------------------------------------- # MODEL LOADING # ------------------------------------------------- def load_wbc_classifier(weights_path: str | Path): """ Load your trained ResNet50 classifier. Expected checkpoint format: {"model_state_dict": ..., ...} """ weights_path = Path(weights_path) if not weights_path.exists(): raise FileNotFoundError(f"Classifier weights not found: {weights_path}") # Base model (ImageNet pre-trained) model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) # Replace FC layer with your 8-class head model.fc = nn.Sequential( nn.Dropout(0.3), nn.Linear(model.fc.in_features, len(CLASS_NAMES)) ) # Load checkpoint ckpt = torch.load(weights_path, map_location=DEVICE,weights_only=False) if "model_state_dict" in ckpt: model.load_state_dict(ckpt["model_state_dict"]) else: model.load_state_dict(ckpt) model.to(DEVICE) model.eval() return model # ------------------------------------------------- # SINGLE-CROP CLASSIFICATION # ------------------------------------------------- def classify_wbc_crop( model: nn.Module, pil_img: Image.Image, ) -> str: """ Run classification on a single crop and return predicted class name. """ x = clf_transform(pil_img).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(x) pred_idx = int(torch.argmax(logits, dim=1).item()) return CLASS_NAMES[pred_idx]