"""Model and dataset loading, inference, and label extraction functions.""" from __future__ import annotations import json import os from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd import torch from datasets import Dataset, DatasetDict, load_dataset from PIL import Image from torchvision import transforms from torchvision.transforms import functional as TF from transformers import ( AutoImageProcessor, AutoModelForImageClassification, ) HF_REPO_ID = "raidium/curia" HF_DATASET_ID = "raidium/CuriaBench" class _NumpyToTensor: """Convert numpy arrays to tensors while preserving tensors/images.""" def __call__(self, value: Any) -> torch.Tensor: if isinstance(value, (torch.Tensor, Image.Image)): return value # type: ignore[return-value] return torch.tensor(value).unsqueeze(0) class AdaptativeResizeMask(torch.nn.Module): """Resize binary masks with a fallback threshold to avoid empty masks.""" def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None: super().__init__() self.target_size = target_size self.initial_threshold = initial_threshold def forward(self, mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] mask = mask.to(dtype=torch.float32) resized = TF.resize( mask, (self.target_size, self.target_size), interpolation=TF.InterpolationMode.BILINEAR, antialias=True, ) binary = resized > self.initial_threshold if binary.sum() == 0: new_threshold = torch.max(resized) * 0.5 binary = resized > new_threshold return binary.to(dtype=torch.float32) @lru_cache(maxsize=1) def make_mask_transform(crop_size: int = 512) -> transforms.Compose: """Return the resize transform used during training/inference.""" return transforms.Compose( [ _NumpyToTensor(), AdaptativeResizeMask(target_size=crop_size), ] ) def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]: """Apply Curia's mask preprocessing so heads get the ROI they expect.""" if mask is None: return None mask_transform = make_mask_transform() try: mask_arr = np.array(mask) except Exception: return None if mask_arr.size == 0: return None if mask_arr.ndim == 3: tensor = mask_transform(mask_arr.transpose(2, 0, 1)) # Match the shape produced in simple_test.py so the model receives # (batch, height, width, channels) style tensors. tensor = tensor.transpose(1, 3).transpose(1, 2) else: tensor = mask_transform(torch.tensor([mask_arr])) tensor = tensor.unsqueeze(0) if isinstance(tensor, np.ndarray): tensor = torch.from_numpy(tensor) return tensor @lru_cache(maxsize=1) def load_id_to_labels() -> Dict[str, Dict[str, str]]: """Load the id_to_labels.json mapping file.""" json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json") with open(json_path, "r") as f: data = json.load(f) # convert string keys to integers for head in data: data[head] = {int(k): v for k, v in data[head].items()} return data @lru_cache(maxsize=1) def load_processor() -> AutoImageProcessor: token = os.environ.get("HF_TOKEN") return AutoImageProcessor.from_pretrained( HF_REPO_ID, trust_remote_code=True, token=token ) @lru_cache(maxsize=None) def load_model(head: str) -> AutoModelForImageClassification: token = os.environ.get("HF_TOKEN") model = AutoModelForImageClassification.from_pretrained( HF_REPO_ID, trust_remote_code=True, subfolder=head, token=token, ) model.eval() return model @lru_cache(maxsize=None) def load_curia_dataset(subset: str) -> Any: token = os.environ.get("HF_TOKEN") ds = load_dataset( HF_DATASET_ID, subset, split="test", token=token, ) if isinstance(ds, DatasetDict): return ds["test"] return ds def to_numpy_image(image: Any) -> np.ndarray: """Convert dataset or user-provided imagery to a float32 numpy array.""" if isinstance(image, np.ndarray): arr = image elif isinstance(image, Image.Image): arr = np.array(image) else: # Some datasets provide nested dicts or lists – attempt to coerce. arr = np.array(image) if arr.ndim == 3 and arr.shape[-1] == 3: # Convert RGB to grayscale by averaging channels arr = arr.mean(axis=-1) return arr.astype(np.float32) def infer_image( image: np.ndarray, head: str, mask: Any | None = None, ) -> torch.Tensor: processor = load_processor() model = load_model(head) with torch.no_grad(): processed = processor(images=image, return_tensors="pt") mask_tensor = prepare_mask_for_model(mask) if mask_tensor is not None: processed["mask"] = mask_tensor outputs = model(**processed) logits = outputs["logits"] probs = torch.nn.functional.softmax(logits[0], dim=-1) return probs