Spaces:
Sleeping
Sleeping
| """ | |
| Pose-LSTM Training Pipeline | |
| ============================ | |
| Training script for the pose-based LSTM classifier in the badminton shot | |
| analysis system. Trains a Conv1D + LSTM architecture on normalized 3D pose | |
| sequences extracted from MediaPipe. | |
| Key Features: | |
| - LSTM-based temporal modeling of pose sequences | |
| - Train/validation/test split (70/10/20) with stratification | |
| - MLflow experiment tracking with interactive run naming | |
| - DVC live callback for experiment versioning | |
| - Early stopping with best model checkpointing | |
| - Automatic model registration to MLflow Model Registry | |
| Pipeline Position: | |
| preprocess_pose.py โ [train_pose.py] โ evaluate.py | |
| Consumes preprocessed .npz files containing normalized pose features | |
| (T, 99) where T=sequence_length and 99=33 joints ร 3 coordinates. | |
| Dependencies: | |
| External: tensorflow, mlflow, sklearn, numpy, yaml, dvclive | |
| Internal: models.build_lstm_pose, mlflow_utils.MLflowRunManager | |
| Configuration (params.yaml): | |
| pose_pipeline: | |
| data_path: Path to preprocessed pose data | |
| model_path: Output path for trained model | |
| sequence_length: Number of frames per sample | |
| epochs: Maximum training epochs | |
| batch_size: Training batch size | |
| base: | |
| random_state: Seed for reproducibility | |
| Usage: | |
| python train_pose.py | |
| Author: IPD Research Team | |
| Version: 1.0.0 | |
| """ | |
| import os | |
| import yaml | |
| import numpy as np | |
| import mlflow | |
| import mlflow.tensorflow | |
| import tensorflow as tf | |
| from sklearn.model_selection import train_test_split | |
| from tensorflow.keras.utils import to_categorical | |
| from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint | |
| from dvclive.keras import DVCLiveCallback | |
| from models import build_lstm_pose | |
| from mlflow_utils import MLflowRunManager | |
| def main(): | |
| with open("params.yaml") as f: params = yaml.safe_load(f) | |
| cfg = params['pose_pipeline'] | |
| run_manager = MLflowRunManager("Pose_LSTM_Experiment") | |
| mlflow.enable_system_metrics_logging() | |
| mlflow.tensorflow.autolog(log_models=False) | |
| with run_manager.start_interactive_run( | |
| default_description="LSTM-Pose pipeline training with geometric normalization" | |
| ): | |
| mlflow.log_params(cfg) | |
| mlflow.log_params(params['mediapipe']) | |
| mlflow.log_params(params['segment_rules']) | |
| mlflow.log_param("base.random_state", params['base']['random_state']) | |
| X, y = [], [] | |
| if not os.path.exists(cfg['data_path']): | |
| print(f"Data path {cfg['data_path']} not found.") | |
| return | |
| classes = sorted(os.listdir(cfg['data_path'])) | |
| for i, cls in enumerate(classes): | |
| path = os.path.join(cfg['data_path'], cls) | |
| for f in os.listdir(path): | |
| X.append(np.load(os.path.join(path, f))['features']) | |
| y.append(i) | |
| if not X: | |
| print("โ No data loaded.") | |
| return | |
| X = np.array(X) | |
| y = np.array(y) | |
| y_cat = to_categorical(y, len(classes)) | |
| X_trainval, X_test, y_trainval, y_test, y_trainval_lbl, y_test_lbl = train_test_split( | |
| X, | |
| y_cat, | |
| y, | |
| test_size=0.2, | |
| stratify=y, | |
| random_state=params['base']['random_state'] | |
| ) | |
| X_train, X_val, y_train, y_val, _, _ = train_test_split( | |
| X_trainval, | |
| y_trainval, | |
| y_trainval_lbl, | |
| test_size=0.125, | |
| stratify=y_trainval_lbl, | |
| random_state=params['base']['random_state'] | |
| ) | |
| run_manager.log_dataset_info(X_train, X_val, X_test, y_train, y_val, y_test, classes) | |
| model = build_lstm_pose(X_train.shape[1:], len(classes)) | |
| run_manager.log_model_architecture(model) | |
| callbacks = [ | |
| EarlyStopping( | |
| patience=10, | |
| restore_best_weights=True, | |
| monitor='val_accuracy', | |
| verbose=1 | |
| ), | |
| ModelCheckpoint( | |
| cfg['model_path'], | |
| save_best_only=True, | |
| monitor='val_accuracy', | |
| verbose=1 | |
| ), | |
| DVCLiveCallback(save_dvc_exp=True) | |
| ] | |
| print("\n๐ Starting Pose-LSTM Training...") | |
| print(f" Train samples: {len(X_train)}") | |
| print(f" Val samples: {len(X_val)}") | |
| print(f" Test samples: {len(X_test)}") | |
| print(f" Classes: {classes}\n") | |
| history = model.fit( | |
| X_train, y_train, | |
| validation_data=(X_val, y_val), | |
| epochs=cfg['epochs'], | |
| batch_size=cfg['batch_size'], | |
| callbacks=callbacks, | |
| verbose=1 | |
| ) | |
| run_manager.log_training_artifacts(history, save_plots=True) | |
| print(f"\nโ Training finished!") | |
| print(f" Best Val Acc: {max(history.history['val_accuracy']):.4f}") | |
| print("\n๐ฆ Logging and Registering Best Model to MLflow...") | |
| best_model = tf.keras.models.load_model(cfg['model_path']) | |
| mlflow.keras.log_model( | |
| best_model, | |
| artifact_path="model", | |
| registered_model_name="Pose_LSTM", | |
| signature=mlflow.models.infer_signature(X_train, model.predict(X_train[:1])) | |
| ) | |
| print("โ Model registered as 'Pose_LSTM'.") | |
| if __name__ == "__main__": main() |