File size: 3,377 Bytes
dd90875
eddf346
 
 
7514f6b
 
 
 
 
eddf346
 
dd90875
 
eddf346
 
 
 
dd90875
eddf346
dd90875
eddf346
b09bc29
dd90875
 
eddf346
 
dd90875
 
eddf346
dd90875
 
 
 
 
eddf346
 
 
 
dd90875
eddf346
dd90875
 
eddf346
 
 
dd90875
eddf346
 
dd90875
 
 
 
 
 
 
eddf346
 
dd90875
 
 
eddf346
dd90875
eddf346
dd90875
eddf346
dd90875
 
 
 
eddf346
dd90875
eddf346
dd90875
eddf346
dd90875
eddf346
dd90875
 
 
 
 
 
 
eddf346
dd90875
 
 
 
 
eddf346
dd90875
 
eddf346
dd90875
eddf346
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
# Must happen before Trellis / huggingface_hub is imported
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
os.makedirs("/tmp/huggingface_cache", exist_ok=True)

os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"

# Create the directory
os.makedirs("/tmp/numba_cache", exist_ok=True)

import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import FileResponse

# Trellis pipeline imports
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import postprocessing_utils

# Use /tmp/space_tmp for user data & avoid read-only /app
TMP_DIR = "/tmp/space_tmp"
os.makedirs(TMP_DIR, exist_ok=True)

# Load the pipeline (no extra args like cache_dir)
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()

# Preload the model (avoids cold-start latencies)
try:
    pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
except:
    pass

# Read your HF_API_KEY from Secrets (set in Space settings)
HF_API_KEY = os.getenv("HF_API_KEY")
if not HF_API_KEY:
    raise RuntimeError("No HF_API_KEY found. Please set a secret in your Space settings.")

# FastAPI App
app = FastAPI()

# (Optional) Limit max input image size
MAX_IMAGE_SIZE = (1024, 1024)

def preprocess_image(image: Image.Image) -> Image.Image:
    """Resize large images to keep memory usage in check, then let Trellis do its own preprocessing."""
    image.thumbnail(MAX_IMAGE_SIZE)
    return pipeline.preprocess_image(image)

@app.post("/generate_3d/")
async def generate_3d(
    image: UploadFile = File(...),
    authorization: str = Header(None)
):
    """Accept an image upload and return a .glb file of the 3D model."""
    # Enforce HF_API_KEY check
    if authorization != f"Bearer {HF_API_KEY}":
        raise HTTPException(status_code=403, detail="Invalid API key")

    # Require PNG/JPG
    if not image.filename.lower().endswith(("png", "jpg", "jpeg")):
        raise HTTPException(status_code=400, detail="Upload PNG or JPG images.")

    # Save upload to /tmp
    image_path = os.path.join(TMP_DIR, image.filename)
    with open(image_path, "wb") as f:
        f.write(image.file.read())

    # Preprocess the image
    img = Image.open(image_path).convert("RGBA")
    processed = preprocess_image(img)

    # Run Trellis pipeline
    outputs = pipeline.run(
        processed,
        seed=np.random.randint(0, np.iinfo(np.int32).max),
        formats=["gaussian", "mesh"],
        preprocess_image=False,
        sparse_structure_sampler_params={"steps": 12, "cfg_strength": 7.5},
        slat_sampler_params={"steps": 12, "cfg_strength": 3.0},
    )

    # Extract and save the GLB
    gs, mesh = outputs["gaussian"][0], outputs["mesh"][0]
    glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.95, texture_size=1024, verbose=False)
    glb_path = os.path.join(TMP_DIR, "sample.glb")
    glb.export(glb_path)

    # Clear GPU memory
    torch.cuda.empty_cache()

    # Return the GLB to the client
    return FileResponse(glb_path, media_type="model/gltf-binary", filename="sample.glb")

# If you want to run locally or override CMD in Docker:
if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", "7860"))
    uvicorn.run(app, host="0.0.0.0", port=port)