eigopop commited on
Commit
3082037
·
verified ·
1 Parent(s): 5281259

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -101
app.py CHANGED
@@ -1,186 +1,468 @@
1
  import spaces
2
  import os
3
- import sys
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  from PIL import Image
8
  import trimesh
9
  import random
10
- from huggingface_hub import snapshot_download
11
- import shutil
 
12
  import subprocess
 
13
 
14
- # ---------------------------------------------------------------------
15
- # Basic setup
16
- # ---------------------------------------------------------------------
17
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.float16
20
 
21
- print("DEVICE:", DEVICE)
22
 
23
  DEFAULT_FACE_NUMBER = 100000
24
  MAX_SEED = np.iinfo(np.int32).max
25
-
26
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
27
- TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
 
28
  RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
 
29
 
30
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
31
- TMP_DIR = os.path.join(BASE_DIR, "tmp")
32
  os.makedirs(TMP_DIR, exist_ok=True)
33
 
34
- # ---------------------------------------------------------------------
35
- # Clone TripoSG code (runtime-safe)
36
- # ---------------------------------------------------------------------
37
-
38
- TRIPOSG_CODE_DIR = os.path.join(BASE_DIR, "triposg")
39
-
40
  if not os.path.exists(TRIPOSG_CODE_DIR):
41
  os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
42
 
43
- # ---------------------------------------------------------------------
44
- # 🔑 CRITICAL FIX: make TripoSG imports visible BEFORE importing
45
- # ---------------------------------------------------------------------
46
 
47
- sys.path.insert(0, TRIPOSG_CODE_DIR)
48
- sys.path.insert(0, os.path.join(TRIPOSG_CODE_DIR, "scripts"))
 
 
 
49
 
50
- # ---------------------------------------------------------------------
51
- # Now imports work
52
- # ---------------------------------------------------------------------
53
 
54
- from image_process import prepare_image
55
- from briarmbg import BriaRMBG
56
- from triposg.pipelines.pipeline_triposg import TripoSGPipeline
57
- from utils import simplify_mesh
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # ---------------------------------------------------------------------
60
- # Load models
61
- # ---------------------------------------------------------------------
62
 
 
 
 
63
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
64
  rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
65
  rmbg_net.eval()
