""" This file trains your posture classifier using the same MLP3d logic from the current project. It reads saved pose-feature .npy files, builds train/val/test loaders, and tracks training cleanly. It also logs metrics with logger and MLflow, which can be connected to DAGsHub. The saved artifacts are disk-friendly and DVC-friendly for later versioning. """ from pathlib import Path import os from typing import Tuple import mlflow import numpy as np import torch from torch.utils.data import DataLoader, TensorDataset from torch.optim import Adam from src.entity.config_entity import PathsConfig, PostureModelConfig, TrainingConfig from src.models.posture_cnn import MCLoss, MLP3d from src.utils.common import resolve_device, save_json, save_yaml from src.utils.logger import get_logger from src.utils.naming_utils import parse_pose_data_filename class PostureTrainer: """ Train the posture model using .npy pose feature files. """ def __init__( self, paths_config: PathsConfig, posture_model_config: PostureModelConfig, training_config: TrainingConfig, log_dir: Path | None = None, log_level: str = "INFO", ) -> None: self.paths_config = paths_config self.posture_model_config = posture_model_config self.training_config = training_config self.logger = get_logger( self.__class__.__name__, log_dir=log_dir, level=log_level ) self.device = resolve_device(self.training_config.general.device_preference) self.model = MLP3d( input_channel_num=self.posture_model_config.input_channels, output_class_num=self.posture_model_config.output_classes, input_shape=( self.posture_model_config.input_shape.depth, self.posture_model_config.input_shape.height, self.posture_model_config.input_shape.width, ), conv_kernel_size=tuple( self.posture_model_config.architecture.conv_kernel_size ), pool_kernel_size=self.posture_model_config.architecture.pool_kernel_size, activation_name=self.posture_model_config.architecture.activation, fc_dims=self.posture_model_config.architecture.fc_dims, ).to(self.device) def _setup_mlflow(self) -> None: """ Configure MLflow tracking. DAGsHub can be used through the tracking URI. """ tracking_uri = os.getenv( "MLFLOW_TRACKING_URI", self.training_config.tracking.tracking_uri, ) mlflow.set_tracking_uri(tracking_uri) mlflow.set_experiment(self.training_config.tracking.experiment_name_posture) def _collect_feature_files(self) -> list[Path]: """ Find all .npy posture feature files. """ source_dir = self.training_config.posture_training.data.source_feature_dir feature_files = sorted(source_dir.rglob("*.npy")) if not feature_files: raise FileNotFoundError( f"No posture feature .npy files found in: {source_dir}" ) self.logger.info("Found %d posture feature files.", len(feature_files)) return feature_files def _load_dataset_arrays(self) -> Tuple[np.ndarray, np.ndarray]: """ Load all posture feature arrays and convert labels from filename. """ feature_files = self._collect_feature_files() inputs_list = [] labels_list = [] for file_path in feature_files: file_info = parse_pose_data_filename(file_path.name, extension=".npy") label_str = file_info["label"] label_value = 1 if label_str.startswith("U") else 0 feature_array = np.load(file_path) if feature_array.ndim != 5: raise ValueError( f"Expected posture feature array with 5 dims (N, C, D, H, W), got {feature_array.shape} " f"for file {file_path}" ) file_labels = np.full( shape=(feature_array.shape[0],), fill_value=label_value, dtype=np.int64 ) inputs_list.append(feature_array.astype(np.float32)) labels_list.append(file_labels) inputs = np.concatenate(inputs_list, axis=0) labels = np.concatenate(labels_list, axis=0) self.logger.info( "Combined posture dataset shape: X=%s, y=%s", inputs.shape, labels.shape ) return inputs, labels def _split_dataset( self, inputs: np.ndarray, labels: np.ndarray, ) -> Tuple[ tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray], ]: """ Split arrays into train/val/test using config ratios. """ total_samples = inputs.shape[0] indices = np.arange(total_samples) if self.training_config.posture_training.data.shuffle: np.random.shuffle(indices) inputs = inputs[indices] labels = labels[indices] train_ratio = self.training_config.posture_training.data.train_split val_ratio = self.training_config.posture_training.data.val_split train_end = int(total_samples * train_ratio) val_end = train_end + int(total_samples * val_ratio) x_train, y_train = inputs[:train_end], labels[:train_end] x_val, y_val = inputs[train_end:val_end], labels[train_end:val_end] x_test, y_test = inputs[val_end:], labels[val_end:] self.logger.info( "Posture split sizes -> train=%d, val=%d, test=%d", len(x_train), len(x_val), len(x_test), ) return (x_train, y_train), (x_val, y_val), (x_test, y_test) def _to_dataloader( self, inputs: np.ndarray, labels: np.ndarray, shuffle: bool ) -> DataLoader: """ Convert numpy arrays to PyTorch DataLoader. """ x_tensor = torch.tensor(inputs, dtype=torch.float32) y_tensor = torch.tensor(labels, dtype=torch.long) dataset = TensorDataset(x_tensor, y_tensor) loader = DataLoader( dataset, batch_size=self.training_config.posture_training.hyperparameters.batch_size, shuffle=shuffle, num_workers=self.training_config.general.num_workers, pin_memory=self.training_config.general.pin_memory, ) return loader def _save_dataset_manifest(self, x_train, x_val, x_test) -> None: """ Save a small manifest file. This is useful for DVC tracking later. """ manifest = { "task": "posture_training", "train_samples": int(len(x_train)), "val_samples": int(len(x_val)), "test_samples": int(len(x_test)), "feature_source_dir": str( self.training_config.posture_training.data.source_feature_dir ), } save_json( self.paths_config.metrics_dir / "posture_dataset_manifest.json", manifest ) def train(self) -> dict: """ Run full posture training. """ self._setup_mlflow() inputs, labels = self._load_dataset_arrays() (x_train, y_train), (x_val, y_val), (x_test, y_test) = self._split_dataset( inputs, labels ) self._save_dataset_manifest(x_train, x_val, x_test) train_loader = self._to_dataloader(x_train, y_train, shuffle=True) val_loader = self._to_dataloader(x_val, y_val, shuffle=False) test_loader = self._to_dataloader(x_test, y_test, shuffle=False) if self.training_config.posture_training.use_existing_weights: existing_weight_path = ( self.training_config.posture_training.existing_weights_path ) if existing_weight_path.exists(): checkpoint = torch.load(existing_weight_path, map_location=self.device) if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) self.logger.info( "Loaded existing posture weights from: %s", existing_weight_path ) criterion = MCLoss() optimizer = Adam( self.model.parameters(), lr=self.training_config.posture_training.hyperparameters.learning_rate, weight_decay=self.training_config.posture_training.hyperparameters.weight_decay, ) best_val_loss = float("inf") patience_counter = 0 history = { "train_loss": [], "val_loss": [], } with mlflow.start_run(run_name="posture_training_run"): mlflow.log_param("device", self.device) mlflow.log_param( "epochs", self.training_config.posture_training.hyperparameters.epochs ) mlflow.log_param( "batch_size", self.training_config.posture_training.hyperparameters.batch_size, ) mlflow.log_param( "learning_rate", self.training_config.posture_training.hyperparameters.learning_rate, ) for epoch in range( self.training_config.posture_training.hyperparameters.epochs ): self.model.train() running_train_loss = 0.0 for batch_inputs, batch_labels in train_loader: batch_inputs = batch_inputs.to(self.device) batch_labels = batch_labels.to(self.device) optimizer.zero_grad() outputs = self.model(batch_inputs) loss = criterion(outputs, batch_labels, self.model) loss.backward() optimizer.step() running_train_loss += loss.item() * len(batch_inputs) train_loss = running_train_loss / max(len(train_loader.dataset), 1) self.model.eval() running_val_loss = 0.0 with torch.no_grad(): for batch_inputs, batch_labels in val_loader: batch_inputs = batch_inputs.to(self.device) batch_labels = batch_labels.to(self.device) outputs = self.model(batch_inputs) loss = criterion(outputs, batch_labels, self.model) running_val_loss += loss.item() * len(batch_inputs) val_loss = running_val_loss / max(len(val_loader.dataset), 1) history["train_loss"].append(train_loss) history["val_loss"].append(val_loss) mlflow.log_metric("train_loss", train_loss, step=epoch) mlflow.log_metric("val_loss", val_loss, step=epoch) self.logger.info( "Epoch %d | train_loss=%.6f | val_loss=%.6f", epoch + 1, train_loss, val_loss, ) if ( val_loss < best_val_loss - self.training_config.posture_training.hyperparameters.min_delta ): best_val_loss = val_loss patience_counter = 0 best_path = ( self.paths_config.posture_weights_dir / self.training_config.posture_training.outputs.save_best_model_as ) torch.save({"model_state_dict": self.model.state_dict()}, best_path) self.logger.info("Saved new best posture model to: %s", best_path) else: patience_counter += 1 if ( patience_counter >= self.training_config.posture_training.hyperparameters.early_stopping_patience ): self.logger.info("Early stopping triggered for posture training.") break last_path = ( self.paths_config.posture_weights_dir / self.training_config.posture_training.outputs.save_last_model_as ) torch.save({"model_state_dict": self.model.state_dict()}, last_path) history_path = ( self.paths_config.metrics_dir / self.training_config.posture_training.outputs.history_file_name ) save_json(history_path, history) dvc_meta = { "artifact_type": "posture_model", "best_model_path": str( self.paths_config.posture_weights_dir / self.training_config.posture_training.outputs.save_best_model_as ), "last_model_path": str(last_path), "history_path": str(history_path), } save_yaml( self.paths_config.metrics_dir / "posture_dvc_manifest.yaml", dvc_meta ) return { "history": history, "test_loader": test_loader, "best_val_loss": best_val_loss, "model": self.model, }