Spaces:
Runtime error
Runtime error
| import os | |
| import tyro | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from core.options import AllConfigs, Options | |
| from core.gs import GaussianRenderer | |
| import dearpygui.dearpygui as dpg | |
| import kiui | |
| from kiui.cam import OrbitCamera | |
| class GUI: | |
| def __init__(self, opt: Options): | |
| self.opt = opt | |
| self.W = opt.output_size | |
| self.H = opt.output_size | |
| self.cam = OrbitCamera(self.W, self.H, r=opt.cam_radius, fovy=opt.fovy) | |
| self.device = torch.device("cuda") | |
| self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) | |
| self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) | |
| self.proj_matrix[0, 0] = 1 / self.tan_half_fov | |
| self.proj_matrix[1, 1] = 1 / self.tan_half_fov | |
| self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) | |
| self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) | |
| self.proj_matrix[2, 3] = 1 | |
| self.mode = "image" | |
| self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) | |
| self.need_update = True # update buffer_image | |
| # renderer | |
| self.renderer = GaussianRenderer(opt) | |
| self.gaussain_scale_factor = 1 | |
| self.gaussians = self.renderer.load_ply(opt.test_path).to(self.device) | |
| dpg.create_context() | |
| self.register_dpg() | |
| self.test_step() | |
| def __del__(self): | |
| dpg.destroy_context() | |
| def test_step(self): | |
| # ignore if no need to update | |
| if not self.need_update: | |
| return | |
| starter = torch.cuda.Event(enable_timing=True) | |
| ender = torch.cuda.Event(enable_timing=True) | |
| starter.record() | |
| # should update image | |
| if self.need_update: | |
| # render image | |
| cam_poses = torch.from_numpy(self.cam.pose).unsqueeze(0).to(self.device) | |
| cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction | |
| # cameras needed by gaussian rasterizer | |
| cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] | |
| cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] | |
| cam_pos = - cam_poses[:, :3, 3] # [V, 3] | |
| buffer_image = self.renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=self.gaussain_scale_factor)[self.mode] | |
| buffer_image = buffer_image.squeeze(1) # [B, C, H, W] | |
| if self.mode in ['alpha']: | |
| buffer_image = buffer_image.repeat(1, 3, 1, 1) | |
| buffer_image = F.interpolate( | |
| buffer_image, | |
| size=(self.H, self.W), | |
| mode="bilinear", | |
| align_corners=False, | |
| ).squeeze(0) | |
| self.buffer_image = ( | |
| buffer_image.permute(1, 2, 0) | |
| .contiguous() | |
| .clamp(0, 1) | |
| .contiguous() | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| ) | |
| self.need_update = False | |
| ender.record() | |
| torch.cuda.synchronize() | |
| t = starter.elapsed_time(ender) | |
| dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") | |
| dpg.set_value( | |
| "_texture", self.buffer_image | |
| ) # buffer must be contiguous, else seg fault! | |
| def register_dpg(self): | |
| ### register texture | |
| with dpg.texture_registry(show=False): | |
| dpg.add_raw_texture( | |
| self.W, | |
| self.H, | |
| self.buffer_image, | |
| format=dpg.mvFormat_Float_rgb, | |
| tag="_texture", | |
| ) | |
| ### register window | |
| # the rendered image, as the primary window | |
| with dpg.window( | |
| tag="_primary_window", | |
| width=self.W, | |
| height=self.H, | |
| pos=[0, 0], | |
| no_move=True, | |
| no_title_bar=True, | |
| no_scrollbar=True, | |
| ): | |
| # add the texture | |
| dpg.add_image("_texture") | |
| # dpg.set_primary_window("_primary_window", True) | |
| # control window | |
| with dpg.window( | |
| label="Control", | |
| tag="_control_window", | |
| width=600, | |
| height=self.H, | |
| pos=[self.W, 0], | |
| no_move=True, | |
| no_title_bar=True, | |
| ): | |
| # button theme | |
| with dpg.theme() as theme_button: | |
| with dpg.theme_component(dpg.mvButton): | |
| dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) | |
| dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) | |
| dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) | |
| dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) | |
| dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) | |
| # timer stuff | |
| with dpg.group(horizontal=True): | |
| dpg.add_text("Infer time: ") | |
| dpg.add_text("no data", tag="_log_infer_time") | |
| # rendering options | |
| with dpg.collapsing_header(label="Rendering", default_open=True): | |
| # mode combo | |
| def callback_change_mode(sender, app_data): | |
| self.mode = app_data | |
| self.need_update = True | |
| dpg.add_combo( | |
| ("image", "alpha"), | |
| label="mode", | |
| default_value=self.mode, | |
| callback=callback_change_mode, | |
| ) | |
| # fov slider | |
| def callback_set_fovy(sender, app_data): | |
| self.cam.fovy = np.deg2rad(app_data) | |
| self.need_update = True | |
| dpg.add_slider_int( | |
| label="FoV (vertical)", | |
| min_value=1, | |
| max_value=120, | |
| format="%d deg", | |
| default_value=np.rad2deg(self.cam.fovy), | |
| callback=callback_set_fovy, | |
| ) | |
| def callback_set_gaussain_scale(sender, app_data): | |
| self.gaussain_scale_factor = app_data | |
| self.need_update = True | |
| dpg.add_slider_float( | |
| label="gaussain scale", | |
| min_value=0, | |
| max_value=1, | |
| format="%.2f", | |
| default_value=self.gaussain_scale_factor, | |
| callback=callback_set_gaussain_scale, | |
| ) | |
| ### register camera handler | |
| def callback_camera_drag_rotate(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| dx = app_data[1] | |
| dy = app_data[2] | |
| self.cam.orbit(dx, dy) | |
| self.need_update = True | |
| def callback_camera_wheel_scale(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| delta = app_data | |
| self.cam.scale(delta) | |
| self.need_update = True | |
| def callback_camera_drag_pan(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| dx = app_data[1] | |
| dy = app_data[2] | |
| self.cam.pan(dx, dy) | |
| self.need_update = True | |
| with dpg.handler_registry(): | |
| # for camera moving | |
| dpg.add_mouse_drag_handler( | |
| button=dpg.mvMouseButton_Left, | |
| callback=callback_camera_drag_rotate, | |
| ) | |
| dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) | |
| dpg.add_mouse_drag_handler( | |
| button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan | |
| ) | |
| dpg.create_viewport( | |
| title="Gaussian3D", | |
| width=self.W + 600, | |
| height=self.H + (45 if os.name == "nt" else 0), | |
| resizable=False, | |
| ) | |
| ### global theme | |
| with dpg.theme() as theme_no_padding: | |
| with dpg.theme_component(dpg.mvAll): | |
| # set all padding to 0 to avoid scroll bar | |
| dpg.add_theme_style( | |
| dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core | |
| ) | |
| dpg.add_theme_style( | |
| dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core | |
| ) | |
| dpg.add_theme_style( | |
| dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core | |
| ) | |
| dpg.bind_item_theme("_primary_window", theme_no_padding) | |
| dpg.setup_dearpygui() | |
| ### register a larger font | |
| # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf | |
| if os.path.exists("LXGWWenKai-Regular.ttf"): | |
| with dpg.font_registry(): | |
| with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: | |
| dpg.bind_font(default_font) | |
| # dpg.show_metrics() | |
| dpg.show_viewport() | |
| def render(self): | |
| while dpg.is_dearpygui_running(): | |
| # update texture every frame | |
| self.test_step() | |
| dpg.render_dearpygui_frame() | |
| opt = tyro.cli(AllConfigs) | |
| # load a saved ply and visualize | |
| assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py' | |
| gui = GUI(opt) | |
| gui.render() |