Spaces:
Paused
Paused
| import gradio as gr | |
| import spaces | |
| import os | |
| import shutil | |
| os.environ['SPCONV_ALGO'] = 'native' | |
| from typing import * | |
| from typing import Optional | |
| import torch | |
| import numpy as np | |
| import imageio | |
| from easydict import EasyDict as edict | |
| from PIL import Image | |
| from trellis.pipelines import TrellisImageTo3DPipeline | |
| from trellis.representations import Gaussian, MeshExtractResult | |
| from trellis.utils import render_utils, postprocessing_utils | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from fastapi.staticfiles import StaticFiles | |
| # 允许 NumPy 加载 pickle 文件 | |
| # 这会改变 NumPy 的全局安全设置,但在这个受控环境中是安全的 | |
| np._no_npy2_warning = True # 禁止相关警告 | |
| np.load.__defaults__ = (*np.load.__defaults__[:-3], True, None, None) # 默认允许 pickle=True | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| # FastAPI 应用 | |
| app = FastAPI(title="TRELLIS 3D API") | |
| # 添加 CORS 中间件 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # API 模型 | |
| class GenerationParams(BaseModel): | |
| seed: int = 0 | |
| ss_guidance_strength: float = 7.5 | |
| ss_sampling_steps: int = 12 | |
| slat_guidance_strength: float = 3.0 | |
| slat_sampling_steps: int = 12 | |
| multiimage_algo: str = "stochastic" | |
| class GLBParams(BaseModel): | |
| mesh_simplify: float = 0.95 | |
| texture_size: int = 1024 | |
| # Funciones auxiliares | |
| def start_session(req: gr.Request): | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| def end_session(req: gr.Request): | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| shutil.rmtree(user_dir) | |
| def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: | |
| images = [image[0] for image in images] | |
| processed_images = [pipeline.preprocess_image(image) for image in images] | |
| return processed_images | |
| def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: | |
| # 确保所有初始化参数被正确保存 | |
| init_params = {} | |
| for k, v in gs.init_params.items(): | |
| if isinstance(v, np.ndarray): | |
| init_params[k] = v # 保持 numpy 数组格式 | |
| elif isinstance(v, (list, tuple)): | |
| init_params[k] = np.array(v) # 转换为 numpy 数组 | |
| else: | |
| init_params[k] = v # 保持原始格式 | |
| return { | |
| 'gaussian': { | |
| **init_params, | |
| '_xyz': gs._xyz.cpu().numpy(), | |
| '_features_dc': gs._features_dc.cpu().numpy(), | |
| '_scaling': gs._scaling.cpu().numpy(), | |
| '_rotation': gs._rotation.cpu().numpy(), | |
| '_opacity': gs._opacity.cpu().numpy(), | |
| }, | |
| 'mesh': { | |
| 'vertices': mesh.vertices.cpu().numpy(), | |
| 'faces': mesh.faces.cpu().numpy(), | |
| }, | |
| } | |
| def unpack_state(state: dict) -> Tuple[Gaussian, edict]: | |
| gs = Gaussian( | |
| aabb=state['gaussian']['aabb'], | |
| sh_degree=state['gaussian']['sh_degree'], | |
| mininum_kernel_size=state['gaussian']['mininum_kernel_size'], | |
| scaling_bias=state['gaussian']['scaling_bias'], | |
| opacity_bias=state['gaussian']['opacity_bias'], | |
| scaling_activation=state['gaussian']['scaling_activation'], | |
| ) | |
| gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') | |
| gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') | |
| gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') | |
| gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') | |
| gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') | |
| mesh = edict( | |
| vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), | |
| faces=torch.tensor(state['mesh']['faces'], device='cuda'), | |
| ) | |
| return gs, mesh | |
| def get_seed(randomize_seed: bool, seed: int) -> int: | |
| return np.random.randint(0, MAX_SEED) if randomize_seed else seed | |
| def image_to_3d( | |
| multiimages: List[Tuple[Image.Image, str]], | |
| seed: int, | |
| ss_guidance_strength: float, | |
| ss_sampling_steps: int, | |
| slat_guidance_strength: float, | |
| slat_sampling_steps: int, | |
| multiimage_algo: Literal["multidiffusion", "stochastic"], | |
| req: gr.Request, | |
| ) -> Tuple[dict, str]: | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| outputs = pipeline.run_multi_image( | |
| [image[0] for image in multiimages], | |
| seed=seed, | |
| formats=["gaussian", "mesh"], | |
| preprocess_image=False, | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "cfg_strength": ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "cfg_strength": slat_guidance_strength, | |
| }, | |
| mode=multiimage_algo, | |
| ) | |
| video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] | |
| video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] | |
| video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] | |
| video_path = os.path.join(user_dir, 'sample.mp4') | |
| imageio.mimsave(video_path, video, fps=15) | |
| state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) | |
| torch.cuda.empty_cache() | |
| return state, video_path | |
| def extract_glb( | |
| state: dict, | |
| mesh_simplify: float, | |
| texture_size: int, | |
| req: gr.Request, | |
| ) -> Tuple[str, str]: | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| gs, mesh = unpack_state(state) | |
| glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) | |
| glb_path = os.path.join(user_dir, 'sample.glb') | |
| glb.export(glb_path) | |
| torch.cuda.empty_cache() | |
| return glb_path, glb_path | |
| # FastAPI 路由 | |
| async def api_generate_3d( | |
| files: List[UploadFile] = File(...), | |
| params: str = Form(None) | |
| ): | |
| if params: | |
| params = GenerationParams.parse_raw(params) | |
| else: | |
| params = GenerationParams() | |
| # 创建临时目录 | |
| session_id = str(np.random.randint(0, MAX_SEED)) | |
| user_dir = os.path.join(TMP_DIR, session_id) | |
| os.makedirs(user_dir, exist_ok=True) | |
| try: | |
| # 处理上传的图片 | |
| images = [] | |
| for file in files: | |
| image = Image.open(file.file) | |
| images.append(image) | |
| # 运行生成 | |
| outputs = pipeline.run_multi_image( | |
| images, | |
| seed=params.seed, | |
| formats=["gaussian", "mesh"], | |
| preprocess_image=False, | |
| sparse_structure_sampler_params={ | |
| "steps": params.ss_sampling_steps, | |
| "cfg_strength": params.ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": params.slat_sampling_steps, | |
| "cfg_strength": params.slat_guidance_strength, | |
| }, | |
| mode=params.multiimage_algo, | |
| ) | |
| # 生成预览视频 | |
| video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] | |
| video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] | |
| video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] | |
| video_path = os.path.join(user_dir, 'preview.mp4') | |
| imageio.mimsave(video_path, video, fps=15) | |
| # 保存状态 | |
| state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) | |
| state_path = os.path.join(user_dir, 'state.npz') | |
| # 改进保存方式,确保嵌套结构可被正确恢复 | |
| # 将嵌套结构分别保存为PyObject | |
| np.savez( | |
| state_path, | |
| gaussian=np.array(state['gaussian'], dtype=object), | |
| mesh=np.array(state['mesh'], dtype=object) | |
| ) | |
| return { | |
| "session_id": session_id, | |
| "preview_url": f"/api/preview/{session_id}", | |
| "state_url": f"/api/state/{session_id}" | |
| } | |
| except Exception as e: | |
| shutil.rmtree(user_dir) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_extract_glb(request: Request): | |
| try: | |
| # 解析请求数据 | |
| data = await request.json() | |
| session_id = data.get("session_id") | |
| if not session_id: | |
| raise HTTPException(status_code=422, detail="Missing session_id in request body") | |
| params_data = data.get("params", {}) | |
| params_obj = GLBParams(**params_data) if params_data else GLBParams() | |
| user_dir = os.path.join(TMP_DIR, session_id) | |
| if not os.path.exists(user_dir): | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # 加载状态 - 添加encoding参数 | |
| state_path = os.path.join(user_dir, 'state.npz') | |
| state_file = np.load(state_path, allow_pickle=True, encoding='latin1') | |
| # 首先尝试打印出state_file的内容以便调试 | |
| import logging | |
| logging.warning(f"Available keys in state file: {state_file.files}") | |
| # 使用新的保存方法的格式加载 | |
| if 'gaussian' in state_file.files and 'mesh' in state_file.files: | |
| # 直接获取gaussian和mesh对象 | |
| gaussian_obj = state_file['gaussian'].item() | |
| mesh_obj = state_file['mesh'].item() | |
| state = { | |
| 'gaussian': gaussian_obj, | |
| 'mesh': mesh_obj | |
| } | |
| logging.warning(f"Successfully loaded state with new format") | |
| else: | |
| # 回退到旧方法 | |
| state = {'gaussian': {}, 'mesh': {}} | |
| # 尝试旧格式 | |
| for k in state_file.files: | |
| if k.startswith('gaussian.'): | |
| subkey = k.replace('gaussian.', '') | |
| state['gaussian'][subkey] = state_file[k].item() if state_file[k].ndim == 0 else state_file[k] | |
| elif k.startswith('mesh.'): | |
| subkey = k.replace('mesh.', '') | |
| state['mesh'][subkey] = state_file[k].item() if state_file[k].ndim == 0 else state_file[k] | |
| elif k.startswith('gaussian/'): | |
| subkey = k.replace('gaussian/', '') | |
| state['gaussian'][subkey] = state_file[k].item() if state_file[k].ndim == 0 else state_file[k] | |
| elif k.startswith('mesh/'): | |
| subkey = k.replace('mesh/', '') | |
| state['mesh'][subkey] = state_file[k].item() if state_file[k].ndim == 0 else state_file[k] | |
| logging.warning(f"Loaded state with legacy format") | |
| # 检查是否成功获取到了必要的数据 | |
| if not state['gaussian'] or not state['mesh']: | |
| raise ValueError("无法正确加载状态数据,缺少关键字段") | |
| logging.warning(f"State gaussian keys: {state['gaussian'].keys()}") | |
| logging.warning(f"State mesh keys: {state['mesh'].keys()}") | |
| # 生成 GLB | |
| gs, mesh = unpack_state(state) | |
| glb = postprocessing_utils.to_glb( | |
| gs, | |
| mesh, | |
| simplify=params_obj.mesh_simplify, | |
| texture_size=params_obj.texture_size, | |
| verbose=False | |
| ) | |
| glb_path = os.path.join(user_dir, 'model.glb') | |
| glb.export(glb_path) | |
| return {"glb_url": f"/api/glb/{session_id}"} | |
| except Exception as e: | |
| import traceback | |
| error_detail = f"{str(e)}\n{traceback.format_exc()}" | |
| raise HTTPException(status_code=500, detail=error_detail) | |
| async def api_get_preview(session_id: str): | |
| preview_path = os.path.join(TMP_DIR, session_id, 'preview.mp4') | |
| if not os.path.exists(preview_path): | |
| raise HTTPException(status_code=404, detail="Preview not found") | |
| return FileResponse(preview_path) | |
| async def api_get_glb(session_id: str): | |
| glb_path = os.path.join(TMP_DIR, session_id, 'model.glb') | |
| if not os.path.exists(glb_path): | |
| raise HTTPException(status_code=404, detail="GLB not found") | |
| return FileResponse(glb_path) | |
| async def api_get_state(session_id: str): | |
| state_path = os.path.join(TMP_DIR, session_id, 'state.npz') | |
| if not os.path.exists(state_path): | |
| raise HTTPException(status_code=404, detail="State not found") | |
| return FileResponse(state_path) | |
| async def api_debug_state(session_id: str): | |
| """用于调试状态文件结构的端点""" | |
| state_path = os.path.join(TMP_DIR, session_id, 'state.npz') | |
| if not os.path.exists(state_path): | |
| raise HTTPException(status_code=404, detail="State not found") | |
| try: | |
| # 加载状态 | |
| state_file = np.load(state_path, allow_pickle=True, encoding='latin1') | |
| # 提取调试信息 | |
| debug_info = { | |
| "keys": list(state_file.files), | |
| "shapes": {}, | |
| "dtypes": {}, | |
| "sample_values": {} | |
| } | |
| # 分析每个键 | |
| for k in state_file.files: | |
| arr = state_file[k] | |
| debug_info["shapes"][k] = str(arr.shape) | |
| debug_info["dtypes"][k] = str(arr.dtype) | |
| # 尝试获取样本值 | |
| if arr.ndim == 0 and arr.dtype == np.dtype('O'): | |
| obj = arr.item() | |
| if isinstance(obj, dict): | |
| debug_info["sample_values"][k] = {"type": "dict", "keys": list(obj.keys())} | |
| else: | |
| debug_info["sample_values"][k] = {"type": str(type(obj))} | |
| return debug_info | |
| except Exception as e: | |
| import traceback | |
| return {"error": str(e), "traceback": traceback.format_exc()} | |
| # Gradio 界面 | |
| with gr.Blocks(delete_cache=(600, 600)) as demo: | |
| gr.Markdown(""" | |
| # UTPL - Conversión de Multiples Imágenes a objetos 3D usando IA | |
| ### Tesis: *"Objetos tridimensionales creados por IA: Innovación en entornos virtuales"* | |
| **Autor:** Carlos Vargas | |
| **Base técnica:** Adaptación de [TRELLIS](https://trellis3d.github.io/) (herramienta de código abierto para generación 3D) | |
| **Propósito educativo:** Demostraciones académicas e Investigación en modelado 3D automático | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tabs() as input_tabs: | |
| with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab: | |
| multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) | |
| with gr.Accordion(label="Generation Settings", open=False): | |
| seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| gr.Markdown("Stage 1: Sparse Structure Generation") | |
| with gr.Row(): | |
| ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) | |
| ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) | |
| gr.Markdown("Stage 2: Structured Latent Generation") | |
| with gr.Row(): | |
| slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) | |
| slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) | |
| multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic") | |
| generate_btn = gr.Button("Generate") | |
| with gr.Accordion(label="GLB Extraction Settings", open=False): | |
| mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) | |
| texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) | |
| extract_glb_btn = gr.Button("Extract GLB", interactive=False) | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) | |
| model_output = gr.Model3D(label="Extracted GLB", height=300) | |
| download_glb = gr.DownloadButton(label="Download GLB", interactive=False) | |
| output_buf = gr.State() | |
| # Manejadores | |
| demo.load(start_session) | |
| demo.unload(end_session) | |
| multiimage_prompt.upload( | |
| preprocess_images, | |
| inputs=[multiimage_prompt], | |
| outputs=[multiimage_prompt], | |
| ) | |
| generate_btn.click( | |
| get_seed, | |
| inputs=[randomize_seed, seed], | |
| outputs=[seed], | |
| ).then( | |
| image_to_3d, | |
| inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo], | |
| outputs=[output_buf, video_output], | |
| ).then( | |
| lambda: gr.Button(interactive=True), | |
| outputs=[extract_glb_btn], | |
| ) | |
| video_output.clear( | |
| lambda: gr.Button(interactive=False), | |
| outputs=[extract_glb_btn], | |
| ) | |
| extract_glb_btn.click( | |
| extract_glb, | |
| inputs=[output_buf, mesh_simplify, texture_size], | |
| outputs=[model_output, download_glb], | |
| ).then( | |
| lambda: gr.Button(interactive=True), | |
| outputs=[download_glb], | |
| ) | |
| model_output.clear( | |
| lambda: gr.Button(interactive=False), | |
| outputs=[download_glb], | |
| ) | |
| # 挂载 Gradio 应用到 FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS") | |
| pipeline.cuda() | |
| try: | |
| pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # 预加载 rembg | |
| except: | |
| pass | |
| # 使用 uvicorn 启动 FastAPI 应用 | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| workers=1, # 由于 GPU 限制,使用单工作进程 | |
| loop="uvloop", | |
| http="httptools" | |
| ) |