File size: 21,598 Bytes
166ab04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
import gc
from shutil import ignore_patterns
import argparse
import json
import sys
import os
import os.path as osp
import datetime
import shutil
from typing import List

# from PIL import Image
import toml
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

import torch
import torch.nn.functional as F
# import torch.distributed as dist 
from accelerate.utils import set_seed

from diffusers import DDPMScheduler, DDIMScheduler
from accelerate import DistributedType
from diffusers.utils import logging
# from diffusers.models import AutoencoderKL

import library.train_util as train_util
import library.chinese_sdxl_train_util as chinese_sdxl_train_util
# import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
    apply_snr_weight,
    prepare_scheduler_for_custom_training,
    scale_v_prediction_loss_like_noise_prediction,
    add_v_prediction_like_loss,
)

from model_lib.nets.layers.ema import LitEma, load_litema, save_litema, ema_scope
from removal.v1_2 import (
    RemovalDataset, RemovalDataset_v1_2,
    load_cfg,
    build_removal_model,
    load_removal_model,
)

from utils_train import (
    build_accelerator,
    build_dataloader,
    build_vae,
    build_models,
    save,
    common_arguments,
    build_progress_bar
)
from model_lib.nets.utils import CustomOutput
from utils_infer import encode_clean_latents #, predict_noise

import warnings
warnings.filterwarnings("ignore", message="Grad strides do not match bucket view strides.*")
warnings.filterwarnings("ignore", message="Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.")

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

def build_teacher_model(args, weight_dtype, accelerator):
    teacher_cfg = load_cfg(args.teacher_config_path)

    teacher_model = build_removal_model(teacher_cfg, args.num_embeddings)

    if args.teacher_weight_path:
        accelerator.print(f"==> Loading teacher model from: {args.teacher_weight_path}")
        state_dict = torch.load(args.teacher_weight_path, map_location=accelerator.device)
        teacher_model.load_state_dict(state_dict)

    accelerator.print(f"weight_dtype:{weight_dtype}")
    if getattr(teacher_model, 'unet', None):
        accelerator.print(f"unet:{teacher_model.unet.dtype}")
    else:
        accelerator.print(f"diff_model:{teacher_model.diff_model.dtype}")

    teacher_model.requires_grad_(False).eval()
    teacher_model.to(accelerator.device, dtype=torch.float32)
    if accelerator.is_main_process:
        from pprint import pprint
        pprint("Teacher Model Config:")
        pprint(teacher_model.diff_model.config)

    # set xformer/mem_eff_attn
    accelerator.print(f"Enable memory efficient attention, mem_eff_attn:{args.mem_eff_attn}, xformers:{args.xformers}")
    chinese_sdxl_train_util.set_diffusers_xformers_flag(teacher_model.diff_model, True)

    return teacher_model



def cal_KD_loss(pred: CustomOutput, target: CustomOutput, args):
    loss_dict = dict()

    # get feat KD loss from intermediate layers.
    if args.kl_feat_loss or args.mse_feat_loss: 
        feat_loss_list = []
        assert len(args.feat_index_S) == len(args.feat_loss_weight)
        assert len(args.feat_index_S) == len(args.feat_index_T)
        for _is, _it, _weight in zip(args.feat_index_S, args.feat_index_T, args.feat_loss_weight):
            feat_S, feat_T = pred.block_outputs[_is], target.block_outputs[_it]
            if args.kl_feat_loss: 
                with torch.no_grad():
                    probs_T = torch.softmax(feat_T /  args.kl_temp, dim=1)
                log_probs_S = torch.log_softmax(feat_S /  args.kl_temp, dim=1)
                feat_loss = torch.nn.functional.kl_div(log_probs_S, probs_T, reduction='batchmean')
            elif args.mse_feat_loss: 
                feat_loss = torch.nn.functional.mse_loss(feat_S, feat_T, reduction='mean')
            else:
                print("no available KD_loss type!")
            feat_loss_list.append(feat_loss * _weight)
        loss_dict["loss_featkd"] = sum(feat_loss_list)
    else:
        loss_dict["loss_featkd"] = 0

    loss_outkd = torch.nn.functional.mse_loss(pred.sample.float(), target.sample.float(), reduction="mean")
    loss_dict["loss_outkd"] = loss_outkd
    loss_kd = sum([ v for k,v in loss_dict.items()])
    return loss_kd, loss_dict 


