File size: 3,025 Bytes
4555cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
from huggingface_hub import hf_hub_download
import onnxruntime as ort
from .error_handler import ModelLoadError, with_retry

class DWPoseManager:
    def __init__(self):
        self.model_repo = "yzd-v/DWPose"
        self.cache_dir = "./models"
        self.yolox_session = None
        self.dwpose_session = None
        self.yolox_input_name = None
        self.dwpose_input_name = None
        # refs互換の標準閾値を使用
        self.detection_threshold = 0.3
        
    @with_retry(max_retries=3, delay=2.0)
    def _download_model(self, filename):
        """モデルファイルダウンロード(リトライ付き)"""
        return hf_hub_download(
            repo_id=self.model_repo,
            filename=filename,
            cache_dir=self.cache_dir
        )
    
    def initialize(self):
        """モデルのダウンロードと初期化"""
        try:
            # キャッシュディレクトリ作成
            os.makedirs(self.cache_dir, exist_ok=True)
            
            # YOLOXモデル(リトライ付き)
            yolox_path = self._download_model("yolox_l.onnx")
            if not yolox_path:
                raise ModelLoadError("YOLOXモデルのダウンロードに失敗しました")
            
            # DWPoseモデル(リトライ付き)
            dwpose_path = self._download_model("dw-ll_ucoco_384.onnx")
            if not dwpose_path:
                raise ModelLoadError("DWPoseモデルのダウンロードに失敗しました")
            
            # ONNXセッション作成
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
            
            try:
                self.yolox_session = ort.InferenceSession(yolox_path, providers=providers)
                self.yolox_input_name = self.yolox_session.get_inputs()[0].name
                print(f"[DEBUG] YOLOX input name: {self.yolox_input_name}")
            except Exception as e:
                raise ModelLoadError(f"YOLOXモデルの初期化に失敗: {str(e)}")
            
            try:
                self.dwpose_session = ort.InferenceSession(dwpose_path, providers=providers)
                self.dwpose_input_name = self.dwpose_session.get_inputs()[0].name
                dwpose_input_shape = self.dwpose_session.get_inputs()[0].shape
                print(f"[DEBUG] DWPose input name: {self.dwpose_input_name}")
                print(f"[DEBUG] DWPose input shape: {dwpose_input_shape}")
            except Exception as e:
                raise ModelLoadError(f"DWPoseモデルの初期化に失敗: {str(e)}")
            
            return True, "モデル初期化成功"
            
        except ModelLoadError as e:
            return False, str(e)
        except Exception as e:
            return False, f"予期しないエラー: {str(e)}"
    
    def is_initialized(self):
        """初期化済みかチェック"""
        return self.yolox_session is not None and self.dwpose_session is not None