File size: 3,643 Bytes
37ed720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Basic training example for HYPERDOA.

This example demonstrates:
1. Loading a pre-generated dataset
2. Training an HDC model
3. Evaluating on test data
4. Saving the trained model

Dataset format:
    List of (X, Y) tuples where:
    - X: Complex tensor of shape (N, T) - sensor observations
    - Y: Tensor of shape (M,) - ground truth DOA in radians

Usage:
    python examples/train_basic.py --data-dir data/

Note:
    You need to generate datasets first using SubspaceNet or similar tools.
    See README.md for dataset generation instructions.
"""

import argparse
from pathlib import Path
import torch

# Add parent directory to path for local development
import sys

sys.path.insert(0, str(Path(__file__).parent.parent))

from hyperdoa import HDCAoAModel, DOAConfig, evaluate_hdc, set_seed, save_checkpoint


def main():
    parser = argparse.ArgumentParser(description="Train HDC model for DOA estimation")
    parser.add_argument(
        "--data-dir", type=Path, default=Path("data"), help="Data directory"
    )
    parser.add_argument(
        "--feature-type",
        type=str,
        default="lag",
        choices=["lag", "spatial_smoothing"],
        help="Feature extraction method",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--N", type=int, default=8, help="Number of sensors")
    parser.add_argument("--M", type=int, default=2, help="Number of sources")
    parser.add_argument("--T", type=int, default=100, help="Number of snapshots")
    args = parser.parse_args()

    # Set seed for reproducibility
    set_seed(args.seed)

    # Configure system parameters
    config = DOAConfig(
        N=args.N,
        M=args.M,
        T=args.T,
    )

    print("=" * 60)
    print("HYPERDOA - Basic Training Example")
    print("=" * 60)
    print(f"Configuration:")
    print(f"  N (sensors): {config.N}")
    print(f"  M (sources): {config.M}")
    print(f"  T (snapshots): {config.T}")
    print(f"  Feature type: {args.feature_type}")
    print(f"  Seed: {args.seed}")
    print("=" * 60)

    # Load datasets
    train_path = args.data_dir / "train_dataset.pt"
    test_path = args.data_dir / "test_dataset.pt"

    if not train_path.exists() or not test_path.exists():
        print("\nERROR: Dataset files not found!")
        print(f"  Expected: {train_path}")
        print(f"  Expected: {test_path}")
        print("\nTo generate datasets, use SubspaceNet:")
        print("  https://github.com/ShlezingerLab/SubspaceNet")
        print("\nDataset format: List of (X, Y) tuples")
        print("  X: Complex tensor (N, T)")
        print("  Y: Tensor (M,) in radians")
        return

    print("\nLoading datasets...")
    train_data = torch.load(train_path, weights_only=False)
    test_data = torch.load(test_path, weights_only=False)

    print(f"  Train samples: {len(train_data)}")
    print(f"  Test samples: {len(test_data)}")

    # Train and evaluate
    print("\nTraining...")
    test_loss, model = evaluate_hdc(
        train_data=train_data,
        test_data=test_data,
        config=config,
        feature_type=args.feature_type,
        return_model=True,
        verbose=True,
        seed=args.seed,
    )

    print("\n" + "=" * 60)
    print(f"Final Test MSPE: {test_loss:.2f} dB")
    print("=" * 60)

    # Save model
    ckpt_dir = Path("checkpoints")
    ckpt_dir.mkdir(exist_ok=True)
    ckpt_path = ckpt_dir / f"hdc_{args.feature_type}_model.pt"

    save_checkpoint(model, str(ckpt_path), meta={"feature_type": args.feature_type})
    print(f"\nModel saved to: {ckpt_path}")


if __name__ == "__main__":
    main()