File size: 16,282 Bytes
87b6aef
 
 
 
 
 
 
 
 
 
 
088412a
 
 
87b6aef
 
 
088412a
 
 
 
 
 
 
 
 
 
 
87b6aef
 
f93ea45
87b6aef
 
 
f93ea45
 
 
 
 
87b6aef
 
f93ea45
87b6aef
f93ea45
87b6aef
 
f93ea45
87b6aef
 
 
 
 
 
 
 
 
 
088412a
 
 
 
 
 
 
87b6aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f93ea45
 
 
 
87b6aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088412a
 
 
 
 
 
 
 
 
 
 
 
87b6aef
088412a
 
 
 
 
 
 
 
87b6aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import gradio as gr
import spaces
import torch
from pathlib import Path
import tempfile
import os
import base64
from typing import Optional
import json

# SHARP モデルのインポート (遅延読み込み)
SHARP_AVAILABLE = False
SHARP_ERROR = None

try:
    from sharp import Sharp
    SHARP_AVAILABLE = True
    print("✅ SHARP module loaded successfully")
except ImportError as e:
    SHARP_ERROR = str(e)
    print(f"❌ SHARP import failed: {e}")
    import traceback
    traceback.print_exc()
except Exception as e:
    SHARP_ERROR = str(e)
    print(f"❌ Unexpected error loading SHARP: {e}")
    import traceback
    traceback.print_exc()

# グローバルモデルインスタンス (メモリ効率のため)
# 注意: ZeroGPUのマルチプロセッシングに対応するため、モジュールレベルで管理
_model = None

def get_model():
    """モデルインスタンスを取得(キャッシング)

    GPU workerプロセス内でモデルを初期化してキャッシュします。
    これによりpickling問題を回避します。
    """
    global _model
    if _model is None and SHARP_AVAILABLE:
        print("🔄 Initializing SHARP model in GPU worker...")
        _model = Sharp()
        print("✅ SHARP model initialized successfully")
    return _model

def _process_image_impl(image) -> tuple[Optional[str], str, str]:
    """
    画像から3D Gaussian Splatsを生成

    Args:
        image: PIL Image or numpy array

    Returns:
        tuple: (PLYファイルパス, ステータスメッセージ, PLYデータ(base64))
    """
    if not SHARP_AVAILABLE:
        error_msg = f"❌ SHARPモデルが利用できません\n\nエラー詳細: {SHARP_ERROR}\n\n"
        error_msg += "考えられる原因:\n"
        error_msg += "1. ml-sharpパッケージのインストール失敗\n"
        error_msg += "2. Python バージョンの非互換性\n"
        error_msg += "3. 依存関係の問題\n\n"
        error_msg += "ログを確認してください。"
        return None, error_msg, ""

    if image is None:
        return None, "❌ 画像をアップロードしてください", ""

    try:
        # 一時ファイルとして保存
        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_input:
            input_path = Path(tmp_input.name)

            # PIL Imageとして保存
            if hasattr(image, 'save'):
                image.save(input_path, format='JPEG')
            else:
                from PIL import Image
                Image.fromarray(image).save(input_path, format='JPEG')

        # モデルで推論
        model = get_model()
        print(f"🔄 Processing image: {input_path}")
        gaussians = model.predict(input_path)

        # PLYファイルとして保存
        with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmp_output:
            output_path = Path(tmp_output.name)

        gaussians.save(str(output_path))

        # PLYファイルをBase64エンコード (Three.jsで使用)
        with open(output_path, 'rb') as f:
            ply_data = f.read()
            ply_base64 = base64.b64encode(ply_data).decode('utf-8')

        # 統計情報を取得
        file_size = output_path.stat().st_size / (1024 * 1024)  # MB

        # 入力ファイルを削除
        if input_path.exists():
            input_path.unlink()

        status_msg = f"✅ 生成完了!\n📦 ファイルサイズ: {file_size:.2f} MB"

        return str(output_path), status_msg, ply_base64

    except Exception as e:
        import traceback
        error_msg = f"❌ エラーが発生しました:\n{str(e)}\n\n{traceback.format_exc()}"
        print(error_msg)
        return None, error_msg, ""

# ZeroGPUデコレータを適用 (180秒のGPUタイムアウト)
# 注意: モジュールレベル関数に適用することでpickling問題を回避
process_image = spaces.GPU(duration=180)(_process_image_impl)

# Three.js ビューアのHTMLテンプレート
def create_viewer_html(ply_base64: str) -> str:
    """Three.js + GaussianSplats3D ビューアのHTMLを生成"""

    if not ply_base64:
        return """
        <div style="width: 100%; height: 600px; display: flex; align-items: center; justify-content: center; background: #1a1a1a; color: white; border-radius: 8px;">
            <div style="text-align: center;">
                <h2>🎨 3D Gaussian Splats ビューア</h2>
                <p>左側で画像を処理すると、ここに3Dプレビューが表示されます</p>
            </div>
        </div>
        """

    html = f"""
    <!DOCTYPE html>
    <html lang="ja">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>3D Gaussian Splats Viewer</title>
        <style>
            body {{
                margin: 0;
                padding: 0;
                overflow: hidden;
                background: #000;
                font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            }}
            #container {{
                width: 100%;
                height: 600px;
                position: relative;
            }}
            #loading {{
                position: absolute;
                top: 50%;
                left: 50%;
                transform: translate(-50%, -50%);
                color: white;
                font-size: 18px;
                z-index: 1000;
            }}
            #controls {{
                position: absolute;
                top: 10px;
                left: 10px;
                background: rgba(0, 0, 0, 0.7);
                color: white;
                padding: 10px;
                border-radius: 5px;
                font-size: 12px;
                z-index: 1000;
            }}
        </style>
    </head>
    <body>
        <div id="container">
            <div id="loading">🔄 3Dモデルを読み込み中...</div>
            <div id="controls">
                <div>🖱️ ドラッグ: 回転</div>
                <div>🔍 スクロール: ズーム</div>
                <div>⌨️ 右クリック: パン</div>
            </div>
        </div>

        <script type="importmap">
        {{
            "imports": {{
                "three": "https://cdn.jsdelivr.net/npm/three@0.168.0/build/three.module.js",
                "three/addons/": "https://cdn.jsdelivr.net/npm/three@0.168.0/examples/jsm/"
            }}
        }}
        </script>

        <script type="module">
            import * as THREE from 'three';
            import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';

            // シーンの初期化
            const container = document.getElementById('container');
            const loading = document.getElementById('loading');

            const scene = new THREE.Scene();
            scene.background = new THREE.Color(0x1a1a1a);

            const camera = new THREE.PerspectiveCamera(
                75,
                container.clientWidth / container.clientHeight,
                0.1,
                1000
            );
            camera.position.set(0, 0, 5);

            const renderer = new THREE.WebGLRenderer({{ antialias: true }});
            renderer.setSize(container.clientWidth, container.clientHeight);
            renderer.setPixelRatio(window.devicePixelRatio);
            container.appendChild(renderer.domElement);

            // OrbitControls
            const controls = new OrbitControls(camera, renderer.domElement);
            controls.enableDamping = true;
            controls.dampingFactor = 0.05;

            // ライト
            const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
            scene.add(ambientLight);

            const directionalLight = new THREE.DirectionalLight(0xffffff, 1);
            directionalLight.position.set(5, 10, 7.5);
            scene.add(directionalLight);

            // グリッドヘルパー
            const gridHelper = new THREE.GridHelper(10, 10);
            scene.add(gridHelper);

            // PLYローダー
            async function loadPLY() {{
                try {{
                    // Base64からArrayBufferに変換
                    const plyBase64 = '{ply_base64}';
                    const binaryString = atob(plyBase64);
                    const bytes = new Uint8Array(binaryString.length);
                    for (let i = 0; i < binaryString.length; i++) {{
                        bytes[i] = binaryString.charCodeAt(i);
                    }}

                    // PLYLoaderを動的にインポート
                    const {{ PLYLoader }} = await import('three/addons/loaders/PLYLoader.js');
                    const loader = new PLYLoader();

                    // ArrayBufferをBlob経由でロード
                    const blob = new Blob([bytes], {{ type: 'application/octet-stream' }});
                    const url = URL.createObjectURL(blob);

                    loader.load(
                        url,
                        function (geometry) {{
                            loading.style.display = 'none';

                            // ポイントクラウドとしてレンダリング
                            geometry.computeVertexNormals();

                            // カラー情報があるか確認
                            const hasColors = geometry.attributes.color !== undefined;

                            const material = new THREE.PointsMaterial({{
                                size: 0.01,
                                vertexColors: hasColors,
                                color: hasColors ? undefined : 0x00ff00,
                                sizeAttenuation: true
                            }});

                            const points = new THREE.Points(geometry, material);
                            scene.add(points);

                            // カメラ位置を調整
                            geometry.computeBoundingBox();
                            const bbox = geometry.boundingBox;
                            const center = new THREE.Vector3();
                            bbox.getCenter(center);

                            const size = new THREE.Vector3();
                            bbox.getSize(size);
                            const maxDim = Math.max(size.x, size.y, size.z);
                            const fov = camera.fov * (Math.PI / 180);
                            let cameraZ = Math.abs(maxDim / Math.tan(fov / 2));
                            cameraZ *= 1.5;

                            camera.position.set(center.x, center.y, center.z + cameraZ);
                            camera.lookAt(center);
                            controls.target.copy(center);
                            controls.update();

                            URL.revokeObjectURL(url);

                            console.log('✅ PLYファイルの読み込み完了');
                        }},
                        function (xhr) {{
                            const percent = (xhr.loaded / xhr.total * 100).toFixed(0);
                            loading.textContent = `🔄 読み込み中... ${{percent}}%`;
                        }},
                        function (error) {{
                            console.error('❌ PLY読み込みエラー:', error);
                            loading.textContent = '❌ 読み込みエラー';
                            loading.style.color = 'red';
                        }}
                    );
                }} catch (error) {{
                    console.error('❌ エラー:', error);
                    loading.textContent = '❌ エラーが発生しました';
                    loading.style.color = 'red';
                }}
            }}

            // アニメーションループ
            function animate() {{
                requestAnimationFrame(animate);
                controls.update();
                renderer.render(scene, camera);
            }}

            // リサイズ対応
            window.addEventListener('resize', () => {{
                camera.aspect = container.clientWidth / container.clientHeight;
                camera.updateProjectionMatrix();
                renderer.setSize(container.clientWidth, container.clientHeight);
            }});

            // PLYを読み込んで開始
            loadPLY();
            animate();
        </script>
    </body>
    </html>
    """
    return html

def update_viewer(ply_base64: str) -> str:
    """ビューアを更新"""
    return create_viewer_html(ply_base64)

# Gradio UI
with gr.Blocks(
    theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"),
    title="SHARP: 3D Gaussian Splats Generator"
) as demo:
    # SHARPステータスバナー
    if SHARP_AVAILABLE:
        status_banner = """
        # 🎨 SHARP: 単一画像から3D Gaussian Splatsを生成

        ✅ **SHARPモデル: 正常に読み込まれました**
        """
    else:
        status_banner = f"""
        # 🎨 SHARP: 単一画像から3D Gaussian Splatsを生成

        ⚠️ **警告: SHARPモデルが読み込めませんでした**

        エラー: `{SHARP_ERROR}`

        Spaceのログを確認するか、リポジトリの管理者にお問い合わせください。
        """

    gr.Markdown(status_banner)

    gr.Markdown("""
    Appleの最新技術「SHARP」を使用して、1枚の画像から高品質な3D Gaussian Splatsを生成します。
    生成された3DモデルはThree.jsで右側にリアルタイムプレビューされます。

    ### 使い方
    1. 左側のエリアに画像をアップロード
    2. 「生成開始」ボタンをクリック
    3. 右側で3Dモデルをインタラクティブに確認
    4. PLYファイルをダウンロード可能

    **ZeroGPU (Nvidia H200)** で高速に処理されます 🚀
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### 📸 入力画像")
            input_image = gr.Image(
                label="画像をアップロード",
                type="pil",
                sources=["upload", "clipboard"],
                height=400
            )

            generate_btn = gr.Button(
                "🚀 生成開始",
                variant="primary",
                size="lg"
            )

            status_box = gr.Textbox(
                label="ステータス",
                lines=3,
                interactive=False
            )

            output_file = gr.File(
                label="📦 PLYファイルをダウンロード",
                interactive=False
            )

        with gr.Column(scale=1):
            gr.Markdown("### 🎬 3Dプレビュー (Three.js)")
            viewer_html = gr.HTML(
                create_viewer_html(""),
                label="3D Viewer"
            )

    # 非表示のステート (PLY Base64データ)
    ply_data_state = gr.State("")

    # イベントハンドラ
    def on_generate(image):
        ply_path, status, ply_base64 = process_image(image)
        viewer = create_viewer_html(ply_base64)
        return ply_path, status, ply_base64, viewer

    generate_btn.click(
        fn=on_generate,
        inputs=[input_image],
        outputs=[output_file, status_box, ply_data_state, viewer_html]
    )

    gr.Markdown("""
    ---
    ### ℹ️ 技術情報

    - **モデル**: Apple SHARP (Sharp Monocular View Synthesis)
    - **出力形式**: PLY (Polygon File Format)
    - **レンダリング**: Three.js + PLYLoader
    - **GPU**: ZeroGPU (Nvidia H200, 動的割り当て)
    - **処理時間**: 通常1秒以下

    ### 📚 リソース
    - [SHARP GitHub](https://github.com/apple/ml-sharp)
    - [論文 (arXiv)](https://arxiv.org/abs/2512.10685)
    - [Hugging Face Model](https://huggingface.co/apple/Sharp)

    ### ⚠️ 注意事項
    - 処理にはGPUを使用するため、待機時間が発生する場合があります
    - ZeroGPUは60秒のタイムアウトがあります
    - 大きな画像は自動的にリサイズされます
    """)

if __name__ == "__main__":
    demo.launch()