Commit
·
9727e5e
1
Parent(s):
8a42349
minor fixes fpr src/dataset
Browse files- src/data/dataset.py +79 -40
- src/main.py +30 -0
- src/main/parser.py +43 -62
- src/main/trainer.py +6 -3
- src/utils.py +13 -4
src/data/dataset.py
CHANGED
|
@@ -1,19 +1,41 @@
|
|
| 1 |
-
"""Module for loading and processing .mat files containing channel estimates for PyTorch.
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Callable, List, Optional, Tuple, Union
|
| 5 |
|
| 6 |
import scipy.io as sio
|
| 7 |
import torch
|
| 8 |
from torch.utils.data import Dataset, DataLoader
|
|
|
|
| 9 |
|
| 10 |
from src.utils import extract_values
|
| 11 |
|
| 12 |
__all__ = ['MatDataset', 'get_test_dataloaders']
|
| 13 |
|
| 14 |
|
| 15 |
-
|
| 16 |
-
class PilotDimensions:
|
| 17 |
"""Container for pilot signal dimensions.
|
| 18 |
|
| 19 |
Stores and validates the dimensions of pilot signals used in channel estimation.
|
|
@@ -22,17 +44,8 @@ class PilotDimensions:
|
|
| 22 |
num_subcarriers: Number of subcarriers in the pilot signal
|
| 23 |
num_ofdm_symbols: Number of OFDM symbols in the pilot signal
|
| 24 |
"""
|
| 25 |
-
num_subcarriers: int
|
| 26 |
-
num_ofdm_symbols: int
|
| 27 |
-
|
| 28 |
-
def __post_init__(self):
|
| 29 |
-
"""Validate dimensions after initialization.
|
| 30 |
-
|
| 31 |
-
Raises:
|
| 32 |
-
ValueError: If either dimension is not a positive integer
|
| 33 |
-
"""
|
| 34 |
-
if self.num_subcarriers <= 0 or self.num_ofdm_symbols <= 0:
|
| 35 |
-
raise ValueError("Pilot dimensions must be positive integers")
|
| 36 |
|
| 37 |
def as_tuple(self) -> Tuple[int, int]:
|
| 38 |
"""Return dimensions as a tuple.
|
|
@@ -48,6 +61,20 @@ class MatDataset(Dataset):
|
|
| 48 |
|
| 49 |
Processes .mat files containing channel estimation data and converts them into
|
| 50 |
PyTorch complex tensors for channel estimation tasks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
|
| 53 |
def __init__(
|
|
@@ -59,16 +86,16 @@ class MatDataset(Dataset):
|
|
| 59 |
"""Initialize the MatDataset.
|
| 60 |
|
| 61 |
Args:
|
| 62 |
-
data_dir: Path to the directory containing the dataset.
|
| 63 |
pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
|
| 64 |
transform: Optional transformation to apply to samples.
|
| 65 |
|
| 66 |
Raises:
|
| 67 |
-
ValueError: If pilot dimensions are invalid.
|
| 68 |
FileNotFoundError: If data_dir doesn't exist.
|
|
|
|
| 69 |
"""
|
| 70 |
self.data_dir = Path(data_dir)
|
| 71 |
-
self.pilot_dims = PilotDimensions(
|
| 72 |
self.transform = transform
|
| 73 |
|
| 74 |
if not self.data_dir.exists():
|
|
@@ -88,7 +115,6 @@ class MatDataset(Dataset):
|
|
| 88 |
|
| 89 |
def _process_channel_data(
|
| 90 |
self,
|
| 91 |
-
h_ideal: torch.Tensor,
|
| 92 |
mat_data: dict
|
| 93 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 94 |
"""Process channel data and extract pilot values from LS estimates.
|
|
@@ -97,8 +123,7 @@ class MatDataset(Dataset):
|
|
| 97 |
returning complex-valued tensors for both estimate and ground truth.
|
| 98 |
|
| 99 |
Args:
|
| 100 |
-
|
| 101 |
-
mat_data: Loaded .mat file data
|
| 102 |
|
| 103 |
Returns:
|
| 104 |
Tuple of (pilot LS estimate, ground truth channel)
|
|
@@ -107,6 +132,9 @@ class MatDataset(Dataset):
|
|
| 107 |
ValueError: If the data format is unexpected or processing fails
|
| 108 |
"""
|
| 109 |
try:
|
|
|
|
|
|
|
|
|
|
| 110 |
# Extract LS channel estimate with zero entries
|
| 111 |
hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)
|
| 112 |
|
|
@@ -143,9 +171,10 @@ class MatDataset(Dataset):
|
|
| 143 |
|
| 144 |
Returns:
|
| 145 |
Tuple containing:
|
| 146 |
-
- Pilot LS channel estimate (complex tensor)
|
| 147 |
-
- Ground truth channel estimate (complex tensor)
|
| 148 |
-
- Metadata
|
|
|
|
| 149 |
|
| 150 |
Raises:
|
| 151 |
ValueError: If file format is invalid or processing fails.
|
|
@@ -155,16 +184,12 @@ class MatDataset(Dataset):
|
|
| 155 |
raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
|
| 156 |
|
| 157 |
try:
|
| 158 |
-
# Load .mat file
|
| 159 |
mat_data = sio.loadmat(self.file_list[idx])
|
| 160 |
if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
|
| 161 |
raise ValueError("Invalid .mat file format: missing required data")
|
| 162 |
|
| 163 |
-
# Extract ground truth channel
|
| 164 |
-
h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
|
| 165 |
-
|
| 166 |
# Process channel data to extract pilot estimates
|
| 167 |
-
h_est, h_ideal = self._process_channel_data(
|
| 168 |
|
| 169 |
# Extract metadata from filename
|
| 170 |
meta_data = extract_values(self.file_list[idx].name)
|
|
@@ -184,7 +209,8 @@ class MatDataset(Dataset):
|
|
| 184 |
|
| 185 |
def get_test_dataloaders(
|
| 186 |
dataset_dir: Union[str, Path],
|
| 187 |
-
|
|
|
|
| 188 |
) -> List[Tuple[str, DataLoader]]:
|
| 189 |
"""Create DataLoaders for each subdirectory in the dataset directory.
|
| 190 |
|
|
@@ -192,26 +218,39 @@ def get_test_dataloaders(
|
|
| 192 |
all subdirectories in the specified dataset directory, useful for testing
|
| 193 |
across multiple test conditions or scenarios.
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
Args:
|
| 196 |
dataset_dir: Path to main directory containing dataset subdirectories
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
- batch_size: Number of samples per batch
|
| 200 |
|
| 201 |
Returns:
|
| 202 |
List of tuples containing (subdirectory_name, corresponding_dataloader)
|
| 203 |
|
| 204 |
Raises:
|
| 205 |
FileNotFoundError: If dataset_dir doesn't exist
|
| 206 |
-
ValueError: If
|
| 207 |
"""
|
| 208 |
dataset_dir = Path(dataset_dir)
|
| 209 |
if not dataset_dir.exists():
|
| 210 |
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
|
| 211 |
|
| 212 |
-
if not isinstance(params, dict) or "pilot_dims" not in params or "batch_size" not in params:
|
| 213 |
-
raise ValueError("params must be a dict containing 'pilot_dims' and 'batch_size'")
|
| 214 |
-
|
| 215 |
subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
|
| 216 |
if not subdirs:
|
| 217 |
raise ValueError(f"No subdirectories found in {dataset_dir}")
|
|
@@ -221,7 +260,7 @@ def get_test_dataloaders(
|
|
| 221 |
subdir.name,
|
| 222 |
MatDataset(
|
| 223 |
subdir,
|
| 224 |
-
|
| 225 |
)
|
| 226 |
)
|
| 227 |
for subdir in subdirs
|
|
@@ -230,8 +269,8 @@ def get_test_dataloaders(
|
|
| 230 |
return [
|
| 231 |
(name, DataLoader(
|
| 232 |
dataset,
|
| 233 |
-
batch_size=
|
| 234 |
-
shuffle=False,
|
| 235 |
num_workers=0
|
| 236 |
))
|
| 237 |
for name, dataset in test_datasets
|
|
|
|
| 1 |
+
"""Module for loading and processing .mat files containing channel estimates for PyTorch.
|
| 2 |
+
|
| 3 |
+
This module expects .mat files with a specific naming convention and internal structure:
|
| 4 |
+
|
| 5 |
+
File Naming Convention:
|
| 6 |
+
Files must follow the pattern: {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
|
| 7 |
+
|
| 8 |
+
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
| 9 |
+
- file_number: Sequential file identifier
|
| 10 |
+
- SNR: Signal-to-Noise Ratio in dB
|
| 11 |
+
- DS: Delay Spread
|
| 12 |
+
- DOP: Maximum Doppler Shift
|
| 13 |
+
- N: Pilot placement frequency
|
| 14 |
+
- channel_type: Channel model type (e.g., TDL-A)
|
| 15 |
+
|
| 16 |
+
File Content Structure:
|
| 17 |
+
Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]:
|
| 18 |
+
- H[:, :, 0]: Ground truth channel (complex values)
|
| 19 |
+
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
|
| 20 |
+
- H[:, :, 2]: Unused (reserved for future use)
|
| 21 |
+
|
| 22 |
+
The dataset extracts pilot values from the LS estimates and provides metadata from the filename
|
| 23 |
+
for adaptive channel estimation models.
|
| 24 |
+
"""
|
| 25 |
from pathlib import Path
|
| 26 |
from typing import Callable, List, Optional, Tuple, Union
|
| 27 |
|
| 28 |
import scipy.io as sio
|
| 29 |
import torch
|
| 30 |
from torch.utils.data import Dataset, DataLoader
|
| 31 |
+
from pydantic import BaseModel, Field
|
| 32 |
|
| 33 |
from src.utils import extract_values
|
| 34 |
|
| 35 |
__all__ = ['MatDataset', 'get_test_dataloaders']
|
| 36 |
|
| 37 |
|
| 38 |
+
class PilotDimensions(BaseModel):
|
|
|
|
| 39 |
"""Container for pilot signal dimensions.
|
| 40 |
|
| 41 |
Stores and validates the dimensions of pilot signals used in channel estimation.
|
|
|
|
| 44 |
num_subcarriers: Number of subcarriers in the pilot signal
|
| 45 |
num_ofdm_symbols: Number of OFDM symbols in the pilot signal
|
| 46 |
"""
|
| 47 |
+
num_subcarriers: int = Field(..., gt=0, description="Number of subcarriers in the pilot signal")
|
| 48 |
+
num_ofdm_symbols: int = Field(..., gt=0, description="Number of OFDM symbols in the pilot signal")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def as_tuple(self) -> Tuple[int, int]:
|
| 51 |
"""Return dimensions as a tuple.
|
|
|
|
| 61 |
|
| 62 |
Processes .mat files containing channel estimation data and converts them into
|
| 63 |
PyTorch complex tensors for channel estimation tasks.
|
| 64 |
+
|
| 65 |
+
Expected File Format:
|
| 66 |
+
- Files must be named according to the pattern:
|
| 67 |
+
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
|
| 68 |
+
- Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]
|
| 69 |
+
- H[:, :, 0]: Ground truth channel (complex values)
|
| 70 |
+
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
|
| 71 |
+
- H[:, :, 2]: Bilinear interpolated LS channel estimate
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
For each sample, returns a tuple of:
|
| 75 |
+
- Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
|
| 76 |
+
- Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
|
| 77 |
+
- Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
|
| 78 |
"""
|
| 79 |
|
| 80 |
def __init__(
|
|
|
|
| 86 |
"""Initialize the MatDataset.
|
| 87 |
|
| 88 |
Args:
|
| 89 |
+
data_dir: Path to the directory containing the dataset (should contain .mat files).
|
| 90 |
pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
|
| 91 |
transform: Optional transformation to apply to samples.
|
| 92 |
|
| 93 |
Raises:
|
|
|
|
| 94 |
FileNotFoundError: If data_dir doesn't exist.
|
| 95 |
+
ValueError: If no .mat files are found in data_dir.
|
| 96 |
"""
|
| 97 |
self.data_dir = Path(data_dir)
|
| 98 |
+
self.pilot_dims = PilotDimensions(num_subcarriers=pilot_dims[0], num_ofdm_symbols=pilot_dims[1])
|
| 99 |
self.transform = transform
|
| 100 |
|
| 101 |
if not self.data_dir.exists():
|
|
|
|
| 115 |
|
| 116 |
def _process_channel_data(
|
| 117 |
self,
|
|
|
|
| 118 |
mat_data: dict
|
| 119 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 120 |
"""Process channel data and extract pilot values from LS estimates.
|
|
|
|
| 123 |
returning complex-valued tensors for both estimate and ground truth.
|
| 124 |
|
| 125 |
Args:
|
| 126 |
+
mat_data: Loaded .mat file data containing 'H' variable
|
|
|
|
| 127 |
|
| 128 |
Returns:
|
| 129 |
Tuple of (pilot LS estimate, ground truth channel)
|
|
|
|
| 132 |
ValueError: If the data format is unexpected or processing fails
|
| 133 |
"""
|
| 134 |
try:
|
| 135 |
+
# Extract ground truth channel
|
| 136 |
+
h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
|
| 137 |
+
|
| 138 |
# Extract LS channel estimate with zero entries
|
| 139 |
hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)
|
| 140 |
|
|
|
|
| 171 |
|
| 172 |
Returns:
|
| 173 |
Tuple containing:
|
| 174 |
+
- Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
|
| 175 |
+
- Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
|
| 176 |
+
- Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
|
| 177 |
+
All metadata values are torch.Tensor except channel_type which is a list
|
| 178 |
|
| 179 |
Raises:
|
| 180 |
ValueError: If file format is invalid or processing fails.
|
|
|
|
| 184 |
raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
|
| 185 |
|
| 186 |
try:
|
|
|
|
| 187 |
mat_data = sio.loadmat(self.file_list[idx])
|
| 188 |
if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
|
| 189 |
raise ValueError("Invalid .mat file format: missing required data")
|
| 190 |
|
|
|
|
|
|
|
|
|
|
| 191 |
# Process channel data to extract pilot estimates
|
| 192 |
+
h_est, h_ideal = self._process_channel_data(mat_data)
|
| 193 |
|
| 194 |
# Extract metadata from filename
|
| 195 |
meta_data = extract_values(self.file_list[idx].name)
|
|
|
|
| 209 |
|
| 210 |
def get_test_dataloaders(
|
| 211 |
dataset_dir: Union[str, Path],
|
| 212 |
+
pilot_dims: List[int],
|
| 213 |
+
batch_size: int
|
| 214 |
) -> List[Tuple[str, DataLoader]]:
|
| 215 |
"""Create DataLoaders for each subdirectory in the dataset directory.
|
| 216 |
|
|
|
|
| 218 |
all subdirectories in the specified dataset directory, useful for testing
|
| 219 |
across multiple test conditions or scenarios.
|
| 220 |
|
| 221 |
+
Expected Directory Structure:
|
| 222 |
+
dataset_dir/
|
| 223 |
+
├── DS_50/ # Delay Spread = 50
|
| 224 |
+
│ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
| 225 |
+
│ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
| 226 |
+
│ └── ...
|
| 227 |
+
├── DS_100/ # Delay Spread = 100
|
| 228 |
+
│ ├── 1_SNR-20_DS-100_DOP-500_N-3_TDL-A.mat
|
| 229 |
+
│ └── ...
|
| 230 |
+
├── SNR_10/ # SNR = 10 dB
|
| 231 |
+
│ ├── 1_SNR-10_DS-50_DOP-500_N-3_TDL-A.mat
|
| 232 |
+
│ └── ...
|
| 233 |
+
└── ...
|
| 234 |
+
|
| 235 |
+
Each subdirectory should contain .mat files with the naming convention:
|
| 236 |
+
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
|
| 237 |
+
|
| 238 |
Args:
|
| 239 |
dataset_dir: Path to main directory containing dataset subdirectories
|
| 240 |
+
pilot_dims: List of [num_subcarriers, num_ofdm_symbols] for pilot dimensions
|
| 241 |
+
batch_size: Number of samples per batch
|
|
|
|
| 242 |
|
| 243 |
Returns:
|
| 244 |
List of tuples containing (subdirectory_name, corresponding_dataloader)
|
| 245 |
|
| 246 |
Raises:
|
| 247 |
FileNotFoundError: If dataset_dir doesn't exist
|
| 248 |
+
ValueError: If no valid subdirectories are found
|
| 249 |
"""
|
| 250 |
dataset_dir = Path(dataset_dir)
|
| 251 |
if not dataset_dir.exists():
|
| 252 |
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
|
| 253 |
|
|
|
|
|
|
|
|
|
|
| 254 |
subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
|
| 255 |
if not subdirs:
|
| 256 |
raise ValueError(f"No subdirectories found in {dataset_dir}")
|
|
|
|
| 260 |
subdir.name,
|
| 261 |
MatDataset(
|
| 262 |
subdir,
|
| 263 |
+
pilot_dims
|
| 264 |
)
|
| 265 |
)
|
| 266 |
for subdir in subdirs
|
|
|
|
| 269 |
return [
|
| 270 |
(name, DataLoader(
|
| 271 |
dataset,
|
| 272 |
+
batch_size=batch_size,
|
| 273 |
+
shuffle=False, # no shuffling for testing
|
| 274 |
num_workers=0
|
| 275 |
))
|
| 276 |
for name, dataset in test_datasets
|
src/main.py
CHANGED
|
@@ -5,6 +5,36 @@ Main entry point for OFDM channel estimation model training.
|
|
| 5 |
This script provides the command-line interface for training OFDM channel estimation
|
| 6 |
models. It loads configuration files, parses command-line arguments, and initiates
|
| 7 |
the training process.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import logging
|
|
|
|
| 5 |
This script provides the command-line interface for training OFDM channel estimation
|
| 6 |
models. It loads configuration files, parses command-line arguments, and initiates
|
| 7 |
the training process.
|
| 8 |
+
|
| 9 |
+
Dataset Requirements:
|
| 10 |
+
The training script expects datasets with the following structure:
|
| 11 |
+
|
| 12 |
+
Training/Validation Sets:
|
| 13 |
+
Directory containing .mat files with naming convention:
|
| 14 |
+
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
|
| 15 |
+
|
| 16 |
+
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
| 17 |
+
|
| 18 |
+
Test Sets:
|
| 19 |
+
Directory with subdirectories for different test conditions:
|
| 20 |
+
test_set/
|
| 21 |
+
├── DS_test_set/ # Delay Spread tests
|
| 22 |
+
│ ├── DS_50/
|
| 23 |
+
│ ├── DS_100/
|
| 24 |
+
│ └── ...
|
| 25 |
+
├── SNR_test_set/ # SNR tests
|
| 26 |
+
│ ├── SNR_10/
|
| 27 |
+
│ ├── SNR_20/
|
| 28 |
+
│ └── ...
|
| 29 |
+
└── MDS_test_set/ # Multi-Doppler tests
|
| 30 |
+
├── DOP_200/
|
| 31 |
+
├── DOP_400/
|
| 32 |
+
└── ...
|
| 33 |
+
|
| 34 |
+
Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
|
| 35 |
+
- H[:, :, 0]: Ground truth channel
|
| 36 |
+
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
|
| 37 |
+
- H[:, :, 2]: Unused (reserved)
|
| 38 |
"""
|
| 39 |
|
| 40 |
import logging
|
src/main/parser.py
CHANGED
|
@@ -7,10 +7,11 @@ their types, default values, and validation rules to ensure proper configuration
|
|
| 7 |
of training runs.
|
| 8 |
"""
|
| 9 |
|
| 10 |
-
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
| 12 |
import argparse
|
| 13 |
from enum import Enum
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class LossType(Enum):
|
|
@@ -20,8 +21,7 @@ class LossType(Enum):
|
|
| 20 |
HUBER = "huber"
|
| 21 |
|
| 22 |
|
| 23 |
-
|
| 24 |
-
class TrainingArguments:
|
| 25 |
"""Container for OFDM model training arguments.
|
| 26 |
|
| 27 |
Stores, validates, and provides access to all parameters needed for
|
|
@@ -57,45 +57,34 @@ class TrainingArguments:
|
|
| 57 |
"""
|
| 58 |
|
| 59 |
# Model Configuration
|
| 60 |
-
model_name: str
|
| 61 |
-
system_config_path: Path
|
| 62 |
-
model_config_path: Path
|
| 63 |
|
| 64 |
# Dataset Paths
|
| 65 |
-
train_set: Path
|
| 66 |
-
val_set: Path
|
| 67 |
-
test_set: Path
|
| 68 |
|
| 69 |
# Experiment Settings
|
| 70 |
-
exp_id: str
|
| 71 |
-
python_log_level: str = "INFO"
|
| 72 |
-
tensorboard_log_dir: Path = Path("runs")
|
| 73 |
|
| 74 |
# Training Hyperparameters
|
| 75 |
-
batch_size: int = 64
|
| 76 |
-
lr: float = 1e-3
|
| 77 |
-
max_epoch: int = 10
|
| 78 |
-
patience: int = 3
|
| 79 |
-
loss_type: LossType = LossType.MSE
|
| 80 |
-
return_type: str = "complex"
|
| 81 |
|
| 82 |
# Hardware & Evaluation
|
| 83 |
-
cuda: int = 0
|
| 84 |
-
test_every_n: int = 10
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
Runs multiple validation checks on the provided arguments to ensure
|
| 90 |
-
they are consistent and valid for training.
|
| 91 |
-
|
| 92 |
-
Raises:
|
| 93 |
-
ValueError: If any validation check fails
|
| 94 |
-
"""
|
| 95 |
-
self._validate_paths()
|
| 96 |
-
self._validate_numeric_args()
|
| 97 |
-
|
| 98 |
-
def _validate_paths(self) -> None:
|
| 99 |
"""Validate path-related arguments.
|
| 100 |
|
| 101 |
Checks that the config files exist and have the correct extension.
|
|
@@ -115,33 +104,7 @@ class TrainingArguments:
|
|
| 115 |
if not self.model_config_path.suffix == '.yaml':
|
| 116 |
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
| 117 |
|
| 118 |
-
|
| 119 |
-
"""Validate numeric arguments.
|
| 120 |
-
|
| 121 |
-
Ensures that all numeric parameters have appropriate values:
|
| 122 |
-
- test_every_n, max_epoch, patience, batch_size, lr must be positive
|
| 123 |
-
- cuda must be non-negative
|
| 124 |
-
|
| 125 |
-
Raises:
|
| 126 |
-
ValueError: If any numeric argument has an invalid value
|
| 127 |
-
"""
|
| 128 |
-
if self.test_every_n <= 0:
|
| 129 |
-
raise ValueError(f"test_every_n must be positive, got: {self.test_every_n}")
|
| 130 |
-
|
| 131 |
-
if self.max_epoch <= 0:
|
| 132 |
-
raise ValueError(f"max_epoch must be positive, got: {self.max_epoch}")
|
| 133 |
-
|
| 134 |
-
if self.patience <= 0:
|
| 135 |
-
raise ValueError(f"patience must be positive, got: {self.patience}")
|
| 136 |
-
|
| 137 |
-
if self.batch_size <= 0:
|
| 138 |
-
raise ValueError(f"batch_size must be positive, got: {self.batch_size}")
|
| 139 |
-
|
| 140 |
-
if self.cuda < 0:
|
| 141 |
-
raise ValueError(f"cuda must be non-negative, got: {self.cuda}")
|
| 142 |
-
|
| 143 |
-
if self.lr <= 0:
|
| 144 |
-
raise ValueError(f"lr must be positive, got: {self.lr}")
|
| 145 |
|
| 146 |
|
| 147 |
def parse_arguments() -> TrainingArguments:
|
|
@@ -278,7 +241,25 @@ def parse_arguments() -> TrainingArguments:
|
|
| 278 |
args = parser.parse_args()
|
| 279 |
|
| 280 |
# Convert loss_type string to enum
|
| 281 |
-
|
| 282 |
|
| 283 |
# Create and validate TrainingArguments
|
| 284 |
-
return TrainingArguments(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
of training runs.
|
| 8 |
"""
|
| 9 |
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
import argparse
|
| 12 |
from enum import Enum
|
| 13 |
+
from pydantic import BaseModel, Field, model_validator
|
| 14 |
+
from typing import Self
|
| 15 |
|
| 16 |
|
| 17 |
class LossType(Enum):
|
|
|
|
| 21 |
HUBER = "huber"
|
| 22 |
|
| 23 |
|
| 24 |
+
class TrainingArguments(BaseModel):
|
|
|
|
| 25 |
"""Container for OFDM model training arguments.
|
| 26 |
|
| 27 |
Stores, validates, and provides access to all parameters needed for
|
|
|
|
| 57 |
"""
|
| 58 |
|
| 59 |
# Model Configuration
|
| 60 |
+
model_name: str = Field(..., description="Model type to train")
|
| 61 |
+
system_config_path: Path = Field(..., description="Path to OFDM system configuration file")
|
| 62 |
+
model_config_path: Path = Field(..., description="Path to model configuration file")
|
| 63 |
|
| 64 |
# Dataset Paths
|
| 65 |
+
train_set: Path = Field(..., description="Training dataset folder path")
|
| 66 |
+
val_set: Path = Field(..., description="Validation dataset folder path")
|
| 67 |
+
test_set: Path = Field(..., description="Test dataset folder path")
|
| 68 |
|
| 69 |
# Experiment Settings
|
| 70 |
+
exp_id: str = Field(..., description="Experiment identifier for log folder naming")
|
| 71 |
+
python_log_level: str = Field(default="INFO", description="Logger level for python logging module")
|
| 72 |
+
tensorboard_log_dir: Path = Field(default=Path("runs"), description="Directory for tensorboard logs")
|
| 73 |
|
| 74 |
# Training Hyperparameters
|
| 75 |
+
batch_size: int = Field(default=64, gt=0, description="Training batch size")
|
| 76 |
+
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
| 77 |
+
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
| 78 |
+
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
| 79 |
+
loss_type: LossType = Field(default=LossType.MSE, description="Loss function type")
|
| 80 |
+
return_type: str = Field(default="complex", description="Type of data to return from dataset")
|
| 81 |
|
| 82 |
# Hardware & Evaluation
|
| 83 |
+
cuda: int = Field(default=0, ge=0, description="CUDA device index (0 for single GPU)")
|
| 84 |
+
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
| 85 |
|
| 86 |
+
@model_validator(mode='after')
|
| 87 |
+
def validate_paths(self) -> Self:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""Validate path-related arguments.
|
| 89 |
|
| 90 |
Checks that the config files exist and have the correct extension.
|
|
|
|
| 104 |
if not self.model_config_path.suffix == '.yaml':
|
| 105 |
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
| 106 |
|
| 107 |
+
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def parse_arguments() -> TrainingArguments:
|
|
|
|
| 241 |
args = parser.parse_args()
|
| 242 |
|
| 243 |
# Convert loss_type string to enum
|
| 244 |
+
loss_type = LossType(args.loss_type)
|
| 245 |
|
| 246 |
# Create and validate TrainingArguments
|
| 247 |
+
return TrainingArguments(
|
| 248 |
+
model_name=args.model_name,
|
| 249 |
+
system_config_path=args.system_config_path,
|
| 250 |
+
model_config_path=args.model_config_path,
|
| 251 |
+
train_set=args.train_set,
|
| 252 |
+
val_set=args.val_set,
|
| 253 |
+
test_set=args.test_set,
|
| 254 |
+
exp_id=args.exp_id,
|
| 255 |
+
python_log_level=args.python_log_level,
|
| 256 |
+
tensorboard_log_dir=args.tensorboard_log_dir,
|
| 257 |
+
batch_size=args.batch_size,
|
| 258 |
+
lr=args.lr,
|
| 259 |
+
max_epoch=args.max_epoch,
|
| 260 |
+
patience=args.patience,
|
| 261 |
+
loss_type=loss_type,
|
| 262 |
+
return_type=args.return_type,
|
| 263 |
+
cuda=args.cuda,
|
| 264 |
+
test_every_n=args.test_every_n
|
| 265 |
+
)
|
src/main/trainer.py
CHANGED
|
@@ -163,15 +163,18 @@ class ModelTrainer:
|
|
| 163 |
test_loaders = {
|
| 164 |
"DS": get_test_dataloaders(
|
| 165 |
self.args.test_set / "DS_test_set",
|
| 166 |
-
|
|
|
|
| 167 |
),
|
| 168 |
"MDS": get_test_dataloaders(
|
| 169 |
self.args.test_set / "MDS_test_set",
|
| 170 |
-
|
|
|
|
| 171 |
),
|
| 172 |
"SNR": get_test_dataloaders(
|
| 173 |
self.args.test_set / "SNR_test_set",
|
| 174 |
-
|
|
|
|
| 175 |
),
|
| 176 |
}
|
| 177 |
return train_loader, val_loader, test_loaders
|
|
|
|
| 163 |
test_loaders = {
|
| 164 |
"DS": get_test_dataloaders(
|
| 165 |
self.args.test_set / "DS_test_set",
|
| 166 |
+
pilot_dims,
|
| 167 |
+
self.args.batch_size
|
| 168 |
),
|
| 169 |
"MDS": get_test_dataloaders(
|
| 170 |
self.args.test_set / "MDS_test_set",
|
| 171 |
+
pilot_dims,
|
| 172 |
+
self.args.batch_size
|
| 173 |
),
|
| 174 |
"SNR": get_test_dataloaders(
|
| 175 |
self.args.test_set / "SNR_test_set",
|
| 176 |
+
pilot_dims,
|
| 177 |
+
self.args.batch_size
|
| 178 |
),
|
| 179 |
}
|
| 180 |
return train_loader, val_loader, test_loaders
|
src/utils.py
CHANGED
|
@@ -70,19 +70,28 @@ def extract_values(file_name):
|
|
| 70 |
Extract channel information from a file name.
|
| 71 |
|
| 72 |
Parses file names with format:
|
| 73 |
-
'{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
Args:
|
| 76 |
file_name: The file name from which to extract channel information
|
| 77 |
|
| 78 |
Returns:
|
| 79 |
tuple: A tuple containing:
|
| 80 |
-
- file_number (torch.Tensor): The file number
|
| 81 |
-
- snr (torch.Tensor): Signal-to-noise ratio value
|
| 82 |
- delay_spread (torch.Tensor): Delay spread value
|
| 83 |
- max_doppler_shift (torch.Tensor): Maximum Doppler shift value
|
| 84 |
- pilot_placement_frequency (torch.Tensor): Pilot placement frequency
|
| 85 |
-
- channel_type (list): The channel type
|
| 86 |
|
| 87 |
Raises:
|
| 88 |
ValueError: If the file name does not match the expected pattern
|
|
|
|
| 70 |
Extract channel information from a file name.
|
| 71 |
|
| 72 |
Parses file names with format:
|
| 73 |
+
'{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat'
|
| 74 |
+
|
| 75 |
+
Example:
|
| 76 |
+
For filename "1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat":
|
| 77 |
+
- file_number: 1
|
| 78 |
+
- snr: 20 (Signal-to-Noise Ratio in dB)
|
| 79 |
+
- delay_spread: 50 (Delay Spread)
|
| 80 |
+
- doppler: 500 (Maximum Doppler Shift)
|
| 81 |
+
- pilot_freq: 3 (Pilot placement frequency)
|
| 82 |
+
- channel_type: TDL-A (Channel model type)
|
| 83 |
|
| 84 |
Args:
|
| 85 |
file_name: The file name from which to extract channel information
|
| 86 |
|
| 87 |
Returns:
|
| 88 |
tuple: A tuple containing:
|
| 89 |
+
- file_number (torch.Tensor): The file number (sequential identifier)
|
| 90 |
+
- snr (torch.Tensor): Signal-to-noise ratio value in dB
|
| 91 |
- delay_spread (torch.Tensor): Delay spread value
|
| 92 |
- max_doppler_shift (torch.Tensor): Maximum Doppler shift value
|
| 93 |
- pilot_placement_frequency (torch.Tensor): Pilot placement frequency
|
| 94 |
+
- channel_type (list): The channel type (e.g., ['TDL-A'])
|
| 95 |
|
| 96 |
Raises:
|
| 97 |
ValueError: If the file name does not match the expected pattern
|