| import torch
|
| import comfy.model_management
|
| import comfy.conds
|
|
|
| def prepare_mask(noise_mask, shape, device):
|
| """ensures noise mask is of proper dimensions"""
|
| noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
| noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
| noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
|
| noise_mask = noise_mask.to(device)
|
| return noise_mask
|
|
|
| def get_models_from_cond(cond, model_type):
|
| models = []
|
| for c in cond:
|
| if model_type in c:
|
| models += [c[model_type]]
|
| return models
|
|
|
| def convert_cond(cond):
|
| out = []
|
| for c in cond:
|
| temp = c[1].copy()
|
| model_conds = temp.get("model_conds", {})
|
| if c[0] is not None:
|
| model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
|
| temp["cross_attn"] = c[0]
|
| temp["model_conds"] = model_conds
|
| out.append(temp)
|
| return out
|
|
|
| def get_additional_models(conds, dtype):
|
| """loads additional models in conditioning"""
|
| cnets = []
|
| gligen = []
|
|
|
| for k in conds:
|
| cnets += get_models_from_cond(conds[k], "control")
|
| gligen += get_models_from_cond(conds[k], "gligen")
|
|
|
| control_nets = set(cnets)
|
|
|
| inference_memory = 0
|
| control_models = []
|
| for m in control_nets:
|
| control_models += m.get_models()
|
| inference_memory += m.inference_memory_requirements(dtype)
|
|
|
| gligen = [x[1] for x in gligen]
|
| models = control_models + gligen
|
| return models, inference_memory
|
|
|
| def cleanup_additional_models(models):
|
| """cleanup additional models that were loaded"""
|
| for m in models:
|
| if hasattr(m, 'cleanup'):
|
| m.cleanup()
|
|
|
|
|
| def prepare_sampling(model, noise_shape, conds):
|
| device = model.load_device
|
| real_model = None
|
| models, inference_memory = get_additional_models(conds, model.model_dtype())
|
| memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
| minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
| comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
| real_model = model.model
|
|
|
| return real_model, conds, models
|
|
|
| def cleanup_models(conds, models):
|
| cleanup_additional_models(models)
|
|
|
| control_cleanup = []
|
| for k in conds:
|
| control_cleanup += get_models_from_cond(conds[k], "control")
|
|
|
| cleanup_additional_models(set(control_cleanup))
|
|
|