demohug commited on
Commit
92318e3
·
1 Parent(s): 701c397
Files changed (4) hide show
  1. .dockerignore +13 -0
  2. Dockerfile +24 -0
  3. api.py +217 -0
  4. requirements.txt +5 -0
.dockerignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .gitignore
3
+ __pycache__
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ .Python
8
+ env/
9
+ venv/
10
+ .env
11
+ .venv
12
+ tmp/
13
+ *.log
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # 安装系统依赖
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # 复制依赖文件
12
+ COPY requirements.txt .
13
+
14
+ # 安装 Python 依赖
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # 复制应用代码
18
+ COPY . .
19
+
20
+ # 暴露端口
21
+ EXPOSE 7860
22
+
23
+ # 启动命令
24
+ CMD ["python", "api.py"]
api.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import FileResponse
4
+ import uvicorn
5
+ import os
6
+ import shutil
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ from typing import List, Optional
11
+ from pydantic import BaseModel
12
+ import imageio
13
+ from trellis.pipelines import TrellisImageTo3DPipeline
14
+ from trellis.representations import Gaussian, MeshExtractResult
15
+ from trellis.utils import render_utils, postprocessing_utils
16
+ from easydict import EasyDict as edict
17
+
18
+ app = FastAPI(title="TRELLIS 3D API")
19
+
20
+ # 添加 CORS 中间件
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"], # 在生产环境中应该设置具体的域名
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # 配置
30
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
31
+ os.makedirs(TMP_DIR, exist_ok=True)
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+
34
+ # 初始化 pipeline
35
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("cavargas10/TRELLIS")
36
+ pipeline.cuda()
37
+ try:
38
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # 预加载 rembg
39
+ except:
40
+ pass
41
+
42
+ class GenerationParams(BaseModel):
43
+ seed: int = 0
44
+ ss_guidance_strength: float = 7.5
45
+ ss_sampling_steps: int = 12
46
+ slat_guidance_strength: float = 3.0
47
+ slat_sampling_steps: int = 12
48
+ multiimage_algo: str = "stochastic"
49
+
50
+ class GLBParams(BaseModel):
51
+ mesh_simplify: float = 0.95
52
+ texture_size: int = 1024
53
+
54
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
55
+ return {
56
+ 'gaussian': {
57
+ **gs.init_params,
58
+ '_xyz': gs._xyz.cpu().numpy(),
59
+ '_features_dc': gs._features_dc.cpu().numpy(),
60
+ '_scaling': gs._scaling.cpu().numpy(),
61
+ '_rotation': gs._rotation.cpu().numpy(),
62
+ '_opacity': gs._opacity.cpu().numpy(),
63
+ },
64
+ 'mesh': {
65
+ 'vertices': mesh.vertices.cpu().numpy(),
66
+ 'faces': mesh.faces.cpu().numpy(),
67
+ },
68
+ }
69
+
70
+ def unpack_state(state: dict) -> tuple[Gaussian, edict]:
71
+ gs = Gaussian(
72
+ aabb=state['gaussian']['aabb'],
73
+ sh_degree=state['gaussian']['sh_degree'],
74
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
75
+ scaling_bias=state['gaussian']['scaling_bias'],
76
+ opacity_bias=state['gaussian']['opacity_bias'],
77
+ scaling_activation=state['gaussian']['scaling_activation'],
78
+ )
79
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
80
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
81
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
82
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
83
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
84
+ mesh = edict(
85
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
86
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
87
+ )
88
+ return gs, mesh
89
+
90
+ @app.post("/generate")
91
+ async def generate_3d(
92
+ files: List[UploadFile] = File(...),
93
+ params: GenerationParams = None
94
+ ):
95
+ if not params:
96
+ params = GenerationParams()
97
+
98
+ # 创建临时目录
99
+ session_id = str(np.random.randint(0, MAX_SEED))
100
+ user_dir = os.path.join(TMP_DIR, session_id)
101
+ os.makedirs(user_dir, exist_ok=True)
102
+
103
+ try:
104
+ # 处理上传的图片
105
+ images = []
106
+ for file in files:
107
+ image = Image.open(file.file)
108
+ images.append(image)
109
+
110
+ # 运行生成
111
+ outputs = pipeline.run_multi_image(
112
+ images,
113
+ seed=params.seed,
114
+ formats=["gaussian", "mesh"],
115
+ preprocess_image=False,
116
+ sparse_structure_sampler_params={
117
+ "steps": params.ss_sampling_steps,
118
+ "cfg_strength": params.ss_guidance_strength,
119
+ },
120
+ slat_sampler_params={
121
+ "steps": params.slat_sampling_steps,
122
+ "cfg_strength": params.slat_guidance_strength,
123
+ },
124
+ mode=params.multiimage_algo,
125
+ )
126
+
127
+ # 生成预览视频
128
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
129
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
130
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
131
+ video_path = os.path.join(user_dir, 'preview.mp4')
132
+ imageio.mimsave(video_path, video, fps=15)
133
+
134
+ # 保存状态
135
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
136
+ state_path = os.path.join(user_dir, 'state.npz')
137
+ np.savez(state_path, **state)
138
+
139
+ return {
140
+ "session_id": session_id,
141
+ "preview_url": f"/preview/{session_id}",
142
+ "state_url": f"/state/{session_id}"
143
+ }
144
+
145
+ except Exception as e:
146
+ shutil.rmtree(user_dir)
147
+ raise HTTPException(status_code=500, detail=str(e))
148
+
149
+ @app.post("/extract_glb")
150
+ async def extract_glb(
151
+ session_id: str,
152
+ params: GLBParams = None
153
+ ):
154
+ if not params:
155
+ params = GLBParams()
156
+
157
+ user_dir = os.path.join(TMP_DIR, session_id)
158
+ if not os.path.exists(user_dir):
159
+ raise HTTPException(status_code=404, detail="Session not found")
160
+
161
+ try:
162
+ # 加载状态
163
+ state_path = os.path.join(user_dir, 'state.npz')
164
+ state = np.load(state_path)
165
+ state = {k: state[k] for k in state.files}
166
+
167
+ # 生成 GLB
168
+ gs, mesh = unpack_state(state)
169
+ glb = postprocessing_utils.to_glb(
170
+ gs,
171
+ mesh,
172
+ simplify=params.mesh_simplify,
173
+ texture_size=params.texture_size,
174
+ verbose=False
175
+ )
176
+
177
+ glb_path = os.path.join(user_dir, 'model.glb')
178
+ glb.export(glb_path)
179
+
180
+ return {"glb_url": f"/glb/{session_id}"}
181
+
182
+ except Exception as e:
183
+ raise HTTPException(status_code=500, detail=str(e))
184
+
185
+ @app.get("/preview/{session_id}")
186
+ async def get_preview(session_id: str):
187
+ preview_path = os.path.join(TMP_DIR, session_id, 'preview.mp4')
188
+ if not os.path.exists(preview_path):
189
+ raise HTTPException(status_code=404, detail="Preview not found")
190
+ return FileResponse(preview_path)
191
+
192
+ @app.get("/glb/{session_id}")
193
+ async def get_glb(session_id: str):
194
+ glb_path = os.path.join(TMP_DIR, session_id, 'model.glb')
195
+ if not os.path.exists(glb_path):
196
+ raise HTTPException(status_code=404, detail="GLB not found")
197
+ return FileResponse(glb_path)
198
+
199
+ @app.get("/state/{session_id}")
200
+ async def get_state(session_id: str):
201
+ state_path = os.path.join(TMP_DIR, session_id, 'state.npz')
202
+ if not os.path.exists(state_path):
203
+ raise HTTPException(status_code=404, detail="State not found")
204
+ return FileResponse(state_path)
205
+
206
+ if __name__ == "__main__":
207
+ # 在 Hugging Face Spaces 中,我们需要使用 0.0.0.0 作为主机
208
+ # 端口 7860 是 Hugging Face Spaces 的默认端口
209
+ uvicorn.run(
210
+ app,
211
+ host="0.0.0.0",
212
+ port=7860,
213
+ # 添加以下配置以提高性能
214
+ workers=1, # 由于 GPU 限制,使用单工作进程
215
+ loop="uvloop",
216
+ http="httptools"
217
+ )
requirements.txt CHANGED
@@ -20,6 +20,11 @@ git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057
20
  xformers==0.0.27.post2
21
  spconv-cu120==2.3.6
22
  transformers==4.46.3
 
 
 
 
 
23
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
24
  https://huggingface.co/spaces/cavargas10/TRELLIS-Multiple3D/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
25
  https://huggingface.co/spaces/cavargas10/TRELLIS-Multiple3D/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
 
20
  xformers==0.0.27.post2
21
  spconv-cu120==2.3.6
22
  transformers==4.46.3
23
+ fastapi==0.110.0
24
+ uvicorn[standard]==0.27.1
25
+ python-multipart==0.0.9
26
+ uvloop==0.19.0
27
+ httptools==0.6.1
28
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
29
  https://huggingface.co/spaces/cavargas10/TRELLIS-Multiple3D/resolve/main/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl?download=true
30
  https://huggingface.co/spaces/cavargas10/TRELLIS-Multiple3D/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true