dkatz2391 commited on
Commit
1391e5a
verified
1 Parent(s): 701c397

custom endpoints

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py CHANGED
@@ -12,10 +12,15 @@ from PIL import Image
12
  from trellis.pipelines import TrellisImageTo3DPipeline
13
  from trellis.representations import Gaussian, MeshExtractResult
14
  from trellis.utils import render_utils, postprocessing_utils
 
 
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
17
  os.makedirs(TMP_DIR, exist_ok=True)
18
 
 
 
19
  # Funciones auxiliares
20
  def start_session(req: gr.Request):
21
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
@@ -120,6 +125,153 @@ def extract_glb(
120
  torch.cuda.empty_cache()
121
  return glb_path, glb_path
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # Interfaz Gradio
124
  with gr.Blocks(delete_cache=(600, 600)) as demo:
125
  gr.Markdown("""
@@ -206,6 +358,38 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
206
  outputs=[download_glb],
207
  )
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Lanzar la aplicaci贸n Gradio
210
  if __name__ == "__main__":
211
  pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
 
12
  from trellis.pipelines import TrellisImageTo3DPipeline
13
  from trellis.representations import Gaussian, MeshExtractResult
14
  from trellis.utils import render_utils, postprocessing_utils
15
+ import requests
16
+ import base64
17
+ import io
18
  MAX_SEED = np.iinfo(np.int32).max
19
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
+ NODE_SERVER_UPLOAD_URL = "https://viverse-backend.onrender.com/api/upload-rigged-model"
23
+
24
  # Funciones auxiliares
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
125
  torch.cuda.empty_cache()
126
  return glb_path, glb_path
127
 
128
+ @spaces.GPU(duration=180)
129
+ def generate_model_from_images_and_upload(
130
+ image_inputs: List[str],
131
+ input_type: str,
132
+ seed_val: int,
133
+ ss_guidance_strength_val: float,
134
+ ss_sampling_steps_val: int,
135
+ slat_guidance_strength_val: float,
136
+ slat_sampling_steps_val: int,
137
+ multiimage_algo_val: str,
138
+ mesh_simplify_val: float,
139
+ texture_size_val: int,
140
+ model_description: str,
141
+ req: gr.Request
142
+ ) -> str:
143
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
144
+ os.makedirs(user_dir, exist_ok=True)
145
+
146
+ pil_images = []
147
+ image_basenames = []
148
+ print(f"Received image_inputs: {image_inputs}, input_type: {input_type}")
149
+
150
+ for i, img_data in enumerate(image_inputs):
151
+ try:
152
+ print(f"Processing image {i+1}/{len(image_inputs)} with type '{input_type}'")
153
+ if input_type == "url":
154
+ print(f"Fetching image from URL: {img_data}")
155
+ response_img = requests.get(img_data, stream=True, timeout=30)
156
+ response_img.raise_for_status()
157
+ img = Image.open(response_img.raw)
158
+ image_basenames.append(os.path.basename(img_data).split('.')[0] or f"image_{i}")
159
+ elif input_type == "base64":
160
+ print(f"Decoding base64 image data (first 30 chars): {img_data[:30]}...")
161
+ # Ensure correct padding for base64
162
+ missing_padding = len(img_data) % 4
163
+ if missing_padding:
164
+ img_data += '=' * (4 - missing_padding)
165
+ img_bytes = base64.b64decode(img_data)
166
+ img = Image.open(io.BytesIO(img_bytes))
167
+ image_basenames.append(f"base64_image_{i}")
168
+ elif input_type == "filepath":
169
+ print(f"Opening image from filepath: {img_data}")
170
+ img = Image.open(img_data)
171
+ image_basenames.append(os.path.basename(img_data).split('.')[0] or f"image_{i}")
172
+ else:
173
+ print(f"Unsupported input_type: {input_type}")
174
+ raise ValueError(f"Unsupported input_type: {input_type}")
175
+
176
+ print(f"Image {i+1} loaded, mode: {img.mode}, size: {img.size}. Preprocessing...")
177
+ # Ensure image is in RGB format if it's not, e.g. RGBA or P
178
+ if img.mode == 'RGBA' or img.mode == 'P':
179
+ print(f"Converting image {i+1} from {img.mode} to RGB")
180
+ img = img.convert('RGB')
181
+
182
+ processed_img = pipeline.preprocess_image(img)
183
+ pil_images.append(processed_img)
184
+ print(f"Image {i+1} processed and added.")
185
+
186
+ except Exception as e:
187
+ print(f"Error processing image {i} ('{str(img_data)[:50]}...'): {e}")
188
+ import traceback
189
+ traceback.print_exc()
190
+ raise gr.Error(f"Failed to load or process input image {i} ({input_type}): {e}")
191
+
192
+ if not pil_images:
193
+ print("No valid images could be processed.")
194
+ raise gr.Error("No valid images could be processed.")
195
+
196
+ print(f"Total PIL images for pipeline: {len(pil_images)}")
197
+
198
+ print("Running multi-image pipeline...")
199
+ outputs = pipeline.run_multi_image(
200
+ pil_images,
201
+ seed=seed_val,
202
+ formats=["gaussian", "mesh"],
203
+ preprocess_image=False,
204
+ sparse_structure_sampler_params={
205
+ "steps": ss_sampling_steps_val,
206
+ "cfg_strength": ss_guidance_strength_val,
207
+ },
208
+ slat_sampler_params={
209
+ "steps": slat_sampling_steps_val,
210
+ "cfg_strength": slat_guidance_strength_val,
211
+ },
212
+ mode=multiimage_algo_val,
213
+ )
214
+ print("Multi-image pipeline completed.")
215
+
216
+ gs_result = outputs['gaussian'][0]
217
+ mesh_result = outputs['mesh'][0]
218
+
219
+ print(f"Extracting GLB with simplify: {mesh_simplify_val}, texture_size: {texture_size_val}")
220
+ glb_data = postprocessing_utils.to_glb(gs_result, mesh_result, simplify=mesh_simplify_val, texture_size=texture_size_val, verbose=False)
221
+
222
+ temp_glb_filename = 'temp_output_image_model.glb'
223
+ temp_glb_path = os.path.join(user_dir, temp_glb_filename)
224
+ print(f"Exporting GLB to temporary path: {temp_glb_path}")
225
+ glb_data.export(temp_glb_path)
226
+
227
+ torch.cuda.empty_cache()
228
+ print("CUDA cache cleared.")
229
+
230
+ print(f"Uploading GLB from {temp_glb_path} to {NODE_SERVER_UPLOAD_URL}")
231
+ persistent_url = None
232
+ upload_prompt_name = model_description or "_".join(filter(None, image_basenames)) or "imagen_generated_model"
233
+ # Sanitize upload_prompt_name further for safety
234
+ upload_prompt_name = "".join(c if c.isalnum() or c in ['_', '-'] else '_' for c in upload_prompt_name)[:50]
235
+
236
+
237
+ try:
238
+ with open(temp_glb_path, "rb") as f:
239
+ files = {"modelFile": (temp_glb_filename, f, "model/gltf-binary")}
240
+ payload = {
241
+ "clientType": "playcanvas",
242
+ "prompt": upload_prompt_name,
243
+ "modelStage": "imagen_trellis_tpose"
244
+ }
245
+ print(f"Upload payload to Node.js: {payload}")
246
+ response = requests.post(NODE_SERVER_UPLOAD_URL, files=files, data=payload, timeout=120)
247
+ response.raise_for_status()
248
+ result = response.json()
249
+ persistent_url = result.get("persistentUrl")
250
+ if not persistent_url:
251
+ print(f"No persistent URL in Node.js server response: {result}")
252
+ raise ValueError("Upload successful, but no persistent URL returned from Node.js server")
253
+ print(f"Successfully uploaded to Node server. Persistent URL: {persistent_url}")
254
+ except requests.exceptions.RequestException as upload_err:
255
+ print(f"FAILED to upload GLB to Node server: {upload_err}")
256
+ if hasattr(upload_err, 'response') and upload_err.response is not None:
257
+ print(f"Node server response status: {upload_err.response.status_code}")
258
+ print(f"Node server response text: {upload_err.response.text}")
259
+ raise gr.Error(f"Failed to upload result to backend server: {upload_err}")
260
+ except Exception as e:
261
+ print(f"UNEXPECTED error during upload: {e}", exc_info=True)
262
+ raise gr.Error(f"Unexpected error during upload: {e}")
263
+ finally:
264
+ if os.path.exists(temp_glb_path):
265
+ print(f"Cleaning up temporary GLB: {temp_glb_path}")
266
+ os.remove(temp_glb_path)
267
+
268
+ if not persistent_url:
269
+ print("Failed to obtain a persistent URL for the generated model.")
270
+ raise gr.Error("Failed to obtain a persistent URL for the generated model.")
271
+
272
+ print(f"Returning persistent URL: {persistent_url}")
273
+ return persistent_url
274
+
275
  # Interfaz Gradio
276
  with gr.Blocks(delete_cache=(600, 600)) as demo:
277
  gr.Markdown("""
 
358
  outputs=[download_glb],
359
  )
360
 
361
+ # --- Add this section to explicitly register the API function for image to 3D ---
362
+ # These State components are placeholders for API-only inputs
363
+ api_image_inputs_state = gr.State(value=[]) # For List[str] of image_inputs
364
+ api_input_type_state = gr.State(value="url") # For input_type: "url", "filepath", or "base64"
365
+ api_model_description_state = gr.State(value="ImagenModel") # For model_description
366
+
367
+ with gr.Row(visible=False): # Hide this row in the UI
368
+ api_image_gen_trigger_btn = gr.Button("API Image-to-3D Trigger")
369
+
370
+ # Output for the API call (can be a dummy Textbox)
371
+ api_image_gen_output_url = gr.Textbox(label="Generated Model URL (API)", visible=False)
372
+
373
+ api_image_gen_trigger_btn.click(
374
+ generate_model_from_images_and_upload,
375
+ inputs=[ # Order must match the Python function's parameters
376
+ api_image_inputs_state,
377
+ api_input_type_state,
378
+ seed, # UI component
379
+ ss_guidance_strength, # UI component
380
+ ss_sampling_steps, # UI component
381
+ slat_guidance_strength, # UI component
382
+ slat_sampling_steps, # UI component
383
+ multiimage_algo, # UI component
384
+ mesh_simplify, # UI component
385
+ texture_size, # UI component
386
+ api_model_description_state,
387
+ ],
388
+ outputs=[api_image_gen_output_url],
389
+ api_name="generate_model_from_images_and_upload" # Critical: Register the API name
390
+ )
391
+ # --- End API registration section ---
392
+
393
  # Lanzar la aplicaci贸n Gradio
394
  if __name__ == "__main__":
395
  pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")