pvs_backend / src /components /posture_model.py
adnankhan-11's picture
PVD System - Initial deployment
d2885a7
from pathlib import Path
import numpy as np
import torch
from src.entity.config_entity import PostureModelConfig
from src.models.posture_cnn import MLP3d
from src.utils.common import resolve_device
from src.utils.logger import get_logger
class PostureModelService:
"""
Load and run the posture classifier.
This service wraps:
- model construction
- checkpoint loading
- inference
- confidence thresholding
It uses the exact same MLP3d architecture as your current project.
"""
def __init__(
self,
config: PostureModelConfig,
log_dir: Path | None = None,
log_level: str = "INFO",
) -> None:
self.config = config
self.logger = get_logger(
self.__class__.__name__, log_dir=log_dir, level=log_level
)
self.device = resolve_device("auto")
self.model: MLP3d | None = None
def build_model(self) -> MLP3d:
"""
Build the posture CNN architecture.
"""
model = MLP3d(
input_channel_num=self.config.input_channels,
output_class_num=self.config.output_classes,
input_shape=(
self.config.input_shape.depth,
self.config.input_shape.height,
self.config.input_shape.width,
),
conv_kernel_size=tuple(self.config.architecture.conv_kernel_size),
pool_kernel_size=self.config.architecture.pool_kernel_size,
activation_name=self.config.architecture.activation,
fc_dims=self.config.architecture.fc_dims,
)
return model
def load_model(self, weight_path: Path | None = None) -> MLP3d:
"""
Load posture model weights from checkpoint.
"""
if weight_path is None:
weight_path = self.config.weights.default_weight_file
weight_path = Path(weight_path)
if not weight_path.is_absolute():
ROOT_DIR = Path(__file__).resolve().parents[2]
weight_path = ROOT_DIR / weight_path
weight_path = weight_path.resolve()
print("DEBUG PATH:", weight_path)
if not weight_path.exists():
raise FileNotFoundError(
f"Posture model checkpoint not found: {weight_path}"
)
model = self.build_model()
checkpoint = torch.load(weight_path, map_location=self.device)
if not isinstance(checkpoint, dict) or "model_state_dict" not in checkpoint:
raise ValueError(
"Invalid posture checkpoint format. Expected a dictionary with 'model_state_dict'."
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(self.device)
self.model = model
self.logger.info("Posture model loaded from: %s", weight_path)
return model
def predict_tensor(self, input_tensor: torch.Tensor) -> tuple[str, int, np.ndarray]:
"""
Predict from one already-prepared tensor.
Expected tensor shape:
(1, C, D, H, W)
"""
if self.model is None:
self.load_model()
if input_tensor.ndim != 5:
raise ValueError(
f"Expected posture tensor shape (N, C, D, H, W), got: {input_tensor.shape}"
)
if input_tensor.shape[0] != 1:
raise ValueError(
"PostureModelService expects exactly one sample for runtime inference. "
f"Received batch size: {input_tensor.shape[0]}"
)
input_tensor = input_tensor.to(self.device)
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.sigmoid(outputs[0])
score_not_using = float(probabilities[0].item())
score_using = float(probabilities[1].item())
threshold = self.config.inference.confidence_threshold
prediction_is_using = (score_using > score_not_using) and (
score_using > threshold
)
class_signal = 1 if prediction_is_using else 0
display_score = score_using if prediction_is_using else score_not_using
score_text = f"{display_score:.2f}"
return score_text, class_signal, probabilities.cpu().numpy()
def predict_numpy(self, input_array: np.ndarray) -> tuple[str, int, np.ndarray]:
"""
Predict from numpy tensor.
Expected numpy shape:
(1, C, D, H, W)
"""
tensor = torch.tensor(input_array, dtype=torch.float32)
return self.predict_tensor(tensor)