anas commited on
Commit
ef36c4f
·
1 Parent(s): 86c87ce

Remove detectron2 dependency for inference

Browse files

- Replace detectron2.utils.registry.Registry with minimal implementation
- Replace detectron2 transforms with pure numpy/cv2 resize+pad
- Make datasets/__init__.py lazy-import poly_data (training only)

Made-with: Cursor

Files changed (3) hide show
  1. app.py +13 -15
  2. datasets/__init__.py +2 -3
  3. util/bf_utils.py +20 -7
app.py CHANGED
@@ -11,8 +11,6 @@ from PIL import Image
11
  from shapely.geometry import Polygon
12
 
13
  from datasets.discrete_tokenizer import DiscreteTokenizer
14
- from datasets.transforms import ResizeAndPad
15
- from detectron2.data import transforms as T
16
  from models import build_model
17
  from util.plot_utils import plot_semantic_rich_floorplan_opencv
18
 
@@ -219,23 +217,23 @@ print("Loading model...")
219
  MODEL = load_model()
220
  print("Model loaded.")
221
 
222
- DATA_TRANSFORM = T.AugmentationList(
223
- [ResizeAndPad((MODEL_ARGS.image_size, MODEL_ARGS.image_size), pad_value=255)]
224
- )
225
-
226
-
227
  def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
 
 
228
  image_np = np.array(pil_image.convert("RGB"))
229
- aug_input = T.AugInput(image_np)
230
- DATA_TRANSFORM(aug_input)
231
- image_np = aug_input.image
232
 
233
- if len(image_np.shape) == 2:
234
- tensor = np.expand_dims(image_np, 0)
235
- else:
236
- tensor = image_np.transpose((2, 0, 1))
 
 
 
 
237
 
238
- return (1 / 255) * torch.as_tensor(tensor, dtype=torch.float32)
 
239
 
240
 
241
  def predict_floorplan(image: Image.Image):
 
11
  from shapely.geometry import Polygon
12
 
13
  from datasets.discrete_tokenizer import DiscreteTokenizer
 
 
14
  from models import build_model
15
  from util.plot_utils import plot_semantic_rich_floorplan_opencv
16
 
 
217
  MODEL = load_model()
218
  print("Model loaded.")
219
 
 
 
 
 
 
220
  def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
221
+ """Resize preserving aspect ratio + pad to (image_size, image_size)."""
222
+ target = MODEL_ARGS.image_size
223
  image_np = np.array(pil_image.convert("RGB"))
224
+ h, w = image_np.shape[:2]
 
 
225
 
226
+ scale = min(target / h, target / w)
227
+ new_h, new_w = int(h * scale), int(w * scale)
228
+ resized = cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
229
+
230
+ padded = np.full((target, target, 3), 255, dtype=np.uint8)
231
+ top = (target - new_h) // 2
232
+ left = (target - new_w) // 2
233
+ padded[top:top + new_h, left:left + new_w] = resized
234
 
235
+ tensor = padded.transpose((2, 0, 1)).astype(np.float32) / 255.0
236
+ return torch.as_tensor(tensor)
237
 
238
 
239
  def predict_floorplan(image: Image.Image):
datasets/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
- from .poly_data import build as build_poly
2
-
3
-
4
  def build_dataset(image_set, args):
 
 
5
  if args.dataset_name in ["stru3d", "cubicasa", "waffle", "r2g"]:
6
  print(f"Build {args.dataset_name} {image_set} dataset")
7
  return build_poly(image_set, args)
 
 
 
 
1
  def build_dataset(image_set, args):
2
+ from .poly_data import build as build_poly
3
+
4
  if args.dataset_name in ["stru3d", "cubicasa", "waffle", "r2g"]:
5
  print(f"Build {args.dataset_name} {image_set} dataset")
6
  return build_poly(image_set, args)
util/bf_utils.py CHANGED
@@ -6,13 +6,26 @@ import numpy as np
6
  import torch
7
  from torch import nn
8
 
9
- from detectron2.utils.registry import Registry
10
-
11
- # need an easier place to avoid circular dependencies.
12
- POLY_LOSS_REGISTRY = Registry("POLY_LOSS")
13
- POLY_LOSS_REGISTRY.__doc__ = """
14
- Registry for loss computations on predicted polygons.
15
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def box_cxcywh_to_xyxy(x):
 
6
  import torch
7
  from torch import nn
8
 
9
+ class _Registry:
10
+ """Minimal replacement for detectron2.utils.registry.Registry."""
11
+ def __init__(self, name):
12
+ self._name = name
13
+ self._obj_map = {}
14
+
15
+ def register(self, obj=None):
16
+ if obj is None:
17
+ def decorator(func_or_class):
18
+ self._obj_map[func_or_class.__name__] = func_or_class
19
+ return func_or_class
20
+ return decorator
21
+ self._obj_map[obj.__name__] = obj
22
+ return obj
23
+
24
+ def get(self, name):
25
+ return self._obj_map[name]
26
+
27
+
28
+ POLY_LOSS_REGISTRY = _Registry("POLY_LOSS")
29
 
30
 
31
  def box_cxcywh_to_xyxy(x):