Upload app.py
#3
by KinetoLabs - opened
app.py
CHANGED
|
@@ -8,77 +8,365 @@ import gradio as gr
|
|
| 8 |
import numpy as np
|
| 9 |
import os
|
| 10 |
import torch
|
| 11 |
-
os.system('pip install --upgrade --no-cache-dir gdown')
|
| 12 |
-
|
| 13 |
-
|
| 14 |
from PIL import Image
|
|
|
|
| 15 |
|
| 16 |
from utils.logger import get_logger
|
| 17 |
from config.defaults import get_config
|
| 18 |
-
from inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from models.build import build_model
|
| 20 |
from argparse import Namespace
|
| 21 |
import gdown
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
for model_id in model_ids:
|
| 34 |
-
if model_id[0] != model_cfg:
|
| 35 |
-
continue
|
| 36 |
-
path = os.path.join(ckpt_dir, 'best.pkl')
|
| 37 |
-
if not os.path.exists(path):
|
| 38 |
-
logger.info(f"Downloading {model_id}")
|
| 39 |
-
os.makedirs(ckpt_dir, exist_ok=True)
|
| 40 |
-
gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
|
| 44 |
-
args.pre_processing = pre_processing
|
| 45 |
-
args.post_processing = post_processing
|
| 46 |
-
if weight_name == 'mp3d':
|
| 47 |
-
model = mp3d_model
|
| 48 |
-
elif weight_name == 'zind':
|
| 49 |
-
model = zind_model
|
| 50 |
else:
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
vp_cache_path = 'src/demo/default_vp.txt'
|
| 58 |
-
if args.pre_processing:
|
| 59 |
-
vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
|
| 60 |
-
logger.info("pre-processing ...")
|
| 61 |
-
img, vp = preprocess(img, vp_cache_path=vp_cache_path)
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
run_one_inference(img, model, args, img_name,
|
| 65 |
logger=logger, show=False,
|
| 66 |
-
show_depth=
|
| 67 |
-
show_floorplan=
|
| 68 |
-
mesh_format=
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
vp_cache_path,
|
| 74 |
-
os.path.join(args.output_dir, f"{img_name}_pred.json")]
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
config = get_config(args)
|
| 79 |
-
down_ckpt(args.cfg, config.CKPT.DIR)
|
| 80 |
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
config.defrost()
|
| 83 |
args.device = "cpu"
|
| 84 |
config.TRAIN.DEVICE = "cpu"
|
|
@@ -88,54 +376,58 @@ def get_model(args):
|
|
| 88 |
|
| 89 |
|
| 90 |
if __name__ == '__main__':
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import os
|
| 10 |
import torch
|
|
|
|
|
|
|
|
|
|
| 11 |
from PIL import Image
|
| 12 |
+
import spaces
|
| 13 |
|
| 14 |
from utils.logger import get_logger
|
| 15 |
from config.defaults import get_config
|
| 16 |
+
# Moved from inference.py - preprocessing and inference functions
|
| 17 |
+
import cv2
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama
|
| 20 |
+
from utils.boundary import corners2boundaries, layout2depth
|
| 21 |
+
from utils.conversion import depth2xyz
|
| 22 |
+
from utils.misc import tensor2np_d
|
| 23 |
+
from utils.writer import xyz2json
|
| 24 |
+
from visualization.boundary import draw_boundaries
|
| 25 |
+
from visualization.floorplan import draw_floorplan, draw_iou_floorplan
|
| 26 |
+
from visualization.obj3d import create_3d_obj
|
| 27 |
from models.build import build_model
|
| 28 |
from argparse import Namespace
|
| 29 |
import gdown
|
| 30 |
+
from utils.misc import tensor2np
|
| 31 |
+
from postprocessing.post_process import post_process
|
| 32 |
|
| 33 |
|
| 34 |
+
def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None):
|
| 35 |
+
"""Align images with VP - moved from inference.py"""
|
| 36 |
+
if vp_cache_path and os.path.exists(vp_cache_path):
|
| 37 |
+
with open(vp_cache_path) as f:
|
| 38 |
+
vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()]
|
| 39 |
+
vp = np.array(vp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
else:
|
| 41 |
+
# VP detection and line segment extraction
|
| 42 |
+
_, vp, _, _, _, _, _ = panoEdgeDetection(img_ori,
|
| 43 |
+
qError=q_error,
|
| 44 |
+
refineIter=refine_iter)
|
| 45 |
+
i_img = rotatePanorama(img_ori, vp[2::-1])
|
| 46 |
|
| 47 |
+
if vp_cache_path is not None:
|
| 48 |
+
with open(vp_cache_path, 'w') as f:
|
| 49 |
+
for i in range(3):
|
| 50 |
+
f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2]))
|
| 51 |
+
|
| 52 |
+
return i_img, vp
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def show_depth_normal_grad(dt):
|
| 56 |
+
"""Simplified gradient visualization - moved from inference.py"""
|
| 57 |
+
depth = tensor2np(dt['depth'][0])
|
| 58 |
+
grad_img = np.gradient(depth, axis=1)
|
| 59 |
+
grad_img = np.abs(grad_img)
|
| 60 |
+
grad_img = (grad_img / grad_img.max() * 255).astype(np.uint8)
|
| 61 |
+
grad_img = cv2.applyColorMap(grad_img, cv2.COLORMAP_JET)
|
| 62 |
+
grad_img = cv2.resize(grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST)
|
| 63 |
+
return grad_img
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None):
|
| 67 |
+
"""Generate alpha floorplan - moved from inference.py"""
|
| 68 |
+
if border_color is None:
|
| 69 |
+
border_color = [1, 0, 0, 1]
|
| 70 |
+
fill_color = [0.2, 0.2, 0.2, 0.2]
|
| 71 |
+
dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color,
|
| 72 |
+
border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1])
|
| 73 |
+
dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA')
|
| 74 |
+
back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float32)
|
| 75 |
+
back[..., :] = [0.8, 0.8, 0.8, 1]
|
| 76 |
+
back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA')
|
| 77 |
+
iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB")
|
| 78 |
+
dt_floorplan = np.array(iou_floorplan) / 255.0
|
| 79 |
+
return dt_floorplan
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None):
|
| 83 |
+
"""2D visualization - moved from inference.py"""
|
| 84 |
+
dt_np = tensor2np_d(dt)
|
| 85 |
+
dt_depth = dt_np['depth'][0]
|
| 86 |
+
dt_xyz = depth2xyz(np.abs(dt_depth))
|
| 87 |
+
dt_ratio = dt_np['ratio'][0][0]
|
| 88 |
+
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1])
|
| 89 |
+
vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0])
|
| 90 |
+
|
| 91 |
+
if 'processed_xyz' in dt:
|
| 92 |
+
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False,
|
| 93 |
+
length=img.shape[1])
|
| 94 |
+
vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0])
|
| 95 |
+
|
| 96 |
+
if show_depth:
|
| 97 |
+
dt_grad_img = show_depth_normal_grad(dt)
|
| 98 |
+
grad_h = dt_grad_img.shape[0]
|
| 99 |
+
vis_merge = [
|
| 100 |
+
vis_img[0:-grad_h, :, :],
|
| 101 |
+
dt_grad_img,
|
| 102 |
+
]
|
| 103 |
+
vis_img = np.concatenate(vis_merge, axis=0)
|
| 104 |
+
|
| 105 |
+
if show_floorplan:
|
| 106 |
+
if 'processed_xyz' in dt:
|
| 107 |
+
floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2],
|
| 108 |
+
dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1])
|
| 109 |
+
else:
|
| 110 |
+
floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1])
|
| 111 |
+
|
| 112 |
+
vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1)
|
| 113 |
+
if show:
|
| 114 |
+
plt.imshow(vis_img)
|
| 115 |
+
plt.show()
|
| 116 |
+
if save_path:
|
| 117 |
+
result = Image.fromarray((vis_img * 255).astype(np.uint8))
|
| 118 |
+
result.save(save_path)
|
| 119 |
+
return vis_img
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def save_pred_json(xyz, ration, save_path):
|
| 123 |
+
"""Save prediction JSON - moved from inference.py"""
|
| 124 |
+
json_data = xyz2json(xyz, ration)
|
| 125 |
+
with open(save_path, 'w') as f:
|
| 126 |
+
import json
|
| 127 |
+
f.write(json.dumps(json_data, indent=4) + '\n')
|
| 128 |
+
return json_data
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def run_one_inference(img, model, args, name, logger=None, show=True, show_depth=True,
|
| 133 |
+
show_floorplan=True, mesh_format='.gltf', mesh_resolution=512):
|
| 134 |
+
"""Main inference function - moved from inference.py"""
|
| 135 |
+
model.eval()
|
| 136 |
+
if logger:
|
| 137 |
+
logger.info("model inference...")
|
| 138 |
+
dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device))
|
| 139 |
+
if args.post_processing != 'original':
|
| 140 |
+
if logger:
|
| 141 |
+
logger.info(f"post-processing, type:{args.post_processing}...")
|
| 142 |
+
dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing)
|
| 143 |
+
|
| 144 |
+
visualize_2d(img, dt,
|
| 145 |
+
show_depth=show_depth,
|
| 146 |
+
show_floorplan=show_floorplan,
|
| 147 |
+
show=show,
|
| 148 |
+
save_path=os.path.join(args.output_dir, f"{name}_pred.png"))
|
| 149 |
+
output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0]))
|
| 150 |
+
|
| 151 |
+
if logger:
|
| 152 |
+
logger.info(f"saving predicted layout json...")
|
| 153 |
+
json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0],
|
| 154 |
+
save_path=os.path.join(args.output_dir, f"{name}_pred.json"))
|
| 155 |
+
|
| 156 |
+
if args.visualize_3d or args.output_3d:
|
| 157 |
+
dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None,
|
| 158 |
+
length=mesh_resolution if 'processed_xyz' in dt else None,
|
| 159 |
+
visible=True if 'processed_xyz' in dt else False)
|
| 160 |
+
dt_layout_depth = layout2depth(dt_boundaries, show=False)
|
| 161 |
+
|
| 162 |
+
if logger:
|
| 163 |
+
logger.info(f"creating 3d mesh ...")
|
| 164 |
+
create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth,
|
| 165 |
+
save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None,
|
| 166 |
+
mesh=True, show=args.visualize_3d)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
def down_ckpt(model_cfg, ckpt_dir, logger=None):
|
| 170 |
+
# Only MP3D model needed
|
| 171 |
+
model_id = '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'
|
| 172 |
+
path = os.path.join(ckpt_dir, 'best.pkl')
|
| 173 |
+
if not os.path.exists(path):
|
| 174 |
+
if logger:
|
| 175 |
+
logger.info(f"Downloading MP3D model")
|
| 176 |
+
else:
|
| 177 |
+
print(f"Downloading MP3D model")
|
| 178 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 179 |
+
gdown.download(f"https://drive.google.com/uc?id={model_id}", path, False)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@torch.no_grad()
|
| 183 |
+
def create_high_res_floorplan(img, model, args, img_name, resolution):
|
| 184 |
+
"""Create a high-resolution floorplan that matches the mesh resolution"""
|
| 185 |
+
model.eval()
|
| 186 |
+
|
| 187 |
+
# Run inference to get layout data
|
| 188 |
+
dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device))
|
| 189 |
+
if args.post_processing != 'original':
|
| 190 |
+
dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing)
|
| 191 |
+
|
| 192 |
+
# Get the processed layout coordinates
|
| 193 |
+
output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0]))
|
| 194 |
+
|
| 195 |
+
# Create high-resolution floorplan
|
| 196 |
+
fill_color = [0.2, 0.2, 0.2, 0.2]
|
| 197 |
+
border_color = [1, 0, 0, 1]
|
| 198 |
+
|
| 199 |
+
# Use the same resolution as the mesh for consistency
|
| 200 |
+
floorplan = draw_floorplan(xz=output_xyz[..., ::2], fill_color=fill_color,
|
| 201 |
+
border_color=border_color, side_l=resolution, show=False,
|
| 202 |
+
center_color=[1, 0, 0, 1])
|
| 203 |
+
|
| 204 |
+
# Save high-res floorplan
|
| 205 |
+
floorplan_path = os.path.join(args.output_dir, f"{img_name}_floorplan_highres.png")
|
| 206 |
+
floorplan_img = Image.fromarray((floorplan * 255).astype(np.uint8), mode='RGBA')
|
| 207 |
+
|
| 208 |
+
# Create background and composite
|
| 209 |
+
back = np.zeros([resolution, resolution, 4], dtype=np.float32)
|
| 210 |
+
back[..., :] = [0.8, 0.8, 0.8, 1]
|
| 211 |
+
back_img = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA')
|
| 212 |
+
final_img = Image.alpha_composite(back_img, floorplan_img).convert("RGB")
|
| 213 |
+
final_img.save(floorplan_path)
|
| 214 |
+
|
| 215 |
+
return floorplan_path
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def calculate_measurements(layout_json, camera_height):
|
| 221 |
+
"""Calculate comprehensive room measurements from layout data"""
|
| 222 |
+
try:
|
| 223 |
+
import json
|
| 224 |
+
if isinstance(layout_json, str):
|
| 225 |
+
with open(layout_json, 'r') as f:
|
| 226 |
+
data = json.load(f)
|
| 227 |
+
else:
|
| 228 |
+
data = layout_json
|
| 229 |
+
|
| 230 |
+
# Extract wall lengths
|
| 231 |
+
walls = data.get('layoutWalls', {}).get('walls', [])
|
| 232 |
+
wall_lengths = [wall.get('width', 0) for wall in walls if 'width' in wall]
|
| 233 |
+
|
| 234 |
+
# Calculate basic measurements
|
| 235 |
+
perimeter = sum(wall_lengths) if wall_lengths else 0
|
| 236 |
+
|
| 237 |
+
# Estimate floor area (simple polygon approximation)
|
| 238 |
+
points = data.get('layoutPoints', {}).get('points', [])
|
| 239 |
+
if len(points) >= 3:
|
| 240 |
+
# Simple area calculation for polygon
|
| 241 |
+
area = 0
|
| 242 |
+
n = len(points)
|
| 243 |
+
for i in range(n):
|
| 244 |
+
j = (i + 1) % n
|
| 245 |
+
if 'xyz' in points[i] and 'xyz' in points[j]:
|
| 246 |
+
x1, _, z1 = points[i]['xyz']
|
| 247 |
+
x2, _, z2 = points[j]['xyz']
|
| 248 |
+
area += x1 * z2 - x2 * z1
|
| 249 |
+
area = abs(area) / 2
|
| 250 |
+
else:
|
| 251 |
+
area = 0
|
| 252 |
+
|
| 253 |
+
# Calculate ceiling height
|
| 254 |
+
layout_height = data.get('layoutHeight', camera_height + 1.0)
|
| 255 |
+
ceiling_height = layout_height - camera_height
|
| 256 |
+
|
| 257 |
+
# Format measurements
|
| 258 |
+
measurements = f"""๐ ROOM MEASUREMENTS (Camera Height: {camera_height:.2f}m)
|
| 259 |
+
|
| 260 |
+
๐ Floor Area: {area:.1f} mยฒ ({area * 10.764:.1f} ftยฒ)
|
| 261 |
+
๐ Room Perimeter: {perimeter:.1f} m ({perimeter * 3.281:.1f} ft)
|
| 262 |
+
๐ Ceiling Height: {ceiling_height:.1f} m ({ceiling_height * 3.281:.1f} ft)
|
| 263 |
+
๐ฆ Room Volume: {area * ceiling_height:.1f} mยณ
|
| 264 |
+
|
| 265 |
+
๐งฑ Wall Lengths: {', '.join([f'{w:.1f}m' for w in wall_lengths])}
|
| 266 |
+
|
| 267 |
+
๐ก ACCURACY NOTES:
|
| 268 |
+
โข All measurements scaled from camera height
|
| 269 |
+
โข ยฑ5cm height error = ยฑ3-8% measurement error
|
| 270 |
+
โข Best accuracy in center-captured, well-lit rooms"""
|
| 271 |
+
|
| 272 |
+
# Quality assessment
|
| 273 |
+
quality_notes = []
|
| 274 |
+
if area < 5:
|
| 275 |
+
quality_notes.append("โ ๏ธ Very small room - verify scale")
|
| 276 |
+
elif area > 200:
|
| 277 |
+
quality_notes.append("โ ๏ธ Very large room - verify scale")
|
| 278 |
+
|
| 279 |
+
if ceiling_height < 2.0:
|
| 280 |
+
quality_notes.append("โ ๏ธ Low ceiling - check camera height")
|
| 281 |
+
elif ceiling_height > 4.0:
|
| 282 |
+
quality_notes.append("โ ๏ธ High ceiling - verify measurements")
|
| 283 |
+
|
| 284 |
+
if len(wall_lengths) < 4:
|
| 285 |
+
quality_notes.append("โ ๏ธ Simplified room shape detected")
|
| 286 |
+
|
| 287 |
+
quality_report = "โ
Processing completed successfully\n\n"
|
| 288 |
+
if quality_notes:
|
| 289 |
+
quality_report += "๐ QUALITY NOTES:\n" + "\n".join(quality_notes)
|
| 290 |
+
else:
|
| 291 |
+
quality_report += "๐ Room measurements appear reasonable"
|
| 292 |
+
|
| 293 |
+
return measurements, quality_report
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
error_msg = f"โ Error calculating measurements: {str(e)}"
|
| 297 |
+
return error_msg, error_msg
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@spaces.GPU
|
| 301 |
+
def gpu_inference(img, model, args, img_name, mesh_resolution, logger):
|
| 302 |
+
"""GPU-intensive inference function"""
|
| 303 |
+
# Run main inference
|
| 304 |
run_one_inference(img, model, args, img_name,
|
| 305 |
logger=logger, show=False,
|
| 306 |
+
show_depth=True,
|
| 307 |
+
show_floorplan=True,
|
| 308 |
+
mesh_format='.obj', mesh_resolution=mesh_resolution)
|
| 309 |
|
| 310 |
+
# Generate high-resolution floorplan
|
| 311 |
+
floorplan_path = create_high_res_floorplan(img, model, args, img_name, mesh_resolution)
|
| 312 |
+
return floorplan_path
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
def greet(img_path, camera_height, units):
|
| 315 |
+
try:
|
| 316 |
+
# Hardcoded settings for optimal UX
|
| 317 |
+
args.pre_processing = True
|
| 318 |
+
args.post_processing = 'manhattan'
|
| 319 |
+
|
| 320 |
+
# Ensure output directory exists
|
| 321 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 322 |
+
|
| 323 |
+
# Use the global model
|
| 324 |
+
model = mp3d_model
|
| 325 |
+
|
| 326 |
+
img_name = os.path.basename(img_path).split('.')[0]
|
| 327 |
+
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
|
| 328 |
+
|
| 329 |
+
vp_cache_path = os.path.join(args.output_dir, f'{img_name}_vp.txt')
|
| 330 |
+
logger.info("pre-processing ...")
|
| 331 |
+
img, vp = preprocess(img, vp_cache_path=vp_cache_path)
|
| 332 |
+
|
| 333 |
+
img = (img / 255.0).astype(np.float32)
|
| 334 |
+
|
| 335 |
+
# High resolution mesh generation
|
| 336 |
+
mesh_resolution = 2048
|
| 337 |
+
|
| 338 |
+
# Run GPU inference in single decorated function
|
| 339 |
+
floorplan_path = gpu_inference(img, model, args, img_name, mesh_resolution, logger)
|
| 340 |
+
|
| 341 |
+
# Calculate measurements (CPU operation)
|
| 342 |
+
json_path = os.path.join(args.output_dir, f"{img_name}_pred.json")
|
| 343 |
+
measurements, quality_report = calculate_measurements(json_path, camera_height)
|
| 344 |
|
| 345 |
+
return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
|
| 346 |
+
floorplan_path,
|
| 347 |
+
os.path.join(args.output_dir, f"{img_name}_3d.obj"),
|
| 348 |
+
os.path.join(args.output_dir, f"{img_name}_3d.obj"),
|
| 349 |
+
vp_cache_path,
|
| 350 |
+
os.path.join(args.output_dir, f"{img_name}_pred.json"),
|
| 351 |
+
measurements,
|
| 352 |
+
quality_report]
|
| 353 |
+
|
| 354 |
+
except Exception as e:
|
| 355 |
+
error_msg = f"โ Error processing image: {str(e)}"
|
| 356 |
+
logger.error(error_msg)
|
| 357 |
+
|
| 358 |
+
# Return error placeholders
|
| 359 |
+
return [None, None, None, None, None, None, error_msg, error_msg]
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def get_model(args, logger=None):
|
| 363 |
config = get_config(args)
|
| 364 |
+
down_ckpt(args.cfg, config.CKPT.DIR, logger)
|
| 365 |
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
|
| 366 |
+
if logger:
|
| 367 |
+
logger.info(f'The {args.device} is not available, will use cpu...')
|
| 368 |
+
else:
|
| 369 |
+
print(f'The {args.device} is not available, will use cpu...')
|
| 370 |
config.defrost()
|
| 371 |
args.device = "cpu"
|
| 372 |
config.TRAIN.DEVICE = "cpu"
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
if __name__ == '__main__':
|
| 379 |
+
try:
|
| 380 |
+
logger = get_logger()
|
| 381 |
+
logger.info("Starting 3D Room Layout Estimation App...")
|
| 382 |
+
|
| 383 |
+
# Use GPU if available (A10G on HF Spaces)
|
| 384 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 385 |
+
logger.info(f"Using device: {device}")
|
| 386 |
+
|
| 387 |
+
args = Namespace(device=device, output_dir='output', visualize_3d=False, output_3d=True)
|
| 388 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 389 |
+
|
| 390 |
+
args.cfg = 'config/mp3d.yaml'
|
| 391 |
+
logger.info("Loading model...")
|
| 392 |
+
mp3d_model = get_model(args, logger)
|
| 393 |
+
logger.info("Model loaded successfully!")
|
| 394 |
+
|
| 395 |
+
except Exception as e:
|
| 396 |
+
print(f"Error during initialization: {e}")
|
| 397 |
+
raise
|
| 398 |
+
|
| 399 |
+
description = "Upload a panoramic image to generate a 3D room layout using " \
|
| 400 |
+
"<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
|
| 401 |
+
"The model automatically processes your image and outputs visualization, 3D mesh, and layout data."
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
demo = gr.Interface(
|
| 405 |
+
fn=greet,
|
| 406 |
+
inputs=[
|
| 407 |
+
gr.Image(type='filepath', label='Upload Panoramic Image'),
|
| 408 |
+
gr.Slider(minimum=1.0, maximum=3.0, value=1.6, label='Camera Height (meters)'),
|
| 409 |
+
gr.Radio(choices=['Metric', 'Imperial'], value='Metric', label='Units')
|
| 410 |
+
],
|
| 411 |
+
outputs=[
|
| 412 |
+
gr.Image(label='2D Layout Visualization', type='filepath'),
|
| 413 |
+
gr.Image(label='High-Res Floorplan', type='filepath'),
|
| 414 |
+
gr.Model3D(label='3D Room Layout', clear_color=[1.0, 1.0, 1.0, 1.0]),
|
| 415 |
+
gr.File(label='3D Mesh (.obj)'),
|
| 416 |
+
gr.File(label='Vanishing Point Data'),
|
| 417 |
+
gr.File(label='Layout JSON'),
|
| 418 |
+
gr.Textbox(label='Room Measurements'),
|
| 419 |
+
gr.Textbox(label='Quality Report')
|
| 420 |
+
],
|
| 421 |
+
title='3D Room Layout Estimation',
|
| 422 |
+
description=description,
|
| 423 |
+
allow_flagging="never",
|
| 424 |
+
cache_examples=False
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
logger.info("Gradio interface created successfully")
|
| 428 |
+
demo.launch(debug=True)
|
| 429 |
+
|
| 430 |
+
except Exception as e:
|
| 431 |
+
logger.error(f"Failed to create or launch Gradio interface: {e}")
|
| 432 |
+
print(f"Error: {e}")
|
| 433 |
+
raise
|