| import os, torch |
| from tqdm import tqdm |
| from accelerate import Accelerator |
| from .training_module import DiffusionTrainingModule |
| from .logger import ModelLogger |
| try: |
| import swanlab |
| except ImportError: |
| swanlab = None |
|
|
|
|
| def launch_training_task( |
| accelerator: Accelerator, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| model_logger: ModelLogger, |
| learning_rate: float = 1e-5, |
| weight_decay: float = 1e-2, |
| num_workers: int = 1, |
| save_steps: int = None, |
| num_epochs: int = 1, |
| args = None, |
| ): |
| if args is not None: |
| learning_rate = args.learning_rate |
| weight_decay = args.weight_decay |
| num_workers = args.dataset_num_workers |
| save_steps = args.save_steps |
| num_epochs = args.num_epochs |
| |
| if accelerator.is_main_process and swanlab is not None: |
| try: |
| swanlab.login(api_key="3i1jW3SQY01fdVPcSnzij") |
| config_dict = vars(args) if args is not None else {} |
| swanlab.init(project="AI4VA_track2_solution", config=config_dict) |
| except Exception as e: |
| print("SwanLab init failed:", e) |
| |
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) |
| scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) |
| model.to(device=accelerator.device) |
| model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) |
| |
| for epoch_id in range(num_epochs): |
| progress_bar = tqdm(dataloader) |
| for data in progress_bar: |
| with accelerator.accumulate(model): |
| optimizer.zero_grad() |
| if dataset.load_from_cache: |
| loss = model({}, inputs=data) |
| else: |
| loss = model(data) |
| accelerator.backward(loss) |
| optimizer.step() |
| model_logger.on_step_end(accelerator, model, save_steps, loss=loss) |
| scheduler.step() |
| |
| |
| progress_bar.set_description(f"Epoch {epoch_id} | Loss: {loss.item():.4f}") |
| |
| if accelerator.is_main_process and swanlab is not None: |
| swanlab.log({"train/loss": loss.item()}) |
| if save_steps is None: |
| model_logger.on_epoch_end(accelerator, model, epoch_id) |
| model_logger.on_training_end(accelerator, model, save_steps) |
|
|
|
|
| def launch_data_process_task( |
| accelerator: Accelerator, |
| dataset: torch.utils.data.Dataset, |
| model: DiffusionTrainingModule, |
| model_logger: ModelLogger, |
| num_workers: int = 8, |
| args = None, |
| ): |
| if args is not None: |
| num_workers = args.dataset_num_workers |
| |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) |
| model.to(device=accelerator.device) |
| model, dataloader = accelerator.prepare(model, dataloader) |
| |
| for data_id, data in enumerate(tqdm(dataloader)): |
| with accelerator.accumulate(model): |
| with torch.no_grad(): |
| folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) |
| os.makedirs(folder, exist_ok=True) |
| save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") |
| data = model(data) |
| torch.save(data, save_path) |
|
|