|
|
import os |
|
|
from huggingface_hub import hf_hub_download |
|
|
import onnxruntime as ort |
|
|
from .error_handler import ModelLoadError, with_retry |
|
|
|
|
|
class DWPoseManager: |
|
|
def __init__(self): |
|
|
self.model_repo = "yzd-v/DWPose" |
|
|
self.cache_dir = "./models" |
|
|
self.yolox_session = None |
|
|
self.dwpose_session = None |
|
|
self.yolox_input_name = None |
|
|
self.dwpose_input_name = None |
|
|
|
|
|
self.detection_threshold = 0.3 |
|
|
|
|
|
@with_retry(max_retries=3, delay=2.0) |
|
|
def _download_model(self, filename): |
|
|
"""モデルファイルダウンロード(リトライ付き)""" |
|
|
return hf_hub_download( |
|
|
repo_id=self.model_repo, |
|
|
filename=filename, |
|
|
cache_dir=self.cache_dir |
|
|
) |
|
|
|
|
|
def initialize(self): |
|
|
"""モデルのダウンロードと初期化""" |
|
|
try: |
|
|
|
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
yolox_path = self._download_model("yolox_l.onnx") |
|
|
if not yolox_path: |
|
|
raise ModelLoadError("YOLOXモデルのダウンロードに失敗しました") |
|
|
|
|
|
|
|
|
dwpose_path = self._download_model("dw-ll_ucoco_384.onnx") |
|
|
if not dwpose_path: |
|
|
raise ModelLoadError("DWPoseモデルのダウンロードに失敗しました") |
|
|
|
|
|
|
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
|
|
|
try: |
|
|
self.yolox_session = ort.InferenceSession(yolox_path, providers=providers) |
|
|
self.yolox_input_name = self.yolox_session.get_inputs()[0].name |
|
|
print(f"[DEBUG] YOLOX input name: {self.yolox_input_name}") |
|
|
except Exception as e: |
|
|
raise ModelLoadError(f"YOLOXモデルの初期化に失敗: {str(e)}") |
|
|
|
|
|
try: |
|
|
self.dwpose_session = ort.InferenceSession(dwpose_path, providers=providers) |
|
|
self.dwpose_input_name = self.dwpose_session.get_inputs()[0].name |
|
|
dwpose_input_shape = self.dwpose_session.get_inputs()[0].shape |
|
|
print(f"[DEBUG] DWPose input name: {self.dwpose_input_name}") |
|
|
print(f"[DEBUG] DWPose input shape: {dwpose_input_shape}") |
|
|
except Exception as e: |
|
|
raise ModelLoadError(f"DWPoseモデルの初期化に失敗: {str(e)}") |
|
|
|
|
|
return True, "モデル初期化成功" |
|
|
|
|
|
except ModelLoadError as e: |
|
|
return False, str(e) |
|
|
except Exception as e: |
|
|
return False, f"予期しないエラー: {str(e)}" |
|
|
|
|
|
def is_initialized(self): |
|
|
"""初期化済みかチェック""" |
|
|
return self.yolox_session is not None and self.dwpose_session is not None |