Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import gradio as gr | |
| import pdb | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import os | |
| import fire | |
| import multiprocessing as mp | |
| import os, sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "DermSynth3D")) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "DermSynth3D", "dermsynth3d")) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "DermSynth3D", "skin3d")) | |
| import pandas as pd | |
| import numpy as np | |
| from glob import glob | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import trimesh | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import math | |
| from trimesh import transformations as tf | |
| import os | |
| from math import pi | |
| import matplotlib.pyplot as plt | |
| import plotly | |
| import plotly.graph_objects as go | |
| from skimage import io | |
| view_width = 400 | |
| view_height = 400 | |
| import mediapy as mpy | |
| try: | |
| from pytorch3d.io import load_objs_as_meshes | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.renderer import ( | |
| look_at_view_transform, | |
| FoVPerspectiveCameras, | |
| PointLights, | |
| DirectionalLights, | |
| Materials, | |
| RasterizationSettings, | |
| MeshRenderer, | |
| MeshRasterizer, | |
| SoftPhongShader, | |
| TexturesUV, | |
| TexturesVertex, | |
| ) | |
| print("Pytorch3d compiled properly") | |
| except: | |
| print("Pytorch3d not compiled properly. Install pytorch3d with torch/cuda support") | |
| try: | |
| sys.path.append("./DermSynth3D/") | |
| sys.path.append("./DermSynth3D/dermsynth3d/") | |
| sys.path.append("./DermSynth3D/skin3d/") | |
| from dermsynth3d import BlendLesions, Generate2DViews, SelectAndPaste | |
| from dermsynth3d.tools.generate2d import Generate2DHelper | |
| from dermsynth3d.utils.utils import yaml_loader | |
| from dermsynth3d.utils.utils import random_bound, make_masks | |
| from dermsynth3d.tools.synthesize import Synthesize2D | |
| from dermsynth3d.datasets.synth_dataset import SynthesizeDataset | |
| from dermsynth3d.tools.renderer import ( | |
| MeshRendererPyTorch3D, | |
| camera_pos_from_normal, | |
| ) | |
| from dermsynth3d.deepblend.blend3d import Blended3d | |
| from dermsynth3d.utils.channels import Target | |
| from dermsynth3d.utils.tensor import ( | |
| pil_to_tensor, | |
| ) | |
| from dermsynth3d.utils.colorconstancy import shade_of_gray_cc | |
| from dermsynth3d.datasets.datasets import Fitz17KAnnotations, Background2d | |
| from skin3d.skin3d.bodytex import BodyTexDataset | |
| print("DermSynth3D compiled properly") | |
| except Exception as e: | |
| print(e) | |
| print("DermSynth3D not in the path. Make sure to add it to the path.") | |
| _TITLE = """DermSynth3D: A Framework for generating Synthetic Dermatological Images""" | |
| _DESCRIPTION = """ | |
| **Step 1**. Select the Mesh, texture map and number of lesions from the dropdown or select an example.</br> | |
| **Step 2**. Selct the number of views to render. </br> | |
| **Step 3** (optional). Randomize the view parameters by clicking on the checkbox.</br> | |
| **Step 4**. Click on the Render Views button to render the views. </br> | |
| """ | |
| deployed = True | |
| if deployed: | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| global DEVICE | |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| else: | |
| print("Running on CPU") | |
| global mesh_paths, mesh_names, all_textures, dir_blended_textures, dir_anatomy | |
| global get_no_lesion_path, get_mesh_path, get_mask_path, get_dilated_lesion_path | |
| global get_blended_lesion_path, get_pasted_lesion_path, get_texture_module | |
| global dir_blended_textures, dir_anatomy, dir_background | |
| # File path of the bodytex CSV. | |
| bodytex_csv = "./DermSynth3D/skin3d/data/3dbodytex-1.1-highres/bodytex.csv" | |
| bodytex_df = pd.read_csv(bodytex_csv, converters={"scan_id": lambda x: str(x)}) | |
| bodytex = BodyTexDataset( | |
| df=bodytex_df, | |
| dir_textures="./DermSynth3D/data/3dbodytex-1.1-highres/", | |
| dir_annotate="./DermSynth3D/skin3d/data/3dbodytex-1.1-highres/annotations/", | |
| ) | |
| # True to use the blended lesions, False to use the pasted lesions. | |
| is_blend = True | |
| background_ds = Background2d( | |
| dir_images="./DermSynth3D/data/background/IndoorScene/", | |
| image_filenames=None, | |
| ) | |
| from dermsynth3d.utils.anatomy import SimpleAnatomy | |
| color_labels = { | |
| 0: (0., 0., 0.), # background | |
| 1: (174., 199., 232.), # head | |
| 2: (152., 223., 138.), # torso | |
| 3: (31., 119., 180.), # hips | |
| 4: (255., 187., 120.), # legs | |
| 5: (188., 189., 34.), # feet | |
| 6: (140., 86., 75.), # arms | |
| 7: (255., 152., 150.), # hands | |
| } | |
| def to_simple_anatomy(anatomy): | |
| for i in range(16+1): | |
| if i in [0,1]: | |
| continue | |
| if i in [2,3]: | |
| anatomy[anatomy==i] = 2 | |
| if i == 4: | |
| anatomy[anatomy==i] = 3 | |
| if i in [5,6,7,8]: | |
| anatomy[anatomy==i] = 4 | |
| if i in [9,10]: | |
| anatomy[anatomy==i] = 5 | |
| if i in [11,12,13,14]: | |
| anatomy[anatomy==i] = 6 | |
| if i in [15,16]: | |
| anatomy[anatomy==i] = 7 | |
| return anatomy | |
| def convert_anatomy_to_rgb(anatomy): | |
| anatomy = to_simple_anatomy(anatomy) | |
| anatomy_rgb = np.zeros((anatomy.shape[0], anatomy.shape[1], 3)) | |
| for k, v in color_labels.items(): | |
| anatomy_rgb[anatomy == k] = v | |
| return anatomy_rgb.astype(np.uint8) | |
| import PIL.Image as pil | |
| import numpy as np | |
| import matplotlib as mpl | |
| import matplotlib.cm as cm | |
| def convert_depth_to_rgb(depth): | |
| mask = depth != 0 | |
| disp_map = 1 / depth | |
| vmax = np.percentile(disp_map[mask], 95) | |
| vmin = np.percentile(disp_map[mask], 5) | |
| normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) | |
| mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') | |
| mask = np.repeat(np.expand_dims(mask,-1), 3, -1) | |
| colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8) | |
| colormapped_im[~mask] = 255 | |
| return colormapped_im | |
| def prepare_ds_renderer( | |
| randomize, | |
| mesh_name, | |
| texture_name, | |
| num_lesion, | |
| num_views, | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| device=DEVICE, | |
| ): | |
| mesh_filename = get_mesh_path(mesh_name) | |
| mesh = load_mesh_and_texture(mesh_name, texture_name, num_lesion, device) | |
| gr.Info("Preparing for Rendering...") | |
| mesh_renderer = MeshRendererPyTorch3D(mesh, DEVICE, config=None) | |
| extension = f"lesion_{num_lesion}" | |
| # if texture_name != "No Lesion": | |
| # extension = f"{texture_name.lower().split(' ')[0]}_lesion_{num_lesion}" | |
| nevi_exists = os.path.exists(bodytex.annotation_filepath(mesh_name.split("_")[0])) | |
| gen2d = Generate2DHelper( | |
| mesh_filename=mesh_filename, | |
| dir_blended_textures="./hf_demo/lesions/", | |
| dir_anatomy="./DermSynth3D/data/bodytex_anatomy_labels/", | |
| fitz_ds=None, # fitz_ds, | |
| background_ds=background_ds, | |
| device=device, | |
| debug=True, | |
| bodytex=bodytex, | |
| blended_file_ext=extension, # if num_lesion > 0 else "demo", | |
| config=None, | |
| is_blended=is_blend, | |
| ) | |
| # blended3d = Blended3d( | |
| # mesh_filename=os.path.join( | |
| # "./DermSynth3D/data/3dbodytex-1.1-highres/", | |
| # mesh_name, | |
| # "model_highres_0_normalized.obj", | |
| # ), | |
| # device=DEVICE, | |
| # dir_blended_textures=dir_blended_textures, | |
| # dir_anatomy=dir_anatomy, | |
| # extension=extension , | |
| # ) | |
| # normal_texture = load_texture_map( | |
| # mesh, mesh_name, "No Lesion", 0, device | |
| # ).maps_padded() | |
| # if num_lesion > 0: | |
| # blended_texture_image = load_texture_map( | |
| # mesh, mesh_name, "Blended Lesion", num_lesion, device | |
| # ).maps_padded() | |
| # pasted_texture_image = load_texture_map( | |
| # mesh, mesh_name, "Pasted Lesion", num_lesion, device | |
| # ).maps_padded() | |
| # dilated_texture_image = load_texture_map( | |
| # mesh, mesh_name, "Dilated Lesion", num_lesion, device | |
| # ).maps_padded() | |
| # texture_lesion_mask = blended3d.lesion_texture_mask(astensor=True).to(device) | |
| # non_skin_texture_mask = blended3d.nonskin_texture_mask(astensor=True).to(device) | |
| # vertices_to_anatomy = blended3d.vertices_to_anatomy() | |
| # mesh_renderer.raster_settings = raster_settings | |
| renderer, cameras, lights, materials = set_rendering_params( | |
| randomize, | |
| 1, # num_views, | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ) | |
| gr.Info("Successfully prepared renderer.") | |
| gr.Info("Rendering Images...") | |
| gr.Info(f"Rendering {num_views} views on {DEVICE}. Please wait...") | |
| img_count = 0 | |
| view2d = [] | |
| depth2d = [] | |
| anatomy2d = [] | |
| seg2d = [] | |
| view_size = (224, 224) | |
| gen2d.view_size = view_size | |
| while img_count < num_views: | |
| if randomize: | |
| gr.Info("Finding suitable parameters...") | |
| success = gen2d.randomize_parameters(config=None, view_size=view_size) | |
| if not success: | |
| gr.Info("Could not find suitable parameters. Trying again.") | |
| continue | |
| else: | |
| raster_settings = RasterizationSettings( | |
| image_size=view_size[0], | |
| blur_radius=0.0, | |
| faces_per_pixel=10, | |
| perspective_correct=True, | |
| ) | |
| gen2d.mesh_renderer.cameras = cameras | |
| gen2d.mesh_renderer.lights = lights | |
| gen2d.mesh_renderer.materials = materials | |
| gen2d.mesh_renderer.raster_settings = raster_settings | |
| gen2d.mesh_renderer.initialize_renderer() | |
| gr.Info("Rasterization in progress...") | |
| gen2d.mesh_renderer.compute_fragments() | |
| gr.Info("Successfully rasterized.") | |
| paste_img, target = gen2d.render_image_and_target(paste_lesion=False) | |
| if paste_img is None: | |
| gr.Info( | |
| "***Not enough skin or unable to paste lesion. Skipping and Retrying." | |
| ) | |
| print("***Not enough skin or unable to paste lesion. Skipping.") | |
| continue | |
| paste_img = (paste_img * 255).astype(np.uint8) | |
| anatomy_view = target[:, :, 3] | |
| depth_view = target[:, :, 4] | |
| depth_img = convert_depth_to_rgb(depth_view) | |
| view2d.append(paste_img) | |
| depth2d.append(depth_img) | |
| anatomy_img = convert_anatomy_to_rgb(anatomy_view) | |
| anatomy2d.append(anatomy_img) | |
| mask = target[:, :, 0] | |
| seg2d.append(mask) | |
| gr.Info(f"Successfully rendered {img_count+1}/{num_views} image+annotations.") | |
| img_count += 1 | |
| return view2d, depth2d, anatomy2d, seg2d | |
| # define the list of all the examples | |
| def get_examples(): | |
| # setup_paths() | |
| # get mesh names from here | |
| mesh_names = globals()["mesh_names"] | |
| # get the textures | |
| textures = ["No Lesion", "Pasted Lesion", "Blended Lesion", "Dilated Lesion"] | |
| lesions = [1, 2, 5, 10] | |
| examples = [] | |
| for mesh in mesh_names: | |
| for texture in textures: | |
| for lesion in lesions: | |
| if texture == "No Lesion": | |
| # examples.append([mesh, texture, 0, 4, True]) | |
| examples.append([mesh, texture, 0]) | |
| break | |
| # examples.append([mesh, texture, lesion, 4, True]) | |
| examples.append([mesh, texture, lesion]) | |
| return examples | |
| import tempfile | |
| def get_trimesh_attrs(mesh_name, tex_name, num_lesion=1): | |
| mesh_path = get_mesh_path(mesh_name) | |
| texture_path = get_texture_module(tex_name)(mesh_name, num_lesion) | |
| texture_img = Image.open(texture_path).convert("RGB") | |
| tri_mesh = trimesh.load(mesh_path) | |
| angle = -math.pi / 2 | |
| direction = [0, 1, 0] | |
| center = [0, 0, 0] | |
| rot_matrix = tf.rotation_matrix(angle, direction, center) | |
| tri_mesh = tri_mesh.apply_transform(rot_matrix) | |
| tri_mesh.apply_transform(tf.rotation_matrix(math.pi, [0, 0, 1], [-1, -1, -1])) | |
| verts, faces = tri_mesh.vertices, tri_mesh.faces | |
| uvs = tri_mesh.visual.uv | |
| material = trimesh.visual.texture.SimpleMaterial(image=texture_img) | |
| vis = trimesh.visual.TextureVisuals(uv=uvs, material=material, image=texture_img) | |
| tri_mesh.visual = vis | |
| colors = tri_mesh.visual.to_color() | |
| vc = colors.vertex_colors # / 255.0 | |
| # timg = tri_mesh.visual.material.image | |
| return verts, faces, vc, mesh_name | |
| def plotly_image(image): | |
| fig = go.Figure() | |
| fig.add_trace(go.Image(z=image)) | |
| fig.update_layout( | |
| width=view_width, | |
| height=view_height, | |
| margin=dict(l=0, r=0, b=0, t=0, pad=0), | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| plot_bgcolor="rgba(0,0,0,0)", | |
| ) | |
| fig.update_xaxes(showticklabels=False) | |
| fig.update_yaxes(showticklabels=False) | |
| fig.update_traces(hoverinfo="none") | |
| return fig | |
| def plotly_mesh(verts, faces, vc, mesh_name): | |
| fig = go.Figure( | |
| data=[ | |
| go.Mesh3d( | |
| x=verts[:, 0], | |
| y=verts[:, 1], | |
| z=verts[:, 2], | |
| i=faces[:, 0], | |
| j=faces[:, 1], | |
| k=faces[:, 2], | |
| vertexcolor=vc, | |
| ) | |
| ] | |
| ) | |
| # fig.update_layout(scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1)) | |
| fig.update_layout(scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False))) | |
| fig.update_layout(scene=dict(zaxis=dict(visible=False))) | |
| fig.update_layout(scene=dict(camera=dict(up=dict(x=1, y=0, z=1)))) | |
| fig.update_layout(scene=dict(camera=dict(eye=dict(x=-2, y=-2, z=-1)))) | |
| # disable hover info | |
| fig.update_traces(hoverinfo="none") | |
| return fig | |
| def load_texture_map(mesh, mesh_name, texture_name, num_lesion, device=DEVICE): | |
| verts = mesh.verts_packed().detach().cpu().numpy() | |
| faces = mesh.faces_packed().detach().cpu().numpy() | |
| normals = mesh.verts_normals_packed().detach().cpu().numpy() | |
| texture_path = get_texture_module(texture_name)(mesh_name, num_lesion) | |
| texture_img = Image.open(texture_path).convert("RGB") | |
| texture_tensor = torch.from_numpy(np.array(texture_img)).to(DEVICE) | |
| tmap = TexturesUV( | |
| maps=texture_tensor.float().to(device=mesh.device).unsqueeze(0), | |
| verts_uvs=mesh.textures.verts_uvs_padded(), | |
| faces_uvs=mesh.textures.faces_uvs_padded(), | |
| ) | |
| return tmap | |
| def load_mesh_and_texture(mesh_name, texture_name, num_lesion=1, device=DEVICE): | |
| """ | |
| Load a mesh and its corresponding texture. | |
| Args: | |
| mesh_name (str): The name of the mesh. | |
| texture_name (str): The name of the texture module. | |
| num_lesion (int, optional): The number of lesions. Defaults to 1. | |
| device (torch.device, optional): The device to load the mesh and texture on. Defaults to DEVICE. | |
| Returns: | |
| new_mesh (Meshes): The loaded mesh with texture. | |
| """ | |
| mesh_path = get_mesh_path(mesh_name) | |
| texture_path = get_texture_module(texture_name)(mesh_name, num_lesion) | |
| gr.Info("Loading mesh and texture...") | |
| mesh = load_objs_as_meshes([mesh_path], device=device) | |
| tmap = load_texture_map(mesh, mesh_name, texture_name, num_lesion, device) | |
| new_mesh = Meshes( | |
| verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tmap | |
| ) | |
| return new_mesh | |
| def setup_cameras(dist, elev, azim, device=DEVICE): | |
| gr.Info("Setting up cameras...") | |
| R, T = look_at_view_transform(dist, elev, azim, degrees=True) | |
| cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=30.0, znear=0.01) | |
| return cameras | |
| def setup_lights( | |
| light_pos, ambient_color, diffuse_color, specular_color, device=DEVICE | |
| ): | |
| gr.Info("Setting up lights...") | |
| lights = PointLights( | |
| device=device, | |
| location=light_pos, | |
| ambient_color=ambient_color, | |
| diffuse_color=diffuse_color, | |
| specular_color=specular_color, | |
| ) | |
| return lights | |
| def setup_materials(shininess, specularity, device=DEVICE): | |
| gr.Info("Setting up materials...") | |
| materials = Materials( | |
| device=device, | |
| specular_color=specularity, # [[specularity, specularity, specularity]], | |
| shininess=shininess.reshape(-1), # [shininess], | |
| ) | |
| return materials | |
| def setup_renderer(cameras, lights, materials, device=DEVICE): | |
| global raster_settings | |
| raster_settings = RasterizationSettings( | |
| image_size=128, | |
| blur_radius=0.0, | |
| faces_per_pixel=1, | |
| # max_faces_per_bin=100, | |
| # bin_size=0, | |
| perspective_correct=True, | |
| ) | |
| renderer = MeshRenderer( | |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), | |
| shader=SoftPhongShader( | |
| device=device, cameras=cameras, lights=lights, materials=materials | |
| ), | |
| ) | |
| return renderer | |
| def render_images(renderer, mesh, lights, cameras, materials, nviews, device=DEVICE): | |
| meshes = mesh.extend(nviews) | |
| gr.Info("Rendering Images...") | |
| images = renderer(meshes, lights=lights, cameras=cameras, materials=materials) | |
| gr.Info("Successfully rendered images.") | |
| images = images[..., :3] | |
| images = (images - images.min()) / (images.max() - images.min()) | |
| return images | |
| fragments = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)(meshes) | |
| # print(images.shape) | |
| # breakpoint() | |
| return images | |
| def randomize_view_params(randomize, num_views): | |
| dist = torch.rand(num_views).uniform_(0.0, 10.0) | |
| elev = torch.rand(num_views).uniform_(-90, 90) | |
| azim = torch.rand(num_views).uniform_(-90, 90) | |
| light_pos = torch.rand(num_views, 3).uniform_(0.0, 2.0) | |
| light_ac = torch.rand(num_views, 3).uniform_(0.0, 1.0) | |
| light_dc = torch.rand(num_views, 3).uniform_(0.0, 1.0) | |
| light_sc = torch.rand(num_views, 3).uniform_(0.0, 1.0) | |
| mat_sh = torch.rand(num_views, 1).uniform_(0, 100) | |
| mat_sc = torch.rand(num_views, 3).uniform_(0.0, 1.0) | |
| gr.Info("Randomized view parameters...") | |
| return ( | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ) | |
| def sample_camera_params(num_views, dist, elev, azim): | |
| gr.Info("Setting up cameras...") | |
| dist = torch.linspace(dist - num_views // 2, dist + num_views // 2, num_views) | |
| elev = torch.linspace(elev - num_views // 2, elev + num_views // 2, num_views) | |
| azim = torch.linspace(azim - num_views // 2, azim + num_views // 2, num_views) | |
| cameras = setup_cameras(dist, elev, azim) | |
| return cameras | |
| def sample_light_params(num_views, light_pos, light_ac, light_dc, light_sc): | |
| gr.Info("Setting up lights...") | |
| light_pos = ( | |
| torch.linspace( | |
| light_pos - num_views // 2, light_pos + num_views // 2, num_views | |
| ) | |
| .reshape(-1, 1) | |
| .repeat(1, 3) | |
| ) | |
| light_ac = ( | |
| torch.linspace(light_ac - num_views // 2, light_ac + num_views // 2, num_views) | |
| .reshape(-1, 1) | |
| .repeat(1, 3) | |
| ) | |
| light_dc = ( | |
| torch.linspace(light_dc - num_views // 2, light_dc + num_views // 2, num_views) | |
| .reshape(-1, 1) | |
| .repeat(1, 3) | |
| ) | |
| light_sc = ( | |
| torch.linspace(light_sc - num_views // 2, light_sc + num_views // 2, num_views) | |
| .reshape(-1, 1) | |
| .repeat(1, 3) | |
| ) | |
| lights = setup_lights(light_pos, light_ac, light_dc, light_sc) | |
| return lights | |
| def sample_material_params(num_views, mat_sh, mat_sc): | |
| gr.Info("Setting up materials...") | |
| mat_sh = ( | |
| torch.linspace(mat_sh - num_views // 2, mat_sh + num_views // 2, num_views) | |
| .reshape(-1, 1) | |
| .repeat(1, 1) | |
| ) | |
| mat_sc = ( | |
| torch.linspace(mat_sc - num_views // 2, mat_sc + num_views // 2, num_views) | |
| .reshape(-1, 1) | |
| .repeat(1, 3) | |
| ) | |
| materials = setup_materials(mat_sh, mat_sc) | |
| return materials | |
| def set_rendering_params( | |
| randomize, | |
| num_views, | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ): | |
| if randomize: | |
| ( | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ) = randomize_view_params(randomize, num_views) | |
| cameras = setup_cameras(dist, elev, azim) | |
| lights = setup_lights(light_pos, light_ac, light_dc, light_sc) | |
| materials = setup_materials(mat_sh, mat_sc) | |
| else: | |
| cameras = sample_camera_params(num_views, dist, elev, azim) | |
| lights = sample_light_params(num_views, light_pos, light_ac, light_dc, light_sc) | |
| materials = sample_material_params(num_views, mat_sh, mat_sc) | |
| renderer = setup_renderer(cameras, lights, materials) | |
| return renderer, cameras, lights, materials | |
| def process_examples(mesh_name, tex_name, n_lesion): | |
| mesh_path = get_mesh_path(mesh_name) | |
| texture_path = get_texture_module(tex_name)(mesh_name, n_lesion) | |
| mesh_to_view = plotly_mesh(*get_trimesh_attrs(mesh_name, tex_name, n_lesion)) | |
| return mesh_to_view, texture_path, n_lesion | |
| def update_plots(mesh_name, texture_name, num_lesion): | |
| if num_lesion > 0 and texture_name == "No Lesion": | |
| gr.Warning( | |
| f"Cannot display '{texture_name}' texture map with {num_lesion} lesions! Please change the texture. Meanwhile, not updating the display." | |
| ) | |
| return default_mesh_plot, default_texture, num_lesion | |
| elif num_lesion == 0 and texture_name != "No Lesion": | |
| go.Warning( | |
| f"Cannot display '{texture_name}' texture map with {num_lesion} lesions! Please increase the number of lesions." | |
| ) | |
| return default_mesh_plot, default_texture, num_lesion | |
| mesh_path = get_mesh_path(mesh_name) | |
| texture_path = Image.open(get_texture_module(texture_name)(mesh_name, num_lesion)).convert("RGB").resize((512, 512)) | |
| mesh_to_view = plotly_mesh(*get_trimesh_attrs(mesh_name, texture_name, num_lesion)) | |
| gr.Info("Successfully updated mesh and texture.") | |
| return mesh_to_view, texture_path, num_lesion | |
| def run_demo(): | |
| # get the defined examples | |
| all_examples = get_examples() | |
| mesh_block = gr.Plot( | |
| label="Selected Mesh", | |
| value=default_mesh_plot, | |
| # scale=1, | |
| ) | |
| texture_block = gr.Image( | |
| value=default_texture, | |
| type="pil", | |
| image_mode="RGB", | |
| height="auto", | |
| width="auto", | |
| label="Selected Texture", | |
| ) | |
| num_lesions = gr.Radio( | |
| choices=[0, 1, 2, 5, 10], | |
| label="Number of Lesions", | |
| value=0, | |
| interactive=True, | |
| ) | |
| num_views = gr.Slider(2, 32, 4, label="Number of Views", step=2, interactive=True) | |
| randomize = gr.Checkbox( | |
| label="Randomize View Parameters", value=True, interactive=True | |
| ) | |
| render_button = gr.Button("Render Views") | |
| select_mesh = gr.Dropdown( | |
| choices=mesh_names, | |
| value=mesh_names[0], | |
| interactive=True, | |
| label="Input Mesh", | |
| info="Select the mesh to render", | |
| ) | |
| select_texture = gr.Dropdown( | |
| choices=["No Lesion", "Pasted Lesion", "Blended Lesion", "Dilated Lesion"], | |
| value="No Lesion", | |
| interactive=True, | |
| label="Input Texture", | |
| info="Select the texture to use for the mesh.", | |
| ) | |
| # compose demo layout and data flow | |
| with gr.Blocks( | |
| title=_TITLE, analytics_enabled=True, theme=gr.themes.Base() | |
| ) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"# {_TITLE}") | |
| gr.Markdown(_DESCRIPTION) | |
| # User input panel | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| select_mesh.render() | |
| select_texture.render() | |
| num_lesions.render() | |
| num_views.render() | |
| randomize.render() | |
| with gr.Column(scale=1): | |
| mesh_block.render() | |
| with gr.Column(scale=1): | |
| texture_block.render() | |
| gr.on( | |
| triggers=[ | |
| select_mesh.change, | |
| select_texture.change, | |
| num_lesions.change, | |
| ], | |
| inputs=[select_mesh, select_texture, num_lesions], | |
| outputs=[mesh_block, texture_block, num_lesions], | |
| fn=update_plots, | |
| ) | |
| # @gr.on( | |
| # inputs=[ | |
| # select_mesh, | |
| # select_texture, | |
| # num_lesions, | |
| # ], | |
| # outputs=[ | |
| # mesh_block, | |
| # texture_block, | |
| # num_lesions, | |
| # ], | |
| # triggers=[ | |
| # select_mesh.change, | |
| # select_texture.change, | |
| # num_lesions.change, | |
| # ], | |
| # ) | |
| # def update(m, t, l): | |
| # return update_plots(m, t, l) | |
| # rendering choices | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| render_button.render() | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Configure View Parameters", open=False): | |
| # setup cameras | |
| with gr.Accordion("Camera Parameters", open=False): | |
| dist = gr.Slider( | |
| minimum=0.0, | |
| maximum=10.0, | |
| value=0.5, | |
| step=0.5, | |
| interactive=True, | |
| label="Distance", | |
| ) | |
| elev = gr.Slider( | |
| label="Elevation", | |
| interactive=True, | |
| minimum=-90, | |
| maximum=90, | |
| value=0, | |
| step=10, | |
| ) | |
| azim = gr.Slider( | |
| label="Azimuth", | |
| interactive=True, | |
| minimum=-90, | |
| maximum=90, | |
| value=90, | |
| step=10, | |
| ) | |
| # setup lights | |
| with gr.Accordion("Lighting Parameters", open=False): | |
| light_pos = gr.Slider( | |
| label="Light Position", | |
| interactive=True, | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| light_ac = gr.Slider( | |
| label="Ambient Color", | |
| minimum=0.0, | |
| maximum=1.0, | |
| interactive=True, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| light_dc = gr.Slider( | |
| label="Diffuse Color", | |
| minimum=0.0, | |
| maximum=1.0, | |
| interactive=True, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| light_sc = gr.Slider( | |
| label="Specular Color", | |
| minimum=0.0, | |
| maximum=1.0, | |
| interactive=True, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| # setup material parameters | |
| with gr.Accordion("Material Parameters", open=False): | |
| mat_sh = gr.Slider( | |
| label="Shininess", | |
| interactive=True, | |
| minimum=0, | |
| maximum=100, | |
| value=50, | |
| step=10, | |
| ) | |
| mat_sc = gr.Slider( | |
| label="Specularity", | |
| minimum=0.0, | |
| interactive=True, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| update_view_btn = gr.Button("Update View Parameters") | |
| gr.on( | |
| triggers=[ | |
| update_view_btn.click, | |
| dist.change, | |
| elev.change, | |
| azim.change, | |
| light_pos.change, | |
| light_ac.change, | |
| light_dc.change, | |
| light_sc.change, | |
| mat_sh.change, | |
| mat_sc.change, | |
| ], | |
| inputs=[randomize], | |
| outputs=[randomize], | |
| fn=lambda x: False, | |
| show_progress="hidden", | |
| queue=False, | |
| scroll_to_output=True, | |
| ) | |
| # rendered views panel | |
| with gr.Row(variant="panel"): | |
| with gr.Tab("Rendered RGB Views"): | |
| render_block = gr.Gallery( | |
| label="Rendered Views", columns=4, height="auto", object_fit="contain" | |
| ) | |
| with gr.Tab("Rendered Depth Views"): | |
| depth_block = gr.Gallery( | |
| label="Depth Maps", columns=4, height="auto", object_fit="contain" | |
| ) | |
| with gr.Tab("Rendered Anatomy Views"): | |
| anatomy_block = gr.Gallery( | |
| label="Anatomy Labels", columns=4, height="auto", object_fit="contain" | |
| ) | |
| with gr.Tab("Rendered Segmentation Views"): | |
| seg_block = gr.Gallery( | |
| label="Segmentation Masks", columns=4, height="auto", object_fit="contain" | |
| ) | |
| # | |
| # render_block = gr.Gallery( | |
| # label="Rendered Views", columns=4, height="auto", object_fit="contain" | |
| # ) | |
| def render_views( | |
| randomize, | |
| select_mesh, | |
| select_texture, | |
| num_lesions, | |
| num_views, | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ): | |
| renderer, cameras, lights, materials = set_rendering_params( | |
| randomize, | |
| num_views, | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ) | |
| # gr.Info("Loading mesh and texture...") | |
| # mesh = load_mesh_and_texture(select_mesh, select_texture, num_lesions) | |
| # cameras | |
| # images = render_images( | |
| # renderer, mesh, lights, cameras, materials, num_views | |
| # ) | |
| # return [_ for _ in images.detach().cpu().numpy()] | |
| view2d, depth, anatomy, segmentation = prepare_ds_renderer( | |
| randomize, | |
| select_mesh, | |
| select_texture, | |
| num_lesions, | |
| num_views, | |
| dist, | |
| elev, | |
| azim, | |
| light_pos, | |
| light_ac, | |
| light_dc, | |
| light_sc, | |
| mat_sh, | |
| mat_sc, | |
| ) | |
| return view2d, depth, anatomy, segmentation | |
| # examples panel when the iuser does not want to input | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| gr.Examples( | |
| examples=all_examples, | |
| inputs=[ | |
| select_mesh, | |
| select_texture, | |
| num_lesions, | |
| ], | |
| outputs=[ | |
| mesh_block, | |
| texture_block, | |
| num_lesions, | |
| ], | |
| cache_examples=False, | |
| fn=update_plots, | |
| label="Meshes and Textures for Demo (Click to start)", | |
| ) | |
| demo.queue(max_size=10) | |
| demo.launch( | |
| share=True, | |
| max_threads=mp.cpu_count(), | |
| show_error=True, | |
| show_api=False, | |
| ) | |
| def get_texture_module(tex_type): | |
| if tex_type == "No Lesion": | |
| return get_no_lesion_path | |
| elif tex_type == "Pasted Lesion": | |
| return get_pasted_lesion_path | |
| elif tex_type == "Blended Lesion": | |
| return get_blended_lesion_path | |
| elif tex_type == "Dilated Lesion": | |
| return get_dilated_lesion_path | |
| else: | |
| raise ValueError(f"Texture type {tex_type} not supported!") | |
| if __name__ == "__main__": | |
| # setup_paths() | |
| mesh_paths = glob("./DermSynth3D/data/3dbodytex-1.1-highres/*/*.obj") | |
| mesh_names = [os.path.basename(os.path.dirname(x)) for x in mesh_paths] | |
| # get the textures | |
| all_textures = glob("./DermSynth3D/data/3dbodytex-1.1-highres/*/*.png") | |
| dir_blended_textures = "./hf_demo/lesions/" | |
| dir_anatomy = "./DermSynth3D/data/bodytex_anatomy_labels/" | |
| dir_background = "./DermSynth3D/data/background/IndoorScene/" | |
| get_no_lesion_path = lambda x, y: os.path.join( | |
| "./DermSynth3D/data/3dbodytex-1.1-highres", x, "model_highres_0_normalized.png" | |
| ) | |
| get_mesh_path = lambda x: os.path.join( | |
| "./DermSynth3D/data/3dbodytex-1.1-highres", x, "model_highres_0_normalized.obj" | |
| ) | |
| # get the textures with the lesions | |
| get_mask_path = lambda x: os.path.join( | |
| "./hf_demo/lesions/", x, "model_highres_0_normalized_mask.png" | |
| ) | |
| get_dilated_lesion_path = lambda x, y: os.path.join( | |
| "./hf_demo/lesions/", | |
| x, | |
| f"model_highres_0_normalized_dilated_lesion_{y}.png", | |
| ) | |
| get_blended_lesion_path = lambda x, y: os.path.join( | |
| "./hf_demo/lesions/", | |
| x, | |
| f"model_highres_0_normalized_blended_lesion_{y}.png", | |
| ) | |
| get_pasted_lesion_path = lambda x, y: os.path.join( | |
| "./hf_demo/lesions/", | |
| x, | |
| f"model_highres_0_normalized_pasted_lesion_{y}.png", | |
| ) | |
| default_mesh_plot = plotly_mesh(*get_trimesh_attrs(mesh_names[0], "No Lesion", 0)) | |
| default_texture = Image.open(all_textures[0]).convert("RGB").resize((512, 512)) | |
| new_values = { | |
| "default_mesh_plot": default_mesh_plot, | |
| "default_texture": default_texture, | |
| } | |
| globals().update(new_values) | |
| run_demo() | |