Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import numpy as np | |
| import cv2 | |
| import PIL.Image as pil_img | |
| import subprocess | |
| subprocess.run( | |
| 'pip install networkx==2.5' | |
| .split() | |
| ) | |
| import gradio as gr | |
| import trimesh | |
| import plotly.graph_objects as go | |
| from models.deco import DECO | |
| from common import constants | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| else: | |
| device = torch.device('cpu') | |
| description = ''' | |
| ### DECO: Dense Estimation of 3D Human-Scene Contact in the Wild (ICCV 2023, Oral) | |
| <table> | |
| <th width="20%"> | |
| <ul> | |
| <li><strong><a href="https://deco.is.tue.mpg.de/">Homepage</a></strong> | |
| <li><strong><a href="https://github.com/sha2nkt/deco">Code</a></strong> | |
| <li><strong><a href="https://openaccess.thecvf.com/content/ICCV2023/html/Tripathi_DECO_Dense_Estimation_of_3D_Human-Scene_Contact_In_The_Wild_ICCV_2023_paper.html">Paper</a></strong> | |
| </ul> | |
| <br> | |
| <ul> | |
| <li><strong>Colab Notebook</strong> <a href='https://colab.research.google.com/drive/1fTQdI2AHEKlwYG9yIb2wqicIMhAa067_?usp=sharing'><img style="display: inline-block;" src='https://colab.research.google.com/assets/colab-badge.svg' alt='Google Colab'></a></li> | |
| </ul> | |
| <br> | |
| <iframe src="https://ghbtns.com/github-btn.html?user=sha2nkt&repo=deco&type=star&count=true&v=2&size=small" frameborder="0" scrolling="0" width="100" height="20"></iframe> | |
| </th> | |
| </table> | |
| #### Citation | |
| ``` | |
| @InProceedings{tripathi2023deco, | |
| author = {Tripathi, Shashank and Chatterjee, Agniv and Passy, Jean-Claude and Yi, Hongwei and Tzionas, Dimitrios and Black, Michael J.}, | |
| title = {{DECO}: Dense Estimation of {3D} Human-Scene Contact In The Wild}, | |
| booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | |
| month = {October}, | |
| year = {2023}, | |
| pages = {8001-8013} | |
| } | |
| ``` | |
| <details> | |
| <summary>More</summary> | |
| #### Acknowledgments: | |
| - [ECON](https://huggingface.co/spaces/Yuliang/ECON) | |
| </details> | |
| ''' | |
| DEFAULT_LIGHTING = dict( | |
| ambient=0.6, | |
| diffuse=0.5, | |
| fresnel=0.01, | |
| specular=0.1, | |
| roughness=0.001) | |
| DEFAULT_LIGHT_POSITION = dict(x=6, y=0, z=10) | |
| def initiate_model(model_path): | |
| deco_model = DECO('hrnet', True, device) | |
| print(f'Loading weights from {model_path}') | |
| checkpoint = torch.load(model_path) | |
| deco_model.load_state_dict(checkpoint['deco'], strict=True) | |
| deco_model.eval() | |
| return deco_model | |
| def create_layout(dummy, camera=None): | |
| if camera is None: | |
| camera = dict( | |
| up=dict(x=0, y=1, z=0), | |
| center=dict(x=0, y=0, z=0), | |
| eye=dict(x=dummy.x.mean(), y=0, z=3), | |
| projection=dict(type='perspective')) | |
| layout = dict( | |
| scene={ | |
| "xaxis": { | |
| 'showgrid': False, | |
| 'zeroline': False, | |
| 'visible': False, | |
| "range": [dummy.x.min(), dummy.x.max()] | |
| }, | |
| "yaxis": { | |
| 'showgrid': False, | |
| 'zeroline': False, | |
| 'visible': False, | |
| "range": [dummy.y.min(), dummy.y.max()] | |
| }, | |
| "zaxis": { | |
| 'showgrid': False, | |
| 'zeroline': False, | |
| 'visible': False, | |
| "range": [dummy.z.min(), dummy.z.max()] | |
| }, | |
| }, | |
| autosize=False, | |
| width=750, height=1000, | |
| scene_camera=camera, | |
| scene_aspectmode="data", | |
| clickmode="event+select", | |
| margin={'l': 0, 't': 0} | |
| ) | |
| return layout | |
| def create_fig(dummy, camera=None): | |
| fig = go.Figure( | |
| data=dummy.mesh_3d(), | |
| layout=create_layout(dummy, camera)) | |
| return fig | |
| class Dummy: | |
| def __init__(self, mesh_path): | |
| """A simple polygonal dummy with colored patches.""" | |
| self._load_trimesh(mesh_path) | |
| def _load_trimesh(self, path): | |
| """Load a mesh given a path to a .PLY file.""" | |
| self._trimesh = trimesh.load(path, process=False) | |
| self._vertices = np.array(self._trimesh.vertices) | |
| self._faces = np.array(self._trimesh.faces) | |
| self.colors = self._trimesh.visual.vertex_colors | |
| def vertices(self): | |
| """All the mesh vertices.""" | |
| return self._vertices | |
| def faces(self): | |
| """All the mesh faces.""" | |
| return self._faces | |
| def n_vertices(self): | |
| """Number of vertices in a mesh.""" | |
| return self._vertices.shape[0] | |
| def n_faces(self): | |
| """Number of faces in a mesh.""" | |
| return self._faces.shape[0] | |
| def x(self): | |
| """An array of vertex x coordinates""" | |
| return self._vertices[:, 0] | |
| def y(self): | |
| """An array of vertex y coordinates""" | |
| return self._vertices[:, 1] | |
| def z(self): | |
| """An array of vertex z coordinates""" | |
| return self._vertices[:, 2] | |
| def i(self): | |
| """An array of the first face vertices""" | |
| return self._faces[:, 0] | |
| def j(self): | |
| """An array of the second face vertices""" | |
| return self._faces[:, 1] | |
| def k(self): | |
| """An array of the third face vertices""" | |
| return self._faces[:, 2] | |
| def default_selection(self): | |
| """Default patch selection mask.""" | |
| return dict(vertices=[]) | |
| def mesh_3d( | |
| self, | |
| lighting=DEFAULT_LIGHTING, | |
| light_position=DEFAULT_LIGHT_POSITION | |
| ): | |
| """Construct a Mesh3D object give a clickmask for patch coloring.""" | |
| return go.Mesh3d( | |
| x=self.x, y=self.y, z=self.z, | |
| i=self.i, j=self.j, k=self.k, | |
| vertexcolor=self.colors, | |
| lighting=lighting, | |
| lightposition=light_position, | |
| hoverinfo='none') | |
| def main(pil_img, out_dir='demo_out', model_path='checkpoint/deco_best.pth', mesh_colour=[130, 130, 130, 255], annot_colour=[0, 255, 0, 255]): | |
| deco_model = initiate_model(model_path) | |
| smpl_path = os.path.join(constants.SMPL_MODEL_DIR, 'smpl_neutral_tpose.ply') | |
| img = np.array(pil_img) | |
| img = cv2.resize(img, (256, 256), cv2.INTER_CUBIC) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = img.transpose(2,0,1)/255.0 | |
| img = img[np.newaxis,:,:,:] | |
| img = torch.tensor(img, dtype = torch.float32).to(device) | |
| with torch.no_grad(): | |
| cont, _, _ = deco_model(img) | |
| cont = cont.detach().cpu().numpy().squeeze() | |
| cont_smpl = [] | |
| for indx, i in enumerate(cont): | |
| if i >= 0.5: | |
| cont_smpl.append(indx) | |
| img = img.detach().cpu().numpy() | |
| img = np.transpose(img[0], (1, 2, 0)) | |
| img = img * 255 | |
| img = img.astype(np.uint8) | |
| contact_smpl = np.zeros((1, 1, 6890)) | |
| contact_smpl[0][0][cont_smpl] = 1 | |
| body_model_smpl = trimesh.load(smpl_path, process=False) | |
| for vert in range(body_model_smpl.visual.vertex_colors.shape[0]): | |
| body_model_smpl.visual.vertex_colors[vert] = mesh_colour | |
| body_model_smpl.visual.vertex_colors[cont_smpl] = annot_colour | |
| mesh_out_dir = os.path.join(out_dir, 'Preds') | |
| os.makedirs(mesh_out_dir, exist_ok=True) | |
| print(f'Saving mesh to {mesh_out_dir}') | |
| body_model_smpl.export(os.path.join(mesh_out_dir, 'pred.obj')) | |
| dummy = Dummy(os.path.join(mesh_out_dir, 'pred.obj')) | |
| fig = create_fig(dummy) | |
| return fig, os.path.join(mesh_out_dir, 'pred.obj') | |
| with gr.Blocks(title="DECO", css=".gradio-container") as demo: | |
| gr.Markdown(description) | |
| gr.HTML("""<h1 style="text-align:center; color:#10768c">DECO Demo</h1>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input image", type="pil") | |
| with gr.Column(): | |
| output_image = gr.Plot(label="Renders") | |
| output_meshes = gr.File(label="3D meshes") | |
| gr.HTML("""<br/>""") | |
| with gr.Row(): | |
| send_btn = gr.Button("Infer") | |
| send_btn.click(fn=main, inputs=[input_image], outputs=[output_image, output_meshes]) | |
| example_images = gr.Examples([ | |
| ['/home/user/app/example_images/213.jpg'], | |
| ['/home/user/app/example_images/pexels-photo-207569.webp'], | |
| ['/home/user/app/example_images/pexels-photo-3622517.webp'], | |
| ['/home/user/app/example_images/pexels-photo-15732209.jpeg'], | |
| ], | |
| inputs=[input_image]) | |
| demo.launch(debug=True) |