SovanK commited on
Commit
40b84df
·
1 Parent(s): 0fddfa1

fix: inference, add mask

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +84 -2
app.py CHANGED
@@ -415,7 +415,7 @@ def run_inference(
415
 
416
  try:
417
  image = image_state["image"]
418
- probs = infer_image(image, head)
419
 
420
  # Use id_to_labels.json mapping, fall back to model config if not available
421
  id2label = load_id_to_labels().get(head, {})
 
415
 
416
  try:
417
  image = image_state["image"]
418
+ probs = infer_image(image, head, image_state.get("mask"))
419
 
420
  # Use id_to_labels.json mapping, fall back to model config if not available
421
  id2label = load_id_to_labels().get(head, {})
inference.py CHANGED
@@ -12,16 +12,92 @@ import pandas as pd
12
  import torch
13
  from datasets import Dataset, DatasetDict, load_dataset
14
  from PIL import Image
 
 
15
  from transformers import (
16
  AutoImageProcessor,
17
  AutoModelForImageClassification,
18
  )
19
 
20
-
21
  HF_REPO_ID = "raidium/curia"
22
  HF_DATASET_ID = "raidium/CuriaBench"
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @lru_cache(maxsize=1)
26
  def load_id_to_labels() -> Dict[str, Dict[str, str]]:
27
  """Load the id_to_labels.json mapping file."""
@@ -37,7 +113,9 @@ def load_id_to_labels() -> Dict[str, Dict[str, str]]:
37
  @lru_cache(maxsize=1)
38
  def load_processor() -> AutoImageProcessor:
39
  token = os.environ.get("HF_TOKEN")
40
- return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
 
 
41
 
42
 
43
  @lru_cache(maxsize=None)
@@ -88,11 +166,15 @@ def to_numpy_image(image: Any) -> np.ndarray:
88
  def infer_image(
89
  image: np.ndarray,
90
  head: str,
 
91
  ) -> torch.Tensor:
92
  processor = load_processor()
93
  model = load_model(head)
94
  with torch.no_grad():
95
  processed = processor(images=image, return_tensors="pt")
 
 
 
96
  outputs = model(**processed)
97
  logits = outputs["logits"]
98
  probs = torch.nn.functional.softmax(logits[0], dim=-1)
 
12
  import torch
13
  from datasets import Dataset, DatasetDict, load_dataset
14
  from PIL import Image
15
+ from torchvision import transforms
16
+ from torchvision.transforms import functional as TF
17
  from transformers import (
18
  AutoImageProcessor,
19
  AutoModelForImageClassification,
20
  )
21
 
 
22
  HF_REPO_ID = "raidium/curia"
23
  HF_DATASET_ID = "raidium/CuriaBench"
24
 
25
 
26
+ class _NumpyToTensor:
27
+ """Convert numpy arrays to tensors while preserving tensors/images."""
28
+
29
+ def __call__(self, value: Any) -> torch.Tensor:
30
+ if isinstance(value, (torch.Tensor, Image.Image)):
31
+ return value # type: ignore[return-value]
32
+ return torch.tensor(value).unsqueeze(0)
33
+
34
+
35
+ class AdaptativeResizeMask(torch.nn.Module):
36
+ """Resize binary masks with a fallback threshold to avoid empty masks."""
37
+
38
+ def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None:
39
+ super().__init__()
40
+ self.target_size = target_size
41
+ self.initial_threshold = initial_threshold
42
+
43
+ def forward(self, mask: torch.Tensor) -> torch.Tensor: # type: ignore[override]
44
+ mask = mask.to(dtype=torch.float32)
45
+ resized = TF.resize(
46
+ mask,
47
+ (self.target_size, self.target_size),
48
+ interpolation=TF.InterpolationMode.BILINEAR,
49
+ antialias=True,
50
+ )
51
+ binary = resized > self.initial_threshold
52
+ if binary.sum() == 0:
53
+ new_threshold = torch.max(resized) * 0.5
54
+ binary = resized > new_threshold
55
+ return binary.to(dtype=torch.float32)
56
+
57
+
58
+ @lru_cache(maxsize=1)
59
+ def make_mask_transform(crop_size: int = 512) -> transforms.Compose:
60
+ """Return the resize transform used during training/inference."""
61
+
62
+ return transforms.Compose(
63
+ [
64
+ _NumpyToTensor(),
65
+ AdaptativeResizeMask(target_size=crop_size),
66
+ ]
67
+ )
68
+
69
+
70
+ def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]:
71
+ """Apply Curia's mask preprocessing so heads get the ROI they expect."""
72
+
73
+ if mask is None:
74
+ return None
75
+
76
+ mask_transform = make_mask_transform()
77
+
78
+ try:
79
+ mask_arr = np.array(mask)
80
+ except Exception:
81
+ return None
82
+
83
+ if mask_arr.size == 0:
84
+ return None
85
+
86
+ if mask_arr.ndim == 3:
87
+ tensor = mask_transform(mask_arr.transpose(2, 0, 1))
88
+ # Match the shape produced in simple_test.py so the model receives
89
+ # (batch, height, width, channels) style tensors.
90
+ tensor = tensor.transpose(1, 3).transpose(1, 2)
91
+ else:
92
+ tensor = mask_transform(torch.tensor([mask_arr]))
93
+ tensor = tensor.unsqueeze(0)
94
+
95
+ if isinstance(tensor, np.ndarray):
96
+ tensor = torch.from_numpy(tensor)
97
+
98
+ return tensor
99
+
100
+
101
  @lru_cache(maxsize=1)
102
  def load_id_to_labels() -> Dict[str, Dict[str, str]]:
103
  """Load the id_to_labels.json mapping file."""
 
113
  @lru_cache(maxsize=1)
114
  def load_processor() -> AutoImageProcessor:
115
  token = os.environ.get("HF_TOKEN")
116
+ return AutoImageProcessor.from_pretrained(
117
+ HF_REPO_ID, trust_remote_code=True, token=token
118
+ )
119
 
120
 
121
  @lru_cache(maxsize=None)
 
166
  def infer_image(
167
  image: np.ndarray,
168
  head: str,
169
+ mask: Any | None = None,
170
  ) -> torch.Tensor:
171
  processor = load_processor()
172
  model = load_model(head)
173
  with torch.no_grad():
174
  processed = processor(images=image, return_tensors="pt")
175
+ mask_tensor = prepare_mask_for_model(mask)
176
+ if mask_tensor is not None:
177
+ processed["mask"] = mask_tensor
178
  outputs = model(**processed)
179
  logits = outputs["logits"]
180
  probs = torch.nn.functional.softmax(logits[0], dim=-1)