| import os |
| import torch |
| from tqdm import tqdm |
| from configs import paths_config, hyperparameters, global_config |
| from training.coaches.base_coach import BaseCoach |
| from utils.log_utils import log_images_from_w |
|
|
|
|
| class SingleIDCoach(BaseCoach): |
|
|
| def __init__(self, data_loader, use_wandb): |
| super().__init__(data_loader, use_wandb) |
|
|
| def train(self): |
|
|
| w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' |
| os.makedirs(w_path_dir, exist_ok=True) |
| os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True) |
|
|
| use_ball_holder = True |
|
|
| for fname, image in tqdm(self.data_loader): |
| image_name = fname[0] |
|
|
| self.restart_training() |
|
|
| if self.image_counter >= hyperparameters.max_images_to_invert: |
| break |
|
|
| embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' |
| os.makedirs(embedding_dir, exist_ok=True) |
|
|
| w_pivot = None |
|
|
| if hyperparameters.use_last_w_pivots: |
| w_pivot = self.load_inversions(w_path_dir, image_name) |
|
|
| elif not hyperparameters.use_last_w_pivots or w_pivot is None: |
| w_pivot = self.calc_inversions(image, image_name) |
|
|
| |
| w_pivot = w_pivot.to(global_config.device) |
|
|
| torch.save(w_pivot, f'{embedding_dir}/0.pt') |
| log_images_counter = 0 |
| real_images_batch = image.to(global_config.device) |
|
|
| for i in tqdm(range(hyperparameters.max_pti_steps)): |
|
|
| generated_images = self.forward(w_pivot) |
| loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name, |
| self.G, use_ball_holder, w_pivot) |
|
|
| self.optimizer.zero_grad() |
|
|
| if loss_lpips <= hyperparameters.LPIPS_value_threshold: |
| break |
|
|
| loss.backward() |
| self.optimizer.step() |
|
|
| use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 |
|
|
| if self.use_wandb and log_images_counter % global_config.image_rec_result_log_snapshot == 0: |
| log_images_from_w([w_pivot], self.G, [image_name]) |
|
|
| global_config.training_step += 1 |
| log_images_counter += 1 |
|
|
| self.image_counter += 1 |
|
|
| torch.save(self.G, |
| f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt') |
|
|