character_openpose_editor / utils /dwpose_manager.py
gearmachine's picture
feat: Add DWPose model management and error handling utilities
4555cad
import os
from huggingface_hub import hf_hub_download
import onnxruntime as ort
from .error_handler import ModelLoadError, with_retry
class DWPoseManager:
def __init__(self):
self.model_repo = "yzd-v/DWPose"
self.cache_dir = "./models"
self.yolox_session = None
self.dwpose_session = None
self.yolox_input_name = None
self.dwpose_input_name = None
# refs互換の標準閾値を使用
self.detection_threshold = 0.3
@with_retry(max_retries=3, delay=2.0)
def _download_model(self, filename):
"""モデルファイルダウンロード(リトライ付き)"""
return hf_hub_download(
repo_id=self.model_repo,
filename=filename,
cache_dir=self.cache_dir
)
def initialize(self):
"""モデルのダウンロードと初期化"""
try:
# キャッシュディレクトリ作成
os.makedirs(self.cache_dir, exist_ok=True)
# YOLOXモデル(リトライ付き)
yolox_path = self._download_model("yolox_l.onnx")
if not yolox_path:
raise ModelLoadError("YOLOXモデルのダウンロードに失敗しました")
# DWPoseモデル(リトライ付き)
dwpose_path = self._download_model("dw-ll_ucoco_384.onnx")
if not dwpose_path:
raise ModelLoadError("DWPoseモデルのダウンロードに失敗しました")
# ONNXセッション作成
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
try:
self.yolox_session = ort.InferenceSession(yolox_path, providers=providers)
self.yolox_input_name = self.yolox_session.get_inputs()[0].name
print(f"[DEBUG] YOLOX input name: {self.yolox_input_name}")
except Exception as e:
raise ModelLoadError(f"YOLOXモデルの初期化に失敗: {str(e)}")
try:
self.dwpose_session = ort.InferenceSession(dwpose_path, providers=providers)
self.dwpose_input_name = self.dwpose_session.get_inputs()[0].name
dwpose_input_shape = self.dwpose_session.get_inputs()[0].shape
print(f"[DEBUG] DWPose input name: {self.dwpose_input_name}")
print(f"[DEBUG] DWPose input shape: {dwpose_input_shape}")
except Exception as e:
raise ModelLoadError(f"DWPoseモデルの初期化に失敗: {str(e)}")
return True, "モデル初期化成功"
except ModelLoadError as e:
return False, str(e)
except Exception as e:
return False, f"予期しないエラー: {str(e)}"
def is_initialized(self):
"""初期化済みかチェック"""
return self.yolox_session is not None and self.dwpose_session is not None