Spaces:
Runtime error
Runtime error
anas commited on
Commit ·
ef36c4f
1
Parent(s): 86c87ce
Remove detectron2 dependency for inference
Browse files- Replace detectron2.utils.registry.Registry with minimal implementation
- Replace detectron2 transforms with pure numpy/cv2 resize+pad
- Make datasets/__init__.py lazy-import poly_data (training only)
Made-with: Cursor
- app.py +13 -15
- datasets/__init__.py +2 -3
- util/bf_utils.py +20 -7
app.py
CHANGED
|
@@ -11,8 +11,6 @@ from PIL import Image
|
|
| 11 |
from shapely.geometry import Polygon
|
| 12 |
|
| 13 |
from datasets.discrete_tokenizer import DiscreteTokenizer
|
| 14 |
-
from datasets.transforms import ResizeAndPad
|
| 15 |
-
from detectron2.data import transforms as T
|
| 16 |
from models import build_model
|
| 17 |
from util.plot_utils import plot_semantic_rich_floorplan_opencv
|
| 18 |
|
|
@@ -219,23 +217,23 @@ print("Loading model...")
|
|
| 219 |
MODEL = load_model()
|
| 220 |
print("Model loaded.")
|
| 221 |
|
| 222 |
-
DATA_TRANSFORM = T.AugmentationList(
|
| 223 |
-
[ResizeAndPad((MODEL_ARGS.image_size, MODEL_ARGS.image_size), pad_value=255)]
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
|
|
|
|
|
|
|
| 228 |
image_np = np.array(pil_image.convert("RGB"))
|
| 229 |
-
|
| 230 |
-
DATA_TRANSFORM(aug_input)
|
| 231 |
-
image_np = aug_input.image
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
-
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
def predict_floorplan(image: Image.Image):
|
|
|
|
| 11 |
from shapely.geometry import Polygon
|
| 12 |
|
| 13 |
from datasets.discrete_tokenizer import DiscreteTokenizer
|
|
|
|
|
|
|
| 14 |
from models import build_model
|
| 15 |
from util.plot_utils import plot_semantic_rich_floorplan_opencv
|
| 16 |
|
|
|
|
| 217 |
MODEL = load_model()
|
| 218 |
print("Model loaded.")
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
def preprocess_image(pil_image: Image.Image) -> torch.Tensor:
|
| 221 |
+
"""Resize preserving aspect ratio + pad to (image_size, image_size)."""
|
| 222 |
+
target = MODEL_ARGS.image_size
|
| 223 |
image_np = np.array(pil_image.convert("RGB"))
|
| 224 |
+
h, w = image_np.shape[:2]
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
scale = min(target / h, target / w)
|
| 227 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 228 |
+
resized = cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
|
| 229 |
+
|
| 230 |
+
padded = np.full((target, target, 3), 255, dtype=np.uint8)
|
| 231 |
+
top = (target - new_h) // 2
|
| 232 |
+
left = (target - new_w) // 2
|
| 233 |
+
padded[top:top + new_h, left:left + new_w] = resized
|
| 234 |
|
| 235 |
+
tensor = padded.transpose((2, 0, 1)).astype(np.float32) / 255.0
|
| 236 |
+
return torch.as_tensor(tensor)
|
| 237 |
|
| 238 |
|
| 239 |
def predict_floorplan(image: Image.Image):
|
datasets/__init__.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from .poly_data import build as build_poly
|
| 2 |
-
|
| 3 |
-
|
| 4 |
def build_dataset(image_set, args):
|
|
|
|
|
|
|
| 5 |
if args.dataset_name in ["stru3d", "cubicasa", "waffle", "r2g"]:
|
| 6 |
print(f"Build {args.dataset_name} {image_set} dataset")
|
| 7 |
return build_poly(image_set, args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def build_dataset(image_set, args):
|
| 2 |
+
from .poly_data import build as build_poly
|
| 3 |
+
|
| 4 |
if args.dataset_name in ["stru3d", "cubicasa", "waffle", "r2g"]:
|
| 5 |
print(f"Build {args.dataset_name} {image_set} dataset")
|
| 6 |
return build_poly(image_set, args)
|
util/bf_utils.py
CHANGED
|
@@ -6,13 +6,26 @@ import numpy as np
|
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def box_cxcywh_to_xyxy(x):
|
|
|
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
| 8 |
|
| 9 |
+
class _Registry:
|
| 10 |
+
"""Minimal replacement for detectron2.utils.registry.Registry."""
|
| 11 |
+
def __init__(self, name):
|
| 12 |
+
self._name = name
|
| 13 |
+
self._obj_map = {}
|
| 14 |
+
|
| 15 |
+
def register(self, obj=None):
|
| 16 |
+
if obj is None:
|
| 17 |
+
def decorator(func_or_class):
|
| 18 |
+
self._obj_map[func_or_class.__name__] = func_or_class
|
| 19 |
+
return func_or_class
|
| 20 |
+
return decorator
|
| 21 |
+
self._obj_map[obj.__name__] = obj
|
| 22 |
+
return obj
|
| 23 |
+
|
| 24 |
+
def get(self, name):
|
| 25 |
+
return self._obj_map[name]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
POLY_LOSS_REGISTRY = _Registry("POLY_LOSS")
|
| 29 |
|
| 30 |
|
| 31 |
def box_cxcywh_to_xyxy(x):
|