BerkIGuler commited on
Commit
9727e5e
·
1 Parent(s): 8a42349

minor fixes fpr src/dataset

Browse files
Files changed (5) hide show
  1. src/data/dataset.py +79 -40
  2. src/main.py +30 -0
  3. src/main/parser.py +43 -62
  4. src/main/trainer.py +6 -3
  5. src/utils.py +13 -4
src/data/dataset.py CHANGED
@@ -1,19 +1,41 @@
1
- """Module for loading and processing .mat files containing channel estimates for PyTorch."""
2
- from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from pathlib import Path
4
  from typing import Callable, List, Optional, Tuple, Union
5
 
6
  import scipy.io as sio
7
  import torch
8
  from torch.utils.data import Dataset, DataLoader
 
9
 
10
  from src.utils import extract_values
11
 
12
  __all__ = ['MatDataset', 'get_test_dataloaders']
13
 
14
 
15
- @dataclass
16
- class PilotDimensions:
17
  """Container for pilot signal dimensions.
18
 
19
  Stores and validates the dimensions of pilot signals used in channel estimation.
@@ -22,17 +44,8 @@ class PilotDimensions:
22
  num_subcarriers: Number of subcarriers in the pilot signal
23
  num_ofdm_symbols: Number of OFDM symbols in the pilot signal
24
  """
25
- num_subcarriers: int
26
- num_ofdm_symbols: int
27
-
28
- def __post_init__(self):
29
- """Validate dimensions after initialization.
30
-
31
- Raises:
32
- ValueError: If either dimension is not a positive integer
33
- """
34
- if self.num_subcarriers <= 0 or self.num_ofdm_symbols <= 0:
35
- raise ValueError("Pilot dimensions must be positive integers")
36
 
37
  def as_tuple(self) -> Tuple[int, int]:
38
  """Return dimensions as a tuple.
@@ -48,6 +61,20 @@ class MatDataset(Dataset):
48
 
49
  Processes .mat files containing channel estimation data and converts them into
50
  PyTorch complex tensors for channel estimation tasks.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
 
53
  def __init__(
@@ -59,16 +86,16 @@ class MatDataset(Dataset):
59
  """Initialize the MatDataset.
60
 
61
  Args:
62
- data_dir: Path to the directory containing the dataset.
63
  pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
64
  transform: Optional transformation to apply to samples.
65
 
66
  Raises:
67
- ValueError: If pilot dimensions are invalid.
68
  FileNotFoundError: If data_dir doesn't exist.
 
69
  """
70
  self.data_dir = Path(data_dir)
71
- self.pilot_dims = PilotDimensions(*pilot_dims)
72
  self.transform = transform
73
 
74
  if not self.data_dir.exists():
@@ -88,7 +115,6 @@ class MatDataset(Dataset):
88
 
89
  def _process_channel_data(
90
  self,
91
- h_ideal: torch.Tensor,
92
  mat_data: dict
93
  ) -> Tuple[torch.Tensor, torch.Tensor]:
94
  """Process channel data and extract pilot values from LS estimates.
@@ -97,8 +123,7 @@ class MatDataset(Dataset):
97
  returning complex-valued tensors for both estimate and ground truth.
98
 
99
  Args:
100
- h_ideal: Ground truth channel tensor
101
- mat_data: Loaded .mat file data
102
 
103
  Returns:
104
  Tuple of (pilot LS estimate, ground truth channel)
@@ -107,6 +132,9 @@ class MatDataset(Dataset):
107
  ValueError: If the data format is unexpected or processing fails
108
  """
109
  try:
 
 
 
110
  # Extract LS channel estimate with zero entries
111
  hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)
112
 
@@ -143,9 +171,10 @@ class MatDataset(Dataset):
143
 
144
  Returns:
145
  Tuple containing:
146
- - Pilot LS channel estimate (complex tensor)
147
- - Ground truth channel estimate (complex tensor)
148
- - Metadata extracted from filename
 
149
 
150
  Raises:
151
  ValueError: If file format is invalid or processing fails.
@@ -155,16 +184,12 @@ class MatDataset(Dataset):
155
  raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
156
 
157
  try:
158
- # Load .mat file
159
  mat_data = sio.loadmat(self.file_list[idx])
160
  if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
161
  raise ValueError("Invalid .mat file format: missing required data")
162
 
163
- # Extract ground truth channel
164
- h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
165
-
166
  # Process channel data to extract pilot estimates
167
- h_est, h_ideal = self._process_channel_data(h_ideal, mat_data)
168
 
169
  # Extract metadata from filename
170
  meta_data = extract_values(self.file_list[idx].name)
@@ -184,7 +209,8 @@ class MatDataset(Dataset):
184
 
185
  def get_test_dataloaders(
186
  dataset_dir: Union[str, Path],
187
- params: dict
 
188
  ) -> List[Tuple[str, DataLoader]]:
189
  """Create DataLoaders for each subdirectory in the dataset directory.
