Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import sys | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import random | |
| os.system("git clone https://github.com/Zhengxinyang/SDF-StyleGAN.git") | |
| sys.path.append("SDF-StyleGAN") | |
| #Codes reference : https://github.com/Zhengxinyang/SDF-StyleGAN | |
| from utils.utils import evaluate_in_chunks, scale_to_unit_sphere | |
| from network.model import StyleGAN2_3D | |
| def noise(batch_size, latent_dim, device): | |
| return torch.randn(batch_size, latent_dim,device=device) | |
| def noise_list(batch_size, layers, latent_dim, device): | |
| return [(noise(batch_size, latent_dim, device), layers)] | |
| def volume_noise(n, vol_size, device): | |
| if device=="cuda": | |
| return torch.FloatTensor(n, vol_size, vol_size, vol_size, 1).uniform_(0., 1.).cuda(device) | |
| return torch.FloatTensor(n, vol_size, vol_size, vol_size, 1).uniform_(0., 1.) | |
| class StyleGAN2_3D_not_cuda(StyleGAN2_3D): | |
| def generate_feature_volume(self, ema=False, trunc_psi=0.75): | |
| latents = noise_list( | |
| 1, self.num_layers, self.latent_dim, device=self.device) | |
| n = volume_noise(1, self.G_vol_size, device=self.device) | |
| if ema: | |
| generate_voxels = self.generate_truncated( | |
| self.SE, self.GE, latents, n, trunc_psi) | |
| else: | |
| generate_voxels = self.generate_truncated( | |
| self.S, self.G, latents, n, trunc_psi) | |
| return generate_voxels | |
| cars=hf_hub_download("SerdarHelli/SDF-StyleGAN-3D", filename="cars.ckpt",revision="main") | |
| #default model | |
| device='cuda' if torch.cuda.is_available() else 'cpu' | |
| models={"Car":cars, | |
| "Airplane":"./planes.ckpt", | |
| "Chair":"./chairs.ckpt", | |
| "Rifle":"./rifles.ckpt", | |
| "Table":"./tables.ckpt" | |
| } | |
| def seed_all(seed): | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def predict(seed,model,trunc_psi): | |
| if seed==None: | |
| seed=777 | |
| seed_all(seed) | |
| if trunc_psi==None: | |
| trunc_psi=1 | |
| z = noise(100000, model.latent_dim, device=model.device) | |
| samples = evaluate_in_chunks(1000, model.SE, z) | |
| model.av = torch.mean(samples, dim=0, keepdim=True) | |
| mesh = model.generate_mesh( | |
| ema=True, mc_vol_size=64, level=-0.015, trunc_psi=trunc_psi) | |
| mesh = scale_to_unit_sphere(mesh) | |
| x=np.asarray(mesh.vertices).T[0] | |
| y=np.asarray(mesh.vertices).T[1] | |
| z=np.asarray(mesh.vertices).T[2] | |
| i=np.asarray(mesh.faces).T[0] | |
| j=np.asarray(mesh.faces).T[1] | |
| k=np.asarray(mesh.faces).T[2] | |
| return x,y,z,i,j,k | |
| def generate(seed,model_name,trunc_psi): | |
| print(model_name) | |
| try : | |
| ckpt=models[model_name] | |
| except KeyError: | |
| ckpt=cars | |
| if device=="cuda": | |
| model = StyleGAN2_3D.load_from_checkpoint(ckpt).cuda(0) | |
| else: | |
| model = StyleGAN2_3D_not_cuda.load_from_checkpoint(ckpt) | |
| model.eval() | |
| x,y,z,i,j,k=predict(seed,model,trunc_psi) | |
| fig = go.Figure(go.Mesh3d(x=x, y=y, z=z, | |
| i=i, j=j, k=k, | |
| colorscale="Viridis", | |
| colorbar_len=0.75, | |
| flatshading=True, | |
| lighting=dict(ambient=0.5, | |
| diffuse=1, | |
| fresnel=4, | |
| specular=0.5, | |
| roughness=0.05, | |
| facenormalsepsilon=0, | |
| vertexnormalsepsilon=0), | |
| lightposition=dict(x=100, | |
| y=100, | |
| z=1000))) | |
| return fig | |
| markdown=f''' | |
| # SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation | |
| [The space demo for the SGP 2022 paper "SDF-StyleGAN: Implicit SDF-Based StyleGAN for 3D Shape Generation".](https://arxiv.org/abs/2206.12055) | |
| [For the official implementation.](https://github.com/Zhengxinyang/SDF-StyleGAN) | |
| ### Future Work based on interest | |
| - Adding new models for new type objects | |
| - New Customization | |
| It is running on {device} | |
| ''' | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown(markdown) | |
| with gr.Row(): | |
| seed = gr.Slider( minimum=0, maximum=2**16,label='Seed') | |
| model_name=gr.Dropdown(choices=["Car","Airplane","Chair","Rifle","Table"],label="Choose Model Type") | |
| trunc_psi = gr.Slider( minimum=0, maximum=2,label='Truncate PSI') | |
| btn = gr.Button(value="Generate") | |
| mesh = gr.Plot() | |
| demo.load(generate, [seed,model_name,trunc_psi], mesh) | |
| btn.click(generate, [seed,model_name,trunc_psi], mesh) | |
| demo.launch(debug=True) |