turome-learning commited on
Commit
eddf346
·
verified ·
1 Parent(s): 7e0992b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -32
main.py CHANGED
@@ -1,44 +1,46 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Header
2
- from fastapi.responses import FileResponse
3
- import torch
4
  import os
 
 
 
 
 
5
  import numpy as np
6
- import cv2
7
  from PIL import Image
8
- from typing import List
 
 
 
9
  from trellis.pipelines import TrellisImageTo3DPipeline
10
- from trellis.utils import render_utils, postprocessing_utils
11
- from trellis.representations import Gaussian, MeshExtractResult
12
- import imageio
13
 
14
- # Define working directories
15
  TMP_DIR = "/tmp/space_tmp"
16
  os.makedirs(TMP_DIR, exist_ok=True)
17
 
18
- # Define a writable cache directory
19
- cache_dir = "/tmp/huggingface_cache"
20
- os.makedirs(cache_dir, exist_ok=True)
21
-
22
- # ✅ Manually specify cache directory when loading the model
23
- pipeline = TrellisImageTo3DPipeline.from_pretrained(
24
- "JeffreyXiang/TRELLIS-image-large",
25
- cache_dir=cache_dir
26
- )
27
  pipeline.cuda()
28
 
29
- # Preload model (to prevent cold starts)
30
  try:
31
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
32
  except:
33
  pass
34
 
35
- # API Key for private access
36
- HF_API_KEY = os.getenv("HF_API_KEY", "your-secure-api-key")
 
 
37
 
 
38
  app = FastAPI()
39
 
 
 
 
40
  def preprocess_image(image: Image.Image) -> Image.Image:
41
- """Preprocess a single input image using the Trellis pipeline."""
 
42
  return pipeline.preprocess_image(image)
43
 
44
  @app.post("/generate_3d/")
@@ -46,27 +48,27 @@ async def generate_3d(
46
  image: UploadFile = File(...),
47
  authorization: str = Header(None)
48
  ):
49
- """Accepts an image upload, runs inference, and returns a GLB file."""
50
-
51
- # 🔒 API Key authentication
52
  if authorization != f"Bearer {HF_API_KEY}":
53
  raise HTTPException(status_code=403, detail="Invalid API key")
54
 
 
55
  if not image.filename.lower().endswith(("png", "jpg", "jpeg")):
56
- raise HTTPException(status_code=400, detail="Invalid image format. Upload a PNG or JPG.")
57
 
58
- # Save the uploaded image
59
  image_path = os.path.join(TMP_DIR, image.filename)
60
  with open(image_path, "wb") as f:
61
  f.write(image.file.read())
62
 
63
- # Load and preprocess the image
64
  img = Image.open(image_path).convert("RGBA")
65
- processed_image = preprocess_image(img)
66
 
67
- # Run the Trellis pipeline
68
  outputs = pipeline.run(
69
- processed_image,
70
  seed=np.random.randint(0, np.iinfo(np.int32).max),
71
  formats=["gaussian", "mesh"],
72
  preprocess_image=False,
@@ -74,12 +76,20 @@ async def generate_3d(
74
  slat_sampler_params={"steps": 12, "cfg_strength": 3.0},
75
  )
76
 
77
- # Extract the GLB file
78
  gs, mesh = outputs["gaussian"][0], outputs["mesh"][0]
79
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.95, texture_size=1024, verbose=False)
80
  glb_path = os.path.join(TMP_DIR, "sample.glb")
81
  glb.export(glb_path)
82
 
 
83
  torch.cuda.empty_cache()
84
 
 
85
  return FileResponse(glb_path, media_type="model/gltf-binary", filename="sample.glb")
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ # Must happen before Trellis / huggingface_hub is imported
3
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
4
+ os.makedirs("/tmp/huggingface_cache", exist_ok=True)
5
+
6
+ import torch
7
  import numpy as np
 
8
  from PIL import Image
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Header
10
+ from fastapi.responses import FileResponse
11
+
12
+ # Trellis pipeline imports
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
+ from trellis.utils import postprocessing_utils
 
 
15
 
16
+ # Use /tmp/space_tmp for user data & avoid read-only /app
17
  TMP_DIR = "/tmp/space_tmp"
18
  os.makedirs(TMP_DIR, exist_ok=True)
19
 
20
+ # Load the pipeline (no extra args like cache_dir)
21
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
 
 
 
 
 
 
 
22
  pipeline.cuda()
23
 
24
+ # Preload the model (avoids cold-start latencies)
25
  try:
26
  pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
27
  except:
28
  pass
29
 
30
+ # Read your HF_API_KEY from Secrets (set in Space settings)
31
+ HF_API_KEY = os.getenv("HF_API_KEY")
32
+ if not HF_API_KEY:
33
+ raise RuntimeError("No HF_API_KEY found. Please set a secret in your Space settings.")
34
 
35
+ # FastAPI App
36
  app = FastAPI()
37
 
38
+ # (Optional) Limit max input image size
39
+ MAX_IMAGE_SIZE = (1024, 1024)
40
+
41
  def preprocess_image(image: Image.Image) -> Image.Image:
42
+ """Resize large images to keep memory usage in check, then let Trellis do its own preprocessing."""
43
+ image.thumbnail(MAX_IMAGE_SIZE)
44
  return pipeline.preprocess_image(image)
45
 
46
  @app.post("/generate_3d/")
 
48
  image: UploadFile = File(...),
49
  authorization: str = Header(None)
50
  ):
51
+ """Accept an image upload and return a .glb file of the 3D model."""
52
+ # Enforce HF_API_KEY check
 
53
  if authorization != f"Bearer {HF_API_KEY}":
54
  raise HTTPException(status_code=403, detail="Invalid API key")
55
 
56
+ # Require PNG/JPG
57
  if not image.filename.lower().endswith(("png", "jpg", "jpeg")):
58
+ raise HTTPException(status_code=400, detail="Upload PNG or JPG images.")
59
 
60
+ # Save upload to /tmp
61
  image_path = os.path.join(TMP_DIR, image.filename)
62
  with open(image_path, "wb") as f:
63
  f.write(image.file.read())
64
 
65
+ # Preprocess the image
66
  img = Image.open(image_path).convert("RGBA")
67
+ processed = preprocess_image(img)
68
 
69
+ # Run Trellis pipeline
70
  outputs = pipeline.run(
71
+ processed,
72
  seed=np.random.randint(0, np.iinfo(np.int32).max),
73
  formats=["gaussian", "mesh"],
74
  preprocess_image=False,
 
76
  slat_sampler_params={"steps": 12, "cfg_strength": 3.0},
77
  )
78
 
79
+ # Extract and save the GLB
80
  gs, mesh = outputs["gaussian"][0], outputs["mesh"][0]
81
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=0.95, texture_size=1024, verbose=False)
82
  glb_path = os.path.join(TMP_DIR, "sample.glb")
83
  glb.export(glb_path)
84
 
85
+ # Clear GPU memory
86
  torch.cuda.empty_cache()
87
 
88
+ # Return the GLB to the client
89
  return FileResponse(glb_path, media_type="model/gltf-binary", filename="sample.glb")
90
+
91
+ # If you want to run locally or override CMD in Docker:
92
+ if __name__ == "__main__":
93
+ import uvicorn
94
+ port = int(os.environ.get("PORT", "7860"))
95
+ uvicorn.run(app, host="0.0.0.0", port=port)