def cal_task_loss(pred: CustomOutput, target: torch.Tensor, args):
    '''
        refer to task loss between gt noise and student pred noise in SnapGen.
    '''
    loss_dict = dict()
    if args.task_loss:
        loss_task = torch.nn.functional.mse_loss(pred.sample.float(), target.float(),reduction='mean')
        loss_dict['loss_task'] = loss_task
    else:
        loss_task = 0
        loss_dict['loss_task'] = loss_task
    return loss_task, loss_dict

def cal_elatentlpips_loss(
    pred: CustomOutput, target: torch.Tensor, encoder_model:torch.nn.Module,  
    noise_scheduler, timesteps, noisy_latents, args = None):
    '''
        refer to task loss between gt noise and student pred noise in SnapGen.
    '''

    loss_dict = dict()
    if args.elatentlpips_loss:
        # Compute the perceptual distance between the two latent representations
        # Note: Set `normalize=True` if the latents (latent0 and latent1) are not already normalized 
        # by `vae.config.scaling_factor` and `vae.config.shift_factor`.
        noise_pred = pred.sample
        noisy_latents_pred = torch.stack([
            noise_scheduler.step(n, t, noisy_latent).pred_original_sample \
                for (n, t, noisy_latent) in zip(noise_pred, timesteps, noisy_latents)
        ])
        target_latents_pred = torch.stack([
            noise_scheduler.step(tgt, t, noisy_latent).pred_original_sample \
                for (tgt, t, noisy_latent) in zip(target.float(), timesteps, noisy_latents)
        ])
        loss_elatentlpips = encoder_model(noisy_latents_pred, target_latents_pred, normalize=True, ensembling=True).mean()
        loss_dict['loss_elatentlpips'] = loss_elatentlpips
    else:
        loss_elatentlpips = 0
        loss_dict['loss_elatentlpips'] = loss_elatentlpips
    return loss_elatentlpips, loss_dict



def cal_adaptive_weights_type8(featkd_loss, task_loss, outkd_loss, elatentlpips_loss, last_featkd_layer=None, outkd_layer=None):
    assert last_featkd_layer is not None, "need last_featkd_layer's parameter to get gradient"
    assert outkd_layer is not None, "need outkd_layer's parameter to get gradient"
    
    from torch.autograd import grad as get_grad
    from torch import norm as get_norm

    feat_grad_featkd = get_grad(featkd_loss, last_featkd_layer, retain_graph=True)[0]
    feat_grad_outkd  = get_grad(outkd_loss,  last_featkd_layer, retain_graph=True)[0]
    feat_grad_task   = get_grad(task_loss,   last_featkd_layer, retain_graph=True)[0]
    feat_grad_elatentlpips = get_grad(elatentlpips_loss,   last_featkd_layer, retain_graph=True)[0]

    out_grad_outkd  = get_grad(outkd_loss, outkd_layer, retain_graph=True)[0]
    out_grad_task   = get_grad(task_loss,  outkd_layer, retain_graph=True)[0]
    out_grad_elatentlpips = get_grad(elatentlpips_loss,  outkd_layer, retain_graph=True)[0]

    out_weight_outkd = get_norm(out_grad_task) / (get_norm(out_grad_outkd) + 1e-6)
    out_weight_outkd = torch.clamp(out_weight_outkd, 0.0, 1e6).detach()

    out_weight_elatentlpips = get_norm(out_grad_task) / (get_norm(out_grad_elatentlpips) + 1e-6)
    out_weight_elatentlpips = torch.clamp(out_weight_elatentlpips, 0.0, 1e6).detach()

    feat_weight_task = get_norm(feat_grad_featkd) / (get_norm(feat_grad_task) + 1e-4)
    feat_weight_task = torch.clamp(feat_weight_task, 0.0, 1e4).detach()  

    return feat_weight_task, out_weight_outkd, out_weight_elatentlpips, \
        get_norm(feat_grad_featkd), get_norm(feat_grad_task), get_norm(feat_grad_outkd), get_norm(feat_grad_elatentlpips), \
        get_norm(out_grad_task), get_norm(out_grad_outkd), get_norm(out_grad_elatentlpips)




