dkatz2391's picture
custom endpoints
1391e5a verified
raw
history blame
16.9 kB
import gradio as gr
import spaces
import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
import requests
import base64
import io
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
NODE_SERVER_UPLOAD_URL = "https://viverse-backend.onrender.com/api/upload-rigged-model"
# Funciones auxiliares
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(user_dir)
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
images = [image[0] for image in images]
processed_images = [pipeline.preprocess_image(image) for image in images]
return processed_images
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
return {
'gaussian': {
**gs.init_params,
'_xyz': gs._xyz.cpu().numpy(),
'_features_dc': gs._features_dc.cpu().numpy(),
'_scaling': gs._scaling.cpu().numpy(),
'_rotation': gs._rotation.cpu().numpy(),
'_opacity': gs._opacity.cpu().numpy(),
},
'mesh': {
'vertices': mesh.vertices.cpu().numpy(),
'faces': mesh.faces.cpu().numpy(),
},
}
def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
gs = Gaussian(
aabb=state['gaussian']['aabb'],
sh_degree=state['gaussian']['sh_degree'],
mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
scaling_bias=state['gaussian']['scaling_bias'],
opacity_bias=state['gaussian']['opacity_bias'],
scaling_activation=state['gaussian']['scaling_activation'],
)
gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
mesh = edict(
vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
faces=torch.tensor(state['mesh']['faces'], device='cuda'),
)
return gs, mesh
def get_seed(randomize_seed: bool, seed: int) -> int:
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def image_to_3d(
multiimages: List[Tuple[Image.Image, str]],
seed: int,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
) -> Tuple[dict, str]:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
outputs = pipeline.run_multi_image(
[image[0] for image in multiimages],
seed=seed,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"cfg_strength": ss_guidance_strength,
},
slat_sampler_params={
"steps": slat_sampling_steps,
"cfg_strength": slat_guidance_strength,
},
mode=multiimage_algo,
)
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
video_path = os.path.join(user_dir, 'sample.mp4')
imageio.mimsave(video_path, video, fps=15)
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
torch.cuda.empty_cache()
return state, video_path
@spaces.GPU(duration=90)
def extract_glb(
state: dict,
mesh_simplify: float,
texture_size: int,
req: gr.Request,
) -> Tuple[str, str]:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
gs, mesh = unpack_state(state)
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
glb_path = os.path.join(user_dir, 'sample.glb')
glb.export(glb_path)
torch.cuda.empty_cache()
return glb_path, glb_path
@spaces.GPU(duration=180)
def generate_model_from_images_and_upload(
image_inputs: List[str],
input_type: str,
seed_val: int,
ss_guidance_strength_val: float,
ss_sampling_steps_val: int,
slat_guidance_strength_val: float,
slat_sampling_steps_val: int,
multiimage_algo_val: str,
mesh_simplify_val: float,
texture_size_val: int,
model_description: str,
req: gr.Request
) -> str:
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
pil_images = []
image_basenames = []
print(f"Received image_inputs: {image_inputs}, input_type: {input_type}")
for i, img_data in enumerate(image_inputs):
try:
print(f"Processing image {i+1}/{len(image_inputs)} with type '{input_type}'")
if input_type == "url":
print(f"Fetching image from URL: {img_data}")
response_img = requests.get(img_data, stream=True, timeout=30)
response_img.raise_for_status()
img = Image.open(response_img.raw)
image_basenames.append(os.path.basename(img_data).split('.')[0] or f"image_{i}")
elif input_type == "base64":
print(f"Decoding base64 image data (first 30 chars): {img_data[:30]}...")
# Ensure correct padding for base64
missing_padding = len(img_data) % 4
if missing_padding:
img_data += '=' * (4 - missing_padding)
img_bytes = base64.b64decode(img_data)
img = Image.open(io.BytesIO(img_bytes))
image_basenames.append(f"base64_image_{i}")
elif input_type == "filepath":
print(f"Opening image from filepath: {img_data}")
img = Image.open(img_data)
image_basenames.append(os.path.basename(img_data).split('.')[0] or f"image_{i}")
else:
print(f"Unsupported input_type: {input_type}")
raise ValueError(f"Unsupported input_type: {input_type}")
print(f"Image {i+1} loaded, mode: {img.mode}, size: {img.size}. Preprocessing...")
# Ensure image is in RGB format if it's not, e.g. RGBA or P
if img.mode == 'RGBA' or img.mode == 'P':
print(f"Converting image {i+1} from {img.mode} to RGB")
img = img.convert('RGB')
processed_img = pipeline.preprocess_image(img)
pil_images.append(processed_img)
print(f"Image {i+1} processed and added.")
except Exception as e:
print(f"Error processing image {i} ('{str(img_data)[:50]}...'): {e}")
import traceback
traceback.print_exc()
raise gr.Error(f"Failed to load or process input image {i} ({input_type}): {e}")
if not pil_images:
print("No valid images could be processed.")
raise gr.Error("No valid images could be processed.")
print(f"Total PIL images for pipeline: {len(pil_images)}")
print("Running multi-image pipeline...")
outputs = pipeline.run_multi_image(
pil_images,
seed=seed_val,
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps_val,
"cfg_strength": ss_guidance_strength_val,
},
slat_sampler_params={
"steps": slat_sampling_steps_val,
"cfg_strength": slat_guidance_strength_val,
},
mode=multiimage_algo_val,
)
print("Multi-image pipeline completed.")
gs_result = outputs['gaussian'][0]
mesh_result = outputs['mesh'][0]
print(f"Extracting GLB with simplify: {mesh_simplify_val}, texture_size: {texture_size_val}")
glb_data = postprocessing_utils.to_glb(gs_result, mesh_result, simplify=mesh_simplify_val, texture_size=texture_size_val, verbose=False)
temp_glb_filename = 'temp_output_image_model.glb'
temp_glb_path = os.path.join(user_dir, temp_glb_filename)
print(f"Exporting GLB to temporary path: {temp_glb_path}")
glb_data.export(temp_glb_path)
torch.cuda.empty_cache()
print("CUDA cache cleared.")
print(f"Uploading GLB from {temp_glb_path} to {NODE_SERVER_UPLOAD_URL}")
persistent_url = None
upload_prompt_name = model_description or "_".join(filter(None, image_basenames)) or "imagen_generated_model"
# Sanitize upload_prompt_name further for safety
upload_prompt_name = "".join(c if c.isalnum() or c in ['_', '-'] else '_' for c in upload_prompt_name)[:50]
try:
with open(temp_glb_path, "rb") as f:
files = {"modelFile": (temp_glb_filename, f, "model/gltf-binary")}
payload = {
"clientType": "playcanvas",
"prompt": upload_prompt_name,
"modelStage": "imagen_trellis_tpose"
}
print(f"Upload payload to Node.js: {payload}")
response = requests.post(NODE_SERVER_UPLOAD_URL, files=files, data=payload, timeout=120)
response.raise_for_status()
result = response.json()
persistent_url = result.get("persistentUrl")
if not persistent_url:
print(f"No persistent URL in Node.js server response: {result}")
raise ValueError("Upload successful, but no persistent URL returned from Node.js server")
print(f"Successfully uploaded to Node server. Persistent URL: {persistent_url}")
except requests.exceptions.RequestException as upload_err:
print(f"FAILED to upload GLB to Node server: {upload_err}")
if hasattr(upload_err, 'response') and upload_err.response is not None:
print(f"Node server response status: {upload_err.response.status_code}")
print(f"Node server response text: {upload_err.response.text}")
raise gr.Error(f"Failed to upload result to backend server: {upload_err}")
except Exception as e:
print(f"UNEXPECTED error during upload: {e}", exc_info=True)
raise gr.Error(f"Unexpected error during upload: {e}")
finally:
if os.path.exists(temp_glb_path):
print(f"Cleaning up temporary GLB: {temp_glb_path}")
os.remove(temp_glb_path)
if not persistent_url:
print("Failed to obtain a persistent URL for the generated model.")
raise gr.Error("Failed to obtain a persistent URL for the generated model.")
print(f"Returning persistent URL: {persistent_url}")
return persistent_url
# Interfaz Gradio
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown("""
# UTPL - Conversi贸n de Multiples Im谩genes a objetos 3D usando IA
### Tesis: *"Objetos tridimensionales creados por IA: Innovaci贸n en entornos virtuales"*
**Autor:** Carlos Vargas
**Base t茅cnica:** Adaptaci贸n de [TRELLIS](https://trellis3d.github.io/) (herramienta de c贸digo abierto para generaci贸n 3D)
**Prop贸sito educativo:** Demostraciones acad茅micas e Investigaci贸n en modelado 3D autom谩tico
""")
with gr.Row():
with gr.Column():
with gr.Tabs() as input_tabs:
with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
with gr.Accordion(label="Generation Settings", open=False):
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
gr.Markdown("Stage 2: Structured Latent Generation")
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
generate_btn = gr.Button("Generate")
with gr.Accordion(label="GLB Extraction Settings", open=False):
mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
extract_glb_btn = gr.Button("Extract GLB", interactive=False)
with gr.Column():
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
model_output = gr.Model3D(label="Extracted GLB", height=300)
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
output_buf = gr.State()
# Manejadores
demo.load(start_session)
demo.unload(end_session)
multiimage_prompt.upload(
preprocess_images,
inputs=[multiimage_prompt],
outputs=[multiimage_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
outputs=[output_buf, video_output],
).then(
lambda: gr.Button(interactive=True),
outputs=[extract_glb_btn],
)
video_output.clear(
lambda: gr.Button(interactive=False),
outputs=[extract_glb_btn],
)
extract_glb_btn.click(
extract_glb,
inputs=[output_buf, mesh_simplify, texture_size],
outputs=[model_output, download_glb],
).then(
lambda: gr.Button(interactive=True),
outputs=[download_glb],
)
model_output.clear(
lambda: gr.Button(interactive=False),
outputs=[download_glb],
)
# --- Add this section to explicitly register the API function for image to 3D ---
# These State components are placeholders for API-only inputs
api_image_inputs_state = gr.State(value=[]) # For List[str] of image_inputs
api_input_type_state = gr.State(value="url") # For input_type: "url", "filepath", or "base64"
api_model_description_state = gr.State(value="ImagenModel") # For model_description
with gr.Row(visible=False): # Hide this row in the UI
api_image_gen_trigger_btn = gr.Button("API Image-to-3D Trigger")
# Output for the API call (can be a dummy Textbox)
api_image_gen_output_url = gr.Textbox(label="Generated Model URL (API)", visible=False)
api_image_gen_trigger_btn.click(
generate_model_from_images_and_upload,
inputs=[ # Order must match the Python function's parameters
api_image_inputs_state,
api_input_type_state,
seed, # UI component
ss_guidance_strength, # UI component
ss_sampling_steps, # UI component
slat_guidance_strength, # UI component
slat_sampling_steps, # UI component
multiimage_algo, # UI component
mesh_simplify, # UI component
texture_size, # UI component
api_model_description_state,
],
outputs=[api_image_gen_output_url],
api_name="generate_model_from_images_and_upload" # Critical: Register the API name
)
# --- End API registration section ---
# Lanzar la aplicaci贸n Gradio
if __name__ == "__main__":
pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
pipeline.cuda()
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Precargar rembg
except:
pass
demo.launch(show_error=True)