| | import gc |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from flow.flow_utils import flow_warp |
| |
|
| | |
| |
|
| |
|
| | def calc_mean_std(feat, eps=1e-5): |
| | |
| | size = feat.size() |
| | assert (len(size) == 4) |
| | N, C = size[:2] |
| | feat_var = feat.view(N, C, -1).var(dim=2) + eps |
| | feat_std = feat_var.sqrt().view(N, C, 1, 1) |
| | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) |
| | return feat_mean, feat_std |
| |
|
| |
|
| | class AttentionControl(): |
| |
|
| | def __init__(self, inner_strength, mask_period, cross_period, ada_period, |
| | warp_period): |
| | self.step_store = self.get_empty_store() |
| | self.cur_step = 0 |
| | self.total_step = 0 |
| | self.cur_index = 0 |
| | self.init_store = False |
| | self.restore = False |
| | self.update = False |
| | self.flow = None |
| | self.mask = None |
| | self.restorex0 = False |
| | self.updatex0 = False |
| | self.inner_strength = inner_strength |
| | self.cross_period = cross_period |
| | self.mask_period = mask_period |
| | self.ada_period = ada_period |
| | self.warp_period = warp_period |
| |
|
| | @staticmethod |
| | def get_empty_store(): |
| | return { |
| | 'first': [], |
| | 'previous': [], |
| | 'x0_previous': [], |
| | 'first_ada': [] |
| | } |
| |
|
| | def forward(self, context, is_cross: bool, place_in_unet: str): |
| | cross_period = (self.total_step * self.cross_period[0], |
| | self.total_step * self.cross_period[1]) |
| | if not is_cross and place_in_unet == 'up': |
| | if self.init_store: |
| | self.step_store['first'].append(context.detach()) |
| | self.step_store['previous'].append(context.detach()) |
| | if self.update: |
| | tmp = context.clone().detach() |
| | if self.restore and self.cur_step >= cross_period[0] and \ |
| | self.cur_step <= cross_period[1]: |
| | context = torch.cat( |
| | (self.step_store['first'][self.cur_index], |
| | self.step_store['previous'][self.cur_index]), |
| | dim=1).clone() |
| | if self.update: |
| | self.step_store['previous'][self.cur_index] = tmp |
| | self.cur_index += 1 |
| | return context |
| |
|
| | def update_x0(self, x0): |
| | if self.init_store: |
| | self.step_store['x0_previous'].append(x0.detach()) |
| | style_mean, style_std = calc_mean_std(x0.detach()) |
| | self.step_store['first_ada'].append(style_mean.detach()) |
| | self.step_store['first_ada'].append(style_std.detach()) |
| | if self.updatex0: |
| | tmp = x0.clone().detach() |
| | if self.restorex0: |
| | if self.cur_step >= self.total_step * self.ada_period[ |
| | 0] and self.cur_step <= self.total_step * self.ada_period[ |
| | 1]: |
| | x0 = F.instance_norm(x0) * self.step_store['first_ada'][ |
| | 2 * self.cur_step + |
| | 1] + self.step_store['first_ada'][2 * self.cur_step] |
| | if self.cur_step >= self.total_step * self.warp_period[ |
| | 0] and self.cur_step <= self.total_step * self.warp_period[ |
| | 1]: |
| | pre = self.step_store['x0_previous'][self.cur_step] |
| | x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + ( |
| | 1 - self.mask) * x0 |
| | if self.updatex0: |
| | self.step_store['x0_previous'][self.cur_step] = tmp |
| | return x0 |
| |
|
| | def set_warp(self, flow, mask): |
| | self.flow = flow.clone() |
| | self.mask = mask.clone() |
| |
|
| | def __call__(self, context, is_cross: bool, place_in_unet: str): |
| | context = self.forward(context, is_cross, place_in_unet) |
| | return context |
| |
|
| | def set_step(self, step): |
| | self.cur_step = step |
| |
|
| | def set_total_step(self, total_step): |
| | self.total_step = total_step |
| | self.cur_index = 0 |
| |
|
| | def clear_store(self): |
| | del self.step_store |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | self.step_store = self.get_empty_store() |
| |
|
| | def set_task(self, task, restore_step=1.0): |
| | self.init_store = False |
| | self.restore = False |
| | self.update = False |
| | self.cur_index = 0 |
| | self.restore_step = restore_step |
| | self.updatex0 = False |
| | self.restorex0 = False |
| | if 'initfirst' in task: |
| | self.init_store = True |
| | self.clear_store() |
| | if 'updatestyle' in task: |
| | self.update = True |
| | if 'keepstyle' in task: |
| | self.restore = True |
| | if 'updatex0' in task: |
| | self.updatex0 = True |
| | if 'keepx0' in task: |
| | self.restorex0 = True |
| |
|