eigopop commited on
Commit
b5ee29a
·
verified ·
1 Parent(s): 402a4f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -166
app.py CHANGED
@@ -1,122 +1,86 @@
1
  import spaces
2
  import os
3
- import sys
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  from PIL import Image
8
  import trimesh
9
  import random
10
- import subprocess
11
- import shutil
12
-
13
  from transformers import AutoModelForImageSegmentation
14
  from torchvision import transforms
15
  from huggingface_hub import hf_hub_download, snapshot_download
 
 
16
 
17
- # --------------------
18
- # Runtime installs
19
- # --------------------
20
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
21
 
22
- # --------------------
23
- # Device
24
- # --------------------
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
  DTYPE = torch.float16
27
- print("DEVICE:", DEVICE)
28
 
29
- # --------------------
30
- # Constants
31
- # --------------------
32
  DEFAULT_FACE_NUMBER = 100000
33
  MAX_SEED = np.iinfo(np.int32).max
34
- NUM_VIEWS = 6
35
-
36
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
37
- MV_ADAPTER_REPO_URL = "https://huggingface.co/spaces/VAST-AI/MV-Adapter-Img2Texture"
38
 
39
  RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
40
  TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
41
 
42
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
43
- TMP_DIR = os.path.join(BASE_DIR, "tmp")
44
  os.makedirs(TMP_DIR, exist_ok=True)
45
 
46
- # --------------------
47
- # Clone repos
48
- # --------------------
49
- TRIPOSG_CODE_DIR = os.path.join(BASE_DIR, "triposg")
50
  if not os.path.exists(TRIPOSG_CODE_DIR):
51
  os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
52
 
53
- MV_ADAPTER_CODE_DIR = os.path.join(BASE_DIR, "mv_adapter")
54
  if not os.path.exists(MV_ADAPTER_CODE_DIR):
55
- os.system(
56
- f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} "
57
- f"&& cd {MV_ADAPTER_CODE_DIR} "
58
- f"&& git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d"
59
- )
60
 
61
- # --------------------
62
- # Python path (CRITICAL)
63
- # --------------------
64
  sys.path.append(TRIPOSG_CODE_DIR)
 
65
  sys.path.append(MV_ADAPTER_CODE_DIR)
 
66
 
67
- # --------------------
68
- # UI header
69
- # --------------------
70
  HEADER = """
71
- # 🔮 Image to 3D with TripoSG
72
 
73
- State-of-the-art open-source 3D generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- Texture generation powered by
76
- [MV-Adapter HF Space](https://huggingface.co/spaces/VAST-AI/MV-Adapter-Img2Texture)
77
  """
78
 
79
- # --------------------
80
- # TripoSG setup
81
- # --------------------
82
  from image_process import prepare_image
83
  from briarmbg import BriaRMBG
84
-
85
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
86
  rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
87
  rmbg_net.eval()
88
-
89
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
90
-
91
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
92
- triposg_pipe = TripoSGPipeline.from_pretrained(
93
- TRIPOSG_PRETRAINED_MODEL
94
- ).to(DEVICE, DTYPE)
95
-
96
- # --------------------
97
- # MV-Adapter imports (CORRECT)
98
- # --------------------
99
- from mv_adapter.inference_ig2mv_sdxl import (
100
- prepare_pipeline,
101
- preprocess_image,
102
- remove_bg,
103
- )
104
 
