test / main.py
turome-learning's picture
Update main.py
7514f6b verified
import os
# Must happen before Trellis / huggingface_hub is imported
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
os.makedirs("/tmp/huggingface_cache", exist_ok=True)
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
# Create the directory
os.makedirs("/tmp/numba_cache", exist_ok=True)
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from fastapi.responses import FileResponse
# Trellis pipeline imports
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import postprocessing_utils
# Use /tmp/space_tmp for user data & avoid read-only /app
TMP_DIR = "/tmp/space_tmp"
os.makedirs(TMP_DIR, exist_ok=True)
# Load the pipeline (no extra args like cache_dir)
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()
# Preload the model (avoids cold-start latencies)
try:
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
except:
pass
# Read your HF_API_KEY from Secrets (set in Space settings)
HF_API_KEY = os.getenv("HF_API_KEY")
if not HF_API_KEY:
raise RuntimeError("No HF_API_KEY found. Please set a secret in your Space settings.")
# FastAPI App
app = FastAPI()
# (Optional) Limit max input image size
MAX_IMAGE_SIZE = (1024, 1024)
def preprocess_image(image: Image.Image) -> Image.Image:
"""Resize large images to keep memory usage in check, then let Trellis do its own preprocessing."""
image.thumbnail(MAX_IMAGE_SIZE)
return pipeline.preprocess_image(image)
@app.post("/generate_3d/")
async def generate_3d(
image: UploadFile = File(...),
authorization: str = Header(None)
):
"""Accept an image upload and return a .glb file of the 3D model."""
# Enforce HF_API_KEY check
if authorization != f"Bearer {HF_API_KEY}":
raise HTTPException(status_code=403, detail="Invalid API key")
# Require PNG/JPG
if not image.filename.lower().endswith(("png", "jpg", "jpeg")):
raise HTTPException(status_code=400, detail="Upload PNG or JPG images.")
# Save upload to /tmp
image_path = os.path.join(TMP_DIR, image.filename)
with open(image_path, "wb") as f:
f.write(image.file.read())
# Preprocess the image
img = Image.open(image_path).convert("RGBA")
processed = preprocess_image(img)
# Run Trellis pipeline
outputs = pipeline.run(
processed,
seed=np.random.randint(0, np.iinfo(np.int32).max),
formats=["gaussian", "mesh"],
preprocess_image=False,
sparse_structure_sampler_params={"steps": 12, "cfg_strength": 7.5},
slat_sampler_params={"steps": 12, "cfg_strength": 3.0},
)
# Extract and save the GLB
gs, mesh = outputs["gaussian"][0], outputs["mesh"][0]
glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.95, texture_size=1024, verbose=False)
glb_path = os.path.join(TMP_DIR, "sample.glb")
glb.export(glb_path)
# Clear GPU memory
torch.cuda.empty_cache()
# Return the GLB to the client
return FileResponse(glb_path, media_type="model/gltf-binary", filename="sample.glb")
# If you want to run locally or override CMD in Docker:
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)