| import torch |
| import math |
| import ldm_patched.modules.utils |
|
|
|
|
|
|
| class CONDRegular: |
| def __init__(self, cond): |
| self.cond = cond |
|
|
| def _copy_with(self, cond): |
| return self.__class__(cond) |
|
|
| def process_cond(self, batch_size, device, **kwargs): |
| return self._copy_with(ldm_patched.modules.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) |
|
|
| def can_concat(self, other): |
| if self.cond.shape != other.cond.shape: |
| return False |
| return True |
|
|
| def concat(self, others): |
| conds = [self.cond] |
| for x in others: |
| conds.append(x.cond) |
| return torch.cat(conds) |
|
|
| class CONDNoiseShape(CONDRegular): |
| def process_cond(self, batch_size, device, area, **kwargs): |
| data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] |
| return self._copy_with(ldm_patched.modules.utils.repeat_to_batch_size(data, batch_size).to(device)) |
|
|
|
|
| class CONDCrossAttn(CONDRegular): |
| def can_concat(self, other): |
| s1 = self.cond.shape |
| s2 = other.cond.shape |
| if s1 != s2: |
| if s1[0] != s2[0] or s1[2] != s2[2]: |
| return False |
|
|
| mult_min = math.lcm(s1[1], s2[1]) |
| diff = mult_min // min(s1[1], s2[1]) |
| if diff > 4: |
| return False |
| return True |
|
|
| def concat(self, others): |
| conds = [self.cond] |
| crossattn_max_len = self.cond.shape[1] |
| for x in others: |
| c = x.cond |
| crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1]) |
| conds.append(c) |
|
|
| out = [] |
| for c in conds: |
| if c.shape[1] < crossattn_max_len: |
| c = c.repeat(1, crossattn_max_len // c.shape[1], 1) |
| out.append(c) |
| return torch.cat(out) |
|
|
| class CONDConstant(CONDRegular): |
| def __init__(self, cond): |
| self.cond = cond |
|
|
| def process_cond(self, batch_size, device, **kwargs): |
| return self._copy_with(self.cond) |
|
|
| def can_concat(self, other): |
| if self.cond != other.cond: |
| return False |
| return True |
|
|
| def concat(self, others): |
| return self.cond |
|
|