demohug commited on
Commit
d9837fa
·
1 Parent(s): 571da3d
Files changed (2) hide show
  1. app.py +5 -3
  2. index.html +275 -0
app.py CHANGED
@@ -12,7 +12,7 @@ from PIL import Image
12
  from trellis.pipelines import TrellisImageTo3DPipeline
13
  from trellis.representations import Gaussian, MeshExtractResult
14
  from trellis.utils import render_utils, postprocessing_utils
15
- from fastapi import FastAPI, UploadFile, File, HTTPException
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from fastapi.responses import FileResponse
18
  from pydantic import BaseModel
@@ -156,9 +156,11 @@ def extract_glb(
156
  @app.post("/api/generate")
157
  async def api_generate_3d(
158
  files: List[UploadFile] = File(...),
159
- params: GenerationParams = None
160
  ):
161
- if not params:
 
 
162
  params = GenerationParams()
163
 
164
  # 创建临时目录
 
12
  from trellis.pipelines import TrellisImageTo3DPipeline
13
  from trellis.representations import Gaussian, MeshExtractResult
14
  from trellis.utils import render_utils, postprocessing_utils
15
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from fastapi.responses import FileResponse
18
  from pydantic import BaseModel
 
156
  @app.post("/api/generate")
157
  async def api_generate_3d(
158
  files: List[UploadFile] = File(...),
159
+ params: str = Form(None)
160
  ):
161
+ if params:
162
+ params = GenerationParams.parse_raw(params)
163
+ else:
164
  params = GenerationParams()
165
 
166
  # 创建临时目录
index.html ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="zh">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>TRELLIS 3D Demo</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
8
+ <style>
9
+ .preview-container {
10
+ max-width: 100%;
11
+ margin: 20px 0;
12
+ }
13
+ .preview-video {
14
+ max-width: 100%;
15
+ height: auto;
16
+ }
17
+ .loading {
18
+ display: none;
19
+ margin: 20px 0;
20
+ }
21
+ .progress {
22
+ display: none;
23
+ margin: 20px 0;
24
+ }
25
+ .error-message {
26
+ display: none;
27
+ color: red;
28
+ margin: 10px 0;
29
+ }
30
+ </style>
31
+ </head>
32
+ <body>
33
+ <div class="container mt-5">
34
+ <h1 class="mb-4">TRELLIS 3D Demo</h1>
35
+
36
+ <!-- 图片上传区域 -->
37
+ <div class="card mb-4">
38
+ <div class="card-body">
39
+ <h5 class="card-title">上传图片</h5>
40
+ <p class="card-text">请上传1-3张图片来生成3D模型</p>
41
+ <input type="file" class="form-control" id="imageInput" multiple accept="image/*">
42
+ <div class="mt-3">
43
+ <button class="btn btn-primary" id="generateBtn" disabled>生成3D模型</button>
44
+ </div>
45
+ </div>
46
+ </div>
47
+
48
+ <!-- 生成设置 -->
49
+ <div class="card mb-4">
50
+ <div class="card-body">
51
+ <h5 class="card-title">生成设置</h5>
52
+ <div class="row">
53
+ <div class="col-md-6">
54
+ <div class="mb-3">
55
+ <label class="form-label">Stage 1: Sparse Structure Generation</label>
56
+ <div class="mb-2">
57
+ <label class="form-label">Guidance Strength</label>
58
+ <input type="range" class="form-range" id="ssGuidanceStrength" min="0" max="10" step="0.1" value="7.5">
59
+ <span id="ssGuidanceStrengthValue">7.5</span>
60
+ </div>
61
+ <div class="mb-2">
62
+ <label class="form-label">Sampling Steps</label>
63
+ <input type="range" class="form-range" id="ssSamplingSteps" min="1" max="50" step="1" value="12">
64
+ <span id="ssSamplingStepsValue">12</span>
65
+ </div>
66
+ </div>
67
+ </div>
68
+ <div class="col-md-6">
69
+ <div class="mb-3">
70
+ <label class="form-label">Stage 2: Structured Latent Generation</label>
71
+ <div class="mb-2">
72
+ <label class="form-label">Guidance Strength</label>
73
+ <input type="range" class="form-range" id="slatGuidanceStrength" min="0" max="10" step="0.1" value="3.0">
74
+ <span id="slatGuidanceStrengthValue">3.0</span>
75
+ </div>
76
+ <div class="mb-2">
77
+ <label class="form-label">Sampling Steps</label>
78
+ <input type="range" class="form-range" id="slatSamplingSteps" min="1" max="50" step="1" value="12">
79
+ <span id="slatSamplingStepsValue">12</span>
80
+ </div>
81
+ </div>
82
+ </div>
83
+ </div>
84
+ <div class="mb-3">
85
+ <label class="form-label">Multi-image Algorithm</label>
86
+ <select class="form-select" id="multiimageAlgo">
87
+ <option value="stochastic">Stochastic</option>
88
+ <option value="multidiffusion">Multi-diffusion</option>
89
+ </select>
90
+ </div>
91
+ </div>
92
+ </div>
93
+
94
+ <!-- 加载状态 -->
95
+ <div class="loading text-center">
96
+ <div class="spinner-border text-primary" role="status">
97
+ <span class="visually-hidden">Loading...</span>
98
+ </div>
99
+ <p class="mt-2">正在生成3D模型,请稍候...</p>
100
+ </div>
101
+
102
+ <!-- 进度条 -->
103
+ <div class="progress">
104
+ <div class="progress-bar progress-bar-striped progress-bar-animated" role="progressbar" style="width: 0%"></div>
105
+ </div>
106
+
107
+ <!-- 错误信息 -->
108
+ <div class="error-message alert alert-danger"></div>
109
+
110
+ <!-- 预览区域 -->
111
+ <div class="preview-container">
112
+ <h3>预览</h3>
113
+ <video id="previewVideo" class="preview-video" controls style="display: none;"></video>
114
+ </div>
115
+
116
+ <!-- GLB 下载区域 -->
117
+ <div class="card mb-4" id="glbCard" style="display: none;">
118
+ <div class="card-body">
119
+ <h5 class="card-title">GLB 文件</h5>
120
+ <p class="card-text">3D模型已生成,点击下方按钮下载GLB文件</p>
121
+ <button class="btn btn-success" id="downloadGlbBtn">下载GLB文件</button>
122
+ </div>
123
+ </div>
124
+ </div>
125
+
126
+ <script>
127
+ const API_BASE_URL = 'https://demohug-trellis-multiple3d.hf.space';
128
+ let currentSessionId = null;
129
+
130
+ // 更新滑块值显示
131
+ document.querySelectorAll('input[type="range"]').forEach(input => {
132
+ const valueDisplay = document.getElementById(input.id + 'Value');
133
+ input.addEventListener('input', () => {
134
+ valueDisplay.textContent = input.value;
135
+ });
136
+ });
137
+
138
+ // 检查文件选择
139
+ document.getElementById('imageInput').addEventListener('change', function() {
140
+ const generateBtn = document.getElementById('generateBtn');
141
+ generateBtn.disabled = this.files.length < 1 || this.files.length > 3;
142
+ });
143
+
144
+ // 生成3D模型
145
+ document.getElementById('generateBtn').addEventListener('click', async function() {
146
+ const files = document.getElementById('imageInput').files;
147
+ if (files.length < 1 || files.length > 3) {
148
+ showError('请选择1-3张图片');
149
+ return;
150
+ }
151
+
152
+ // 显示加载状态
153
+ showLoading(true);
154
+ hideError();
155
+
156
+ try {
157
+ // 准备表单数据
158
+ const formData = new FormData();
159
+ for (let file of files) {
160
+ formData.append('files', file);
161
+ }
162
+
163
+ // 添加生成参数
164
+ const params = {
165
+ seed: 0,
166
+ ss_guidance_strength: parseFloat(document.getElementById('ssGuidanceStrength').value),
167
+ ss_sampling_steps: parseInt(document.getElementById('ssSamplingSteps').value),
168
+ slat_guidance_strength: parseFloat(document.getElementById('slatGuidanceStrength').value),
169
+ slat_sampling_steps: parseInt(document.getElementById('slatSamplingSteps').value),
170
+ multiimage_algo: document.getElementById('multiimageAlgo').value
171
+ };
172
+ formData.append('params', JSON.stringify(params));
173
+
174
+ // 发送生成请求
175
+ const response = await fetch(`${API_BASE_URL}/api/generate`, {
176
+ method: 'POST',
177
+ body: formData
178
+ });
179
+
180
+ if (!response.ok) {
181
+ throw new Error('生成失败');
182
+ }
183
+
184
+ const result = await response.json();
185
+ currentSessionId = result.session_id;
186
+
187
+ // 显示预览视频
188
+ const video = document.getElementById('previewVideo');
189
+ video.src = `${API_BASE_URL}${result.preview_url}`;
190
+ video.style.display = 'block';
191
+ video.play();
192
+
193
+ // 显示GLB下载按钮
194
+ document.getElementById('glbCard').style.display = 'block';
195
+
196
+ } catch (error) {
197
+ showError(error.message);
198
+ } finally {
199
+ showLoading(false);
200
+ }
201
+ });
202
+
203
+ // 下载GLB文件
204
+ document.getElementById('downloadGlbBtn').addEventListener('click', async function() {
205
+ if (!currentSessionId) {
206
+ showError('请先生成3D模型');
207
+ return;
208
+ }
209
+
210
+ showLoading(true);
211
+ hideError();
212
+
213
+ try {
214
+ // 提取GLB文件
215
+ const extractResponse = await fetch(`${API_BASE_URL}/api/extract_glb`, {
216
+ method: 'POST',
217
+ headers: {
218
+ 'Content-Type': 'application/json'
219
+ },
220
+ body: JSON.stringify({
221
+ session_id: currentSessionId,
222
+ params: {
223
+ mesh_simplify: 0.95,
224
+ texture_size: 1024
225
+ }
226
+ })
227
+ });
228
+
229
+ if (!extractResponse.ok) {
230
+ throw new Error('GLB提取失败');
231
+ }
232
+
233
+ const extractResult = await extractResponse.json();
234
+
235
+ // 下载GLB文件
236
+ const glbResponse = await fetch(`${API_BASE_URL}${extractResult.glb_url}`);
237
+ if (!glbResponse.ok) {
238
+ throw new Error('GLB下载失败');
239
+ }
240
+
241
+ const blob = await glbResponse.blob();
242
+ const url = window.URL.createObjectURL(blob);
243
+ const a = document.createElement('a');
244
+ a.href = url;
245
+ a.download = 'model.glb';
246
+ document.body.appendChild(a);
247
+ a.click();
248
+ window.URL.revokeObjectURL(url);
249
+ document.body.removeChild(a);
250
+
251
+ } catch (error) {
252
+ showError(error.message);
253
+ } finally {
254
+ showLoading(false);
255
+ }
256
+ });
257
+
258
+ // 辅助函数
259
+ function showLoading(show) {
260
+ document.querySelector('.loading').style.display = show ? 'block' : 'none';
261
+ document.getElementById('generateBtn').disabled = show;
262
+ }
263
+
264
+ function showError(message) {
265
+ const errorDiv = document.querySelector('.error-message');
266
+ errorDiv.textContent = message;
267
+ errorDiv.style.display = 'block';
268
+ }
269
+
270
+ function hideError() {
271
+ document.querySelector('.error-message').style.display = 'none';
272
+ }
273
+ </script>
274
+ </body>
275
+ </html>