BerkIGuler commited on
Commit
71dbdc8
·
1 Parent(s): 687eaba

minor fixed on src/main/parser.py

Browse files
Files changed (4) hide show
  1. src/main.py +38 -7
  2. src/main/parser.py +16 -61
  3. src/main/train_helpers.py +9 -12
  4. src/main/trainer.py +15 -22
src/main.py CHANGED
@@ -32,13 +32,14 @@ Dataset Requirements:
32
  └── ...
33
 
34
  Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
35
- - H[:, :, 0]: Ground truth channel
36
- - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
37
- - H[:, :, 2]: Unused (reserved)
38
  """
39
 
40
  import logging
41
  import sys
 
42
  from pathlib import Path
43
 
44
  from src.main.parser import parse_arguments
@@ -47,18 +48,26 @@ from src.config import load_config
47
  from src.config.schemas import ModelConfig
48
 
49
 
50
- def setup_logging(log_level: str) -> None:
51
  """Set up logging configuration.
52
 
53
  Args:
54
  log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
 
 
55
  """
 
 
 
 
 
 
56
  logging.basicConfig(
57
  level=getattr(logging, log_level.upper()),
58
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
59
  handlers=[
60
  logging.StreamHandler(sys.stdout),
61
- logging.FileHandler('training.log')
62
  ]
63
  )
64
 
@@ -70,7 +79,7 @@ def main() -> None:
70
  args = parse_arguments()
71
 
72
  # Set up logging
73
- setup_logging(args.python_log_level)
74
  logger = logging.getLogger(__name__)
75
 
76
  logger.info("Starting OFDM channel estimation model training")
@@ -86,13 +95,35 @@ def main() -> None:
86
  args.model_config_path
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  logger.info("Configuration loaded successfully")
90
  logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
91
  logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
 
 
92
  if model_config.model_type == "linear":
93
  logger.info(f"Linear model with device: {model_config.device}")
 
 
 
 
 
 
 
94
  else:
95
- logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
96
 
97
  # Start training
98
  logger.info("Initializing training...")
 
32
  └── ...
33
 
34
  Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
35
+ - H[:, :, 0]: Ground truth channel (complex-valued channel matrix)
36
+ - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions (complex-valued) - used as input to models
37
+ - H[:, :, 2]: Bilinear interpolated LS channel estimate (complex-valued) - available but currently unused
38
  """
39
 
40
  import logging
41
  import sys
42
+ from datetime import datetime
43
  from pathlib import Path
44
 
45
  from src.main.parser import parse_arguments
 
48
  from src.config.schemas import ModelConfig
49
 
50
 
51
+ def setup_logging(log_level: str, log_dir: Path, exp_id: str) -> None:
52
  """Set up logging configuration.
53
 
54
  Args:
55
  log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
56
+ log_dir: Directory path for log files
57
+ exp_id: Experiment identifier for log file naming
58
  """
59
+ # Create logs directory if it doesn't exist
60
+ log_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ # Create log file path using exp_id for easy matching
63
+ log_file = log_dir / f"training_{exp_id}.log"
64
+
65
  logging.basicConfig(
66
  level=getattr(logging, log_level.upper()),
67
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
68
  handlers=[
69
  logging.StreamHandler(sys.stdout),
70
+ logging.FileHandler(log_file)
71
  ]
72
  )
73
 
 
79
  args = parse_arguments()
80
 
81
  # Set up logging
82
+ setup_logging(args.python_log_level, args.python_log_dir, args.exp_id)
83
  logger = logging.getLogger(__name__)
84
 
85
  logger.info("Starting OFDM channel estimation model training")
 
95
  args.model_config_path
96
  )
97
 
98
+ # Validate model type consistency
99
+ expected_model_types = {
100
+ "linear": "linear",
101
+ "fortitran": "fortitran",
102
+ "adafortitran": "adafortitran"
103
+ }
104
+
105
+ if args.model_name not in expected_model_types:
106
+ raise ValueError(f"Unknown model name: {args.model_name}. Expected one of: {list(expected_model_types.keys())}")
107
+
108
+ if model_config.model_type != expected_model_types[args.model_name]:
109
+ raise ValueError(f"Model type mismatch: config specifies '{model_config.model_type}' but model name is '{args.model_name}'")
110
+
111
  logger.info("Configuration loaded successfully")
112
  logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
113
  logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
114
+
115
+ # Log model-specific information
116
  if model_config.model_type == "linear":
117
  logger.info(f"Linear model with device: {model_config.device}")
