Spaces:
Running on Zero
Running on Zero
File size: 5,660 Bytes
d066167 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import torch
import torch.nn.functional as F
import numpy as np
def compute_pwv(s: torch.Tensor, dscale: torch.Tensor, ratio=2, thresholds=[0.5, 0.55, 0.65, 0.95]):
"""
The shape of input scales tensor should be (b, n, 1)
"""
assert len(s.shape) == 3, len(thresholds) == 4
maxm = s.max(dim=1, keepdim=True).values
minm = s.min(dim=1, keepdim=True).values
d = maxm - minm
maxmin = (s - minm) / d
adjust_scale = torch.where(maxmin <= thresholds[0],
-dscale * ratio,
-dscale + dscale * (maxmin - thresholds[0]) / (thresholds[1] - thresholds[0]))
adjust_scale = torch.where(maxmin > thresholds[1],
0.5 * dscale * (maxmin - thresholds[1]) / (thresholds[2] - thresholds[1]),
adjust_scale)
adjust_scale = torch.where(maxmin > thresholds[2],
0.5 * dscale + 0.5 * dscale * (maxmin - thresholds[2]) / (thresholds[3] - thresholds[2]),
adjust_scale)
adjust_scale = torch.where(maxmin > thresholds[3], dscale, adjust_scale)
return adjust_scale
def local_manipulate_step(clip, v, t, target_scale, a=None, c=None, enhance=False, thresholds=[]):
# print(f"target:{t}, anchor:{a}")
cls_token = v[:, 0].unsqueeze(1)
v = v[:, 1:]
cur_target_scale = clip.calculate_scale(cls_token, t)
# control_scale = clip.calculate_scale(cls_token, c)
# print(f"current global target scale: {cur_target_scale},",
# f" global control scale: {control_scale}")
if a is not None and a != "none":
a = [a] * v.shape[0]
a = clip.encode_text(a)
anchor_scale = clip.calculate_scale(cls_token, a)
dscale = target_scale - cur_target_scale if not enhance else target_scale - anchor_scale
# print(f"global anchor scale: {anchor_scale}")
c_map = clip.calculate_scale(v, c)
a_map = clip.calculate_scale(v, a)
pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
base = 1 if enhance else 0
v = v + (pwm + base * a_map) * (t - a)
else:
dscale = target_scale - cur_target_scale
c_map = clip.calculate_scale(v, c)
pwm = compute_pwv(c_map, dscale, thresholds=thresholds) if c != "everything" else dscale
v = v + pwm * t
v = torch.cat([cls_token, v], dim=1)
return v
def local_manipulate(clip, v, targets, target_scales, anchors, controls, enhances=[], thresholds_list=[]):
"""
v: visual tokens in shape (b, n, c)
target: target text embeddings in shape (b, 1 ,c)
control: control text embeddings in shape (b, 1, c)
"""
controls, targets = clip.encode_text(controls + targets).chunk(2)
for t, a, c, s_t, enhance, thresholds in zip(targets, anchors, controls, target_scales, enhances, thresholds_list):
v = local_manipulate_step(clip, v, t, s_t, a, c, enhance, thresholds)
return v
def global_manipulate_step(clip, v, t, target_scale, a=None, enhance=False):
if a is not None and a != "none":
a = [a] * v.shape[0]
a = clip.encode_text(a)
if enhance:
s_a = clip.calculate_scale(v, a)
v = v - s_a * a
else:
v = v + target_scale * (t - a)
return v
if enhance:
v = v + target_scale * t
else:
cur_target_scale = clip.calculate_scale(v, t)
v = v + (target_scale - cur_target_scale) * t
return v
def global_manipulate(clip, v, targets, target_scales, anchors, enhances):
targets = clip.encode_text(targets)
for t, a, s_t, enhance in zip(targets, anchors, target_scales, enhances):
v = global_manipulate_step(clip, v, t, s_t, a, enhance)
return v
def assign_heatmap(s: torch.Tensor, threshold: float):
"""
The shape of input scales tensor should be (b, n, 1)
"""
maxm = s.max(dim=1, keepdim=True).values
minm = s.min(dim=1, keepdim=True).values
d = maxm - minm
return torch.where((s - minm) / d < threshold, torch.zeros_like(s), torch.ones_like(s) * 0.25)
def get_heatmaps(model, reference, height, width, vis_c, ts0, ts1, ts2, ts3,
controls, targets, anchors, thresholds_list, target_scales, enhances):
model.low_vram_shift("cond")
clip = model.cond_stage_model
v = clip.encode(reference, "full")
if len(targets) > 0:
controls, targets = clip.encode_text(controls + targets).chunk(2)
inputs_iter = zip(controls, targets, anchors, target_scales, thresholds_list, enhances)
for c, t, a, target_scale, thresholds, enhance in inputs_iter:
# update image tokens
v = local_manipulate_step(clip, v, t, target_scale, a, c, enhance, thresholds)
token_length = v.shape[1] - 1
grid_num = int(token_length ** 0.5)
vis_c = clip.encode_text([vis_c])
local_v = v[:, 1:]
scale = clip.calculate_scale(local_v, vis_c)
scale = scale.permute(0, 2, 1).view(1, 1, grid_num, grid_num)
scale = F.interpolate(scale, size=(height, width), mode="bicubic").squeeze(0).view(1, height * width)
# calculate heatmaps
heatmaps = []
for threshold in [ts0, ts1, ts2, ts3]:
heatmap = assign_heatmap(scale, threshold=threshold)
heatmap = heatmap.view(1, height, width).permute(1, 2, 0).cpu().numpy()
heatmap = (heatmap * 255.).astype(np.uint8)
heatmaps.append(heatmap)
return heatmaps |