| import os, torch |
| from tqdm import tqdm |
| from accelerate import Accelerator |
| from .training_module import DiffusionTrainingModule |
| from .logger import ModelLogger |
|
|
|
|
| def launch_training_task( |
| accelerator: Accelerator, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| model_logger: ModelLogger, |
| learning_rate: float = 1e-5, |
| weight_decay: float = 1e-2, |
| num_workers: int = 1, |
| save_steps: int = None, |
| num_epochs: int = 1, |
| args = None, |
| ): |
| if args is not None: |
| learning_rate = args.learning_rate |
| weight_decay = args.weight_decay |
| num_workers = args.dataset_num_workers |
| save_steps = args.save_steps |
| num_epochs = args.num_epochs |
| |
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) |
| scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) |
| |
| model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) |
| |
| for epoch_id in range(num_epochs): |
| for data in tqdm(dataloader): |
| with accelerator.accumulate(model): |
| optimizer.zero_grad() |
| if dataset.load_from_cache: |
| loss = model({}, inputs=data) |
| else: |
| loss = model(data) |
| accelerator.backward(loss) |
| optimizer.step() |
| model_logger.on_step_end(accelerator, model, save_steps) |
| scheduler.step() |
| if save_steps is None: |
| model_logger.on_epoch_end(accelerator, model, epoch_id) |
| model_logger.on_training_end(accelerator, model, save_steps) |
|
|
|
|
| def launch_data_process_task( |
| accelerator: Accelerator, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| model_logger: ModelLogger, |
| num_workers: int = 8, |
| args = None, |
| ): |
| if args is not None: |
| num_workers = args.dataset_num_workers |
| |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) |
| model, dataloader = accelerator.prepare(model, dataloader) |
| |
| for data_id, data in enumerate(tqdm(dataloader)): |
| with accelerator.accumulate(model): |
| with torch.no_grad(): |
| folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) |
| os.makedirs(folder, exist_ok=True) |
| save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") |
| data = model(data) |
| torch.save(data, save_path) |
|
|