Spaces:
Paused
Paused
| import os | |
| # Must happen before Trellis / huggingface_hub is imported | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache" | |
| os.makedirs("/tmp/huggingface_cache", exist_ok=True) | |
| os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache" | |
| # Create the directory | |
| os.makedirs("/tmp/numba_cache", exist_ok=True) | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Header | |
| from fastapi.responses import FileResponse | |
| # Trellis pipeline imports | |
| from trellis.pipelines import TrellisImageTo3DPipeline | |
| from trellis.utils import postprocessing_utils | |
| # Use /tmp/space_tmp for user data & avoid read-only /app | |
| TMP_DIR = "/tmp/space_tmp" | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| # Load the pipeline (no extra args like cache_dir) | |
| pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") | |
| pipeline.cuda() | |
| # Preload the model (avoids cold-start latencies) | |
| try: | |
| pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) | |
| except: | |
| pass | |
| # Read your HF_API_KEY from Secrets (set in Space settings) | |
| HF_API_KEY = os.getenv("HF_API_KEY") | |
| if not HF_API_KEY: | |
| raise RuntimeError("No HF_API_KEY found. Please set a secret in your Space settings.") | |
| # FastAPI App | |
| app = FastAPI() | |
| # (Optional) Limit max input image size | |
| MAX_IMAGE_SIZE = (1024, 1024) | |
| def preprocess_image(image: Image.Image) -> Image.Image: | |
| """Resize large images to keep memory usage in check, then let Trellis do its own preprocessing.""" | |
| image.thumbnail(MAX_IMAGE_SIZE) | |
| return pipeline.preprocess_image(image) | |
| async def generate_3d( | |
| image: UploadFile = File(...), | |
| authorization: str = Header(None) | |
| ): | |
| """Accept an image upload and return a .glb file of the 3D model.""" | |
| # Enforce HF_API_KEY check | |
| if authorization != f"Bearer {HF_API_KEY}": | |
| raise HTTPException(status_code=403, detail="Invalid API key") | |
| # Require PNG/JPG | |
| if not image.filename.lower().endswith(("png", "jpg", "jpeg")): | |
| raise HTTPException(status_code=400, detail="Upload PNG or JPG images.") | |
| # Save upload to /tmp | |
| image_path = os.path.join(TMP_DIR, image.filename) | |
| with open(image_path, "wb") as f: | |
| f.write(image.file.read()) | |
| # Preprocess the image | |
| img = Image.open(image_path).convert("RGBA") | |
| processed = preprocess_image(img) | |
| # Run Trellis pipeline | |
| outputs = pipeline.run( | |
| processed, | |
| seed=np.random.randint(0, np.iinfo(np.int32).max), | |
| formats=["gaussian", "mesh"], | |
| preprocess_image=False, | |
| sparse_structure_sampler_params={"steps": 12, "cfg_strength": 7.5}, | |
| slat_sampler_params={"steps": 12, "cfg_strength": 3.0}, | |
| ) | |
| # Extract and save the GLB | |
| gs, mesh = outputs["gaussian"][0], outputs["mesh"][0] | |
| glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.95, texture_size=1024, verbose=False) | |
| glb_path = os.path.join(TMP_DIR, "sample.glb") | |
| glb.export(glb_path) | |
| # Clear GPU memory | |
| torch.cuda.empty_cache() | |
| # Return the GLB to the client | |
| return FileResponse(glb_path, media_type="model/gltf-binary", filename="sample.glb") | |
| # If you want to run locally or override CMD in Docker: | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |