import gradio as gr import spaces import torch from pathlib import Path import tempfile import os import base64 from typing import Optional import json # SHARP モデルのインポート (遅延読み込み) SHARP_AVAILABLE = False SHARP_ERROR = None try: from sharp import Sharp SHARP_AVAILABLE = True print("✅ SHARP module loaded successfully") except ImportError as e: SHARP_ERROR = str(e) print(f"❌ SHARP import failed: {e}") import traceback traceback.print_exc() except Exception as e: SHARP_ERROR = str(e) print(f"❌ Unexpected error loading SHARP: {e}") import traceback traceback.print_exc() # グローバルモデルインスタンス (メモリ効率のため) # 注意: ZeroGPUのマルチプロセッシングに対応するため、モジュールレベルで管理 _model = None def get_model(): """モデルインスタンスを取得(キャッシング) GPU workerプロセス内でモデルを初期化してキャッシュします。 これによりpickling問題を回避します。 """ global _model if _model is None and SHARP_AVAILABLE: print("🔄 Initializing SHARP model in GPU worker...") _model = Sharp() print("✅ SHARP model initialized successfully") return _model def _process_image_impl(image) -> tuple[Optional[str], str, str]: """ 画像から3D Gaussian Splatsを生成 Args: image: PIL Image or numpy array Returns: tuple: (PLYファイルパス, ステータスメッセージ, PLYデータ(base64)) """ if not SHARP_AVAILABLE: error_msg = f"❌ SHARPモデルが利用できません\n\nエラー詳細: {SHARP_ERROR}\n\n" error_msg += "考えられる原因:\n" error_msg += "1. ml-sharpパッケージのインストール失敗\n" error_msg += "2. Python バージョンの非互換性\n" error_msg += "3. 依存関係の問題\n\n" error_msg += "ログを確認してください。" return None, error_msg, "" if image is None: return None, "❌ 画像をアップロードしてください", "" try: # 一時ファイルとして保存 with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_input: input_path = Path(tmp_input.name) # PIL Imageとして保存 if hasattr(image, 'save'): image.save(input_path, format='JPEG') else: from PIL import Image Image.fromarray(image).save(input_path, format='JPEG') # モデルで推論 model = get_model() print(f"🔄 Processing image: {input_path}") gaussians = model.predict(input_path) # PLYファイルとして保存 with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as tmp_output: output_path = Path(tmp_output.name) gaussians.save(str(output_path)) # PLYファイルをBase64エンコード (Three.jsで使用) with open(output_path, 'rb') as f: ply_data = f.read() ply_base64 = base64.b64encode(ply_data).decode('utf-8') # 統計情報を取得 file_size = output_path.stat().st_size / (1024 * 1024) # MB # 入力ファイルを削除 if input_path.exists(): input_path.unlink() status_msg = f"✅ 生成完了!\n📦 ファイルサイズ: {file_size:.2f} MB" return str(output_path), status_msg, ply_base64 except Exception as e: import traceback error_msg = f"❌ エラーが発生しました:\n{str(e)}\n\n{traceback.format_exc()}" print(error_msg) return None, error_msg, "" # ZeroGPUデコレータを適用 (180秒のGPUタイムアウト) # 注意: モジュールレベル関数に適用することでpickling問題を回避 process_image = spaces.GPU(duration=180)(_process_image_impl) # Three.js ビューアのHTMLテンプレート def create_viewer_html(ply_base64: str) -> str: """Three.js + GaussianSplats3D ビューアのHTMLを生成""" if not ply_base64: return """
左側で画像を処理すると、ここに3Dプレビューが表示されます