| | """ |
| | Author: Luigi Piccinelli |
| | Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) |
| | """ |
| |
|
| | from functools import wraps |
| |
|
| | import numpy as np |
| | from scipy import interpolate |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from einops import rearrange, repeat, reduce |
| |
|
| |
|
| | def max_stack(tensors): |
| | return torch.stack(tensors, dim=-1).max(dim=-1)[0] |
| |
|
| |
|
| | def softmax_stack(tensors, temperature=1.0): |
| | return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) |
| |
|
| |
|
| | def mean_stack(tensors): |
| | if len(tensors) == 1: |
| | return tensors[0] |
| | return torch.stack(tensors, dim=-1).mean(dim=-1) |
| |
|
| |
|
| | def sum_stack(tensors): |
| | return torch.stack(tensors, dim=-1).sum(dim=-1) |
| |
|
| |
|
| | def convert_module_to_f16(l): |
| | """ |
| | Convert primitive modules to float16. |
| | """ |
| | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| | l.weight.data = l.weight.data.half() |
| | if l.bias is not None: |
| | l.bias.data = l.bias.data.half() |
| |
|
| |
|
| | def convert_module_to_f32(l): |
| | """ |
| | Convert primitive modules to float32, undoing convert_module_to_f16(). |
| | """ |
| | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| | l.weight.data = l.weight.data.float() |
| | if l.bias is not None: |
| | l.bias.data = l.bias.data.float() |
| |
|
| |
|
| | def format_seconds(seconds): |
| | minutes, seconds = divmod(seconds, 60) |
| | hours, minutes = divmod(minutes, 60) |
| | return f"{hours:d}:{minutes:02d}:{seconds:02d}" |
| |
|
| |
|
| | def get_params(module, lr, wd): |
| | skip_list = {} |
| | skip_keywords = {} |
| | if hasattr(module, "no_weight_decay"): |
| | skip_list = module.no_weight_decay() |
| | if hasattr(module, "no_weight_decay_keywords"): |
| | skip_keywords = module.no_weight_decay_keywords() |
| | has_decay = [] |
| | no_decay = [] |
| | for name, param in module.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | if ( |
| | (name in skip_list) |
| | or any((kw in name for kw in skip_keywords)) |
| | or len(param.shape) == 1 |
| | ): |
| | |
| | |
| | no_decay.append(param) |
| | else: |
| | has_decay.append(param) |
| |
|
| | group1 = { |
| | "params": has_decay, |
| | "weight_decay": wd, |
| | "lr": lr, |
| | "weight_decay_init": wd, |
| | "weight_decay_base": wd, |
| | "lr_init": lr, |
| | "lr_base": lr, |
| | } |
| | group2 = { |
| | "params": no_decay, |
| | "weight_decay": 0.0, |
| | "lr": lr, |
| | "weight_decay_init": 0.0, |
| | "weight_decay_base": 0.0, |
| | "weight_decay_final": 0.0, |
| | "lr_init": lr, |
| | "lr_base": lr, |
| | } |
| | return [group1, group2], [lr, lr] |
| |
|
| |
|
| | def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage): |
| | if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"): |
| | return 0 |
| | elif var_name.startswith("patch_embed"): |
| | return 0 |
| | elif var_name.startswith("layers"): |
| | if var_name.split(".")[2] == "blocks": |
| | stage_id = int(var_name.split(".")[1]) |
| | layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id]) |
| | return layer_id + 1 |
| | elif var_name.split(".")[2] == "downsample": |
| | stage_id = int(var_name.split(".")[1]) |
| | layer_id = sum(layers_per_stage[: stage_id + 1]) |
| | return layer_id |
| | else: |
| | return num_max_layer - 1 |
| |
|
| |
|
| | def get_params_layerdecayswin(module, lr, wd, ld): |
| | skip_list = {} |
| | skip_keywords = {} |
| | if hasattr(module, "no_weight_decay"): |
| | skip_list = module.no_weight_decay() |
| | if hasattr(module, "no_weight_decay_keywords"): |
| | skip_keywords = module.no_weight_decay_keywords() |
| | layers_per_stage = module.depths |
| | num_layers = sum(layers_per_stage) + 1 |
| | lrs = [] |
| | params = [] |
| | for name, param in module.named_parameters(): |
| | if not param.requires_grad: |
| | print(f"{name} frozen") |
| | continue |
| | layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage) |
| | lr_cur = lr * ld ** (num_layers - layer_id - 1) |
| | |
| | if (name in skip_list) or any((kw in name for kw in skip_keywords)): |
| | wd_cur = 0.0 |
| | else: |
| | wd_cur = wd |
| | params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur}) |
| | lrs.append(lr_cur) |
| | return params, lrs |
| |
|
| |
|
| | def log(t, eps: float = 1e-5): |
| | return torch.log(t.clamp(min=eps)) |
| |
|
| |
|
| | def l2norm(t): |
| | return F.normalize(t, dim=-1) |
| |
|
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def identity(t, *args, **kwargs): |
| | return t |
| |
|
| |
|
| | def divisible_by(numer, denom): |
| | return (numer % denom) == 0 |
| |
|
| |
|
| | def first(arr, d=None): |
| | if len(arr) == 0: |
| | return d |
| | return arr[0] |
| |
|
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | def maybe(fn): |
| | @wraps(fn) |
| | def inner(x): |
| | if not exists(x): |
| | return x |
| | return fn(x) |
| |
|
| | return inner |
| |
|
| |
|
| | def once(fn): |
| | called = False |
| |
|
| | @wraps(fn) |
| | def inner(x): |
| | nonlocal called |
| | if called: |
| | return |
| | called = True |
| | return fn(x) |
| |
|
| | return inner |
| |
|
| |
|
| | def _many(fn): |
| | @wraps(fn) |
| | def inner(tensors, pattern, **kwargs): |
| | return (fn(tensor, pattern, **kwargs) for tensor in tensors) |
| |
|
| | return inner |
| |
|
| |
|
| | rearrange_many = _many(rearrange) |
| | repeat_many = _many(repeat) |
| | reduce_many = _many(reduce) |
| |
|
| |
|
| | def load_pretrained(state_dict, checkpoint): |
| | checkpoint_model = checkpoint["model"] |
| | if any([True if "encoder." in k else False for k in checkpoint_model.keys()]): |
| | checkpoint_model = { |
| | k.replace("encoder.", ""): v |
| | for k, v in checkpoint_model.items() |
| | if k.startswith("encoder.") |
| | } |
| | print("Detect pre-trained model, remove [encoder.] prefix.") |
| | else: |
| | print("Detect non-pre-trained model, pass without doing anything.") |
| | print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") |
| | checkpoint = load_checkpoint_swin(state_dict, checkpoint_model) |
| |
|
| |
|
| | def load_checkpoint_swin(model, checkpoint_model): |
| | state_dict = model.state_dict() |
| | |
| | all_keys = list(checkpoint_model.keys()) |
| | for key in all_keys: |
| | if "relative_position_bias_table" in key: |
| | relative_position_bias_table_pretrained = checkpoint_model[key] |
| | relative_position_bias_table_current = state_dict[key] |
| | L1, nH1 = relative_position_bias_table_pretrained.size() |
| | L2, nH2 = relative_position_bias_table_current.size() |
| | if nH1 != nH2: |
| | print(f"Error in loading {key}, passing......") |
| | else: |
| | if L1 != L2: |
| | print(f"{key}: Interpolate relative_position_bias_table using geo.") |
| | src_size = int(L1**0.5) |
| | dst_size = int(L2**0.5) |
| |
|
| | def geometric_progression(a, r, n): |
| | return a * (1.0 - r**n) / (1.0 - r) |
| |
|
| | left, right = 1.01, 1.5 |
| | while right - left > 1e-6: |
| | q = (left + right) / 2.0 |
| | gp = geometric_progression(1, q, src_size // 2) |
| | if gp > dst_size // 2: |
| | right = q |
| | else: |
| | left = q |
| |
|
| | |
| | |
| |
|
| | dis = [] |
| | cur = 1 |
| | for i in range(src_size // 2): |
| | dis.append(cur) |
| | cur += q ** (i + 1) |
| |
|
| | r_ids = [-_ for _ in reversed(dis)] |
| |
|
| | x = r_ids + [0] + dis |
| | y = r_ids + [0] + dis |
| |
|
| | t = dst_size // 2.0 |
| | dx = np.arange(-t, t + 0.1, 1.0) |
| | dy = np.arange(-t, t + 0.1, 1.0) |
| |
|
| | print("Original positions = %s" % str(x)) |
| | print("Target positions = %s" % str(dx)) |
| |
|
| | all_rel_pos_bias = [] |
| |
|
| | for i in range(nH1): |
| | z = ( |
| | relative_position_bias_table_pretrained[:, i] |
| | .view(src_size, src_size) |
| | .float() |
| | .numpy() |
| | ) |
| | f_cubic = interpolate.interp2d(x, y, z, kind="cubic") |
| | all_rel_pos_bias.append( |
| | torch.Tensor(f_cubic(dx, dy)) |
| | .contiguous() |
| | .view(-1, 1) |
| | .to(relative_position_bias_table_pretrained.device) |
| | ) |
| |
|
| | new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) |
| | checkpoint_model[key] = new_rel_pos_bias |
| |
|
| | |
| | relative_position_index_keys = [ |
| | k for k in checkpoint_model.keys() if "relative_position_index" in k |
| | ] |
| | for k in relative_position_index_keys: |
| | del checkpoint_model[k] |
| |
|
| | |
| | relative_coords_table_keys = [ |
| | k for k in checkpoint_model.keys() if "relative_coords_table" in k |
| | ] |
| | for k in relative_coords_table_keys: |
| | del checkpoint_model[k] |
| |
|
| | |
| | rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k] |
| | for k in rpe_mlp_keys: |
| | checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k) |
| |
|
| | |
| | attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] |
| | for k in attn_mask_keys: |
| | del checkpoint_model[k] |
| |
|
| | encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")] |
| | for k in encoder_keys: |
| | checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k) |
| |
|
| | return checkpoint_model |
| |
|
| |
|
| | def add_padding_metas(out, image_metas): |
| | device = out.device |
| | |
| | paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas] |
| | paddings = torch.stack(paddings).to(device) |
| | outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] |
| | return torch.stack(outs) |
| |
|
| |
|
| | def remove_padding(out, paddings): |
| | B, C, H, W = out.shape |
| | device = out.device |
| | |
| | paddings = torch.stack(paddings).to(device) |
| | outs = [ |
| | o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]] |
| | for padding, o in zip(paddings, out) |
| | ] |
| | return torch.stack(outs) |
| |
|
| |
|
| | def remove_padding_metas(out, image_metas): |
| | B, C, H, W = out.shape |
| | device = out.device |
| | |
| | paddings = [ |
| | torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas |
| | ] |
| | return remove_padding(out, paddings) |
| |
|
| |
|
| | def ssi_helper(tensor1, tensor2): |
| | stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) |
| | tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1) |
| | scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( |
| | tensor2_one.T @ tensor1.unsqueeze(1) |
| | ) |
| | scale, shift = scale_shift.squeeze().chunk(2, dim=0) |
| | return scale, shift |
| |
|
| |
|
| | def calculate_mean_values(names, values): |
| | |
| | name_values = {name: {} for name in names} |
| |
|
| | |
| | for name, value in zip(names, values): |
| | name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value |
| | name_values[name]["count"] = name_values[name].get("count", 0.0) + 1 |
| |
|
| | |
| | output_dict = { |
| | name: name_values[name]["sum"] / name_values[name]["count"] |
| | for name in name_values |
| | } |
| |
|
| | return output_dict |
| |
|
| |
|
| | def remove_leading_dim(infos): |
| | if isinstance(infos, dict): |
| | return {k: remove_leading_dim(v) for k, v in infos.items()} |
| | elif isinstance(infos, torch.Tensor): |
| | return infos.squeeze(0) |
| | else: |
| | return infos |
| |
|