def train_distillation(args):
    chinese_sdxl_train_util.verify_sdxl_training_args(args,False)

    if args.seed is not None:
        set_seed(args.seed)

    accelerator = build_accelerator(args, fsdp_plugin=None)

    weight_dtype, save_dtype = train_util.prepare_dtype(args)

    student, vae = build_models(args, weight_dtype, accelerator)
    teacher = build_teacher_model(args, weight_dtype, accelerator)
    del vae.decoder # cause not need docoder in current training paradigm

    # EMA
    if args.use_model_ema:
        student_ema = LitEma(student, decay=args.ema_decay).to(accelerator.device)
        accelerator.print(f"Keeping EMAs of {len(list(student_ema.buffers()))}.")

    # torch.compile
    if args.use_compile:
        student.compile(backend="cudagraphs", 
                        fullgraph=False,
                        dynamic=False)

    noise_scheduler = DDPMScheduler(
        beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
    )
    prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
    
    if args.resume_from_ckpt:
        accelerator.print(f"==> resume ckpt from : {args.resume_from_ckpt}")
        msg = load_removal_model(
            student, args.resume_from_ckpt, accelerator.device, strict=True
        )
        accelerator.print(f'load state dict {msg}')

        if args.use_model_ema:
            load_path = args.resume_from_ckpt.replace(
                "diffusion_pytorch_model.bin",
                "diffusion_pytorch_model.EMA.bin"
            )
            load_litema(student_ema, load_path, map_location=accelerator.device)
            accelerator.print(f'load EMA state dict {msg}')

    # E-Latent-LPIPS
    if args.elatentlpips_loss:
        from elatentlpips import ELatentLPIPS
        # Initialize E-LatentLPIPS with the specified encoder model (options: sd15, sd21, sdxl, sd3, flux)
        # The 'augment' parameter can be set to one of the following: b, bg, bgc, bgco
        elatentlpips_model = ELatentLPIPS(encoder="sdxl", augment="bg").eval()
        elatentlpips_model = accelerator.prepare(elatentlpips_model)
    else:
        elatentlpips_model = None


    # training_models
    training_models = []
    params_to_optimize = []
    named_params_to_optimize = []

    training_models.append(student)
    params_to_optimize.append({"params": list(student.parameters()), "lr": args.learning_rate})
    named_params_to_optimize.append({"params": list(student.named_parameters()), "lr": args.learning_rate})

    n_params = 0
    for params in params_to_optimize:
        for p in params["params"]:
            n_params += p.numel()
    accelerator.print(f"number of models: {len(training_models)}")
    accelerator.print(f"number of trainable parameters: {n_params}")

    accelerator.print("prepare optimizer, data loader etc.")
    
    _, _, optimizer = train_util.get_optimizer(args, 
        trainable_params=params_to_optimize, 
        named_trainable_params=named_params_to_optimize)
    lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
    
    executor = ThreadPoolExecutor(max_workers=1) 

    student, optimizer, lr_scheduler = accelerator.prepare(
        student, optimizer, lr_scheduler
    )

    teacher = accelerator.prepare(teacher)
    
    if accelerator.is_main_process:
        init_kwargs = {}
        if args.log_tracker_config is not None:
            init_kwargs = toml.load(args.log_tracker_config)
        accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)

    loss_total = 0
    accumulate_loss = 0
    for m in training_models:
        m.train()

    dataset_class = eval(args.data_type)
    train_dataloader, _ = build_dataloader(args, 
        dataset_class, accelerator)

    global_step = args.global_step
    pbar = build_progress_bar(
        range(args.max_train_steps),  args.global_step, 
        disable=not accelerator.is_local_main_process)


    for step in range(args.global_step, args.max_train_steps):
        with accelerator.accumulate(training_models[0]): 
            batch = next(train_dataloader)
            latents, masked_image_latents = encode_clean_latents(batch, vae, weight_dtype, accelerator)

            # resize mask
            masks = batch["masks"]
            h, w = masks.shape[-2:]
            vae_ds_ratio = 2 ** (len(vae.config.block_out_channels) - 1)
            size = (h // vae_ds_ratio, w // vae_ds_ratio)
            resized_masks = F.interpolate(masks, size=size).to(accelerator.device, dtype=weight_dtype)
            
            # Sample noise
            noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
            noisy_latents = noisy_latents.to(weight_dtype)

            # Predict the noise residual
            with accelerator.autocast():
                latent_model_input = torch.cat([
                 noisy_latents, resized_masks, masked_image_latents], dim=1)

                pred_S = student(
                    latent_model_input, timesteps=timesteps, input_ids=batch["input_ids"])

                pred_T = teacher(
                    latent_model_input, timesteps=timesteps, input_ids=batch["input_ids"])

            # target = noise
            loss_kd, loss_dict_kd = cal_KD_loss(pred_S, pred_T, args)
            loss_task, loss_dict_task = cal_task_loss(pred_S, noise, args)
            loss_elatentlpips, loss_dict_elatentlpips = cal_elatentlpips_loss(
                pred_S, noise, elatentlpips_model, 
                noise_scheduler = noise_scheduler, 
                timesteps = timesteps, 
                noisy_latents = noisy_latents, 
                args = args)

            loss_dict = loss_dict_kd | loss_dict_task | loss_dict_elatentlpips

            raw_student = accelerator.unwrap_model(student)
            feat_weight_task, \
            out_weight_outkd, \
            out_weight_elatentlpips, \
            feat_gnorm_featkd, \
            feat_gnorm_task,  \
            feat_gnorm_outkd, \
            feat_gnorm_elatentlpips, \
            out_gnorm_task,   \
            out_gnorm_outkd, \
            out_gnorm_elatentlpips  = cal_adaptive_weights_type8(
                loss_dict["loss_featkd"], 
                loss_dict["loss_task"], 
                loss_dict["loss_outkd"], 
                loss_dict["loss_elatentlpips"], 
                last_featkd_layer = raw_student.diff_model.down_blocks[2].attentions[1].proj_out.weight,
                outkd_layer = raw_student.diff_model.conv_out.conv_pw.weight)

            loss = loss_dict["loss_featkd"] * args.KD_loss_weight \
                    + feat_weight_task * ( \
                        loss_dict["loss_task"]  * args.task_loss_weight \
                        + loss_dict["loss_outkd"] * out_weight_outkd * args.KD_loss_weight \
                        + loss_dict["loss_elatentlpips"] * out_weight_elatentlpips * args.elatentlpips_loss_weight)

            accelerator.backward(loss)
            if args.max_grad_norm != 0.0:
                grad_norm = accelerator.clip_grad_norm_(
                    student.parameters(), args.max_grad_norm).item()

        optimizer.step()

        if args.use_model_ema:
            raw_student = accelerator.unwrap_model(student)
            student_ema(accelerator.unwrap_model(raw_student))

        lr_scheduler.step()
        optimizer.zero_grad()

        current_loss = loss.detach()
        accumulate_loss += current_loss
        
        # logging        
        if accelerator.sync_gradients: 
            loss_total += accumulate_loss #current_loss
            logs = {
                "avr_loss": loss_total.item() / (step + 1 - args.global_step),
                "loss": accumulate_loss.item() / accelerator.gradient_accumulation_steps, #current_loss,
                "lr": float(lr_scheduler.get_last_lr()[0]),
                "grad_norm": grad_norm,
                'global_step': global_step,
                "feat_gnorm_featkd": feat_gnorm_featkd.item(),
                "feat_gnorm_task": feat_gnorm_task.item(),
                "feat_gnorm_outkd": feat_gnorm_outkd.item(),
                "feat_gnorm_elatentlpips": feat_gnorm_elatentlpips.item(),
                "out_gnorm_task": out_gnorm_task.item(),
                "out_gnorm_outkd": out_gnorm_outkd.item(),
                "out_gnorm_elatentlpips": out_gnorm_elatentlpips.item(),
                "feat_weight_task": feat_weight_task.item(),
                "out_weight_outkd": out_weight_outkd.item(),
                "out_weight_elatentlpips": out_weight_elatentlpips.item()
            }
            logs |= { k:v.item() for k,v in loss_dict.items()}
            pbar.set_postfix(**logs, refresh=False)

            if args.logging_dir:
                tb_logs = logs | {"rank": accelerator.process_index,}
                executor.submit(accelerator.log, tb_logs, step=global_step)

            accumulate_loss = 0

        # save model by step
        if (global_step != args.global_step \
            and args.save_every_n_steps \
            and global_step % args.save_every_n_steps == 0):
                save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.bin")
                save(student, save_path, accelerator)

                if args.use_model_ema:
                    save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.EMA.bin")
                    if accelerator.is_main_process:
                        save_litema(student_ema, save_path)
                    accelerator.print(f"d[info]: EMA Model saved at: {save_path}\n")

        pbar.update()
        global_step += 1

    # save the final model
    save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.bin")
    save(student, save_path, accelerator)

    if args.use_model_ema:
        save_path = osp.join(args.output_dir, "ckpt", f"exp-step{global_step:08d}", f"diffusion_pytorch_model.EMA.bin")
        save_litema(student_ema, save_path)
        accelerator.print(f"d[info]: EMA Model saved at: {save_path}\n")

    accelerator.wait_for_everyone()
    accelerator.end_training()

def setup_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()

    train_util.add_sd_models_arguments(parser)
    train_util.add_training_arguments(parser, False)
    train_util.add_sd_saving_arguments(parser)
    train_util.add_optimizer_arguments(parser)
    # config_util.add_config_arguments(parser)

    common_arguments(parser)
   
    '''add teacher config path'''
    parser.add_argument('--teacher_config_path', type=str, default=None)
    parser.add_argument('--teacher_weight_path', type=str, default=None)
    
    parser.add_argument('--kl_feat_loss', action='store_true', 
        help='enable KLDivLoss for feat and output KD.')
    parser.add_argument('--kl_tempeature', type=float, default=1.0, dest = 'kl_temp', 
        help='temperature for the smoothment of soft label feature.')
    
    parser.add_argument('--mse_feat_loss', action='store_true', 
        help='enable MSELoss for feat and output KD.')
    
    parser.add_argument('--feat_index_T', nargs='*', type=int, default=[4,], 
        help='index list of Teacher intermediate feautures for KD.')
    parser.add_argument('--feat_index_S', nargs='*', type=int, default=[4,], 
        help='index list of Student intermediate feautures for KD.')
    parser.add_argument('--feat_loss_weight', nargs='*',type=float, default=[0.2,], 
        help='loss weights of intermediate feautures for KD.')


    parser.add_argument('--task_loss', action='store_true', 
        help='enable MSELoss for output and gt_noise.')
    parser.add_argument('--task_loss_weight', type=float, default=1.0, 
        help='weight multiplied to loss_task.')
    parser.add_argument('--KD_loss_weight', type=float, default=1.0, 
        help='weight multiplied to loss_kd.')
    
    parser.add_argument('--elatentlpips_loss', action='store_true', 
        help='enable MSELoss for output and gt_noise.')
    parser.add_argument('--elatentlpips_loss_weight', type=float, default=1.0, 
        help='weight multiplied to loss_task.')
    

    parser.add_argument('--use_model_ema', action='store_true', 
        help='enable EMA on training model.')
    parser.add_argument('--ema_decay', type=float, default=0.9999)
    
    parser.add_argument('--use_compile', action='store_true', 
        help='use torch.compile on foward & backward.')
    
    # datatype
    parser.add_argument('--data_type', 
        type=str, default="RemovalDataset",
        choices=['RemovalDataset', 'RemovalDataset_v1_2'], 
        help='different mask assignment strategy.')
    
    return parser


if __name__ == "__main__":
#    timeout_seconds = 1800
#    timeout_timedelta = datetime.timedelta(seconds=timeout_seconds)
#    torch.distributed.init_process_group(backend='nccl', timeout=timeout_timedelta)

    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

    parser = setup_parser()

    args = parser.parse_args()
    args = train_util.read_config_from_file(args, parser)
    
    train_distillation(args)