| ''' |
| @author: Zhigang Jiang |
| @time: 2022/05/23 |
| @description: |
| ''' |
|
|
| import gradio as gr |
| import numpy as np |
| import os |
| import torch |
| from PIL import Image |
| import spaces |
|
|
| from utils.logger import get_logger |
| from config.defaults import get_config |
| |
| import cv2 |
| import matplotlib.pyplot as plt |
| from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama |
| from utils.boundary import corners2boundaries, layout2depth |
| from utils.conversion import depth2xyz |
| from utils.misc import tensor2np_d |
| from utils.writer import xyz2json |
| from visualization.boundary import draw_boundaries |
| from visualization.floorplan import draw_floorplan, draw_iou_floorplan |
| from visualization.obj3d import create_3d_obj |
| from models.build import build_model |
| from argparse import Namespace |
| import gdown |
| from utils.misc import tensor2np |
| from postprocessing.post_process import post_process |
|
|
|
|
| def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None): |
| """Align images with VP - moved from inference.py""" |
| if vp_cache_path and os.path.exists(vp_cache_path): |
| with open(vp_cache_path) as f: |
| vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()] |
| vp = np.array(vp) |
| else: |
| |
| _, vp, _, _, _, _, _ = panoEdgeDetection(img_ori, |
| qError=q_error, |
| refineIter=refine_iter) |
| i_img = rotatePanorama(img_ori, vp[2::-1]) |
|
|
| if vp_cache_path is not None: |
| with open(vp_cache_path, 'w') as f: |
| for i in range(3): |
| f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2])) |
|
|
| return i_img, vp |
|
|
|
|
| def show_depth_normal_grad(dt): |
| """Simplified gradient visualization - moved from inference.py""" |
| depth = tensor2np(dt['depth'][0]) |
| grad_img = np.gradient(depth, axis=1) |
| grad_img = np.abs(grad_img) |
| grad_img = (grad_img / grad_img.max() * 255).astype(np.uint8) |
| grad_img = cv2.applyColorMap(grad_img, cv2.COLORMAP_JET) |
| grad_img = cv2.resize(grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST) |
| return grad_img |
|
|
|
|
| def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None): |
| """Generate alpha floorplan - moved from inference.py""" |
| if border_color is None: |
| border_color = [1, 0, 0, 1] |
| fill_color = [0.2, 0.2, 0.2, 0.2] |
| dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color, |
| border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1]) |
| dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA') |
| back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float32) |
| back[..., :] = [0.8, 0.8, 0.8, 1] |
| back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA') |
| iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB") |
| dt_floorplan = np.array(iou_floorplan) / 255.0 |
| return dt_floorplan |
|
|
|
|
| def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None): |
| """2D visualization - moved from inference.py""" |
| dt_np = tensor2np_d(dt) |
| dt_depth = dt_np['depth'][0] |
| dt_xyz = depth2xyz(np.abs(dt_depth)) |
| dt_ratio = dt_np['ratio'][0][0] |
| dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1]) |
| vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0]) |
|
|
| if 'processed_xyz' in dt: |
| dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False, |
| length=img.shape[1]) |
| vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0]) |
|
|
| if show_depth: |
| dt_grad_img = show_depth_normal_grad(dt) |
| grad_h = dt_grad_img.shape[0] |
| vis_merge = [ |
| vis_img[0:-grad_h, :, :], |
| dt_grad_img, |
| ] |
| vis_img = np.concatenate(vis_merge, axis=0) |
|
|
| if show_floorplan: |
| if 'processed_xyz' in dt: |
| floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2], |
| dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1]) |
| else: |
| floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1]) |
|
|
| vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1) |
| if show: |
| plt.imshow(vis_img) |
| plt.show() |
| if save_path: |
| result = Image.fromarray((vis_img * 255).astype(np.uint8)) |
| result.save(save_path) |
| return vis_img |
|
|
|
|
| def save_pred_json(xyz, ration, save_path): |
| """Save prediction JSON - moved from inference.py""" |
| json_data = xyz2json(xyz, ration) |
| with open(save_path, 'w') as f: |
| import json |
| f.write(json.dumps(json_data, indent=4) + '\n') |
| return json_data |
|
|
|
|
| @torch.no_grad() |
| def run_one_inference(img, model, args, name, logger=None, show=True, show_depth=True, |
| show_floorplan=True, mesh_format='.gltf', mesh_resolution=512): |
| """Main inference function - moved from inference.py""" |
| model.eval() |
| if logger: |
| logger.info("model inference...") |
| dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device)) |
| if args.post_processing != 'original': |
| if logger: |
| logger.info(f"post-processing, type:{args.post_processing}...") |
| dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing) |
|
|
| visualize_2d(img, dt, |
| show_depth=show_depth, |
| show_floorplan=show_floorplan, |
| show=show, |
| save_path=os.path.join(args.output_dir, f"{name}_pred.png")) |
| output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0])) |
|
|
| if logger: |
| logger.info(f"saving predicted layout json...") |
| json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0], |
| save_path=os.path.join(args.output_dir, f"{name}_pred.json")) |
|
|
| if args.visualize_3d or args.output_3d: |
| dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None, |
| length=mesh_resolution if 'processed_xyz' in dt else None, |
| visible=True if 'processed_xyz' in dt else False) |
| dt_layout_depth = layout2depth(dt_boundaries, show=False) |
|
|
| if logger: |
| logger.info(f"creating 3d mesh ...") |
| create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth, |
| save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None, |
| mesh=True, show=args.visualize_3d) |
|
|
|
|
| def down_ckpt(model_cfg, ckpt_dir, logger=None): |
| |
| model_id = '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh' |
| path = os.path.join(ckpt_dir, 'best.pkl') |
| if not os.path.exists(path): |
| if logger: |
| logger.info(f"Downloading MP3D model") |
| else: |
| print(f"Downloading MP3D model") |
| os.makedirs(ckpt_dir, exist_ok=True) |
| gdown.download(f"https://drive.google.com/uc?id={model_id}", path, False) |
|
|
|
|
| @torch.no_grad() |
| def create_high_res_floorplan(img, model, args, img_name, resolution): |
| """Create a high-resolution floorplan that matches the mesh resolution""" |
| model.eval() |
| |
| |
| dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device)) |
| if args.post_processing != 'original': |
| dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing) |
| |
| |
| output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0])) |
| |
| |
| fill_color = [0.2, 0.2, 0.2, 0.2] |
| border_color = [1, 0, 0, 1] |
| |
| |
| floorplan = draw_floorplan(xz=output_xyz[..., ::2], fill_color=fill_color, |
| border_color=border_color, side_l=resolution, show=False, |
| center_color=[1, 0, 0, 1]) |
| |
| |
| floorplan_path = os.path.join(args.output_dir, f"{img_name}_floorplan_highres.png") |
| floorplan_img = Image.fromarray((floorplan * 255).astype(np.uint8), mode='RGBA') |
| |
| |
| back = np.zeros([resolution, resolution, 4], dtype=np.float32) |
| back[..., :] = [0.8, 0.8, 0.8, 1] |
| back_img = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA') |
| final_img = Image.alpha_composite(back_img, floorplan_img).convert("RGB") |
| final_img.save(floorplan_path) |
| |
| return floorplan_path |
|
|
|
|
|
|
|
|
| def calculate_measurements(layout_json, camera_height): |
| """Calculate comprehensive room measurements from layout data""" |
| try: |
| import json |
| if isinstance(layout_json, str): |
| with open(layout_json, 'r') as f: |
| data = json.load(f) |
| else: |
| data = layout_json |
| |
| |
| walls = data.get('layoutWalls', {}).get('walls', []) |
| wall_lengths = [wall.get('width', 0) for wall in walls if 'width' in wall] |
| |
| |
| perimeter = sum(wall_lengths) if wall_lengths else 0 |
| |
| |
| points = data.get('layoutPoints', {}).get('points', []) |
| if len(points) >= 3: |
| |
| area = 0 |
| n = len(points) |
| for i in range(n): |
| j = (i + 1) % n |
| if 'xyz' in points[i] and 'xyz' in points[j]: |
| x1, _, z1 = points[i]['xyz'] |
| x2, _, z2 = points[j]['xyz'] |
| area += x1 * z2 - x2 * z1 |
| area = abs(area) / 2 |
| else: |
| area = 0 |
| |
| |
| layout_height = data.get('layoutHeight', camera_height + 1.0) |
| ceiling_height = layout_height - camera_height |
| |
| |
| measurements = f"""📏 ROOM MEASUREMENTS (Camera Height: {camera_height:.2f}m) |
| |
| 🏠 Floor Area: {area:.1f} m² ({area * 10.764:.1f} ft²) |
| 📐 Room Perimeter: {perimeter:.1f} m ({perimeter * 3.281:.1f} ft) |
| 📊 Ceiling Height: {ceiling_height:.1f} m ({ceiling_height * 3.281:.1f} ft) |
| 📦 Room Volume: {area * ceiling_height:.1f} m³ |
| |
| 🧱 Wall Lengths: {', '.join([f'{w:.1f}m' for w in wall_lengths])} |
| |
| 💡 ACCURACY NOTES: |
| • All measurements scaled from camera height |
| • ±5cm height error = ±3-8% measurement error |
| • Best accuracy in center-captured, well-lit rooms""" |
|
|
| |
| quality_notes = [] |
| if area < 5: |
| quality_notes.append("⚠️ Very small room - verify scale") |
| elif area > 200: |
| quality_notes.append("⚠️ Very large room - verify scale") |
| |
| if ceiling_height < 2.0: |
| quality_notes.append("⚠️ Low ceiling - check camera height") |
| elif ceiling_height > 4.0: |
| quality_notes.append("⚠️ High ceiling - verify measurements") |
| |
| if len(wall_lengths) < 4: |
| quality_notes.append("⚠️ Simplified room shape detected") |
| |
| quality_report = "✅ Processing completed successfully\n\n" |
| if quality_notes: |
| quality_report += "📊 QUALITY NOTES:\n" + "\n".join(quality_notes) |
| else: |
| quality_report += "📊 Room measurements appear reasonable" |
| |
| return measurements, quality_report |
| |
| except Exception as e: |
| error_msg = f"❌ Error calculating measurements: {str(e)}" |
| return error_msg, error_msg |
|
|
|
|
| @spaces.GPU |
| def gpu_inference(img, model, args, img_name, mesh_resolution, logger): |
| """GPU-intensive inference function""" |
| |
| run_one_inference(img, model, args, img_name, |
| logger=logger, show=False, |
| show_depth=True, |
| show_floorplan=True, |
| mesh_format='.obj', mesh_resolution=mesh_resolution) |
|
|
| |
| floorplan_path = create_high_res_floorplan(img, model, args, img_name, mesh_resolution) |
| return floorplan_path |
|
|
| def greet(img_path, camera_height, units): |
| try: |
| |
| args.pre_processing = True |
| args.post_processing = 'manhattan' |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| model = mp3d_model |
|
|
| img_name = os.path.basename(img_path).split('.')[0] |
| img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] |
|
|
| vp_cache_path = os.path.join(args.output_dir, f'{img_name}_vp.txt') |
| logger.info("pre-processing ...") |
| img, vp = preprocess(img, vp_cache_path=vp_cache_path) |
|
|
| img = (img / 255.0).astype(np.float32) |
| |
| |
| mesh_resolution = 2048 |
| |
| |
| floorplan_path = gpu_inference(img, model, args, img_name, mesh_resolution, logger) |
| |
| |
| json_path = os.path.join(args.output_dir, f"{img_name}_pred.json") |
| measurements, quality_report = calculate_measurements(json_path, camera_height) |
|
|
| return [os.path.join(args.output_dir, f"{img_name}_pred.png"), |
| floorplan_path, |
| os.path.join(args.output_dir, f"{img_name}_3d.obj"), |
| os.path.join(args.output_dir, f"{img_name}_3d.obj"), |
| vp_cache_path, |
| os.path.join(args.output_dir, f"{img_name}_pred.json"), |
| measurements, |
| quality_report] |
| |
| except Exception as e: |
| error_msg = f"❌ Error processing image: {str(e)}" |
| logger.error(error_msg) |
| |
| |
| return [None, None, None, None, None, None, error_msg, error_msg] |
|
|
|
|
| def get_model(args, logger=None): |
| config = get_config(args) |
| down_ckpt(args.cfg, config.CKPT.DIR, logger) |
| if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): |
| if logger: |
| logger.info(f'The {args.device} is not available, will use cpu...') |
| else: |
| print(f'The {args.device} is not available, will use cpu...') |
| config.defrost() |
| args.device = "cpu" |
| config.TRAIN.DEVICE = "cpu" |
| config.freeze() |
| model, _, _, _ = build_model(config, logger) |
| return model |
|
|
|
|
| if __name__ == '__main__': |
| try: |
| logger = get_logger() |
| logger.info("Starting 3D Room Layout Estimation App...") |
| |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| logger.info(f"Using device: {device}") |
| |
| args = Namespace(device=device, output_dir='output', visualize_3d=False, output_3d=True) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| args.cfg = 'config/mp3d.yaml' |
| logger.info("Loading model...") |
| mp3d_model = get_model(args, logger) |
| logger.info("Model loaded successfully!") |
| |
| except Exception as e: |
| print(f"Error during initialization: {e}") |
| raise |
|
|
| description = "Upload a panoramic image to generate a 3D room layout using " \ |
| "<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \ |
| "The model automatically processes your image and outputs visualization, 3D mesh, and layout data." |
|
|
| try: |
| demo = gr.Interface( |
| fn=greet, |
| inputs=[ |
| gr.Image(type='filepath', label='Upload Panoramic Image'), |
| gr.Slider(minimum=1.0, maximum=3.0, value=1.6, label='Camera Height (meters)'), |
| gr.Radio(choices=['Metric', 'Imperial'], value='Metric', label='Units') |
| ], |
| outputs=[ |
| gr.Image(label='2D Layout Visualization', type='filepath'), |
| gr.Image(label='High-Res Floorplan', type='filepath'), |
| gr.Model3D(label='3D Room Layout', clear_color=[1.0, 1.0, 1.0, 1.0]), |
| gr.File(label='3D Mesh (.obj)'), |
| gr.File(label='Vanishing Point Data'), |
| gr.File(label='Layout JSON'), |
| gr.Textbox(label='Room Measurements'), |
| gr.Textbox(label='Quality Report') |
| ], |
| title='3D Room Layout Estimation', |
| description=description, |
| allow_flagging="never", |
| cache_examples=False |
| ) |
| |
| logger.info("Gradio interface created successfully") |
| demo.launch(debug=True) |
| |
| except Exception as e: |
| logger.error(f"Failed to create or launch Gradio interface: {e}") |
| print(f"Error: {e}") |
| raise |