eigopop commited on
Commit
8a19b8e
·
verified ·
1 Parent(s): 59d84d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -73
app.py CHANGED
@@ -4,94 +4,94 @@ import sys
4
  import gradio as gr
5
  import numpy as np
6
  import torch
 
7
  import trimesh
8
  import random
 
9
  import shutil
 
10
 
11
- from huggingface_hub import snapshot_download
12
- from PIL import Image
 
13
 
14
- # --------------------
15
- # Device
16
- # --------------------
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  DTYPE = torch.float16
 
19
  print("DEVICE:", DEVICE)
20
 
21
- # --------------------
22
- # Constants
23
- # --------------------
24
  DEFAULT_FACE_NUMBER = 100000
25
  MAX_SEED = np.iinfo(np.int32).max
26
 
27
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
28
-
29
- RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
30
  TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
 
31
 
32
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
33
  TMP_DIR = os.path.join(BASE_DIR, "tmp")
34
  os.makedirs(TMP_DIR, exist_ok=True)
35
 
36
- # --------------------
37
- # Clone TripoSG
38
- # --------------------
 
39
  TRIPOSG_CODE_DIR = os.path.join(BASE_DIR, "triposg")
 
40
  if not os.path.exists(TRIPOSG_CODE_DIR):
41
  os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
42
 
43
- sys.path.append(TRIPOSG_CODE_DIR)
 
 
44
 
45
- # --------------------
46
- # UI Header
47
- # --------------------
48
- HEADER = """
49
- # 🔮 Image to 3D with TripoSG
50
 
51
- Upload an image → get a clean 3D mesh (GLB).
 
 
52
 
53
- **Texture generation intentionally disabled.**
54
- """
55
-
56
- # --------------------
57
- # TripoSG + RMBG
58
- # --------------------
59
  from image_process import prepare_image
60
  from briarmbg import BriaRMBG
 
 
 
 
 
 
61
 
62
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
63
  rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
64
  rmbg_net.eval()
65
 
66
- from triposg.pipelines.pipeline_triposg import TripoSGPipeline
67
-
68
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
69
  triposg_pipe = TripoSGPipeline.from_pretrained(
70
  TRIPOSG_PRETRAINED_MODEL
71
  ).to(DEVICE, DTYPE)
72
 
73
- # --------------------
74
  # Helpers
75
- # --------------------
 
76
  def get_random_hex():
77
  return os.urandom(8).hex()
78
 
79
- def get_random_seed(randomize_seed, seed):
80
- if randomize_seed:
81
- seed = random.randint(0, MAX_SEED)
82
- return seed
83
 
84
  def start_session(req: gr.Request):
85
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
86
- os.makedirs(save_dir, exist_ok=True)
87
 
88
  def end_session(req: gr.Request):
89
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
90
- shutil.rmtree(save_dir, ignore_errors=True)
 
 
 
 
 
91
 
92
- # --------------------
93
- # GPU Functions
94
- # --------------------
95
  @spaces.GPU()
96
  @torch.no_grad()
97
  def run_segmentation(image_path: str):
@@ -106,27 +106,27 @@ def run_segmentation(image_path: str):
106
  def image_to_3d(
107
  image: Image.Image,
108
  seed: int,
109
- num_inference_steps: int,
110
- guidance_scale: float,
111
  simplify: bool,
112
- target_face_num: int,
113
  req: gr.Request,
114
  ):
115
  outputs = triposg_pipe(
116
  image=image,
117
  generator=torch.Generator(device=DEVICE).manual_seed(seed),
118
- num_inference_steps=num_inference_steps,
119
- guidance_scale=guidance_scale,
120
  ).samples[0]
121
 
122
  mesh = trimesh.Trimesh(
123
  outputs[0].astype(np.float32),
124
  np.ascontiguousarray(outputs[1]),
 
125
  )
126
 
127
  if simplify:
128
- from utils import simplify_mesh
129
- mesh = simplify_mesh(mesh, target_face_num)
130
 
131
  save_dir = os.path.join(TMP_DIR, str(req.session_hash))
132
  mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
