BerkIGuler commited on
Commit
2fa0d24
·
1 Parent(s): 4e938bd

fixed linter issues in trainer.py

Browse files
Files changed (1) hide show
  1. src/main/trainer.py +31 -15
src/main/trainer.py CHANGED
@@ -120,13 +120,18 @@ class ModelTrainer:
120
  Returns:
121
  Initialized model instance of the specified type
122
  """
123
- model_class = self.MODEL_REGISTRY[self.args.model_name]
124
- if model_class is LinearEstimator:
125
- model = model_class(self.system_config, device=str(self.device))
126
- else:
 
 
 
127
  if self.model_config is None:
128
- raise ValueError("model_config must be provided for non-linear models.")
129
- model = model_class(self.system_config, self.model_config)
 
 
130
  num_params, model_summary = get_model_details(model)
131
  self.logger.info("\n" + model_summary)
132
  self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
@@ -276,41 +281,51 @@ class ModelTrainer:
276
 
277
  def _forward_pass(self, batch, model):
278
  estimated_channel, ideal_channel, meta_data = batch
279
- if hasattr(model, 'name') and model.name in ["fortitran", "MMSE"]:
280
  h_est_re = model(torch.real(estimated_channel))
281
  h_est_im = model(torch.imag(estimated_channel))
282
  estimated_channel = torch.complex(h_est_re, h_est_im)
283
- elif hasattr(model, 'name') and model.name == "adafortitran":
284
  h_est_re = model(torch.real(estimated_channel), meta_data)
285
  h_est_im = model(torch.imag(estimated_channel), meta_data)
286
  estimated_channel = torch.complex(h_est_re, h_est_im)
 
 
 
 
287
  else:
288
- raise ValueError(f"Unknown model type: {getattr(model, 'name', type(model))}")
289
  return estimated_channel, ideal_channel.to(model.device)
290
 
291
  def _train_epoch(self):
292
  train_loss = 0.0
293
  self.model.train()
 
294
  for batch in self.train_loader:
295
  self.optimizer.zero_grad()
296
  estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
297
  output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
298
  output.backward()
299
  self.optimizer.step()
300
- train_loss += (2 * output.item() * batch[0].size(0))
 
 
301
  self.scheduler.step()
302
- train_loss /= len(self.train_loader.dataset)
303
  return train_loss
304
 
305
  def _eval_model(self, eval_dataloader):
306
  val_loss = 0.0
307
  self.model.eval()
 
308
  with torch.no_grad():
309
  for batch in eval_dataloader:
310
  estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
311
  output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
312
- val_loss += (2 * output.item() * batch[0].size(0))
313
- val_loss /= len(eval_dataloader.dataset)
 
 
314
  return val_loss
315
 
316
  def _predict_channels(self, test_dataloaders):
@@ -359,9 +374,10 @@ class ModelTrainer:
359
  - Early stopping when validation loss plateaus
360
  - Logging final metrics and results
361
  """
362
- epoch = None
363
  pbar = tqdm(range(self.args.max_epoch), desc="Training")
364
  for epoch in pbar:
 
365
  # Training step
366
  train_loss = self._train_epoch()
367
  self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
@@ -383,7 +399,7 @@ class ModelTrainer:
383
  message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
384
  pbar.write(message)
385
  self._run_tests(epoch)
386
- self._log_final_metrics(epoch)
387
  self.writer.close()
388
 
389
 
 
120
  Returns:
121
  Initialized model instance of the specified type
122
  """
123
+ if self.args.model_name == "linear":
124
+ model = LinearEstimator(self.system_config, device=str(self.device))
125
+ elif self.args.model_name == "adafortitran":
126
+ if self.model_config is None:
127
+ raise ValueError("model_config must be provided for AdaFortiTranEstimator.")
128
+ model = AdaFortiTranEstimator(self.system_config, self.model_config)
129
+ elif self.args.model_name == "fortitran":
130
  if self.model_config is None:
131
+ raise ValueError("model_config must be provided for FortiTranEstimator.")
132
+ model = FortiTranEstimator(self.system_config, self.model_config)
133
+ else:
134
+ raise ValueError(f"Unknown model name: {self.args.model_name}")
135
  num_params, model_summary = get_model_details(model)
136
  self.logger.info("\n" + model_summary)
137
  self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
 
281
 
282
  def _forward_pass(self, batch, model):
283
  estimated_channel, ideal_channel, meta_data = batch
284
+ if isinstance(model, FortiTranEstimator):
285
  h_est_re = model(torch.real(estimated_channel))
286
  h_est_im = model(torch.imag(estimated_channel))
287
  estimated_channel = torch.complex(h_est_re, h_est_im)
288
+ elif isinstance(model, AdaFortiTranEstimator):
289
  h_est_re = model(torch.real(estimated_channel), meta_data)
290
  h_est_im = model(torch.imag(estimated_channel), meta_data)
291
  estimated_channel = torch.complex(h_est_re, h_est_im)
292
+ elif isinstance(model, LinearEstimator):
293
+ h_est_re = model(torch.real(estimated_channel))
294
+ h_est_im = model(torch.imag(estimated_channel))
295
+ estimated_channel = torch.complex(h_est_re, h_est_im)
296
  else:
297
+ raise ValueError(f"Unknown model type: {type(model)}")
298
  return estimated_channel, ideal_channel.to(model.device)
299
 
300
  def _train_epoch(self):
301
  train_loss = 0.0
302
  self.model.train()
303
+ num_samples = 0
304
  for batch in self.train_loader:
305
  self.optimizer.zero_grad()
306
  estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
307
  output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
308
  output.backward()
309
  self.optimizer.step()
310
+ batch_size = batch[0].size(0)
311
+ train_loss += (2 * output.item() * batch_size)
312
+ num_samples += batch_size
313
  self.scheduler.step()
314
+ train_loss /= num_samples
315
  return train_loss
316
 
317
  def _eval_model(self, eval_dataloader):
318
  val_loss = 0.0
319
  self.model.eval()
320
+ num_samples = 0
321
  with torch.no_grad():
322
  for batch in eval_dataloader:
323
  estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
324
  output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
325
+ batch_size = batch[0].size(0)
326
+ val_loss += (2 * output.item() * batch_size)
327
+ num_samples += batch_size
328
+ val_loss /= num_samples
329
  return val_loss
330
 
331
  def _predict_channels(self, test_dataloaders):
 
374
  - Early stopping when validation loss plateaus
375
  - Logging final metrics and results
376
  """
377
+ last_epoch = 0
378
  pbar = tqdm(range(self.args.max_epoch), desc="Training")
379
  for epoch in pbar:
380
+ last_epoch = epoch
381
  # Training step
382
  train_loss = self._train_epoch()
383
  self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
 
399
  message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
400
  pbar.write(message)
401
  self._run_tests(epoch)
402
+ self._log_final_metrics(last_epoch)
403
  self.writer.close()
404
 
405