Spaces:
Runtime error
Runtime error
| import torch | |
| import math | |
| def repeat_to_batch_size(tensor, batch_size): | |
| if tensor.shape[0] > batch_size: | |
| return tensor[:batch_size] | |
| elif tensor.shape[0] < batch_size: | |
| return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] | |
| return tensor | |
| def lcm(a, b): | |
| return abs(a * b) // math.gcd(a, b) | |
| class Condition: | |
| def __init__(self, cond): | |
| self.cond = cond | |
| def _copy_with(self, cond): | |
| return self.__class__(cond) | |
| def process_cond(self, batch_size, device, **kwargs): | |
| return self._copy_with(repeat_to_batch_size(self.cond, batch_size).to(device)) | |
| def can_concat(self, other): | |
| if self.cond.shape != other.cond.shape: | |
| return False | |
| return True | |
| def concat(self, others): | |
| conds = [self.cond] | |
| for x in others: | |
| conds.append(x.cond) | |
| return torch.cat(conds) | |
| class ConditionNoiseShape(Condition): | |
| def process_cond(self, batch_size, device, area, **kwargs): | |
| data = self.cond[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]] | |
| return self._copy_with(repeat_to_batch_size(data, batch_size).to(device)) | |
| class ConditionCrossAttn(Condition): | |
| def can_concat(self, other): | |
| s1 = self.cond.shape | |
| s2 = other.cond.shape | |
| if s1 != s2: | |
| if s1[0] != s2[0] or s1[2] != s2[2]: | |
| return False | |
| mult_min = lcm(s1[1], s2[1]) | |
| diff = mult_min // min(s1[1], s2[1]) | |
| if diff > 4: | |
| return False | |
| return True | |
| def concat(self, others): | |
| conds = [self.cond] | |
| crossattn_max_len = self.cond.shape[1] | |
| for x in others: | |
| c = x.cond | |
| crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) | |
| conds.append(c) | |
| out = [] | |
| for c in conds: | |
| if c.shape[1] < crossattn_max_len: | |
| c = c.repeat(1, crossattn_max_len // c.shape[1], 1) | |
| out.append(c) | |
| return torch.cat(out) | |
| class ConditionConstant(Condition): | |
| def __init__(self, cond): | |
| self.cond = cond | |
| def process_cond(self, batch_size, device, **kwargs): | |
| return self._copy_with(self.cond) | |
| def can_concat(self, other): | |
| if self.cond != other.cond: | |
| return False | |
| return True | |
| def concat(self, others): | |
| return self.cond | |
| def compile_conditions(cond): | |
| if cond is None: | |
| return None | |
| if isinstance(cond, torch.Tensor): | |
| result = dict( | |
| cross_attn=cond, | |
| model_conds=dict( | |
| c_crossattn=ConditionCrossAttn(cond), | |
| ) | |
| ) | |
| return [result, ] | |
| cross_attn = cond['crossattn'] | |
| pooled_output = cond['vector'] | |
| result = dict( | |
| cross_attn=cross_attn, | |
| pooled_output=pooled_output, | |
| model_conds=dict( | |
| c_crossattn=ConditionCrossAttn(cross_attn), | |
| y=Condition(pooled_output) | |
| ) | |
| ) | |
| if 'guidance' in cond: | |
| result['model_conds']['guidance'] = Condition(cond['guidance']) | |
| return [result, ] | |
| def compile_weighted_conditions(cond, weights): | |
| transposed = list(map(list, zip(*weights))) | |
| results = [] | |
| for cond_pre in transposed: | |
| current_indices = [] | |
| current_weight = 0 | |
| for i, w in cond_pre: | |
| current_indices.append(i) | |
| current_weight = w | |
| if hasattr(cond, 'advanced_indexing'): | |
| feed = cond.advanced_indexing(current_indices) | |
| else: | |
| feed = cond[current_indices] | |
| h = compile_conditions(feed) | |
| h[0]['strength'] = current_weight | |
| results += h | |
| return results | |