| """ |
| Basic training example for HYPERDOA. |
| |
| This example demonstrates: |
| 1. Loading a pre-generated dataset |
| 2. Training an HDC model |
| 3. Evaluating on test data |
| 4. Saving the trained model |
| |
| Dataset format: |
| List of (X, Y) tuples where: |
| - X: Complex tensor of shape (N, T) - sensor observations |
| - Y: Tensor of shape (M,) - ground truth DOA in radians |
| |
| Usage: |
| python examples/train_basic.py --data-dir data/ |
| |
| Note: |
| You need to generate datasets first using SubspaceNet or similar tools. |
| See README.md for dataset generation instructions. |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| import torch |
|
|
| |
| import sys |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from hyperdoa import HDCAoAModel, DOAConfig, evaluate_hdc, set_seed, save_checkpoint |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train HDC model for DOA estimation") |
| parser.add_argument( |
| "--data-dir", type=Path, default=Path("data"), help="Data directory" |
| ) |
| parser.add_argument( |
| "--feature-type", |
| type=str, |
| default="lag", |
| choices=["lag", "spatial_smoothing"], |
| help="Feature extraction method", |
| ) |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument("--N", type=int, default=8, help="Number of sensors") |
| parser.add_argument("--M", type=int, default=2, help="Number of sources") |
| parser.add_argument("--T", type=int, default=100, help="Number of snapshots") |
| args = parser.parse_args() |
|
|
| |
| set_seed(args.seed) |
|
|
| |
| config = DOAConfig( |
| N=args.N, |
| M=args.M, |
| T=args.T, |
| ) |
|
|
| print("=" * 60) |
| print("HYPERDOA - Basic Training Example") |
| print("=" * 60) |
| print(f"Configuration:") |
| print(f" N (sensors): {config.N}") |
| print(f" M (sources): {config.M}") |
| print(f" T (snapshots): {config.T}") |
| print(f" Feature type: {args.feature_type}") |
| print(f" Seed: {args.seed}") |
| print("=" * 60) |
|
|
| |
| train_path = args.data_dir / "train_dataset.pt" |
| test_path = args.data_dir / "test_dataset.pt" |
|
|
| if not train_path.exists() or not test_path.exists(): |
| print("\nERROR: Dataset files not found!") |
| print(f" Expected: {train_path}") |
| print(f" Expected: {test_path}") |
| print("\nTo generate datasets, use SubspaceNet:") |
| print(" https://github.com/ShlezingerLab/SubspaceNet") |
| print("\nDataset format: List of (X, Y) tuples") |
| print(" X: Complex tensor (N, T)") |
| print(" Y: Tensor (M,) in radians") |
| return |
|
|
| print("\nLoading datasets...") |
| train_data = torch.load(train_path, weights_only=False) |
| test_data = torch.load(test_path, weights_only=False) |
|
|
| print(f" Train samples: {len(train_data)}") |
| print(f" Test samples: {len(test_data)}") |
|
|
| |
| print("\nTraining...") |
| test_loss, model = evaluate_hdc( |
| train_data=train_data, |
| test_data=test_data, |
| config=config, |
| feature_type=args.feature_type, |
| return_model=True, |
| verbose=True, |
| seed=args.seed, |
| ) |
|
|
| print("\n" + "=" * 60) |
| print(f"Final Test MSPE: {test_loss:.2f} dB") |
| print("=" * 60) |
|
|
| |
| ckpt_dir = Path("checkpoints") |
| ckpt_dir.mkdir(exist_ok=True) |
| ckpt_path = ckpt_dir / f"hdc_{args.feature_type}_model.pt" |
|
|
| save_checkpoint(model, str(ckpt_path), meta={"feature_type": args.feature_type}) |
| print(f"\nModel saved to: {ckpt_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|