File size: 5,289 Bytes
0fddfa1 40b84df 0fddfa1 40b84df 0fddfa1 40b84df 0fddfa1 40b84df 0fddfa1 40b84df 0fddfa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""Model and dataset loading, inference, and label extraction functions."""
from __future__ import annotations
import json
import os
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict, load_dataset
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as TF
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
)
HF_REPO_ID = "raidium/curia"
HF_DATASET_ID = "raidium/CuriaBench"
class _NumpyToTensor:
"""Convert numpy arrays to tensors while preserving tensors/images."""
def __call__(self, value: Any) -> torch.Tensor:
if isinstance(value, (torch.Tensor, Image.Image)):
return value # type: ignore[return-value]
return torch.tensor(value).unsqueeze(0)
class AdaptativeResizeMask(torch.nn.Module):
"""Resize binary masks with a fallback threshold to avoid empty masks."""
def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None:
super().__init__()
self.target_size = target_size
self.initial_threshold = initial_threshold
def forward(self, mask: torch.Tensor) -> torch.Tensor: # type: ignore[override]
mask = mask.to(dtype=torch.float32)
resized = TF.resize(
mask,
(self.target_size, self.target_size),
interpolation=TF.InterpolationMode.BILINEAR,
antialias=True,
)
binary = resized > self.initial_threshold
if binary.sum() == 0:
new_threshold = torch.max(resized) * 0.5
binary = resized > new_threshold
return binary.to(dtype=torch.float32)
@lru_cache(maxsize=1)
def make_mask_transform(crop_size: int = 512) -> transforms.Compose:
"""Return the resize transform used during training/inference."""
return transforms.Compose(
[
_NumpyToTensor(),
AdaptativeResizeMask(target_size=crop_size),
]
)
def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]:
"""Apply Curia's mask preprocessing so heads get the ROI they expect."""
if mask is None:
return None
mask_transform = make_mask_transform()
try:
mask_arr = np.array(mask)
except Exception:
return None
if mask_arr.size == 0:
return None
if mask_arr.ndim == 3:
tensor = mask_transform(mask_arr.transpose(2, 0, 1))
# Match the shape produced in simple_test.py so the model receives
# (batch, height, width, channels) style tensors.
tensor = tensor.transpose(1, 3).transpose(1, 2)
else:
tensor = mask_transform(torch.tensor([mask_arr]))
tensor = tensor.unsqueeze(0)
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
return tensor
@lru_cache(maxsize=1)
def load_id_to_labels() -> Dict[str, Dict[str, str]]:
"""Load the id_to_labels.json mapping file."""
json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json")
with open(json_path, "r") as f:
data = json.load(f)
# convert string keys to integers
for head in data:
data[head] = {int(k): v for k, v in data[head].items()}
return data
@lru_cache(maxsize=1)
def load_processor() -> AutoImageProcessor:
token = os.environ.get("HF_TOKEN")
return AutoImageProcessor.from_pretrained(
HF_REPO_ID, trust_remote_code=True, token=token
)
@lru_cache(maxsize=None)
def load_model(head: str) -> AutoModelForImageClassification:
token = os.environ.get("HF_TOKEN")
model = AutoModelForImageClassification.from_pretrained(
HF_REPO_ID,
trust_remote_code=True,
subfolder=head,
token=token,
)
model.eval()
return model
@lru_cache(maxsize=None)
def load_curia_dataset(subset: str) -> Any:
token = os.environ.get("HF_TOKEN")
ds = load_dataset(
HF_DATASET_ID,
subset,
split="test",
token=token,
)
if isinstance(ds, DatasetDict):
return ds["test"]
return ds
def to_numpy_image(image: Any) -> np.ndarray:
"""Convert dataset or user-provided imagery to a float32 numpy array."""
if isinstance(image, np.ndarray):
arr = image
elif isinstance(image, Image.Image):
arr = np.array(image)
else:
# Some datasets provide nested dicts or lists – attempt to coerce.
arr = np.array(image)
if arr.ndim == 3 and arr.shape[-1] == 3:
# Convert RGB to grayscale by averaging channels
arr = arr.mean(axis=-1)
return arr.astype(np.float32)
def infer_image(
image: np.ndarray,
head: str,
mask: Any | None = None,
) -> torch.Tensor:
processor = load_processor()
model = load_model(head)
with torch.no_grad():
processed = processor(images=image, return_tensors="pt")
mask_tensor = prepare_mask_for_model(mask)
if mask_tensor is not None:
processed["mask"] = mask_tensor
outputs = model(**processed)
logits = outputs["logits"]
probs = torch.nn.functional.softmax(logits[0], dim=-1)
return probs
|