| | import torch |
| | import math |
| | import comfy.utils |
| | import logging |
| |
|
| |
|
| | class CONDRegular: |
| | def __init__(self, cond): |
| | self.cond = cond |
| |
|
| | def _copy_with(self, cond): |
| | return self.__class__(cond) |
| |
|
| | def process_cond(self, batch_size, **kwargs): |
| | return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size)) |
| |
|
| | def can_concat(self, other): |
| | if self.cond.shape != other.cond.shape: |
| | return False |
| | if self.cond.device != other.cond.device: |
| | logging.warning("WARNING: conds not on same device, skipping concat.") |
| | return False |
| | return True |
| |
|
| | def concat(self, others): |
| | conds = [self.cond] |
| | for x in others: |
| | conds.append(x.cond) |
| | return torch.cat(conds) |
| |
|
| | def size(self): |
| | return list(self.cond.size()) |
| |
|
| |
|
| | class CONDNoiseShape(CONDRegular): |
| | def process_cond(self, batch_size, area, **kwargs): |
| | data = self.cond |
| | if area is not None: |
| | dims = len(area) // 2 |
| | for i in range(dims): |
| | data = data.narrow(i + 2, area[i + dims], area[i]) |
| |
|
| | return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size)) |
| |
|
| |
|
| | class CONDCrossAttn(CONDRegular): |
| | def can_concat(self, other): |
| | s1 = self.cond.shape |
| | s2 = other.cond.shape |
| | if s1 != s2: |
| | if s1[0] != s2[0] or s1[2] != s2[2]: |
| | return False |
| |
|
| | mult_min = math.lcm(s1[1], s2[1]) |
| | diff = mult_min // min(s1[1], s2[1]) |
| | if diff > 4: |
| | return False |
| | if self.cond.device != other.cond.device: |
| | logging.warning("WARNING: conds not on same device: skipping concat.") |
| | return False |
| | return True |
| |
|
| | def concat(self, others): |
| | conds = [self.cond] |
| | crossattn_max_len = self.cond.shape[1] |
| | for x in others: |
| | c = x.cond |
| | crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1]) |
| | conds.append(c) |
| |
|
| | out = [] |
| | for c in conds: |
| | if c.shape[1] < crossattn_max_len: |
| | c = c.repeat(1, crossattn_max_len // c.shape[1], 1) |
| | out.append(c) |
| | return torch.cat(out) |
| |
|
| |
|
| | class CONDConstant(CONDRegular): |
| | def __init__(self, cond): |
| | self.cond = cond |
| |
|
| | def process_cond(self, batch_size, **kwargs): |
| | return self._copy_with(self.cond) |
| |
|
| | def can_concat(self, other): |
| | if self.cond != other.cond: |
| | return False |
| | return True |
| |
|
| | def concat(self, others): |
| | return self.cond |
| |
|
| | def size(self): |
| | return [1] |
| |
|
| |
|
| | class CONDList(CONDRegular): |
| | def __init__(self, cond): |
| | self.cond = cond |
| |
|
| | def process_cond(self, batch_size, **kwargs): |
| | out = [] |
| | for c in self.cond: |
| | out.append(comfy.utils.repeat_to_batch_size(c, batch_size)) |
| |
|
| | return self._copy_with(out) |
| |
|
| | def can_concat(self, other): |
| | if len(self.cond) != len(other.cond): |
| | return False |
| | for i in range(len(self.cond)): |
| | if self.cond[i].shape != other.cond[i].shape: |
| | return False |
| |
|
| | return True |
| |
|
| | def concat(self, others): |
| | out = [] |
| | for i in range(len(self.cond)): |
| | o = [self.cond[i]] |
| | for x in others: |
| | o.append(x.cond[i]) |
| | out.append(torch.cat(o)) |
| |
|
| | return out |
| |
|
| | def size(self): |
| | o = 0 |
| | c = 1 |
| | for c in self.cond: |
| | size = c.size() |
| | o += math.prod(size) |
| | if len(size) > 1: |
| | c = size[1] |
| |
|
| | return [1, c, o // c] |
| |
|