118
+ elif model_config.model_type == "fortitran":
119
+ logger.info(f"FortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
120
+ logger.info(f"Channel adaptation: disabled")
121
+ elif model_config.model_type == "adafortitran":
122
+ logger.info(f"AdaFortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
123
+ logger.info(f"Channel adaptation: enabled")
124
+ logger.info(f"Adaptive token length: {model_config.adaptive_token_length}")
125
  else:
126
+ logger.warning(f"Unknown model type: {model_config.model_type}")
127
 
128
  # Start training
129
  logger.info("Initializing training...")
src/main/parser.py CHANGED
@@ -9,18 +9,10 @@ of training runs.
9
 
10
  from pathlib import Path
11
  import argparse
12
- from enum import Enum
13
  from pydantic import BaseModel, Field, model_validator
14
  from typing import Self
15
 
16
 
17
- class LossType(Enum):
18
- """Enumeration of supported loss functions."""
19
- MSE = "mse"
20
- MAE = "mae"
21
- HUBER = "huber"
22
-
23
-
24
  class TrainingArguments(BaseModel):
25
  """Container for OFDM model training arguments.
26
 
@@ -29,7 +21,7 @@ class TrainingArguments(BaseModel):
29
 
30
  Attributes:
31
  # Model Configuration
32
- model_name: Supports Linear, AdaFortiTran, or FortiTran training
33
  system_config_path: Path to OFDM system configuration file
34
  model_config_path: Path to model configuration file
35
 
@@ -42,17 +34,15 @@ class TrainingArguments(BaseModel):
42
  exp_id: Experiment identifier string
43
  python_log_level: Logging verbosity level
44
  tensorboard_log_dir: Directory for tensorboard logs
 
45
 
46
  # Training Hyperparameters
47
  batch_size: Number of samples per batch
48
  lr: Learning rate for optimizer
49
  max_epoch: Maximum number of training epochs
50
  patience: Early stopping patience in epochs
51
- loss_type: Type of loss function to use
52
- return_type: Type of data to return from dataset
53
 
54
- # Hardware & Evaluation
55
- cuda: CUDA device index
56
  test_every_n: Number of epochs between test evaluations
57
  """
58
 
@@ -70,17 +60,15 @@ class TrainingArguments(BaseModel):
70
  exp_id: str = Field(..., description="Experiment identifier for log folder naming")
71
  python_log_level: str = Field(default="INFO", description="Logger level for python logging module")
72
  tensorboard_log_dir: Path = Field(default=Path("runs"), description="Directory for tensorboard logs")
 
73
 
74
  # Training Hyperparameters
75
  batch_size: int = Field(default=64, gt=0, description="Training batch size")
76
  lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
77
  max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
78
  patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
79
- loss_type: LossType = Field(default=LossType.MSE, description="Loss function type")
80
- return_type: str = Field(default="complex", description="Type of data to return from dataset")
81
 
82
- # Hardware & Evaluation
83
- cuda: int = Field(default=0, ge=0, description="CUDA device index (0 for single GPU)")
84
  test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
85
 
86
  @model_validator(mode='after')
@@ -133,8 +121,8 @@ def parse_arguments() -> TrainingArguments:
133
  '--model_name',
134
  type=str,
135
  required=True,
136
- choices=['Linear', 'AdaFortiTran', 'FortiTran'],
137
- help='Model type to train (Linear, AdaFortiTran, or FortiTran)'
138
  )
139
  required.add_argument(
140
  '--system_config_path',
@@ -187,6 +175,12 @@ def parse_arguments() -> TrainingArguments:
187
  default="runs",
188
  help='Directory for tensorboard logs'
189
  )
 
 
 
 
 
 
190
  optional.add_argument(
191
  '--test_every_n',
192
  type=int,
@@ -211,55 +205,16 @@ def parse_arguments() -> TrainingArguments:
211
  default=64,
212
  help='Training batch size'
213
  )
214
- optional.add_argument(
215
- '--cuda',
216
- type=int,
217
- default=0,
218
- help='CUDA device index (0 for single GPU)'
219
- )
220
  optional.add_argument(
221
  '--lr',
222
  type=float,
223
  default=1e-3,
224
  help='Initial learning rate'
225
  )
226
- optional.add_argument(
227
- '--loss_type',
228
- type=str,
229
- default="mse",
230
- choices=['mse', 'mae', 'huber'],
231
- help='Loss function type'
232
- )
233
- optional.add_argument(
234
- '--return_type',
235
- type=str,
236
- default="complex",
237
- choices=['complex', 'real'],
238
- help='Type of data to return from dataset'
239
- )
240
 
241
- args = parser.parse_args()
242
 
243
- # Convert loss_type string to enum
244
- loss_type = LossType(args.loss_type)
245
 
246
  # Create and validate TrainingArguments
247
- return TrainingArguments(
248
- model_name=args.model_name,
249
- system_config_path=args.system_config_path,
250
- model_config_path=args.model_config_path,
251
- train_set=args.train_set,
252
- val_set=args.val_set,
253
- test_set=args.test_set,
254
- exp_id=args.exp_id,
255
- python_log_level=args.python_log_level,
256
- tensorboard_log_dir=args.tensorboard_log_dir,
257
- batch_size=args.batch_size,
258
- lr=args.lr,
259
- max_epoch=args.max_epoch,
260
- patience=args.patience,
261
- loss_type=loss_type,
262
- return_type=args.return_type,
263
- cuda=args.cuda,
264
- test_every_n=args.test_every_n
265
- )
 
9
 
10
  from pathlib import Path
11
  import argparse
 
12
  from pydantic import BaseModel, Field, model_validator
13
  from typing import Self
14
 
15
 
 
 
 
 
 
 
 
16
  class TrainingArguments(BaseModel):
17
  """Container for OFDM model training arguments.
18
 
 
21
 
22
  Attributes:
23
  # Model Configuration
24
+ model_name: Supports linear, adafortitran, or fortitran training
25
  system_config_path: Path to OFDM system configuration file
26
  model_config_path: Path to model configuration file
27
 
 
34
  exp_id: Experiment identifier string
35
  python_log_level: Logging verbosity level
36
  tensorboard_log_dir: Directory for tensorboard logs
37
+ python_log_dir: Directory for python logging files
38
 
39
  # Training Hyperparameters
40
  batch_size: Number of samples per batch
41
  lr: Learning rate for optimizer
42
  max_epoch: Maximum number of training epochs
43
  patience: Early stopping patience in epochs
 
 
44
 
45
+ # Evaluation
 
46
  test_every_n: Number of epochs between test evaluations
47
  """
48
 
 
60
  exp_id: str = Field(..., description="Experiment identifier for log folder naming")
61
  python_log_level: str = Field(default="INFO", description="Logger level for python logging module")
62
  tensorboard_log_dir: Path = Field(default=Path("runs"), description="Directory for tensorboard logs")
63
+ python_log_dir: Path = Field(default=Path("logs"), description="Directory for python logging files")
64
 
65
  # Training Hyperparameters
66
  batch_size: int = Field(default=64, gt=0, description="Training batch size")
67
  lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
68
  max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
69
  patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
 
 
70
 
71
+ # Evaluation
 
72
  test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
73
 
74
  @model_validator(mode='after')
 
121
  '--model_name',
122
  type=str,
123
  required=True,
124
+ choices=['linear', 'adafortitran', 'fortitran'],
125
+ help='Model type to train (linear, adafortitran, or fortitran)'
126
  )
