import torch import torch.nn.functional as F import numpy as np def compute_pwv(s: torch.Tensor, dscale: torch.Tensor, ratio=2, thresholds=[0.5, 0.55, 0.65, 0.95]): """ The shape of input scales tensor should be (b, n, 1) """ assert len(s.shape) == 3, len(thresholds) == 4 maxm = s.max(dim=1, keepdim=True).values minm = s.min(dim=1, keepdim=True).values d = maxm - minm maxmin = (s - minm) / d adjust_scale = torch.where(maxmin <= thresholds[0], -dscale * ratio, -dscale + dscale * (maxmin - thresholds[0]) / (thresholds[1] - thresholds[0])) adjust_scale = torch.where(maxmin > thresholds[1], 0.5 * dscale * (maxmin - thresholds[1]) / (thresholds[2] - thresholds[1]), adjust_scale) adjust_scale = torch.where(maxmin > thresholds[2], 0.5 * dscale + 0.5 * dscale * (maxmin - thresholds[2]) / (thresholds[3] - thresholds[2]), adjust_scale) adjust_scale = torch.where(maxmin > thresholds[3], dscale, adjust_scale) return adjust_scale def local_manipulate_step(clip, v, t, target_scale, a=None, c=None, enhance=False, thresholds=[]): # print(f"target:{t}, anchor:{a}") cls_token = v[:, 0].unsqueeze(1) v = v[:, 1:] cur_target_scale = clip.calculate_scale(cls_token, t) # control_scale = clip.calculate_scale(cls_token, c) # print(f"current global target scale: {cur_target_scale},", # f" global control scale: {control_scale}") if a is not None and a != "none": a = [a] * v.shape[0] a = clip.encode_text(a) anchor_scale = clip.calculate_scale(cls_token, a) dscale = target_scale - cur_target_scale if not enhance else target_scale - anchor_scale # print(f"global anchor scale: {anchor_scale}") c_map = clip.calculate_scale(v, c) a_map = clip.calculate_scale(v, a) pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale base = 1 if enhance else 0 v = v + (pwm + base * a_map) * (t - a) else: dscale = target_scale - cur_target_scale c_map = clip.calculate_scale(v, c) pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale v = v + pwm * t v = torch.cat([cls_token, v], dim=1) return v def local_manipulate(clip, v, targets, target_scales, anchors, controls, enhances=[], thresholds_list=[]): """ v: visual tokens in shape (b, n, c) target: target text embeddings in shape (b, 1 ,c) control: control text embeddings in shape (b, 1, c) """ controls, targets = clip.encode_text(controls + targets).chunk(2) for t, a, c, s_t, enhance, thresholds in zip(targets, anchors, controls, target_scales, enhances, thresholds_list): v = local_manipulate_step(clip, v, t, s_t, a, c, enhance, thresholds) return v def global_manipulate_step(clip, v, t, target_scale, a=None, enhance=False): if a is not None and a != "none": a = [a] * v.shape[0] a = clip.encode_text(a) if enhance: s_a = clip.calculate_scale(v, a) v = v - s_a * a else: v = v + target_scale * (t - a) return v if enhance: v = v + target_scale * t else: cur_target_scale = clip.calculate_scale(v, t) v = v + (target_scale - cur_target_scale) * t return v def global_manipulate(clip, v, targets, target_scales, anchors, enhances): targets = clip.encode_text(targets) for t, a, s_t, enhance in zip(targets, anchors, target_scales, enhances): v = global_manipulate_step(clip, v, t, s_t, a, enhance) return v def assign_heatmap(s: torch.Tensor, threshold: float): """ The shape of input scales tensor should be (b, n, 1) """ maxm = s.max(dim=1, keepdim=True).values minm = s.min(dim=1, keepdim=True).values d = maxm - minm return torch.where((s - minm) / d < threshold, torch.zeros_like(s), torch.ones_like(s) * 0.25) def get_heatmaps(model, reference, height, width, vis_c, ts0, ts1, ts2, ts3, controls, targets, anchors, thresholds_list, target_scales, enhances): model.low_vram_shift("cond") clip = model.cond_stage_model v = clip.encode(reference, "full") if len(targets) > 0: controls, targets = clip.encode_text(controls + targets).chunk(2) inputs_iter = zip(controls, targets, anchors, target_scales, thresholds_list, enhances) for c, t, a, target_scale, thresholds, enhance in inputs_iter: # update image tokens v = local_manipulate_step(clip, v, t, target_scale, a, c, enhance, thresholds) token_length = v.shape[1] - 1 grid_num = int(token_length ** 0.5) vis_c = clip.encode_text([vis_c]) local_v = v[:, 1:] scale = clip.calculate_scale(local_v, vis_c) scale = scale.permute(0, 2, 1).view(1, 1, grid_num, grid_num) scale = F.interpolate(scale, size=(height, width), mode="bicubic").squeeze(0).view(1, height * width) # calculate heatmaps heatmaps = [] for threshold in [ts0, ts1, ts2, ts3]: heatmap = assign_heatmap(scale, threshold=threshold) heatmap = heatmap.view(1, height, width).permute(1, 2, 0).cpu().numpy() heatmap = (heatmap * 255.).astype(np.uint8) heatmaps.append(heatmap) return heatmaps