66
-
67
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
68
- triposg_pipe = TripoSGPipeline.from_pretrained(
69
- TRIPOSG_PRETRAINED_MODEL
70
- ).to(DEVICE, DTYPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # ---------------------------------------------------------------------
73
- # Helpers
74
- # ---------------------------------------------------------------------
 
 
 
 
 
75
 
76
  def get_random_hex():
77
- return os.urandom(8).hex()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- def get_random_seed(randomize, seed):
80
- return random.randint(0, MAX_SEED) if randomize else seed
 
 
 
 
 
 
81
 
82
- def start_session(req: gr.Request):
83
- path = os.path.join(TMP_DIR, str(req.session_hash))
84
- os.makedirs(path, exist_ok=True)
 
 
 
 
 
 
 
85
 
86
- def end_session(req: gr.Request):
87
- path = os.path.join(TMP_DIR, str(req.session_hash))
88
- if os.path.exists(path):
89
- shutil.rmtree(path)
90
 
91
- # ---------------------------------------------------------------------
92
- # GPU functions
93
- # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  @spaces.GPU()
96
  @torch.no_grad()
97
- def run_segmentation(image_path: str):
98
- return prepare_image(
99
- image_path,
100
- bg_color=np.array([1.0, 1.0, 1.0]),
101
- rmbg_net=rmbg_net,
102
- )
103
 
104
- @spaces.GPU(duration=120)
105
  @torch.no_grad()
106
  def image_to_3d(
107
  image: Image.Image,
108
  seed: int,
109
- steps: int,
110
- guidance: float,
111
  simplify: bool,
112
- target_faces: int,
113
- req: gr.Request,
114
  ):
115
  outputs = triposg_pipe(
116
  image=image,
117
- generator=torch.Generator(device=DEVICE).manual_seed(seed),
118
- num_inference_steps=steps,
119
- guidance_scale=guidance,
120
  ).samples[0]
121
-
122
- mesh = trimesh.Trimesh(
123
- outputs[0].astype(np.float32),
124
- np.ascontiguousarray(outputs[1]),
125
- process=False,
126
- )
127
 
128
  if simplify:
129
- mesh = simplify_mesh(mesh, target_faces)
130
-
 
 
131
  save_dir = os.path.join(TMP_DIR, str(req.session_hash))
132
  mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
133
  mesh.export(mesh_path)
 
134
 
135
  torch.cuda.empty_cache()
 
136
  return mesh_path
137
 
138
- # ---------------------------------------------------------------------
139
- # UI
140
- # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- HEADER = """
143
- # 🔮 Image → 3D (TripoSG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- Mesh-only demo (no texture, no MV-Adapter).
146
- """
147
 
148
  with gr.Blocks(title="TripoSG") as demo:
149
  gr.Markdown(HEADER)
150
 
151
  with gr.Row():
152
  with gr.Column():
153
- input_image = gr.Image(label="Input Image", type="filepath")
154
- seg_image = gr.Image(label="Segmentation", type="pil")
155
-
156
- seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
157
- randomize_seed = gr.Checkbox(value=True, label="Randomize Seed")
158
- steps = gr.Slider(8, 50, value=50, step=1, label="Inference Steps")
159
- guidance = gr.Slider(0.0, 20.0, value=7.5, step=0.1, label="CFG Scale")
160
-
161
- simplify = gr.Checkbox(value=True, label="Simplify Mesh")
162
- target_faces = gr.Slider(10_000, 1_000_000, value=DEFAULT_FACE_NUMBER)
163
-
164
- gen_btn = gr.Button("Generate 3D", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  with gr.Column():
167
- model_out = gr.Model3D(label="Generated GLB")
 
168
 
169
- gen_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
170
  run_segmentation,
171
- inputs=input_image,
172
- outputs=seg_image,
173
  ).then(
174
  get_random_seed,
175
  inputs=[randomize_seed, seed],
176
- outputs=seed,
177
  ).then(
178
  image_to_3d,
179
- inputs=[seg_image, seed, steps, guidance, simplify, target_faces],
180
- outputs=model_out,
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  )
182
 
183
  demo.load(start_session)
184
  demo.unload(end_session)
185
 
186
- demo.launch()
 
1
  import spaces
2
  import os
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from PIL import Image
7
  import trimesh
8
  import random
9
+ from transformers import AutoModelForImageSegmentation
10
+ from torchvision import transforms
11
+ from huggingface_hub import hf_hub_download, snapshot_download
12
  import subprocess
13
+ import shutil
14
 
15
+ # install others
16
+ subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
 
17
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.float16
20
 
21
+ print("DEVICE: ", DEVICE)
22
 
23
  DEFAULT_FACE_NUMBER = 100000
24
  MAX_SEED = np.iinfo(np.int32).max
 
25
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
26
+ MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
27
+
28
  RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
29
+ TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
30
 
31
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
 
32
  os.makedirs(TMP_DIR, exist_ok=True)
33
 
34
+ TRIPOSG_CODE_DIR = "./triposg"
 
 
 
 
 
35
  if not os.path.exists(TRIPOSG_CODE_DIR):
36
  os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
37
 
38
+ MV_ADAPTER_CODE_DIR = "./mv_adapter"
39
+ if not os.path.exists(MV_ADAPTER_CODE_DIR):
40
+ os.system(f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} && cd {MV_ADAPTER_CODE_DIR} && git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d")
41
 
42
+ import sys
43
+ sys.path.append(TRIPOSG_CODE_DIR)
44
+ sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
45
+ sys.path.append(MV_ADAPTER_CODE_DIR)
46
+ sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
47
 
48
+ HEADER = """
 
 
49
 
50
+ # 🔮 Image to 3D with [TripoSG](https://github.com/VAST-AI-Research/TripoSG)
51
+
52
+ ## State-of-the-art Open Source 3D Generation Using Large-Scale Rectified Flow Transformers
53
+
54
+ <p style="font-size: 1.1em;">By <a href="https://www.tripo3d.ai/" style="color: #1E90FF; text-decoration: none; font-weight: bold;">Tripo</a></p>
55
+
56
+ ## 📋 Quick Start Guide:
57
+ 1. **Upload an image** (single object works best)
58
+ 2. Click **Generate Shape** to create the 3D mesh
59
+ 3. Click **Apply Texture** to add textures
60
+ 4. Use **Download GLB** to save your 3D model
61
+ 5. Adjust parameters under **Generation Settings** for fine-tuning
62
+
63
+ Best results come from clean, well-lit images with clear subject isolation. Try it now!
64
+
65
+ <p style="font-size: 0.9em; margin-top: 10px;">Texture generation powered by <a href="https://github.com/huanngzh/MV-Adapter" style="color: #1E90FF; text-decoration: none;">MV-Adapter</a> - a versatile multi-view adapter for consistent texture generation. Try the <a href="https://huggingface.co/spaces/VAST-AI/MV-Adapter-I2MV-SDXL" style="color: #1E90FF; text-decoration: none;">MV-Adapter demo</a> for multi-view image generation.</p>
66
 
67
+ """
 
 
68
 
69
+ # # triposg
70
+ from image_process import prepare_image
71
+ from briarmbg import BriaRMBG
72
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
73
  rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
74
  rmbg_net.eval()
75
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
76
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
77
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
78
+
79
+ # mv adapter
80
+ NUM_VIEWS = 6
81
+ from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
82
+ from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
83
+ from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
84
+ mv_adapter_pipe = prepare_pipeline(
85
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
86
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
87
+ unet_model=None,
88
+ lora_model=None,
89
+ adapter_path="huanngzh/mv-adapter",
90
+ scheduler=None,
91
+ num_views=NUM_VIEWS,
92
+ device=DEVICE,
93
+ dtype=torch.float16,
94
+ )
95
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
96
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
97
+ )
98
+ birefnet.to(DEVICE)
99
+ transform_image = transforms.Compose(
100
+ [
101
+ transforms.Resize((1024, 1024)),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
104
+ ]
105
+ )
106
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
107
+
108
+ if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
109
+ hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
110
+ if not os.path.exists("checkpoints/big-lama.pt"):
111
+ subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
112
 
113
+ def start_session(req: gr.Request):
114
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
115
+ os.makedirs(save_dir, exist_ok=True)
116
+ print("start session, mkdir", save_dir)
117
+
118
+ def end_session(req: gr.Request):
119
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
120
+ shutil.rmtree(save_dir)
121
 
122
  def get_random_hex():
123
+ random_bytes = os.urandom(8)
124
+ random_hex = random_bytes.hex()
125
+ return random_hex
126
+
127
+ def get_random_seed(randomize_seed, seed):
128
+ if randomize_seed:
129
+ seed = random.randint(0, MAX_SEED)
130
+ return seed
131
+
132
+ @spaces.GPU(duration=180)
133
+ def run_full(image: str, req: gr.Request):
134
+ seed = 0
135
+ num_inference_steps = 50
136
+ guidance_scale = 7.5
137
+ simplify = True
138
+ target_face_num = DEFAULT_FACE_NUMBER
139
+
140
+ image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
141
 
142
+ outputs = triposg_pipe(
143
+ image=image_seg,
144
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
145
+ num_inference_steps=num_inference_steps,
146
+ guidance_scale=guidance_scale
147
+ ).samples[0]
148
+ print("mesh extraction done")
149
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
150
 
151
+ if simplify:
152
+ print("start simplify")
153
+ from utils import simplify_mesh
154
+ mesh = simplify_mesh(mesh, target_face_num)
155
+
156
+ save_dir = os.path.join(TMP_DIR, "examples")
157
+ os.makedirs(save_dir, exist_ok=True)
158
+ mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
159
+ mesh.export(mesh_path)
160
+ print("save to ", mesh_path)
161
 
162
+ torch.cuda.empty_cache()
 
 
 
163
 
164
+ height, width = 768, 768
165
+ # Prepare cameras
166
+ cameras = get_orthogonal_camera(
167
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
168
+ distance=[1.8] * NUM_VIEWS,
169
+ left=-0.55,
170
+ right=0.55,
171
+ bottom=-0.55,
172
+ top=0.55,
173
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
174
+ device=DEVICE,
175
+ )
176
+ ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
177
+
178
+ mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
179
+ render_out = render(
180
+ ctx,
181
+ mesh,
182
+ cameras,
183
+ height=height,
184
+ width=width,
185
+ render_attr=False,
186
+ normal_background=0.0,
187
+ )
188
+ control_images = (
189
+ torch.cat(
190
+ [
191
+ (render_out.pos + 0.5).clamp(0, 1),
192
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
193
+ ],
194
+ dim=-1,
195
+ )
196
+ .permute(0, 3, 1, 2)
197
+ .to(DEVICE)
198
+ )
199
+
200
+ image = Image.open(image)
201
+ image = remove_bg_fn(image)
202
+ image = preprocess_image(image, height, width)
203
+
204
+ pipe_kwargs = {}
205
+ if seed != -1 and isinstance(seed, int):
206
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
207
+
208
+ images = mv_adapter_pipe(
209
+ "high quality",
210
+ height=height,
211
+ width=width,
212
+ num_inference_steps=15,
213
+ guidance_scale=3.0,
214
+ num_images_per_prompt=NUM_VIEWS,
215
+ control_image=control_images,
216
+ control_conditioning_scale=1.0,
217
+ reference_image=image,
218
+ reference_conditioning_scale=1.0,
219
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
220
+ cross_attention_kwargs={"scale": 1.0},
221
+ **pipe_kwargs,
222
+ ).images
223
+
224
+ torch.cuda.empty_cache()
225
+
226
+ mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
227
+ make_image_grid(images, rows=1).save(mv_image_path)
228
+
229
+ from texture import TexturePipeline, ModProcessConfig
230
+ texture_pipe = TexturePipeline(
231
+ upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
232
+ inpaint_ckpt_path="checkpoints/big-lama.pt",
233
+ device=DEVICE,
234
+ )
235
+
236
+ textured_glb_path = texture_pipe(
237
+ mesh_path=mesh_path,
238
+ save_dir=save_dir,
239
+ save_name=f"texture_mesh_{get_random_hex()}.glb",
240
+ uv_unwarp=True,
241
+ uv_size=4096,
242
+ rgb_path=mv_image_path,
243
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
244
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
245
+ )
246
+
247
+ return image_seg, mesh_path, textured_glb_path
248
+
249
 
250
  @spaces.GPU()
251
  @torch.no_grad()
252
+ def run_segmentation(image: str):
253
+ image = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
254
+ return image
 
 
 
255
 
256
+ @spaces.GPU(duration=90)
257
  @torch.no_grad()
258
  def image_to_3d(
259
  image: Image.Image,
260
  seed: int,
261
+ num_inference_steps: int,
262
+ guidance_scale: float,
263
  simplify: bool,
264
+ target_face_num: int,
265
+ req: gr.Request
266
  ):
267
  outputs = triposg_pipe(
268
  image=image,
269
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
270
+ num_inference_steps=num_inference_steps,
271
+ guidance_scale=guidance_scale
272
  ).samples[0]
273
+ print("mesh extraction done")
274
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
 
 
 
 
275
 
276
  if simplify:
277
+ print("start simplify")
278
+ from utils import simplify_mesh
279
+ mesh = simplify_mesh(mesh, target_face_num)
280
+
281
  save_dir = os.path.join(TMP_DIR, str(req.session_hash))
282
  mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
283
  mesh.export(mesh_path)
284
+ print("save to ", mesh_path)
285
 
286
  torch.cuda.empty_cache()
287
+
288
  return mesh_path
289
 
290
+ @spaces.GPU(duration=120)
291
+ @torch.no_grad()
292
+ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
293
+ height, width = 768, 768
294
+ # Prepare cameras
295
+ cameras = get_orthogonal_camera(
296
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
297
+ distance=[1.8] * NUM_VIEWS,
298
+ left=-0.55,
299
+ right=0.55,
300
+ bottom=-0.55,
301
+ top=0.55,
302
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
303
+ device=DEVICE,
304
+ )
305
+ ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
306
+
307
+ mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
308
+ render_out = render(
309
+ ctx,
310
+ mesh,
311
+ cameras,
312
+ height=height,
313
+ width=width,
314
+ render_attr=False,
315
+ normal_background=0.0,
316
+ )
317
+ control_images = (
318
+ torch.cat(
319
+ [
320
+ (render_out.pos + 0.5).clamp(0, 1),
321
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
322
+ ],
323
+ dim=-1,
324
+ )
325
+ .permute(0, 3, 1, 2)
326
+ .to(DEVICE)
327
+ )
328
 
329
+ image = Image.open(image)
330
+ image = remove_bg_fn(image)
331
+ image = preprocess_image(image, height, width)
332
+
333
+ pipe_kwargs = {}
334
+ if seed != -1 and isinstance(seed, int):
335
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
336
+
337
+ images = mv_adapter_pipe(
338
+ "high quality",
339
+ height=height,
340
+ width=width,
341
+ num_inference_steps=15,
342
+ guidance_scale=3.0,
343
+ num_images_per_prompt=NUM_VIEWS,
344
+ control_image=control_images,
345
+ control_conditioning_scale=1.0,
346
+ reference_image=image,
347
+ reference_conditioning_scale=1.0,
348
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
349
+ cross_attention_kwargs={"scale": 1.0},
350
+ **pipe_kwargs,
351
+ ).images
352
+
353
+ torch.cuda.empty_cache()
354
+
355
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
356
+ mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
357
+ make_image_grid(images, rows=1).save(mv_image_path)
358
+
359
+ from texture import TexturePipeline, ModProcessConfig
360
+ texture_pipe = TexturePipeline(
361
+ upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
362
+ inpaint_ckpt_path="checkpoints/big-lama.pt",
363
+ device=DEVICE,
364
+ )
365
+
366
+ textured_glb_path = texture_pipe(
367
+ mesh_path=mesh_path,
368
+ save_dir=save_dir,
369
+ save_name=f"texture_mesh_{get_random_hex()}.glb",
370
+ uv_unwarp=True,
371
+ uv_size=4096,
372
+ rgb_path=mv_image_path,
373
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
374
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
375
+ )
376
+
377
+ return textured_glb_path
378
 
 
 
379
 
380
  with gr.Blocks(title="TripoSG") as demo:
381
  gr.Markdown(HEADER)
382
 
383
  with gr.Row():
384
  with gr.Column():
385
+ with gr.Row():
386
+ image_prompts = gr.Image(label="Input Image", type="filepath")
387
+ seg_image = gr.Image(
388
+ label="Segmentation Result", type="pil", format="png", interactive=False
389
+ )
390
+
391
+ with gr.Accordion("Generation Settings", open=True):
392
+ seed = gr.Slider(
393
+ label="Seed",
394
+ minimum=0,
395
+ maximum=MAX_SEED,
396
+ step=0,
397
+ value=0
398
+ )
399
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
400
+ num_inference_steps = gr.Slider(
401
+ label="Number of inference steps",
402
+ minimum=8,
403
+ maximum=50,
404
+ step=1,
405
+ value=50,
406
+ )
407
+ guidance_scale = gr.Slider(
408
+ label="CFG scale",
409
+ minimum=0.0,
410
+ maximum=20.0,
411
+ step=0.1,
412
+ value=7.0,
413
+ )
414
+
415
+ with gr.Row():
416
+ reduce_face = gr.Checkbox(label="Simplify Mesh", value=True)
417
+ target_face_num = gr.Slider(maximum=1000000, minimum=10000, value=DEFAULT_FACE_NUMBER, label="Target Face Number")
418
+
419
+ gen_button = gr.Button("Generate Shape", variant="primary")
420
+ gen_texture_button = gr.Button("Apply Texture", interactive=False)
421
 
422
  with gr.Column():
423
+ model_output = gr.Model3D(label="Generated GLB", interactive=False)
424
+ textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
425
 
426
+ with gr.Row():
427
+ examples = gr.Examples(
428
+ examples=[
429
+ f"{TRIPOSG_CODE_DIR}/assets/example_data/{image}"
430
+ for image in os.listdir(f"{TRIPOSG_CODE_DIR}/assets/example_data")
431
+ ],
432
+ fn=run_full,
433
+ inputs=[image_prompts],
434
+ outputs=[seg_image, model_output, textured_model_output],
435
+ cache_examples=True,
436
+ )
437
+
438
+ gen_button.click(
439
  run_segmentation,
440
+ inputs=[image_prompts],
441
+ outputs=[seg_image]
442
  ).then(
443
  get_random_seed,
444
  inputs=[randomize_seed, seed],
445
+ outputs=[seed],
446
  ).then(
447
  image_to_3d,
448
+ inputs=[
449
+ seg_image,
450
+ seed,
451
+ num_inference_steps,
452
+ guidance_scale,
453
+ reduce_face,
454
+ target_face_num
455
+ ],
456
+ outputs=[model_output]
457
+ ).then(lambda: gr.Button(interactive=True), outputs=[gen_texture_button])
458
+
459
+ gen_texture_button.click(
460
+ run_texture,
461
+ inputs=[image_prompts, model_output, seed],
462
+ outputs=[textured_model_output]
463
  )
464
 
465
  demo.load(start_session)
466
  demo.unload(end_session)
467
 
468
+ demo.launch()