127
  required.add_argument(
128
  '--system_config_path',
 
175
  default="runs",
176
  help='Directory for tensorboard logs'
177
  )
178
+ optional.add_argument(
179
+ '--python_log_dir',
180
+ type=Path,
181
+ default="logs",
182
+ help='Directory for python logging files'
183
+ )
184
  optional.add_argument(
185
  '--test_every_n',
186
  type=int,
 
205
  default=64,
206
  help='Training batch size'
207
  )
208
+
 
 
 
 
 
209
  optional.add_argument(
210
  '--lr',
211
  type=float,
212
  default=1e-3,
213
  help='Initial learning rate'
214
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
 
216
 
217
+ args = parser.parse_args()
 
218
 
219
  # Create and validate TrainingArguments
220
+ return TrainingArguments(**vars(args))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main/train_helpers.py CHANGED
@@ -120,7 +120,7 @@ def eval_model(
120
  output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
121
  val_loss += (2 * output.item() * batch[0].size(0))
122
 
123
- val_loss /= len(eval_dataloader.dataset)
124
  return val_loss
125
 
126
 
@@ -206,7 +206,7 @@ def train_epoch(
206
  train_loss += (2 * output.item() * batch[0].size(0))
207
 
208
  scheduler.step()
209
- train_loss /= len(train_dataloader.dataset)
210
  return train_loss
211
 
212
 
@@ -225,20 +225,17 @@ def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, Co
225
  Tuple of (processed_estimated_channel, ideal_channel)
226
 
227
  Raises:
228
- ValueError: If model name is not recognized
229
  """
230
  estimated_channel, ideal_channel, meta_data = batch
231
 
232
- if model.name in ["fortitran", "MMSE"]:
233
- h_est_re = model(torch.real(estimated_channel))
234
- h_est_im = model(torch.imag(estimated_channel))
235
- estimated_channel = torch.complex(h_est_re, h_est_im)
236
- elif model.name == "adafortitran":
237
- h_est_re = model(torch.real(estimated_channel), meta_data)
238
- h_est_im = model(torch.imag(estimated_channel), meta_data)
239
- estimated_channel = torch.complex(h_est_re, h_est_im)
240
  else:
241
- raise ValueError(f"Unknown model type: {model.name}")
 
242
 
243
  return estimated_channel, ideal_channel.to(model.device)
244
 
 
120
  output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
121
  val_loss += (2 * output.item() * batch[0].size(0))
122
 
123
+ val_loss /= sum(len(batch[0]) for batch in eval_dataloader)
124
  return val_loss
125
 
126
 
 
206
  train_loss += (2 * output.item() * batch[0].size(0))
207
 
208
  scheduler.step()
209
+ train_loss /= sum(len(batch[0]) for batch in train_dataloader)
210
  return train_loss
211
 
212
 
 
225
  Tuple of (processed_estimated_channel, ideal_channel)
226
 
227
  Raises:
228
+ ValueError: If model type is not recognized
229
  """
230
  estimated_channel, ideal_channel, meta_data = batch
231
 
232
+ # All models now handle complex input directly
233
+ if hasattr(model, 'use_channel_adaptation') and model.use_channel_adaptation:
234
+ # AdaFortiTran uses meta_data for channel adaptation
235
+ estimated_channel = model(estimated_channel, meta_data)
 
 
 
 
236
  else:
237
+ # Linear and FortiTran models don't use meta_data
238
+ estimated_channel = model(estimated_channel)
239
 
240
  return estimated_channel, ideal_channel.to(model.device)
241
 
src/main/trainer.py CHANGED
@@ -81,7 +81,7 @@ class ModelTrainer:
81
  self.system_config = system_config
82
  self.model_config = model_config
83
  self.args = args
84
- self.device = torch.device(f"cuda:{args.cuda}")
85
  self.writer = self._setup_tensorboard()
86
  self.logger = logging.getLogger(__name__)
87
 
@@ -120,14 +120,12 @@ class ModelTrainer:
120
  Returns:
121
  Initialized model instance of the specified type
122
  """
123
- if self.args.model_name == "linear":
124
- model = LinearEstimator(self.system_config, self.model_config)
125
- elif self.args.model_name == "adafortitran":
126
- model = AdaFortiTranEstimator(self.system_config, self.model_config)
127
- elif self.args.model_name == "fortitran":
128
- model = FortiTranEstimator(self.system_config, self.model_config)
129
- else:
130
- raise ValueError(f"Unknown model name: {self.args.model_name}")
131
  num_params, model_summary = get_model_details(model)
132
  self.logger.info("\n" + model_summary)
133
  self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
@@ -280,20 +278,15 @@ class ModelTrainer:
280
 
281
  def _forward_pass(self, batch, model):
282
  estimated_channel, ideal_channel, meta_data = batch
283
- if isinstance(model, FortiTranEstimator):
284
- h_est_re = model(torch.real(estimated_channel))
285
- h_est_im = model(torch.imag(estimated_channel))
286
- estimated_channel = torch.complex(h_est_re, h_est_im)
287
- elif isinstance(model, AdaFortiTranEstimator):
288
- h_est_re = model(torch.real(estimated_channel), meta_data)
289
- h_est_im = model(torch.imag(estimated_channel), meta_data)
290
- estimated_channel = torch.complex(h_est_re, h_est_im)
291
- elif isinstance(model, LinearEstimator):
292
- h_est_re = model(torch.real(estimated_channel))
293
- h_est_im = model(torch.imag(estimated_channel))
294
- estimated_channel = torch.complex(h_est_re, h_est_im)
295
  else:
296
- raise ValueError(f"Unknown model type: {type(model)}")
 
 
297
  return estimated_channel, ideal_channel.to(model.device)
298
 
299
  def _train_epoch(self):
 
81
  self.system_config = system_config
82
  self.model_config = model_config
83
  self.args = args
84
+ self.device = torch.device(model_config.device)
85
  self.writer = self._setup_tensorboard()
86
  self.logger = logging.getLogger(__name__)
87
 
 
120
  Returns:
121
  Initialized model instance of the specified type
122
  """
123
+ if self.args.model_name not in self.MODEL_REGISTRY:
124
+ raise ValueError(f"Unknown model name: {self.args.model_name}. Available: {list(self.MODEL_REGISTRY.keys())}")
125
+
126
+ model_class = self.MODEL_REGISTRY[self.args.model_name]
127
+ model = model_class(self.system_config, self.model_config)
128
+
 
 
129
  num_params, model_summary = get_model_details(model)
130
  self.logger.info("\n" + model_summary)
131
  self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
 
278
 
279
  def _forward_pass(self, batch, model):
280
  estimated_channel, ideal_channel, meta_data = batch
281
+
282
+ # All models now handle complex input directly
283
+ if isinstance(model, AdaFortiTranEstimator):
284
+ # AdaFortiTran uses meta_data for channel adaptation
285
+ estimated_channel = model(estimated_channel, meta_data)
 
 
 
 
 
 
 
286
  else:
287
+ # Linear and FortiTran models don't use meta_data
288
+ estimated_channel = model(estimated_channel)
289
+
290
  return estimated_channel, ideal_channel.to(model.device)
291
 
292
  def _train_epoch(self):