|
|
import json |
|
|
import logging |
|
|
import math |
|
|
import time |
|
|
import torch |
|
|
|
|
|
from open_clip import get_cast_dtype |
|
|
from .distributed import is_master |
|
|
from .zero_shot import zero_shot_eval |
|
|
from .precision import get_autocast |
|
|
import os |
|
|
|
|
|
class AverageMeter(object): |
|
|
"""Computes and stores the average and current value""" |
|
|
|
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
def postprocess_clip_output(model_out): |
|
|
return { |
|
|
"image_features": model_out[0], |
|
|
"text_features": model_out[1], |
|
|
"logit_scale": model_out[2] |
|
|
} |
|
|
|
|
|
def unwrap_model(model): |
|
|
if hasattr(model, 'module'): |
|
|
return model.module |
|
|
else: |
|
|
return model |
|
|
|
|
|
def backward(total_loss, scaler): |
|
|
if scaler is not None: |
|
|
scaler.scale(total_loss).backward() |
|
|
else: |
|
|
total_loss.backward() |
|
|
|
|
|
@torch.no_grad() |
|
|
def student_teacher_ensemble(student, teacher, alpha=0.5): |
|
|
target_state_dict = {} |
|
|
for k, v in student.items(): |
|
|
target_state_dict[k] = v * alpha + teacher[k] * (1.0 - alpha) |
|
|
|
|
|
return target_state_dict |
|
|
|
|
|
def train_one_epoch(model, method, data, loss, epoch, optimizer, scaler, scheduler, dist_P_VLM, dist_model, args): |
|
|
device = torch.device(args.device) |
|
|
autocast = get_autocast(args.precision) |
|
|
cast_dtype = get_cast_dtype(args.precision) |
|
|
|
|
|
model.train() |
|
|
if dist_model is not None: |
|
|
dist_model.eval() |
|
|
|
|
|
if dist_P_VLM is not None: |
|
|
dist_P_VLM.eval() |
|
|
|
|
|
data['train'].set_epoch(epoch) |
|
|
dataloader = data['train'].dataloader |
|
|
num_batches_per_epoch = dataloader.num_batches // args.accum_freq |
|
|
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
|
|
|
|
|
losses_m = {} |
|
|
batch_time_m = AverageMeter() |
|
|
data_time_m = AverageMeter() |
|
|
end = time.time() |
|
|
for i, batch in enumerate(dataloader): |
|
|
i_accum = i // args.accum_freq |
|
|
step = num_batches_per_epoch * epoch + i_accum |
|
|
|
|
|
if not args.skip_scheduler: |
|
|
scheduler(step) |
|
|
|
|
|
data_time_m.update(time.time() - end) |
|
|
optimizer.zero_grad() |
|
|
assert args.accum_freq == 1, "accum freq disabled" |
|
|
with autocast(): |
|
|
losses, batch_size, logit_scale = method(batch, model, dist_P_VLM, dist_model, loss, device, cast_dtype, |
|
|
args.distributed, args) |
|
|
total_loss = sum(losses.values()) |
|
|
losses["loss"] = total_loss |
|
|
|
|
|
backward(total_loss, scaler) |
|
|
|
|
|
if scaler is not None: |
|
|
if args.horovod: |
|
|
optimizer.synchronize() |
|
|
scaler.unscale_(optimizer) |
|
|
if args.grad_clip_norm is not None: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) |
|
|
with optimizer.skip_synchronize(): |
|
|
scaler.step(optimizer) |
|
|
else: |
|
|
if args.grad_clip_norm is not None: |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
else: |
|
|
if args.grad_clip_norm is not None: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
unwrap_model(model).logit_scale.clamp_(0, math.log(100)) |
|
|
|
|
|
batch_time_m.update(time.time() - end) |
|
|
end = time.time() |
|
|
batch_count = i_accum + 1 |
|
|
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): |
|
|
|
|
|
num_samples = batch_count * batch_size * args.accum_freq * args.world_size |
|
|
samples_per_epoch = dataloader.num_samples |
|
|
percent_complete = 100.0 * batch_count / num_batches_per_epoch |
|
|
|
|
|
|
|
|
for key, val in losses.items(): |
|
|
if key not in losses_m: |
|
|
losses_m[key] = AverageMeter() |
|
|
losses_m[key].update(val.item(), batch_size) |
|
|
|
|
|
logit_scale_scalar = logit_scale.item() |
|
|
loss_log = " ".join( |
|
|
[ |
|
|
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" |
|
|
for loss_name, loss_m in losses_m.items() |
|
|
] |
|
|
) |
|
|
samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val |
|
|
samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val |
|
|
logging.info( |
|
|
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
|
|
f"Data (t): {data_time_m.avg:.3f} " |
|
|
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " |
|
|
f"LR: {optimizer.param_groups[0]['lr']:5f} " |
|
|
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log |
|
|
) |
|
|
|
|
|
|
|
|
log_data = { |
|
|
"data_time": data_time_m.val, |
|
|
"batch_time": batch_time_m.val, |
|
|
"samples_per_second": samples_per_second, |
|
|
"samples_per_second_per_gpu": samples_per_second_per_gpu, |
|
|
"scale": logit_scale_scalar, |
|
|
"lr": optimizer.param_groups[0]["lr"] |
|
|
} |
|
|
log_data.update({name:val.val for name,val in losses_m.items()}) |
|
|
|
|
|
batch_time_m.reset() |
|
|
data_time_m.reset() |
|
|
|
|
|
def evaluate(model, data, epoch, args): |
|
|
metrics = {} |
|
|
model.eval() |
|
|
|
|
|
zero_shot_metrics = zero_shot_eval(model, data, epoch, args) |
|
|
if not is_master(args): |
|
|
return {} |
|
|
metrics.update(zero_shot_metrics) |
|
|
if not metrics: |
|
|
return metrics |
|
|
|
|
|
keys = ''.join([f"{k}, " for k in metrics.keys() if 'all' in k])[:-2] |
|
|
values = ''.join([f'{round(v, 4):.4f}, ' for k, v in metrics.items() if 'all' in k])[:-2] |
|
|
|
|
|
logging.info( |
|
|
f"Eval Epoch: {epoch}. " |
|
|
+ f"{keys}: {values}." |
|
|
) |
|
|
|
|
|
logging.info(metrics) |
|
|
|
|
|
if args.save_logs: |
|
|
with open(os.path.join(args.checkpoint_path, "results.json"), "a+") as f: |
|
|
f.write(json.dumps(metrics)) |
|
|
f.write("\n") |
|
|
|
|
|
return metrics |
|
|
|