File size: 4,599 Bytes
d2885a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from pathlib import Path

import numpy as np
import torch

from src.entity.config_entity import PostureModelConfig
from src.models.posture_cnn import MLP3d
from src.utils.common import resolve_device
from src.utils.logger import get_logger


class PostureModelService:
    """
    Load and run the posture classifier.

    This service wraps:
    - model construction
    - checkpoint loading
    - inference
    - confidence thresholding

    It uses the exact same MLP3d architecture as your current project.
    """

    def __init__(
        self,
        config: PostureModelConfig,
        log_dir: Path | None = None,
        log_level: str = "INFO",
    ) -> None:
        self.config = config
        self.logger = get_logger(
            self.__class__.__name__, log_dir=log_dir, level=log_level
        )

        self.device = resolve_device("auto")
        self.model: MLP3d | None = None

    def build_model(self) -> MLP3d:
        """
        Build the posture CNN architecture.
        """
        model = MLP3d(
            input_channel_num=self.config.input_channels,
            output_class_num=self.config.output_classes,
            input_shape=(
                self.config.input_shape.depth,
                self.config.input_shape.height,
                self.config.input_shape.width,
            ),
            conv_kernel_size=tuple(self.config.architecture.conv_kernel_size),
            pool_kernel_size=self.config.architecture.pool_kernel_size,
            activation_name=self.config.architecture.activation,
            fc_dims=self.config.architecture.fc_dims,
        )
        return model

    def load_model(self, weight_path: Path | None = None) -> MLP3d:
        """
        Load posture model weights from checkpoint.
        """
        if weight_path is None:
            weight_path = self.config.weights.default_weight_file

        weight_path = Path(weight_path)
        if not weight_path.is_absolute():
            ROOT_DIR = Path(__file__).resolve().parents[2]
            weight_path = ROOT_DIR / weight_path
        weight_path = weight_path.resolve()
        print("DEBUG PATH:", weight_path)

        if not weight_path.exists():
            raise FileNotFoundError(
                f"Posture model checkpoint not found: {weight_path}"
            )

        model = self.build_model()
        checkpoint = torch.load(weight_path, map_location=self.device)

        if not isinstance(checkpoint, dict) or "model_state_dict" not in checkpoint:
            raise ValueError(
                "Invalid posture checkpoint format. Expected a dictionary with 'model_state_dict'."
            )

        model.load_state_dict(checkpoint["model_state_dict"])
        model.eval()
        model.to(self.device)

        self.model = model
        self.logger.info("Posture model loaded from: %s", weight_path)
        return model

    def predict_tensor(self, input_tensor: torch.Tensor) -> tuple[str, int, np.ndarray]:
        """
        Predict from one already-prepared tensor.

        Expected tensor shape:
        (1, C, D, H, W)
        """
        if self.model is None:
            self.load_model()

        if input_tensor.ndim != 5:
            raise ValueError(
                f"Expected posture tensor shape (N, C, D, H, W), got: {input_tensor.shape}"
            )

        if input_tensor.shape[0] != 1:
            raise ValueError(
                "PostureModelService expects exactly one sample for runtime inference. "
                f"Received batch size: {input_tensor.shape[0]}"
            )

        input_tensor = input_tensor.to(self.device)

        with torch.no_grad():
            outputs = self.model(input_tensor)
            probabilities = torch.sigmoid(outputs[0])

        score_not_using = float(probabilities[0].item())
        score_using = float(probabilities[1].item())

        threshold = self.config.inference.confidence_threshold
        prediction_is_using = (score_using > score_not_using) and (
            score_using > threshold
        )

        class_signal = 1 if prediction_is_using else 0
        display_score = score_using if prediction_is_using else score_not_using
        score_text = f"{display_score:.2f}"

        return score_text, class_signal, probabilities.cpu().numpy()

    def predict_numpy(self, input_array: np.ndarray) -> tuple[str, int, np.ndarray]:
        """
        Predict from numpy tensor.

        Expected numpy shape:
        (1, C, D, H, W)
        """
        tensor = torch.tensor(input_array, dtype=torch.float32)
        return self.predict_tensor(tensor)