Commit
·
2fa0d24
1
Parent(s):
4e938bd
fixed linter issues in trainer.py
Browse files- src/main/trainer.py +31 -15
src/main/trainer.py
CHANGED
|
@@ -120,13 +120,18 @@ class ModelTrainer:
|
|
| 120 |
Returns:
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
if self.model_config is None:
|
| 128 |
-
raise ValueError("model_config must be provided for
|
| 129 |
-
model =
|
|
|
|
|
|
|
| 130 |
num_params, model_summary = get_model_details(model)
|
| 131 |
self.logger.info("\n" + model_summary)
|
| 132 |
self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
|
|
@@ -276,41 +281,51 @@ class ModelTrainer:
|
|
| 276 |
|
| 277 |
def _forward_pass(self, batch, model):
|
| 278 |
estimated_channel, ideal_channel, meta_data = batch
|
| 279 |
-
if
|
| 280 |
h_est_re = model(torch.real(estimated_channel))
|
| 281 |
h_est_im = model(torch.imag(estimated_channel))
|
| 282 |
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 283 |
-
elif
|
| 284 |
h_est_re = model(torch.real(estimated_channel), meta_data)
|
| 285 |
h_est_im = model(torch.imag(estimated_channel), meta_data)
|
| 286 |
estimated_channel = torch.complex(h_est_re, h_est_im)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
else:
|
| 288 |
-
raise ValueError(f"Unknown model type: {
|
| 289 |
return estimated_channel, ideal_channel.to(model.device)
|
| 290 |
|
| 291 |
def _train_epoch(self):
|
| 292 |
train_loss = 0.0
|
| 293 |
self.model.train()
|
|
|
|
| 294 |
for batch in self.train_loader:
|
| 295 |
self.optimizer.zero_grad()
|
| 296 |
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 297 |
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 298 |
output.backward()
|
| 299 |
self.optimizer.step()
|
| 300 |
-
|
|
|
|
|
|
|
| 301 |
self.scheduler.step()
|
| 302 |
-
train_loss /=
|
| 303 |
return train_loss
|
| 304 |
|
| 305 |
def _eval_model(self, eval_dataloader):
|
| 306 |
val_loss = 0.0
|
| 307 |
self.model.eval()
|
|
|
|
| 308 |
with torch.no_grad():
|
| 309 |
for batch in eval_dataloader:
|
| 310 |
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 311 |
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
| 314 |
return val_loss
|
| 315 |
|
| 316 |
def _predict_channels(self, test_dataloaders):
|
|
@@ -359,9 +374,10 @@ class ModelTrainer:
|
|
| 359 |
- Early stopping when validation loss plateaus
|
| 360 |
- Logging final metrics and results
|
| 361 |
"""
|
| 362 |
-
|
| 363 |
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
| 364 |
for epoch in pbar:
|
|
|
|
| 365 |
# Training step
|
| 366 |
train_loss = self._train_epoch()
|
| 367 |
self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
|
|
@@ -383,7 +399,7 @@ class ModelTrainer:
|
|
| 383 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 384 |
pbar.write(message)
|
| 385 |
self._run_tests(epoch)
|
| 386 |
-
self._log_final_metrics(
|
| 387 |
self.writer.close()
|
| 388 |
|
| 389 |
|
|
|
|
| 120 |
Returns:
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
+
if self.args.model_name == "linear":
|
| 124 |
+
model = LinearEstimator(self.system_config, device=str(self.device))
|
| 125 |
+
elif self.args.model_name == "adafortitran":
|
| 126 |
+
if self.model_config is None:
|
| 127 |
+
raise ValueError("model_config must be provided for AdaFortiTranEstimator.")
|
| 128 |
+
model = AdaFortiTranEstimator(self.system_config, self.model_config)
|
| 129 |
+
elif self.args.model_name == "fortitran":
|
| 130 |
if self.model_config is None:
|
| 131 |
+
raise ValueError("model_config must be provided for FortiTranEstimator.")
|
| 132 |
+
model = FortiTranEstimator(self.system_config, self.model_config)
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"Unknown model name: {self.args.model_name}")
|
| 135 |
num_params, model_summary = get_model_details(model)
|
| 136 |
self.logger.info("\n" + model_summary)
|
| 137 |
self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
|
|
|
|
| 281 |
|
| 282 |
def _forward_pass(self, batch, model):
|
| 283 |
estimated_channel, ideal_channel, meta_data = batch
|
| 284 |
+
if isinstance(model, FortiTranEstimator):
|
| 285 |
h_est_re = model(torch.real(estimated_channel))
|
| 286 |
h_est_im = model(torch.imag(estimated_channel))
|
| 287 |
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 288 |
+
elif isinstance(model, AdaFortiTranEstimator):
|
| 289 |
h_est_re = model(torch.real(estimated_channel), meta_data)
|
| 290 |
h_est_im = model(torch.imag(estimated_channel), meta_data)
|
| 291 |
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 292 |
+
elif isinstance(model, LinearEstimator):
|
| 293 |
+
h_est_re = model(torch.real(estimated_channel))
|
| 294 |
+
h_est_im = model(torch.imag(estimated_channel))
|
| 295 |
+
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 296 |
else:
|
| 297 |
+
raise ValueError(f"Unknown model type: {type(model)}")
|
| 298 |
return estimated_channel, ideal_channel.to(model.device)
|
| 299 |
|
| 300 |
def _train_epoch(self):
|
| 301 |
train_loss = 0.0
|
| 302 |
self.model.train()
|
| 303 |
+
num_samples = 0
|
| 304 |
for batch in self.train_loader:
|
| 305 |
self.optimizer.zero_grad()
|
| 306 |
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 307 |
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 308 |
output.backward()
|
| 309 |
self.optimizer.step()
|
| 310 |
+
batch_size = batch[0].size(0)
|
| 311 |
+
train_loss += (2 * output.item() * batch_size)
|
| 312 |
+
num_samples += batch_size
|
| 313 |
self.scheduler.step()
|
| 314 |
+
train_loss /= num_samples
|
| 315 |
return train_loss
|
| 316 |
|
| 317 |
def _eval_model(self, eval_dataloader):
|
| 318 |
val_loss = 0.0
|
| 319 |
self.model.eval()
|
| 320 |
+
num_samples = 0
|
| 321 |
with torch.no_grad():
|
| 322 |
for batch in eval_dataloader:
|
| 323 |
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
| 324 |
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
| 325 |
+
batch_size = batch[0].size(0)
|
| 326 |
+
val_loss += (2 * output.item() * batch_size)
|
| 327 |
+
num_samples += batch_size
|
| 328 |
+
val_loss /= num_samples
|
| 329 |
return val_loss
|
| 330 |
|
| 331 |
def _predict_channels(self, test_dataloaders):
|
|
|
|
| 374 |
- Early stopping when validation loss plateaus
|
| 375 |
- Logging final metrics and results
|
| 376 |
"""
|
| 377 |
+
last_epoch = 0
|
| 378 |
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
| 379 |
for epoch in pbar:
|
| 380 |
+
last_epoch = epoch
|
| 381 |
# Training step
|
| 382 |
train_loss = self._train_epoch()
|
| 383 |
self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
|
|
|
|
| 399 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
| 400 |
pbar.write(message)
|
| 401 |
self._run_tests(epoch)
|
| 402 |
+
self._log_final_metrics(last_epoch)
|
| 403 |
self.writer.close()
|
| 404 |
|
| 405 |
|