@@ -135,35 +135,40 @@ def image_to_3d(
135
  torch.cuda.empty_cache()
136
  return mesh_path
137
 
138
- # --------------------
139
  # UI
140
- # --------------------
 
 
 
 
 
 
 
141
  with gr.Blocks(title="TripoSG") as demo:
142
  gr.Markdown(HEADER)
143
 
144
  with gr.Row():
145
  with gr.Column():
146
- image_input = gr.Image(label="Input Image", type="filepath")
147
  seg_image = gr.Image(label="Segmentation", type="pil")
148
 
149
- with gr.Accordion("Generation Settings", open=True):
150
- seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
151
- randomize_seed = gr.Checkbox(value=True, label="Randomize Seed")
152
- num_inference_steps = gr.Slider(8, 50, value=50, step=1, label="Inference Steps")
153
- guidance_scale = gr.Slider(0, 20, value=7.5, step=0.1, label="CFG Scale")
154
- simplify = gr.Checkbox(value=True, label="Simplify Mesh")
155
- target_face_num = gr.Slider(
156
- 10_000, 1_000_000, value=DEFAULT_FACE_NUMBER, label="Target Face Count"
157
- )
158
 
159
- gen_button = gr.Button("Generate 3D", variant="primary")
160
 
161
  with gr.Column():
162
- model_output = gr.Model3D(label="Generated GLB")
163
 
164
- gen_button.click(
165
  run_segmentation,
166
- inputs=image_input,
167
  outputs=seg_image,
168
  ).then(
169
  get_random_seed,
@@ -171,15 +176,8 @@ with gr.Blocks(title="TripoSG") as demo:
171
  outputs=seed,
172
  ).then(
173
  image_to_3d,
174
- inputs=[
175
- seg_image,
176
- seed,
177
- num_inference_steps,
178
- guidance_scale,
179
- simplify,
180
- target_face_num,
181
- ],
182
- outputs=model_output,
183
  )
184
 
185
  demo.load(start_session)
 
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
+ from PIL import Image
8
  import trimesh
9
  import random
10
+ from huggingface_hub import snapshot_download
11
  import shutil
12
+ import subprocess
13
 
14
+ # ---------------------------------------------------------------------
15
+ # Basic setup
16
+ # ---------------------------------------------------------------------
17
 
 
 
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
  DTYPE = torch.float16
20
+
21
  print("DEVICE:", DEVICE)
22
 
 
 
 
23
  DEFAULT_FACE_NUMBER = 100000
24
  MAX_SEED = np.iinfo(np.int32).max
25
 
26
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
 
 
27
  TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
28
+ RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
29
 
30
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
31
  TMP_DIR = os.path.join(BASE_DIR, "tmp")
32
  os.makedirs(TMP_DIR, exist_ok=True)
33
 
34
+ # ---------------------------------------------------------------------
35
+ # Clone TripoSG code (runtime-safe)
36
+ # ---------------------------------------------------------------------
37
+
38
  TRIPOSG_CODE_DIR = os.path.join(BASE_DIR, "triposg")
39
+
40
  if not os.path.exists(TRIPOSG_CODE_DIR):
41
  os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
42
 
43
+ # ---------------------------------------------------------------------
44
+ # 🔑 CRITICAL FIX: make TripoSG imports visible BEFORE importing
45
+ # ---------------------------------------------------------------------
46
 
47
+ sys.path.insert(0, TRIPOSG_CODE_DIR)
48
+ sys.path.insert(0, os.path.join(TRIPOSG_CODE_DIR, "scripts"))
 
 
 
49
 
50
+ # ---------------------------------------------------------------------
51
+ # Now imports work
52
+ # ---------------------------------------------------------------------
53
 
 
 
 
 
 
 
54
  from image_process import prepare_image
55
  from briarmbg import BriaRMBG
56
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
57
+ from utils import simplify_mesh
58
+
59
+ # ---------------------------------------------------------------------
60
+ # Load models
61
+ # ---------------------------------------------------------------------
62
 
63
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
64
  rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
65
  rmbg_net.eval()
66
 
 
 
67
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
68
  triposg_pipe = TripoSGPipeline.from_pretrained(
69
  TRIPOSG_PRETRAINED_MODEL
70
  ).to(DEVICE, DTYPE)
71
 
72
+ # ---------------------------------------------------------------------
73
  # Helpers
74
+ # ---------------------------------------------------------------------
75
+
76
  def get_random_hex():
77
  return os.urandom(8).hex()
78
 
79
+ def get_random_seed(randomize, seed):
80
+ return random.randint(0, MAX_SEED) if randomize else seed
 
 
81
 
82
  def start_session(req: gr.Request):
83
+ path = os.path.join(TMP_DIR, str(req.session_hash))
84
+ os.makedirs(path, exist_ok=True)
85
 
86
  def end_session(req: gr.Request):
87
+ path = os.path.join(TMP_DIR, str(req.session_hash))
88
+ if os.path.exists(path):
89
+ shutil.rmtree(path)
90
+
91
+ # ---------------------------------------------------------------------
92
+ # GPU functions
93
+ # ---------------------------------------------------------------------
94
 
 
 
 
95
  @spaces.GPU()
96
  @torch.no_grad()
97
  def run_segmentation(image_path: str):
 
106
  def image_to_3d(
107
  image: Image.Image,
108
  seed: int,
109
+ steps: int,
110
+ guidance: float,
111
  simplify: bool,
112
+ target_faces: int,
113
  req: gr.Request,
114
  ):
115
  outputs = triposg_pipe(
116
  image=image,
117
  generator=torch.Generator(device=DEVICE).manual_seed(seed),
118
+ num_inference_steps=steps,
119
+ guidance_scale=guidance,
120
  ).samples[0]
121
 
122
  mesh = trimesh.Trimesh(
123
  outputs[0].astype(np.float32),
124
  np.ascontiguousarray(outputs[1]),
125
+ process=False,
126
  )
127
 
128
  if simplify:
129
+ mesh = simplify_mesh(mesh, target_faces)
 
130
 
131
  save_dir = os.path.join(TMP_DIR, str(req.session_hash))
132
  mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
 
135
  torch.cuda.empty_cache()
136
  return mesh_path
137
 
138
+ # ---------------------------------------------------------------------
139
  # UI
140
+ # ---------------------------------------------------------------------
141
+
142
+ HEADER = """
143
+ # 🔮 Image → 3D (TripoSG)
144
+
145
+ Mesh-only demo (no texture, no MV-Adapter).
146
+ """
147
+
148
  with gr.Blocks(title="TripoSG") as demo:
149
  gr.Markdown(HEADER)
150
 
151
  with gr.Row():
152
  with gr.Column():
153
+ input_image = gr.Image(label="Input Image", type="filepath")
154
  seg_image = gr.Image(label="Segmentation", type="pil")
155
 
156
+ seed = gr.Slider(0, MAX_SEED, value=0, label="Seed")
157
+ randomize_seed = gr.Checkbox(value=True, label="Randomize Seed")
158
+ steps = gr.Slider(8, 50, value=50, step=1, label="Inference Steps")
159
+ guidance = gr.Slider(0.0, 20.0, value=7.5, step=0.1, label="CFG Scale")
160
+
161
+ simplify = gr.Checkbox(value=True, label="Simplify Mesh")
162
+ target_faces = gr.Slider(10_000, 1_000_000, value=DEFAULT_FACE_NUMBER)
 
 
163
 
164
+ gen_btn = gr.Button("Generate 3D", variant="primary")
165
 
166
  with gr.Column():
167
+ model_out = gr.Model3D(label="Generated GLB")
168
 
169
+ gen_btn.click(
170
  run_segmentation,
171
+ inputs=input_image,
172
  outputs=seg_image,
173
  ).then(
174
  get_random_seed,
 
176
  outputs=seed,
177
  ).then(
178
  image_to_3d,
179
+ inputs=[seg_image, seed, steps, guidance, simplify, target_faces],
180
+ outputs=model_out,
 
 
 
 
 
 
 
181
  )
182
 
183
  demo.load(start_session)