File size: 4,733 Bytes
0fddfa1 a8db175 0fddfa1 a8db175 0fddfa1 40b84df 0fddfa1 40b84df 882bd8f 40b84df 0fddfa1 40b84df 0fddfa1 40b84df de6c079 0fddfa1 40b84df 0fddfa1 de6c079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""Model and dataset loading, inference, and label extraction functions."""
from __future__ import annotations
import json
import os
from functools import lru_cache
from typing import Any, Dict, Optional
import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as TF
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
HF_REPO_ID = "raidium/curia"
HF_DATASET_ID = "raidium/CuriaBench"
class _NumpyToTensor:
"""Convert numpy arrays to tensors while preserving tensors/images."""
def __call__(self, value: Any) -> torch.Tensor:
if isinstance(value, (torch.Tensor, Image.Image)):
return value # type: ignore[return-value]
return torch.tensor(value).unsqueeze(0)
class AdaptativeResizeMask(torch.nn.Module):
"""Resize binary masks with a fallback threshold to avoid empty masks."""
def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None:
super().__init__()
self.target_size = target_size
self.initial_threshold = initial_threshold
def forward(self, mask: torch.Tensor) -> torch.Tensor: # type: ignore[override]
mask = mask.to(dtype=torch.float32)
resized = TF.resize(
mask,
(self.target_size, self.target_size),
interpolation=TF.InterpolationMode.BILINEAR,
antialias=True,
)
binary = resized > self.initial_threshold
if binary.sum() == 0:
new_threshold = torch.max(resized) * 0.5
binary = resized > new_threshold
return binary.to(dtype=torch.float32)
@lru_cache(maxsize=1)
def make_mask_transform(crop_size: int = 512) -> transforms.Compose:
"""Return the resize transform used during training/inference."""
return transforms.Compose(
[
_NumpyToTensor(),
AdaptativeResizeMask(target_size=crop_size),
]
)
def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]:
"""Apply Curia's mask preprocessing so heads get the ROI they expect."""
if mask is None:
return None
mask_transform = make_mask_transform()
try:
mask_arr = np.array(mask)
except Exception:
return None
if mask_arr.size == 0:
return None
if mask_arr.ndim == 3: # (H, W, slices)
tensor = mask_transform(mask_arr.transpose(2, 0, 1)) # (1, slices, H, W)
tensor = tensor.transpose(1, 3).transpose(1, 2) #
else:
tensor = mask_transform(torch.tensor([mask_arr]))
tensor = tensor.unsqueeze(0)
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
return tensor
@lru_cache(maxsize=1)
def load_id_to_labels() -> Dict[str, Dict[str, str]]:
"""Load the id_to_labels.json mapping file."""
json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json")
with open(json_path, "r") as f:
data = json.load(f)
# convert string keys to integers
for head in data:
data[head] = {int(k): v for k, v in data[head].items()}
return data
@lru_cache(maxsize=1)
def load_processor() -> AutoImageProcessor:
token = os.environ.get("HF_TOKEN")
return AutoImageProcessor.from_pretrained(
HF_REPO_ID, trust_remote_code=True, token=token
)
@lru_cache(maxsize=None)
def load_model(head: str) -> AutoModelForImageClassification:
token = os.environ.get("HF_TOKEN")
model = AutoModelForImageClassification.from_pretrained(
HF_REPO_ID,
trust_remote_code=True,
subfolder=head,
token=token,
)
model.eval()
return model
@lru_cache(maxsize=None)
def load_curia_dataset(subset: str) -> Any:
token = os.environ.get("HF_TOKEN")
ds = load_dataset(
HF_DATASET_ID,
subset,
split="test",
token=token,
)
if isinstance(ds, DatasetDict):
return ds["test"]
return ds
def infer_image(
image: np.ndarray,
head: str,
mask: Any | None = None,
return_probs: bool = True,
) -> torch.Tensor:
processor = load_processor()
model = load_model(head)
with torch.no_grad():
processed = processor(images=image, return_tensors="pt")
mask_tensor = prepare_mask_for_model(mask)
if mask_tensor is not None:
processed["mask"] = mask_tensor
outputs = model(**processed)
logits = outputs["logits"]
if return_probs:
probs = torch.nn.functional.softmax(logits[0], dim=-1)
return probs
else:
return logits[0].squeeze()
|