190
 
@@ -192,26 +218,39 @@ def get_test_dataloaders(
192
  all subdirectories in the specified dataset directory, useful for testing
193
  across multiple test conditions or scenarios.
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  Args:
196
  dataset_dir: Path to main directory containing dataset subdirectories
197
- params: Configuration parameters including:
198
- - pilot_dims: List of [num_subcarriers, num_ofdm_symbols]
199
- - batch_size: Number of samples per batch
200
 
201
  Returns:
202
  List of tuples containing (subdirectory_name, corresponding_dataloader)
203
 
204
  Raises:
205
  FileNotFoundError: If dataset_dir doesn't exist
206
- ValueError: If params are invalid or no valid subdirectories are found
207
  """
208
  dataset_dir = Path(dataset_dir)
209
  if not dataset_dir.exists():
210
  raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
211
 
212
- if not isinstance(params, dict) or "pilot_dims" not in params or "batch_size" not in params:
213
- raise ValueError("params must be a dict containing 'pilot_dims' and 'batch_size'")
214
-
215
  subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
216
  if not subdirs:
217
  raise ValueError(f"No subdirectories found in {dataset_dir}")
@@ -221,7 +260,7 @@ def get_test_dataloaders(
221
  subdir.name,
222
  MatDataset(
223
  subdir,
224
- params["pilot_dims"]
225
  )
226
  )
227
  for subdir in subdirs
@@ -230,8 +269,8 @@ def get_test_dataloaders(
230
  return [
231
  (name, DataLoader(
232
  dataset,
233
- batch_size=params["batch_size"],
234
- shuffle=False,
235
  num_workers=0
236
  ))
237
  for name, dataset in test_datasets
 
1
+ """Module for loading and processing .mat files containing channel estimates for PyTorch.
2
+
3
+ This module expects .mat files with a specific naming convention and internal structure:
4
+
5
+ File Naming Convention:
6
+ Files must follow the pattern: {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
7
+
8
+ Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
9
+ - file_number: Sequential file identifier
10
+ - SNR: Signal-to-Noise Ratio in dB
11
+ - DS: Delay Spread
12
+ - DOP: Maximum Doppler Shift
13
+ - N: Pilot placement frequency
14
+ - channel_type: Channel model type (e.g., TDL-A)
15
+
16
+ File Content Structure:
17
+ Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]:
18
+ - H[:, :, 0]: Ground truth channel (complex values)
19
+ - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
20
+ - H[:, :, 2]: Unused (reserved for future use)
21
+
22
+ The dataset extracts pilot values from the LS estimates and provides metadata from the filename
23
+ for adaptive channel estimation models.
24
+ """
25
  from pathlib import Path
26
  from typing import Callable, List, Optional, Tuple, Union
27
 
28
  import scipy.io as sio
29
  import torch
30
  from torch.utils.data import Dataset, DataLoader
31
+ from pydantic import BaseModel, Field
32
 
33
  from src.utils import extract_values
34
 
35
  __all__ = ['MatDataset', 'get_test_dataloaders']
36
 
37
 
38
+ class PilotDimensions(BaseModel):
 
39
  """Container for pilot signal dimensions.
40
 
41
  Stores and validates the dimensions of pilot signals used in channel estimation.
 
44
  num_subcarriers: Number of subcarriers in the pilot signal
45
  num_ofdm_symbols: Number of OFDM symbols in the pilot signal
46
  """
47
+ num_subcarriers: int = Field(..., gt=0, description="Number of subcarriers in the pilot signal")
48
+ num_ofdm_symbols: int = Field(..., gt=0, description="Number of OFDM symbols in the pilot signal")
 
 
 
 
 
 
 
 
 
49
 
50
  def as_tuple(self) -> Tuple[int, int]:
51
  """Return dimensions as a tuple.
 
61
 
62
  Processes .mat files containing channel estimation data and converts them into
63
  PyTorch complex tensors for channel estimation tasks.
64
+
65
+ Expected File Format:
66
+ - Files must be named according to the pattern:
67
+ {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
68
+ - Each .mat file must contain a variable 'H' with shape [subcarriers, symbols, 3]
69
+ - H[:, :, 0]: Ground truth channel (complex values)
70
+ - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
71
+ - H[:, :, 2]: Bilinear interpolated LS channel estimate
72
+
73
+ Returns:
74
+ For each sample, returns a tuple of:
75
+ - Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
76
+ - Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
77
+ - Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
78
  """
79
 
80
  def __init__(
 
86
  """Initialize the MatDataset.
87
 
88
  Args:
89
+ data_dir: Path to the directory containing the dataset (should contain .mat files).
90
  pilot_dims: Dimensions of pilot data as [num_subcarriers, num_ofdm_symbols].
91
  transform: Optional transformation to apply to samples.
92
 
93
  Raises:
 
94
  FileNotFoundError: If data_dir doesn't exist.
95
+ ValueError: If no .mat files are found in data_dir.
96
  """
97
  self.data_dir = Path(data_dir)
98
+ self.pilot_dims = PilotDimensions(num_subcarriers=pilot_dims[0], num_ofdm_symbols=pilot_dims[1])
99
  self.transform = transform
100
 
101
  if not self.data_dir.exists():
 
115
 
116
  def _process_channel_data(
117
  self,
 
118
  mat_data: dict
119
  ) -> Tuple[torch.Tensor, torch.Tensor]:
120
  """Process channel data and extract pilot values from LS estimates.
 
123
  returning complex-valued tensors for both estimate and ground truth.
124
 
125
  Args:
126
+ mat_data: Loaded .mat file data containing 'H' variable
 
127
 
128
  Returns:
129
  Tuple of (pilot LS estimate, ground truth channel)
 
132
  ValueError: If the data format is unexpected or processing fails
133
  """
134
  try:
135
+ # Extract ground truth channel
136
+ h_ideal = torch.tensor(mat_data['H'][:, :, 0], dtype=torch.cfloat)
137
+
138
  # Extract LS channel estimate with zero entries
139
  hzero_ls = torch.tensor(mat_data['H'][:, :, 1], dtype=torch.cfloat)
140
 
 
171
 
172
  Returns:
173
  Tuple containing:
174
+ - Pilot LS channel estimate (complex tensor, shape [pilot_subcarriers, pilot_symbols])
175
+ - Ground truth channel estimate (complex tensor, shape [ofdm_subcarriers, ofdm_symbols])
176
+ - Metadata tuple: (file_number, snr, delay_spread, doppler, pilot_freq, channel_type)
177
+ All metadata values are torch.Tensor except channel_type which is a list
178
 
179
  Raises:
180
  ValueError: If file format is invalid or processing fails.
 
184
  raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")
185
 
186
  try:
 
187
  mat_data = sio.loadmat(self.file_list[idx])
188
  if 'H' not in mat_data or mat_data['H'].shape[-1] < 3:
189
  raise ValueError("Invalid .mat file format: missing required data")
190
 
 
 
 
191
  # Process channel data to extract pilot estimates
192
+ h_est, h_ideal = self._process_channel_data(mat_data)
193
 
194
  # Extract metadata from filename
195
  meta_data = extract_values(self.file_list[idx].name)
 
209
 
210
  def get_test_dataloaders(
211
  dataset_dir: Union[str, Path],
212
+ pilot_dims: List[int],
213
+ batch_size: int
214
  ) -> List[Tuple[str, DataLoader]]:
215
  """Create DataLoaders for each subdirectory in the dataset directory.
216
 
 
218
  all subdirectories in the specified dataset directory, useful for testing
219
  across multiple test conditions or scenarios.
220
 
221
+ Expected Directory Structure:
222
+ dataset_dir/
223
+ ├── DS_50/ # Delay Spread = 50
224
+ │ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
225
+ │ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
226
+ │ └── ...
227
+ ├── DS_100/ # Delay Spread = 100
228
+ │ ├── 1_SNR-20_DS-100_DOP-500_N-3_TDL-A.mat
229
+ │ └── ...
230
+ ├── SNR_10/ # SNR = 10 dB
231
+ │ ├── 1_SNR-10_DS-50_DOP-500_N-3_TDL-A.mat
232
+ │ └── ...
233
+ └── ...
234
+
235
+ Each subdirectory should contain .mat files with the naming convention:
236
+ {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
237
+
238
  Args:
239
  dataset_dir: Path to main directory containing dataset subdirectories
240
+ pilot_dims: List of [num_subcarriers, num_ofdm_symbols] for pilot dimensions
241
+ batch_size: Number of samples per batch
 
242
 
243
  Returns:
244
  List of tuples containing (subdirectory_name, corresponding_dataloader)
245
 
246
  Raises:
247
  FileNotFoundError: If dataset_dir doesn't exist
248
+ ValueError: If no valid subdirectories are found
249
  """
250
  dataset_dir = Path(dataset_dir)
251
  if not dataset_dir.exists():
252
  raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
253
 
 
 
 
254
  subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
255
  if not subdirs:
256
  raise ValueError(f"No subdirectories found in {dataset_dir}")
 
260
  subdir.name,
261
  MatDataset(
262
  subdir,
263
+ pilot_dims
264
  )
265
  )
266
  for subdir in subdirs
 
269
  return [
270
  (name, DataLoader(
271
  dataset,
272
+ batch_size=batch_size,
273
+ shuffle=False, # no shuffling for testing
274
  num_workers=0
275
  ))
276
  for name, dataset in test_datasets
src/main.py CHANGED
@@ -5,6 +5,36 @@ Main entry point for OFDM channel estimation model training.
5
  This script provides the command-line interface for training OFDM channel estimation
6
  models. It loads configuration files, parses command-line arguments, and initiates
7
  the training process.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  import logging
 
5
  This script provides the command-line interface for training OFDM channel estimation
6
  models. It loads configuration files, parses command-line arguments, and initiates
7
  the training process.
8
+
9
+ Dataset Requirements:
10
+ The training script expects datasets with the following structure:
11
+
12
+ Training/Validation Sets:
13
+ Directory containing .mat files with naming convention:
14
+ {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
15
+
16
+ Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
17
+
18
+ Test Sets:
19
+ Directory with subdirectories for different test conditions:
20
+ test_set/
21
+ ├── DS_test_set/ # Delay Spread tests
22
+ │ ├── DS_50/
23
+ │ ├── DS_100/
24
+ │ └── ...
25
+ ├── SNR_test_set/ # SNR tests
26
+ │ ├── SNR_10/
27
+ │ ├── SNR_20/
28
+ │ └── ...
29
+ └── MDS_test_set/ # Multi-Doppler tests
30
+ ├── DOP_200/
31
+ ├── DOP_400/
32
+ └── ...
33
+
34
+ Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
35
+ - H[:, :, 0]: Ground truth channel
36
+ - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
37
+ - H[:, :, 2]: Unused (reserved)
38
  """
39
 
40
  import logging
src/main/parser.py CHANGED
@@ -7,10 +7,11 @@ their types, default values, and validation rules to ensure proper configuration
7
  of training runs.
8
  """
9
 
10
- from dataclasses import dataclass
11
  from pathlib import Path
12
  import argparse
13
  from enum import Enum
 
 
14
 
15
 
16
  class LossType(Enum):
@@ -20,8 +21,7 @@ class LossType(Enum):
20
  HUBER = "huber"
21
 
22
 
23
- @dataclass
24
- class TrainingArguments:
25
  """Container for OFDM model training arguments.
26
 
27
  Stores, validates, and provides access to all parameters needed for
@@ -57,45 +57,34 @@ class TrainingArguments:
57
  """
58
 
59
  # Model Configuration
60
- model_name: str
61
- system_config_path: Path
62
- model_config_path: Path
63
 
64
  # Dataset Paths
65
- train_set: Path
66
- val_set: Path
67
- test_set: Path
68
 
69
  # Experiment Settings
70
- exp_id: str
71
- python_log_level: str = "INFO"
72
- tensorboard_log_dir: Path = Path("runs")
73
 
74
  # Training Hyperparameters
75
- batch_size: int = 64
76
- lr: float = 1e-3
77
- max_epoch: int = 10
78
- patience: int = 3
79
- loss_type: LossType = LossType.MSE
80
- return_type: str = "complex"
81
 
82
  # Hardware & Evaluation
83
- cuda: int = 0
84
- test_every_n: int = 10
85
 
86
- def __post_init__(self) -> None:
87
- """Validate arguments after initialization.
88
-
89
- Runs multiple validation checks on the provided arguments to ensure
90
- they are consistent and valid for training.
91
-
92
- Raises:
93
- ValueError: If any validation check fails
94
- """
95
- self._validate_paths()
96
- self._validate_numeric_args()
97
-
98
- def _validate_paths(self) -> None:
99
  """Validate path-related arguments.
100
 
101
  Checks that the config files exist and have the correct extension.
@@ -115,33 +104,7 @@ class TrainingArguments:
115
  if not self.model_config_path.suffix == '.yaml':
116
  raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
117
 
118
- def _validate_numeric_args(self) -> None:
119
- """Validate numeric arguments.
120
-
121
- Ensures that all numeric parameters have appropriate values:
122
- - test_every_n, max_epoch, patience, batch_size, lr must be positive
123
- - cuda must be non-negative
124
-
125
- Raises:
126
- ValueError: If any numeric argument has an invalid value
127
- """
128
- if self.test_every_n <= 0:
129
- raise ValueError(f"test_every_n must be positive, got: {self.test_every_n}")
130
-
131
- if self.max_epoch <= 0:
132
- raise ValueError(f"max_epoch must be positive, got: {self.max_epoch}")
133
-
134
- if self.patience <= 0:
135
- raise ValueError(f"patience must be positive, got: {self.patience}")
136
-
137
- if self.batch_size <= 0:
138
- raise ValueError(f"batch_size must be positive, got: {self.batch_size}")
139
-
140
- if self.cuda < 0:
141
- raise ValueError(f"cuda must be non-negative, got: {self.cuda}")
142
-
143
- if self.lr <= 0:
144
- raise ValueError(f"lr must be positive, got: {self.lr}")
145
 
146
 
147
  def parse_arguments() -> TrainingArguments:
@@ -278,7 +241,25 @@ def parse_arguments() -> TrainingArguments:
278
  args = parser.parse_args()
279
 
280
  # Convert loss_type string to enum
281
- args.loss_type = LossType(args.loss_type)
282
 
283
  # Create and validate TrainingArguments
284
- return TrainingArguments(**vars(args))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  of training runs.
8
  """
9
 
 
10
  from pathlib import Path
11
  import argparse
12
  from enum import Enum
13
+ from pydantic import BaseModel, Field, model_validator
14
+ from typing import Self
15
 
16
 
17
  class LossType(Enum):
 
21
  HUBER = "huber"
22
 
23
 
24
+ class TrainingArguments(BaseModel):
 
25
  """Container for OFDM model training arguments.
26
 
27
  Stores, validates, and provides access to all parameters needed for
 
57
  """
58
 
59
  # Model Configuration
60
+ model_name: str = Field(..., description="Model type to train")
61
+ system_config_path: Path = Field(..., description="Path to OFDM system configuration file")
62
+ model_config_path: Path = Field(..., description="Path to model configuration file")
63
 
64
  # Dataset Paths
65
+ train_set: Path = Field(..., description="Training dataset folder path")
66
+ val_set: Path = Field(..., description="Validation dataset folder path")
67
+ test_set: Path = Field(..., description="Test dataset folder path")
68
 
69
  # Experiment Settings
70
+ exp_id: str = Field(..., description="Experiment identifier for log folder naming")
71
+ python_log_level: str = Field(default="INFO", description="Logger level for python logging module")
72
+ tensorboard_log_dir: Path = Field(default=Path("runs"), description="Directory for tensorboard logs")
73
 
74
  # Training Hyperparameters
75
+ batch_size: int = Field(default=64, gt=0, description="Training batch size")
76
+ lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
77
+ max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
78
+ patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
79
+ loss_type: LossType = Field(default=LossType.MSE, description="Loss function type")
80
+ return_type: str = Field(default="complex", description="Type of data to return from dataset")
81
 
82
  # Hardware & Evaluation
83
+ cuda: int = Field(default=0, ge=0, description="CUDA device index (0 for single GPU)")
84
+ test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
85
 
86
+ @model_validator(mode='after')
87
+ def validate_paths(self) -> Self:
 
 
 
 
 
 
 
 
 
 
 
88
  """Validate path-related arguments.
89
 
90
  Checks that the config files exist and have the correct extension.
 
104
  if not self.model_config_path.suffix == '.yaml':
105
  raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
106
 
107
+ return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  def parse_arguments() -> TrainingArguments:
 
241
  args = parser.parse_args()
242
 
243
  # Convert loss_type string to enum
244
+ loss_type = LossType(args.loss_type)
245
 
246
  # Create and validate TrainingArguments
247
+ return TrainingArguments(
248
+ model_name=args.model_name,
249
+ system_config_path=args.system_config_path,
250
+ model_config_path=args.model_config_path,
251
+ train_set=args.train_set,
252
+ val_set=args.val_set,
253
+ test_set=args.test_set,
254
+ exp_id=args.exp_id,
255
+ python_log_level=args.python_log_level,
256
+ tensorboard_log_dir=args.tensorboard_log_dir,
257
+ batch_size=args.batch_size,
258
+ lr=args.lr,
259
+ max_epoch=args.max_epoch,
260
+ patience=args.patience,
261
+ loss_type=loss_type,
262
+ return_type=args.return_type,
263
+ cuda=args.cuda,
264
+ test_every_n=args.test_every_n
265
+ )
src/main/trainer.py CHANGED
@@ -163,15 +163,18 @@ class ModelTrainer:
163
  test_loaders = {
164
  "DS": get_test_dataloaders(
165
  self.args.test_set / "DS_test_set",
166
- {"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
 
167
  ),
168
  "MDS": get_test_dataloaders(
169
  self.args.test_set / "MDS_test_set",
170
- {"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
 
171
  ),
172
  "SNR": get_test_dataloaders(
173
  self.args.test_set / "SNR_test_set",
174
- {"pilot_dims": pilot_dims, "batch_size": self.args.batch_size}
 
175
  ),
176
  }
177
  return train_loader, val_loader, test_loaders
 
163
  test_loaders = {
164
  "DS": get_test_dataloaders(
165
  self.args.test_set / "DS_test_set",
166
+ pilot_dims,
167
+ self.args.batch_size
168
  ),
169
  "MDS": get_test_dataloaders(
170
  self.args.test_set / "MDS_test_set",
171
+ pilot_dims,
172
+ self.args.batch_size
173
  ),
174
  "SNR": get_test_dataloaders(
175
  self.args.test_set / "SNR_test_set",
176
+ pilot_dims,
177
+ self.args.batch_size
178
  ),
179
  }
180
  return train_loader, val_loader, test_loaders
src/utils.py CHANGED
@@ -70,19 +70,28 @@ def extract_values(file_name):
70
  Extract channel information from a file name.
71
 
72
  Parses file names with format:
73
- '{number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat'
 
 
 
 
 
 
 
 
 
74
 
75
  Args:
76
  file_name: The file name from which to extract channel information
77
 
78
  Returns:
79
  tuple: A tuple containing:
80
- - file_number (torch.Tensor): The file number
81
- - snr (torch.Tensor): Signal-to-noise ratio value
82
  - delay_spread (torch.Tensor): Delay spread value
83
  - max_doppler_shift (torch.Tensor): Maximum Doppler shift value
84
  - pilot_placement_frequency (torch.Tensor): Pilot placement frequency
85
- - channel_type (list): The channel type
86
 
87
  Raises:
88
  ValueError: If the file name does not match the expected pattern
 
70
  Extract channel information from a file name.
71
 
72
  Parses file names with format:
73
+ '{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat'
74
+
75
+ Example:
76
+ For filename "1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat":
77
+ - file_number: 1
78
+ - snr: 20 (Signal-to-Noise Ratio in dB)
79
+ - delay_spread: 50 (Delay Spread)
80
+ - doppler: 500 (Maximum Doppler Shift)
81
+ - pilot_freq: 3 (Pilot placement frequency)
82
+ - channel_type: TDL-A (Channel model type)
83
 
84
  Args:
85
  file_name: The file name from which to extract channel information
86
 
87
  Returns:
88
  tuple: A tuple containing:
89
+ - file_number (torch.Tensor): The file number (sequential identifier)
90
+ - snr (torch.Tensor): Signal-to-noise ratio value in dB
91
  - delay_spread (torch.Tensor): Delay spread value
92
  - max_doppler_shift (torch.Tensor): Maximum Doppler shift value
93
  - pilot_placement_frequency (torch.Tensor): Pilot placement frequency
94
+ - channel_type (list): The channel type (e.g., ['TDL-A'])
95
 
96
  Raises:
97
  ValueError: If the file name does not match the expected pattern