105
- from mv_adapter.mvadapter.utils import (
106
- get_orthogonal_camera,
107
- tensor_to_image,
108
- make_image_grid,
109
- )
110
-
111
- from mv_adapter.mvadapter.utils.render import (
112
- NVDiffRastContextWrapper,
113
- load_mesh,
114
- render,
115
- )
116
-
117
- # --------------------
118
- # MV-Adapter pipeline
119
- # --------------------
120
  mv_adapter_pipe = prepare_pipeline(
121
  base_model="stabilityai/stable-diffusion-xl-base-1.0",
122
  vae_model="madebyollin/sdxl-vae-fp16-fix",
@@ -128,85 +92,77 @@ mv_adapter_pipe = prepare_pipeline(
128
  device=DEVICE,
129
  dtype=torch.float16,
130
  )
131
-
132
- # --------------------
133
- # BiRefNet background remover
134
- # --------------------
135
  birefnet = AutoModelForImageSegmentation.from_pretrained(
136
- "ZhengPeng7/BiRefNet",
137
- trust_remote_code=True,
138
- ).to(DEVICE)
139
-
140
  transform_image = transforms.Compose(
141
  [
142
  transforms.Resize((1024, 1024)),
143
  transforms.ToTensor(),
144
- transforms.Normalize(
145
- [0.485, 0.456, 0.406],
146
- [0.229, 0.224, 0.225],
147
- ),
148
  ]
149
  )
150
-
151
  remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
152
 
153
- # --------------------
154
- # Texture models
155
- # --------------------
156
  if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
157
- hf_hub_download(
158
- "dtarnow/UPscaler",
159
- filename="RealESRGAN_x2plus.pth",
160
- local_dir="checkpoints",
161
- )
162
-
163
  if not os.path.exists("checkpoints/big-lama.pt"):
164
- subprocess.run(
165
- "wget -P checkpoints/ "
166
- "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
167
- shell=True,
168
- check=True,
169
- )
 
 
 
 
170
 
171
- # --------------------
172
- # Helpers
173
- # --------------------
174
  def get_random_hex():
175
- return os.urandom(8).hex()
 
 
 
 
 
 
 
176
 
177
- # --------------------
178
- # Main pipeline
179
- # --------------------
180
  @spaces.GPU(duration=180)
181
- def run_full(image: str):
182
  seed = 0
183
-
184
- image_seg = prepare_image(
185
- image,
186
- bg_color=np.array([1.0, 1.0, 1.0]),
187
- rmbg_net=rmbg_net,
188
- )
189
 
190
  outputs = triposg_pipe(
191
  image=image_seg,
192
- generator=torch.Generator(device=DEVICE).manual_seed(seed),
193
- num_inference_steps=50,
194
- guidance_scale=7.5,
195
  ).samples[0]
196
-
197
- mesh = trimesh.Trimesh(
198
- outputs[0].astype(np.float32),
199
- np.ascontiguousarray(outputs[1]),
200
- )
201
-
 
 
202
  save_dir = os.path.join(TMP_DIR, "examples")
203
  os.makedirs(save_dir, exist_ok=True)
204
- mesh_path = os.path.join(save_dir, f"mesh_{get_random_hex()}.glb")
205
  mesh.export(mesh_path)
 
206
 
207
  torch.cuda.empty_cache()
208
 
209
- # Cameras
 
210
  cameras = get_orthogonal_camera(
211
  elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
212
  distance=[1.8] * NUM_VIEWS,
@@ -217,88 +173,296 @@ def run_full(image: str):
217
  azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
218
  device=DEVICE,
219
  )
220
-
221
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
222
- mesh_gpu = load_mesh(mesh_path, rescale=True, device=DEVICE)
223
 
 
224
  render_out = render(
225
  ctx,
226
- mesh_gpu,
227
  cameras,
228
- height=768,
229
- width=768,
230
  render_attr=False,
231
  normal_background=0.0,
232
  )
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- control_images = torch.cat(
235
- [
236
- (render_out.pos + 0.5).clamp(0, 1),
237
- (render_out.normal / 2 + 0.5).clamp(0, 1),
238
- ],
239
- dim=-1,
240
- ).permute(0, 3, 1, 2)
241
 
242
- image_ref = preprocess_image(
243
- remove_bg_fn(Image.open(image)),
244
- 768,
245
- 768,
246
- )
247
 
248
  images = mv_adapter_pipe(
249
  "high quality",
250
- height=768,
251
- width=768,
252
  num_inference_steps=15,
253
  guidance_scale=3.0,
254
  num_images_per_prompt=NUM_VIEWS,
255
  control_image=control_images,
256
- reference_image=image_ref,
257
- negative_prompt="watermark, ugly, blurry",
 
 
 
 
258
  ).images
259
 
260
- mv_image_path = os.path.join(save_dir, f"mv_{get_random_hex()}.png")
 
 
261
  make_image_grid(images, rows=1).save(mv_image_path)
262
 
263
  from texture import TexturePipeline, ModProcessConfig
 
 
 
 
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  texture_pipe = TexturePipeline(
266
  upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
267
  inpaint_ckpt_path="checkpoints/big-lama.pt",
268
  device=DEVICE,
269
  )
270
 
271
- textured_path = texture_pipe(
272
  mesh_path=mesh_path,
273
  save_dir=save_dir,
274
- save_name=f"tex_{get_random_hex()}.glb",
275
  uv_unwarp=True,
276
  uv_size=4096,
277
  rgb_path=mv_image_path,
278
- rgb_process_config=ModProcessConfig(
279
- view_upscale=True,
280
- inpaint_mode="view",
281
- ),
282
  camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
283
  )
284
 
285
- return image_seg, mesh_path, textured_path
 
286
 
287
- # --------------------
288
- # UI
289
- # --------------------
290
  with gr.Blocks(title="TripoSG") as demo:
291
  gr.Markdown(HEADER)
292
 
293
- image_input = gr.Image(type="filepath")
294
- out_seg = gr.Image()
295
- out_mesh = gr.Model3D()
296
- out_tex = gr.Model3D()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- gr.Button("Generate").click(
299
- run_full,
300
- inputs=image_input,
301
- outputs=[out_seg, out_mesh, out_tex],
302
  )
303
 
304
- demo.launch()
 
 
 
 
1
  import spaces
2
  import os
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from PIL import Image
7
  import trimesh
8
  import random
 
 
 
9
  from transformers import AutoModelForImageSegmentation
10
  from torchvision import transforms
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
+ import subprocess
13
+ import shutil
14
 
15
+ # install others
 
 
16
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
17
 
 
 
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.float16
 
20
 
21
+ print("DEVICE: ", DEVICE)
22
+
 
23
  DEFAULT_FACE_NUMBER = 100000
24
  MAX_SEED = np.iinfo(np.int32).max
 
 
25
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
26
+ MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
27
 
28
  RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
29
  TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
30
 
31
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
 
32
  os.makedirs(TMP_DIR, exist_ok=True)
33
 
34
+ TRIPOSG_CODE_DIR = "./triposg"
 
 
 
35
  if not os.path.exists(TRIPOSG_CODE_DIR):
36
  os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
37
 
38
+ MV_ADAPTER_CODE_DIR = "./mv_adapter"
39
  if not os.path.exists(MV_ADAPTER_CODE_DIR):
40
+ os.system(f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} && cd {MV_ADAPTER_CODE_DIR} && git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d")
 
 
 
 
41
 
42
+ import sys
 
 
43
  sys.path.append(TRIPOSG_CODE_DIR)
44
+ sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
45
  sys.path.append(MV_ADAPTER_CODE_DIR)
46
+ sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
47
 
 
 
 
48
  HEADER = """
 
49
 
50
+ # 🔮 Image to 3D with [TripoSG](https://github.com/VAST-AI-Research/TripoSG)
51
+
52
+ ## State-of-the-art Open Source 3D Generation Using Large-Scale Rectified Flow Transformers
53
+
54
+ <p style="font-size: 1.1em;">By <a href="https://www.tripo3d.ai/" style="color: #1E90FF; text-decoration: none; font-weight: bold;">Tripo</a></p>
55
+
56
+ ## 📋 Quick Start Guide:
57
+ 1. **Upload an image** (single object works best)
58
+ 2. Click **Generate Shape** to create the 3D mesh
59
+ 3. Click **Apply Texture** to add textures
60
+ 4. Use **Download GLB** to save your 3D model
61
+ 5. Adjust parameters under **Generation Settings** for fine-tuning
62
+
63
+ Best results come from clean, well-lit images with clear subject isolation. Try it now!
64
+
65
+ <p style="font-size: 0.9em; margin-top: 10px;">Texture generation powered by <a href="https://github.com/huanngzh/MV-Adapter" style="color: #1E90FF; text-decoration: none;">MV-Adapter</a> - a versatile multi-view adapter for consistent texture generation. Try the <a href="https://huggingface.co/spaces/VAST-AI/MV-Adapter-I2MV-SDXL" style="color: #1E90FF; text-decoration: none;">MV-Adapter demo</a> for multi-view image generation.</p>
66
 
 
 
67
  """
68
 
69
+ # # triposg
 
 
70
  from image_process import prepare_image
71
  from briarmbg import BriaRMBG
 
72
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
73
  rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
74
  rmbg_net.eval()
 
75
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
 
76
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
77
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # mv adapter
80
+ NUM_VIEWS = 6
81
+ from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
82
+ from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
83
+ from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
 
 
 
 
 
 
 
 
 
 
84
  mv_adapter_pipe = prepare_pipeline(
85
  base_model="stabilityai/stable-diffusion-xl-base-1.0",
86
  vae_model="madebyollin/sdxl-vae-fp16-fix",
 
92
  device=DEVICE,
93
  dtype=torch.float16,
94
  )
 
 
 
 
95
  birefnet = AutoModelForImageSegmentation.from_pretrained(
96
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
97
+ )
98
+ birefnet.to(DEVICE)
 
99
  transform_image = transforms.Compose(
100
  [
101
  transforms.Resize((1024, 1024)),
102
  transforms.ToTensor(),
103
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
 
 
 
104
  ]
105
  )
 
106
  remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
107
 
 
 
 
108
  if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
109
+ hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
 
 
 
 
 
110
  if not os.path.exists("checkpoints/big-lama.pt"):
111
+ subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
112
+
113
+ def start_session(req: gr.Request):
114
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
115
+ os.makedirs(save_dir, exist_ok=True)
116
+ print("start session, mkdir", save_dir)
117
+
118
+ def end_session(req: gr.Request):
119
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
120
+ shutil.rmtree(save_dir)
121
 
 
 
 
122
  def get_random_hex():
123
+ random_bytes = os.urandom(8)
124
+ random_hex = random_bytes.hex()
125
+ return random_hex
126
+
127
+ def get_random_seed(randomize_seed, seed):
128
+ if randomize_seed:
129
+ seed = random.randint(0, MAX_SEED)
130
+ return seed
131
 
 
 
 
132
  @spaces.GPU(duration=180)
133
+ def run_full(image: str, req: gr.Request):
134
  seed = 0
135
+ num_inference_steps = 50
136
+ guidance_scale = 7.5
137
+ simplify = True
138
+ target_face_num = DEFAULT_FACE_NUMBER
139
+
140
+ image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
141
 
142
  outputs = triposg_pipe(
143
  image=image_seg,
144
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
145
+ num_inference_steps=num_inference_steps,
146
+ guidance_scale=guidance_scale
147
  ).samples[0]
148
+ print("mesh extraction done")
149
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
150
+
151
+ if simplify:
152
+ print("start simplify")
153
+ from utils import simplify_mesh
154
+ mesh = simplify_mesh(mesh, target_face_num)
155
+
156
  save_dir = os.path.join(TMP_DIR, "examples")
157
  os.makedirs(save_dir, exist_ok=True)
158
+ mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
159
  mesh.export(mesh_path)
160
+ print("save to ", mesh_path)
161
 
162
  torch.cuda.empty_cache()
163
 
164
+ height, width = 768, 768
165
+ # Prepare cameras
166
  cameras = get_orthogonal_camera(
167
  elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
168
  distance=[1.8] * NUM_VIEWS,
 
173
  azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
174
  device=DEVICE,
175
  )
 
176
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
 
177
 
178
+ mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
179
  render_out = render(
180
  ctx,
181
+ mesh,
182
  cameras,
183
+ height=height,
184
+ width=width,
185
  render_attr=False,
186
  normal_background=0.0,
187
  )
188
+ control_images = (
189
+ torch.cat(
190
+ [
191
+ (render_out.pos + 0.5).clamp(0, 1),
192
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
193
+ ],
194
+ dim=-1,
195
+ )
196
+ .permute(0, 3, 1, 2)
197
+ .to(DEVICE)
198
+ )
199
 
200
+ image = Image.open(image)
201
+ image = remove_bg_fn(image)
202
+ image = preprocess_image(image, height, width)
 
 
 
 
203
 
204
+ pipe_kwargs = {}
205
+ if seed != -1 and isinstance(seed, int):
206
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
 
 
207
 
208
  images = mv_adapter_pipe(
209
  "high quality",
210
+ height=height,
211
+ width=width,
212
  num_inference_steps=15,
213
  guidance_scale=3.0,
214
  num_images_per_prompt=NUM_VIEWS,
215
  control_image=control_images,
216
+ control_conditioning_scale=1.0,
217
+ reference_image=image,
218
+ reference_conditioning_scale=1.0,
219
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
220
+ cross_attention_kwargs={"scale": 1.0},
221
+ **pipe_kwargs,
222
  ).images
223
 
224
+ torch.cuda.empty_cache()
225
+
226
+ mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
227
  make_image_grid(images, rows=1).save(mv_image_path)
228
 
229
  from texture import TexturePipeline, ModProcessConfig
230
+ texture_pipe = TexturePipeline(
231
+ upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
232
+ inpaint_ckpt_path="checkpoints/big-lama.pt",
233
+ device=DEVICE,
234
+ )
235
 
236
+ textured_glb_path = texture_pipe(
237
+ mesh_path=mesh_path,
238
+ save_dir=save_dir,
239
+ save_name=f"texture_mesh_{get_random_hex()}.glb",
240
+ uv_unwarp=True,
241
+ uv_size=4096,
242
+ rgb_path=mv_image_path,
243
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
244
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
245
+ )
246
+
247
+ return image_seg, mesh_path, textured_glb_path
248
+
249
+
250
+ @spaces.GPU()
251
+ @torch.no_grad()
252
+ def run_segmentation(image: str):
253
+ image = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
254
+ return image
255
+
256
+ @spaces.GPU(duration=90)
257
+ @torch.no_grad()
258
+ def image_to_3d(
259
+ image: Image.Image,
260
+ seed: int,
261
+ num_inference_steps: int,
262
+ guidance_scale: float,
263
+ simplify: bool,
264
+ target_face_num: int,
265
+ req: gr.Request
266
+ ):
267
+ outputs = triposg_pipe(
268
+ image=image,
269
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
270
+ num_inference_steps=num_inference_steps,
271
+ guidance_scale=guidance_scale
272
+ ).samples[0]
273
+ print("mesh extraction done")
274
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
275
+
276
+ if simplify:
277
+ print("start simplify")
278
+ from utils import simplify_mesh
279
+ mesh = simplify_mesh(mesh, target_face_num)
280
+
281
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
282
+ mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
283
+ mesh.export(mesh_path)
284
+ print("save to ", mesh_path)
285
+
286
+ torch.cuda.empty_cache()
287
+
288
+ return mesh_path
289
+
290
+ @spaces.GPU(duration=120)
291
+ @torch.no_grad()
292
+ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
293
+ height, width = 768, 768
294
+ # Prepare cameras
295
+ cameras = get_orthogonal_camera(
296
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
297
+ distance=[1.8] * NUM_VIEWS,
298
+ left=-0.55,
299
+ right=0.55,
300
+ bottom=-0.55,
301
+ top=0.55,
302
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
303
+ device=DEVICE,
304
+ )
305
+ ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
306
+
307
+ mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
308
+ render_out = render(
309
+ ctx,
310
+ mesh,
311
+ cameras,
312
+ height=height,
313
+ width=width,
314
+ render_attr=False,
315
+ normal_background=0.0,
316
+ )
317
+ control_images = (
318
+ torch.cat(
319
+ [
320
+ (render_out.pos + 0.5).clamp(0, 1),
321
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
322
+ ],
323
+ dim=-1,
324
+ )
325
+ .permute(0, 3, 1, 2)
326
+ .to(DEVICE)
327
+ )
328
+
329
+ image = Image.open(image)
330
+ image = remove_bg_fn(image)
331
+ image = preprocess_image(image, height, width)
332
+
333
+ pipe_kwargs = {}
334
+ if seed != -1 and isinstance(seed, int):
335
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
336
+
337
+ images = mv_adapter_pipe(
338
+ "high quality",
339
+ height=height,
340
+ width=width,
341
+ num_inference_steps=15,
342
+ guidance_scale=3.0,
343
+ num_images_per_prompt=NUM_VIEWS,
344
+ control_image=control_images,
345
+ control_conditioning_scale=1.0,
346
+ reference_image=image,
347
+ reference_conditioning_scale=1.0,
348
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
349
+ cross_attention_kwargs={"scale": 1.0},
350
+ **pipe_kwargs,
351
+ ).images
352
+
353
+ torch.cuda.empty_cache()
354
+
355
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
356
+ mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
357
+ make_image_grid(images, rows=1).save(mv_image_path)
358
+
359
+ from texture import TexturePipeline, ModProcessConfig
360
  texture_pipe = TexturePipeline(
361
  upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
362
  inpaint_ckpt_path="checkpoints/big-lama.pt",
363
  device=DEVICE,
364
  )
365
 
366
+ textured_glb_path = texture_pipe(
367
  mesh_path=mesh_path,
368
  save_dir=save_dir,
369
+ save_name=f"texture_mesh_{get_random_hex()}.glb",
370
  uv_unwarp=True,
371
  uv_size=4096,
372
  rgb_path=mv_image_path,
373
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
 
 
 
374
  camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
375
  )
376
 
377
+ return textured_glb_path
378
+
379
 
 
 
 
380
  with gr.Blocks(title="TripoSG") as demo:
381
  gr.Markdown(HEADER)
382
 
383
+ with gr.Row():
384
+ with gr.Column():
385
+ with gr.Row():
386
+ image_prompts = gr.Image(label="Input Image", type="filepath")
387
+ seg_image = gr.Image(
388
+ label="Segmentation Result", type="pil", format="png", interactive=False
389
+ )
390
+
391
+ with gr.Accordion("Generation Settings", open=True):
392
+ seed = gr.Slider(
393
+ label="Seed",
394
+ minimum=0,
395
+ maximum=MAX_SEED,
396
+ step=0,
397
+ value=0
398
+ )
399
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
400
+ num_inference_steps = gr.Slider(
401
+ label="Number of inference steps",
402
+ minimum=8,
403
+ maximum=50,
404
+ step=1,
405
+ value=50,
406
+ )
407
+ guidance_scale = gr.Slider(
408
+ label="CFG scale",
409
+ minimum=0.0,
410
+ maximum=20.0,
411
+ step=0.1,
412
+ value=7.0,
413
+ )
414
+
415
+ with gr.Row():
416
+ reduce_face = gr.Checkbox(label="Simplify Mesh", value=True)
417
+ target_face_num = gr.Slider(maximum=1000000, minimum=10000, value=DEFAULT_FACE_NUMBER, label="Target Face Number")
418
+
419
+ gen_button = gr.Button("Generate Shape", variant="primary")
420
+ gen_texture_button = gr.Button("Apply Texture", interactive=False)
421
+
422
+ with gr.Column():
423
+ model_output = gr.Model3D(label="Generated GLB", interactive=False)
424
+ textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
425
+
426
+ with gr.Row():
427
+ examples = gr.Examples(
428
+ examples=[
429
+ f"{TRIPOSG_CODE_DIR}/assets/example_data/{image}"
430
+ for image in os.listdir(f"{TRIPOSG_CODE_DIR}/assets/example_data")
431
+ ],
432
+ fn=run_full,
433
+ inputs=[image_prompts],
434
+ outputs=[seg_image, model_output, textured_model_output],
435
+ cache_examples=True,
436
+ )
437
+
438
+ gen_button.click(
439
+ run_segmentation,
440
+ inputs=[image_prompts],
441
+ outputs=[seg_image]
442
+ ).then(
443
+ get_random_seed,
444
+ inputs=[randomize_seed, seed],
445
+ outputs=[seed],
446
+ ).then(
447
+ image_to_3d,
448
+ inputs=[
449
+ seg_image,
450
+ seed,
451
+ num_inference_steps,
452
+ guidance_scale,
453
+ reduce_face,
454
+ target_face_num
455
+ ],
456
+ outputs=[model_output]
457
+ ).then(lambda: gr.Button(interactive=True), outputs=[gen_texture_button])
458
 
459
+ gen_texture_button.click(
460
+ run_texture,
461
+ inputs=[image_prompts, model_output, seed],
462
+ outputs=[textured_model_output]
463
  )
464
 
465
+ demo.load(start_session)
466
+ demo.unload(end_session)
467
+
468
+ demo.launch()