File size: 5,959 Bytes
bc8c4af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os, torch
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger


def launch_training_task(
    accelerator: Accelerator,
    dataset: torch.utils.data.Dataset,
    model: DiffusionTrainingModule,
    model_logger: ModelLogger,
    learning_rate: float = 1e-5,
    weight_decay: float = 1e-2,
    num_workers: int = 1,
    save_steps: int = None,
    num_epochs: int = 1,
    args = None,
):
    if args is not None:
        learning_rate = args.learning_rate
        weight_decay = args.weight_decay
        num_workers = args.dataset_num_workers
        save_steps = args.save_steps
        num_epochs = args.num_epochs
    
    optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
    dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
    model.to(device=accelerator.device)
    # Exclude VAE from DeepSpeed ZeRO-3 wrapping to avoid compatibility issues
    # Store VAE outside the module tree so DeepSpeed doesn't touch it
    vae_module = getattr(model.pipe, 'vae', None)
    if vae_module is not None:
        del model.pipe._modules['vae']
    model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
    if vae_module is not None:
        vae_module.to(accelerator.device)
        # Store VAE as a non-module attribute so pipeline code can still use pipe.vae
        pipe = model.module.pipe if hasattr(model, 'module') else model.pipe
        # Use object.__setattr__ to bypass nn.Module's __setattr__ which would register it as a submodule
        object.__setattr__(pipe, 'vae', vae_module)
    initialize_deepspeed_gradient_checkpointing(accelerator)
    # Training log file
    log_path = os.path.join(model_logger.output_path, "training_log.txt")
    if accelerator.is_main_process:
        os.makedirs(model_logger.output_path, exist_ok=True)
        log_file = open(log_path, "a")
        log_file.write(f"Training started. Epochs: {num_epochs}, LR: {learning_rate}, Steps/epoch: {len(dataloader)}\n")
        log_file.flush()
    else:
        log_file = None

    total_target = num_epochs * len(dataloader)
    reached_target = False
    for epoch_id in range(num_epochs):
        if reached_target:
            break
        progress = tqdm(
            total=total_target,
            initial=model_logger.num_steps,
            desc=f"Epoch {epoch_id+1}/{num_epochs}",
        )
        for step_id, data in enumerate(dataloader):
            if model_logger.num_steps >= total_target:
                reached_target = True
                break
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                if dataset.load_from_cache:
                    loss = model({}, inputs=data)
                else:
                    loss = model(data)
                accelerator.backward(loss)
                optimizer.step()
                model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
                scheduler.step()

                # Log loss
                loss_val = loss.item()
                progress.update(1)
                progress.set_postfix(loss=f"{loss_val:.4f}")
                if accelerator.is_main_process and log_file is not None and (model_logger.num_steps % 10 == 0 or model_logger.num_steps <= 5):
                    log_file.write(f"epoch={epoch_id+1} step={model_logger.num_steps} loss={loss_val:.6f}\n")
                    log_file.flush()
        progress.close()
        if save_steps is None:
            model_logger.on_epoch_end(accelerator, model, epoch_id)
        if accelerator.is_main_process and log_file is not None:
            log_file.write(f"Epoch {epoch_id+1} completed. Checkpoint saved.\n")
            log_file.flush()
    model_logger.on_training_end(accelerator, model, save_steps)
    if log_file is not None:
        log_file.close()


def launch_data_process_task(
    accelerator: Accelerator,
    dataset: torch.utils.data.Dataset,
    model: DiffusionTrainingModule,
    model_logger: ModelLogger,
    num_workers: int = 8,
    args = None,
):
    if args is not None:
        num_workers = args.dataset_num_workers
        
    dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
    model.to(device=accelerator.device)
    model, dataloader = accelerator.prepare(model, dataloader)
    
    for data_id, data in enumerate(tqdm(dataloader)):
        with accelerator.accumulate(model):
            with torch.no_grad():
                folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
                os.makedirs(folder, exist_ok=True)
                save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
                data = model(data)
                torch.save(data, save_path)


def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
    if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
        ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
        if "activation_checkpointing" in ds_config:
            import deepspeed
            act_config = ds_config["activation_checkpointing"]
            deepspeed.checkpointing.configure(
                mpu_=None, 
                partition_activations=act_config.get("partition_activations", False),
                checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
                contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
            )
        else:
            print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")