Spaces:
Runtime error
Runtime error
| import torch | |
| from backend import memory_management, attention | |
| from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler | |
| class KModel(torch.nn.Module): | |
| def __init__(self, model, diffusers_scheduler, k_predictor=None, config=None): | |
| super().__init__() | |
| self.config = config | |
| self.storage_dtype = model.storage_dtype | |
| self.computation_dtype = model.computation_dtype | |
| print(f'K-Model Created: {dict(storage_dtype=self.storage_dtype, computation_dtype=self.computation_dtype)}') | |
| self.diffusion_model = model | |
| if k_predictor is None: | |
| self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler) | |
| else: | |
| self.predictor = k_predictor | |
| def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): | |
| sigma = t | |
| xc = self.predictor.calculate_input(sigma, x) | |
| if c_concat is not None: | |
| xc = torch.cat([xc] + [c_concat], dim=1) | |
| context = c_crossattn | |
| dtype = self.computation_dtype | |
| xc = xc.to(dtype) | |
| t = self.predictor.timestep(t).float() | |
| context = context.to(dtype) | |
| extra_conds = {} | |
| for o in kwargs: | |
| extra = kwargs[o] | |
| if hasattr(extra, "dtype"): | |
| if extra.dtype != torch.int and extra.dtype != torch.long: | |
| extra = extra.to(dtype) | |
| extra_conds[o] = extra | |
| model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() | |
| return self.predictor.calculate_denoised(sigma, model_output, x) | |
| def memory_required(self, input_shape): | |
| area = input_shape[0] * input_shape[2] * input_shape[3] | |
| dtype_size = memory_management.dtype_size(self.computation_dtype) | |
| if attention.attention_function in [attention.attention_pytorch, attention.attention_xformers]: | |
| scaler = 1.28 | |
| else: | |
| scaler = 1.65 | |
| if attention.get_attn_precision() == torch.float32: | |
| dtype_size = 4 | |
| return scaler * area * dtype_size * 16384 | |